format with Black

This commit is contained in:
Gilles Boccon-Gibod
2022-12-10 08:53:51 -08:00
parent 297246fa4c
commit 135df0dcc0
104 changed files with 8646 additions and 5766 deletions

View File

@@ -60,18 +60,18 @@ from prompt_toolkit.layout import (
FormattedTextControl,
FloatContainer,
ConditionalContainer,
Dimension
Dimension,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
BUMBLE_USER_DIR = os.path.expanduser('~/.bumble')
DEFAULT_RSSI_BAR_WIDTH = 20
BUMBLE_USER_DIR = os.path.expanduser('~/.bumble')
DEFAULT_RSSI_BAR_WIDTH = 20
DEFAULT_CONNECTION_TIMEOUT = 30.0
DISPLAY_MIN_RSSI = -100
DISPLAY_MAX_RSSI = -30
RSSI_MONITOR_INTERVAL = 5.0 # Seconds
DISPLAY_MIN_RSSI = -100
DISPLAY_MAX_RSSI = -30
RSSI_MONITOR_INTERVAL = 5.0 # Seconds
# -----------------------------------------------------------------------------
@@ -84,12 +84,11 @@ App = None
# Utils
# -----------------------------------------------------------------------------
def le_phy_name(phy_id):
return {
HCI_LE_1M_PHY: '1M',
HCI_LE_2M_PHY: '2M',
HCI_LE_CODED_PHY: 'CODED'
}.get(phy_id, HCI_Constant.le_phy_name(phy_id))
return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
phy_id, HCI_Constant.le_phy_name(phy_id)
)
def rssi_bar(rssi):
@@ -124,20 +123,22 @@ def parse_phys(phys):
# -----------------------------------------------------------------------------
class ConsoleApp:
def __init__(self):
self.known_addresses = set()
self.known_addresses = set()
self.known_attributes = []
self.device = None
self.connected_peer = None
self.top_tab = 'device'
self.monitor_rssi = False
self.connection_rssi = None
self.device = None
self.connected_peer = None
self.top_tab = 'device'
self.monitor_rssi = False
self.connection_rssi = None
style = Style.from_dict({
'output-field': 'bg:#000044 #ffffff',
'input-field': 'bg:#000000 #ffffff',
'line': '#004400',
'error': 'fg:ansired'
})
style = Style.from_dict(
{
'output-field': 'bg:#000044 #ffffff',
'input-field': 'bg:#000000 #ffffff',
'line': '#004400',
'error': 'fg:ansired',
}
)
class LiveCompleter(Completer):
def __init__(self, words):
@@ -149,52 +150,37 @@ class ConsoleApp:
yield Completion(word, start_position=-len(prefix))
def make_completer():
return NestedCompleter.from_nested_dict({
'scan': {
'on': None,
'off': None,
'clear': None
},
'advertise': {
'on': None,
'off': None
},
'rssi': {
'on': None,
'off': None
},
'show': {
'scan': None,
'services': None,
'attributes': None,
'log': None,
'device': None
},
'filter': {
'address': None,
},
'connect': LiveCompleter(self.known_addresses),
'update-parameters': None,
'encrypt': None,
'disconnect': None,
'discover': {
'services': None,
'attributes': None
},
'request-mtu': None,
'read': LiveCompleter(self.known_attributes),
'write': LiveCompleter(self.known_attributes),
'subscribe': LiveCompleter(self.known_attributes),
'unsubscribe': LiveCompleter(self.known_attributes),
'set-phy': {
'1m': None,
'2m': None,
'coded': None
},
'set-default-phy': None,
'quit': None,
'exit': None
})
return NestedCompleter.from_nested_dict(
{
'scan': {'on': None, 'off': None, 'clear': None},
'advertise': {'on': None, 'off': None},
'rssi': {'on': None, 'off': None},
'show': {
'scan': None,
'services': None,
'attributes': None,
'log': None,
'device': None,
},
'filter': {
'address': None,
},
'connect': LiveCompleter(self.known_addresses),
'update-parameters': None,
'encrypt': None,
'disconnect': None,
'discover': {'services': None, 'attributes': None},
'request-mtu': None,
'read': LiveCompleter(self.known_attributes),
'write': LiveCompleter(self.known_attributes),
'subscribe': LiveCompleter(self.known_attributes),
'unsubscribe': LiveCompleter(self.known_attributes),
'set-phy': {'1m': None, '2m': None, 'coded': None},
'set-default-phy': None,
'quit': None,
'exit': None,
}
)
self.input_field = TextArea(
height=1,
@@ -202,49 +188,55 @@ class ConsoleApp:
multiline=False,
wrap_lines=False,
completer=make_completer(),
history=FileHistory(os.path.join(BUMBLE_USER_DIR, 'history'))
history=FileHistory(os.path.join(BUMBLE_USER_DIR, 'history')),
)
self.input_field.accept_handler = self.accept_input
self.output_height = Dimension(min=7, max=7, weight=1)
self.output_lines = []
self.output = FormattedTextControl(get_cursor_position=lambda: Point(0, max(0, len(self.output_lines) - 1)))
self.output = FormattedTextControl(
get_cursor_position=lambda: Point(0, max(0, len(self.output_lines) - 1))
)
self.output_max_lines = 20
self.scan_results_text = FormattedTextControl()
self.services_text = FormattedTextControl()
self.attributes_text = FormattedTextControl()
self.device_text = FormattedTextControl()
self.log_text = FormattedTextControl(get_cursor_position=lambda: Point(0, max(0, len(self.log_lines) - 1)))
self.log_text = FormattedTextControl(
get_cursor_position=lambda: Point(0, max(0, len(self.log_lines) - 1))
)
self.log_height = Dimension(min=7, weight=4)
self.log_max_lines = 100
self.log_lines = []
container = HSplit([
ConditionalContainer(
Frame(Window(self.scan_results_text), title='Scan Results'),
filter=Condition(lambda: self.top_tab == 'scan')
),
ConditionalContainer(
Frame(Window(self.services_text), title='Services'),
filter=Condition(lambda: self.top_tab == 'services')
),
ConditionalContainer(
Frame(Window(self.attributes_text), title='Attributes'),
filter=Condition(lambda: self.top_tab == 'attributes')
),
ConditionalContainer(
Frame(Window(self.log_text, height=self.log_height), title='Log'),
filter=Condition(lambda: self.top_tab == 'log')
),
ConditionalContainer(
Frame(Window(self.device_text), title='Device'),
filter=Condition(lambda: self.top_tab == 'device')
),
Frame(Window(self.output, height=self.output_height)),
FormattedTextToolbar(text=self.get_status_bar_text, style='reverse'),
self.input_field
])
container = HSplit(
[
ConditionalContainer(
Frame(Window(self.scan_results_text), title='Scan Results'),
filter=Condition(lambda: self.top_tab == 'scan'),
),
ConditionalContainer(
Frame(Window(self.services_text), title='Services'),
filter=Condition(lambda: self.top_tab == 'services'),
),
ConditionalContainer(
Frame(Window(self.attributes_text), title='Attributes'),
filter=Condition(lambda: self.top_tab == 'attributes'),
),
ConditionalContainer(
Frame(Window(self.log_text, height=self.log_height), title='Log'),
filter=Condition(lambda: self.top_tab == 'log'),
),
ConditionalContainer(
Frame(Window(self.device_text), title='Device'),
filter=Condition(lambda: self.top_tab == 'device'),
),
Frame(Window(self.output, height=self.output_height)),
FormattedTextToolbar(text=self.get_status_bar_text, style='reverse'),
self.input_field,
]
)
container = FloatContainer(
container,
@@ -260,16 +252,14 @@ class ConsoleApp:
layout = Layout(container, focused_element=self.input_field)
kb = KeyBindings()
@kb.add("c-c")
@kb.add("c-q")
def _(event):
event.app.exit()
self.ui = Application(
layout=layout,
style=style,
key_bindings=kb,
full_screen=True
layout=layout, style=style, key_bindings=kb, full_screen=True
)
async def run_async(self, device_config, transport):
@@ -277,13 +267,19 @@ class ConsoleApp:
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
if device_config:
self.device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
self.device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
random_address = f"{random.randint(192,255):02X}" # address is static random
random_address = (
f"{random.randint(192,255):02X}" # address is static random
)
for c in random.sample(range(255), 5):
random_address += f":{c:02X}"
self.append_to_log(f"Setting random address: {random_address}")
self.device = Device.with_hci('Bumble', random_address, hci_source, hci_sink)
self.device = Device.with_hci(
'Bumble', random_address, hci_source, hci_sink
)
self.device.listener = DeviceListener(self)
await self.device.power_on()
self.show_device(self.device)
@@ -321,7 +317,9 @@ class ConsoleApp:
else:
phy_state = ''
connection_state = f'{connection.peer_address} {connection_parameters} {connection.data_length}{phy_state}'
encryption_state = 'ENCRYPTED' if connection.is_encrypted else 'NOT ENCRYPTED'
encryption_state = (
'ENCRYPTED' if connection.is_encrypted else 'NOT ENCRYPTED'
)
att_mtu = f'ATT_MTU: {connection.att_mtu}'
return [
@@ -333,10 +331,10 @@ class ConsoleApp:
('', ' '),
('ansicyan', f' {att_mtu} '),
('', ' '),
('ansiyellow', f' {rssi} ')
('ansiyellow', f' {rssi} '),
]
def show_error(self, title, details = None):
def show_error(self, title, details=None):
appended = [('class:error', title)]
if details:
appended.append(('', f' {details}'))
@@ -359,7 +357,9 @@ class ConsoleApp:
for characteristic in service.characteristics:
lines.append(('ansimagenta', ' ' + str(characteristic) + '\n'))
self.known_attributes.append(f'{service.uuid.to_hex_str()}.{characteristic.uuid.to_hex_str()}')
self.known_attributes.append(
f'{service.uuid.to_hex_str()}.{characteristic.uuid.to_hex_str()}'
)
self.known_attributes.append(f'*.{characteristic.uuid.to_hex_str()}')
self.known_attributes.append(f'#{characteristic.handle:X}')
for descriptor in characteristic.descriptors:
@@ -418,7 +418,7 @@ class ConsoleApp:
def append_to_output(self, line, invalidate=True):
if type(line) is str:
line = [('', line)]
self.output_lines = self.output_lines[-self.output_max_lines:]
self.output_lines = self.output_lines[-self.output_max_lines :]
self.output_lines.append(line)
formatted_text = []
for line in self.output_lines:
@@ -430,7 +430,7 @@ class ConsoleApp:
def append_to_log(self, lines, invalidate=True):
self.log_lines.extend(lines.split('\n'))
self.log_lines = self.log_lines[-self.log_max_lines:]
self.log_lines = self.log_lines[-self.log_max_lines :]
self.log_text.text = ANSI('\n'.join(self.log_lines))
if invalidate:
self.ui.invalidate()
@@ -515,7 +515,10 @@ class ConsoleApp:
elif params[0] == 'on':
if len(params) == 2:
if not params[1].startswith("filter="):
self.show_error('invalid syntax', 'expected address filter=key1:value1,key2:value,... available filters: address')
self.show_error(
'invalid syntax',
'expected address filter=key1:value1,key2:value,... available filters: address',
)
# regex: (word):(any char except ,)
matches = re.findall(r"(\w+):([^,]+)", params[1])
for match in matches:
@@ -557,8 +560,7 @@ class ConsoleApp:
connection_parameters_preferences = None
else:
connection_parameters_preferences = {
phy: ConnectionParametersPreferences()
for phy in phys
phy: ConnectionParametersPreferences() for phy in phys
}
if self.device.is_scanning:
@@ -570,7 +572,7 @@ class ConsoleApp:
await self.device.connect(
params[0],
connection_parameters_preferences=connection_parameters_preferences,
timeout=DEFAULT_CONNECTION_TIMEOUT
timeout=DEFAULT_CONNECTION_TIMEOUT,
)
self.top_tab = 'services'
except TimeoutError:
@@ -588,7 +590,10 @@ class ConsoleApp:
async def do_update_parameters(self, params):
if len(params) != 1 or len(params[0].split('/')) != 3:
self.show_error('invalid syntax', 'expected update-parameters <interval-min>-<interval-max>/<max-latency>/<supervision>')
self.show_error(
'invalid syntax',
'expected update-parameters <interval-min>-<interval-max>/<max-latency>/<supervision>',
)
return
if not self.connected_peer:
@@ -596,14 +601,16 @@ class ConsoleApp:
return
connection_intervals, max_latency, supervision_timeout = params[0].split('/')
connection_interval_min, connection_interval_max = [int(x) for x in connection_intervals.split('-')]
connection_interval_min, connection_interval_max = [
int(x) for x in connection_intervals.split('-')
]
max_latency = int(max_latency)
supervision_timeout = int(supervision_timeout)
await self.connected_peer.connection.update_parameters(
connection_interval_min,
connection_interval_max,
max_latency,
supervision_timeout
supervision_timeout,
)
async def do_encrypt(self, params):
@@ -639,7 +646,9 @@ class ConsoleApp:
return
phy = await self.connected_peer.connection.get_phy()
self.append_to_output(f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, TX={HCI_Constant.le_phy_name(phy[1])}')
self.append_to_output(
f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, TX={HCI_Constant.le_phy_name(phy[1])}'
)
async def do_request_mtu(self, params):
if len(params) != 1:
@@ -721,7 +730,9 @@ class ConsoleApp:
return
await characteristic.subscribe(
lambda value: self.append_to_output(f"{characteristic} VALUE: 0x{value.hex()}"),
lambda value: self.append_to_output(
f"{characteristic} VALUE: 0x{value.hex()}"
),
)
async def do_unsubscribe(self, params):
@@ -742,7 +753,9 @@ class ConsoleApp:
async def do_set_phy(self, params):
if len(params) != 1:
self.show_error('invalid syntax', 'expected set-phy <tx_rx_phys>|<tx_phys>/<rx_phys>')
self.show_error(
'invalid syntax', 'expected set-phy <tx_rx_phys>|<tx_phys>/<rx_phys>'
)
return
if not self.connected_peer:
@@ -756,13 +769,15 @@ class ConsoleApp:
rx_phys = tx_phys
await self.connected_peer.connection.set_phy(
tx_phys=parse_phys(tx_phys),
rx_phys=parse_phys(rx_phys)
tx_phys=parse_phys(tx_phys), rx_phys=parse_phys(rx_phys)
)
async def do_set_default_phy(self, params):
if len(params) != 1:
self.show_error('invalid syntax', 'expected set-default-phy <tx_rx_phys>|<tx_phys>/<rx_phys>')
self.show_error(
'invalid syntax',
'expected set-default-phy <tx_rx_phys>|<tx_phys>/<rx_phys>',
)
return
if '/' in params[0]:
@@ -772,8 +787,7 @@ class ConsoleApp:
rx_phys = tx_phys
await self.device.set_default_phy(
tx_phys=parse_phys(tx_phys),
rx_phys=parse_phys(rx_phys)
tx_phys=parse_phys(tx_phys), rx_phys=parse_phys(rx_phys)
)
async def do_exit(self, params):
@@ -789,6 +803,7 @@ class ConsoleApp:
return
self.device.listener.address_filter = params[1]
# -----------------------------------------------------------------------------
# Device and Connection Listener
# -----------------------------------------------------------------------------
@@ -808,7 +823,9 @@ class DeviceListener(Device.Listener, Connection.Listener):
self._address_filter = re.compile(r".*")
else:
self._address_filter = re.compile(filter_addr)
self.scan_results = OrderedDict(filter(lambda x: self.filter_address_match(x), self.scan_results))
self.scan_results = OrderedDict(
filter(lambda x: self.filter_address_match(x), self.scan_results)
)
self.app.show_scan_results(self.scan_results)
def filter_address_match(self, address):
@@ -825,24 +842,36 @@ class DeviceListener(Device.Listener, Connection.Listener):
connection.listener = self
def on_disconnection(self, reason):
self.app.append_to_output(f'disconnected from {self.app.connected_peer}, reason: {HCI_Constant.error_name(reason)}')
self.app.append_to_output(
f'disconnected from {self.app.connected_peer}, reason: {HCI_Constant.error_name(reason)}'
)
self.app.connected_peer = None
self.app.connection_rssi = None
def on_connection_parameters_update(self):
self.app.append_to_output(f'connection parameters update: {self.app.connected_peer.connection.parameters}')
self.app.append_to_output(
f'connection parameters update: {self.app.connected_peer.connection.parameters}'
)
def on_connection_phy_update(self):
self.app.append_to_output(f'connection phy update: {self.app.connected_peer.connection.phy}')
self.app.append_to_output(
f'connection phy update: {self.app.connected_peer.connection.phy}'
)
def on_connection_att_mtu_update(self):
self.app.append_to_output(f'connection att mtu update: {self.app.connected_peer.connection.att_mtu}')
self.app.append_to_output(
f'connection att mtu update: {self.app.connected_peer.connection.att_mtu}'
)
def on_connection_encryption_change(self):
self.app.append_to_output(f'connection encryption change: {"encrypted" if self.app.connected_peer.connection.is_encrypted else "not encrypted"}')
self.app.append_to_output(
f'connection encryption change: {"encrypted" if self.app.connected_peer.connection.is_encrypted else "not encrypted"}'
)
def on_connection_data_length_change(self):
self.app.append_to_output(f'connection data length change: {self.app.connected_peer.connection.data_length}')
self.app.append_to_output(
f'connection data length change: {self.app.connected_peer.connection.data_length}'
)
def on_advertisement(self, advertisement):
if not self.filter_address_match(str(advertisement.address)):
@@ -851,12 +880,18 @@ class DeviceListener(Device.Listener, Connection.Listener):
entry_key = f'{advertisement.address}/{advertisement.address.address_type}'
entry = self.scan_results.get(entry_key)
if entry:
entry.ad_data = advertisement.data
entry.rssi = advertisement.rssi
entry.ad_data = advertisement.data
entry.rssi = advertisement.rssi
entry.connectable = advertisement.is_connectable
else:
self.app.add_known_address(str(advertisement.address))
self.scan_results[entry_key] = ScanResult(advertisement.address, advertisement.address.address_type, advertisement.data, advertisement.rssi, advertisement.is_connectable)
self.scan_results[entry_key] = ScanResult(
advertisement.address,
advertisement.address.address_type,
advertisement.data,
advertisement.rssi,
advertisement.is_connectable,
)
self.app.show_scan_results(self.scan_results)
@@ -866,11 +901,11 @@ class DeviceListener(Device.Listener, Connection.Listener):
# -----------------------------------------------------------------------------
class ScanResult:
def __init__(self, address, address_type, ad_data, rssi, connectable):
self.address = address
self.address = address
self.address_type = address_type
self.ad_data = ad_data
self.rssi = rssi
self.connectable = connectable
self.ad_data = ad_data
self.rssi = rssi
self.connectable = connectable
def to_display_string(self):
address_type_string = ('P', 'R', 'PI', 'RI')[self.address_type]

View File

@@ -39,7 +39,7 @@ from bumble.hci import (
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
@@ -51,13 +51,18 @@ async def get_classic_info(host):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if response.return_parameters.status == HCI_SUCCESS:
print()
print(color('Classic Address:', 'yellow'), response.return_parameters.bd_addr)
print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command())
if response.return_parameters.status == HCI_SUCCESS:
print()
print(color('Local Name:', 'yellow'), map_null_terminated_utf8_string(response.return_parameters.local_name))
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response.return_parameters.local_name),
)
# -----------------------------------------------------------------------------
@@ -65,21 +70,25 @@ async def get_le_info(host):
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command(HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command())
response = await host.send_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n'
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Advertising_Data_Length_Command())
response = await host.send_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n'
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
@@ -93,7 +102,7 @@ async def get_le_info(host):
f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'{response.return_parameters.supported_max_rx_time}'
),
'\n'
'\n',
)
print(color('LE Features:', 'yellow'))
@@ -112,10 +121,19 @@ async def async_main(transport):
# Print version
print(color('Version:', 'yellow'))
print(color(' Manufacturer: ', 'green'), name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier))
print(color(' HCI Version: ', 'green'), name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version))
print(
color(' Manufacturer: ', 'green'),
name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier),
)
print(
color(' HCI Version: ', 'green'),
name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version),
)
print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion)
print(color(' LMP Version: ', 'green'), name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version))
print(
color(' LMP Version: ', 'green'),
name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version),
)
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the Classic info
@@ -135,7 +153,7 @@ async def async_main(transport):
@click.command()
@click.argument('transport')
def main(transport):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(transport))

View File

@@ -28,7 +28,9 @@ from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def async_main():
if len(sys.argv) != 3:
print('Usage: controllers.py <hci-transport-1> <hci-transport-2> [<hci-transport-3> ...]')
print(
'Usage: controllers.py <hci-transport-1> <hci-transport-2> [<hci-transport-3> ...]'
)
print('example: python controllers.py pty:ble1 pty:ble2')
return
@@ -41,7 +43,12 @@ async def async_main():
for index, transport_name in enumerate(sys.argv[1:]):
transport = await open_transport_or_link(transport_name)
transports.append(transport)
controller = Controller(f'C{index}', host_source = transport.source, host_sink = transport.sink, link = link)
controller = Controller(
f'C{index}',
host_source=transport.source,
host_sink=transport.sink,
link=link,
)
controllers.append(controller)
# Wait until the user interrupts
@@ -54,7 +61,7 @@ async def async_main():
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main())

View File

@@ -64,9 +64,13 @@ async def async_main(device_config, encrypt, transport, address_or_name):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if address_or_name:
@@ -81,7 +85,12 @@ async def async_main(device_config, encrypt, transport, address_or_name):
else:
# Wait for a connection
done = asyncio.get_running_loop().create_future()
device.on('connection', lambda connection: asyncio.create_task(dump_gatt_db(Peer(connection), done)))
device.on(
'connection',
lambda connection: asyncio.create_task(
dump_gatt_db(Peer(connection), done)
),
)
await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue'))
@@ -99,7 +108,7 @@ def main(device_config, encrypt, transport, address_or_name):
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection.
"""
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))

View File

@@ -33,10 +33,12 @@ from bumble.hci import HCI_Constant
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = 'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = (
'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
)
GG_PREFERRED_MTU = 256
@@ -44,8 +46,8 @@ GG_PREFERRED_MTU = 256
# -----------------------------------------------------------------------------
class GattlinkL2capEndpoint:
def __init__(self):
self.l2cap_channel = None
self.l2cap_packet = b''
self.l2cap_channel = None
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# Called when an L2CAP SDU has been received
@@ -71,12 +73,12 @@ class GattlinkL2capEndpoint:
class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device, peer_address):
super().__init__()
self.device = device
self.peer_address = peer_address
self.peer = None
self.tx_socket = None
self.rx_characteristic = None
self.tx_characteristic = None
self.device = device
self.peer_address = peer_address
self.peer = None
self.tx_socket = None
self.rx_characteristic = None
self.tx_characteristic = None
self.l2cap_psm_characteristic = None
device.listener = self
@@ -127,7 +129,9 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
self.rx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
self.tx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID:
elif (
characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID
):
self.l2cap_psm_characteristic = characteristic
print('RX:', self.rx_characteristic)
print('TX:', self.tx_characteristic)
@@ -135,7 +139,9 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
if self.l2cap_psm_characteristic:
# Subscribe to and then read the PSM value
await self.peer.subscribe(self.l2cap_psm_characteristic, self.on_l2cap_psm_received)
await self.peer.subscribe(
self.l2cap_psm_characteristic, self.on_l2cap_psm_received
)
psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
psm = struct.unpack('<H', psm_bytes)[0]
await self.connect_l2cap(psm)
@@ -150,7 +156,12 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
print(color(f'!!! Connection failed: {error}'))
def on_disconnection(self, reason):
print(color(f'!!! Disconnected from {self.peer}, reason={HCI_Constant.error_name(reason)}', 'red'))
print(
color(
f'!!! Disconnected from {self.peer}, reason={HCI_Constant.error_name(reason)}',
'red',
)
)
self.tx_characteristic = None
self.rx_characteristic = None
self.peer = None
@@ -193,10 +204,10 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device):
super().__init__()
self.device = device
self.peer = None
self.tx_socket = None
self.tx_subscriber = None
self.device = device
self.peer = None
self.tx_socket = None
self.tx_subscriber = None
self.rx_characteristic = None
# Register as a listener
@@ -212,35 +223,37 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
GG_GATTLINK_RX_CHARACTERISTIC_UUID,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_rx_write)
CharacteristicValue(write=self.on_rx_write),
)
self.tx_characteristic = Characteristic(
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
Characteristic.NOTIFY,
Characteristic.READABLE
Characteristic.READABLE,
)
self.tx_characteristic.on('subscription', self.on_tx_subscription)
self.psm_characteristic = Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([psm, 0])
bytes([psm, 0]),
)
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
self.rx_characteristic,
self.tx_characteristic,
self.psm_characteristic
]
[self.rx_characteristic, self.tx_characteristic, self.psm_characteristic],
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(
reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))
),
),
]
)
)
async def start(self):
@@ -270,7 +283,9 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
# Called when the subscription to the TX characteristic has changed
def on_tx_subscription(self, peer, enabled):
print(f'### [GATT TX] subscription from {peer}: {"enabled" if enabled else "disabled"}')
print(
f'### [GATT TX] subscription from {peer}: {"enabled" if enabled else "disabled"}'
)
if enabled:
self.tx_subscriber = peer
else:
@@ -290,7 +305,15 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
# -----------------------------------------------------------------------------
async def run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
async def run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -307,14 +330,12 @@ async def run(hci_transport, device_address, role_or_peer_address, send_host, se
# Create a UDP to RX bridge (receive from UDP, send to RX)
loop = asyncio.get_running_loop()
await loop.create_datagram_endpoint(
lambda: bridge,
local_addr=(receive_host, receive_port)
lambda: bridge, local_addr=(receive_host, receive_port)
)
# Create a UDP to TX bridge (receive from TX, send to UDP)
bridge.tx_socket, _ = await loop.create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(),
remote_addr=(send_host, send_port)
lambda: asyncio.DatagramProtocol(), remote_addr=(send_host, send_port)
)
await device.power_on()
@@ -328,15 +349,43 @@ async def run(hci_transport, device_address, role_or_peer_address, send_host, se
@click.argument('hci_transport')
@click.argument('device_address')
@click.argument('role_or_peer_address')
@click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to')
@click.option(
'-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to'
)
@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
@click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on')
@click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on')
def main(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
asyncio.run(run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port))
@click.option(
'-rh',
'--receive-host',
type=str,
default='127.0.0.1',
help='UDP host to receive on',
)
@click.option(
'-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on'
)
def main(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
asyncio.run(
run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
)
)
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
main()

View File

@@ -34,16 +34,26 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def async_main():
if len(sys.argv) < 3:
print('Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]')
print('example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078')
print(
'Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]'
)
print(
'example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078'
)
return
print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[1]) as (hci_host_source, hci_host_sink):
async with await transport.open_transport_or_link(sys.argv[1]) as (
hci_host_source,
hci_host_sink,
):
print('>>> connected')
print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[2]) as (hci_controller_source, hci_controller_sink):
async with await transport.open_transport_or_link(sys.argv[2]) as (
hci_controller_source,
hci_controller_sink,
):
print('>>> connected')
command_short_circuits = []
@@ -51,18 +61,23 @@ async def async_main():
for op_code_str in sys.argv[3].split(','):
if ':' in op_code_str:
ogf, ocf = op_code_str.split(':')
command_short_circuits.append(hci.hci_command_op_code(int(ogf, 16), int(ocf, 16)))
command_short_circuits.append(
hci.hci_command_op_code(int(ogf, 16), int(ocf, 16))
)
else:
command_short_circuits.append(int(op_code_str, 16))
def host_to_controller_filter(hci_packet):
if hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET and hci_packet.op_code in command_short_circuits:
if (
hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET
and hci_packet.op_code in command_short_circuits
):
# Respond with a success response
logger.debug('short-circuiting packet')
response = hci.HCI_Command_Complete_Event(
num_hci_command_packets = 1,
command_opcode = hci_packet.op_code,
return_parameters = bytes([hci.HCI_SUCCESS])
num_hci_command_packets=1,
command_opcode=hci_packet.op_code,
return_parameters=bytes([hci.HCI_SUCCESS]),
)
# Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True)
@@ -73,14 +88,14 @@ async def async_main():
hci_controller_source,
hci_controller_sink,
host_to_controller_filter,
None
None,
)
await asyncio.get_running_loop().create_future()
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main())

View File

@@ -38,36 +38,32 @@ class ServerBridge:
and waits for a new L2CAP CoC channel to be connected.
When the TCP connection is closed by the TCP server, XXXX
"""
def __init__(
self,
psm,
max_credits,
mtu,
mps,
tcp_host,
tcp_port
):
self.psm = psm
def __init__(self, psm, max_credits, mtu, mps, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.mtu = mtu
self.mps = mps
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
# Listen for incoming L2CAP CoC connections
device.register_l2cap_channel_server(
psm = self.psm,
server = self.on_coc,
max_credits = self.max_credits,
mtu = self.mtu,
mps = self.mps
psm=self.psm,
server=self.on_coc,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
print(color('@@@ Bluetooth connection:', 'green'), connection)
connection.on('disconnection', on_ble_disconnection)
@@ -82,7 +78,7 @@ class ServerBridge:
class Pipe:
def __init__(self, bridge, l2cap_channel):
self.bridge = bridge
self.bridge = bridge
self.tcp_transport = None
self.l2cap_channel = l2cap_channel
@@ -91,7 +87,12 @@ class ServerBridge:
async def connect_to_tcp(self):
# Connect to the TCP server
print(color(f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...', 'yellow'))
print(
color(
f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...',
'yellow',
)
)
class TcpClientProtocol(asyncio.Protocol):
def __init__(self, pipe):
@@ -107,7 +108,10 @@ class ServerBridge:
self.pipe.l2cap_channel.write(data)
try:
self.tcp_transport, _ = await asyncio.get_running_loop().create_connection(
(
self.tcp_transport,
_,
) = await asyncio.get_running_loop().create_connection(
lambda: TcpClientProtocol(self),
host=self.bridge.tcp_host,
port=self.bridge.tcp_port,
@@ -149,23 +153,14 @@ class ClientBridge:
READ_CHUNK_SIZE = 4096
def __init__(
self,
psm,
max_credits,
mtu,
mps,
address,
tcp_host,
tcp_port
):
self.psm = psm
def __init__(self, psm, max_credits, mtu, mps, address, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.mtu = mtu
self.mps = mps
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
print(color(f'### Connecting to {self.address}...', 'yellow'))
@@ -174,7 +169,10 @@ class ClientBridge:
# Called when the BLE connection is disconnected
def on_ble_disconnection(reason):
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
connection.on('disconnection', on_ble_disconnection)
@@ -196,10 +194,10 @@ class ClientBridge:
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.open_l2cap_channel(
psm = self.psm,
max_credits = self.max_credits,
mtu = self.mtu,
mps = self.mps
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
@@ -215,7 +213,7 @@ class ClientBridge:
l2cap_channel.pause_reading,
l2cap_channel.resume_reading,
writer.write,
writer.drain
writer.drain,
)
l2cap_to_tcp_pipe.start()
@@ -242,9 +240,13 @@ class ClientBridge:
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != '_' else None,
port=self.tcp_port
port=self.tcp_port,
)
print(
color(
f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'
)
)
print(color(f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'))
# -----------------------------------------------------------------------------
@@ -266,20 +268,43 @@ async def run(device_config, hci_transport, bridge):
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option('--l2cap-coc-max-credits', help='Maximum L2CAP CoC Credits', type=click.IntRange(1, 65535), default=128)
@click.option('--l2cap-coc-mtu', help='L2CAP CoC MTU', type=click.IntRange(23, 65535), default=1022)
@click.option('--l2cap-coc-mps', help='L2CAP CoC MPS', type=click.IntRange(23, 65533), default=1024)
def cli(context, device_config, hci_transport, psm, l2cap_coc_max_credits, l2cap_coc_mtu, l2cap_coc_mps):
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option(
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
)
@click.option(
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
default=1024,
)
def cli(
context,
device_config,
hci_transport,
psm,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
# -----------------------------------------------------------------------------
@@ -294,12 +319,9 @@ def server(context, tcp_host, tcp_port):
context.obj['mtu'],
context.obj['mps'],
tcp_host,
tcp_port)
asyncio.run(run(
context.obj['device_config'],
context.obj['hci_transport'],
bridge
))
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
@@ -316,16 +338,12 @@ def client(context, bluetooth_address, tcp_host, tcp_port):
context.obj['mps'],
bluetooth_address,
tcp_host,
tcp_port
tcp_port,
)
asyncio.run(run(
context.obj['device_config'],
context.obj['hci_transport'],
bridge
))
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
cli(obj={})

View File

@@ -65,9 +65,9 @@ class Connection:
"""
def __init__(self, room, websocket):
self.room = room
self.room = room
self.websocket = websocket
self.address = str(uuid.uuid4())
self.address = str(uuid.uuid4())
async def send_message(self, message):
try:
@@ -110,9 +110,9 @@ class Room:
"""
def __init__(self, relay, name):
self.relay = relay
self.name = name
self.observers = []
self.relay = relay
self.name = name
self.observers = []
self.connections = []
async def add_connection(self, connection):
@@ -145,7 +145,9 @@ class Room:
# This is an RPC request
await self.on_rpc_request(connection, message)
else:
await connection.send_message(f'result:{error_to_json("error: invalid message")}')
await connection.send_message(
f'result:{error_to_json("error: invalid message")}'
)
async def broadcast_message(self, sender, message):
'''
@@ -155,7 +157,9 @@ class Room:
async def on_rpc_request(self, connection, message):
command, *params = message.split(' ', 1)
if handler := getattr(self, f'on_{command[1:].lower().replace("-","_")}_command', None):
if handler := getattr(
self, f'on_{command[1:].lower().replace("-","_")}_command', None
):
try:
result = await handler(connection, params)
except Exception as error:
@@ -192,7 +196,9 @@ class Room:
current_address = connection.address
new_address = params[0]
connection.set_address(new_address)
await self.broadcast_message(connection, f'address-changed:from={current_address},to={new_address}')
await self.broadcast_message(
connection, f'address-changed:from={current_address},to={new_address}'
)
# ----------------------------------------------------------------------------
@@ -246,24 +252,24 @@ def main():
print('ERROR: Python 3.6.1 or higher is required')
sys.exit(1)
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Parse arguments
arg_parser = argparse.ArgumentParser(description='Bumble Link Relay')
arg_parser.add_argument('--log-level', default='INFO', help='logger level')
arg_parser.add_argument('--log-config', help='logger config file (YAML)')
arg_parser.add_argument('--port',
type = int,
default = DEFAULT_RELAY_PORT,
help = 'Port to listen on')
arg_parser.add_argument(
'--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on'
)
args = arg_parser.parse_args()
# Setup logger
if args.log_config:
from logging import config
config.fileConfig(args.log_config)
else:
logging.basicConfig(level = getattr(logging, args.log_level.upper()))
logging.basicConfig(level=getattr(logging, args.log_level.upper()))
# Start a relay
relay = Relay(args.port)

View File

@@ -33,30 +33,32 @@ from bumble.gatt import (
GATT_GENERIC_ACCESS_SERVICE,
Service,
Characteristic,
CharacteristicValue
CharacteristicValue,
)
from bumble.att import (
ATT_Error,
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
)
# -----------------------------------------------------------------------------
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, prompt):
super().__init__({
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
'none': PairingDelegate.NO_OUTPUT_NO_INPUT
}[capability_string.lower()])
super().__init__(
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
'none': PairingDelegate.NO_OUTPUT_NO_INPUT,
}[capability_string.lower()]
)
self.mode = mode
self.peer = Peer(connection)
self.mode = mode
self.peer = Peer(connection)
self.peer_name = None
self.prompt = prompt
self.prompt = prompt
async def update_peer_name(self):
if self.peer_name is not None:
@@ -103,7 +105,11 @@ class Delegate(PairingDelegate):
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
while True:
response = await aioconsole.ainput(color(f'>>> Does the other device display {number:0{digits}}? ', 'yellow'))
response = await aioconsole.ainput(
color(
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
)
)
response = response.lower().strip()
if response == 'yes':
return True
@@ -149,7 +155,9 @@ async def get_peer_name(peer, mode):
if not services:
return None
values = await peer.read_characteristics_by_uuid(GATT_DEVICE_NAME_CHARACTERISTIC, services[0])
values = await peer.read_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
)
if values:
return values[0].decode('utf-8')
@@ -183,14 +191,14 @@ def on_connection(connection, request):
print(color(f'<<< Connection: {connection}', 'green'))
# Listen for pairing events
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', on_pairing)
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', on_pairing)
connection.on('pairing_failure', on_pairing_failure)
# Listen for encryption changes
connection.on(
'connection_encryption_change',
lambda: on_connection_encryption_change(connection)
lambda: on_connection_encryption_change(connection),
)
# Request pairing if needed
@@ -202,7 +210,12 @@ def on_connection(connection, request):
# -----------------------------------------------------------------------------
def on_connection_encryption_change(connection):
print(color('@@@-----------------------------------', 'blue'))
print(color(f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', 'blue'))
print(
color(
f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted',
'blue',
)
)
print(color('@@@-----------------------------------', 'blue'))
@@ -241,7 +254,7 @@ async def pair(
keystore_file,
device_config,
hci_transport,
address_or_name
address_or_name,
):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
@@ -272,9 +285,11 @@ async def pair(
'552957FB-CF1F-4A31-9535-E78847E1A714',
Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=read_with_error, write=write_with_error)
CharacteristicValue(
read=read_with_error, write=write_with_error
),
)
]
],
)
)
@@ -288,10 +303,7 @@ async def pair(
# Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig(
sc,
mitm,
bond,
Delegate(mode, connection, io, prompt)
sc, mitm, bond, Delegate(mode, connection, io, prompt)
)
# Connect to a peer or wait for a connection
@@ -319,21 +331,70 @@ async def pair(
# -----------------------------------------------------------------------------
@click.command()
@click.option('--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True)
@click.option('--sc', type=bool, default=True, help='Use the Secure Connections protocol', show_default=True)
@click.option('--mitm', type=bool, default=True, help='Request MITM protection', show_default=True)
@click.option('--bond', type=bool, default=True, help='Enable bonding', show_default=True)
@click.option('--io', type=click.Choice(['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']), default='display+keyboard', show_default=True)
@click.option(
'--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True
)
@click.option(
'--sc',
type=bool,
default=True,
help='Use the Secure Connections protocol',
show_default=True,
)
@click.option(
'--mitm', type=bool, default=True, help='Request MITM protection', show_default=True
)
@click.option(
'--bond', type=bool, default=True, help='Enable bonding', show_default=True
)
@click.option(
'--io',
type=click.Choice(
['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']
),
default='display+keyboard',
show_default=True,
)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option('--request', is_flag=True, help='Request that the connecting peer initiate pairing')
@click.option(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
)
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
@click.option('--keystore-file', help='File in which to store the pairing keys')
@click.argument('device-config')
@click.argument('hci_transport')
@click.argument('address-or-name', required=False)
def main(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(pair(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name))
def main(
mode,
sc,
mitm,
bond,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(
pair(
mode,
sc,
mitm,
bond,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
)
)
# -----------------------------------------------------------------------------

View File

@@ -31,8 +31,8 @@ from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# -----------------------------------------------------------------------------
def make_rssi_bar(rssi):
DISPLAY_MIN_RSSI = -105
DISPLAY_MAX_RSSI = -30
DISPLAY_MIN_RSSI = -105
DISPLAY_MAX_RSSI = -30
DEFAULT_RSSI_BAR_WIDTH = 30
blocks = ['', '', '', '', '', '', '', '']
@@ -63,7 +63,9 @@ class AdvertisementPrinter:
resolution_qualifier = f'(resolved from {advertisement.address})'
address = resolved
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[address.address_type]
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
if address.is_public:
type_color = 'cyan'
else:
@@ -93,7 +95,8 @@ class AdvertisementPrinter:
f'[{color(address_type_string, type_color)}]{address_qualifier}{resolution_qualifier}:{separator}'
f'{phy_info}'
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
f'{advertisement.data.to_string(separator)}\n')
f'{advertisement.data.to_string(separator)}\n'
)
def on_advertisement(self, advertisement):
self.print_advertisement(advertisement)
@@ -114,16 +117,20 @@ async def scan(
raw,
keystore_file,
device_config,
transport
transport,
):
print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected')
if device_config:
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
if keystore_file:
keystore = JsonKeyStore(namespace=None, filename=keystore_file)
@@ -153,7 +160,7 @@ async def scan(
scan_interval=scan_interval,
scan_window=scan_window,
filter_duplicates=filter_duplicates,
scanning_phys=scanning_phys
scanning_phys=scanning_phys,
)
await hci_source.wait_for_termination()
@@ -165,15 +172,51 @@ async def scan(
@click.option('--passive', is_flag=True, default=False, help='Perform passive scanning')
@click.option('--scan-interval', type=int, default=60, help='Scan interval')
@click.option('--scan-window', type=int, default=60, help='Scan window')
@click.option('--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY')
@click.option('--filter-duplicates', type=bool, default=True, help='Filter duplicates at the controller level')
@click.option('--raw', is_flag=True, default=False, help='Listen for raw advertising reports instead of processed ones')
@click.option(
'--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY'
)
@click.option(
'--filter-duplicates',
type=bool,
default=True,
help='Filter duplicates at the controller level',
)
@click.option(
'--raw',
is_flag=True,
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device')
@click.argument('transport')
def main(min_rssi, passive, scan_interval, scan_window, phy, filter_duplicates, raw, keystore_file, device_config, transport):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(scan(min_rssi, passive, scan_interval, scan_window, phy, filter_duplicates, raw, keystore_file, device_config, transport))
def main(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(
scan(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
)
)
# -----------------------------------------------------------------------------

View File

@@ -30,10 +30,10 @@ class SnoopPacketReader:
Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not exactly the same...)
'''
DATALINK_H1 = 1001
DATALINK_H4 = 1002
DATALINK_H1 = 1001
DATALINK_H4 = 1002
DATALINK_BSCP = 1003
DATALINK_H5 = 1004
DATALINK_H5 = 1004
def __init__(self, source):
self.source = source
@@ -41,9 +41,16 @@ class SnoopPacketReader:
# Read the header
identification_pattern = source.read(8)
if identification_pattern.hex().lower() != '6274736e6f6f7000':
raise ValueError('not a valid snoop file, unexpected identification pattern')
(self.version_number, self.data_link_type) = struct.unpack('>II', source.read(8))
if self.data_link_type != self.DATALINK_H4 and self.data_link_type != self.DATALINK_H1:
raise ValueError(
'not a valid snoop file, unexpected identification pattern'
)
(self.version_number, self.data_link_type) = struct.unpack(
'>II', source.read(8)
)
if (
self.data_link_type != self.DATALINK_H4
and self.data_link_type != self.DATALINK_H1
):
raise ValueError(f'datalink type {self.data_link_type} not supported')
def next_packet(self):
@@ -57,7 +64,7 @@ class SnoopPacketReader:
packet_flags,
cumulative_drops,
timestamp_seconds,
timestamp_microsecond
timestamp_microsecond,
) = struct.unpack('>IIIIII', header)
# Abort on truncated packets
@@ -79,7 +86,10 @@ class SnoopPacketReader:
else:
packet_type = hci.HCI_ACL_DATA_PACKET
return (packet_flags & 1, bytes([packet_type]) + self.source.read(included_length))
return (
packet_flags & 1,
bytes([packet_type]) + self.source.read(included_length),
)
else:
return (packet_flags & 1, self.source.read(included_length))
@@ -88,7 +98,12 @@ class SnoopPacketReader:
# Main
# -----------------------------------------------------------------------------
@click.command()
@click.option('--format', type=click.Choice(['h4', 'snoop']), default='h4', help='Format of the input file')
@click.option(
'--format',
type=click.Choice(['h4', 'snoop']),
default='h4',
help='Format of the input file',
)
@click.argument('filename')
def main(format, filename):
input = open(filename, 'rb')
@@ -97,6 +112,7 @@ def main(format, filename):
def read_next_packet():
(0, packet_reader.next_packet())
else:
packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet

View File

@@ -54,7 +54,7 @@ async def unbond(keystore_file, device_config, address):
@click.argument('device-config')
@click.argument('address', required=False)
def main(keystore_file, device_config, address):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(unbond(keystore_file, device_config, address))

View File

@@ -37,9 +37,9 @@ from colors import color
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
USB_DEVICE_CLASS_DEVICE = 0x00
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_CLASS_DEVICE = 0x00
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_DEVICE_CLASSES = {
@@ -69,22 +69,22 @@ USB_DEVICE_CLASSES = {
0x01: 'Bluetooth',
0x02: 'UWB',
0x03: 'Remote NDIS',
0x04: 'Bluetooth AMP'
0x04: 'Bluetooth AMP',
}
}
},
),
0xEF: 'Miscellaneous',
0xFE: 'Application Specific',
0xFF: 'Vendor Specific'
0xFF: 'Vendor Specific',
}
USB_ENDPOINT_IN = 0x80
USB_ENDPOINT_IN = 0x80
USB_ENDPOINT_TYPES = ['CONTROL', 'ISOCHRONOUS', 'BULK', 'INTERRUPT']
USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
@@ -95,18 +95,24 @@ def show_device_details(device):
for interface in configuration:
for setting in interface:
alternateSetting = setting.getAlternateSetting()
suffix = f'/{alternateSetting}' if interface.getNumSettings() > 1 else ''
suffix = (
f'/{alternateSetting}' if interface.getNumSettings() > 1 else ''
)
(class_string, subclass_string) = get_class_info(
setting.getClass(),
setting.getSubClass(),
setting.getProtocol()
setting.getClass(), setting.getSubClass(), setting.getProtocol()
)
details = f'({class_string}, {subclass_string})'
print(f' Interface: {setting.getNumber()}{suffix} {details}')
for endpoint in setting:
endpoint_type = USB_ENDPOINT_TYPES[endpoint.getAttributes() & 3]
endpoint_direction = 'OUT' if (endpoint.getAddress() & USB_ENDPOINT_IN == 0) else 'IN'
print(f' Endpoint 0x{endpoint.getAddress():02X}: {endpoint_type} {endpoint_direction}')
endpoint_direction = (
'OUT'
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
else 'IN'
)
print(
f' Endpoint 0x{endpoint.getAddress():02X}: {endpoint_type} {endpoint_direction}'
)
# -----------------------------------------------------------------------------
@@ -135,7 +141,11 @@ def get_class_info(cls, subclass, protocol):
# -----------------------------------------------------------------------------
def is_bluetooth_hci(device):
# Check if the device class indicates a match
if (device.getDeviceClass(), device.getDeviceSubClass(), device.getDeviceProtocol()) == USB_BT_HCI_CLASS_TUPLE:
if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
# If the device class is 'Device', look for a matching interface
@@ -143,7 +153,11 @@ def is_bluetooth_hci(device):
for configuration in device:
for interface in configuration:
for setting in interface:
if (setting.getClass(), setting.getSubClass(), setting.getProtocol()) == USB_BT_HCI_CLASS_TUPLE:
if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
return False
@@ -153,23 +167,21 @@ def is_bluetooth_hci(device):
@click.command()
@click.option('--verbose', is_flag=True, default=False, help='Print more details')
def main(verbose):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
with usb1.USBContext() as context:
bluetooth_device_count = 0
devices = {}
for device in context.getDeviceIterator(skip_on_error=True):
device_class = device.getDeviceClass()
device_class = device.getDeviceClass()
device_subclass = device.getDeviceSubClass()
device_protocol = device.getDeviceProtocol()
device_id = (device.getVendorID(), device.getProductID())
(device_class_string, device_subclass_string) = get_class_info(
device_class,
device_subclass,
device_protocol
device_class, device_subclass, device_protocol
)
try:
@@ -198,7 +210,9 @@ def main(verbose):
# Compute the different ways this can be referenced as a Bumble transport
bumble_transport_names = []
basic_transport_name = f'usb:{device.getVendorID():04X}:{device.getProductID():04X}'
basic_transport_name = (
f'usb:{device.getVendorID():04X}:{device.getProductID():04X}'
)
if device_is_bluetooth_hci:
bumble_transport_names.append(f'usb:{bluetooth_device_count - 1}')
@@ -206,17 +220,39 @@ def main(verbose):
if device_id not in devices:
bumble_transport_names.append(basic_transport_name)
else:
bumble_transport_names.append(f'{basic_transport_name}#{len(devices[device_id])}')
bumble_transport_names.append(
f'{basic_transport_name}#{len(devices[device_id])}'
)
if device_serial_number is not None:
if device_id not in devices or device_serial_number not in devices[device_id]:
bumble_transport_names.append(f'{basic_transport_name}/{device_serial_number}')
if (
device_id not in devices
or device_serial_number not in devices[device_id]
):
bumble_transport_names.append(
f'{basic_transport_name}/{device_serial_number}'
)
# Print the results
print(color(f'ID {device.getVendorID():04X}:{device.getProductID():04X}', fg=fg_color, bg=bg_color))
print(
color(
f'ID {device.getVendorID():04X}:{device.getProductID():04X}',
fg=fg_color,
bg=bg_color,
)
)
if bumble_transport_names:
print(color(' Bumble Transport Names:', 'blue'), ' or '.join(color(x, 'cyan' if device_is_bluetooth_hci else 'red') for x in bumble_transport_names))
print(color(' Bus/Device: ', 'green'), f'{device.getBusNumber():03}/{device.getDeviceAddress():03}')
print(
color(' Bumble Transport Names:', 'blue'),
' or '.join(
color(x, 'cyan' if device_is_bluetooth_hci else 'red')
for x in bumble_transport_names
),
)
print(
color(' Bus/Device: ', 'green'),
f'{device.getBusNumber():03}/{device.getDeviceAddress():03}',
)
print(color(' Class: ', 'green'), device_class_string)
print(color(' Subclass/Protocol: ', 'green'), device_subclass_string)
if device_serial_number is not None:

View File

@@ -30,7 +30,7 @@ from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from .core import (
BT_L2CAP_PROTOCOL_ID,
@@ -38,7 +38,7 @@ from .core import (
BT_AUDIO_SINK_SERVICE,
BT_AVDTP_PROTOCOL_ID,
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number
name_or_number,
)
@@ -51,6 +51,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
A2DP_SBC_CODEC_TYPE = 0x00
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01
@@ -127,6 +128,8 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
# fmt: on
# -----------------------------------------------------------------------------
def flags_to_list(flags, values):
@@ -140,58 +143,98 @@ def flags_to_list(flags, values):
# -----------------------------------------------------------------------------
def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)):
from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1]
return [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)),
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
])),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)
])),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.sequence([
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM)
]),
DataElement.sequence([
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int)
])
])),
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int)
])),
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
]
# -----------------------------------------------------------------------------
def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1]
return [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)),
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
])),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_AUDIO_SINK_SERVICE)
])),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.sequence([
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM)
]),
DataElement.sequence([
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int)
])
])),
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int)
])),
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
]
@@ -206,8 +249,8 @@ class SbcMediaCodecInformation(
'subbands',
'allocation_method',
'minimum_bitpool_value',
'maximum_bitpool_value'
]
'maximum_bitpool_value',
],
)
):
'''
@@ -215,36 +258,25 @@ class SbcMediaCodecInformation(
'''
BIT_FIELDS = 'u4u4u4u2u2u8u8'
SAMPLING_FREQUENCY_BITS = {
16000: 1 << 3,
32000: 1 << 2,
44100: 1 << 1,
48000: 1
}
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1
}
BLOCK_LENGTH_BITS = {
4: 1 << 3,
8: 1 << 2,
12: 1 << 1,
16: 1
}
SUBBANDS_BITS = {
4: 1 << 1,
8: 1
SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1,
}
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
ALLOCATION_METHOD_BITS = {
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1,
}
@staticmethod
def from_bytes(data):
return SbcMediaCodecInformation(*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data))
return SbcMediaCodecInformation(
*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)
)
@classmethod
def from_discrete_values(
@@ -255,16 +287,16 @@ class SbcMediaCodecInformation(
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value
maximum_bitpool_value,
):
return SbcMediaCodecInformation(
sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode = cls.CHANNEL_MODE_BITS[channel_mode],
block_length = cls.BLOCK_LENGTH_BITS[block_length],
subbands = cls.SUBBANDS_BITS[subbands],
allocation_method = cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value = minimum_bitpool_value,
maximum_bitpool_value = maximum_bitpool_value
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
block_length=cls.BLOCK_LENGTH_BITS[block_length],
subbands=cls.SUBBANDS_BITS[subbands],
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
)
@classmethod
@@ -276,16 +308,20 @@ class SbcMediaCodecInformation(
subbands,
allocation_methods,
minimum_bitpool_value,
maximum_bitpool_value
maximum_bitpool_value,
):
return SbcMediaCodecInformation(
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies),
channel_mode = sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
block_length = sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
subbands = sum(cls.SUBBANDS_BITS[x] for x in subbands),
allocation_method = sum(cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods),
minimum_bitpool_value = minimum_bitpool_value,
maximum_bitpool_value = maximum_bitpool_value
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
allocation_method=sum(
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
),
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
)
def __bytes__(self):
@@ -294,30 +330,25 @@ class SbcMediaCodecInformation(
def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join([
'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}'
')'
])
return '\n'.join(
[
'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
]
)
# -----------------------------------------------------------------------------
class AacMediaCodecInformation(
namedtuple(
'AacMediaCodecInformation',
[
'object_type',
'sampling_frequency',
'channels',
'vbr',
'bitrate'
]
['object_type', 'sampling_frequency', 'channels', 'vbr', 'bitrate'],
)
):
'''
@@ -326,13 +357,13 @@ class AacMediaCodecInformation(
BIT_FIELDS = 'u8u12u2p2u1u23'
OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
}
SAMPLING_FREQUENCY_BITS = {
8000: 1 << 11,
8000: 1 << 11,
11025: 1 << 10,
12000: 1 << 9,
16000: 1 << 8,
@@ -343,66 +374,65 @@ class AacMediaCodecInformation(
48000: 1 << 3,
64000: 1 << 2,
88200: 1 << 1,
96000: 1
}
CHANNELS_BITS = {
1: 1 << 1,
2: 1
96000: 1,
}
CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod
def from_bytes(data):
return AacMediaCodecInformation(*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data))
@classmethod
def from_discrete_values(
cls,
object_type,
sampling_frequency,
channels,
vbr,
bitrate
):
return AacMediaCodecInformation(
object_type = cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels = cls.CHANNELS_BITS[channels],
vbr = vbr,
bitrate = bitrate
*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data)
)
@classmethod
def from_lists(
cls,
object_types,
sampling_frequencies,
channels,
vbr,
bitrate
def from_discrete_values(
cls, object_type, sampling_frequency, channels, vbr, bitrate
):
return AacMediaCodecInformation(
object_type = sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies),
channels = sum(cls.CHANNELS_BITS[x] for x in channels),
vbr = vbr,
bitrate = bitrate
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels],
vbr=vbr,
bitrate=bitrate,
)
@classmethod
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channels=sum(cls.CHANNELS_BITS[x] for x in channels),
vbr=vbr,
bitrate=bitrate,
)
def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self)
def __str__(self):
object_types = ['MPEG_2_AAC_LC', 'MPEG_4_AAC_LC', 'MPEG_4_AAC_LTP', 'MPEG_4_AAC_SCALABLE', '[4]', '[5]', '[6]', '[7]']
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
'MPEG_4_AAC_LTP',
'MPEG_4_AAC_SCALABLE',
'[4]',
'[5]',
'[6]',
'[7]',
]
channels = [1, 2]
return '\n'.join([
'AacMediaCodecInformation(',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' vbr: {self.vbr}',
f' bitrate: {self.bitrate}'
')'
])
return '\n'.join(
[
'AacMediaCodecInformation(',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' vbr: {self.vbr}',
f' bitrate: {self.bitrate}' ')',
]
)
# -----------------------------------------------------------------------------
@@ -418,37 +448,33 @@ class VendorSpecificMediaCodecInformation:
def __init__(self, vendor_id, codec_id, value):
self.vendor_id = vendor_id
self.codec_id = codec_id
self.value = value
self.codec_id = codec_id
self.value = value
def __bytes__(self):
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self):
return '\n'.join([
'VendorSpecificMediaCodecInformation(',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
f' codec_id: {self.codec_id:04X}',
f' value: {self.value.hex()}'
')'
])
return '\n'.join(
[
'VendorSpecificMediaCodecInformation(',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
f' codec_id: {self.codec_id:04X}',
f' value: {self.value.hex()}' ')',
]
)
# -----------------------------------------------------------------------------
class SbcFrame:
def __init__(
self,
sampling_frequency,
block_count,
channel_mode,
subband_count,
payload
self, sampling_frequency, block_count, channel_mode, subband_count, payload
):
self.sampling_frequency = sampling_frequency
self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
@property
def sample_count(self):
@@ -487,24 +513,30 @@ class SbcParser:
# Extract some of the header fields
sampling_frequency = SBC_SAMPLING_FREQUENCIES[(header[1] >> 6) & 3]
blocks = 4 * (1 + ((header[1] >> 4) & 3))
channel_mode = (header[1] >> 2) & 3
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
subbands = 8 if ((header[1]) & 1) else 4
bitpool = header[2]
blocks = 4 * (1 + ((header[1] >> 4) & 3))
channel_mode = (header[1] >> 2) & 3
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
subbands = 8 if ((header[1]) & 1) else 4
bitpool = header[2]
# Compute the frame length
frame_length = 4 + (4 * subbands * channels) // 8
if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE):
frame_length += (blocks * channels * bitpool) // 8
else:
frame_length += ((1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0) * subbands + blocks * bitpool) // 8
frame_length += (
(1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0)
* subbands
+ blocks * bitpool
) // 8
# Read the rest of the frame
payload = header + await self.read(frame_length - 4)
# Emit the next frame
yield SbcFrame(sampling_frequency, blocks, channel_mode, subbands, payload)
yield SbcFrame(
sampling_frequency, blocks, channel_mode, subbands, payload
)
return generate_frames()
@@ -512,8 +544,8 @@ class SbcParser:
# -----------------------------------------------------------------------------
class SbcPacketSource:
def __init__(self, read, mtu, codec_capabilities):
self.read = read
self.mtu = mtu
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities
@property
@@ -522,9 +554,9 @@ class SbcPacketSource:
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
timestamp = 0
frames = []
frames_size = 0
timestamp = 0
frames = []
frames_size = 0
max_rtp_payload = self.mtu - 12 - 1
# NOTE: this doesn't support frame fragments
@@ -532,12 +564,19 @@ class SbcPacketSource:
async for frame in sbc_parser.frames:
print(frame)
if frames_size + len(frame.payload) > max_rtp_payload or len(frames) == 16:
if (
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == 16
):
# Need to flush what has been accumulated so far
# Emit a packet
sbc_payload = bytes([len(frames)]) + b''.join([frame.payload for frame in frames])
packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload)
sbc_payload = bytes([len(frames)]) + b''.join(
[frame.payload for frame in frames]
)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
packet.timestamp_seconds = timestamp / frame.sampling_frequency
yield packet

View File

@@ -31,6 +31,8 @@ from .hci import *
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
ATT_CID = 0x04
ATT_ERROR_RESPONSE = 0x01
@@ -166,6 +168,8 @@ HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'}
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) # noqa: E731
UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# fmt: on
# -----------------------------------------------------------------------------
# Utils
@@ -196,6 +200,7 @@ class ATT_PDU:
'''
See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU
'''
pdu_classes = {}
op_code = 0
@@ -274,11 +279,13 @@ class ATT_PDU:
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}),
('attribute_handle_in_error', HANDLE_FIELD_SPEC),
('error_code', {'size': 1, 'mapper': ATT_PDU.error_name})
])
@ATT_PDU.subclass(
[
('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}),
('attribute_handle_in_error', HANDLE_FIELD_SPEC),
('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}),
]
)
class ATT_Error_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
@@ -286,9 +293,7 @@ class ATT_Error_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('client_rx_mtu', 2)
])
@ATT_PDU.subclass([('client_rx_mtu', 2)])
class ATT_Exchange_MTU_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request
@@ -296,9 +301,7 @@ class ATT_Exchange_MTU_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('server_rx_mtu', 2)
])
@ATT_PDU.subclass([('server_rx_mtu', 2)])
class ATT_Exchange_MTU_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response
@@ -306,10 +309,9 @@ class ATT_Exchange_MTU_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC)
])
@ATT_PDU.subclass(
[('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC)]
)
class ATT_Find_Information_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -317,10 +319,7 @@ class ATT_Find_Information_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('format', 1),
('information_data', '*')
])
@ATT_PDU.subclass([('format', 1), ('information_data', '*')])
class ATT_Find_Information_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response
@@ -332,7 +331,7 @@ class ATT_Find_Information_Response(ATT_PDU):
uuid_size = 2 if self.format == 1 else 16
while offset + uuid_size <= len(self.information_data):
handle = struct.unpack_from('<H', self.information_data, offset)[0]
uuid = self.information_data[2 + offset:2 + offset + uuid_size]
uuid = self.information_data[2 + offset : 2 + offset + uuid_size]
self.information.append((handle, uuid))
offset += 2 + uuid_size
@@ -346,20 +345,33 @@ class ATT_Find_Information_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
('format', 1),
('information', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x])})
], ' ')
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('format', 1),
(
'information',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x]
)
},
),
],
' ',
)
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_FIELD_SPEC),
('attribute_value', '*')
])
@ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_FIELD_SPEC),
('attribute_value', '*'),
]
)
class ATT_Find_By_Type_Value_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request
@@ -367,9 +379,7 @@ class ATT_Find_By_Type_Value_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('handles_information_list', '*')
])
@ATT_PDU.subclass([('handles_information_list', '*')])
class ATT_Find_By_Type_Value_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response
@@ -379,7 +389,9 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
self.handles_information = []
offset = 0
while offset + 4 <= len(self.handles_information_list):
found_attribute_handle, group_end_handle = struct.unpack_from('<HH', self.handles_information_list, offset)
found_attribute_handle, group_end_handle = struct.unpack_from(
'<HH', self.handles_information_list, offset
)
self.handles_information.append((found_attribute_handle, group_end_handle))
offset += 4
@@ -393,18 +405,34 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
('handles_information', {'mapper': lambda x: ', '.join([f'0x{handle1:04X}-0x{handle2:04X}' for handle1, handle2 in x])})
], ' ')
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
(
'handles_information',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle1:04X}-0x{handle2:04X}'
for handle1, handle2 in x
]
)
},
)
],
' ',
)
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_16_FIELD_SPEC)
])
@ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_16_FIELD_SPEC),
]
)
class ATT_Read_By_Type_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request
@@ -412,10 +440,7 @@ class ATT_Read_By_Type_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('length', 1),
('attribute_data_list', '*')
])
@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
class ATT_Read_By_Type_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response
@@ -424,9 +449,15 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def parse_attribute_data_list(self):
self.attributes = []
offset = 0
while self.length != 0 and offset + self.length <= len(self.attribute_data_list):
attribute_handle, = struct.unpack_from('<H', self.attribute_data_list, offset)
attribute_value = self.attribute_data_list[offset + 2:offset + self.length]
while self.length != 0 and offset + self.length <= len(
self.attribute_data_list
):
(attribute_handle,) = struct.unpack_from(
'<H', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 2 : offset + self.length
]
self.attributes.append((attribute_handle, attribute_value))
offset += self.length
@@ -440,17 +471,26 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
('length', 1),
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{value.hex()}' for handle, value in x])})
], ' ')
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('length', 1),
(
'attributes',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{value.hex()}' for handle, value in x]
)
},
),
],
' ',
)
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC)
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC)])
class ATT_Read_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request
@@ -458,9 +498,7 @@ class ATT_Read_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_value', '*')])
class ATT_Read_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response
@@ -468,10 +506,7 @@ class ATT_Read_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2)
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('value_offset', 2)])
class ATT_Read_Blob_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -479,9 +514,7 @@ class ATT_Read_Blob_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('part_attribute_value', '*')
])
@ATT_PDU.subclass([('part_attribute_value', '*')])
class ATT_Read_Blob_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response
@@ -489,9 +522,7 @@ class ATT_Read_Blob_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('set_of_handles', '*')
])
@ATT_PDU.subclass([('set_of_handles', '*')])
class ATT_Read_Multiple_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request
@@ -499,9 +530,7 @@ class ATT_Read_Multiple_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('set_of_values', '*')
])
@ATT_PDU.subclass([('set_of_values', '*')])
class ATT_Read_Multiple_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response
@@ -509,11 +538,13 @@ class ATT_Read_Multiple_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_group_type', UUID_2_16_FIELD_SPEC)
])
@ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_group_type', UUID_2_16_FIELD_SPEC),
]
)
class ATT_Read_By_Group_Type_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -521,10 +552,7 @@ class ATT_Read_By_Group_Type_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('length', 1),
('attribute_data_list', '*')
])
@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
class ATT_Read_By_Group_Type_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response
@@ -533,10 +561,18 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def parse_attribute_data_list(self):
self.attributes = []
offset = 0
while self.length != 0 and offset + self.length <= len(self.attribute_data_list):
attribute_handle, end_group_handle = struct.unpack_from('<HH', self.attribute_data_list, offset)
attribute_value = self.attribute_data_list[offset + 4:offset + self.length]
self.attributes.append((attribute_handle, end_group_handle, attribute_value))
while self.length != 0 and offset + self.length <= len(
self.attribute_data_list
):
attribute_handle, end_group_handle = struct.unpack_from(
'<HH', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 4 : offset + self.length
]
self.attributes.append(
(attribute_handle, end_group_handle, attribute_value)
)
offset += self.length
def __init__(self, *args, **kwargs):
@@ -549,18 +585,29 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
('length', 1),
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}-0x{end:04X}:{value.hex()}' for handle, end, value in x])})
], ' ')
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('length', 1),
(
'attributes',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle:04X}-0x{end:04X}:{value.hex()}'
for handle, end, value in x
]
)
},
),
],
' ',
)
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
class ATT_Write_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request
@@ -576,10 +623,7 @@ class ATT_Write_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
class ATT_Write_Command(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command
@@ -587,11 +631,13 @@ class ATT_Write_Command(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
# ('authentication_signature', 'TODO')
])
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
# ('authentication_signature', 'TODO')
]
)
class ATT_Signed_Write_Command(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command
@@ -599,11 +645,13 @@ class ATT_Signed_Write_Command(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*')
])
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*'),
]
)
class ATT_Prepare_Write_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request
@@ -611,11 +659,13 @@ class ATT_Prepare_Write_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*')
])
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*'),
]
)
class ATT_Prepare_Write_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response
@@ -639,10 +689,7 @@ class ATT_Execute_Write_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
class ATT_Handle_Value_Notification(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification
@@ -650,10 +697,7 @@ class ATT_Handle_Value_Notification(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
class ATT_Handle_Value_Indication(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication
@@ -671,20 +715,20 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# -----------------------------------------------------------------------------
class Attribute(EventEmitter):
# Permission flags
READABLE = 0x01
WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10
READABLE = 0x01
WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10
WRITE_REQUIRES_AUTHENTICATION = 0x20
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
def __init__(self, attribute_type, permissions, value = b''):
def __init__(self, attribute_type, permissions, value=b''):
EventEmitter.__init__(self)
self.handle = 0
self.handle = 0
self.end_group_handle = 0
self.permissions = permissions
self.permissions = permissions
# Convert the type to a UUID object if it isn't already
if type(attribute_type) is str:

View File

@@ -26,7 +26,7 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError,
ProtocolError,
name_or_number
name_or_number,
)
from .a2dp import (
A2DP_CODEC_TYPE_NAMES,
@@ -35,7 +35,7 @@ from .a2dp import (
A2DP_SBC_CODEC_TYPE,
AacMediaCodecInformation,
SbcMediaCodecInformation,
VendorSpecificMediaCodecInformation
VendorSpecificMediaCodecInformation,
)
from . import sdp
@@ -48,6 +48,8 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
AVDTP_PSM = 0x0019
AVDTP_DEFAULT_RTX_SIG_TIMER = 5 # Seconds
@@ -195,6 +197,8 @@ AVDTP_STATE_NAMES = {
AVDTP_ABORTING_STATE: 'AVDTP_ABORTING_STATE'
}
# fmt: on
# -----------------------------------------------------------------------------
async def find_avdtp_service_with_sdp_client(sdp_client):
@@ -206,14 +210,11 @@ async def find_avdtp_service_with_sdp_client(sdp_client):
# Search for services with an Audio Sink service class
search_result = await sdp_client.search_attributes(
[BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE],
[
sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
]
[sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID],
)
for attribute_list in search_result:
profile_descriptor_list = sdp.ServiceAttribute.find_attribute_in_list(
attribute_list,
sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
attribute_list, sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if profile_descriptor_list:
for profile_descriptor in profile_descriptor_list.value:
@@ -251,17 +252,19 @@ class RealtimeClock:
class MediaPacket:
@staticmethod
def from_bytes(data):
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
version = (data[0] >> 6) & 0x03
padding = (data[0] >> 5) & 0x01
extension = (data[0] >> 4) & 0x01
csrc_count = data[0] & 0x0F
marker = (data[1] >> 7) & 0x01
payload_type = data[1] & 0x7F
sequence_number = struct.unpack_from('>H', data, 2)[0]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)]
payload = data[12 + csrc_count * 4:]
timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
]
payload = data[12 + csrc_count * 4 :]
return MediaPacket(
version,
@@ -273,7 +276,7 @@ class MediaPacket:
ssrc,
csrc_list,
payload_type,
payload
payload,
)
def __init__(
@@ -287,27 +290,29 @@ class MediaPacket:
ssrc,
csrc_list,
payload_type,
payload
payload,
):
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.version = version
self.padding = padding
self.extension = extension
self.marker = marker
self.sequence_number = sequence_number
self.timestamp = timestamp
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
self.timestamp = timestamp
self.ssrc = ssrc
self.csrc_list = csrc_list
self.payload_type = payload_type
self.payload = payload
def __bytes__(self):
header = (
bytes([
self.version << 6 | self.padding << 5 | self.extension << 4 | len(self.csrc_list),
self.marker << 7 | self.payload_type
]) +
struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc)
)
header = bytes(
[
self.version << 6
| self.padding << 5
| self.extension << 4
| len(self.csrc_list),
self.marker << 7 | self.payload_type,
]
) + struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc)
for csrc in self.csrc_list:
header += struct.pack('>I', csrc)
return header + self.payload
@@ -319,13 +324,13 @@ class MediaPacket:
# -----------------------------------------------------------------------------
class MediaPacketPump:
def __init__(self, packets, clock=RealtimeClock()):
self.packets = packets
self.clock = clock
self.packets = packets
self.clock = clock
self.pump_task = None
async def start(self, rtp_channel):
async def pump_packets():
start_time = 0
start_time = 0
start_timestamp = 0
try:
@@ -333,7 +338,7 @@ class MediaPacketPump:
async for packet in self.packets:
# Capture the timestamp of the first packet
if start_time == 0:
start_time = self.clock.now()
start_time = self.clock.now()
start_timestamp = packet.timestamp_seconds
# Wait until we can send
@@ -346,7 +351,9 @@ class MediaPacketPump:
# Emit
rtp_channel.send_pdu(bytes(packet))
logger.debug(f'{color(">>> sending RTP packet:", "green")} {packet}')
logger.debug(
f'{color(">>> sending RTP packet:", "green")} {packet}'
)
except asyncio.exceptions.CancelledError:
logger.debug('pump canceled')
@@ -368,67 +375,87 @@ class MessageAssembler:
self.reset()
def reset(self):
self.transaction_label = 0
self.message = None
self.message_type = 0
self.signal_identifier = 0
self.transaction_label = 0
self.message = None
self.message_type = 0
self.signal_identifier = 0
self.number_of_signal_packets = 0
self.packet_count = 0
self.packet_count = 0
def on_pdu(self, pdu):
self.packet_count += 1
transaction_label = pdu[0] >> 4
packet_type = (pdu[0] >> 2) & 3
message_type = pdu[0] & 3
packet_type = (pdu[0] >> 2) & 3
message_type = pdu[0] & 3
logger.debug(f'transaction_label={transaction_label}, packet_type={Protocol.packet_type_name(packet_type)}, message_type={Message.message_type_name(message_type)}')
if packet_type == Protocol.SINGLE_PACKET or packet_type == Protocol.START_PACKET:
logger.debug(
f'transaction_label={transaction_label}, packet_type={Protocol.packet_type_name(packet_type)}, message_type={Message.message_type_name(message_type)}'
)
if (
packet_type == Protocol.SINGLE_PACKET
or packet_type == Protocol.START_PACKET
):
if self.message is not None:
# The previous message has not been terminated
logger.warning('received a start or single packet when expecting an end or continuation')
logger.warning(
'received a start or single packet when expecting an end or continuation'
)
self.reset()
self.transaction_label = transaction_label
self.signal_identifier = pdu[1] & 0x3F
self.message_type = message_type
self.message_type = message_type
if packet_type == Protocol.SINGLE_PACKET:
self.message = pdu[2:]
self.on_message_complete()
else:
self.number_of_signal_packets = pdu[2]
self.message = pdu[3:]
elif packet_type == Protocol.CONTINUE_PACKET or packet_type == Protocol.END_PACKET:
self.message = pdu[3:]
elif (
packet_type == Protocol.CONTINUE_PACKET
or packet_type == Protocol.END_PACKET
):
if self.packet_count == 0:
logger.warning('unexpected continuation')
return
if transaction_label != self.transaction_label:
logger.warning(f'transaction label mismatch: expected {self.transaction_label}, received {transaction_label}')
logger.warning(
f'transaction label mismatch: expected {self.transaction_label}, received {transaction_label}'
)
return
if message_type != self.message_type:
logger.warning(f'message type mismatch: expected {self.message_type}, received {message_type}')
logger.warning(
f'message type mismatch: expected {self.message_type}, received {message_type}'
)
return
self.message += pdu[1:]
if packet_type == Protocol.END_PACKET:
if self.packet_count != self.number_of_signal_packets:
logger.warning(f'incomplete fragmented message: expected {self.number_of_signal_packets} packets, received {self.packet_count}')
logger.warning(
f'incomplete fragmented message: expected {self.number_of_signal_packets} packets, received {self.packet_count}'
)
self.reset()
return
self.on_message_complete()
else:
if self.packet_count > self.number_of_signal_packets:
logger.warning(f'too many packets: expected {self.number_of_signal_packets}, received {self.packet_count}')
logger.warning(
f'too many packets: expected {self.number_of_signal_packets}, received {self.packet_count}'
)
self.reset()
return
def on_message_complete(self):
message = Message.create(self.signal_identifier, self.message_type, self.message)
message = Message.create(
self.signal_identifier, self.message_type, self.message
)
try:
self.callback(self.transaction_label, message)
@@ -460,12 +487,14 @@ class ServiceCapabilities:
def parse_capabilities(payload):
capabilities = []
while payload:
service_category = payload[0]
service_category = payload[0]
length_of_service_capabilities = payload[1]
service_capabilities_bytes = payload[2:2 + length_of_service_capabilities]
capabilities.append(ServiceCapabilities.create(service_category, service_capabilities_bytes))
service_capabilities_bytes = payload[2 : 2 + length_of_service_capabilities]
capabilities.append(
ServiceCapabilities.create(service_category, service_capabilities_bytes)
)
payload = payload[2 + length_of_service_capabilities:]
payload = payload[2 + length_of_service_capabilities :]
return capabilities
@@ -473,21 +502,24 @@ class ServiceCapabilities:
def serialize_capabilities(capabilities):
serialized = b''
for item in capabilities:
serialized += bytes([
item.service_category,
len(item.service_capabilities_bytes)
]) + item.service_capabilities_bytes
serialized += (
bytes([item.service_category, len(item.service_capabilities_bytes)])
+ item.service_capabilities_bytes
)
return serialized
def init_from_bytes(self):
pass
def __init__(self, service_category, service_capabilities_bytes=b''):
self.service_category = service_category
self.service_category = service_category
self.service_capabilities_bytes = service_capabilities_bytes
def to_string(self, details=[]):
attributes = ','.join([name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)] + details)
attributes = ','.join(
[name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)]
+ details
)
return f'ServiceCapabilities({attributes})'
def __str__(self):
@@ -501,31 +533,39 @@ class ServiceCapabilities:
# -----------------------------------------------------------------------------
class MediaCodecCapabilities(ServiceCapabilities):
def init_from_bytes(self):
self.media_type = self.service_capabilities_bytes[0]
self.media_codec_type = self.service_capabilities_bytes[1]
self.media_type = self.service_capabilities_bytes[0]
self.media_codec_type = self.service_capabilities_bytes[1]
self.media_codec_information = self.service_capabilities_bytes[2:]
if self.media_codec_type == A2DP_SBC_CODEC_TYPE:
self.media_codec_information = SbcMediaCodecInformation.from_bytes(self.media_codec_information)
self.media_codec_information = SbcMediaCodecInformation.from_bytes(
self.media_codec_information
)
elif self.media_codec_type == A2DP_MPEG_2_4_AAC_CODEC_TYPE:
self.media_codec_information = AacMediaCodecInformation.from_bytes(self.media_codec_information)
self.media_codec_information = AacMediaCodecInformation.from_bytes(
self.media_codec_information
)
elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE:
self.media_codec_information = VendorSpecificMediaCodecInformation.from_bytes(self.media_codec_information)
self.media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(
self.media_codec_information
)
)
def __init__(self, media_type, media_codec_type, media_codec_information):
super().__init__(
AVDTP_MEDIA_CODEC_SERVICE_CATEGORY,
bytes([media_type, media_codec_type]) + bytes(media_codec_information)
bytes([media_type, media_codec_type]) + bytes(media_codec_information),
)
self.media_type = media_type
self.media_codec_type = media_codec_type
self.media_type = media_type
self.media_codec_type = media_codec_type
self.media_codec_information = media_codec_information
def __str__(self):
details = [
f'media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
f'codec={name_or_number(A2DP_CODEC_TYPE_NAMES, self.media_codec_type)}',
f'codec_info={self.media_codec_information.hex() if type(self.media_codec_information) is bytes else str(self.media_codec_information)}'
f'codec_info={self.media_codec_information.hex() if type(self.media_codec_information) is bytes else str(self.media_codec_information)}',
]
return self.to_string(details)
@@ -535,37 +575,33 @@ class EndPointInfo:
@staticmethod
def from_bytes(payload):
return EndPointInfo(
payload[0] >> 2,
payload[0] >> 1 & 1,
payload[1] >> 4,
payload[1] >> 3 & 1
payload[0] >> 2, payload[0] >> 1 & 1, payload[1] >> 4, payload[1] >> 3 & 1
)
def __bytes__(self):
return bytes([
self.seid << 2 | self.in_use << 1,
self.media_type << 4 | self.tsep << 3
])
return bytes(
[self.seid << 2 | self.in_use << 1, self.media_type << 4 | self.tsep << 3]
)
def __init__(self, seid, in_use, media_type, tsep):
self.seid = seid
self.in_use = in_use
self.seid = seid
self.in_use = in_use
self.media_type = media_type
self.tsep = tsep
self.tsep = tsep
# -----------------------------------------------------------------------------
class Message:
COMMAND = 0
GENERAL_REJECT = 1
COMMAND = 0
GENERAL_REJECT = 1
RESPONSE_ACCEPT = 2
RESPONSE_REJECT = 3
MESSAGE_TYPE_NAMES = {
COMMAND: 'COMMAND',
GENERAL_REJECT: 'GENERAL_REJECT',
COMMAND: 'COMMAND',
GENERAL_REJECT: 'GENERAL_REJECT',
RESPONSE_ACCEPT: 'RESPONSE_ACCEPT',
RESPONSE_REJECT: 'RESPONSE_REJECT'
RESPONSE_REJECT: 'RESPONSE_REJECT',
}
subclasses = {} # Subclasses, by signal identifier and message type
@@ -603,7 +639,9 @@ class Message:
break
# Register the subclass
Message.subclasses.setdefault(cls.signal_identifier, {})[cls.message_type] = cls
Message.subclasses.setdefault(cls.signal_identifier, {})[
cls.message_type
] = cls
return cls
@@ -635,7 +673,7 @@ class Message:
pass
def __init__(self, payload=b''):
self.payload = payload
self.payload = payload
def to_string(self, details):
base = f'{color(f"{name_or_number(AVDTP_SIGNAL_NAMES, self.signal_identifier)}_{Message.message_type_name(self.message_type)}", "yellow")}'
@@ -643,7 +681,11 @@ class Message:
if type(details) is str:
return f'{base}: {details}'
else:
return base + ':\n' + '\n'.join([' ' + color(detail, 'cyan') for detail in details])
return (
base
+ ':\n'
+ '\n'.join([' ' + color(detail, 'cyan') for detail in details])
)
else:
return base
@@ -682,9 +724,7 @@ class Simple_Reject(Message):
self.payload = bytes([self.error_code])
def __str__(self):
details = [
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'
]
details = [f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}']
return self.to_string(details)
@@ -707,11 +747,13 @@ class Discover_Response(Message):
self.endpoints = []
endpoint_count = len(self.payload) // 2
for i in range(endpoint_count):
self.endpoints.append(EndPointInfo.from_bytes(self.payload[i * 2:(i + 1) * 2]))
self.endpoints.append(
EndPointInfo.from_bytes(self.payload[i * 2 : (i + 1) * 2])
)
def __init__(self, endpoints):
self.endpoints = endpoints
self.payload = b''.join([bytes(endpoint) for endpoint in endpoints])
self.payload = b''.join([bytes(endpoint) for endpoint in endpoints])
def __str__(self):
details = []
@@ -721,7 +763,7 @@ class Discover_Response(Message):
f'ACP SEID: {endpoint.seid}',
f' in_use: {endpoint.in_use}',
f' media_type: {name_or_number(AVDTP_MEDIA_TYPE_NAMES, endpoint.media_type)}',
f' tsep: {name_or_number(AVDTP_TSEP_NAMES, endpoint.tsep)}'
f' tsep: {name_or_number(AVDTP_TSEP_NAMES, endpoint.tsep)}',
]
)
return self.to_string(details)
@@ -794,21 +836,22 @@ class Set_Configuration_Command(Message):
'''
def init_from_payload(self):
self.acp_seid = self.payload[0] >> 2
self.int_seid = self.payload[1] >> 2
self.acp_seid = self.payload[0] >> 2
self.int_seid = self.payload[1] >> 2
self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[2:])
def __init__(self, acp_seid, int_seid, capabilities):
self.acp_seid = acp_seid
self.int_seid = int_seid
self.acp_seid = acp_seid
self.int_seid = int_seid
self.capabilities = capabilities
self.payload = bytes([acp_seid << 2, int_seid << 2]) + ServiceCapabilities.serialize_capabilities(capabilities)
self.payload = bytes(
[acp_seid << 2, int_seid << 2]
) + ServiceCapabilities.serialize_capabilities(capabilities)
def __str__(self):
details = [
f'ACP SEID: {self.acp_seid}',
f'INT SEID: {self.int_seid}'
] + [str(capability) for capability in self.capabilities]
details = [f'ACP SEID: {self.acp_seid}', f'INT SEID: {self.int_seid}'] + [
str(capability) for capability in self.capabilities
]
return self.to_string(details)
@@ -829,17 +872,17 @@ class Set_Configuration_Reject(Message):
def init_from_payload(self):
self.service_category = self.payload[0]
self.error_code = self.payload[1]
self.error_code = self.payload[1]
def __init__(self, service_category, error_code):
self.service_category = service_category
self.error_code = error_code
self.payload = bytes([service_category, self.error_code])
self.error_code = error_code
self.payload = bytes([service_category, self.error_code])
def __str__(self):
details = [
f'service_category: {name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)}',
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}',
]
return self.to_string(details)
@@ -887,7 +930,7 @@ class Reconfigure_Command(Message):
'''
def init_from_payload(self):
self.acp_seid = self.payload[0] >> 2
self.acp_seid = self.payload[0] >> 2
self.capabilities = ServiceCapabilities.parse_capabilities(self.payload[1:])
def __str__(self):
@@ -971,18 +1014,18 @@ class Start_Reject(Message):
'''
def init_from_payload(self):
self.acp_seid = self.payload[0] >> 2
self.acp_seid = self.payload[0] >> 2
self.error_code = self.payload[1]
def __init__(self, acp_seid, error_code):
self.acp_seid = acp_seid
self.acp_seid = acp_seid
self.error_code = error_code
self.payload = bytes([self.acp_seid << 2, self.error_code])
self.payload = bytes([self.acp_seid << 2, self.error_code])
def __str__(self):
details = [
f'acp_seid: {self.acp_seid}',
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}',
]
return self.to_string(details)
@@ -1095,13 +1138,10 @@ class DelayReport_Command(Message):
def init_from_payload(self):
self.acp_seid = self.payload[0] >> 2
self.delay = (self.payload[1] << 8) | (self.payload[2])
self.delay = (self.payload[1] << 8) | (self.payload[2])
def __str__(self):
return self.to_string([
f'ACP_SEID: {self.acp_seid}',
f'delay: {self.delay}'
])
return self.to_string([f'ACP_SEID: {self.acp_seid}', f'delay: {self.delay}'])
# -----------------------------------------------------------------------------
@@ -1122,16 +1162,16 @@ class DelayReport_Reject(Simple_Reject):
# -----------------------------------------------------------------------------
class Protocol:
SINGLE_PACKET = 0
START_PACKET = 1
SINGLE_PACKET = 0
START_PACKET = 1
CONTINUE_PACKET = 2
END_PACKET = 3
END_PACKET = 3
PACKET_TYPE_NAMES = {
SINGLE_PACKET: 'SINGLE_PACKET',
START_PACKET: 'START_PACKET',
SINGLE_PACKET: 'SINGLE_PACKET',
START_PACKET: 'START_PACKET',
CONTINUE_PACKET: 'CONTINUE_PACKET',
END_PACKET: 'END_PACKET'
END_PACKET: 'END_PACKET',
}
@staticmethod
@@ -1148,18 +1188,18 @@ class Protocol:
return protocol
def __init__(self, l2cap_channel, version=(1, 3)):
self.l2cap_channel = l2cap_channel
self.version = version
self.rtx_sig_timer = AVDTP_DEFAULT_RTX_SIG_TIMER
self.message_assembler = MessageAssembler(self.on_message)
self.transaction_results = [None] * 16 # Futures for up to 16 transactions
self.l2cap_channel = l2cap_channel
self.version = version
self.rtx_sig_timer = AVDTP_DEFAULT_RTX_SIG_TIMER
self.message_assembler = MessageAssembler(self.on_message)
self.transaction_results = [None] * 16 # Futures for up to 16 transactions
self.transaction_semaphore = asyncio.Semaphore(16)
self.transaction_count = 0
self.channel_acceptor = None
self.channel_connector = None
self.local_endpoints = [] # Local endpoints, with contiguous seid values
self.remote_endpoints = {} # Remote stream endpoints, by seid
self.streams = {} # Streams, by seid
self.transaction_count = 0
self.channel_acceptor = None
self.channel_connector = None
self.local_endpoints = [] # Local endpoints, with contiguous seid values
self.remote_endpoints = {} # Remote stream endpoints, by seid
self.streams = {} # Streams, by seid
# Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu
@@ -1205,7 +1245,9 @@ class Protocol:
response = await self.send_command(Discover_Command())
for endpoint_entry in response.endpoints:
logger.debug(f'getting endpoint capabilities for endpoint {endpoint_entry.seid}')
logger.debug(
f'getting endpoint capabilities for endpoint {endpoint_entry.seid}'
)
get_capabilities_response = await self.get_capabilities(endpoint_entry.seid)
endpoint = DiscoveredStreamEndPoint(
self,
@@ -1213,7 +1255,7 @@ class Protocol:
endpoint_entry.media_type,
endpoint_entry.tsep,
endpoint_entry.in_use,
get_capabilities_response.capabilities
get_capabilities_response.capabilities,
)
self.remote_endpoints[endpoint_entry.seid] = endpoint
@@ -1221,14 +1263,27 @@ class Protocol:
def find_remote_sink_by_codec(self, media_type, codec_type):
for endpoint in self.remote_endpoints.values():
if not endpoint.in_use and endpoint.media_type == media_type and endpoint.tsep == AVDTP_TSEP_SNK:
if (
not endpoint.in_use
and endpoint.media_type == media_type
and endpoint.tsep == AVDTP_TSEP_SNK
):
has_media_transport = False
has_codec = False
for capabilities in endpoint.capabilities:
if capabilities.service_category == AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY:
if (
capabilities.service_category
== AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY
):
has_media_transport = True
elif capabilities.service_category == AVDTP_MEDIA_CODEC_SERVICE_CATEGORY:
if capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE and capabilities.media_codec_type == codec_type:
elif (
capabilities.service_category
== AVDTP_MEDIA_CODEC_SERVICE_CATEGORY
):
if (
capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE
and capabilities.media_codec_type == codec_type
):
has_codec = True
if has_media_transport and has_codec:
return endpoint
@@ -1237,7 +1292,9 @@ class Protocol:
self.message_assembler.on_pdu(pdu)
def on_message(self, transaction_label, message):
logger.debug(f'{color("<<< Received AVDTP message", "magenta")}: [{transaction_label}] {message}')
logger.debug(
f'{color("<<< Received AVDTP message", "magenta")}: [{transaction_label}] {message}'
)
# Check that the identifier is not reserved
if message.signal_identifier == 0:
@@ -1245,7 +1302,10 @@ class Protocol:
return
# Check that the identifier is valid
if message.signal_identifier < 0 or message.signal_identifier > AVDTP_DELAYREPORT:
if (
message.signal_identifier < 0
or message.signal_identifier > AVDTP_DELAYREPORT
):
logger.warning('!!! invalid signal identifier')
self.send_message(transaction_label, General_Reject())
@@ -1258,7 +1318,9 @@ class Protocol:
response = handler(message)
self.send_message(transaction_label, response)
except Exception as error:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
logger.warning(
f'{color("!!! Exception in handler:", "red")} {error}'
)
else:
logger.warning('unhandled command')
else:
@@ -1281,8 +1343,12 @@ class Protocol:
logger.debug(color('<<< L2CAP channel open', 'magenta'))
def send_message(self, transaction_label, message):
logger.debug(f'{color(">>> Sending AVDTP message", "magenta")}: [{transaction_label}] {message}')
max_fragment_size = self.l2cap_channel.mtu - 3 # Enough space for a 3-byte start packet header
logger.debug(
f'{color(">>> Sending AVDTP message", "magenta")}: [{transaction_label}] {message}'
)
max_fragment_size = (
self.l2cap_channel.mtu - 3
) # Enough space for a 3-byte start packet header
payload = message.payload
if len(payload) + 2 <= self.l2cap_channel.mtu:
# Fits in a single packet
@@ -1292,13 +1358,19 @@ class Protocol:
done = False
while not done:
first_header_byte = transaction_label << 4 | packet_type << 2 | message.message_type
first_header_byte = (
transaction_label << 4 | packet_type << 2 | message.message_type
)
if packet_type == self.SINGLE_PACKET:
header = bytes([first_header_byte, message.signal_identifier])
elif packet_type == self.START_PACKET:
packet_count = (max_fragment_size - 1 + len(payload)) // max_fragment_size
header = bytes([first_header_byte, message.signal_identifier, packet_count])
packet_count = (
max_fragment_size - 1 + len(payload)
) // max_fragment_size
header = bytes(
[first_header_byte, message.signal_identifier, packet_count]
)
else:
header = bytes([first_header_byte])
@@ -1308,7 +1380,11 @@ class Protocol:
# Prepare for the next packet
payload = payload[max_fragment_size:]
if payload:
packet_type = self.CONTINUE_PACKET if payload > max_fragment_size else self.END_PACKET
packet_type = (
self.CONTINUE_PACKET
if payload > max_fragment_size
else self.END_PACKET
)
else:
done = True
@@ -1322,7 +1398,10 @@ class Protocol:
response = await transaction_result
# Check for errors
if response.message_type == Message.GENERAL_REJECT or response.message_type == Message.RESPONSE_REJECT:
if (
response.message_type == Message.GENERAL_REJECT
or response.message_type == Message.RESPONSE_REJECT
):
raise ProtocolError(response.error_code, 'avdtp')
return response
@@ -1340,7 +1419,7 @@ class Protocol:
self.transaction_count += 1
return (transaction_label, transaction_result)
assert(False) # Should never reach this
assert False # Should never reach this
async def get_capabilities(self, seid):
if self.version > (1, 2):
@@ -1349,7 +1428,9 @@ class Protocol:
return await self.send_command(Get_Capabilities_Command(seid))
async def set_configuration(self, acp_seid, int_seid, capabilities):
return await self.send_command(Set_Configuration_Command(acp_seid, int_seid, capabilities))
return await self.send_command(
Set_Configuration_Command(acp_seid, int_seid, capabilities)
)
async def get_configuration(self, seid):
response = await self.send_command(Get_Configuration_Command(seid))
@@ -1537,6 +1618,7 @@ class Listener(EventEmitter):
server = Protocol(channel, self.version)
self.set_server(channel.connection, server)
self.emit('connection', server)
channel.on('open', on_channel_open)
@@ -1562,8 +1644,7 @@ class Stream:
raise InvalidStateError('current state is not IDLE')
await self.remote_endpoint.set_configuration(
self.local_endpoint.seid,
self.local_endpoint.configuration
self.local_endpoint.seid, self.local_endpoint.configuration
)
self.change_state(AVDTP_CONFIGURED_STATE)
@@ -1639,7 +1720,11 @@ class Stream:
self.change_state(AVDTP_CONFIGURED_STATE)
def on_get_configuration_command(self, configuration):
if self.state not in {AVDTP_CONFIGURED_STATE, AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE}:
if self.state not in {
AVDTP_CONFIGURED_STATE,
AVDTP_OPEN_STATE,
AVDTP_STREAMING_STATE,
}:
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return self.local_endpoint.on_get_configuration_command(configuration)
@@ -1718,7 +1803,7 @@ class Stream:
def on_l2cap_connection(self, channel):
logger.debug(color('<<< stream channel connected', 'magenta'))
self.rtp_channel = channel
channel.on('open', self.on_l2cap_channel_open)
channel.on('open', self.on_l2cap_channel_open)
channel.on('close', self.on_l2cap_channel_close)
# We don't need more channels
@@ -1744,11 +1829,11 @@ class Stream:
remote_endpoint must be a subclass of StreamEndPointProxy
'''
self.protocol = protocol
self.local_endpoint = local_endpoint
self.protocol = protocol
self.local_endpoint = local_endpoint
self.remote_endpoint = remote_endpoint
self.rtp_channel = None
self.state = AVDTP_IDLE_STATE
self.rtp_channel = None
self.state = AVDTP_IDLE_STATE
local_endpoint.stream = self
local_endpoint.in_use = 1
@@ -1760,38 +1845,36 @@ class Stream:
# -----------------------------------------------------------------------------
class StreamEndPoint:
def __init__(self, seid, media_type, tsep, in_use, capabilities):
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self.in_use = in_use
self.seid = seid
self.media_type = media_type
self.tsep = tsep
self.in_use = in_use
self.capabilities = capabilities
def __str__(self):
return '\n'.join([
'SEP(',
f' seid={self.seid}',
f' media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
f' tsep={name_or_number(AVDTP_TSEP_NAMES, self.tsep)}',
f' in_use={self.in_use}',
' capabilities=[',
'\n'.join([f' {x}' for x in self.capabilities]),
' ]',
')'
])
return '\n'.join(
[
'SEP(',
f' seid={self.seid}',
f' media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
f' tsep={name_or_number(AVDTP_TSEP_NAMES, self.tsep)}',
f' in_use={self.in_use}',
' capabilities=[',
'\n'.join([f' {x}' for x in self.capabilities]),
' ]',
')',
]
)
# -----------------------------------------------------------------------------
class StreamEndPointProxy:
def __init__(self, protocol, seid):
self.seid = seid
self.seid = seid
self.protocol = protocol
async def set_configuration(self, int_seid, configuration):
return await self.protocol.set_configuration(
self.seid,
int_seid,
configuration
)
return await self.protocol.set_configuration(self.seid, int_seid, configuration)
async def open(self):
return await self.protocol.open(self.seid)
@@ -1818,11 +1901,13 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
# -----------------------------------------------------------------------------
class LocalStreamEndPoint(StreamEndPoint):
def __init__(self, protocol, seid, media_type, tsep, capabilities, configuration=[]):
def __init__(
self, protocol, seid, media_type, tsep, capabilities, configuration=[]
):
super().__init__(seid, media_type, tsep, 0, capabilities)
self.protocol = protocol
self.protocol = protocol
self.configuration = configuration
self.stream = None
self.stream = None
async def start(self):
pass
@@ -1866,9 +1951,17 @@ class LocalSource(LocalStreamEndPoint, EventEmitter):
def __init__(self, protocol, seid, codec_capabilities, packet_pump):
capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
codec_capabilities
codec_capabilities,
]
LocalStreamEndPoint.__init__(self, protocol, seid, codec_capabilities.media_type, AVDTP_TSEP_SRC, capabilities, capabilities)
LocalStreamEndPoint.__init__(
self,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SRC,
capabilities,
capabilities,
)
EventEmitter.__init__(self)
self.packet_pump = packet_pump
@@ -1901,9 +1994,16 @@ class LocalSink(LocalStreamEndPoint, EventEmitter):
def __init__(self, protocol, seid, codec_capabilities):
capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
codec_capabilities
codec_capabilities,
]
LocalStreamEndPoint.__init__(self, protocol, seid, codec_capabilities.media_type, AVDTP_TSEP_SNK, capabilities)
LocalStreamEndPoint.__init__(
self,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SNK,
capabilities,
)
EventEmitter.__init__(self)
def on_set_configuration_command(self, configuration):
@@ -1917,5 +2017,7 @@ class LocalSink(LocalStreamEndPoint, EventEmitter):
def on_avdtp_packet(self, packet):
rtp_packet = MediaPacket.from_bytes(packet)
logger.debug(f'{color("<<< RTP Packet:", "green")} {rtp_packet} {rtp_packet.payload[:16].hex()}')
logger.debug(
f'{color("<<< RTP Packet:", "green")} {rtp_packet} {rtp_packet.payload[:16].hex()}'
)
self.emit('rtp_packet', rtp_packet)

View File

@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
class HCI_Bridge:
class Forwarder:
def __init__(self, hci_sink, sender_hci_sink, packet_filter, trace):
self.hci_sink = hci_sink
self.hci_sink = hci_sink
self.sender_hci_sink = sender_hci_sink
self.packet_filter = packet_filter
self.trace = trace
self.packet_filter = packet_filter
self.trace = trace
def on_packet(self, packet):
# Convert the packet bytes to an object
@@ -61,15 +61,15 @@ class HCI_Bridge:
hci_host_sink,
hci_controller_source,
hci_controller_sink,
host_to_controller_filter = None,
controller_to_host_filter = None
host_to_controller_filter=None,
controller_to_host_filter=None,
):
tracer = PacketTracer(emit_message=logger.info)
host_to_controller_forwarder = HCI_Bridge.Forwarder(
hci_controller_sink,
hci_host_sink,
host_to_controller_filter,
lambda packet: tracer.trace(packet, 0)
lambda packet: tracer.trace(packet, 0),
)
hci_host_source.set_packet_sink(host_to_controller_forwarder)
@@ -77,6 +77,6 @@ class HCI_Bridge:
hci_host_sink,
hci_controller_sink,
controller_to_host_filter,
lambda packet: tracer.trace(packet, 1)
lambda packet: tracer.trace(packet, 1),
)
hci_controller_source.set_packet_sink(controller_to_host_forwarder)

View File

@@ -2704,5 +2704,5 @@ COMPANY_IDENTIFIERS = {
0x0A7C: "WAFERLOCK",
0x0A7D: "Freedman Electronics Pty Ltd",
0x0A7E: "Keba AG",
0x0A7F: "Intuity Medical"
}
0x0A7F: "Intuity Medical",
}

View File

@@ -39,67 +39,79 @@ class DataObject:
# -----------------------------------------------------------------------------
class Connection:
def __init__(self, controller, handle, role, peer_address, link):
self.controller = controller
self.handle = handle
self.role = role
self.controller = controller
self.handle = handle
self.role = role
self.peer_address = peer_address
self.link = link
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.link = link
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet)
self.controller.send_hci_packet(HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)]))
self.controller.send_hci_packet(
HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)])
)
def on_acl_pdu(self, data):
if self.link:
self.link.send_acl_data(self.controller.random_address, self.peer_address, data)
self.link.send_acl_data(
self.controller.random_address, self.peer_address, data
)
# -----------------------------------------------------------------------------
class Controller:
def __init__(self, name, host_source = None, host_sink = None, link = None):
self.name = name
def __init__(self, name, host_source=None, host_sink=None, link=None):
self.name = name
self.hci_sink = None
self.link = link
self.link = link
self.central_connections = {} # Connections where this controller is the central
self.peripheral_connections = {} # Connections where this controller is the peripheral
self.central_connections = (
{}
) # Connections where this controller is the central
self.peripheral_connections = (
{}
) # Connections where this controller is the peripheral
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.hci_revision = 0
self.lmp_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.lmp_subversion = 0
self.lmp_features = bytes.fromhex('0000000060000000') # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF
self.hc_le_data_packet_length = 27
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.hci_revision = 0
self.lmp_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.lmp_subversion = 0
self.lmp_features = bytes.fromhex(
'0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF
self.hc_le_data_packet_length = 27
self.hc_total_num_le_data_packets = 64
self.supported_commands = bytes.fromhex('2000800000c000000000e40000002822000000000000040000f7ffff7f00000030f0f9ff01008004000000000000000000000000000000000000000000000000')
self.le_features = bytes.fromhex('ff49010000000000')
self.le_states = bytes.fromhex('ffff3fffff030000')
self.supported_commands = bytes.fromhex(
'2000800000c000000000e40000002822000000000000040000f7ffff7f00000030f0f9ff01008004000000000000000000000000000000000000000000000000'
)
self.le_features = bytes.fromhex('ff49010000000000')
self.le_states = bytes.fromhex('ffff3fffff030000')
self.advertising_channel_tx_power = 0
self.filter_accept_list_size = 8
self.resolving_list_size = 8
self.supported_max_tx_octets = 27
self.supported_max_tx_time = 10000 # microseconds
self.supported_max_rx_octets = 27
self.supported_max_rx_time = 10000 # microseconds
self.suggested_max_tx_octets = 27
self.suggested_max_tx_time = 0x0148 # microseconds
self.default_phy = bytes([0, 0, 0])
self.le_scan_type = 0
self.le_scan_interval = 0x10
self.le_scan_window = 0x10
self.le_scan_enable = 0
self.le_scan_own_address_type = Address.RANDOM_DEVICE_ADDRESS
self.le_scanning_filter_policy = 0
self.le_scan_response_data = None
self.le_address_resolution = False
self.le_rpa_timeout = 0
self.sync_flow_control = False
self.local_name = 'Bumble'
self.filter_accept_list_size = 8
self.resolving_list_size = 8
self.supported_max_tx_octets = 27
self.supported_max_tx_time = 10000 # microseconds
self.supported_max_rx_octets = 27
self.supported_max_rx_time = 10000 # microseconds
self.suggested_max_tx_octets = 27
self.suggested_max_tx_time = 0x0148 # microseconds
self.default_phy = bytes([0, 0, 0])
self.le_scan_type = 0
self.le_scan_interval = 0x10
self.le_scan_window = 0x10
self.le_scan_enable = 0
self.le_scan_own_address_type = Address.RANDOM_DEVICE_ADDRESS
self.le_scanning_filter_policy = 0
self.le_scan_response_data = None
self.le_address_resolution = False
self.le_rpa_timeout = 0
self.sync_flow_control = False
self.local_name = 'Bumble'
self.advertising_interval = 2000 # Fixed for now
self.advertising_data = None
self.advertising_interval = 2000 # Fixed for now
self.advertising_data = None
self.advertising_timer_handle = None
self._random_address = Address('00:00:00:00:00:00')
@@ -162,7 +174,9 @@ class Controller:
self.on_hci_packet(HCI_Packet.from_bytes(packet))
def on_hci_packet(self, packet):
logger.debug(f'{color("<<<", "blue")} [{self.name}] {color("HOST -> CONTROLLER", "blue")}: {packet}')
logger.debug(
f'{color("<<<", "blue")} [{self.name}] {color("HOST -> CONTROLLER", "blue")}: {packet}'
)
# If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET:
@@ -179,11 +193,13 @@ class Controller:
handler = getattr(self, handler_name, self.on_hci_command)
result = handler(command)
if type(result) is bytes:
self.send_hci_packet(HCI_Command_Complete_Event(
num_hci_command_packets = 1,
command_opcode = command.op_code,
return_parameters = result
))
self.send_hci_packet(
HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=command.op_code,
return_parameters=result,
)
)
def on_hci_event_packet(self, event):
logger.warning('!!! unexpected event packet')
@@ -192,14 +208,18 @@ class Controller:
# Look for the connection to which this data belongs
connection = self.find_connection_by_handle(packet.connection_handle)
if connection is None:
logger.warning(f'!!! no connection for handle 0x{packet.connection_handle:04X}')
logger.warning(
f'!!! no connection for handle 0x{packet.connection_handle:04X}'
)
return
# Pass the packet to the connection
connection.on_hci_acl_data_packet(packet)
def send_hci_packet(self, packet):
logger.debug(f'{color(">>>", "green")} [{self.name}] {color("CONTROLLER -> HOST", "green")}: {packet}')
logger.debug(
f'{color(">>>", "green")} [{self.name}] {color("CONTROLLER -> HOST", "green")}: {packet}'
)
if self.host:
self.host.on_packet(packet.to_bytes())
@@ -215,8 +235,7 @@ class Controller:
handle = 0
max_handle = 0
for connection in itertools.chain(
self.central_connections.values(),
self.peripheral_connections.values()
self.central_connections.values(), self.peripheral_connections.values()
):
max_handle = max(max_handle, connection.handle)
if connection.handle == handle:
@@ -225,12 +244,13 @@ class Controller:
return handle
def find_connection_by_address(self, address):
return self.central_connections.get(address) or self.peripheral_connections.get(address)
return self.central_connections.get(address) or self.peripheral_connections.get(
address
)
def find_connection_by_handle(self, handle):
for connection in itertools.chain(
self.central_connections.values(),
self.peripheral_connections.values()
self.central_connections.values(), self.peripheral_connections.values()
):
if connection.handle == handle:
return connection
@@ -253,22 +273,26 @@ class Controller:
connection = self.peripheral_connections.get(peer_address)
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link)
connection = Connection(
self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link
)
self.peripheral_connections[peer_address] = connection
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
# Then say that the connection has completed
self.send_hci_packet(HCI_LE_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = connection.handle,
role = connection.role,
peer_address_type = peer_address_type,
peer_address = peer_address,
connection_interval = 10, # FIXME
peripheral_latency = 0, # FIXME
supervision_timeout = 10, # FIXME
central_clock_accuracy = 7 # FIXME
))
self.send_hci_packet(
HCI_LE_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=connection.handle,
role=connection.role,
peer_address_type=peer_address_type,
peer_address=peer_address,
connection_interval=10, # FIXME
peripheral_latency=0, # FIXME
supervision_timeout=10, # FIXME
central_clock_accuracy=7, # FIXME
)
)
def on_link_central_disconnected(self, peer_address, reason):
'''
@@ -277,18 +301,22 @@ class Controller:
# Send a disconnection complete event
if connection := self.peripheral_connections.get(peer_address):
self.send_hci_packet(HCI_Disconnection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = connection.handle,
reason = reason
))
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=connection.handle,
reason=reason,
)
)
# Remove the connection
del self.peripheral_connections[peer_address]
else:
logger.warn(f'!!! No peripheral connection found for {peer_address}')
def on_link_peripheral_connection_complete(self, le_create_connection_command, status):
def on_link_peripheral_connection_complete(
self, le_create_connection_command, status
):
'''
Called by the link when a connection has been made or has failed to be made
'''
@@ -300,29 +328,29 @@ class Controller:
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(
self,
connection_handle,
BT_CENTRAL_ROLE,
peer_address,
self.link
self, connection_handle, BT_CENTRAL_ROLE, peer_address, self.link
)
self.central_connections[peer_address] = connection
logger.debug(f'New CENTRAL connection handle: 0x{connection_handle:04X}')
logger.debug(
f'New CENTRAL connection handle: 0x{connection_handle:04X}'
)
else:
connection = None
# Say that the connection has completed
self.send_hci_packet(HCI_LE_Connection_Complete_Event(
status = status,
connection_handle = connection.handle if connection else 0,
role = BT_CENTRAL_ROLE,
peer_address_type = le_create_connection_command.peer_address_type,
peer_address = le_create_connection_command.peer_address,
connection_interval = le_create_connection_command.connection_interval_min,
peripheral_latency = le_create_connection_command.max_latency,
supervision_timeout = le_create_connection_command.supervision_timeout,
central_clock_accuracy = 0
))
self.send_hci_packet(
HCI_LE_Connection_Complete_Event(
status=status,
connection_handle=connection.handle if connection else 0,
role=BT_CENTRAL_ROLE,
peer_address_type=le_create_connection_command.peer_address_type,
peer_address=le_create_connection_command.peer_address,
connection_interval=le_create_connection_command.connection_interval_min,
peripheral_latency=le_create_connection_command.max_latency,
supervision_timeout=le_create_connection_command.supervision_timeout,
central_clock_accuracy=0,
)
)
def on_link_peripheral_disconnection_complete(self, disconnection_command, status):
'''
@@ -330,14 +358,18 @@ class Controller:
'''
# Send a disconnection complete event
self.send_hci_packet(HCI_Disconnection_Complete_Event(
status = status,
connection_handle = disconnection_command.connection_handle,
reason = disconnection_command.reason
))
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=status,
connection_handle=disconnection_command.connection_handle,
reason=disconnection_command.reason,
)
)
# Remove the connection
if connection := self.find_central_connection_by_handle(disconnection_command.connection_handle):
if connection := self.find_central_connection_by_handle(
disconnection_command.connection_handle
):
logger.debug(f'CENTRAL Connection removed: {connection}')
del self.central_connections[connection.peer_address]
@@ -348,11 +380,13 @@ class Controller:
# Send a disconnection complete event
if connection := self.central_connections.get(peer_address):
self.send_hci_packet(HCI_Disconnection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = connection.handle,
reason = HCI_CONNECTION_TIMEOUT_ERROR
))
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=connection.handle,
reason=HCI_CONNECTION_TIMEOUT_ERROR,
)
)
# Remove the connection
del self.central_connections[peer_address]
@@ -364,9 +398,7 @@ class Controller:
if connection := self.find_connection_by_address(peer_address):
self.send_hci_packet(
HCI_Encryption_Change_Event(
status = 0,
connection_handle = connection.handle,
encryption_enabled = 1
status=0, connection_handle=connection.handle, encryption_enabled=1
)
)
@@ -390,22 +422,22 @@ class Controller:
# Send a scan report
report = HCI_Object(
HCI_LE_Advertising_Report_Event.REPORT_FIELDS,
event_type = HCI_LE_Advertising_Report_Event.ADV_IND,
address_type = sender_address.address_type,
address = sender_address,
data = data,
rssi = -50
event_type=HCI_LE_Advertising_Report_Event.ADV_IND,
address_type=sender_address.address_type,
address=sender_address,
data=data,
rssi=-50,
)
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
# Simulate a scan response
report = HCI_Object(
HCI_LE_Advertising_Report_Event.REPORT_FIELDS,
event_type = HCI_LE_Advertising_Report_Event.SCAN_RSP,
address_type = sender_address.address_type,
address = sender_address,
data = data,
rssi = -50
event_type=HCI_LE_Advertising_Report_Event.SCAN_RSP,
address_type=sender_address.address_type,
address=sender_address,
data=data,
rssi=-50,
)
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
@@ -414,14 +446,18 @@ class Controller:
############################################################
def on_advertising_timer_fired(self):
self.send_advertising_data()
self.advertising_timer_handle = asyncio.get_running_loop().call_later(self.advertising_interval / 1000.0, self.on_advertising_timer_fired)
self.advertising_timer_handle = asyncio.get_running_loop().call_later(
self.advertising_interval / 1000.0, self.on_advertising_timer_fired
)
def start_advertising(self):
# Stop any ongoing advertising before we start again
self.stop_advertising()
# Advertise now
self.advertising_timer_handle = asyncio.get_running_loop().call_soon(self.on_advertising_timer_fired)
self.advertising_timer_handle = asyncio.get_running_loop().call_soon(
self.on_advertising_timer_fired
)
def stop_advertising(self):
if self.advertising_timer_handle is not None:
@@ -455,14 +491,20 @@ class Controller:
See Bluetooth spec Vol 2, Part E - 7.1.6 Disconnect Command
'''
# First, say that the disconnection is pending
self.send_hci_packet(HCI_Command_Status_Event(
status = HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets = 1,
command_opcode = command.op_code
))
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
# Notify the link of the disconnection
if not (connection := self.find_central_connection_by_handle(command.connection_handle)):
if not (
connection := self.find_central_connection_by_handle(
command.connection_handle
)
):
logger.warn('connection not found')
return
@@ -590,7 +632,7 @@ class Controller:
self.hci_revision,
self.lmp_version,
self.manufacturer_name,
self.lmp_subversion
self.lmp_subversion,
)
def on_hci_read_local_supported_commands_command(self, command):
@@ -609,7 +651,11 @@ class Controller:
'''
See Bluetooth spec Vol 2, Part E - 7.4.6 Read BD_ADDR Command
'''
bd_addr = self._public_address.to_bytes() if self._public_address is not None else bytes(6)
bd_addr = (
self._public_address.to_bytes()
if self._public_address is not None
else bytes(6)
)
return bytes([HCI_SUCCESS]) + bd_addr
def on_hci_le_set_event_mask_command(self, command):
@@ -623,10 +669,12 @@ class Controller:
'''
See Bluetooth spec Vol 2, Part E - 7.8.2 LE Read Buffer Size Command
'''
return struct.pack('<BHB',
HCI_SUCCESS,
self.hc_le_data_packet_length,
self.hc_total_num_le_data_packets)
return struct.pack(
'<BHB',
HCI_SUCCESS,
self.hc_le_data_packet_length,
self.hc_total_num_le_data_packets,
)
def on_hci_le_read_local_supported_features_command(self, command):
'''
@@ -683,10 +731,10 @@ class Controller:
'''
See Bluetooth spec Vol 2, Part E - 7.8.10 LE Set Scan Parameters Command
'''
self.le_scan_type = command.le_scan_type
self.le_scan_interval = command.le_scan_interval
self.le_scan_window = command.le_scan_window
self.le_scan_own_address_type = command.own_address_type
self.le_scan_type = command.le_scan_type
self.le_scan_interval = command.le_scan_interval
self.le_scan_window = command.le_scan_window
self.le_scan_own_address_type = command.own_address_type
self.le_scanning_filter_policy = command.scanning_filter_policy
return bytes([HCI_SUCCESS])
@@ -694,7 +742,7 @@ class Controller:
'''
See Bluetooth spec Vol 2, Part E - 7.8.11 LE Set Scan Enable Command
'''
self.le_scan_enable = command.le_scan_enable
self.le_scan_enable = command.le_scan_enable
self.filter_duplicates = command.filter_duplicates
return bytes([HCI_SUCCESS])
@@ -710,22 +758,26 @@ class Controller:
# Check that we don't already have a pending connection
if self.link.get_pending_connection():
self.send_hci_packet(HCI_Command_Status_Event(
status = HCI_COMMAND_DISALLOWED_ERROR,
num_hci_command_packets = 1,
command_opcode = command.op_code
))
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_DISALLOWED_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
# Initiate the connection
self.link.connect(self.random_address, command)
# Say that the connection is pending
self.send_hci_packet(HCI_Command_Status_Event(
status = HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets = 1,
command_opcode = command.op_code
))
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
def on_hci_le_create_connection_cancel_command(self, command):
'''
@@ -763,18 +815,22 @@ class Controller:
'''
# First, say that the command is pending
self.send_hci_packet(HCI_Command_Status_Event(
status = HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets = 1,
command_opcode = command.op_code
))
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
# Then send the remote features
self.send_hci_packet(HCI_LE_Read_Remote_Features_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0,
le_features = bytes.fromhex('dd40000000000000')
))
self.send_hci_packet(
HCI_LE_Read_Remote_Features_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0,
le_features=bytes.fromhex('dd40000000000000'),
)
)
def on_hci_le_rand_command(self, command):
'''
@@ -788,7 +844,11 @@ class Controller:
'''
# Check the parameters
if not (connection := self.find_central_connection_by_handle(command.connection_handle)):
if not (
connection := self.find_central_connection_by_handle(
command.connection_handle
)
):
logger.warn('connection not found')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
@@ -798,14 +858,16 @@ class Controller:
connection.peer_address,
command.random_number,
command.encrypted_diversifier,
command.long_term_key
command.long_term_key,
)
self.send_hci_packet(HCI_Command_Status_Event(
status = HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets = 1,
command_opcode = command.op_code
))
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
def on_hci_le_read_supported_states_command(self, command):
'''
@@ -817,16 +879,20 @@ class Controller:
'''
See Bluetooth spec Vol 2, Part E - 7.8.34 LE Read Suggested Default Data Length Command
'''
return struct.pack('<BHH',
HCI_SUCCESS,
self.suggested_max_tx_octets,
self.suggested_max_tx_time)
return struct.pack(
'<BHH',
HCI_SUCCESS,
self.suggested_max_tx_octets,
self.suggested_max_tx_time,
)
def on_hci_le_write_suggested_default_data_length_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.35 LE Write Suggested Default Data Length Command
'''
self.suggested_max_tx_octets, self.suggested_max_tx_time = struct.unpack('<HH', command.parameters[:4])
self.suggested_max_tx_octets, self.suggested_max_tx_time = struct.unpack(
'<HH', command.parameters[:4]
)
return bytes([HCI_SUCCESS])
def on_hci_le_read_local_p_256_public_key_command(self, command):
@@ -884,7 +950,7 @@ class Controller:
self.supported_max_tx_octets,
self.supported_max_tx_time,
self.supported_max_rx_octets,
self.supported_max_rx_time
self.supported_max_rx_time,
)
def on_hci_le_read_phy_command(self, command):
@@ -896,7 +962,7 @@ class Controller:
HCI_SUCCESS,
command.connection_handle,
HCI_LE_1M_PHY,
HCI_LE_1M_PHY
HCI_LE_1M_PHY,
)
def on_hci_le_set_default_phy_command(self, command):
@@ -905,8 +971,7 @@ class Controller:
'''
self.default_phy = {
'all_phys': command.all_phys,
'tx_phys': command.tx_phys,
'rx_phys': command.rx_phys
'tx_phys': command.tx_phys,
'rx_phys': command.rx_phys,
}
return bytes([HCI_SUCCESS])

View File

@@ -23,6 +23,8 @@ from .company_ids import COMPANY_IDENTIFIERS
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
BT_CENTRAL_ROLE = 0
BT_PERIPHERAL_ROLE = 1
@@ -30,6 +32,9 @@ BT_BR_EDR_TRANSPORT = 0
BT_LE_TRANSPORT = 1
# fmt: on
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
@@ -64,17 +69,19 @@ def get_dict_key_by_value(dictionary, value):
return key
return None
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
class BaseError(Exception):
""" Base class for errors with an error code, error name and namespace"""
"""Base class for errors with an error code, error name and namespace"""
def __init__(self, error_code, error_namespace='', error_name='', details=''):
super().__init__()
self.error_code = error_code
self.error_code = error_code
self.error_namespace = error_namespace
self.error_name = error_name
self.details = details
self.error_name = error_name
self.details = details
def __str__(self):
if self.error_namespace:
@@ -90,27 +97,36 @@ class BaseError(Exception):
class ProtocolError(BaseError):
""" Protocol Error """
"""Protocol Error"""
class TimeoutError(Exception):
""" Timeout Error """
"""Timeout Error"""
class CommandTimeoutError(Exception):
""" Command Timeout Error """
"""Command Timeout Error"""
class InvalidStateError(Exception):
""" Invalid State Error """
"""Invalid State Error"""
class ConnectionError(BaseError):
""" Connection Error """
FAILURE = 0x01
"""Connection Error"""
FAILURE = 0x01
CONNECTION_REFUSED = 0x02
def __init__(self, error_code, transport, peer_address, error_namespace='', error_name='', details=''):
def __init__(
self,
error_code,
transport,
peer_address,
error_namespace='',
error_name='',
details='',
):
super().__init__(error_code, error_namespace, error_name, details)
self.transport = transport
self.peer_address = peer_address
@@ -127,15 +143,21 @@ class UUID:
'''
See Bluetooth spec Vol 3, Part B - 2.5.1 UUID
'''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
UUIDS = [] # Registry of all instances created
def __init__(self, uuid_str_or_int, name = None):
def __init__(self, uuid_str_or_int, name=None):
if type(uuid_str_or_int) is int:
self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else:
if len(uuid_str_or_int) == 36:
if uuid_str_or_int[8] != '-' or uuid_str_or_int[13] != '-' or uuid_str_or_int[18] != '-' or uuid_str_or_int[23] != '-':
if (
uuid_str_or_int[8] != '-'
or uuid_str_or_int[13] != '-'
or uuid_str_or_int[18] != '-'
or uuid_str_or_int[23] != '-'
):
raise ValueError('invalid UUID format')
uuid_str = uuid_str_or_int.replace('-', '')
else:
@@ -157,7 +179,7 @@ class UUID:
return self
@classmethod
def from_bytes(cls, uuid_bytes, name = None):
def from_bytes(cls, uuid_bytes, name=None):
if len(uuid_bytes) in {2, 4, 16}:
self = cls.__new__(cls)
self.uuid_bytes = uuid_bytes
@@ -168,11 +190,11 @@ class UUID:
raise ValueError('only 2, 4 and 16 bytes are allowed')
@classmethod
def from_16_bits(cls, uuid_16, name = None):
def from_16_bits(cls, uuid_16, name=None):
return cls.from_bytes(struct.pack('<H', uuid_16), name)
@classmethod
def from_32_bits(cls, uuid_32, name = None):
def from_32_bits(cls, uuid_32, name=None):
return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod
@@ -181,9 +203,9 @@ class UUID:
@classmethod
def parse_uuid_2(cls, bytes, offset):
return offset + 2, cls.from_bytes(bytes[offset:offset + 2])
return offset + 2, cls.from_bytes(bytes[offset : offset + 2])
def to_bytes(self, force_128 = False):
def to_bytes(self, force_128=False):
if len(self.uuid_bytes) == 16 or not force_128:
return self.uuid_bytes
elif len(self.uuid_bytes) == 4:
@@ -198,26 +220,28 @@ class UUID:
"All 32-bit Attribute UUIDs shall be converted to 128-bit UUIDs when the
Attribute UUID is contained in an ATT PDU."
'''
return self.to_bytes(force_128 = (len(self.uuid_bytes) == 4))
return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
def to_hex_str(self):
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
return bytes(reversed(self.uuid_bytes)).hex().upper()
else:
return ''.join([
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex()
]).upper()
return ''.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
def __bytes__(self):
return self.to_bytes()
def __eq__(self, other):
if isinstance(other, UUID):
return self.to_bytes(force_128 = True) == other.to_bytes(force_128 = True)
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
elif type(other) is str:
return UUID(other) == self
@@ -234,13 +258,15 @@ class UUID:
v = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{v:08X}'
else:
result = '-'.join([
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex()
]).upper()
result = '-'.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
if self.name is not None:
return result + f' ({self.name})'
else:
@@ -253,6 +279,7 @@ class UUID:
# -----------------------------------------------------------------------------
# Common UUID constants
# -----------------------------------------------------------------------------
# fmt: off
# Protocol Identifiers
BT_SDP_PROTOCOL_ID = UUID.from_16_bits(0x0001, 'SDP')
@@ -358,11 +385,15 @@ BT_HDP_SERVICE = UUID.from_16_bits(0x1400,
BT_HDP_SOURCE_SERVICE = UUID.from_16_bits(0x1401, 'HDP Source')
BT_HDP_SINK_SERVICE = UUID.from_16_bits(0x1402, 'HDP Sink')
# fmt: on
# -----------------------------------------------------------------------------
# DeviceClass
# -----------------------------------------------------------------------------
class DeviceClass:
# fmt: off
# Major Service Classes (flags combined with OR)
LIMITED_DISCOVERABLE_MODE_SERVICE_CLASS = (1 << 0)
LE_AUDIO_SERVICE_CLASS = (1 << 1)
@@ -530,11 +561,17 @@ class DeviceClass:
PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES
}
# fmt: on
@staticmethod
def split_class_of_device(class_of_device):
# Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class)
return ((class_of_device >> 13 & 0x7FF), (class_of_device >> 8 & 0x1F), (class_of_device >> 2 & 0x3F))
return (
(class_of_device >> 13 & 0x7FF),
(class_of_device >> 8 & 0x1F),
(class_of_device >> 2 & 0x3F),
)
@staticmethod
def pack_class_of_device(service_classes, major_device_class, minor_device_class):
@@ -542,7 +579,9 @@ class DeviceClass:
@staticmethod
def service_class_labels(service_class_flags):
return bit_flags_to_strings(service_class_flags, DeviceClass.SERVICE_CLASS_LABELS)
return bit_flags_to_strings(
service_class_flags, DeviceClass.SERVICE_CLASS_LABELS
)
@staticmethod
def major_device_class_name(device_class):
@@ -560,6 +599,8 @@ class DeviceClass:
# Advertising Data
# -----------------------------------------------------------------------------
class AdvertisingData:
# fmt: off
# This list is only partial, it still needs to be filled in from the spec
FLAGS = 0x01
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02
@@ -671,7 +712,9 @@ class AdvertisingData:
BR_EDR_CONTROLLER_FLAG = 0x08
BR_EDR_HOST_FLAG = 0x10
def __init__(self, ad_structures = []):
# fmt: on
def __init__(self, ad_structures=[]):
self.ad_structures = ad_structures[:]
@staticmethod
@@ -682,19 +725,17 @@ class AdvertisingData:
@staticmethod
def flags_to_string(flags, short=False):
flag_names = [
'LE Limited',
'LE General',
'No BR/EDR',
'BR/EDR C',
'BR/EDR H'
] if short else [
'LE Limited Discoverable Mode',
'LE General Discoverable Mode',
'BR/EDR Not Supported',
'Simultaneous LE and BR/EDR (Controller)',
'Simultaneous LE and BR/EDR (Host)'
]
flag_names = (
['LE Limited', 'LE General', 'No BR/EDR', 'BR/EDR C', 'BR/EDR H']
if short
else [
'LE Limited Discoverable Mode',
'LE General Discoverable Mode',
'BR/EDR Not Supported',
'Simultaneous LE and BR/EDR (Controller)',
'Simultaneous LE and BR/EDR (Host)',
]
)
return ','.join(bit_flags_to_strings(flags, flag_names))
@staticmethod
@@ -702,16 +743,18 @@ class AdvertisingData:
uuids = []
offset = 0
while (uuid_size * (offset + 1)) <= len(ad_data):
uuids.append(UUID.from_bytes(ad_data[offset:offset + uuid_size]))
uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size]))
offset += uuid_size
return uuids
@staticmethod
def uuid_list_to_string(ad_data, uuid_size):
return ', '.join([
str(uuid)
for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size)
])
return ', '.join(
[
str(uuid)
for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size)
]
)
@staticmethod
def ad_data_to_string(ad_type, ad_data):
@@ -776,19 +819,19 @@ class AdvertisingData:
if ad_type in {
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
}:
return AdvertisingData.uuid_list_to_objects(ad_data, 2)
elif ad_type in {
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
}:
return AdvertisingData.uuid_list_to_objects(ad_data, 4)
elif ad_type in {
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
}:
return AdvertisingData.uuid_list_to_objects(ad_data, 16)
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
@@ -800,17 +843,14 @@ class AdvertisingData:
elif ad_type in {
AdvertisingData.SHORTENED_LOCAL_NAME,
AdvertisingData.COMPLETE_LOCAL_NAME,
AdvertisingData.URI
AdvertisingData.URI,
}:
return ad_data.decode("utf-8")
elif ad_type in {
AdvertisingData.TX_POWER_LEVEL,
AdvertisingData.FLAGS
}:
elif ad_type in {AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS}:
return ad_data[0]
elif ad_type in {
AdvertisingData.APPEARANCE,
AdvertisingData.ADVERTISING_INTERVAL
AdvertisingData.ADVERTISING_INTERVAL,
}:
return struct.unpack('<H', ad_data)[0]
elif ad_type == AdvertisingData.CLASS_OF_DEVICE:
@@ -829,7 +869,7 @@ class AdvertisingData:
offset += 1
if length > 0:
ad_type = data[offset]
ad_data = data[offset + 1:offset + length]
ad_data = data[offset + 1 : offset + length]
self.ad_structures.append((ad_type, ad_data))
offset += length
@@ -840,19 +880,33 @@ class AdvertisingData:
If return_all is True, returns a (possibly empty) list of matches,
else returns the first entry, or None if no structure matches.
'''
def process_ad_data(ad_data):
return ad_data if raw else self.ad_data_to_object(type_id, ad_data)
if return_all:
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
return [
process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id
]
else:
return next((process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id), None)
return next(
(
process_ad_data(ad[1])
for ad in self.ad_structures
if ad[0] == type_id
),
None,
)
def __bytes__(self):
return b''.join([bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures])
return b''.join(
[bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures]
)
def to_string(self, separator=', '):
return separator.join([AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures])
return separator.join(
[AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures]
)
def __str__(self):
return self.to_string()
@@ -864,7 +918,7 @@ class AdvertisingData:
class ConnectionParameters:
def __init__(self, connection_interval, peripheral_latency, supervision_timeout):
self.connection_interval = connection_interval
self.peripheral_latency = peripheral_latency
self.peripheral_latency = peripheral_latency
self.supervision_timeout = supervision_timeout
def __str__(self):

View File

@@ -24,19 +24,16 @@
import logging
import operator
import platform
if platform.system() != 'Emscripten':
import secrets
from cryptography.hazmat.primitives.ciphers import (
Cipher,
algorithms,
modes
)
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key,
ECDH,
EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers,
SECP256R1
SECP256R1,
)
from cryptography.hazmat.primitives import cmac
else:
@@ -66,16 +63,26 @@ class EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
private_key = EllipticCurvePrivateNumbers(d, EllipticCurvePublicNumbers(x, y, SECP256R1())).private_key()
private_key = EllipticCurvePrivateNumbers(
d, EllipticCurvePublicNumbers(x, y, SECP256R1())
).private_key()
return cls(private_key)
@property
def x(self):
return self.private_key.public_key().public_numbers().x.to_bytes(32, byteorder='big')
return (
self.private_key.public_key()
.public_numbers()
.x.to_bytes(32, byteorder='big')
)
@property
def y(self):
return self.private_key.public_key().public_numbers().y.to_bytes(32, byteorder='big')
return (
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
def dh(self, public_key_x, public_key_y):
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
@@ -92,7 +99,7 @@ class EccKey:
# -----------------------------------------------------------------------------
def xor(x, y):
assert(len(x) == len(y))
assert len(x) == len(y)
return bytes(map(operator.xor, x, y))
@@ -165,7 +172,11 @@ def f4(u, v, x, z):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value Generation Function f4
'''
return bytes(reversed(aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))))
return bytes(
reversed(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
)
)
# -----------------------------------------------------------------------------
@@ -177,28 +188,36 @@ def f5(w, n1, n2, a1, a2):
'''
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
t = aes_cmac(bytes(reversed(w)), salt)
key_id = bytes([0x62, 0x74, 0x6c, 0x65])
key_id = bytes([0x62, 0x74, 0x6C, 0x65])
return (
bytes(reversed(aes_cmac(
bytes([0]) +
key_id +
bytes(reversed(n1)) +
bytes(reversed(n2)) +
bytes(reversed(a1)) +
bytes(reversed(a2)) +
bytes([1, 0]),
t
))),
bytes(reversed(aes_cmac(
bytes([1]) +
key_id +
bytes(reversed(n1)) +
bytes(reversed(n2)) +
bytes(reversed(a1)) +
bytes(reversed(a2)) +
bytes([1, 0]),
t
)))
bytes(
reversed(
aes_cmac(
bytes([0])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
bytes(
reversed(
aes_cmac(
bytes([1])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
)
@@ -207,15 +226,19 @@ def f6(w, n1, n2, r, io_cap, a1, a2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value Generation Function f6
'''
return bytes(reversed(aes_cmac(
bytes(reversed(n1)) +
bytes(reversed(n2)) +
bytes(reversed(r)) +
bytes(reversed(io_cap)) +
bytes(reversed(a1)) +
bytes(reversed(a2)),
bytes(reversed(w))
)))
return bytes(
reversed(
aes_cmac(
bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(r))
+ bytes(reversed(io_cap))
+ bytes(reversed(a1))
+ bytes(reversed(a2)),
bytes(reversed(w)),
)
)
)
# -----------------------------------------------------------------------------
@@ -224,10 +247,14 @@ def g2(u, v, x, y):
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison Value Generation Function g2
'''
return int.from_bytes(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)), bytes(reversed(x)))[-4:],
byteorder='big'
aes_cmac(
bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
bytes(reversed(x)),
)[-4:],
byteorder='big',
)
# -----------------------------------------------------------------------------
def h6(w, key_id):
'''
@@ -235,6 +262,7 @@ def h6(w, key_id):
'''
return aes_cmac(key_id, w)
# -----------------------------------------------------------------------------
def h7(salt, w):
'''

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,7 @@ from .gatt import (
Characteristic,
GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC
GATT_APPEARANCE_CHARACTERISTIC,
)
# -----------------------------------------------------------------------------
@@ -38,22 +38,22 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class GenericAccessService(Service):
def __init__(self, device_name, appearance = (0, 0)):
def __init__(self, device_name, appearance=(0, 0)):
device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248]
device_name.encode('utf-8')[:248],
)
appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1])
struct.pack('<H', (appearance[0] << 6) | appearance[1]),
)
super().__init__(GATT_GENERIC_ACCESS_SERVICE, [
device_name_characteristic,
appearance_characteristic
])
super().__init__(
GATT_GENERIC_ACCESS_SERVICE,
[device_name_characteristic, appearance_characteristic],
)

View File

@@ -42,6 +42,8 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
GATT_REQUEST_TIMEOUT = 30 # seconds
GATT_MAX_ATTRIBUTE_VALUE_SIZE = 512
@@ -174,11 +176,14 @@ GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bi
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
# fmt: on
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def show_services(services):
for service in services:
print(color(str(service), 'cyan'))
@@ -202,14 +207,16 @@ class Service(Attribute):
uuid = UUID(uuid)
super().__init__(
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Attribute.READABLE,
uuid.to_pdu_bytes()
uuid.to_pdu_bytes(),
)
self.uuid = uuid
self.uuid = uuid
self.included_services = []
self.characteristics = characteristics[:]
self.primary = primary
self.characteristics = characteristics[:]
self.primary = primary
def get_advertising_data(self):
"""
@@ -229,6 +236,7 @@ class TemplateService(Service):
Convenience abstract class that can be used by profile-specific subclasses that want
to expose their UUID as a class property
'''
UUID = None
def __init__(self, characteristics, primary=True):
@@ -242,24 +250,24 @@ class Characteristic(Attribute):
'''
# Property flags
BROADCAST = 0x01
READ = 0x02
WRITE_WITHOUT_RESPONSE = 0x04
WRITE = 0x08
NOTIFY = 0x10
INDICATE = 0X20
AUTHENTICATED_SIGNED_WRITES = 0X40
EXTENDED_PROPERTIES = 0X80
BROADCAST = 0x01
READ = 0x02
WRITE_WITHOUT_RESPONSE = 0x04
WRITE = 0x08
NOTIFY = 0x10
INDICATE = 0x20
AUTHENTICATED_SIGNED_WRITES = 0x40
EXTENDED_PROPERTIES = 0x80
PROPERTY_NAMES = {
BROADCAST: 'BROADCAST',
READ: 'READ',
WRITE_WITHOUT_RESPONSE: 'WRITE_WITHOUT_RESPONSE',
WRITE: 'WRITE',
NOTIFY: 'NOTIFY',
INDICATE: 'INDICATE',
BROADCAST: 'BROADCAST',
READ: 'READ',
WRITE_WITHOUT_RESPONSE: 'WRITE_WITHOUT_RESPONSE',
WRITE: 'WRITE',
NOTIFY: 'NOTIFY',
INDICATE: 'INDICATE',
AUTHENTICATED_SIGNED_WRITES: 'AUTHENTICATED_SIGNED_WRITES',
EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES'
EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES',
}
@staticmethod
@@ -268,10 +276,13 @@ class Characteristic(Attribute):
@staticmethod
def properties_as_string(properties):
return ','.join([
Characteristic.property_name(p) for p in Characteristic.PROPERTY_NAMES.keys()
if properties & p
])
return ','.join(
[
Characteristic.property_name(p)
for p in Characteristic.PROPERTY_NAMES.keys()
if properties & p
]
)
@staticmethod
def string_to_properties(properties_str: str):
@@ -281,9 +292,16 @@ class Characteristic(Attribute):
0,
)
def __init__(self, uuid, properties, permissions, value = b'', descriptors: list[Descriptor] = []):
def __init__(
self,
uuid,
properties,
permissions,
value=b'',
descriptors: list[Descriptor] = [],
):
super().__init__(uuid, permissions, value)
self.uuid = self.type
self.uuid = self.type
if type(properties) is str:
self.properties = Characteristic.string_to_properties(properties)
else:
@@ -304,25 +322,29 @@ class CharacteristicDeclaration(Attribute):
'''
See Vol 3, Part G - 3.3.1 CHARACTERISTIC DECLARATION
'''
def __init__(self, characteristic, value_handle):
declaration_bytes = struct.pack(
'<BH',
characteristic.properties,
value_handle
) + characteristic.uuid.to_pdu_bytes()
super().__init__(GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes)
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
)
super().__init__(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
)
self.value_handle = value_handle
self.characteristic = characteristic
def __str__(self):
return f'CharacteristicDeclaration(handle=0x{self.handle:04X}, value_handle=0x{self.value_handle:04X}, uuid={self.characteristic.uuid}, properties={Characteristic.properties_as_string(self.characteristic.properties)})'
# -----------------------------------------------------------------------------
class CharacteristicValue:
'''
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
@@ -349,18 +371,18 @@ class CharacteristicAdapter:
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
'''
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber
if (
asyncio.iscoroutinefunction(characteristic.read_value) and
asyncio.iscoroutinefunction(characteristic.write_value)
):
self.read_value = self.read_decoded_value
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
self.read_value = self.read_encoded_value
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
@@ -379,7 +401,7 @@ class CharacteristicAdapter:
'read_value',
'write_value',
'subscribe',
'unsubscribe'
'unsubscribe',
}:
super().__setattr__(name, value)
else:
@@ -389,15 +411,16 @@ class CharacteristicAdapter:
return self.encode_value(self.wrapped_characteristic.read_value(connection))
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(connection, self.decode_value(value))
return self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
async def read_decoded_value(self):
return self.decode_value(await self.wrapped_characteristic.read_value())
async def write_decoded_value(self, value, with_response=False):
return await self.wrapped_characteristic.write_value(
self.encode_value(value),
with_response
self.encode_value(value), with_response
)
def encode_value(self, value):
@@ -417,6 +440,7 @@ class CharacteristicAdapter:
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
@@ -438,6 +462,7 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts bytes values using an encode and a decode function.
'''
def __init__(self, characteristic, encode=None, decode=None):
super().__init__(characteristic)
self.encode = encode
@@ -460,6 +485,7 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
they return/accept a tuple with the same number of elements as is required for
the format.
'''
def __init__(self, characteristic, format):
super().__init__(characteristic)
self.struct = struct.Struct(format)
@@ -487,6 +513,7 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
is packed/unpacked according to format, with the arguments extracted from the dictionary
by key, in the same order as they occur in the `keys` parameter.
'''
def __init__(self, characteristic, format, keys):
super().__init__(characteristic, format)
self.keys = keys
@@ -503,6 +530,7 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
def encode_value(self, value):
return value.encode('utf-8')
@@ -516,7 +544,7 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
'''
def __init__(self, descriptor_type, permissions, value = b''):
def __init__(self, descriptor_type, permissions, value=b''):
super().__init__(descriptor_type, permissions, value)
def __str__(self):
@@ -527,6 +555,7 @@ class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit field definition
'''
DEFAULT = 0x0000
NOTIFICATION = 0x0001
INDICATION = 0x0002

View File

@@ -31,11 +31,15 @@ from colors import color
from .att import *
from .core import InvalidStateError, ProtocolError, TimeoutError
from .gatt import (GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, Characteristic,
ClientCharacteristicConfigurationBits)
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
)
from .hci import *
# -----------------------------------------------------------------------------
@@ -50,16 +54,20 @@ logger = logging.getLogger(__name__)
class AttributeProxy(EventEmitter):
def __init__(self, client, handle, end_group_handle, attribute_type):
EventEmitter.__init__(self)
self.client = client
self.handle = handle
self.client = client
self.handle = handle
self.end_group_handle = end_group_handle
self.type = attribute_type
self.type = attribute_type
async def read_value(self, no_long_read=False):
return self.decode_value(await self.client.read_value(self.handle, no_long_read))
return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
async def write_value(self, value, with_response=False):
return await self.client.write_value(self.handle, self.encode_value(value), with_response)
return await self.client.write_value(
self.handle, self.encode_value(value), with_response
)
def encode_value(self, value):
return value
@@ -80,9 +88,13 @@ class ServiceProxy(AttributeProxy):
return cls(service) if service else None
def __init__(self, client, handle, end_group_handle, uuid, primary=True):
attribute_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
attribute_type = (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
)
super().__init__(client, handle, end_group_handle, attribute_type)
self.uuid = uuid
self.uuid = uuid
self.characteristics = []
async def discover_characteristics(self, uuids=[]):
@@ -98,11 +110,11 @@ class ServiceProxy(AttributeProxy):
class CharacteristicProxy(AttributeProxy):
def __init__(self, client, handle, end_group_handle, uuid, properties):
super().__init__(client, handle, end_group_handle, uuid)
self.uuid = uuid
self.properties = properties
self.descriptors = []
self.uuid = uuid
self.properties = properties
self.descriptors = []
self.descriptors_discovered = False
self.subscribers = {} # Map from subscriber to proxy subscriber
self.subscribers = {} # Map from subscriber to proxy subscriber
def get_descriptor(self, descriptor_type):
for descriptor in self.descriptors:
@@ -123,6 +135,7 @@ class CharacteristicProxy(AttributeProxy):
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
@@ -150,6 +163,7 @@ class ProfileServiceProxy:
'''
Base class for profile-specific service proxies
'''
@classmethod
def from_client(cls, client):
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -160,24 +174,30 @@ class ProfileServiceProxy:
# -----------------------------------------------------------------------------
class Client:
def __init__(self, connection):
self.connection = connection
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = {} # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.services = []
self.connection = connection
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = (
{}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.services = []
def send_gatt_pdu(self, pdu):
self.connection.send_l2cap_pdu(ATT_CID, pdu)
async def send_command(self, command):
logger.debug(f'GATT Command from client: [0x{self.connection.handle:04X}] {command}')
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes())
async def send_request(self, request):
logger.debug(f'GATT Request from client: [0x{self.connection.handle:04X}] {request}')
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
# Wait until we can send (only one pending command at a time for the connection)
response = None
@@ -187,22 +207,26 @@ class Client:
# Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future()
self.pending_request = request
self.pending_request = request
try:
self.send_gatt_pdu(request.to_bytes())
response = await asyncio.wait_for(self.pending_response, GATT_REQUEST_TIMEOUT)
response = await asyncio.wait_for(
self.pending_response, GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError:
logger.warning(color('!!! GATT Request timeout', 'red'))
raise TimeoutError(f'GATT timeout for {request.name}')
finally:
self.pending_request = None
self.pending_request = None
self.pending_response = None
return response
def send_confirmation(self, confirmation):
logger.debug(f'GATT Confirmation from client: [0x{self.connection.handle:04X}] {confirmation}')
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] {confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes())
async def request_mtu(self, mtu):
@@ -218,13 +242,13 @@ class Client:
# Send the request
self.mtu_exchange_done = True
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu = mtu))
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu))
if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response
response,
)
# Compute the final MTU
@@ -235,12 +259,16 @@ class Client:
def get_services_by_uuid(self, uuid):
return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid(self, uuid, service = None):
def get_characteristics_by_uuid(self, uuid, service=None):
services = [service] if service else self.services
return [c for c in [c for s in services for c in s.characteristics] if c.uuid == uuid]
return [
c
for c in [c for s in services for c in s.characteristics]
if c.uuid == uuid
]
def on_service_discovered(self, service):
''' Add a service to the service list if it wasn't already there '''
'''Add a service to the service list if it wasn't already there'''
already_known = False
for existing_service in self.services:
if existing_service.handle == service.handle:
@@ -249,7 +277,7 @@ class Client:
if not already_known:
self.services.append(service)
async def discover_services(self, uuids = None):
async def discover_services(self, uuids=None):
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -258,9 +286,9 @@ class Client:
while starting_handle < 0xFFFF:
response = await self.send_request(
ATT_Read_By_Group_Type_Request(
starting_handle = starting_handle,
ending_handle = 0xFFFF,
attribute_group_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
starting_handle=starting_handle,
ending_handle=0xFFFF,
attribute_group_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
)
)
if response is None:
@@ -271,15 +299,26 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
break
for attribute_handle, end_group_handle, attribute_value in response.attributes:
if attribute_handle < starting_handle or end_group_handle < attribute_handle:
for (
attribute_handle,
end_group_handle,
attribute_value,
) in response.attributes:
if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
# Something's not right
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}')
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return
# Create a service proxy for this service
@@ -288,7 +327,7 @@ class Client:
attribute_handle,
end_group_handle,
UUID.from_bytes(attribute_value),
True
True,
)
# Filter out returned services based on the given uuids list
@@ -321,10 +360,10 @@ class Client:
while starting_handle < 0xFFFF:
response = await self.send_request(
ATT_Find_By_Type_Value_Request(
starting_handle = starting_handle,
ending_handle = 0xFFFF,
attribute_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
attribute_value = uuid.to_pdu_bytes()
starting_handle=starting_handle,
ending_handle=0xFFFF,
attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
attribute_value=uuid.to_pdu_bytes(),
)
)
if response is None:
@@ -335,19 +374,28 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
break
for attribute_handle, end_group_handle in response.handles_information:
if attribute_handle < starting_handle or end_group_handle < attribute_handle:
if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
# Something's not right
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}')
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return
# Create a service proxy for this service
service = ServiceProxy(self, attribute_handle, end_group_handle, uuid, True)
service = ServiceProxy(
self, attribute_handle, end_group_handle, uuid, True
)
# Add the service to the peer's service list
services.append(service)
@@ -388,15 +436,15 @@ class Client:
discovered_characteristics = []
for service in services:
starting_handle = service.handle
ending_handle = service.end_group_handle
ending_handle = service.end_group_handle
characteristics = []
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Read_By_Type_Request(
starting_handle = starting_handle,
ending_handle = ending_handle,
attribute_type = GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
starting_handle=starting_handle,
ending_handle=ending_handle,
attribute_type=GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
)
)
if response is None:
@@ -407,7 +455,9 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering characteristics: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
f'!!! unexpected error while discovering characteristics: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
break
@@ -425,7 +475,9 @@ class Client:
properties, handle = struct.unpack_from('<BH', attribute_value)
characteristic_uuid = UUID.from_bytes(attribute_value[3:])
characteristic = CharacteristicProxy(self, handle, 0, characteristic_uuid, properties)
characteristic = CharacteristicProxy(
self, handle, 0, characteristic_uuid, properties
)
# Set the previous characteristic's end handle
if characteristics:
@@ -441,22 +493,26 @@ class Client:
characteristics[-1].end_group_handle = service.end_group_handle
# Set the service's characteristics
characteristics = [c for c in characteristics if not uuids or c.uuid in uuids]
characteristics = [
c for c in characteristics if not uuids or c.uuid in uuids
]
service.characteristics = characteristics
discovered_characteristics.extend(characteristics)
return discovered_characteristics
async def discover_descriptors(self, characteristic = None, start_handle = None, end_handle = None):
async def discover_descriptors(
self, characteristic=None, start_handle=None, end_handle=None
):
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
'''
if characteristic:
starting_handle = characteristic.handle + 1
ending_handle = characteristic.end_group_handle
ending_handle = characteristic.end_group_handle
elif start_handle and end_handle:
starting_handle = start_handle
ending_handle = end_handle
ending_handle = end_handle
else:
return []
@@ -464,8 +520,7 @@ class Client:
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Find_Information_Request(
starting_handle = starting_handle,
ending_handle = ending_handle
starting_handle=starting_handle, ending_handle=ending_handle
)
)
if response is None:
@@ -476,7 +531,9 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering descriptors: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
f'!!! unexpected error while discovering descriptors: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return []
break
@@ -492,7 +549,9 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}')
return []
descriptor = DescriptorProxy(self, attribute_handle, UUID.from_bytes(attribute_uuid))
descriptor = DescriptorProxy(
self, attribute_handle, UUID.from_bytes(attribute_uuid)
)
descriptors.append(descriptor)
# TODO: read descriptor value
@@ -510,13 +569,12 @@ class Client:
Discover all attributes, regardless of type
'''
starting_handle = 0x0001
ending_handle = 0xFFFF
ending_handle = 0xFFFF
attributes = []
while True:
response = await self.send_request(
ATT_Find_Information_Request(
starting_handle = starting_handle,
ending_handle = ending_handle
starting_handle=starting_handle, ending_handle=ending_handle
)
)
if response is None:
@@ -526,7 +584,9 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering attributes: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
f'!!! unexpected error while discovering attributes: {HCI_Constant.error_name(response.error_code)}'
)
return []
break
@@ -536,7 +596,9 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}')
return []
attribute = AttributeProxy(self, attribute_handle, 0, UUID.from_bytes(attribute_uuid))
attribute = AttributeProxy(
self, attribute_handle, 0, UUID.from_bytes(attribute_uuid)
)
attributes.append(attribute)
# Move on to the next attributes
@@ -550,7 +612,9 @@ class Client:
await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor
cccd = characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR)
cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
if not cccd:
logger.warning('subscribing to characteristic with no CCCD descriptor')
return
@@ -590,14 +654,19 @@ class Client:
await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor
cccd = characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR)
cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
if not cccd:
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return
if subscriber is not None:
# Remove matching subscriber from subscriber sets
for subscriber_set in (self.notification_subscribers, self.indication_subscribers):
for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers:
subscribers.remove(subscriber)
@@ -623,7 +692,9 @@ class Client:
# Send a request to read
attribute_handle = attribute if type(attribute) is int else attribute.handle
response = await self.send_request(ATT_Read_Request(attribute_handle = attribute_handle))
response = await self.send_request(
ATT_Read_Request(attribute_handle=attribute_handle)
)
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
@@ -631,7 +702,7 @@ class Client:
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response
response,
)
# If the value is the max size for the MTU, try to read more unless the caller
@@ -642,18 +713,23 @@ class Client:
offset = len(attribute_value)
while True:
response = await self.send_request(
ATT_Read_Blob_Request(attribute_handle = attribute_handle, value_offset = offset)
ATT_Read_Blob_Request(
attribute_handle=attribute_handle, value_offset=offset
)
)
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code == ATT_ATTRIBUTE_NOT_LONG_ERROR or response.error_code == ATT_INVALID_OFFSET_ERROR:
if (
response.error_code == ATT_ATTRIBUTE_NOT_LONG_ERROR
or response.error_code == ATT_INVALID_OFFSET_ERROR
):
break
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response
response,
)
part = response.part_attribute_value
@@ -674,18 +750,18 @@ class Client:
if service is None:
starting_handle = 0x0001
ending_handle = 0xFFFF
ending_handle = 0xFFFF
else:
starting_handle = service.handle
ending_handle = service.end_group_handle
ending_handle = service.end_group_handle
characteristics_values = []
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Read_By_Type_Request(
starting_handle = starting_handle,
ending_handle = ending_handle,
attribute_type = uuid
starting_handle=starting_handle,
ending_handle=ending_handle,
attribute_type=uuid,
)
)
if response is None:
@@ -696,7 +772,9 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while reading characteristics: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
f'!!! unexpected error while reading characteristics: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return []
break
@@ -731,26 +809,27 @@ class Client:
if with_response:
response = await self.send_request(
ATT_Write_Request(
attribute_handle = attribute_handle,
attribute_value = value
attribute_handle=attribute_handle, attribute_value=value
)
)
if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code), response
ATT_PDU.error_name(response.error_code),
response,
)
else:
await self.send_command(
ATT_Write_Command(
attribute_handle = attribute_handle,
attribute_value = value
attribute_handle=attribute_handle, attribute_value=value
)
)
def on_gatt_pdu(self, att_pdu):
logger.debug(f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}')
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
if att_pdu.op_code in ATT_RESPONSES:
if self.pending_request is None:
# Not expected!
@@ -759,9 +838,13 @@ class Client:
# Sanity check: the response should match the pending request unless it is an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace('_REQUEST', '_RESPONSE')
expected_response_name = self.pending_request.name.replace(
'_REQUEST', '_RESPONSE'
)
if att_pdu.name != expected_response_name:
logger.warning(f'!!! mismatched response: expected {expected_response_name}')
logger.warning(
f'!!! mismatched response: expected {expected_response_name}'
)
return
# Return the response to the coroutine that is waiting for it
@@ -772,11 +855,15 @@ class Client:
if handler is not None:
handler(att_pdu)
else:
logger.warning(f'{color(f"--- Ignoring GATT Response from [0x{self.connection.handle:04X}]:", "red")} {att_pdu}')
logger.warning(
f'{color(f"--- Ignoring GATT Response from [0x{self.connection.handle:04X}]:", "red")} {att_pdu}'
)
def on_att_handle_value_notification(self, notification):
# Call all subscribers
subscribers = self.notification_subscribers.get(notification.attribute_handle, [])
subscribers = self.notification_subscribers.get(
notification.attribute_handle, []
)
if not subscribers:
logger.warning('!!! received notification with no subscriber')
for subscriber in subscribers:

View File

@@ -53,11 +53,15 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
class Server(EventEmitter):
def __init__(self, device):
super().__init__()
self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = (
GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
)
self.subscribers = (
{}
) # Map of subscriber states by connection handle and attribute handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
@@ -72,8 +76,10 @@ class Server(EventEmitter):
def get_advertising_service_data(self):
return {
attribute: data for attribute in self.attributes
if isinstance(attribute, Service) and (data := attribute.get_advertising_data())
attribute: data
for attribute in self.attributes
if isinstance(attribute, Service)
and (data := attribute.get_advertising_data())
}
def get_attribute(self, handle):
@@ -149,7 +155,9 @@ class Server(EventEmitter):
def add_attribute(self, attribute):
# Assign a handle to this attribute
attribute.handle = self.next_handle()
attribute.end_group_handle = attribute.handle # TODO: keep track of descriptors in the group
attribute.end_group_handle = (
attribute.handle
) # TODO: keep track of descriptors in the group
# Add this attribute to the list
self.attributes.append(attribute)
@@ -178,17 +186,25 @@ class Server(EventEmitter):
# If the characteristic supports subscriptions, add a CCCD descriptor
# unless there is one already
if (
characteristic.properties & (Characteristic.NOTIFY | Characteristic.INDICATE) and
characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR) is None
characteristic.properties
& (Characteristic.NOTIFY | Characteristic.INDICATE)
and characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
is None
):
self.add_attribute(
Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
Attribute.READABLE | Attribute.WRITEABLE,
CharacteristicValue(
read=lambda connection, characteristic=characteristic: self.read_cccd(connection, characteristic),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(connection, characteristic, value)
)
read=lambda connection, characteristic=characteristic: self.read_cccd(
connection, characteristic
),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
)
)
@@ -215,7 +231,9 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value):
logger.debug(f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}')
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}'
)
# Sanity check
if len(value) != 2:
@@ -225,13 +243,23 @@ class Server(EventEmitter):
cccds = self.subscribers.setdefault(connection.handle, {})
cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}')
notify_enabled = (value[0] & 0x01 != 0)
indicate_enabled = (value[0] & 0x02 != 0)
characteristic.emit('subscription', connection, notify_enabled, indicate_enabled)
self.emit('characteristic_subscription', connection, characteristic, notify_enabled, indicate_enabled)
notify_enabled = value[0] & 0x01 != 0
indicate_enabled = value[0] & 0x02 != 0
characteristic.emit(
'subscription', connection, notify_enabled, indicate_enabled
)
self.emit(
'characteristic_subscription',
connection,
characteristic,
notify_enabled,
indicate_enabled,
)
def send_response(self, connection, response):
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}')
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, value=None, force=False):
@@ -243,25 +271,32 @@ class Server(EventEmitter):
return
cccd = subscribers.get(attribute.handle)
if not cccd:
logger.debug(f'not notifying, no subscribers for handle {attribute.handle:04X}')
logger.debug(
f'not notifying, no subscribers for handle {attribute.handle:04X}'
)
return
if len(cccd) != 2 or (cccd[0] & 0x01 == 0):
logger.debug(f'not notifying, cccd={cccd.hex()}')
return
# Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]
value = value[: connection.att_mtu - 3]
# Notify
notification = ATT_Handle_Value_Notification(
attribute_handle = attribute.handle,
attribute_value = value
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
)
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
@@ -273,46 +308,60 @@ class Server(EventEmitter):
return
cccd = subscribers.get(attribute.handle)
if not cccd:
logger.debug(f'not indicating, no subscribers for handle {attribute.handle:04X}')
logger.debug(
f'not indicating, no subscribers for handle {attribute.handle:04X}'
)
return
if len(cccd) != 2 or (cccd[0] & 0x02 == 0):
logger.debug(f'not indicating, cccd={cccd.hex()}')
return
# Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]
value = value[: connection.att_mtu - 3]
# Indicate
indication = ATT_Handle_Value_Indication(
attribute_handle = attribute.handle,
attribute_value = value
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
)
logger.debug(f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}')
# Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[connection.handle]:
assert(self.pending_confirmations[connection.handle] is None)
assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response
self.pending_confirmations[connection.handle] = asyncio.get_running_loop().create_future()
self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
try:
self.send_gatt_pdu(connection.handle, indication.to_bytes())
await asyncio.wait_for(self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT)
await asyncio.wait_for(
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}')
finally:
self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers(self, indicate, attribute, value=None, force=False):
async def notify_or_indicate_subscribers(
self, indicate, attribute, value=None, force=False
):
# Get all the connections for which there's at least one subscription
connections = [
connection for connection in [
connection
for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
@@ -323,10 +372,12 @@ class Server(EventEmitter):
# Indicate or notify for each connection
if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
await asyncio.wait([
asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
])
await asyncio.wait(
[
asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
]
)
async def notify_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
@@ -352,17 +403,17 @@ class Server(EventEmitter):
except ATT_Error as error:
logger.debug(f'normal exception returned by handler: {error}')
response = ATT_Error_Response(
request_opcode_in_error = att_pdu.op_code,
attribute_handle_in_error = error.att_handle,
error_code = error.error_code
request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error=error.att_handle,
error_code=error.error_code,
)
self.send_response(connection, response)
except Exception as error:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
response = ATT_Error_Response(
request_opcode_in_error = att_pdu.op_code,
attribute_handle_in_error = 0x0000,
error_code = ATT_UNLIKELY_ERROR_ERROR
request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error=0x0000,
error_code=ATT_UNLIKELY_ERROR_ERROR,
)
self.send_response(connection, response)
raise error
@@ -373,7 +424,9 @@ class Server(EventEmitter):
self.on_att_request(connection, att_pdu)
else:
# Just ignore
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
logger.warning(
f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}'
)
#######################################################
# ATT handlers
@@ -382,11 +435,13 @@ class Server(EventEmitter):
'''
Handler for requests without a more specific handler
'''
logger.warning(f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}')
logger.warning(
f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}'
)
response = ATT_Error_Response(
request_opcode_in_error = pdu.op_code,
attribute_handle_in_error = 0x0000,
error_code = ATT_REQUEST_NOT_SUPPORTED_ERROR
request_opcode_in_error=pdu.op_code,
attribute_handle_in_error=0x0000,
error_code=ATT_REQUEST_NOT_SUPPORTED_ERROR,
)
self.send_response(connection, response)
@@ -394,7 +449,9 @@ class Server(EventEmitter):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu))
self.send_response(
connection, ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
)
# Compute the final MTU
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
@@ -411,12 +468,18 @@ class Server(EventEmitter):
'''
# Check the request parameters
if request.starting_handle == 0 or request.starting_handle > request.ending_handle:
self.send_response(connection, ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_INVALID_HANDLE_ERROR
))
if (
request.starting_handle == 0
or request.starting_handle > request.ending_handle
):
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
),
)
return
# Build list of returned attributes
@@ -424,9 +487,10 @@ class Server(EventEmitter):
attributes = []
uuid_size = 0
for attribute in (
attribute for attribute in self.attributes if
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
):
# TODO: check permissions
@@ -453,14 +517,14 @@ class Server(EventEmitter):
for attribute in attributes
]
response = ATT_Find_Information_Response(
format = 1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2,
information_data = b''.join(information_data_list)
format=1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2,
information_data=b''.join(information_data_list),
)
else:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
@@ -474,12 +538,13 @@ class Server(EventEmitter):
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle and
attribute.type == request.attribute_type and
attribute.read_value(connection) == request.attribute_value and
pdu_space_available >= 4
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and attribute.read_value(connection) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -494,22 +559,24 @@ class Server(EventEmitter):
if attribute.type in {
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
}:
# Part of a group
group_end_handle = attribute.end_group_handle
else:
# Not part of a group
group_end_handle = attribute.handle
handles_information_list.append(struct.pack('<HH', attribute.handle, group_end_handle))
handles_information_list.append(
struct.pack('<HH', attribute.handle, group_end_handle)
)
response = ATT_Find_By_Type_Value_Response(
handles_information_list = b''.join(handles_information_list)
handles_information_list=b''.join(handles_information_list)
)
else:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
@@ -522,11 +589,12 @@ class Server(EventEmitter):
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
attribute.type == request.attribute_type and
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle and
pdu_space_available
attribute
for attribute in self.attributes
if attribute.type == request.attribute_type
and attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and pdu_space_available
):
# TODO: check permissions
@@ -550,16 +618,17 @@ class Server(EventEmitter):
pdu_space_available -= entry_size
if attributes:
attribute_data_list = [struct.pack('<H', handle) + value for handle, value in attributes]
attribute_data_list = [
struct.pack('<H', handle) + value for handle, value in attributes
]
response = ATT_Read_By_Type_Response(
length = entry_size,
attribute_data_list = b''.join(attribute_data_list)
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
)
else:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
@@ -573,14 +642,12 @@ class Server(EventEmitter):
# TODO: check permissions
value = attribute.read_value(connection)
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(
attribute_value = value[:value_size]
)
response = ATT_Read_Response(attribute_value=value[:value_size])
else:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
@@ -594,26 +661,30 @@ class Server(EventEmitter):
value = attribute.read_value(connection)
if request.value_offset > len(value):
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
)
else:
part_size = min(connection.att_mtu - 1, len(value) - request.value_offset)
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
)
response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
part_attribute_value=value[
request.value_offset : request.value_offset + part_size
]
)
else:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
@@ -624,12 +695,12 @@ class Server(EventEmitter):
if request.attribute_group_type not in {
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_INCLUDE_ATTRIBUTE_TYPE
GATT_INCLUDE_ATTRIBUTE_TYPE,
}:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_UNSUPPORTED_GROUP_TYPE_ERROR
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
)
self.send_response(connection, response)
return
@@ -637,11 +708,12 @@ class Server(EventEmitter):
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
attribute.type == request.attribute_group_type and
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle and
pdu_space_available
attribute
for attribute in self.attributes
if attribute.type == request.attribute_group_type
and attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and pdu_space_available
):
# Check the attribute value size
attribute_value = attribute.read_value(connection)
@@ -659,7 +731,9 @@ class Server(EventEmitter):
break
# Add the attribute to the list
attributes.append((attribute.handle, attribute.end_group_handle, attribute_value))
attributes.append(
(attribute.handle, attribute.end_group_handle, attribute_value)
)
pdu_space_available -= entry_size
if attributes:
@@ -668,14 +742,14 @@ class Server(EventEmitter):
for handle, end_group_handle, value in attributes
]
response = ATT_Read_By_Group_Type_Response(
length = len(attribute_data_list[0]),
attribute_data_list = b''.join(attribute_data_list)
length=len(attribute_data_list[0]),
attribute_data_list=b''.join(attribute_data_list),
)
else:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
@@ -688,22 +762,28 @@ class Server(EventEmitter):
# Check that the attribute exists
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(connection, ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR
))
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
),
)
return
# TODO: check permissions
# Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response(connection, ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_ATTRIBUTE_LENGTH_ERROR
))
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
),
)
return
# Accept the value
@@ -740,7 +820,9 @@ class Server(EventEmitter):
'''
if self.pending_confirmations[connection.handle] is None:
# Not expected!
logger.warning('!!! unexpected confirmation, there is no pending indication')
logger.warning(
'!!! unexpected confirmation, there is no pending indication'
)
return
self.pending_confirmations[connection.handle].set_result(None)

File diff suppressed because it is too large Load Diff

View File

@@ -29,20 +29,17 @@ from .l2cap import (
L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame,
L2CAP_Connection_Response
L2CAP_Connection_Response,
)
from .hci import (
HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler
HCI_AclDataPacketAssembler,
)
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM
from .avdtp import (
MessageAssembler as AVDTP_MessageAssembler,
AVDTP_PSM
)
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
# -----------------------------------------------------------------------------
# Logging
@@ -53,8 +50,8 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
PSM_NAMES = {
RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP'
SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP'
# TODO: add more PSM values
}
@@ -63,11 +60,11 @@ PSM_NAMES = {
class PacketTracer:
class AclStream:
def __init__(self, analyzer):
self.analyzer = analyzer
self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
@@ -78,7 +75,10 @@ class PacketTracer:
elif l2cap_pdu.cid == SMP_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command)
elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID:
elif (
l2cap_pdu.cid == L2CAP_SIGNALING_CID
or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID
):
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame)
@@ -86,7 +86,10 @@ class PacketTracer:
if control_frame.code == L2CAP_CONNECTION_REQUEST:
self.psms[control_frame.source_cid] = control_frame.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
if control_frame.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
if (
control_frame.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid):
# Found a pending connection
@@ -94,8 +97,14 @@ class PacketTracer:
# For AVDTP connections, create a packet assembler for each direction
if psm == AVDTP_PSM:
self.avdtp_assemblers[control_frame.source_cid] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[control_frame.destination_cid] = AVDTP_MessageAssembler(self.peer.on_avdtp_message)
self.avdtp_assemblers[
control_frame.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
control_frame.destination_cid
] = AVDTP_MessageAssembler(
self.peer.on_avdtp_message
)
else:
# Try to find the PSM associated with this PDU
@@ -107,31 +116,39 @@ class PacketTracer:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM:
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}')
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler:
assembler.on_pdu(l2cap_pdu.payload)
else:
psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}')
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}'
)
else:
self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(self, transaction_label, message):
self.analyzer.emit(f'{color("AVDTP", "green")} [{transaction_label}] {message}')
self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def feed_packet(self, packet):
self.packet_assembler.feed_packet(packet)
class Analyzer:
def __init__(self, label, emit_message):
self.label = label
self.label = label
self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle
self.peer = None # Analyzer in the other direction
self.acl_streams = {} # ACL streams, by connection handle
self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle):
logger.info(f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}')
logger.info(
f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}'
)
stream = PacketTracer.AclStream(self)
self.acl_streams[connection_handle] = stream
@@ -144,7 +161,9 @@ class PacketTracer:
def end_acl_stream(self, connection_handle):
if connection_handle in self.acl_streams:
logger.info(f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}')
logger.info(
f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}'
)
del self.acl_streams[connection_handle]
# Let the other forwarder know so it can cleanup its stream as well
@@ -176,9 +195,13 @@ class PacketTracer:
self,
host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info
emit_message=logger.info,
):
self.host_to_controller_analyzer = PacketTracer.Analyzer(host_to_controller_label, emit_message)
self.controller_to_host_analyzer = PacketTracer.Analyzer(controller_to_host_label, emit_message)
self.host_to_controller_analyzer = PacketTracer.Analyzer(
host_to_controller_label, emit_message
)
self.controller_to_host_analyzer = PacketTracer.Analyzer(
controller_to_host_label, emit_message
)
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer

View File

@@ -34,9 +34,9 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class HfpProtocol:
def __init__(self, dlc):
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
self.lines_available = asyncio.Event()
dlc.sink = self.feed
@@ -52,7 +52,7 @@ class HfpProtocol:
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
line = self.buffer[:separator].strip()
self.buffer = self.buffer[separator + 1:]
self.buffer = self.buffer[separator + 1 :]
if len(line) > 0:
self.on_line(line)
@@ -79,16 +79,16 @@ class HfpProtocol:
async def initialize_service(self):
# Perform Service Level Connection Initialization
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
line = await(self.next_line())
line = await(self.next_line())
line = await (self.next_line())
line = await (self.next_line())
self.send_command_line('AT+CIND=?')
line = await(self.next_line())
line = await(self.next_line())
line = await (self.next_line())
line = await (self.next_line())
self.send_command_line('AT+CIND?')
line = await(self.next_line())
line = await(self.next_line())
line = await (self.next_line())
line = await (self.next_line())
self.send_command_line('AT+CMER=3,0,0,1')
line = await(self.next_line())
line = await (self.next_line())

View File

@@ -36,21 +36,25 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
# fmt: on
# -----------------------------------------------------------------------------
class Connection:
def __init__(self, host, handle, role, peer_address, transport):
self.host = host
self.handle = handle
self.role = role
self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
self.host = host
self.handle = handle
self.role = role
self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet)
@@ -62,29 +66,29 @@ class Connection:
# -----------------------------------------------------------------------------
class Host(EventEmitter):
def __init__(self, controller_source = None, controller_sink = None):
def __init__(self, controller_source=None, controller_sink=None):
super().__init__()
self.hci_sink = None
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.pending_command = None
self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
self.hci_sink = None
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.pending_command = None
self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
self.suggested_max_tx_octets = 251 # Max allowed
self.suggested_max_tx_time = 2120 # Max allowed
self.command_semaphore = asyncio.Semaphore(1)
self.long_term_key_provider = None
self.link_key_provider = None
self.pairing_io_capability_provider = None # Classic only
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
self.suggested_max_tx_octets = 251 # Max allowed
self.suggested_max_tx_time = 2120 # Max allowed
self.command_semaphore = asyncio.Semaphore(1)
self.long_term_key_provider = None
self.link_key_provider = None
self.pairing_io_capability_provider = None # Classic only
# Connect to the source and sink if specified
if controller_source:
@@ -96,30 +100,51 @@ class Host(EventEmitter):
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True
response = await self.send_command(HCI_Read_Local_Supported_Commands_Command(), check_result=True)
response = await self.send_command(
HCI_Read_Local_Supported_Commands_Command(), check_result=True
)
self.local_supported_commands = response.return_parameters.supported_commands
if self.supports_command(HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(HCI_LE_Read_Local_Supported_Features_Command(), check_result=True)
self.local_le_features = struct.unpack('<Q', response.return_parameters.le_features)[0]
response = await self.send_command(
HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
)
self.local_le_features = struct.unpack(
'<Q', response.return_parameters.le_features
)[0]
if self.supports_command(HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
response = await self.send_command(HCI_Read_Local_Version_Information_Command(), check_result=True)
response = await self.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
self.local_version = response.return_parameters
await self.send_command(HCI_Set_Event_Mask_Command(event_mask = bytes.fromhex('FFFFFFFFFFFFFF3F')))
await self.send_command(
HCI_Set_Event_Mask_Command(event_mask=bytes.fromhex('FFFFFFFFFFFFFF3F'))
)
if self.local_version is not None and self.local_version.hci_version <= HCI_VERSION_BLUETOOTH_CORE_4_0:
if (
self.local_version is not None
and self.local_version.hci_version <= HCI_VERSION_BLUETOOTH_CORE_4_0
):
# Some older controllers don't like event masks with bits they don't understand
le_event_mask = bytes.fromhex('1F00000000000000')
else:
le_event_mask = bytes.fromhex('FFFFF00000000000')
await self.send_command(HCI_LE_Set_Event_Mask_Command(le_event_mask = le_event_mask))
await self.send_command(
HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
)
if self.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(HCI_Read_Buffer_Size_Command(), check_result=True)
self.hc_acl_data_packet_length = response.return_parameters.hc_acl_data_packet_length
self.hc_total_num_acl_data_packets = response.return_parameters.hc_total_num_acl_data_packets
response = await self.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
self.hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
self.hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
logger.debug(
f'HCI ACL flow control: hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
@@ -127,9 +152,15 @@ class Host(EventEmitter):
)
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(HCI_LE_Read_Buffer_Size_Command(), check_result=True)
self.hc_le_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets
response = await self.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
self.hc_le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length
)
self.hc_total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets
)
logger.debug(
f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
@@ -137,28 +168,33 @@ class Host(EventEmitter):
)
if (
response.return_parameters.hc_le_acl_data_packet_length == 0 or
response.return_parameters.hc_total_num_le_acl_data_packets == 0
response.return_parameters.hc_le_acl_data_packet_length == 0
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
):
# LE and Classic share the same values
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = self.hc_total_num_acl_data_packets
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = (
self.hc_total_num_acl_data_packets
)
if (
self.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND) and
self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND)
):
response = await self.send_command(HCI_LE_Read_Suggested_Default_Data_Length_Command())
if self.supports_command(
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
) and self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await self.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets or
suggested_max_tx_time != self.suggested_max_tx_time
suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets = self.suggested_max_tx_octets,
suggested_max_tx_time = self.suggested_max_tx_time
))
await self.send_command(
HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time=self.suggested_max_tx_time,
)
)
self.reset_done = True
@@ -205,12 +241,16 @@ class Host(EventEmitter):
status = response.return_parameters.status
if status != HCI_SUCCESS:
logger.warning(f'{command.name} failed ({HCI_Constant.error_name(status)})')
logger.warning(
f'{command.name} failed ({HCI_Constant.error_name(status)})'
)
raise HCI_Error(status)
return response
except Exception as error:
logger.warning(f'{color("!!! Exception while sending HCI packet:", "red")} {error}')
logger.warning(
f'{color("!!! Exception while sending HCI packet:", "red")} {error}'
)
raise error
finally:
self.pending_command = None
@@ -234,13 +274,15 @@ class Host(EventEmitter):
# TODO: support different LE/Classic lengths
data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
acl_packet = HCI_AclDataPacket(
connection_handle = connection_handle,
pb_flag = pb_flag,
bc_flag = 0,
data_total_length = data_total_length,
data = l2cap_pdu[offset:offset + data_total_length]
connection_handle=connection_handle,
pb_flag=pb_flag,
bc_flag=0,
data_total_length=data_total_length,
data=l2cap_pdu[offset : offset + data_total_length],
)
logger.debug(
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
)
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}')
self.queue_acl_packet(acl_packet)
pb_flag = 1
offset += data_total_length
@@ -251,11 +293,16 @@ class Host(EventEmitter):
self.check_acl_packet_queue()
if len(self.acl_packet_queue):
logger.debug(f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue')
logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue'
)
def check_acl_packet_queue(self):
# Send all we can (TODO: support different LE/Classic limits)
while len(self.acl_packet_queue) > 0 and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets:
while (
len(self.acl_packet_queue) > 0
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
):
packet = self.acl_packet_queue.pop()
self.send_hci_packet(packet)
self.acl_packets_in_flight += 1
@@ -267,7 +314,9 @@ class Host(EventEmitter):
if value == command:
# Check if the flag is set
if octet < len(self.local_supported_commands) and flag_position < 8:
return (self.local_supported_commands[octet] & (1 << flag_position)) != 0
return (
self.local_supported_commands[octet] & (1 << flag_position)
) != 0
return False
@@ -289,15 +338,17 @@ class Host(EventEmitter):
@property
def supported_le_features(self):
return [feature for feature in range(64) if self.local_le_features & (1 << feature)]
return [
feature for feature in range(64) if self.local_le_features & (1 << feature)
]
# Packet Sink protocol (packets coming from the controller via HCI)
def on_packet(self, packet):
hci_packet = HCI_Packet.from_bytes(packet)
if self.ready or (
hci_packet.hci_packet_type == HCI_EVENT_PACKET and
hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT and
hci_packet.command_opcode == HCI_RESET_COMMAND
hci_packet.hci_packet_type == HCI_EVENT_PACKET
and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT
and hci_packet.command_opcode == HCI_RESET_COMMAND
):
self.on_hci_packet(hci_packet)
else:
@@ -336,7 +387,9 @@ class Host(EventEmitter):
if self.pending_response:
# Check that it is what we were expecting
if self.pending_command.op_code != event.command_opcode:
logger.warning(f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}')
logger.warning(
f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}'
)
self.pending_response.set_result(event)
else:
@@ -364,7 +417,11 @@ class Host(EventEmitter):
self.acl_packets_in_flight -= total_packets
self.check_acl_packet_queue()
else:
logger.warning(color(f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight'))
logger.warning(
color(
f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight'
)
)
self.acl_packets_in_flight = 0
# Classic only
@@ -381,18 +438,26 @@ class Host(EventEmitter):
# Check if this is a cancellation
if event.status == HCI_SUCCESS:
# Create/update the connection
logger.debug(f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}')
logger.debug(
f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}'
)
connection = self.connections.get(event.connection_handle)
if connection is None:
connection = Connection(self, event.connection_handle, event.role, event.peer_address, BT_LE_TRANSPORT)
connection = Connection(
self,
event.connection_handle,
event.role,
event.peer_address,
BT_LE_TRANSPORT,
)
self.connections[event.connection_handle] = connection
# Notify the client
connection_parameters = ConnectionParameters(
event.connection_interval,
event.peripheral_latency,
event.supervision_timeout
event.supervision_timeout,
)
self.emit(
'connection',
@@ -401,13 +466,15 @@ class Host(EventEmitter):
event.peer_address,
None,
event.role,
connection_parameters
connection_parameters,
)
else:
logger.debug(f'### CONNECTION FAILED: {event.status}')
# Notify the listeners
self.emit('connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status)
self.emit(
'connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status
)
def on_hci_le_enhanced_connection_complete_event(self, event):
# Just use the same implementation as for the non-enhanced event for now
@@ -416,11 +483,19 @@ class Host(EventEmitter):
def on_hci_connection_complete_event(self, event):
if event.status == HCI_SUCCESS:
# Create/update the connection
logger.debug(f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}')
logger.debug(
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}'
)
connection = self.connections.get(event.connection_handle)
if connection is None:
connection = Connection(self, event.connection_handle, BT_CENTRAL_ROLE, event.bd_addr, BT_BR_EDR_TRANSPORT)
connection = Connection(
self,
event.connection_handle,
BT_CENTRAL_ROLE,
event.bd_addr,
BT_BR_EDR_TRANSPORT,
)
self.connections[event.connection_handle] = connection
# Notify the client
@@ -431,13 +506,15 @@ class Host(EventEmitter):
event.bd_addr,
None,
BT_CENTRAL_ROLE,
None
None,
)
else:
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
# Notify the client
self.emit('connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status)
self.emit(
'connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status
)
def on_hci_disconnection_complete_event(self, event):
# Find the connection
@@ -446,7 +523,9 @@ class Host(EventEmitter):
return
if event.status == HCI_SUCCESS:
logger.debug(f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}')
logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}'
)
del self.connections[event.connection_handle]
# Notify the listeners
@@ -467,11 +546,15 @@ class Host(EventEmitter):
connection_parameters = ConnectionParameters(
event.connection_interval,
event.peripheral_latency,
event.supervision_timeout
event.supervision_timeout,
)
self.emit(
'connection_parameters_update', connection.handle, connection_parameters
)
self.emit('connection_parameters_update', connection.handle, connection_parameters)
else:
self.emit('connection_parameters_update_failure', connection.handle, event.status)
self.emit(
'connection_parameters_update_failure', connection.handle, event.status
)
def on_hci_le_phy_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None:
@@ -501,13 +584,13 @@ class Host(EventEmitter):
# TODO: delegate the decision
self.send_command_sync(
HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
connection_handle = event.connection_handle,
interval_min = event.interval_min,
interval_max = event.interval_max,
latency = event.latency,
timeout = event.timeout,
min_ce_length = 0,
max_ce_length = 0
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
latency=event.latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,
)
)
@@ -522,18 +605,16 @@ class Host(EventEmitter):
long_term_key = None
else:
long_term_key = await self.long_term_key_provider(
connection.handle,
event.random_number,
event.encryption_diversifier
connection.handle, event.random_number, event.encryption_diversifier
)
if long_term_key:
response = HCI_LE_Long_Term_Key_Request_Reply_Command(
connection_handle = event.connection_handle,
long_term_key = long_term_key
connection_handle=event.connection_handle,
long_term_key=long_term_key,
)
else:
response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command(
connection_handle = event.connection_handle
connection_handle=event.connection_handle
)
await self.send_command(response)
@@ -548,10 +629,14 @@ class Host(EventEmitter):
def on_hci_role_change_event(self, event):
if event.status == HCI_SUCCESS:
logger.debug(f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}')
logger.debug(
f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}'
)
# TODO: lookup the connection and update the role
else:
logger.debug(f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}')
logger.debug(
f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}'
)
def on_hci_le_data_length_change_event(self, event):
self.emit(
@@ -560,7 +645,7 @@ class Host(EventEmitter):
event.max_tx_octets,
event.max_tx_time,
event.max_rx_octets,
event.max_rx_time
event.max_rx_time,
)
def on_hci_authentication_complete_event(self, event):
@@ -568,21 +653,35 @@ class Host(EventEmitter):
if event.status == HCI_SUCCESS:
self.emit('connection_authentication', event.connection_handle)
else:
self.emit('connection_authentication_failure', event.connection_handle, event.status)
self.emit(
'connection_authentication_failure',
event.connection_handle,
event.status,
)
def on_hci_encryption_change_event(self, event):
# Notify the client
if event.status == HCI_SUCCESS:
self.emit('connection_encryption_change', event.connection_handle, event.encryption_enabled)
self.emit(
'connection_encryption_change',
event.connection_handle,
event.encryption_enabled,
)
else:
self.emit('connection_encryption_failure', event.connection_handle, event.status)
self.emit(
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_encryption_key_refresh_complete_event(self, event):
# Notify the client
if event.status == HCI_SUCCESS:
self.emit('connection_encryption_key_refresh', event.connection_handle)
else:
self.emit('connection_encryption_key_refresh_failure', event.connection_handle, event.status)
self.emit(
'connection_encryption_key_refresh_failure',
event.connection_handle,
event.status,
)
def on_hci_link_supervision_timeout_changed_event(self, event):
pass
@@ -594,11 +693,15 @@ class Host(EventEmitter):
pass
def on_hci_link_key_notification_event(self, event):
logger.debug(f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}')
logger.debug(
f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}'
)
self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
def on_hci_simple_pairing_complete_event(self, event):
logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}')
logger.debug(
f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}'
)
# Notify the client
if event.status == HCI_SUCCESS:
self.emit('ssp_complete', event.bd_addr)
@@ -607,9 +710,7 @@ class Host(EventEmitter):
# For now, just refuse all requests
# TODO: delegate the decision
self.send_command_sync(
HCI_PIN_Code_Request_Negative_Reply_Command(
bd_addr = event.bd_addr
)
HCI_PIN_Code_Request_Negative_Reply_Command(bd_addr=event.bd_addr)
)
def on_hci_link_key_request_event(self, event):
@@ -621,12 +722,11 @@ class Host(EventEmitter):
link_key = await self.link_key_provider(event.bd_addr)
if link_key:
response = HCI_Link_Key_Request_Reply_Command(
bd_addr = event.bd_addr,
link_key = link_key
bd_addr=event.bd_addr, link_key=link_key
)
else:
response = HCI_Link_Key_Request_Negative_Reply_Command(
bd_addr = event.bd_addr
bd_addr=event.bd_addr
)
await self.send_command(response)
@@ -640,13 +740,19 @@ class Host(EventEmitter):
pass
def on_hci_user_confirmation_request_event(self, event):
self.emit('authentication_user_confirmation_request', event.bd_addr, event.numeric_value)
self.emit(
'authentication_user_confirmation_request',
event.bd_addr,
event.numeric_value,
)
def on_hci_user_passkey_request_event(self, event):
self.emit('authentication_user_passkey_request', event.bd_addr)
def on_hci_user_passkey_notification_event(self, event):
self.emit('authentication_user_passkey_notification', event.bd_addr, event.passkey)
self.emit(
'authentication_user_passkey_notification', event.bd_addr, event.passkey
)
def on_hci_inquiry_complete_event(self, event):
self.emit('inquiry_complete')
@@ -658,7 +764,7 @@ class Host(EventEmitter):
response.bd_addr,
response.class_of_device,
b'',
response.rssi
response.rssi,
)
def on_hci_extended_inquiry_result_event(self, event):
@@ -667,7 +773,7 @@ class Host(EventEmitter):
event.bd_addr,
event.class_of_device,
event.extended_inquiry_response,
event.rssi
event.rssi,
)
def on_hci_remote_name_request_complete_event(self, event):
@@ -677,4 +783,8 @@ class Host(EventEmitter):
self.emit('remote_name', event.bd_addr, event.remote_name)
def on_hci_remote_host_supported_features_notification_event(self, event):
self.emit('remote_host_supported_features', event.bd_addr, event.host_supported_features)
self.emit(
'remote_host_supported_features',
event.bd_addr,
event.host_supported_features,
)

View File

@@ -39,10 +39,10 @@ logger = logging.getLogger(__name__)
class PairingKeys:
class Key:
def __init__(self, value, authenticated=False, ediv=None, rand=None):
self.value = value
self.value = value
self.authenticated = authenticated
self.ediv = ediv
self.rand = rand
self.ediv = ediv
self.rand = rand
@classmethod
def from_dict(cls, key_dict):
@@ -65,13 +65,13 @@ class PairingKeys:
return key_dict
def __init__(self):
self.address_type = None
self.ltk = None
self.ltk_central = None
self.address_type = None
self.ltk = None
self.ltk_central = None
self.ltk_peripheral = None
self.irk = None
self.csrk = None
self.link_key = None # Classic
self.irk = None
self.csrk = None
self.link_key = None # Classic
@staticmethod
def key_from_dict(keys_dict, key_name):
@@ -83,13 +83,13 @@ class PairingKeys:
def from_dict(keys_dict):
keys = PairingKeys()
keys.address_type = keys_dict.get('address_type')
keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
keys.address_type = keys_dict.get('address_type')
keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
return keys
@@ -166,7 +166,7 @@ class KeyStore:
separator = ''
for (name, keys) in entries:
print(separator + prefix + color(name, 'yellow'))
keys.print(prefix = prefix + ' ')
keys.print(prefix=prefix + ' ')
separator = '\n'
@staticmethod
@@ -183,9 +183,9 @@ class KeyStore:
# -----------------------------------------------------------------------------
class JsonKeyStore(KeyStore):
APP_NAME = 'Bumble'
APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing'
APP_NAME = 'Bumble'
APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing'
DEFAULT_NAMESPACE = '__DEFAULT__'
def __init__(self, namespace, filename=None):
@@ -194,9 +194,9 @@ class JsonKeyStore(KeyStore):
if filename is None:
# Use a default for the current user
import appdirs
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR),
self.KEYS_DIR
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
json_filename = f'{self.namespace}.json'.lower().replace(':', '-')
self.filename = os.path.join(self.directory_name, json_filename)
@@ -262,7 +262,9 @@ class JsonKeyStore(KeyStore):
if namespace is None:
return []
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()]
return [
(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
]
async def delete_all(self):
db = await self.load()

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,7 @@ from bumble.hci import (
Address,
HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR
HCI_CONNECTION_TIMEOUT_ERROR,
)
# -----------------------------------------------------------------------------
@@ -55,7 +55,7 @@ class LocalLink:
'''
def __init__(self):
self.controllers = set()
self.controllers = set()
self.pending_connection = None
def add_controller(self, controller):
@@ -103,23 +103,30 @@ class LocalLink:
return
# Connect to the first controller with a matching address
if peripheral_controller := self.find_controller(le_create_connection_command.peer_address):
central_controller.on_link_peripheral_connection_complete(le_create_connection_command, HCI_SUCCESS)
if peripheral_controller := self.find_controller(
le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_SUCCESS
)
peripheral_controller.on_link_central_connected(central_address)
return
# No peripheral found
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
)
def connect(self, central_address, le_create_connection_command):
logger.debug(f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}')
logger.debug(
f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(self, central_address, peripheral_address, disconnect_command):
def on_disconnection_complete(
self, central_address, peripheral_address, disconnect_command
):
# Find the controller that initiated the disconnection
if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found')
@@ -127,16 +134,24 @@ class LocalLink:
# Disconnect from the first controller with a matching address
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected(central_address, disconnect_command.reason)
peripheral_controller.on_link_central_disconnected(
central_address, disconnect_command.reason
)
central_controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS)
central_controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}')
logger.debug(
f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}'
)
args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk):
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address):
@@ -152,15 +167,18 @@ class RemoteLink:
A Link implementation that communicates with other virtual controllers via a
WebSocket relay
'''
def __init__(self, uri):
self.controller = None
self.uri = uri
self.execution_queue = asyncio.Queue()
self.websocket = asyncio.get_running_loop().create_future()
self.rpc_result = None
self.pending_connection = None
self.central_connections = set() # List of addresses that we have connected to
self.peripheral_connections = set() # List of addresses that have connected to us
self.controller = None
self.uri = uri
self.execution_queue = asyncio.Queue()
self.websocket = asyncio.get_running_loop().create_future()
self.rpc_result = None
self.pending_connection = None
self.central_connections = set() # List of addresses that we have connected to
self.peripheral_connections = (
set()
) # List of addresses that have connected to us
# Connect and run asynchronously
asyncio.create_task(self.run_connection())
@@ -192,7 +210,9 @@ class RemoteLink:
try:
await item
except Exception as error:
logger.warning(f'{color("!!! Exception in async handler:", "red")} {error}')
logger.warning(
f'{color("!!! Exception in async handler:", "red")} {error}'
)
async def run_connection(self):
# Connect to the relay
@@ -227,7 +247,9 @@ class RemoteLink:
self.central_connections.remove(address)
if address in self.peripheral_connections:
self.controller.on_link_central_disconnected(address, HCI_CONNECTION_TIMEOUT_ERROR)
self.controller.on_link_central_disconnected(
address, HCI_CONNECTION_TIMEOUT_ERROR
)
self.peripheral_connections.remove(address)
async def on_unreachable_received(self, target):
@@ -244,7 +266,9 @@ class RemoteLink:
async def on_advertisement_message_received(self, sender, advertisement):
try:
self.controller.on_link_advertising_data(Address(sender), bytes.fromhex(advertisement))
self.controller.on_link_advertising_data(
Address(sender), bytes.fromhex(advertisement)
)
except Exception:
logger.exception('exception')
@@ -275,7 +299,9 @@ class RemoteLink:
# Notify the controller
logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
self.controller.on_link_peripheral_connection_complete(self.pending_connection, HCI_SUCCESS)
self.controller.on_link_peripheral_connection_complete(
self.pending_connection, HCI_SUCCESS
)
async def on_disconnect_message_received(self, sender, message):
# Notify the controller
@@ -296,7 +322,7 @@ class RemoteLink:
websocket = await self.websocket
# Create a future value to hold the eventual result
assert(self.rpc_result is None)
assert self.rpc_result is None
self.rpc_result = asyncio.get_running_loop().create_future()
# Send the command
@@ -345,16 +371,43 @@ class RemoteLink:
logger.warn('connection already pending')
return
self.pending_connection = le_create_connection_command
self.execute(partial(self.send_connection_request_to_relay, str(le_create_connection_command.peer_address)))
self.execute(
partial(
self.send_connection_request_to_relay,
str(le_create_connection_command.peer_address),
)
)
def on_disconnection_complete(self, disconnect_command):
self.controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS)
self.controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}')
self.execute(partial(self.send_targetted_message, peripheral_address, f'disconnect:reason={disconnect_command.reason}'))
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, disconnect_command)
logger.debug(
f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}'
)
self.execute(
partial(
self.send_targetted_message,
peripheral_address,
f'disconnect:reason={disconnect_command.reason}',
)
)
asyncio.get_running_loop().call_soon(
self.on_disconnection_complete, disconnect_command
)
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk):
asyncio.get_running_loop().call_soon(self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk)
self.execute(partial(self.send_targetted_message, peripheral_address, f'encrypted:ltk={ltk.hex()}'))
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
asyncio.get_running_loop().call_soon(
self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
)
self.execute(
partial(
self.send_targetted_message,
peripheral_address,
f'encrypted:ltk={ltk.hex()}',
)
)

View File

@@ -29,7 +29,7 @@ from ..gatt import (
TemplateService,
Characteristic,
CharacteristicValue,
PackedCharacteristicAdapter
PackedCharacteristicAdapter,
)
# -----------------------------------------------------------------------------
@@ -66,7 +66,8 @@ class AshaService(TemplateService):
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}')
f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}'
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
elif opcode == AshaService.OPCODE_STATUS:
@@ -79,34 +80,36 @@ class AshaService(TemplateService):
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]) +
bytes(self.hisyncid) +
bytes(AshaService.FEATURE_MAP) +
bytes(AshaService.RENDER_DELAY) +
bytes(AshaService.RESERVED_FOR_FUTURE_USE) +
bytes(AshaService.SUPPORTED_CODEC_ID)
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write)
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0])
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write)
CharacteristicValue(write=on_volume_write),
)
# TODO add real psm value
@@ -116,26 +119,39 @@ class AshaService(TemplateService):
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm)
struct.pack('<H', self.psm),
)
characteristics = [self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic]
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self):
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData([
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(GATT_ASHA_SERVICE)),
(AdvertisingData.SERVICE_DATA_16_BIT_UUID, bytes(GATT_ASHA_SERVICE) + bytes([
AshaService.PROTOCOL_VERSION,
self.capability,
]) + bytes(self.hisyncid[:4]))
])
AdvertisingData(
[
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(GATT_ASHA_SERVICE),
),
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)
+ bytes(
[
AshaService.PROTOCOL_VERSION,
self.capability,
]
)
+ bytes(self.hisyncid[:4]),
),
]
)
)

View File

@@ -23,7 +23,7 @@ from ..gatt import (
TemplateService,
Characteristic,
CharacteristicValue,
PackedCharacteristicAdapter
PackedCharacteristicAdapter,
)
@@ -38,9 +38,9 @@ class BatteryService(TemplateService):
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
CharacteristicValue(read=read_battery_level)
CharacteristicValue(read=read_battery_level),
),
format=BatteryService.BATTERY_LEVEL_FORMAT
format=BatteryService.BATTERY_LEVEL_FORMAT,
)
super().__init__([self.battery_level_characteristic])
@@ -52,10 +52,11 @@ class BatteryServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BATTERY_LEVEL_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
self.battery_level = PackedCharacteristicAdapter(
characteristics[0],
format=BatteryService.BATTERY_LEVEL_FORMAT
characteristics[0], format=BatteryService.BATTERY_LEVEL_FORMAT
)
else:
self.battery_level = None

View File

@@ -33,7 +33,7 @@ from ..gatt import (
TemplateService,
Characteristic,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter
UTF8CharacteristicAdapter,
)
@@ -63,38 +63,37 @@ class DeviceInformationService(TemplateService):
# TODO: pnp_id
):
characteristics = [
Characteristic(
uuid,
Characteristic.READ,
Characteristic.READABLE,
field
)
Characteristic(uuid, Characteristic.READ, Characteristic.READABLE, field)
for (field, uuid) in (
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
(model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
(serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
(model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
(serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
(hardware_revision, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
(firmware_revision, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
(software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC)
(software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
)
if field is not None
]
if system_id is not None:
characteristics.append(Characteristic(
GATT_SYSTEM_ID_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
self.pack_system_id(*system_id)
))
characteristics.append(
Characteristic(
GATT_SYSTEM_ID_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
self.pack_system_id(*system_id),
)
)
if ieee_regulatory_certification_data_list is not None:
characteristics.append(Characteristic(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
ieee_regulatory_certification_data_list
))
characteristics.append(
Characteristic(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
ieee_regulatory_certification_data_list,
)
)
super().__init__(characteristics)
@@ -108,11 +107,11 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
for (field, uuid) in (
('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
('hardware_revision', GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
('firmware_revision', GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC)
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
):
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
characteristic = UTF8CharacteristicAdapter(characteristics[0])
@@ -120,16 +119,20 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
characteristic = None
self.__setattr__(field, characteristic)
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_SYSTEM_ID_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_SYSTEM_ID_CHARACTERISTIC
):
self.system_id = DelegatedCharacteristicAdapter(
characteristics[0],
encode=lambda v: DeviceInformationService.pack_system_id(*v),
decode=DeviceInformationService.unpack_system_id
decode=DeviceInformationService.unpack_system_id,
)
else:
self.system_id = None
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC
):
self.ieee_regulatory_certification_data_list = characteristics[0]
else:
self.ieee_regulatory_certification_data_list = None

View File

@@ -30,25 +30,25 @@ from ..gatt import (
Characteristic,
CharacteristicValue,
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter
PackedCharacteristicAdapter,
)
# -----------------------------------------------------------------------------
class HeartRateService(TemplateService):
UUID = GATT_HEART_RATE_SERVICE
UUID = GATT_HEART_RATE_SERVICE
HEART_RATE_CONTROL_POINT_FORMAT = 'B'
CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01
CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01
class BodySensorLocation(IntEnum):
OTHER = 0,
CHEST = 1,
WRIST = 2,
FINGER = 3,
HAND = 4,
EAR_LOBE = 5,
FOOT = 6
OTHER = (0,)
CHEST = (1,)
WRIST = (2,)
FINGER = (3,)
HAND = (4,)
EAR_LOBE = (5,)
FOOT = 6
class HeartRateMeasurement:
def __init__(
@@ -56,12 +56,14 @@ class HeartRateService(TemplateService):
heart_rate,
sensor_contact_detected=None,
energy_expended=None,
rr_intervals=None
rr_intervals=None,
):
if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range')
if energy_expended is not None and (energy_expended < 0 or energy_expended > 0xFFFF):
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
):
raise ValueError('energy_expended out of range')
if rr_intervals:
@@ -69,10 +71,10 @@ class HeartRateService(TemplateService):
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise ValueError('rr_intervals out of range')
self.heart_rate = heart_rate
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
@classmethod
def from_bytes(cls, data):
@@ -87,7 +89,7 @@ class HeartRateService(TemplateService):
offset += 1
if flags & (1 << 2):
sensor_contact_detected = (flags & (1 << 1) != 0)
sensor_contact_detected = flags & (1 << 1) != 0
else:
sensor_contact_detected = None
@@ -119,38 +121,42 @@ class HeartRateService(TemplateService):
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
if self.energy_expended is not None:
flags |= (1 << 3)
flags |= 1 << 3
data += struct.pack('<H', self.energy_expended)
if self.rr_intervals:
flags |= (1 << 4)
data += b''.join([
struct.pack('<H', int(rr_interval * 1024))
for rr_interval in self.rr_intervals
])
flags |= 1 << 4
data += b''.join(
[
struct.pack('<H', int(rr_interval * 1024))
for rr_interval in self.rr_intervals
]
)
return bytes([flags]) + data
def __str__(self):
return f'HeartRateMeasurement(heart_rate={self.heart_rate},'\
f' sensor_contact_detected={self.sensor_contact_detected},'\
f' energy_expended={self.energy_expended},'\
return (
f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})'
)
def __init__(
self,
read_heart_rate_measurement,
body_sensor_location=None,
reset_energy_expended=None
reset_energy_expended=None,
):
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.NOTIFY,
0,
CharacteristicValue(read=read_heart_rate_measurement)
CharacteristicValue(read=read_heart_rate_measurement),
),
encode=lambda value: bytes(value)
encode=lambda value: bytes(value),
)
characteristics = [self.heart_rate_measurement_characteristic]
@@ -159,11 +165,12 @@ class HeartRateService(TemplateService):
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([int(body_sensor_location)])
bytes([int(body_sensor_location)]),
)
characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value):
if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None:
@@ -176,9 +183,9 @@ class HeartRateService(TemplateService):
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE,
Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value)
CharacteristicValue(write=write_heart_rate_control_point_value),
),
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
characteristics.append(self.heart_rate_control_point_characteristic)
@@ -192,30 +199,38 @@ class HeartRateServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
self.heart_rate_measurement = DelegatedCharacteristicAdapter(
characteristics[0],
decode=HeartRateService.HeartRateMeasurement.from_bytes
decode=HeartRateService.HeartRateMeasurement.from_bytes,
)
else:
self.heart_rate_measurement = None
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
):
self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0])
decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
)
else:
self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
):
self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0],
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
else:
self.heart_rate_control_point = None
async def reset_energy_expended(self):
if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value(HeartRateService.RESET_ENERGY_EXPENDED)
return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED
)

View File

@@ -32,6 +32,8 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
RFCOMM_PSM = 0x0003
@@ -98,6 +100,8 @@ RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# fmt: on
# -----------------------------------------------------------------------------
def fcs(buffer):
@@ -109,11 +113,11 @@ def fcs(buffer):
# -----------------------------------------------------------------------------
class RFCOMM_Frame:
def __init__(self, type, c_r, dlci, p_f, information = b'', with_credits = False):
self.type = type
self.c_r = c_r
self.dlci = dlci
self.p_f = p_f
def __init__(self, type, c_r, dlci, p_f, information=b'', with_credits=False):
self.type = type
self.c_r = c_r
self.dlci = dlci
self.p_f = p_f
self.information = information
length = len(information)
if with_credits:
@@ -124,8 +128,8 @@ class RFCOMM_Frame:
else:
# 1-byte length indicator
self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = type | (p_f << 4)
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = type | (p_f << 4)
if type == RFCOMM_UIH_FRAME:
self.fcs = fcs(bytes([self.address, self.control]))
else:
@@ -144,13 +148,16 @@ class RFCOMM_Frame:
value = data[2:]
else:
length = (data[3] << 7) & (length >> 1)
value = data[3:3 + length]
value = data[3 : 3 + length]
return (type, c_r, value)
@staticmethod
def make_mcc(type, c_r, data):
return bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1]) + data
return (
bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
+ data
)
@staticmethod
def sabm(c_r, dlci):
@@ -169,8 +176,10 @@ class RFCOMM_Frame:
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
@staticmethod
def uih(c_r, dlci, information, p_f = 0):
return RFCOMM_Frame(RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits = (p_f == 1))
def uih(c_r, dlci, information, p_f=0):
return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod
def from_bytes(data):
@@ -197,7 +206,12 @@ class RFCOMM_Frame:
return frame
def __bytes__(self):
return bytes([self.address, self.control]) + self.length + self.information + bytes([self.fcs])
return (
bytes([self.address, self.control])
+ self.length
+ self.information
+ bytes([self.fcs])
)
def __str__(self):
return f'{color(self.type_name(), "yellow")}(c/r={self.c_r},dlci={self.dlci},p/f={self.p_f},length={len(self.information)},fcs=0x{self.fcs:02X})'
@@ -205,38 +219,49 @@ class RFCOMM_Frame:
# -----------------------------------------------------------------------------
class RFCOMM_MCC_PN:
def __init__(self, dlci, cl, priority, ack_timer, max_frame_size, max_retransmissions, window_size):
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
def __init__(
self,
dlci,
cl,
priority,
ack_timer,
max_frame_size,
max_retransmissions,
window_size,
):
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
self.max_retransmissions = max_retransmissions
self.window_size = window_size
self.window_size = window_size
@staticmethod
def from_bytes(data):
return RFCOMM_MCC_PN(
dlci = data[0],
cl = data[1],
priority = data[2],
ack_timer = data[3],
max_frame_size = data[4] | data[5] << 8,
max_retransmissions = data[6],
window_size = data[7]
dlci=data[0],
cl=data[1],
priority=data[2],
ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6],
window_size=data[7],
)
def __bytes__(self):
return bytes([
self.dlci & 0xFF,
self.cl & 0xFF,
self.priority & 0xFF,
self.ack_timer & 0xFF,
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
self.window_size & 0xFF
])
return bytes(
[
self.dlci & 0xFF,
self.cl & 0xFF,
self.priority & 0xFF,
self.ack_timer & 0xFF,
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
self.window_size & 0xFF,
]
)
def __str__(self):
return f'PN(dlci={self.dlci},cl={self.cl},priority={self.priority},ack_timer={self.ack_timer},max_frame_size={self.max_frame_size},max_retransmissions={self.max_retransmissions},window_size={self.window_size})'
@@ -246,28 +271,35 @@ class RFCOMM_MCC_PN:
class RFCOMM_MCC_MSC:
def __init__(self, dlci, fc, rtc, rtr, ic, dv):
self.dlci = dlci
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv
@staticmethod
def from_bytes(data):
return RFCOMM_MCC_MSC(
dlci = data[0] >> 2,
fc = data[1] >> 1 & 1,
rtc = data[1] >> 2 & 1,
rtr = data[1] >> 3 & 1,
ic = data[1] >> 6 & 1,
dv = data[1] >> 7 & 1
dlci=data[0] >> 2,
fc=data[1] >> 1 & 1,
rtc=data[1] >> 2 & 1,
rtr=data[1] >> 3 & 1,
ic=data[1] >> 6 & 1,
dv=data[1] >> 7 & 1,
)
def __bytes__(self):
return bytes([
(self.dlci << 2) | 3,
1 | self.fc << 1 | self.rtc << 2 | self.rtr << 3 | self.ic << 6 | self.dv << 7
])
return bytes(
[
(self.dlci << 2) | 3,
1
| self.fc << 1
| self.rtc << 2
| self.rtr << 3
| self.ic << 6
| self.dv << 7,
]
)
def __str__(self):
return f'MSC(dlci={self.dlci},fc={self.fc},rtc={self.rtc},rtr={self.rtr},ic={self.ic},dv={self.dv})'
@@ -276,45 +308,49 @@ class RFCOMM_MCC_MSC:
# -----------------------------------------------------------------------------
class DLC(EventEmitter):
# States
INIT = 0x00
CONNECTING = 0x01
CONNECTED = 0x02
INIT = 0x00
CONNECTING = 0x01
CONNECTED = 0x02
DISCONNECTING = 0x03
DISCONNECTED = 0x04
RESET = 0x05
DISCONNECTED = 0x04
RESET = 0x05
STATE_NAMES = {
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET'
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET',
}
def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits):
super().__init__()
self.multiplexer = multiplexer
self.dlci = dlci
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.multiplexer = multiplexer
self.dlci = dlci
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.rx_threshold = self.rx_credits // 2
self.tx_credits = initial_tx_credits
self.tx_buffer = b''
self.state = DLC.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0
self.sink = None
self.tx_credits = initial_tx_credits
self.tx_buffer = b''
self.state = DLC.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0
self.sink = None
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min(max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead)
self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
)
@staticmethod
def state_name(state):
return DLC.STATE_NAMES[state]
def change_state(self, new_state):
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "magenta")}')
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "magenta")}'
)
self.state = new_state
def send_frame(self, frame):
@@ -329,26 +365,13 @@ class DLC(EventEmitter):
logger.warn(color('!!! received SABM when not in CONNECTING state', 'red'))
return
self.send_frame(RFCOMM_Frame.ua(c_r = 1 - self.c_r, dlci = self.dlci))
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc))
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTED)
self.emit('open')
@@ -359,23 +382,10 @@ class DLC(EventEmitter):
return
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc))
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTED)
self.multiplexer.on_dlc_open_complete(self)
@@ -386,7 +396,7 @@ class DLC(EventEmitter):
def on_disc_frame(self, frame):
# TODO: handle all states
self.send_frame(RFCOMM_Frame.ua(c_r = 1 - self.c_r, dlci = self.dlci))
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
def on_uih_frame(self, frame):
data = frame.information
@@ -395,10 +405,14 @@ class DLC(EventEmitter):
credits = frame.information[0]
self.tx_credits += credits
logger.debug(f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}')
logger.debug(
f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}'
)
data = data[1:]
logger.debug(f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}')
logger.debug(
f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}'
)
if len(data) and self.sink:
self.sink(data)
@@ -418,23 +432,12 @@ class DLC(EventEmitter):
if c_r:
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 0, data = bytes(msc))
logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
else:
# Response
logger.debug(f'<<< MCC MSC Response: {msc}')
@@ -445,35 +448,24 @@ class DLC(EventEmitter):
self.change_state(DLC.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(
RFCOMM_Frame.sabm(
c_r = self.c_r,
dlci = self.dlci
)
)
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
def accept(self):
if not self.state == DLC.INIT:
raise InvalidStateError('invalid state')
pn = RFCOMM_MCC_PN(
dlci = self.dlci,
cl = 0xE0,
priority = 7,
ack_timer = 0,
max_frame_size = RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions = 0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
dlci=self.dlci,
cl=0xE0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 0, data = bytes(pn))
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTING)
def rx_credits_needed(self):
@@ -488,13 +480,13 @@ class DLC(EventEmitter):
while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
# Get the next chunk, up to MTU size
if rx_credits_needed > 0:
chunk = bytes([rx_credits_needed]) + self.tx_buffer[:self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1:]
chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
self.rx_credits += rx_credits_needed
tx_credit_spent = (len(chunk) > 1)
tx_credit_spent = len(chunk) > 1
else:
chunk = self.tx_buffer[:self.mtu]
self.tx_buffer = self.tx_buffer[len(chunk):]
chunk = self.tx_buffer[: self.mtu]
self.tx_buffer = self.tx_buffer[len(chunk) :]
tx_credit_spent = True
# Update the tx credits
@@ -503,13 +495,15 @@ class DLC(EventEmitter):
self.tx_credits -= 1
# Send the frame
logger.debug(f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}')
logger.debug(
f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}'
)
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = self.dlci,
information = chunk,
p_f = 1 if rx_credits_needed > 0 else 0
c_r=self.c_r,
dlci=self.dlci,
information=chunk,
p_f=1 if rx_credits_needed > 0 else 0,
)
)
@@ -543,34 +537,34 @@ class Multiplexer(EventEmitter):
RESPONDER = 0x01
# States
INIT = 0x00
CONNECTING = 0x01
CONNECTED = 0x02
OPENING = 0x03
INIT = 0x00
CONNECTING = 0x01
CONNECTED = 0x02
OPENING = 0x03
DISCONNECTING = 0x04
DISCONNECTED = 0x05
RESET = 0x06
DISCONNECTED = 0x05
RESET = 0x06
STATE_NAMES = {
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
OPENING: 'OPENING',
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
OPENING: 'OPENING',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET'
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET',
}
def __init__(self, l2cap_channel, role):
super().__init__()
self.role = role
self.l2cap_channel = l2cap_channel
self.state = Multiplexer.INIT
self.dlcs = {} # DLCs, by DLCI
self.connection_result = None
self.role = role
self.l2cap_channel = l2cap_channel
self.state = Multiplexer.INIT
self.dlcs = {} # DLCs, by DLCI
self.connection_result = None
self.disconnection_result = None
self.open_result = None
self.acceptor = None
self.open_result = None
self.acceptor = None
# Become a sink for the L2CAP channel
l2cap_channel.sink = self.on_pdu
@@ -580,7 +574,9 @@ class Multiplexer(EventEmitter):
return Multiplexer.STATE_NAMES[state]
def change_state(self, new_state):
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "cyan")}')
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
self.state = new_state
def send_frame(self, frame):
@@ -616,7 +612,7 @@ class Multiplexer(EventEmitter):
logger.debug('not in INIT state, ignoring SABM')
return
self.change_state(Multiplexer.CONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r = 1, dlci = 0))
self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
def on_ua_frame(self, frame):
if self.state == Multiplexer.CONNECTING:
@@ -634,18 +630,22 @@ class Multiplexer(EventEmitter):
if self.state == Multiplexer.OPENING:
self.change_state(Multiplexer.CONNECTED)
if self.open_result:
self.open_result.set_exception(ConnectionError(
ConnectionError.CONNECTION_REFUSED,
BT_BR_EDR_TRANSPORT,
self.l2cap_channel.connection.peer_address,
'rfcomm'
))
self.open_result.set_exception(
ConnectionError(
ConnectionError.CONNECTION_REFUSED,
BT_BR_EDR_TRANSPORT,
self.l2cap_channel.connection.peer_address,
'rfcomm',
)
)
else:
logger.warn(f'unexpected state for DM: {self}')
def on_disc_frame(self, frame):
self.change_state(Multiplexer.DISCONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r = 0 if self.role == Multiplexer.INITIATOR else 1, dlci = 0))
self.send_frame(
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0)
)
def on_uih_frame(self, frame):
(type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
@@ -685,7 +685,7 @@ class Multiplexer(EventEmitter):
dlc.accept()
else:
# No acceptor, we're in Disconnected Mode
self.send_frame(RFCOMM_Frame.dm(c_r = 1, dlci = pn.dlci))
self.send_frame(RFCOMM_Frame.dm(c_r=1, dlci=pn.dlci))
else:
# No acceptor?? shouldn't happen
logger.warn(color('!!! no acceptor registered', 'red'))
@@ -712,7 +712,7 @@ class Multiplexer(EventEmitter):
self.change_state(Multiplexer.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r = 1, dlci = 0))
self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0))
return await self.connection_result
async def disconnect(self):
@@ -721,7 +721,11 @@ class Multiplexer(EventEmitter):
self.disconnection_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.DISCONNECTING)
self.send_frame(RFCOMM_Frame.disc(c_r = 1 if self.role == Multiplexer.INITIATOR else 0, dlci = 0))
self.send_frame(
RFCOMM_Frame.disc(
c_r=1 if self.role == Multiplexer.INITIATOR else 0, dlci=0
)
)
await self.disconnection_result
async def open_dlc(self, channel):
@@ -732,23 +736,23 @@ class Multiplexer(EventEmitter):
raise InvalidStateError('not connected')
pn = RFCOMM_MCC_PN(
dlci = channel << 1,
cl = 0xF0,
priority = 7,
ack_timer = 0,
max_frame_size = RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions = 0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
dlci=channel << 1,
cl=0xF0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 1, data = bytes(pn))
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.OPENING)
self.send_frame(
RFCOMM_Frame.uih(
c_r = 1 if self.role == Multiplexer.INITIATOR else 0,
dlci = 0,
information = mcc
c_r=1 if self.role == Multiplexer.INITIATOR else 0,
dlci=0,
information=mcc,
)
)
result = await self.open_result
@@ -768,15 +772,17 @@ class Multiplexer(EventEmitter):
# -----------------------------------------------------------------------------
class Client:
def __init__(self, device, connection):
self.device = device
self.connection = connection
self.device = device
self.connection = connection
self.l2cap_channel = None
self.multiplexer = None
self.multiplexer = None
async def start(self):
# Create a new L2CAP connection
try:
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(self.connection, RFCOMM_PSM)
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
self.connection, RFCOMM_PSM
)
except ProtocolError as error:
logger.warn(f'L2CAP connection failed: {error}')
raise
@@ -802,16 +808,18 @@ class Client:
class Server(EventEmitter):
def __init__(self, device):
super().__init__()
self.device = device
self.device = device
self.multiplexer = None
self.acceptors = {}
self.acceptors = {}
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
def listen(self, acceptor):
# Find a free channel number
for channel in range(RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1):
for channel in range(
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1
):
if channel not in self.acceptors:
self.acceptors[channel] = acceptor
return channel

View File

@@ -33,6 +33,8 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do
SDP_PSM = 0x0001
@@ -112,48 +114,64 @@ SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
# To be used in searches where an attribute ID list allows a range to be specified
SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size
# fmt: on
# -----------------------------------------------------------------------------
class DataElement:
NIL = 0
NIL = 0
UNSIGNED_INTEGER = 1
SIGNED_INTEGER = 2
UUID = 3
TEXT_STRING = 4
BOOLEAN = 5
SEQUENCE = 6
ALTERNATIVE = 7
URL = 8
SIGNED_INTEGER = 2
UUID = 3
TEXT_STRING = 4
BOOLEAN = 5
SEQUENCE = 6
ALTERNATIVE = 7
URL = 8
TYPE_NAMES = {
NIL: 'NIL',
NIL: 'NIL',
UNSIGNED_INTEGER: 'UNSIGNED_INTEGER',
SIGNED_INTEGER: 'SIGNED_INTEGER',
UUID: 'UUID',
TEXT_STRING: 'TEXT_STRING',
BOOLEAN: 'BOOLEAN',
SEQUENCE: 'SEQUENCE',
ALTERNATIVE: 'ALTERNATIVE',
URL: 'URL'
SIGNED_INTEGER: 'SIGNED_INTEGER',
UUID: 'UUID',
TEXT_STRING: 'TEXT_STRING',
BOOLEAN: 'BOOLEAN',
SEQUENCE: 'SEQUENCE',
ALTERNATIVE: 'ALTERNATIVE',
URL: 'URL',
}
type_constructors = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(DataElement.UNSIGNED_INTEGER, DataElement.unsigned_integer_from_bytes(x), value_size=y),
SIGNED_INTEGER: lambda x, y: DataElement(DataElement.SIGNED_INTEGER, DataElement.signed_integer_from_bytes(x), value_size=y),
UUID: lambda x: DataElement(DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(DataElement.SEQUENCE, DataElement.list_from_bytes(x)),
ALTERNATIVE: lambda x: DataElement(DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8'))
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
def __init__(self, type, value, value_size=None):
self.type = type
self.value = value
self.type = type
self.value = value
self.value_size = value_size
self.bytes = None # Used a cache when parsing from bytes so we can emit a byte-for-byte replica
self.bytes = None # Used a cache when parsing from bytes so we can emit a byte-for-byte replica
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER:
if value_size is None:
raise ValueError('integer types must have a value size specified')
@@ -250,7 +268,7 @@ class DataElement:
while data:
element = DataElement.from_bytes(data)
elements.append(element)
data = data[len(bytes(element)):]
data = data[len(bytes(element)) :]
return elements
@staticmethod
@@ -261,7 +279,7 @@ class DataElement:
@staticmethod
def from_bytes(data):
type = data[0] >> 3
size_index = data[0] & 7
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if type == DataElement.NIL:
@@ -286,16 +304,21 @@ class DataElement:
value_size = struct.unpack('>I', data[1:5])[0]
value_offset = 4
value_data = data[1 + value_offset:1 + value_offset + value_size]
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.type_constructors.get(type)
if constructor:
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER:
if (
type == DataElement.UNSIGNED_INTEGER
or type == DataElement.SIGNED_INTEGER
):
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(type, value_data)
result.bytes = data[:1 + value_offset + value_size] # Keep a copy so we can re-serialize to an exact replica
result.bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def to_bytes(self):
@@ -349,9 +372,11 @@ class DataElement:
if size != 0:
raise ValueError('NIL must be empty')
size_index = 0
elif (self.type == DataElement.UNSIGNED_INTEGER or
self.type == DataElement.SIGNED_INTEGER or
self.type == DataElement.UUID):
elif (
self.type == DataElement.UNSIGNED_INTEGER
or self.type == DataElement.SIGNED_INTEGER
or self.type == DataElement.UUID
):
if size <= 1:
size_index = 0
elif size == 2:
@@ -364,10 +389,12 @@ class DataElement:
size_index = 4
else:
raise ValueError('invalid data size')
elif (self.type == DataElement.TEXT_STRING or
self.type == DataElement.SEQUENCE or
self.type == DataElement.ALTERNATIVE or
self.type == DataElement.URL):
elif (
self.type == DataElement.TEXT_STRING
or self.type == DataElement.SEQUENCE
or self.type == DataElement.ALTERNATIVE
or self.type == DataElement.URL
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
@@ -396,7 +423,10 @@ class DataElement:
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
value_string = f'[{container_separator}{element_separator.join([element.to_string(pretty, indentation + 1 if pretty else 0) for element in self.value])}{container_separator}{prefix}]'
elif self.type == DataElement.UNSIGNED_INTEGER or self.type == DataElement.SIGNED_INTEGER:
elif (
self.type == DataElement.UNSIGNED_INTEGER
or self.type == DataElement.SIGNED_INTEGER
):
value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
@@ -411,14 +441,14 @@ class DataElement:
# -----------------------------------------------------------------------------
class ServiceAttribute:
def __init__(self, id, value):
self.id = id
self.id = id
self.value = value
@staticmethod
def list_from_data_elements(elements):
attribute_list = []
for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i:2 * (i + 1)]
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
if attribute_id.type != DataElement.UNSIGNED_INTEGER:
logger.warn('attribute ID element is not an integer')
continue
@@ -428,7 +458,14 @@ class ServiceAttribute:
@staticmethod
def find_attribute_in_list(attribute_list, attribute_id):
return next((attribute.value for attribute in attribute_list if attribute.id == attribute_id), None)
return next(
(
attribute.value
for attribute in attribute_list
if attribute.id == attribute_id
),
None,
)
@staticmethod
def id_name(id):
@@ -462,6 +499,7 @@ class SDP_PDU:
'''
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
'''
sdp_pdu_classes = {}
@staticmethod
@@ -484,13 +522,15 @@ class SDP_PDU:
@staticmethod
def parse_service_record_handle_list_preceded_by_count(data, offset):
count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)]
handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
]
return offset + count * 4, handle_list
@staticmethod
def parse_bytes_preceded_by_length(data, offset):
length = struct.unpack_from('>H', data, offset - 2)[0]
return offset + length, data[offset:offset + length]
return offset + length, data[offset : offset + length]
@staticmethod
def error_name(error_code):
@@ -532,7 +572,10 @@ class SDP_PDU:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
parameters = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters)) + parameters
pdu = (
struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters))
+ parameters
)
self.pdu = pdu
self.transaction_id = transaction_id
@@ -555,9 +598,7 @@ class SDP_PDU:
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})
])
@SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})])
class SDP_ErrorResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
@@ -565,11 +606,13 @@ class SDP_ErrorResponse(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_service_record_count', '>2'),
('continuation_state', '*')
])
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_service_record_count', '>2'),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
@@ -577,12 +620,17 @@ class SDP_ServiceSearchRequest(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
('total_service_record_count', '>2'),
('current_service_record_count', '>2'),
('service_record_handle_list', SDP_PDU.parse_service_record_handle_list_preceded_by_count),
('continuation_state', '*')
])
@SDP_PDU.subclass(
[
('total_service_record_count', '>2'),
('current_service_record_count', '>2'),
(
'service_record_handle_list',
SDP_PDU.parse_service_record_handle_list_preceded_by_count,
),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
@@ -590,12 +638,14 @@ class SDP_ServiceSearchResponse(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
('service_record_handle', '>4'),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*')
])
@SDP_PDU.subclass(
[
('service_record_handle', '>4'),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*'),
]
)
class SDP_ServiceAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
@@ -603,11 +653,13 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
('attribute_list_byte_count', '>2'),
('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*')
])
@SDP_PDU.subclass(
[
('attribute_list_byte_count', '>2'),
('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*'),
]
)
class SDP_ServiceAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
@@ -615,12 +667,14 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*')
])
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
@@ -628,11 +682,13 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
('attribute_lists_byte_count', '>2'),
('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*')
])
@SDP_PDU.subclass(
[
('attribute_lists_byte_count', '>2'),
('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
@@ -642,9 +698,9 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
# -----------------------------------------------------------------------------
class Client:
def __init__(self, device):
self.device = device
self.device = device
self.pending_request = None
self.channel = None
self.channel = None
async def connect(self, connection):
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
@@ -659,7 +715,9 @@ class Client:
if self.pending_request is not None:
raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids])
service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
# Request and accumulate until there's no more continuation
service_record_handle_list = []
@@ -668,10 +726,10 @@ class Client:
while watchdog > 0:
response_pdu = await self.channel.send_request(
SDP_ServiceSearchRequest(
transaction_id = 0, # Transaction ID TODO: pick a real value
service_search_pattern = service_search_pattern,
maximum_service_record_count = 0xFFFF,
continuation_state = continuation_state
transaction_id=0, # Transaction ID TODO: pick a real value
service_search_pattern=service_search_pattern,
maximum_service_record_count=0xFFFF,
continuation_state=continuation_state,
)
)
response = SDP_PDU.from_bytes(response_pdu)
@@ -689,10 +747,14 @@ class Client:
if self.pending_request is not None:
raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids])
service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1])
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if type(attribute_id) is tuple
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
@@ -706,11 +768,11 @@ class Client:
while watchdog > 0:
response_pdu = await self.channel.send_request(
SDP_ServiceSearchAttributeRequest(
transaction_id = 0, # Transaction ID TODO: pick a real value
service_search_pattern = service_search_pattern,
maximum_attribute_byte_count = 0xFFFF,
attribute_id_list = attribute_id_list,
continuation_state = continuation_state
transaction_id=0, # Transaction ID TODO: pick a real value
service_search_pattern=service_search_pattern,
maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list,
continuation_state=continuation_state,
)
)
response = SDP_PDU.from_bytes(response_pdu)
@@ -740,7 +802,9 @@ class Client:
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1])
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if type(attribute_id) is tuple
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
@@ -754,11 +818,11 @@ class Client:
while watchdog > 0:
response_pdu = await self.channel.send_request(
SDP_ServiceAttributeRequest(
transaction_id = 0, # Transaction ID TODO: pick a real value
service_record_handle = service_record_handle,
maximum_attribute_byte_count = 0xFFFF,
attribute_id_list = attribute_id_list,
continuation_state = continuation_state
transaction_id=0, # Transaction ID TODO: pick a real value
service_record_handle=service_record_handle,
maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list,
continuation_state=continuation_state,
)
)
response = SDP_PDU.from_bytes(response_pdu)
@@ -784,8 +848,8 @@ class Server:
CONTINUATION_STATE = bytes([0x01, 0x43])
def __init__(self, device):
self.device = device
self.service_records = {} # Service records maps, by record handle
self.device = device
self.service_records = {} # Service records maps, by record handle
self.current_response = None
def register(self, l2cap_channel_manager):
@@ -823,8 +887,7 @@ class Server:
logger.warn(color(f'failed to parse SDP Request PDU: {error}', 'red'))
self.send_response(
SDP_ErrorResponse(
transaction_id = 0,
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR
transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
)
)
@@ -840,16 +903,16 @@ class Server:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
self.send_response(
SDP_ErrorResponse(
transaction_id = sdp_pdu.transaction_id,
error_code = SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR
transaction_id=sdp_pdu.transaction_id,
error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR,
)
)
else:
logger.error(color('SDP Request not handled???', 'red'))
self.send_response(
SDP_ErrorResponse(
transaction_id = sdp_pdu.transaction_id,
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR
transaction_id=sdp_pdu.transaction_id,
error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR,
)
)
@@ -872,17 +935,18 @@ class Server:
if attribute_id.value_size == 4:
# Attribute ID range
id_range_start = attribute_id.value >> 16
id_range_end = attribute_id.value & 0xFFFF
id_range_end = attribute_id.value & 0xFFFF
else:
id_range_start = attribute_id.value
id_range_end = attribute_id.value
id_range_end = attribute_id.value
attributes += [
attribute for attribute in service
attribute
for attribute in service
if attribute.id >= id_range_start and attribute.id <= id_range_end
]
# Return the maching attributes, sorted by attribute id
attributes.sort(key = lambda x: x.id)
attributes.sort(key=lambda x: x.id)
attribute_list = DataElement.sequence([])
for attribute in attributes:
attribute_list.value.append(DataElement.unsigned_integer_16(attribute.id))
@@ -896,8 +960,8 @@ class Server:
if not self.current_response:
self.send_response(
SDP_ErrorResponse(
transaction_id = request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR
transaction_id=request.transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
return
@@ -910,30 +974,38 @@ class Server:
service_record_handles = list(matching_services.keys())
# Only return up to the maximum requested
service_record_handles_subset = service_record_handles[:request.maximum_service_record_count]
service_record_handles_subset = service_record_handles[
: request.maximum_service_record_count
]
# Serialize to a byte array, and remember the total count
logger.debug(f'Service Record Handles: {service_record_handles}')
self.current_response = (
len(service_record_handles),
service_record_handles_subset
service_record_handles_subset,
)
# Respond, keeping any unsent handles for later
service_record_handles = self.current_response[1][:request.maximum_service_record_count]
service_record_handles = self.current_response[1][
: request.maximum_service_record_count
]
self.current_response = (
self.current_response[0],
self.current_response[1][request.maximum_service_record_count:]
self.current_response[1][request.maximum_service_record_count :],
)
continuation_state = (
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
)
service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles]
)
continuation_state = Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
service_record_handle_list = b''.join([struct.pack('>I', handle) for handle in service_record_handles])
self.send_response(
SDP_ServiceSearchResponse(
transaction_id = request.transaction_id,
total_service_record_count = self.current_response[0],
current_service_record_count = len(service_record_handles),
service_record_handle_list = service_record_handle_list,
continuation_state = continuation_state
transaction_id=request.transaction_id,
total_service_record_count=self.current_response[0],
current_service_record_count=len(service_record_handles),
service_record_handle_list=service_record_handle_list,
continuation_state=continuation_state,
)
)
@@ -943,8 +1015,8 @@ class Server:
if not self.current_response:
self.send_response(
SDP_ErrorResponse(
transaction_id = request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR
transaction_id=request.transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
return
@@ -957,27 +1029,31 @@ class Server:
if service is None:
self.send_response(
SDP_ErrorResponse(
transaction_id = request.transaction_id,
error_code = SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR
transaction_id=request.transaction_id,
error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR,
)
)
return
# Get the attributes for the service
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value)
attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
# Serialize to a byte array
logger.debug(f'Attributes: {attribute_list}')
self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later
attribute_list, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count)
attribute_list, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response(
SDP_ServiceAttributeResponse(
transaction_id = request.transaction_id,
attribute_list_byte_count = len(attribute_list),
attribute_list = attribute_list,
continuation_state = continuation_state
transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list),
attribute_list=attribute_list,
continuation_state=continuation_state,
)
)
@@ -987,8 +1063,8 @@ class Server:
if not self.current_response:
self.send_response(
SDP_ErrorResponse(
transaction_id = request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR
transaction_id=request.transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
else:
@@ -996,12 +1072,16 @@ class Server:
self.current_response = None
# Find the matching services
matching_services = self.match_services(request.service_search_pattern).values()
matching_services = self.match_services(
request.service_search_pattern
).values()
# Filter the required attributes
attribute_lists = DataElement.sequence([])
for service in matching_services:
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value)
attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
if attribute_list.value:
attribute_lists.value.append(attribute_list)
@@ -1010,12 +1090,14 @@ class Server:
self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later
attribute_lists, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count)
attribute_lists, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response(
SDP_ServiceSearchAttributeResponse(
transaction_id = request.transaction_id,
attribute_lists_byte_count = len(attribute_lists),
attribute_lists = attribute_lists,
continuation_state = continuation_state
transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists),
attribute_lists=attribute_lists,
continuation_state=continuation_state,
)
)

File diff suppressed because it is too large Load Diff

View File

@@ -38,42 +38,55 @@ async def open_transport(name):
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec:
from .serial import open_serial_transport
return await open_serial_transport(spec[0])
elif scheme == 'udp' and spec:
from .udp import open_udp_transport
return await open_udp_transport(spec[0])
elif scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0])
elif scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0])
elif scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0])
elif scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0])
elif scheme == 'pty':
from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None)
elif scheme == 'file':
from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None)
elif scheme == 'vhci':
from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None)
elif scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None)
elif scheme == 'usb':
from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None)
elif scheme == 'pyusb':
from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None)
elif scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None)
else:
raise ValueError('unknown transport scheme')
@@ -84,7 +97,7 @@ async def open_transport_or_link(name):
if name.startswith('link-relay:'):
link = RemoteLink(name[11:])
await link.wait_until_connected()
controller = Controller('remote', link = link)
controller = Controller('remote', link=link)
class LinkTransport(Transport):
async def close(self):

View File

@@ -59,15 +59,10 @@ async def open_android_emulator_transport(spec):
return bytes([packet.type]) + packet.packet
async def write(self, packet):
await self.hci_device.write(
HCIPacket(
type = packet[0],
packet = packet[1:]
)
)
await self.hci_device.write(HCIPacket(type=packet[0], packet=packet[1:]))
# Parse the parameters
mode = 'host'
mode = 'host'
server_host = 'localhost'
server_port = 8554
if spec is not None:
@@ -100,7 +95,7 @@ async def open_android_emulator_transport(spec):
transport = PumpedTransport(
PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write),
channel.close
channel.close,
)
transport.start()

View File

@@ -33,10 +33,10 @@ logger = logging.getLogger(__name__)
# For each packet type, the info represents:
# (length-size, length-offset, unpack-type)
HCI_PACKET_INFO = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B')
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
}
@@ -48,7 +48,7 @@ class PacketPump:
def __init__(self, reader, sink):
self.reader = reader
self.sink = sink
self.sink = sink
async def run(self):
while True:
@@ -67,41 +67,46 @@ class PacketParser:
'''
In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed
'''
NEED_TYPE = 0
NEED_LENGTH = 1
NEED_BODY = 2
def __init__(self, sink = None):
NEED_TYPE = 0
NEED_LENGTH = 1
NEED_BODY = 2
def __init__(self, sink=None):
self.sink = sink
self.extended_packet_info = {}
self.reset()
def reset(self):
self.state = PacketParser.NEED_TYPE
self.state = PacketParser.NEED_TYPE
self.bytes_needed = 1
self.packet = bytearray()
self.packet_info = None
self.packet = bytearray()
self.packet_info = None
def feed_data(self, data):
data_offset = 0
data_left = len(data)
while data_left and self.bytes_needed:
consumed = min(self.bytes_needed, data_left)
self.packet.extend(data[data_offset:data_offset + consumed])
data_offset += consumed
data_left -= consumed
self.packet.extend(data[data_offset : data_offset + consumed])
data_offset += consumed
data_left -= consumed
self.bytes_needed -= consumed
if self.bytes_needed == 0:
if self.state == PacketParser.NEED_TYPE:
packet_type = self.packet[0]
self.packet_info = HCI_PACKET_INFO.get(packet_type) or self.extended_packet_info.get(packet_type)
self.packet_info = HCI_PACKET_INFO.get(
packet_type
) or self.extended_packet_info.get(packet_type)
if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}')
self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH:
body_length = struct.unpack_from(self.packet_info[2], self.packet, 1 + self.packet_info[1])[0]
body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0]
self.bytes_needed = body_length
self.state = PacketParser.NEED_BODY
@@ -111,7 +116,9 @@ class PacketParser:
try:
self.sink.on_packet(bytes(self.packet))
except Exception as error:
logger.warning(color(f'!!! Exception in on_packet: {error}', 'red'))
logger.warning(
color(f'!!! Exception in on_packet: {error}', 'red')
)
self.reset()
def set_packet_sink(self, sink):
@@ -187,6 +194,7 @@ class AsyncPipeSink:
'''
Sink that forwards packets asynchronously to another sink
'''
def __init__(self, sink):
self.sink = sink
self.loop = asyncio.get_running_loop()
@@ -202,7 +210,7 @@ class ParserSource:
"""
def __init__(self):
self.parser = PacketParser()
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future()
def set_packet_sink(self, sink):
@@ -237,7 +245,7 @@ class StreamPacketSink:
class Transport:
def __init__(self, source, sink):
self.source = source
self.sink = sink
self.sink = sink
async def __aenter__(self):
return self
@@ -258,7 +266,7 @@ class PumpedPacketSource(ParserSource):
def __init__(self, receive):
super().__init__()
self.receive_function = receive
self.pump_task = None
self.pump_task = None
def start(self):
async def pump_packets():
@@ -285,8 +293,8 @@ class PumpedPacketSource(ParserSource):
class PumpedPacketSink:
def __init__(self, send):
self.send_function = send
self.packet_queue = asyncio.Queue()
self.pump_task = None
self.packet_queue = asyncio.Queue()
self.pump_task = None
def on_packet(self, packet):
self.packet_queue.put_nowait(packet)

View File

@@ -21,32 +21,36 @@ from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
_HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket']
_HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType']
HCIPacket = _reflection.GeneratedProtocolMessageType('HCIPacket', (_message.Message,), {
'DESCRIPTOR' : _HCIPACKET,
'__module__' : 'emulated_bluetooth_packets_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket)
})
HCIPacket = _reflection.GeneratedProtocolMessageType(
'HCIPacket',
(_message.Message,),
{
'DESCRIPTOR': _HCIPACKET,
'__module__': 'emulated_bluetooth_packets_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket)
},
)
_sym_db.RegisterMessage(HCIPacket)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_HCIPACKET._serialized_start=66
_HCIPACKET._serialized_end=317
_HCIPACKET_PACKETTYPE._serialized_start=161
_HCIPACKET_PACKETTYPE._serialized_end=317
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_HCIPACKET._serialized_start = 66
_HCIPACKET._serialized_end = 317
_HCIPACKET_PACKETTYPE._serialized_start = 161
_HCIPACKET_PACKETTYPE._serialized_end = 317
# @@protoc_insertion_point(module_scope)

View File

@@ -21,6 +21,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -29,25 +30,30 @@ _sym_db = _symbol_database.Default()
from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3'
)
_RAWDATA = DESCRIPTOR.message_types_by_name['RawData']
RawData = _reflection.GeneratedProtocolMessageType('RawData', (_message.Message,), {
'DESCRIPTOR' : _RAWDATA,
'__module__' : 'emulated_bluetooth_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData)
})
RawData = _reflection.GeneratedProtocolMessageType(
'RawData',
(_message.Message,),
{
'DESCRIPTOR': _RAWDATA,
'__module__': 'emulated_bluetooth_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData)
},
)
_sym_db.RegisterMessage(RawData)
_EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\036com.android.emulator.bluetoothP\001'
_RAWDATA._serialized_start=91
_RAWDATA._serialized_end=116
_EMULATEDBLUETOOTHSERVICE._serialized_start=119
_EMULATEDBLUETOOTHSERVICE._serialized_end=450
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\036com.android.emulator.bluetoothP\001'
_RAWDATA._serialized_start = 91
_RAWDATA._serialized_end = 116
_EMULATEDBLUETOOTHSERVICE._serialized_start = 119
_EMULATEDBLUETOOTHSERVICE._serialized_end = 450
# @@protoc_insertion_point(module_scope)

View File

@@ -39,20 +39,20 @@ class EmulatedBluetoothServiceStub(object):
channel: A grpc.Channel.
"""
self.registerClassicPhy = channel.stream_stream(
'/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
)
'/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
)
self.registerBlePhy = channel.stream_stream(
'/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
)
'/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
)
self.registerHCIDevice = channel.stream_stream(
'/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
)
'/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
)
class EmulatedBluetoothServiceServicer(object):
@@ -121,28 +121,29 @@ class EmulatedBluetoothServiceServicer(object):
def add_EmulatedBluetoothServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'registerClassicPhy': grpc.stream_stream_rpc_method_handler(
servicer.registerClassicPhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
),
'registerBlePhy': grpc.stream_stream_rpc_method_handler(
servicer.registerBlePhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
),
'registerHCIDevice': grpc.stream_stream_rpc_method_handler(
servicer.registerHCIDevice,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
),
'registerClassicPhy': grpc.stream_stream_rpc_method_handler(
servicer.registerClassicPhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
),
'registerBlePhy': grpc.stream_stream_rpc_method_handler(
servicer.registerBlePhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
),
'registerHCIDevice': grpc.stream_stream_rpc_method_handler(
servicer.registerHCIDevice,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers)
'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
# This class is part of an EXPERIMENTAL API.
class EmulatedBluetoothService(object):
"""An Emulated Bluetooth Service exposes the emulated bluetooth chip from the
android emulator. It allows you to register emulated bluetooth devices and
@@ -156,52 +157,88 @@ class EmulatedBluetoothService(object):
"""
@staticmethod
def registerClassicPhy(request_iterator,
def registerClassicPhy(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
'/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod
def registerBlePhy(request_iterator,
def registerBlePhy(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
'/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod
def registerHCIDevice(request_iterator,
def registerHCIDevice(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
'/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)

View File

@@ -21,6 +21,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -29,15 +30,16 @@ _sym_db = _symbol_database.Default()
import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
_VHCIFORWARDINGSERVICE = DESCRIPTOR.services_by_name['VhciForwardingService']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_VHCIFORWARDINGSERVICE._serialized_start=96
_VHCIFORWARDINGSERVICE._serialized_end=217
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_VHCIFORWARDINGSERVICE._serialized_start = 96
_VHCIFORWARDINGSERVICE._serialized_end = 217
# @@protoc_insertion_point(module_scope)

View File

@@ -35,10 +35,10 @@ class VhciForwardingServiceStub(object):
channel: A grpc.Channel.
"""
self.attachVhci = channel.stream_stream(
'/android.emulation.bluetooth.VhciForwardingService/attachVhci',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
)
'/android.emulation.bluetooth.VhciForwardingService/attachVhci',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
)
class VhciForwardingServiceServicer(object):
@@ -75,18 +75,19 @@ class VhciForwardingServiceServicer(object):
def add_VhciForwardingServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'attachVhci': grpc.stream_stream_rpc_method_handler(
servicer.attachVhci,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
),
'attachVhci': grpc.stream_stream_rpc_method_handler(
servicer.attachVhci,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers)
'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
# This class is part of an EXPERIMENTAL API.
class VhciForwardingService(object):
"""This is a service which allows you to directly intercept the VHCI packets
that are coming and going to the device before they are delivered to
@@ -97,18 +98,30 @@ class VhciForwardingService(object):
"""
@staticmethod
def attachVhci(request_iterator,
def attachVhci(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.VhciForwardingService/attachVhci',
'/android.emulation.bluetooth.VhciForwardingService/attachVhci',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)

View File

@@ -39,14 +39,12 @@ async def open_file_transport(spec):
# Setup reading
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(),
file
lambda: StreamPacketSource(), file
)
# Setup writing
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(),
file
lambda: asyncio.BaseProtocol(), file
)
packet_sink = StreamPacketSink(write_transport)
@@ -57,4 +55,3 @@ async def open_file_transport(spec):
file.close()
return FileTransport(packet_source, packet_sink)

View File

@@ -44,7 +44,11 @@ async def open_hci_socket_transport(spec):
# Create a raw HCI socket
try:
hci_socket = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.BTPROTO_HCI)
hci_socket = socket.socket(
socket.AF_BLUETOOTH,
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI,
)
except AttributeError:
# Not supported on this platform
logger.info("HCI sockets not supported on this platform")
@@ -67,15 +71,26 @@ async def open_hci_socket_transport(spec):
raise Exception('Bluetooth HCI sockets not supported on this platform')
libc.bind.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_char), ctypes.c_int)
libc.bind.restype = ctypes.c_int
bind_address = struct.pack('<HHH', socket.AF_BLUETOOTH, adapter_index, HCI_CHANNEL_USER)
if libc.bind(hci_socket.fileno(), ctypes.create_string_buffer(bind_address), len(bind_address)) != 0:
bind_address = struct.pack(
'<HHH', socket.AF_BLUETOOTH, adapter_index, HCI_CHANNEL_USER
)
if (
libc.bind(
hci_socket.fileno(),
ctypes.create_string_buffer(bind_address),
len(bind_address),
)
!= 0
):
raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
class HciSocketSource(ParserSource):
def __init__(self, socket):
super().__init__()
self.socket = socket
asyncio.get_running_loop().add_reader(socket.fileno(), self.recv_until_would_block)
self.socket = socket
asyncio.get_running_loop().add_reader(
socket.fileno(), self.recv_until_would_block
)
def recv_until_would_block(self):
logger.debug('recv until would block +++')
@@ -93,8 +108,8 @@ async def open_hci_socket_transport(spec):
class HciSocketSink:
def __init__(self, socket):
self.socket = socket
self.packets = collections.deque()
self.socket = socket
self.packets = collections.deque()
self.writer_added = False
def send_until_would_block(self):
@@ -114,7 +129,9 @@ async def open_hci_socket_transport(spec):
if self.packets:
# There's still something to send, ensure that we are monitoring the socket
if not self.writer_added:
asyncio.get_running_loop().add_writer(socket.fileno(), self.send_until_would_block)
asyncio.get_running_loop().add_writer(
socket.fileno(), self.send_until_would_block
)
self.writer_added = True
else:
# Nothing left to send, stop monitoring the socket

View File

@@ -47,13 +47,11 @@ async def open_pty_transport(spec):
tty.setraw(replica)
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(),
io.open(primary, 'rb', closefd=False)
lambda: StreamPacketSource(), io.open(primary, 'rb', closefd=False)
)
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(),
io.open(primary, 'wb', closefd=False)
lambda: asyncio.BaseProtocol(), io.open(primary, 'wb', closefd=False)
)
packet_sink = StreamPacketSink(write_transport)

View File

@@ -48,25 +48,25 @@ async def open_pyusb_transport(spec):
04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
'''
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
USB_ENDPOINT_ACL_IN = 0x82
USB_ENDPOINT_SCO_IN = 0x83
USB_ENDPOINT_ACL_OUT = 0x02
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
USB_ENDPOINT_ACL_IN = 0x82
USB_ENDPOINT_SCO_IN = 0x83
USB_ENDPOINT_ACL_OUT = 0x02
# USB_ENDPOINT_SCO_OUT = 0x03
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
READ_SIZE = 1024
READ_SIZE = 1024
READ_TIMEOUT = 1000
class UsbPacketSink:
def __init__(self, device):
self.device = device
self.thread = threading.Thread(target=self.run)
self.loop = asyncio.get_running_loop()
self.device = device
self.thread = threading.Thread(target=self.run)
self.loop = asyncio.get_running_loop()
self.stop_event = None
def on_packet(self, packet):
@@ -80,9 +80,17 @@ async def open_pyusb_transport(spec):
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:])
elif packet_type == hci.HCI_COMMAND_PACKET:
self.device.ctrl_transfer(USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, packet[1:])
self.device.ctrl_transfer(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
)
else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
logger.warning(
color(f'unsupported packet type {packet_type}', 'red')
)
except usb.core.USBTimeoutError:
logger.warning('USB Write Timeout')
except usb.core.USBError as error:
@@ -105,17 +113,15 @@ async def open_pyusb_transport(spec):
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, sco_enabled):
super().__init__()
self.device = device
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.device = device
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.event_thread = threading.Thread(
target=self.run,
args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
target=self.run, args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
)
self.event_thread.stop_event = None
self.acl_thread = threading.Thread(
target=self.run,
args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
target=self.run, args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
)
self.acl_thread.stop_event = None
@@ -124,7 +130,7 @@ async def open_pyusb_transport(spec):
if sco_enabled:
self.sco_thread = threading.Thread(
target=self.run,
args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET)
args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET),
)
self.sco_thread.stop_event = None
@@ -155,7 +161,7 @@ async def open_pyusb_transport(spec):
# Create stop events and wait for them to be signaled
self.event_thread.stop_event = asyncio.Event()
self.acl_thread.stop_event = asyncio.Event()
self.acl_thread.stop_event = asyncio.Event()
await self.event_thread.stop_event.wait()
await self.acl_thread.stop_event.wait()
if self.sco_enabled:
@@ -197,15 +203,19 @@ async def open_pyusb_transport(spec):
# Find the device according to the spec moniker
if ':' in spec:
vendor_id, product_id = spec.split(':')
device = usb.core.find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
device = usb.core.find(
idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)
)
else:
device_index = int(spec)
devices = list(usb.core.find(
find_all = 1,
bDeviceClass = USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
bDeviceSubClass = USB_DEVICE_SUBCLASS_RF_CONTROLLER,
bDeviceProtocol = USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
))
devices = list(
usb.core.find(
find_all=1,
bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER,
bDeviceProtocol=USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
)
if len(devices) > device_index:
device = devices[device_index]
else:
@@ -273,4 +283,4 @@ async def open_pyusb_transport(spec):
packet_source.start()
packet_sink.start()
return UsbTransport(device, packet_source, packet_sink)
return UsbTransport(device, packet_source, packet_sink)

View File

@@ -64,9 +64,8 @@ async def open_serial_transport(spec):
device,
baudrate=speed,
rtscts=rtscts,
dsrdtr=dsrdtr
dsrdtr=dsrdtr,
)
packet_sink = StreamPacketSink(serial_transport)
return Transport(packet_source, packet_sink)

View File

@@ -45,7 +45,7 @@ async def open_tcp_server_transport(spec):
class TcpServerProtocol:
def __init__(self, packet_source, packet_sink):
self.packet_source = packet_source
self.packet_sink = packet_sink
self.packet_sink = packet_sink
# Called when a new connection is established
def connection_made(self, transport):
@@ -78,7 +78,7 @@ async def open_tcp_server_transport(spec):
local_host, local_port = spec.split(':')
packet_source = StreamPacketSource()
packet_sink = TcpServerPacketSink()
packet_sink = TcpServerPacketSink()
await asyncio.get_running_loop().create_server(
lambda: TcpServerProtocol(packet_source, packet_sink),
host=local_host if local_host != '_' else None,

View File

@@ -53,10 +53,13 @@ async def open_udp_transport(spec):
local, remote = spec.split(',')
local_host, local_port = local.split(':')
remote_host, remote_port = remote.split(':')
udp_transport, packet_source = await asyncio.get_running_loop().create_datagram_endpoint(
(
udp_transport,
packet_source,
) = await asyncio.get_running_loop().create_datagram_endpoint(
lambda: UdpPacketSource(),
local_addr=(local_host, int(local_port)),
remote_addr=(remote_host, int(remote_port))
remote_addr=(remote_host, int(remote_port)),
)
packet_sink = UdpPacketSink(udp_transport)

View File

@@ -59,33 +59,33 @@ async def open_usb_transport(spec):
usb:0B05:17CB! --> the BT USB dongle vendor=0B05 and product=17CB, in "forced" mode.
'''
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_DEVICE_CLASS_DEVICE = 0x00
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_DEVICE_CLASS_DEVICE = 0x00
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02
USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03
USB_ENDPOINT_IN = 0x80
USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02
USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03
USB_ENDPOINT_IN = 0x80
USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
READ_SIZE = 1024
class UsbPacketSink:
def __init__(self, device, acl_out):
self.device = device
self.acl_out = acl_out
self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
self.device = device
self.acl_out = acl_out
self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
self.cancel_done = self.loop.create_future()
self.closed = False
self.closed = False
def start(self):
pass
@@ -114,7 +114,9 @@ async def open_usb_transport(spec):
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else:
logger.warning(color(f'!!! out transfer not completed: status={status}', 'red'))
logger.warning(
color(f'!!! out transfer not completed: status={status}', 'red')
)
def on_packet_sent_(self):
if self.packets:
@@ -129,17 +131,18 @@ async def open_usb_transport(spec):
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.transfer.setBulk(
self.acl_out,
packet[1:],
callback=self.on_packet_sent
self.acl_out, packet[1:], callback=self.on_packet_sent
)
logger.debug('submit ACL')
self.transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET:
self.transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0,
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
callback=self.on_packet_sent
callback=self.on_packet_sent,
)
logger.debug('submit COMMAND')
self.transfer.submit()
@@ -167,17 +170,17 @@ async def open_usb_transport(spec):
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, context, device, acl_in, events_in):
super().__init__()
self.context = context
self.device = device
self.acl_in = acl_in
self.events_in = events_in
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.closed = False
self.context = context
self.device = device
self.acl_in = acl_in
self.events_in = events_in
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.closed = False
self.event_loop_done = self.loop.create_future()
self.cancel_done = {
hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future()
hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
}
# Create a thread to process events
@@ -190,7 +193,7 @@ async def open_usb_transport(spec):
self.events_in,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET
user_data=hci.HCI_EVENT_PACKET,
)
self.events_in_transfer.submit()
@@ -199,7 +202,7 @@ async def open_usb_transport(spec):
self.acl_in,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET
user_data=hci.HCI_ACL_DATA_PACKET,
)
self.acl_in_transfer.submit()
@@ -212,13 +215,20 @@ async def open_usb_transport(spec):
# logger.debug(f'<<< USB IN transfer callback: status={status} packet_type={packet_type} length={transfer.getActualLength()}')
if status == usb1.TRANSFER_COMPLETED:
packet = bytes([packet_type]) + transfer.getBuffer()[:transfer.getActualLength()]
packet = (
bytes([packet_type])
+ transfer.getBuffer()[: transfer.getActualLength()]
)
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done[packet_type].set_result, None)
self.loop.call_soon_threadsafe(
self.cancel_done[packet_type].set_result, None
)
return
else:
logger.warning(color(f'!!! transfer not completed: status={status}', 'red'))
logger.warning(
color(f'!!! transfer not completed: status={status}', 'red')
)
# Re-submit the transfer so we can receive more data
transfer.submit()
@@ -233,7 +243,10 @@ async def open_usb_transport(spec):
def run(self):
logger.debug('starting USB event loop')
while self.events_in_transfer.isSubmitted() or self.acl_in_transfer.isSubmitted():
while (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
):
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
@@ -253,11 +266,15 @@ async def open_usb_transport(spec):
packet_type = transfer.getUserData()
try:
transfer.cancel()
logger.debug(f'waiting for IN[{packet_type}] transfer cancellation to be done...')
logger.debug(
f'waiting for IN[{packet_type}] transfer cancellation to be done...'
)
await self.cancel_done[packet_type]
logger.debug(f'IN[{packet_type}] transfer cancellation done')
except usb1.USBError:
logger.debug(f'IN[{packet_type}] transfer likely already completed')
logger.debug(
f'IN[{packet_type}] transfer likely already completed'
)
# Wait for the thread to terminate
await self.event_loop_done
@@ -265,8 +282,8 @@ async def open_usb_transport(spec):
class UsbTransport(Transport):
def __init__(self, context, device, interface, setting, source, sink):
super().__init__(source, sink)
self.context = context
self.device = device
self.context = context
self.device = device
self.interface = interface
# Get exclusive access
@@ -315,9 +332,9 @@ async def open_usb_transport(spec):
except usb1.USBError:
device_serial_number = None
if (
device.getVendorID() == int(vendor_id, 16) and
device.getProductID() == int(product_id, 16) and
(serial_number is None or serial_number == device_serial_number)
device.getVendorID() == int(vendor_id, 16)
and device.getProductID() == int(product_id, 16)
and (serial_number is None or serial_number == device_serial_number)
):
if device_index == 0:
found = device
@@ -328,8 +345,11 @@ async def open_usb_transport(spec):
# Look for a compatible device by index
def device_is_bluetooth_hci(device):
# Check if the device class indicates a match
if (device.getDeviceClass(), device.getDeviceSubClass(), device.getDeviceProtocol()) == \
USB_BT_HCI_CLASS_TUPLE:
if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
# If the device class is 'Device', look for a matching interface
@@ -337,8 +357,11 @@ async def open_usb_transport(spec):
for configuration in device:
for interface in configuration:
for setting in interface:
if (setting.getClass(), setting.getSubClass(), setting.getProtocol()) == \
USB_BT_HCI_CLASS_TUPLE:
if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
return False
@@ -366,38 +389,52 @@ async def open_usb_transport(spec):
setting = None
for setting in interface:
if (
not forced_mode and
(setting.getClass(), setting.getSubClass(), setting.getProtocol()) != USB_BT_HCI_CLASS_TUPLE
not forced_mode
and (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
)
!= USB_BT_HCI_CLASS_TUPLE
):
continue
events_in = None
acl_in = None
acl_out = None
acl_in = None
acl_out = None
for endpoint in setting:
attributes = endpoint.getAttributes()
address = endpoint.getAddress()
address = endpoint.getAddress()
if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK:
if address & USB_ENDPOINT_IN and acl_in is None:
acl_in = address
elif acl_out is None:
acl_out = address
elif attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT:
elif (
attributes & 0x03
== USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT
):
if address & USB_ENDPOINT_IN and events_in is None:
events_in = address
# Return if we found all 3 endpoints
if acl_in is not None and acl_out is not None and events_in is not None:
if (
acl_in is not None
and acl_out is not None
and events_in is not None
):
return (
configuration_index + 1,
setting.getNumber(),
setting.getAlternateSetting(),
acl_in,
acl_out,
events_in
events_in,
)
else:
logger.debug(f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}')
logger.debug(
f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}'
)
endpoints = find_endpoints(found)
if endpoints is None:
@@ -437,7 +474,7 @@ async def open_usb_transport(spec):
logger.warning('failed to set configuration')
source = UsbPacketSource(context, device, acl_in, events_in)
sink = UsbPacketSink(device, acl_out)
sink = UsbPacketSink(device, acl_out)
return UsbTransport(context, device, interface, setting, source, sink)
except usb1.USBError as error:
logger.warning(color(f'!!! failed to open USB device: {error}', 'red'))

View File

@@ -33,7 +33,7 @@ async def open_vhci_transport(spec):
path at /dev/vhci), or the path of a VHCI device
'''
HCI_VENDOR_PKT = 0xff
HCI_VENDOR_PKT = 0xFF
HCI_BREDR = 0x00 # Controller type
# Open the VHCI device
@@ -56,4 +56,3 @@ async def open_vhci_transport(spec):
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))
return transport

View File

@@ -43,7 +43,7 @@ async def open_ws_client_transport(spec):
transport = PumpedTransport(
PumpedPacketSource(websocket.recv),
PumpedPacketSink(websocket.send),
websocket.close
websocket.close,
)
transport.start()
return transport

View File

@@ -41,8 +41,8 @@ async def open_ws_server_transport(spec):
class WsServerTransport(Transport):
def __init__(self):
source = ParserSource()
sink = PumpedPacketSink(self.send_packet)
source = ParserSource()
sink = PumpedPacketSink(self.send_packet)
self.connection = asyncio.get_running_loop().create_future()
super().__init__(source, sink)
@@ -50,14 +50,16 @@ async def open_ws_server_transport(spec):
async def serve(self, local_host, local_port):
self.sink.start()
self.server = await websockets.serve(
ws_handler = self.on_connection,
host = local_host if local_host != '_' else None,
port = int(local_port)
ws_handler=self.on_connection,
host=local_host if local_host != '_' else None,
port=int(local_port),
)
logger.debug(f'websocket server ready on port {local_port}')
async def on_connection(self, connection):
logger.debug(f'new connection on {connection.local_address} from {connection.remote_address}')
logger.debug(
f'new connection on {connection.local_address} from {connection.remote_address}'
)
self.connection.set_result(connection)
try:
async for packet in connection:

View File

@@ -34,6 +34,7 @@ logger = logging.getLogger(__name__)
def setup_event_forwarding(emitter, forwarder, event_name):
def emit(*args, **kwargs):
forwarder.emit(event_name, *args, **kwargs)
emitter.on(event_name, emit)
@@ -44,6 +45,7 @@ def composite_listener(cls):
registers/deregisters all methods named `on_<event_name>` as a listener for
the <event_name> event with an emitter.
"""
def register(self, emitter):
for method_name in dir(cls):
if method_name.startswith('on_'):
@@ -54,7 +56,7 @@ def composite_listener(cls):
if method_name.startswith('on_'):
emitter.remove_listener(method_name[3:], getattr(self, method_name))
cls._bumble_register_composite = register
cls._bumble_register_composite = register
cls._bumble_deregister_composite = deregister
return cls
@@ -110,7 +112,9 @@ class AsyncRunner:
try:
await item
except Exception as error:
logger.warning(f'{color("!!! Exception in work queue:", "red")} {error}')
logger.warning(
f'{color("!!! Exception in work queue:", "red")} {error}'
)
# Shared default queue
default_queue = WorkQueue()
@@ -131,7 +135,9 @@ class AsyncRunner:
try:
await coroutine
except Exception:
logger.warning(f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}')
logger.warning(
f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}'
)
asyncio.create_task(run())
else:
@@ -150,18 +156,26 @@ class FlowControlAsyncPipe:
paused (by calling a function passed in when the pipe is created) if the
amount of queued data exceeds a specified threshold.
"""
def __init__(self, pause_source, resume_source, write_to_sink=None, drain_sink=None, threshold=0):
self.pause_source = pause_source
def __init__(
self,
pause_source,
resume_source,
write_to_sink=None,
drain_sink=None,
threshold=0,
):
self.pause_source = pause_source
self.resume_source = resume_source
self.write_to_sink = write_to_sink
self.drain_sink = drain_sink
self.threshold = threshold
self.queue = collections.deque() # Queue of packets
self.queued_bytes = 0 # Number of bytes in the queue
self.drain_sink = drain_sink
self.threshold = threshold
self.queue = collections.deque() # Queue of packets
self.queued_bytes = 0 # Number of bytes in the queue
self.ready_to_pump = asyncio.Event()
self.paused = False
self.paused = False
self.source_paused = False
self.pump_task = None
self.pump_task = None
def start(self):
if self.pump_task is None:

View File

@@ -80,6 +80,7 @@ async def main():
await my_work_queue2.run()
print("MAIN: end (should never get here)")
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -55,7 +55,9 @@ async def main():
# Subscribe to and read the battery level
if battery_service.battery_level:
await battery_service.battery_level.subscribe(
lambda value: print(f'{color("Battery Level Update:", "green")} {value}')
lambda value: print(
f'{color("Battery Level Update:", "green")} {value}'
)
)
value = await battery_service.battery_level.read_value()
print(f'{color("Initial Battery Level:", "green")} {value}')
@@ -64,5 +66,5 @@ async def main():
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -44,11 +44,19 @@ async def main():
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Battery', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(battery_service.uuid)),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340))
])
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Battery', 'utf-8'),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(battery_service.uuid),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
# Go!
@@ -58,9 +66,11 @@ async def main():
# Notify every 3 seconds
while True:
await asyncio.sleep(3.0)
await device.notify_subscribers(battery_service.battery_level_characteristic)
await device.notify_subscribers(
battery_service.battery_level_characteristic
)
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -28,7 +28,9 @@ from bumble.transport import open_transport
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) != 3:
print('Usage: device_information_client.py <transport-spec> <bluetooth-address>')
print(
'Usage: device_information_client.py <transport-spec> <bluetooth-address>'
)
print('example: device_information_client.py usb:0 E1:CA:72:48:C4:E8')
return
@@ -49,7 +51,9 @@ async def main():
# Discover the Device Information service
peer = Peer(connection)
print('=== Discovering Device Information Service')
device_information_service = await peer.discover_service_and_create_proxy(DeviceInformationServiceProxy)
device_information_service = await peer.discover_service_and_create_proxy(
DeviceInformationServiceProxy
)
# Check that the service was found
if device_information_service is None:
@@ -58,23 +62,52 @@ async def main():
# Read and print the fields
if device_information_service.manufacturer_name is not None:
print(color('Manufacturer Name: ', 'green'), await device_information_service.manufacturer_name.read_value())
print(
color('Manufacturer Name: ', 'green'),
await device_information_service.manufacturer_name.read_value(),
)
if device_information_service.model_number is not None:
print(color('Model Number: ', 'green'), await device_information_service.model_number.read_value())
print(
color('Model Number: ', 'green'),
await device_information_service.model_number.read_value(),
)
if device_information_service.serial_number is not None:
print(color('Serial Number: ', 'green'), await device_information_service.serial_number.read_value())
print(
color('Serial Number: ', 'green'),
await device_information_service.serial_number.read_value(),
)
if device_information_service.hardware_revision is not None:
print(color('Hardware Revision: ', 'green'), await device_information_service.hardware_revision.read_value())
print(
color('Hardware Revision: ', 'green'),
await device_information_service.hardware_revision.read_value(),
)
if device_information_service.firmware_revision is not None:
print(color('Firmware Revision: ', 'green'), await device_information_service.firmware_revision.read_value())
print(
color('Firmware Revision: ', 'green'),
await device_information_service.firmware_revision.read_value(),
)
if device_information_service.software_revision is not None:
print(color('Software Revision: ', 'green'), await device_information_service.software_revision.read_value())
print(
color('Software Revision: ', 'green'),
await device_information_service.software_revision.read_value(),
)
if device_information_service.system_id is not None:
print(color('System ID: ', 'green'), await device_information_service.system_id.read_value())
if device_information_service.ieee_regulatory_certification_data_list is not None:
print(color('Regulatory Certification:', 'green'), (await device_information_service.ieee_regulatory_certification_data_list.read_value()).hex())
print(
color('System ID: ', 'green'),
await device_information_service.system_id.read_value(),
)
if (
device_information_service.ieee_regulatory_certification_data_list
is not None
):
print(
color('Regulatory Certification:', 'green'),
(
await device_information_service.ieee_regulatory_certification_data_list.read_value()
).hex(),
)
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -39,21 +39,26 @@ async def main():
# Add a Device Information Service to the GATT sever
device_information_service = DeviceInformationService(
manufacturer_name = 'ACME',
model_number = 'AB-102',
serial_number = '7654321',
hardware_revision = '1.1.3',
software_revision = '2.5.6',
system_id = (0x123456, 0x8877665544)
manufacturer_name='ACME',
model_number='AB-102',
serial_number='7654321',
hardware_revision='1.1.3',
software_revision='2.5.6',
system_id=(0x123456, 0x8877665544),
)
device.add_service(device_information_service)
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Device', 'utf-8')),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340))
])
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Device', 'utf-8'),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
# Go!
@@ -61,6 +66,7 @@ async def main():
await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -61,12 +61,14 @@ async def main():
# Subscribe to the heart rate measurement
if heart_rate_service.heart_rate_measurement:
await heart_rate_service.heart_rate_measurement.subscribe(
lambda value: print(f'{color("Heart Rate Measurement:", "green")} {value}')
lambda value: print(
f'{color("Heart Rate Measurement:", "green")} {value}'
)
)
await peer.sustain()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -50,34 +50,52 @@ async def main():
# Add a Device Information Service and Heart Rate Service to the GATT sever
device_information_service = DeviceInformationService(
manufacturer_name = 'ACME',
model_number = 'HR-102',
serial_number = '7654321',
hardware_revision = '1.1.3',
software_revision = '2.5.6',
system_id = (0x123456, 0x8877665544)
manufacturer_name='ACME',
model_number='HR-102',
serial_number='7654321',
hardware_revision='1.1.3',
software_revision='2.5.6',
system_id=(0x123456, 0x8877665544),
)
heart_rate_service = HeartRateService(
read_heart_rate_measurement = lambda _: HeartRateService.HeartRateMeasurement(
heart_rate = 100 + int(50 * math.sin(time.time() * math.pi / 60)),
sensor_contact_detected = random.choice((True, False, None)),
energy_expended = random.choice((int((time.time() - energy_start_time) * 100), None)),
rr_intervals = random.choice(((random.randint(900, 1100) / 1000, random.randint(900, 1100) / 1000), None))
read_heart_rate_measurement=lambda _: HeartRateService.HeartRateMeasurement(
heart_rate=100 + int(50 * math.sin(time.time() * math.pi / 60)),
sensor_contact_detected=random.choice((True, False, None)),
energy_expended=random.choice(
(int((time.time() - energy_start_time) * 100), None)
),
rr_intervals=random.choice(
(
(
random.randint(900, 1100) / 1000,
random.randint(900, 1100) / 1000,
),
None,
)
),
),
body_sensor_location=HeartRateService.BodySensorLocation.WRIST,
reset_energy_expended=lambda _: reset_energy_expended()
reset_energy_expended=lambda _: reset_energy_expended(),
)
device.add_services([device_information_service, heart_rate_service])
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Heart', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(heart_rate_service.uuid)),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340))
])
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Heart', 'utf-8'),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(heart_rate_service.uuid),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
# Go!
@@ -87,9 +105,11 @@ async def main():
# Notify every 3 seconds
while True:
await asyncio.sleep(3.0)
await device.notify_subscribers(heart_rate_service.heart_rate_measurement_characteristic)
await device.notify_subscribers(
heart_rate_service.heart_rate_measurement_characteristic
)
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -43,56 +43,90 @@ from bumble.gatt import (
GATT_PROTOCOL_MODE_CHARACTERISTIC,
GATT_HID_INFORMATION_CHARACTERISTIC,
GATT_HID_CONTROL_POINT_CHARACTERISTIC,
GATT_REPORT_REFERENCE_DESCRIPTOR
GATT_REPORT_REFERENCE_DESCRIPTOR,
)
# -----------------------------------------------------------------------------
# Protocol Modes
HID_BOOT_PROTOCOL = 0x00
HID_BOOT_PROTOCOL = 0x00
HID_REPORT_PROTOCOL = 0x01
# Report Types
HID_INPUT_REPORT = 0x01
HID_OUTPUT_REPORT = 0x02
HID_INPUT_REPORT = 0x01
HID_OUTPUT_REPORT = 0x02
HID_FEATURE_REPORT = 0x03
# Report Map
HID_KEYBOARD_REPORT_MAP = bytes([
0x05, 0x01, # Usage Page (Generic Desktop Ctrls)
0x09, 0x06, # Usage (Keyboard)
0xA1, 0x01, # Collection (Application)
0x85, 0x01, # . Report ID (1)
0x05, 0x07, # . Usage Page (Kbrd/Keypad)
0x19, 0xE0, # . Usage Minimum (0xE0)
0x29, 0xE7, # . Usage Maximum (0xE7)
0x15, 0x00, # . Logical Minimum (0)
0x25, 0x01, # . Logical Maximum (1)
0x75, 0x01, # . Report Size (1)
0x95, 0x08, # . Report Count (8)
0x81, 0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95, 0x01, # . Report Count (1)
0x75, 0x08, # . Report Size (8)
0x81, 0x01, # . Input (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95, 0x06, # . Report Count (6)
0x75, 0x08, # . Report Size (8)
0x15, 0x00, # . Logical Minimum (0x00)
0x25, 0x94, # . Logical Maximum (0x94)
0x05, 0x07, # . Usage Page (Kbrd/Keypad)
0x19, 0x00, # . Usage Minimum (0x00)
0x29, 0x94, # . Usage Maximum (0x94)
0x81, 0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95, 0x05, # . Report Count (5)
0x75, 0x01, # . Report Size (1)
0x05, 0x08, # . Usage Page (LEDs)
0x19, 0x01, # . Usage Minimum (Num Lock)
0x29, 0x05, # . Usage Maximum (Kana)
0x91, 0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0x95, 0x01, # . Report Count (1)
0x75, 0x03, # . Report Size (3)
0x91, 0x01, # . Output (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0xC0 # End Collection
])
HID_KEYBOARD_REPORT_MAP = bytes(
[
0x05,
0x01, # Usage Page (Generic Desktop Ctrls)
0x09,
0x06, # Usage (Keyboard)
0xA1,
0x01, # Collection (Application)
0x85,
0x01, # . Report ID (1)
0x05,
0x07, # . Usage Page (Kbrd/Keypad)
0x19,
0xE0, # . Usage Minimum (0xE0)
0x29,
0xE7, # . Usage Maximum (0xE7)
0x15,
0x00, # . Logical Minimum (0)
0x25,
0x01, # . Logical Maximum (1)
0x75,
0x01, # . Report Size (1)
0x95,
0x08, # . Report Count (8)
0x81,
0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x01, # . Report Count (1)
0x75,
0x08, # . Report Size (8)
0x81,
0x01, # . Input (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x06, # . Report Count (6)
0x75,
0x08, # . Report Size (8)
0x15,
0x00, # . Logical Minimum (0x00)
0x25,
0x94, # . Logical Maximum (0x94)
0x05,
0x07, # . Usage Page (Kbrd/Keypad)
0x19,
0x00, # . Usage Minimum (0x00)
0x29,
0x94, # . Usage Maximum (0x94)
0x81,
0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x05, # . Report Count (5)
0x75,
0x01, # . Report Size (1)
0x05,
0x08, # . Usage Page (LEDs)
0x19,
0x01, # . Usage Minimum (Num Lock)
0x29,
0x05, # . Usage Maximum (Kana)
0x91,
0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0x95,
0x01, # . Report Count (1)
0x75,
0x03, # . Report Size (3)
0x91,
0x01, # . Output (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0xC0, # End Collection
]
)
# -----------------------------------------------------------------------------
@@ -133,31 +167,41 @@ async def keyboard_host(device, peer_address):
return
await peer.discover_characteristics()
protocol_mode_characteristics = peer.get_characteristics_by_uuid(GATT_PROTOCOL_MODE_CHARACTERISTIC)
protocol_mode_characteristics = peer.get_characteristics_by_uuid(
GATT_PROTOCOL_MODE_CHARACTERISTIC
)
if not protocol_mode_characteristics:
print(color('!!! No Protocol Mode characteristic', 'red'))
return
protocol_mode_characteristic = protocol_mode_characteristics[0]
hid_information_characteristics = peer.get_characteristics_by_uuid(GATT_HID_INFORMATION_CHARACTERISTIC)
hid_information_characteristics = peer.get_characteristics_by_uuid(
GATT_HID_INFORMATION_CHARACTERISTIC
)
if not hid_information_characteristics:
print(color('!!! No HID Information characteristic', 'red'))
return
hid_information_characteristic = hid_information_characteristics[0]
report_map_characteristics = peer.get_characteristics_by_uuid(GATT_REPORT_MAP_CHARACTERISTIC)
report_map_characteristics = peer.get_characteristics_by_uuid(
GATT_REPORT_MAP_CHARACTERISTIC
)
if not report_map_characteristics:
print(color('!!! No Report Map characteristic', 'red'))
return
report_map_characteristic = report_map_characteristics[0]
control_point_characteristics = peer.get_characteristics_by_uuid(GATT_HID_CONTROL_POINT_CHARACTERISTIC)
control_point_characteristics = peer.get_characteristics_by_uuid(
GATT_HID_CONTROL_POINT_CHARACTERISTIC
)
if not control_point_characteristics:
print(color('!!! No Control Point characteristic', 'red'))
return
# control_point_characteristic = control_point_characteristics[0]
report_characteristics = peer.get_characteristics_by_uuid(GATT_REPORT_CHARACTERISTIC)
report_characteristics = peer.get_characteristics_by_uuid(
GATT_REPORT_CHARACTERISTIC
)
if not report_characteristics:
print(color('!!! No Report characteristic', 'red'))
return
@@ -165,13 +209,20 @@ async def keyboard_host(device, peer_address):
print(color('REPORT:', 'yellow'), characteristic)
if characteristic.properties & Characteristic.NOTIFY:
await peer.discover_descriptors(characteristic)
report_reference_descriptor = characteristic.get_descriptor(GATT_REPORT_REFERENCE_DESCRIPTOR)
report_reference_descriptor = characteristic.get_descriptor(
GATT_REPORT_REFERENCE_DESCRIPTOR
)
if report_reference_descriptor:
report_reference = await peer.read_value(report_reference_descriptor)
print(color(' Report Reference:', 'blue'), report_reference.hex())
else:
report_reference = bytes([0, 0])
await peer.subscribe(characteristic, lambda value, param=f'[{i}] {report_reference.hex()}': on_report(param, value))
await peer.subscribe(
characteristic,
lambda value, param=f'[{i}] {report_reference.hex()}': on_report(
param, value
),
)
protocol_mode = await peer.read_value(protocol_mode_characteristic)
print(f'Protocol Mode: {protocol_mode.hex()}')
@@ -192,77 +243,91 @@ async def keyboard_device(device, command):
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([0, 0, 0, 0, 0, 0, 0, 0]),
[
Descriptor(GATT_REPORT_REFERENCE_DESCRIPTOR, Descriptor.READABLE, bytes([0x01, HID_INPUT_REPORT]))
]
Descriptor(
GATT_REPORT_REFERENCE_DESCRIPTOR,
Descriptor.READABLE,
bytes([0x01, HID_INPUT_REPORT]),
)
],
)
# Create an 'output report' characteristic to receive keyboard reports from the host
output_report_characteristic = Characteristic(
GATT_REPORT_CHARACTERISTIC,
Characteristic.READ | Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.READ
| Characteristic.WRITE
| Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([0]),
[
Descriptor(GATT_REPORT_REFERENCE_DESCRIPTOR, Descriptor.READABLE, bytes([0x01, HID_OUTPUT_REPORT]))
]
Descriptor(
GATT_REPORT_REFERENCE_DESCRIPTOR,
Descriptor.READABLE,
bytes([0x01, HID_OUTPUT_REPORT]),
)
],
)
# Add the services to the GATT sever
device.add_services([
Service(
GATT_DEVICE_INFORMATION_SERVICE,
[
Characteristic(
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
'Bumble'
)
]
),
Service(
GATT_HUMAN_INTERFACE_DEVICE_SERVICE,
[
Characteristic(
GATT_PROTOCOL_MODE_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([HID_REPORT_PROTOCOL])
),
Characteristic(
GATT_HID_INFORMATION_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([0x11, 0x01, 0x00, 0x03]) # bcdHID=1.1, bCountryCode=0x00, Flags=RemoteWake|NormallyConnectable
),
Characteristic(
GATT_HID_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_hid_control_point_write)
),
Characteristic(
GATT_REPORT_MAP_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
HID_KEYBOARD_REPORT_MAP
),
input_report_characteristic,
output_report_characteristic
]
),
Service(
GATT_BATTERY_SERVICE,
[
Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([100])
)
]
)
])
device.add_services(
[
Service(
GATT_DEVICE_INFORMATION_SERVICE,
[
Characteristic(
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
'Bumble',
)
],
),
Service(
GATT_HUMAN_INTERFACE_DEVICE_SERVICE,
[
Characteristic(
GATT_PROTOCOL_MODE_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([HID_REPORT_PROTOCOL]),
),
Characteristic(
GATT_HID_INFORMATION_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes(
[0x11, 0x01, 0x00, 0x03]
), # bcdHID=1.1, bCountryCode=0x00, Flags=RemoteWake|NormallyConnectable
),
Characteristic(
GATT_HID_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_hid_control_point_write),
),
Characteristic(
GATT_REPORT_MAP_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
HID_KEYBOARD_REPORT_MAP,
),
input_report_characteristic,
output_report_characteristic,
],
),
Service(
GATT_BATTERY_SERVICE,
[
Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([100]),
)
],
),
]
)
# Debug print
for attribute in device.gatt_server.attributes:
@@ -270,13 +335,20 @@ async def keyboard_device(device, command):
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Keyboard', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(GATT_HUMAN_INTERFACE_DEVICE_SERVICE)),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x03C1)),
(AdvertisingData.FLAGS, bytes([0x05]))
])
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Keyboard', 'utf-8'),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(GATT_HUMAN_INTERFACE_DEVICE_SERVICE),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x03C1)),
(AdvertisingData.FLAGS, bytes([0x05])),
]
)
)
# Attach a listener
@@ -303,14 +375,21 @@ async def keyboard_device(device, command):
code = ord(key)
if code >= ord('a') and code <= ord('z'):
hid_code = 0x04 + code - ord('a')
input_report_characteristic.value = bytes([0, 0, hid_code, 0, 0, 0, 0, 0])
await device.notify_subscribers(input_report_characteristic)
input_report_characteristic.value = bytes(
[0, 0, hid_code, 0, 0, 0, 0, 0]
)
await device.notify_subscribers(
input_report_characteristic
)
elif message_type == 'keyup':
input_report_characteristic.value = bytes.fromhex('0000000000000000')
input_report_characteristic.value = bytes.fromhex(
'0000000000000000'
)
await device.notify_subscribers(input_report_characteristic)
except websockets.exceptions.ConnectionClosedOK:
pass
await websockets.serve(serve, 'localhost', 8989)
await asyncio.get_event_loop().create_future()
else:
@@ -321,7 +400,9 @@ async def keyboard_device(device, command):
# Keypress for the letter
keycode = 0x04 + letter - 0x61
input_report_characteristic.value = bytes([0, 0, keycode, 0, 0, 0, 0, 0])
input_report_characteristic.value = bytes(
[0, 0, keycode, 0, 0, 0, 0, 0]
)
await device.notify_subscribers(input_report_characteristic)
# Key release
@@ -335,10 +416,16 @@ async def main():
print('Usage: python keyboard.py <device-config> <transport-spec> <command>')
print(' where <command> is one of:')
print(' connect <address> (run a keyboard host, connecting to a keyboard)')
print(' web (run a keyboard with keypress input from a web page, see keyboard.html')
print(' sim (run a keyboard simulation, emitting a canned sequence of keystrokes')
print(
' web (run a keyboard with keypress input from a web page, see keyboard.html'
)
print(
' sim (run a keyboard simulation, emitting a canned sequence of keystrokes'
)
print('example: python keyboard.py keyboard.json usb:0 sim')
print('example: python keyboard.py keyboard.json usb:0 connect A0:A1:A2:A3:A4:A5')
print(
'example: python keyboard.py keyboard.json usb:0 connect A0:A1:A2:A3:A4:A5'
)
return
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
@@ -355,5 +442,5 @@ async def main():
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -27,12 +27,9 @@ from bumble.core import (
BT_BR_EDR_TRANSPORT,
BT_AVDTP_PROTOCOL_ID,
BT_AUDIO_SINK_SERVICE,
BT_L2CAP_PROTOCOL_ID
)
from bumble.avdtp import (
Protocol as AVDTP_Protocol,
find_avdtp_service_with_connection
BT_L2CAP_PROTOCOL_ID,
)
from bumble.avdtp import Protocol as AVDTP_Protocol, find_avdtp_service_with_connection
from bumble.a2dp import make_audio_source_service_sdp_records
from bumble.sdp import (
Client as SDP_Client,
@@ -40,7 +37,7 @@ from bumble.sdp import (
DataElement,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
)
@@ -48,7 +45,9 @@ from bumble.sdp import (
def sdp_records():
service_record_handle = 0x00010001
return {
service_record_handle: make_audio_source_service_sdp_records(service_record_handle)
service_record_handle: make_audio_source_service_sdp_records(
service_record_handle
)
}
@@ -64,8 +63,8 @@ async def find_a2dp_service(device, connection):
[
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
]
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
print(color('==================================', 'blue'))
@@ -78,8 +77,7 @@ async def find_a2dp_service(device, connection):
# Service classes
service_class_id_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
)
if service_class_id_list:
if service_class_id_list.value:
@@ -89,8 +87,7 @@ async def find_a2dp_service(device, connection):
# Protocol info
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if protocol_descriptor_list:
print(color(' Protocol:', 'green'))
@@ -103,18 +100,24 @@ async def find_a2dp_service(device, connection):
if len(protocol_descriptor.value) >= 2:
avdtp_version_major = protocol_descriptor.value[1].value >> 8
avdtp_version_minor = protocol_descriptor.value[1].value & 0xFF
print(f'{color(" AVDTP Version:", "cyan")} {avdtp_version_major}.{avdtp_version_minor}')
print(
f'{color(" AVDTP Version:", "cyan")} {avdtp_version_major}.{avdtp_version_minor}'
)
service_version = (avdtp_version_major, avdtp_version_minor)
# Profile info
bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
attribute_list, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if bluetooth_profile_descriptor_list:
if bluetooth_profile_descriptor_list.value:
if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE:
bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value
if (
bluetooth_profile_descriptor_list.value[0].type
== DataElement.SEQUENCE
):
bluetooth_profile_descriptors = (
bluetooth_profile_descriptor_list.value
)
else:
# Sometimes, instead of a list of lists, we just find a list. Fix that
bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list]
@@ -123,7 +126,9 @@ async def find_a2dp_service(device, connection):
for bluetooth_profile_descriptor in bluetooth_profile_descriptors:
version_major = bluetooth_profile_descriptor.value[1].value >> 8
version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF
print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}')
print(
f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}'
)
await sdp_client.disconnect()
return service_version
@@ -184,5 +189,5 @@ async def main():
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -28,7 +28,7 @@ from bumble.avdtp import (
AVDTP_AUDIO_MEDIA_TYPE,
Protocol,
Listener,
MediaCodecCapabilities
MediaCodecCapabilities,
)
from bumble.a2dp import (
make_audio_sink_service_sdp_records,
@@ -39,19 +39,19 @@ from bumble.a2dp import (
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
SbcMediaCodecInformation
SbcMediaCodecInformation,
)
Context = {
'output': None
}
Context = {'output': None}
# -----------------------------------------------------------------------------
def sdp_records():
service_record_handle = 0x00010001
return {
service_record_handle: make_audio_sink_service_sdp_records(service_record_handle)
service_record_handle: make_audio_sink_service_sdp_records(
service_record_handle
)
}
@@ -59,22 +59,25 @@ def sdp_records():
def codec_capabilities():
# NOTE: this shouldn't be hardcoded, but passed on the command line instead
return MediaCodecCapabilities(
media_type = AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type = A2DP_SBC_CODEC_TYPE,
media_codec_information = SbcMediaCodecInformation.from_lists(
sampling_frequencies = [48000, 44100, 32000, 16000],
channel_modes = [
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE
SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths = [4, 8, 12, 16],
subbands = [4, 8],
allocation_methods = [SBC_LOUDNESS_ALLOCATION_METHOD, SBC_SNR_ALLOCATION_METHOD],
minimum_bitpool_value = 2,
maximum_bitpool_value = 53
)
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
)
@@ -87,10 +90,10 @@ def on_avdtp_connection(server):
# -----------------------------------------------------------------------------
def on_rtp_packet(packet):
header = packet.payload[0]
fragmented = header >> 7
start = (header >> 6) & 0x01
last = (header >> 5) & 0x01
header = packet.payload[0]
fragmented = header >> 7
start = (header >> 6) & 0x01
last = (header >> 5) & 0x01
number_of_frames = header & 0x0F
if fragmented:
@@ -104,7 +107,9 @@ def on_rtp_packet(packet):
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 4:
print('Usage: run_a2dp_sink.py <device-config> <transport-spec> <sbc-file> [<bt-addr>]')
print(
'Usage: run_a2dp_sink.py <device-config> <transport-spec> <sbc-file> [<bt-addr>]'
)
print('example: run_a2dp_sink.py classic1.json usb:0 output.sbc')
return
@@ -133,7 +138,9 @@ async def main():
# Connect to the source
target_address = sys.argv[4]
print(f'=== Connecting to {target_address}...')
connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT)
connection = await device.connect(
target_address, transport=BT_BR_EDR_TRANSPORT
)
print(f'=== Connected to {connection.peer_address}!')
# Request authentication
@@ -159,5 +166,5 @@ async def main():
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -30,7 +30,7 @@ from bumble.avdtp import (
MediaCodecCapabilities,
MediaPacketPump,
Protocol,
Listener
Listener,
)
from bumble.a2dp import (
SBC_JOINT_STEREO_CHANNEL_MODE,
@@ -38,7 +38,7 @@ from bumble.a2dp import (
make_audio_source_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
SbcMediaCodecInformation,
SbcPacketSource
SbcPacketSource,
)
@@ -46,7 +46,9 @@ from bumble.a2dp import (
def sdp_records():
service_record_handle = 0x00010001
return {
service_record_handle: make_audio_source_service_sdp_records(service_record_handle)
service_record_handle: make_audio_source_service_sdp_records(
service_record_handle
)
}
@@ -54,23 +56,25 @@ def sdp_records():
def codec_capabilities():
# NOTE: this shouldn't be hardcoded, but should be inferred from the input file instead
return MediaCodecCapabilities(
media_type = AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type = A2DP_SBC_CODEC_TYPE,
media_codec_information = SbcMediaCodecInformation.from_discrete_values(
sampling_frequency = 44100,
channel_mode = SBC_JOINT_STEREO_CHANNEL_MODE,
block_length = 16,
subbands = 8,
allocation_method = SBC_LOUDNESS_ALLOCATION_METHOD,
minimum_bitpool_value = 2,
maximum_bitpool_value = 53
)
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation.from_discrete_values(
sampling_frequency=44100,
channel_mode=SBC_JOINT_STEREO_CHANNEL_MODE,
block_length=16,
subbands=8,
allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD,
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
)
# -----------------------------------------------------------------------------
def on_avdtp_connection(read_function, protocol):
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.mtu, codec_capabilities())
packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets)
protocol.add_source(packet_source.codec_capabilities, packet_pump)
@@ -83,14 +87,18 @@ async def stream_packets(read_function, protocol):
print('@@@', endpoint)
# Select a sink
sink = protocol.find_remote_sink_by_codec(AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE)
sink = protocol.find_remote_sink_by_codec(
AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE
)
if sink is None:
print(color('!!! no SBC sink found', 'red'))
return
print(f'### Selected sink: {sink.seid}')
# Stream the packets
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.mtu, codec_capabilities())
packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets)
source = protocol.add_source(packet_source.codec_capabilities, packet_pump)
stream = await protocol.create_stream(source, sink)
@@ -107,8 +115,12 @@ async def stream_packets(read_function, protocol):
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 4:
print('Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> [<bluetooth-address>]')
print('example: run_a2dp_source.py classic1.json usb:0 test.sbc E1:CA:72:48:C4:E8')
print(
'Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> [<bluetooth-address>]'
)
print(
'example: run_a2dp_source.py classic1.json usb:0 test.sbc E1:CA:72:48:C4:E8'
)
return
print('<<< connecting to HCI...')
@@ -134,7 +146,9 @@ async def main():
# Connect to a peer
target_address = sys.argv[4]
print(f'=== Connecting to {target_address}...')
connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT)
connection = await device.connect(
target_address, transport=BT_BR_EDR_TRANSPORT
)
print(f'=== Connected to {connection.peer_address}!')
# Request authentication
@@ -148,7 +162,9 @@ async def main():
print('*** Encryption on')
# Look for an A2DP service
avdtp_version = await find_avdtp_service_with_connection(device, connection)
avdtp_version = await find_avdtp_service_with_connection(
device, connection
)
if not avdtp_version:
print(color('!!! no A2DP service found'))
return
@@ -161,7 +177,9 @@ async def main():
else:
# Create a listener to wait for AVDTP connections
listener = Listener(Listener.create_registrar(device), version=(1, 2))
listener.on('connection', lambda protocol: on_avdtp_connection(read, protocol))
listener.on(
'connection', lambda protocol: on_avdtp_connection(read, protocol)
)
# Become connectable and wait for a connection
await device.set_discoverable(True)
@@ -171,5 +189,5 @@ async def main():
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -30,7 +30,9 @@ from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_advertiser.py <config-file> <transport-spec> [type] [address]')
print(
'Usage: run_advertiser.py <config-file> <transport-spec> [type] [address]'
)
print('example: run_advertiser.py device1.json usb:0')
return
@@ -56,6 +58,7 @@ async def main():
await device.start_advertising(advertising_type=advertising_type, target=target)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -25,28 +25,34 @@ from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.hci import UUID
from bumble.gatt import (
Service,
Characteristic,
CharacteristicValue
)
from bumble.gatt import Service, Characteristic, CharacteristicValue
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID(
'6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties'
)
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID(
'f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint'
)
ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID(
'38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus'
)
ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID(
'2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT'
)
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) != 4:
print('Usage: python run_asha_sink.py <device-config> <transport-spec> <audio-file>')
print(
'Usage: python run_asha_sink.py <device-config> <transport-spec> <audio-file>'
)
print('example: python run_asha_sink.py device1.json usb:0 audio_out.g722')
return
@@ -62,14 +68,18 @@ async def main():
if opcode == 1:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
print(f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}')
print(
f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}'
)
elif opcode == 2:
print('### STOP')
elif opcode == 3:
print(f'### STATUS: connected={value[1]}')
# Respond with a status
asyncio.create_task(device.notify_subscribers(audio_status_characteristic, force=True))
asyncio.create_task(
device.notify_subscribers(audio_status_characteristic, force=True)
)
# Handler for volume control
def on_volume_write(connection, value):
@@ -91,63 +101,91 @@ async def main():
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([
0x01, # Version
0x00, # Device Capabilities [Left, Monaural]
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, # HiSyncId
0x01, # Feature Map [LE CoC audio output streaming supported]
0x00, 0x00, # Render Delay
0x00, 0x00, # RFU
0x02, 0x00 # Codec IDs [G.722 at 16 kHz]
])
bytes(
[
0x01, # Version
0x00, # Device Capabilities [Left, Monaural]
0x01,
0x02,
0x03,
0x04,
0x05,
0x06,
0x07,
0x08, # HiSyncId
0x01, # Feature Map [LE CoC audio output streaming supported]
0x00,
0x00, # Render Delay
0x00,
0x00, # RFU
0x02,
0x00, # Codec IDs [G.722 at 16 kHz]
]
),
)
audio_control_point_characteristic = Characteristic(
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write)
CharacteristicValue(write=on_audio_control_point_write),
)
audio_status_characteristic = Characteristic(
ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0])
bytes([0]),
)
volume_characteristic = Characteristic(
ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write)
CharacteristicValue(write=on_volume_write),
)
le_psm_out_characteristic = Characteristic(
ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', psm)
struct.pack('<H', psm),
)
device.add_service(
Service(
ASHA_SERVICE,
[
read_only_properties_characteristic,
audio_control_point_characteristic,
audio_status_characteristic,
volume_characteristic,
le_psm_out_characteristic,
],
)
)
device.add_service(Service(
ASHA_SERVICE,
[
read_only_properties_characteristic,
audio_control_point_characteristic,
audio_status_characteristic,
volume_characteristic,
le_psm_out_characteristic
]
))
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(device.name, 'utf-8')),
(AdvertisingData.FLAGS, bytes([0x06])),
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(ASHA_SERVICE)),
(AdvertisingData.SERVICE_DATA_16_BIT_UUID, bytes(ASHA_SERVICE) + bytes([
0x01, # Protocol Version
0x00, # Capability
0x01, 0x02, 0x03, 0x04 # Truncated HiSyncID
]))
])
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(device.name, 'utf-8')),
(AdvertisingData.FLAGS, bytes([0x06])),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(ASHA_SERVICE),
),
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(ASHA_SERVICE)
+ bytes(
[
0x01, # Protocol Version
0x00, # Capability
0x01,
0x02,
0x03,
0x04, # Truncated HiSyncID
]
),
),
]
)
)
# Go!
@@ -156,6 +194,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -24,14 +24,22 @@ from colors import color
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.core import BT_BR_EDR_TRANSPORT, BT_L2CAP_PROTOCOL_ID
from bumble.sdp import Client as SDP_Client, SDP_PUBLIC_BROWSE_ROOT, SDP_ALL_ATTRIBUTES_RANGE
from bumble.sdp import (
Client as SDP_Client,
SDP_PUBLIC_BROWSE_ROOT,
SDP_ALL_ATTRIBUTES_RANGE,
)
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_classic_connect.py <device-config> <transport-spec> <bluetooth-addresses..>')
print('example: run_classic_connect.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8')
print(
'Usage: run_classic_connect.py <device-config> <transport-spec> <bluetooth-addresses..>'
)
print(
'example: run_classic_connect.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8'
)
return
print('<<< connecting to HCI...')
@@ -53,32 +61,49 @@ async def main():
await sdp_client.connect(connection)
# List all services in the root browse group
service_record_handles = await sdp_client.search_services([SDP_PUBLIC_BROWSE_ROOT])
service_record_handles = await sdp_client.search_services(
[SDP_PUBLIC_BROWSE_ROOT]
)
print(color('\n==================================', 'blue'))
print(color('SERVICES:', 'yellow'), service_record_handles)
# For each service in the root browse group, get all its attributes
for service_record_handle in service_record_handles:
attributes = await sdp_client.get_attributes(service_record_handle, [SDP_ALL_ATTRIBUTES_RANGE])
attributes = await sdp_client.get_attributes(
service_record_handle, [SDP_ALL_ATTRIBUTES_RANGE]
)
print(color(f'SERVICE {service_record_handle:04X} attributes:', 'yellow'))
for attribute in attributes:
print(' ', attribute.to_string(color=True))
# Search for services with an L2CAP service attribute
search_result = await sdp_client.search_attributes([BT_L2CAP_PROTOCOL_ID], [SDP_ALL_ATTRIBUTES_RANGE])
search_result = await sdp_client.search_attributes(
[BT_L2CAP_PROTOCOL_ID], [SDP_ALL_ATTRIBUTES_RANGE]
)
print(color('\n==================================', 'blue'))
print(color('SEARCH RESULTS:', 'yellow'))
for attribute_list in search_result:
print(color('SERVICE:', 'green'))
print(' ' + '\n '.join([attribute.to_string(color=True) for attribute in attribute_list]))
print(
' '
+ '\n '.join(
[attribute.to_string(color=True) for attribute in attribute_list]
)
)
await sdp_client.disconnect()
await hci_source.wait_for_termination()
# Connect to a peer
target_addresses = sys.argv[3:]
await asyncio.wait([asyncio.create_task(connect(target_address)) for target_address in target_addresses])
await asyncio.wait(
[
asyncio.create_task(connect(target_address))
for target_address in target_addresses
]
)
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -30,48 +30,62 @@ from bumble.sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.core import (
BT_AUDIO_SINK_SERVICE,
BT_L2CAP_PROTOCOL_ID,
BT_AVDTP_PROTOCOL_ID,
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
)
# -----------------------------------------------------------------------------
SDP_SERVICE_RECORDS = {
0x00010001: [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(0x00010001)),
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
])),
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)])
DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence([
DataElement.sequence([
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(25)
]),
DataElement.sequence([
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(256)
])
])
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(25),
]
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(256),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence([
DataElement.sequence([
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(256)
])
])
)
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(256),
]
)
]
),
),
]
}
@@ -99,6 +113,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -29,13 +29,23 @@ from bumble.core import DeviceClass
# -----------------------------------------------------------------------------
class DiscoveryListener(Device.Listener):
def on_inquiry_result(self, address, class_of_device, eir_data, rssi):
service_classes, major_device_class, minor_device_class = DeviceClass.split_class_of_device(class_of_device)
(
service_classes,
major_device_class,
minor_device_class,
) = DeviceClass.split_class_of_device(class_of_device)
separator = '\n '
print(f'>>> {color(address, "yellow")}:')
print(f' Device Class (raw): {class_of_device:06X}')
print(f' Device Major Class: {DeviceClass.major_device_class_name(major_device_class)}')
print(f' Device Minor Class: {DeviceClass.minor_device_class_name(major_device_class, minor_device_class)}')
print(f' Device Services: {", ".join(DeviceClass.service_class_labels(service_classes))}')
print(
f' Device Major Class: {DeviceClass.major_device_class_name(major_device_class)}'
)
print(
f' Device Minor Class: {DeviceClass.minor_device_class_name(major_device_class, minor_device_class)}'
)
print(
f' Device Services: {", ".join(DeviceClass.service_class_labels(service_classes))}'
)
print(f' RSSI: {rssi}')
if eir_data.ad_structures:
print(f' {eir_data.to_string(separator)}')
@@ -59,6 +69,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -27,8 +27,12 @@ from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_connect_and_encrypt.py <device-config> <transport-spec> <bluetooth-address>')
print('example: run_connect_and_encrypt.py device1.json usb:0 E1:CA:72:48:C4:E8')
print(
'Usage: run_connect_and_encrypt.py <device-config> <transport-spec> <bluetooth-address>'
)
print(
'example: run_connect_and_encrypt.py device1.json usb:0 E1:CA:72:48:C4:E8'
)
return
print('<<< connecting to HCI...')
@@ -53,6 +57,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -32,8 +32,12 @@ from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) != 4:
print('Usage: run_controller.py <controller-address> <device-config> <transport-spec>')
print('example: run_controller.py F2:F3:F4:F5:F6:F7 device1.json udp:0.0.0.0:22333,172.16.104.161:22333')
print(
'Usage: run_controller.py <controller-address> <device-config> <transport-spec>'
)
print(
'example: run_controller.py F2:F3:F4:F5:F6:F7 device1.json udp:0.0.0.0:22333,172.16.104.161:22333'
)
return
print('>>> connecting to HCI...')
@@ -44,11 +48,13 @@ async def main():
link = LocalLink()
# Create a first controller using the packet source/sink as its host interface
controller1 = Controller('C1', host_source = hci_source, host_sink = hci_sink, link = link)
controller1 = Controller(
'C1', host_source=hci_source, host_sink=hci_sink, link=link
)
controller1.random_address = sys.argv[1]
# Create a second controller using the same link
controller2 = Controller('C2', link = link)
controller2 = Controller('C2', link=link)
# Create a host for the second controller
host = Host()
@@ -59,17 +65,21 @@ async def main():
device.host = host
# Add some basic services to the device's GATT server
descriptor = Descriptor(GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, Descriptor.READABLE, 'My Description')
descriptor = Descriptor(
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
Descriptor.READABLE,
'My Description',
)
manufacturer_name_characteristic = Characteristic(
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
"Fitbit",
[descriptor]
[descriptor],
)
device_info_service = Service(
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
)
device_info_service = Service(GATT_DEVICE_INFORMATION_SERVICE, [
manufacturer_name_characteristic
])
device.add_service(device_info_service)
# Debug print
@@ -82,6 +92,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -37,7 +37,9 @@ class ScannerListener(Device.Listener):
else:
type_color = 'cyan'
print(f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]: RSSI={advertisement.rssi}, {advertisement.data}')
print(
f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]: RSSI={advertisement.rssi}, {advertisement.data}'
)
# -----------------------------------------------------------------------------
@@ -55,20 +57,25 @@ async def main():
link = LocalLink()
# Create a first controller using the packet source/sink as its host interface
controller1 = Controller('C1', host_source = hci_source, host_sink = hci_sink, link = link)
controller1 = Controller(
'C1', host_source=hci_source, host_sink=hci_sink, link=link
)
controller1.address = 'E0:E1:E2:E3:E4:E5'
# Create a second controller using the same link
controller2 = Controller('C2', link = link)
controller2 = Controller('C2', link=link)
# Create a device with a scanner listener
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', controller2, controller2)
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', controller2, controller2
)
device.listener = ScannerListener()
await device.power_on()
await device.start_scanning()
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -70,7 +70,9 @@ class Listener(Device.Listener):
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_gatt_client.py <device-config> <transport-spec> [<bluetooth-address>]')
print(
'Usage: run_gatt_client.py <device-config> <transport-spec> [<bluetooth-address>]'
)
print('example: run_gatt_client.py device1.json usb:0 E1:CA:72:48:C4:E8')
return
@@ -93,6 +95,7 @@ async def main():
await asyncio.get_running_loop().create_future()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -32,7 +32,7 @@ from bumble.gatt import (
show_services,
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
GATT_DEVICE_INFORMATION_SERVICE
GATT_DEVICE_INFORMATION_SERVICE,
)
@@ -48,32 +48,36 @@ async def main():
link = LocalLink()
# Setup a stack for the client
client_controller = Controller("client controller", link = link)
client_controller = Controller("client controller", link=link)
client_host = Host()
client_host.controller = client_controller
client_device = Device("client", address = 'F0:F1:F2:F3:F4:F5', host = client_host)
client_device = Device("client", address='F0:F1:F2:F3:F4:F5', host=client_host)
await client_device.power_on()
# Setup a stack for the server
server_controller = Controller("server controller", link = link)
server_controller = Controller("server controller", link=link)
server_host = Host()
server_host.controller = server_controller
server_device = Device("server", address = 'F6:F7:F8:F9:FA:FB', host = server_host)
server_device = Device("server", address='F6:F7:F8:F9:FA:FB', host=server_host)
server_device.listener = ServerListener()
await server_device.power_on()
# Add a few entries to the device's GATT server
descriptor = Descriptor(GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, Descriptor.READABLE, 'My Description')
descriptor = Descriptor(
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
Descriptor.READABLE,
'My Description',
)
manufacturer_name_characteristic = Characteristic(
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
"Fitbit",
[descriptor]
[descriptor],
)
device_info_service = Service(
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
)
device_info_service = Service(GATT_DEVICE_INFORMATION_SERVICE, [
manufacturer_name_characteristic
])
server_device.add_service(device_info_service)
# Connect the client to the server
@@ -109,6 +113,7 @@ async def main():
await asyncio.get_running_loop().create_future()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -22,10 +22,7 @@ import logging
from bumble.device import Device, Connection
from bumble.transport import open_transport_or_link
from bumble.att import (
ATT_Error,
ATT_INSUFFICIENT_ENCRYPTION_ERROR
)
from bumble.att import ATT_Error, ATT_INSUFFICIENT_ENCRYPTION_ERROR
from bumble.gatt import (
Service,
Characteristic,
@@ -33,7 +30,7 @@ from bumble.gatt import (
Descriptor,
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
GATT_DEVICE_INFORMATION_SERVICE
GATT_DEVICE_INFORMATION_SERVICE,
)
@@ -76,7 +73,9 @@ def my_custom_write_with_error(connection, value):
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_gatt_server.py <device-config> <transport-spec> [<bluetooth-address>]')
print(
'Usage: run_gatt_server.py <device-config> <transport-spec> [<bluetooth-address>]'
)
print('example: run_gatt_server.py device1.json usb:0 E1:CA:72:48:C4:E8')
return
@@ -89,17 +88,21 @@ async def main():
device.listener = Listener(device)
# Add a few entries to the device's GATT server
descriptor = Descriptor(GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, Descriptor.READABLE, 'My Description')
descriptor = Descriptor(
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
Descriptor.READABLE,
'My Description',
)
manufacturer_name_characteristic = Characteristic(
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
'Fitbit',
[descriptor]
[descriptor],
)
device_info_service = Service(
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
)
device_info_service = Service(GATT_DEVICE_INFORMATION_SERVICE, [
manufacturer_name_characteristic
])
custom_service1 = Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
[
@@ -107,21 +110,23 @@ async def main():
'D901B45B-4916-412E-ACCA-376ECB603B2C',
Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=my_custom_read, write=my_custom_write)
CharacteristicValue(read=my_custom_read, write=my_custom_write),
),
Characteristic(
'552957FB-CF1F-4A31-9535-E78847E1A714',
Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=my_custom_read_with_error, write=my_custom_write_with_error)
CharacteristicValue(
read=my_custom_read_with_error, write=my_custom_write_with_error
),
),
Characteristic(
'486F64C6-4B5F-4B3B-8AFF-EDE134A8446A',
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
'hello'
)
]
'hello',
),
],
)
device.add_services([device_info_service, custom_service1])
@@ -142,6 +147,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -31,12 +31,9 @@ from bumble.sdp import (
ServiceAttribute,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
from bumble.hci import (
BT_HANDSFREE_SERVICE,
BT_RFCOMM_PROTOCOL_ID
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.hci import BT_HANDSFREE_SERVICE, BT_RFCOMM_PROTOCOL_ID
from bumble.hfp import HfpProtocol
@@ -52,8 +49,8 @@ async def list_rfcomm_channels(device, connection):
[
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
]
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
print(color('==================================', 'blue'))
print(color('Handsfree Services:', 'yellow'))
@@ -61,40 +58,59 @@ async def list_rfcomm_channels(device, connection):
for attribute_list in search_result:
# Look for the RFCOMM Channel number
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if protocol_descriptor_list:
for protocol_descriptor in protocol_descriptor_list.value:
if len(protocol_descriptor.value) >= 2:
if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID:
print(color('SERVICE:', 'green'))
print(color(' RFCOMM Channel:', 'cyan'), protocol_descriptor.value[1].value)
print(
color(' RFCOMM Channel:', 'cyan'),
protocol_descriptor.value[1].value,
)
rfcomm_channels.append(protocol_descriptor.value[1].value)
# List profiles
bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
bluetooth_profile_descriptor_list = (
ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
)
if bluetooth_profile_descriptor_list:
if bluetooth_profile_descriptor_list.value:
if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE:
bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value
if (
bluetooth_profile_descriptor_list.value[0].type
== DataElement.SEQUENCE
):
bluetooth_profile_descriptors = (
bluetooth_profile_descriptor_list.value
)
else:
# Sometimes, instead of a list of lists, we just find a list. Fix that
bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list]
bluetooth_profile_descriptors = [
bluetooth_profile_descriptor_list
]
print(color(' Profiles:', 'green'))
for bluetooth_profile_descriptor in bluetooth_profile_descriptors:
version_major = bluetooth_profile_descriptor.value[1].value >> 8
version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF
print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}')
for (
bluetooth_profile_descriptor
) in bluetooth_profile_descriptors:
version_major = (
bluetooth_profile_descriptor.value[1].value >> 8
)
version_minor = (
bluetooth_profile_descriptor.value[1].value
& 0xFF
)
print(
f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}'
)
# List service classes
service_class_id_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
)
if service_class_id_list:
if service_class_id_list.value:
@@ -109,9 +125,15 @@ async def list_rfcomm_channels(device, connection):
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 4:
print('Usage: run_hfp_gateway.py <device-config> <transport-spec> <bluetooth-address>')
print(' specifying a channel number, or "discover" to list all RFCOMM channels')
print('example: run_hfp_gateway.py hfp_gateway.json usb:04b4:f901 E1:CA:72:48:C4:E8')
print(
'Usage: run_hfp_gateway.py <device-config> <transport-spec> <bluetooth-address>'
)
print(
' specifying a channel number, or "discover" to list all RFCOMM channels'
)
print(
'example: run_hfp_gateway.py hfp_gateway.json usb:04b4:f901 E1:CA:72:48:C4:E8'
)
return
print('<<< connecting to HCI...')
@@ -173,7 +195,9 @@ async def main():
protocol.send_response_line('+BRSF: 30')
protocol.send_response_line('OK')
elif line.startswith('AT+CIND=?'):
protocol.send_response_line('+CIND: ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),("signal",(0-5)),("roam",(0,1)),("battchg",(0-5)),("callheld",(0-2))')
protocol.send_response_line(
'+CIND: ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),("signal",(0-5)),("roam",(0,1)),("battchg",(0-5)),("callheld",(0-2))'
)
protocol.send_response_line('OK')
elif line.startswith('AT+CIND?'):
protocol.send_response_line('+CIND: 0,0,1,4,1,5,0')
@@ -193,7 +217,9 @@ async def main():
elif line.startswith('AT+BIA='):
protocol.send_response_line('OK')
elif line.startswith('AT+BVRA='):
protocol.send_response_line('+BVRA: 1,1,12AA,1,1,"Message 1 from Janina"')
protocol.send_response_line(
'+BVRA: 1,1,12AA,1,1,"Message 1 from Janina"'
)
elif line.startswith('AT+XEVENT='):
protocol.send_response_line('OK')
elif line.startswith('AT+XAPL='):
@@ -204,6 +230,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -32,13 +32,13 @@ from bumble.sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.core import (
BT_GENERIC_AUDIO_SERVICE,
BT_HANDSFREE_SERVICE,
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID
BT_RFCOMM_PROTOCOL_ID,
)
from bumble.hfp import HfpProtocol
@@ -49,36 +49,44 @@ def make_sdp_records(rfcomm_channel):
0x00010001: [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001)
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE)
])
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence([
DataElement.sequence([
DataElement.uuid(BT_L2CAP_PROTOCOL_ID)
]),
DataElement.sequence([
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(rfcomm_channel)
])
])
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(rfcomm_channel),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence([
DataElement.sequence([
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.unsigned_integer_16(0x0105)
])
])
)
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.unsigned_integer_16(0x0105),
]
)
]
),
),
]
}
@@ -103,6 +111,7 @@ class UiServer:
except websockets.exceptions.ConnectionClosedOK:
pass
await websockets.serve(serve, 'localhost', 8989)
@@ -111,7 +120,7 @@ async def protocol_loop(protocol):
await protocol.initialize_service()
while True:
await(protocol.next_line())
await (protocol.next_line())
# -----------------------------------------------------------------------------
@@ -160,6 +169,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -23,10 +23,7 @@ import logging
from bumble.device import Device, Connection
from bumble.transport import open_transport_or_link
from bumble.gatt import (
Service,
Characteristic
)
from bumble.gatt import Service, Characteristic
# -----------------------------------------------------------------------------
@@ -41,7 +38,9 @@ class Listener(Device.Listener, Connection.Listener):
def on_disconnection(self, reason):
print(f'### Disconnected, reason={reason}')
def on_characteristic_subscription(self, connection, characteristic, notify_enabled, indicate_enabled):
def on_characteristic_subscription(
self, connection, characteristic, notify_enabled, indicate_enabled
):
print(
f'$$$ Characteristic subscription for handle {characteristic.handle} from {connection}: '
f'notify {"enabled" if notify_enabled else "disabled"}, '
@@ -55,6 +54,7 @@ class Listener(Device.Listener, Connection.Listener):
def on_my_characteristic_subscription(peer, enabled):
print(f'### My characteristic from {peer}: {"enabled" if enabled else "disabled"}')
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
@@ -75,24 +75,24 @@ async def main():
'486F64C6-4B5F-4B3B-8AFF-EDE134A8446A',
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0x40])
bytes([0x40]),
)
characteristic2 = Characteristic(
'8EBDEBAE-0017-418E-8D3B-3A3809492165',
Characteristic.READ | Characteristic.INDICATE,
Characteristic.READABLE,
bytes([0x41])
bytes([0x41]),
)
characteristic3 = Characteristic(
'8EBDEBAE-0017-418E-8D3B-3A3809492165',
Characteristic.READ | Characteristic.NOTIFY | Characteristic.INDICATE,
Characteristic.READABLE,
bytes([0x42])
bytes([0x42]),
)
characteristic3.on('subscription', on_my_characteristic_subscription)
custom_service = Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
[characteristic1, characteristic2, characteristic3]
[characteristic1, characteristic2, characteristic3],
)
device.add_services([custom_service])
@@ -123,5 +123,5 @@ async def main():
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -31,7 +31,7 @@ from bumble.sdp import (
ServiceAttribute,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
@@ -48,47 +48,66 @@ async def list_rfcomm_channels(device, connection):
[
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
]
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
print(color('==================================', 'blue'))
print(color('RFCOMM Services:', 'yellow'))
for attribute_list in search_result:
# Look for the RFCOMM Channel number
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if protocol_descriptor_list:
for protocol_descriptor in protocol_descriptor_list.value:
if len(protocol_descriptor.value) >= 2:
if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID:
print(color('SERVICE:', 'green'))
print(color(' RFCOMM Channel:', 'cyan'), protocol_descriptor.value[1].value)
print(
color(' RFCOMM Channel:', 'cyan'),
protocol_descriptor.value[1].value,
)
# List profiles
bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
bluetooth_profile_descriptor_list = (
ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
)
if bluetooth_profile_descriptor_list:
if bluetooth_profile_descriptor_list.value:
if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE:
bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value
if (
bluetooth_profile_descriptor_list.value[0].type
== DataElement.SEQUENCE
):
bluetooth_profile_descriptors = (
bluetooth_profile_descriptor_list.value
)
else:
# Sometimes, instead of a list of lists, we just find a list. Fix that
bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list]
bluetooth_profile_descriptors = [
bluetooth_profile_descriptor_list
]
print(color(' Profiles:', 'green'))
for bluetooth_profile_descriptor in bluetooth_profile_descriptors:
version_major = bluetooth_profile_descriptor.value[1].value >> 8
version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF
print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}')
for (
bluetooth_profile_descriptor
) in bluetooth_profile_descriptors:
version_major = (
bluetooth_profile_descriptor.value[1].value >> 8
)
version_minor = (
bluetooth_profile_descriptor.value[1].value
& 0xFF
)
print(
f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}'
)
# List service classes
service_class_id_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
)
if service_class_id_list:
if service_class_id_list.value:
@@ -138,9 +157,15 @@ async def tcp_server(tcp_port, rfcomm_session):
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 5:
print('Usage: run_rfcomm_client.py <device-config> <transport-spec> <bluetooth-address> <channel>|discover [tcp-port]')
print(' specifying a channel number, or "discover" to list all RFCOMM channels')
print('example: run_rfcomm_client.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8 8')
print(
'Usage: run_rfcomm_client.py <device-config> <transport-spec> <bluetooth-address> <channel>|discover [tcp-port]'
)
print(
' specifying a channel number, or "discover" to list all RFCOMM channels'
)
print(
'example: run_rfcomm_client.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8 8'
)
return
print('<<< connecting to HCI...')
@@ -197,6 +222,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -31,7 +31,7 @@ from bumble.sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
@@ -40,22 +40,34 @@ from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
def sdp_records(channel):
return {
0x00010001: [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(0x00010001)),
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
])),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))
])),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.sequence([
DataElement.uuid(BT_L2CAP_PROTOCOL_ID)
]),
DataElement.sequence([
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel)
])
]))
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
]
),
]
),
),
]
}
@@ -113,6 +125,7 @@ async def main():
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -35,13 +35,15 @@ async def main():
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[1]) as (hci_source, hci_sink):
print('<<< connected')
filter_duplicates = (len(sys.argv) == 3 and sys.argv[2] == 'filter')
filter_duplicates = len(sys.argv) == 3 and sys.argv[2] == 'filter'
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
@device.on('advertisement')
def _(advertisement):
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[advertisement.address.address_type]
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
advertisement.address.address_type
]
address_color = 'yellow' if advertisement.is_connectable else 'red'
address_qualifier = ''
if address_type_string.startswith('P'):
@@ -57,13 +59,16 @@ async def main():
type_color = 'white'
separator = '\n '
print(f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}:{separator}RSSI:{advertisement.rssi}{separator}{advertisement.data.to_string(separator)}')
print(
f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}:{separator}RSSI:{advertisement.rssi}{separator}{advertisement.data.to_string(separator)}'
)
await device.power_on()
await device.start_scanning(filter_duplicates=filter_duplicates)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -13,4 +13,5 @@
# limitations under the License.
from setuptools import setup
setup()

View File

@@ -27,6 +27,7 @@ ns = Collection()
build_tasks = Collection()
ns.add_collection(build_tasks, name="build")
@task
def build(ctx, install=False):
if install:
@@ -34,18 +35,23 @@ def build(ctx, install=False):
ctx.run("python -m build")
build_tasks.add_task(build, default=True)
@task
def release_build(ctx):
build(ctx, install=True)
build_tasks.add_task(release_build, name="release")
@task
def mkdocs(ctx):
ctx.run("mkdocs build -f docs/mkdocs/mkdocs.yml")
build_tasks.add_task(mkdocs, name="mkdocs")
# Testing
@@ -70,10 +76,13 @@ def test(ctx, filter=None, junit=False, install=False, html=False, verbose=0):
args += f" -{'v' * verbose}"
ctx.run(f"python -m pytest {os.path.join(ROOT_DIR, 'tests')} {args}")
test_tasks.add_task(test, default=True)
@task
def release_test(ctx):
test(ctx, install=True)
test_tasks.add_task(release_test, name="release")

View File

@@ -35,7 +35,7 @@ from bumble.avdtp import (
MediaPacket,
AVDTP_AUDIO_MEDIA_TYPE,
AVDTP_TSEP_SNK,
A2DP_SBC_CODEC_TYPE
A2DP_SBC_CODEC_TYPE,
)
from bumble.a2dp import (
SbcMediaCodecInformation,
@@ -44,7 +44,7 @@ from bumble.a2dp import (
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD
SBC_SNR_ALLOCATION_METHOD,
)
# -----------------------------------------------------------------------------
@@ -60,18 +60,18 @@ class TwoDevices:
self.link = LocalLink()
self.controllers = [
Controller('C1', link = self.link),
Controller('C2', link = self.link)
Controller('C1', link=self.link),
Controller('C2', link=self.link),
]
self.devices = [
Device(
address = 'F0:F1:F2:F3:F4:F5',
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
address='F0:F1:F2:F3:F4:F5',
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address = 'F5:F4:F3:F2:F1:F0',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
)
address='F5:F4:F3:F2:F1:F0',
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
]
self.paired = [None, None]
@@ -87,8 +87,12 @@ async def test_self_connection():
two_devices = TwoDevices()
# Attach listeners
two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection))
two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection))
two_devices.devices[0].on(
'connection', lambda connection: two_devices.on_connection(0, connection)
)
two_devices.devices[1].on(
'connection', lambda connection: two_devices.on_connection(1, connection)
)
# Start
await two_devices.devices[0].power_on()
@@ -98,46 +102,49 @@ async def test_self_connection():
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
# Check the post conditions
assert(two_devices.connections[0] is not None)
assert(two_devices.connections[1] is not None)
assert two_devices.connections[0] is not None
assert two_devices.connections[1] is not None
# -----------------------------------------------------------------------------
def source_codec_capabilities():
return MediaCodecCapabilities(
media_type = AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type = A2DP_SBC_CODEC_TYPE,
media_codec_information = SbcMediaCodecInformation.from_discrete_values(
sampling_frequency = 44100,
channel_mode = SBC_JOINT_STEREO_CHANNEL_MODE,
block_length = 16,
subbands = 8,
allocation_method = SBC_LOUDNESS_ALLOCATION_METHOD,
minimum_bitpool_value = 2,
maximum_bitpool_value = 53
)
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation.from_discrete_values(
sampling_frequency=44100,
channel_mode=SBC_JOINT_STEREO_CHANNEL_MODE,
block_length=16,
subbands=8,
allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD,
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
)
# -----------------------------------------------------------------------------
def sink_codec_capabilities():
return MediaCodecCapabilities(
media_type = AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type = A2DP_SBC_CODEC_TYPE,
media_codec_information = SbcMediaCodecInformation.from_lists(
sampling_frequencies = [48000, 44100, 32000, 16000],
channel_modes = [
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE
SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths = [4, 8, 12, 16],
subbands = [4, 8],
allocation_methods = [SBC_LOUDNESS_ALLOCATION_METHOD, SBC_SNR_ALLOCATION_METHOD],
minimum_bitpool_value = 2,
maximum_bitpool_value = 53
)
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
)
@@ -164,21 +171,25 @@ async def test_source_sink_1():
listener = Listener(Listener.create_registrar(two_devices.devices[1]))
listener.on('connection', on_avdtp_connection)
connection = await two_devices.devices[0].connect(two_devices.devices[1].random_address)
connection = await two_devices.devices[0].connect(
two_devices.devices[1].random_address
)
client = await Protocol.connect(connection)
endpoints = await client.discover_remote_endpoints()
assert(len(endpoints) == 1)
assert len(endpoints) == 1
remote_sink = list(endpoints)[0]
assert(remote_sink.in_use == 0)
assert(remote_sink.media_type == AVDTP_AUDIO_MEDIA_TYPE)
assert(remote_sink.tsep == AVDTP_TSEP_SNK)
assert remote_sink.in_use == 0
assert remote_sink.media_type == AVDTP_AUDIO_MEDIA_TYPE
assert remote_sink.tsep == AVDTP_TSEP_SNK
async def generate_packets(packet_count):
sequence_number = 0
timestamp = 0
for i in range(packet_count):
payload = bytes([sequence_number % 256])
packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, payload)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, payload
)
packet.timestamp_seconds = timestamp / 44100
timestamp += 10
sequence_number += 1
@@ -192,50 +203,49 @@ async def test_source_sink_1():
source = client.add_source(source_codec_capabilities(), pump)
stream = await client.create_stream(source, remote_sink)
await stream.start()
assert(stream.state == AVDTP_STREAMING_STATE)
assert(stream.local_endpoint.in_use == 1)
assert(stream.rtp_channel is not None)
assert(sink.in_use == 1)
assert(sink.stream is not None)
assert(sink.stream.state == AVDTP_STREAMING_STATE)
assert stream.state == AVDTP_STREAMING_STATE
assert stream.local_endpoint.in_use == 1
assert stream.rtp_channel is not None
assert sink.in_use == 1
assert sink.stream is not None
assert sink.stream.state == AVDTP_STREAMING_STATE
await rtp_packets_fully_received
await stream.close()
assert(stream.rtp_channel is None)
assert(source.in_use == 0)
assert(source.stream.state == AVDTP_IDLE_STATE)
assert(sink.in_use == 0)
assert(sink.stream.state == AVDTP_IDLE_STATE)
assert stream.rtp_channel is None
assert source.in_use == 0
assert source.stream.state == AVDTP_IDLE_STATE
assert sink.in_use == 0
assert sink.stream.state == AVDTP_IDLE_STATE
# Send packets manually
rtp_packets_fully_received = asyncio.get_running_loop().create_future()
rtp_packets_expected = 3
rtp_packets = []
source_packets = [
MediaPacket(2, 0, 0, 0, i, i * 10, 0, [], 96, bytes([i]))
for i in range(3)
MediaPacket(2, 0, 0, 0, i, i * 10, 0, [], 96, bytes([i])) for i in range(3)
]
source = client.add_source(source_codec_capabilities(), None)
stream = await client.create_stream(source, remote_sink)
await stream.start()
assert(stream.state == AVDTP_STREAMING_STATE)
assert(stream.local_endpoint.in_use == 1)
assert(stream.rtp_channel is not None)
assert(sink.in_use == 1)
assert(sink.stream is not None)
assert(sink.stream.state == AVDTP_STREAMING_STATE)
assert stream.state == AVDTP_STREAMING_STATE
assert stream.local_endpoint.in_use == 1
assert stream.rtp_channel is not None
assert sink.in_use == 1
assert sink.stream is not None
assert sink.stream.state == AVDTP_STREAMING_STATE
stream.send_media_packet(source_packets[0])
stream.send_media_packet(source_packets[1])
stream.send_media_packet(source_packets[2])
await stream.close()
assert(stream.rtp_channel is None)
assert(len(rtp_packets) == 3)
assert(source.in_use == 0)
assert(source.stream.state == AVDTP_IDLE_STATE)
assert(sink.in_use == 0)
assert(sink.stream.state == AVDTP_IDLE_STATE)
assert stream.rtp_channel is None
assert len(rtp_packets) == 3
assert source.in_use == 0
assert source.stream.state == AVDTP_IDLE_STATE
assert sink.in_use == 0
assert sink.stream.state == AVDTP_IDLE_STATE
# -----------------------------------------------------------------------------
@@ -246,5 +256,5 @@ async def run_test_self():
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_test_self())

View File

@@ -28,7 +28,7 @@ from bumble.avdtp import (
Set_Configuration_Command,
Set_Configuration_Response,
ServiceCapabilities,
MediaCodecCapabilities
MediaCodecCapabilities,
)
@@ -37,24 +37,28 @@ def test_messages():
capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
MediaCodecCapabilities(
media_type = AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type = A2DP_SBC_CODEC_TYPE,
media_codec_information = bytes.fromhex('211502fa')
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=bytes.fromhex('211502fa'),
),
ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY)
ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY),
]
message = Get_Capabilities_Response(capabilities)
parsed = Message.create(AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload)
assert(message.payload == parsed.payload)
parsed = Message.create(
AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload
)
assert message.payload == parsed.payload
message = Set_Configuration_Command(3, 4, capabilities)
parsed = Message.create(AVDTP_SET_CONFIGURATION, Message.COMMAND, message.payload)
assert(message.payload == parsed.payload)
assert message.payload == parsed.payload
# -----------------------------------------------------------------------------
def test_rtp():
packet = bytes.fromhex('8060000103141c6a000000000a9cbd2adbfe75443333542210037eeeed5f76dfbbbb57ddb890eed5f76e2ad3958613d3d04a5f596fc2b54d613a6a95570b4b49c2d0955ac710ca6abb293bb4580d5896b106cd6a7c4b557d8bb73aac56b8e633aa161447caa86585ae4cbc9576cc9cbd2a54fe7443322064221000b44a5cd51929bc96328916b1694e1f3611d6b6928dbf554b01e96d23a6ad879834d99326a649b94ca6adbeab1311e372a3aa3468e9582d2d9c857da28e5b76a2d363089367432930a0160af22d48911bc46cea549cbd2a03fe754332206532210054cf1d3d9260d3bc9895566f124b22c4b3cb6bc66648cf9b21e1613a48b3592466e90cee3424cc6cc56d2f569b12145234c6bd73560c95ad9c584c9d6c26552cea9905da55b3eab182c40e2dae64b46c328ba64d9cbd2a3cde74433220643211001e8d1ad6210d5c26b296d40d298a29b073b46bb4542ceb1aea011612c6df64c731068d49b56bb48afb2456ea9b5903222bb63b8b1a60c52896325a22aad781486cdb36269d9dc6dd38d9acf5b0e9328e0b23542c9cbd2adffe744323206432200095731b2a62604accea58da8ee6aba6d6fc9169ab66a824527412a66ac6c5c41d12c85295673c3263848c88ae934f62619c46ed2adccaaeb3eac70c396bb28cb8cecaf22423c548cd4adca92d30d1370ba34a772d9cbd2a3efe6442221064322100cc932cd12222dcd854d6da8d09330d2708b392a3997ec8a2f30b9312b8c562d9353513eda7733c4b835176eeca695909cc10d08614574d36cac669c583e68d9778daca9b92d6e4bb5cd008ef3562aa52332bc54a9cbd2a1efe6443332064322100a6e91a6ddc58a3a4b966a3452cb6d0b9c5334d2b695929128dcd6123b8b366d491122fd545f9b96cf769d530d2e2646b15c6a43695b12d33aa214e622e45b1ac132309a39eddc82caad35115b3d2350c5c6dcd749cbd2a9c7e654332207433110086ed5b68531a54c6e7bb052d15add1b204bd62568d8922d3379418b9c4e202482909ab712a744d81f392fa94193d62293ac6dfa7278f79b451c70c3b4b2b64d70f0b3463323c46f598ecd70d35e5a743282307099cbd2ae9fe654332106432110082acdb4aca734b843b6699f491ad3a511aab6db2344eeed386d0aa34c49c4b0a4b2aa59ec98bba6419b06310d2f9626c42a7466728f0ca0f1db579b46c0a701264e59153535228dc6497492dac722596138bd74a9cbd2a0b7e655432107432110056a8d22a62d643b428e513b52ea4a66c7a41991719370c8d9664ce2bca685dd2690b1c368c5dce36d26b38d10e0c672343ca8c25c58d0d5c568de433b7561c61268aaf83260b4b868dca8ee6dc6ba573abcb5093')
packet = bytes.fromhex(
'8060000103141c6a000000000a9cbd2adbfe75443333542210037eeeed5f76dfbbbb57ddb890eed5f76e2ad3958613d3d04a5f596fc2b54d613a6a95570b4b49c2d0955ac710ca6abb293bb4580d5896b106cd6a7c4b557d8bb73aac56b8e633aa161447caa86585ae4cbc9576cc9cbd2a54fe7443322064221000b44a5cd51929bc96328916b1694e1f3611d6b6928dbf554b01e96d23a6ad879834d99326a649b94ca6adbeab1311e372a3aa3468e9582d2d9c857da28e5b76a2d363089367432930a0160af22d48911bc46cea549cbd2a03fe754332206532210054cf1d3d9260d3bc9895566f124b22c4b3cb6bc66648cf9b21e1613a48b3592466e90cee3424cc6cc56d2f569b12145234c6bd73560c95ad9c584c9d6c26552cea9905da55b3eab182c40e2dae64b46c328ba64d9cbd2a3cde74433220643211001e8d1ad6210d5c26b296d40d298a29b073b46bb4542ceb1aea011612c6df64c731068d49b56bb48afb2456ea9b5903222bb63b8b1a60c52896325a22aad781486cdb36269d9dc6dd38d9acf5b0e9328e0b23542c9cbd2adffe744323206432200095731b2a62604accea58da8ee6aba6d6fc9169ab66a824527412a66ac6c5c41d12c85295673c3263848c88ae934f62619c46ed2adccaaeb3eac70c396bb28cb8cecaf22423c548cd4adca92d30d1370ba34a772d9cbd2a3efe6442221064322100cc932cd12222dcd854d6da8d09330d2708b392a3997ec8a2f30b9312b8c562d9353513eda7733c4b835176eeca695909cc10d08614574d36cac669c583e68d9778daca9b92d6e4bb5cd008ef3562aa52332bc54a9cbd2a1efe6443332064322100a6e91a6ddc58a3a4b966a3452cb6d0b9c5334d2b695929128dcd6123b8b366d491122fd545f9b96cf769d530d2e2646b15c6a43695b12d33aa214e622e45b1ac132309a39eddc82caad35115b3d2350c5c6dcd749cbd2a9c7e654332207433110086ed5b68531a54c6e7bb052d15add1b204bd62568d8922d3379418b9c4e202482909ab712a744d81f392fa94193d62293ac6dfa7278f79b451c70c3b4b2b64d70f0b3463323c46f598ecd70d35e5a743282307099cbd2ae9fe654332106432110082acdb4aca734b843b6699f491ad3a511aab6db2344eeed386d0aa34c49c4b0a4b2aa59ec98bba6419b06310d2f9626c42a7466728f0ca0f1db579b46c0a701264e59153535228dc6497492dac722596138bd74a9cbd2a0b7e655432107432110056a8d22a62d643b428e513b52ea4a66c7a41991719370c8d9664ce2bca685dd2690b1c368c5dce36d26b38d10e0c672343ca8c25c58d0d5c568de433b7561c61268aaf83260b4b868dca8ee6dc6ba573abcb5093'
)
media_packet = MediaPacket.from_bytes(packet)
print(media_packet)
@@ -62,4 +66,4 @@ def test_rtp():
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_messages()
test_rtp()
test_rtp()

View File

@@ -22,32 +22,35 @@ def test_ad_data():
data = bytes([2, AdvertisingData.TX_POWER_LEVEL, 123])
ad = AdvertisingData.from_bytes(data)
ad_bytes = bytes(ad)
assert(data == ad_bytes)
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None)
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123]))
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == [])
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [bytes([123])])
assert data == ad_bytes
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None
assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []
assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [
bytes([123])
]
data2 = bytes([2, AdvertisingData.TX_POWER_LEVEL, 234])
ad.append(data2)
ad_bytes = bytes(ad)
assert(ad_bytes == data + data2)
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None)
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123]))
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == [])
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [bytes([123]), bytes([234])])
assert ad_bytes == data + data2
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None
assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []
assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [
bytes([123]),
bytes([234]),
]
# -----------------------------------------------------------------------------
def test_get_dict_key_by_value():
dictionary = {
"A": 1,
"B": 2
}
dictionary = {"A": 1, "B": 2}
assert get_dict_key_by_value(dictionary, 1) == "A"
assert get_dict_key_by_value(dictionary, 2) == "B"
assert get_dict_key_by_value(dictionary, 3) is None
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_ad_data()
test_ad_data()

View File

@@ -25,10 +25,23 @@ from bumble.core import BT_BR_EDR_TRANSPORT
from bumble.device import Connection, Device
from bumble.host import Host
from bumble.hci import (
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_COMMAND_STATUS_PENDING, HCI_CREATE_CONNECTION_COMMAND, HCI_SUCCESS,
Address, HCI_Command_Complete_Event, HCI_Command_Status_Event, HCI_Connection_Complete_Event, HCI_Connection_Request_Event, HCI_Packet
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
HCI_COMMAND_STATUS_PENDING,
HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS,
Address,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_Connection_Complete_Event,
HCI_Connection_Request_Event,
HCI_Packet,
)
from bumble.gatt import (
GATT_GENERIC_ACCESS_SERVICE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
)
from bumble.gatt import GATT_GENERIC_ACCESS_SERVICE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_DEVICE_NAME_CHARACTERISTIC, GATT_APPEARANCE_CHARACTERISTIC
# -----------------------------------------------------------------------------
# Logging
@@ -59,70 +72,90 @@ async def test_device_connect_parallel():
d2.classic_enabled = True
# set public addresses
d0.public_address = Address('F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS)
d1.public_address = Address('F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS)
d2.public_address = Address('F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS)
d0.public_address = Address(
'F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
)
d1.public_address = Address(
'F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS
)
d2.public_address = Address(
'F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
)
def d0_flow():
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
assert packet.bd_addr == d1.public_address
d0.host.on_hci_packet(HCI_Command_Status_Event(
status = HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets = 1,
command_opcode = HCI_CREATE_CONNECTION_COMMAND
))
d0.host.on_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_CREATE_CONNECTION_COMMAND,
)
)
d1.host.on_hci_packet(HCI_Connection_Request_Event(
bd_addr = d0.public_address,
class_of_device = 0,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE
))
d1.host.on_hci_packet(
HCI_Connection_Request_Event(
bd_addr=d0.public_address,
class_of_device=0,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
)
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
assert packet.bd_addr == d2.public_address
d0.host.on_hci_packet(HCI_Command_Status_Event(
status = HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets = 1,
command_opcode = HCI_CREATE_CONNECTION_COMMAND
))
d0.host.on_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_CREATE_CONNECTION_COMMAND,
)
)
d2.host.on_hci_packet(HCI_Connection_Request_Event(
bd_addr = d0.public_address,
class_of_device = 0,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE
))
d2.host.on_hci_packet(
HCI_Connection_Request_Event(
bd_addr=d0.public_address,
class_of_device=0,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
)
assert (yield) == None
def d1_flow():
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d1.host.on_hci_packet(HCI_Command_Complete_Event(
num_hci_command_packets = 1,
command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters = b"\x00"
))
d1.host.on_hci_packet(
HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)
d1.host.on_hci_packet(HCI_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x100,
bd_addr = d0.public_address,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled = True,
))
d1.host.on_hci_packet(
HCI_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x100,
bd_addr=d0.public_address,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled=True,
)
)
d0.host.on_hci_packet(HCI_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x100,
bd_addr = d1.public_address,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled = True,
))
d0.host.on_hci_packet(
HCI_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x100,
bd_addr=d1.public_address,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled=True,
)
)
assert (yield) == None
@@ -130,27 +163,33 @@ async def test_device_connect_parallel():
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d2.host.on_hci_packet(HCI_Command_Complete_Event(
num_hci_command_packets = 1,
command_opcode = HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters = b"\x00"
))
d2.host.on_hci_packet(
HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)
d2.host.on_hci_packet(HCI_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x101,
bd_addr = d0.public_address,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled = True,
))
d2.host.on_hci_packet(
HCI_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x101,
bd_addr=d0.public_address,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled=True,
)
)
d0.host.on_hci_packet(HCI_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x101,
bd_addr = d2.public_address,
link_type = HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled = True,
))
d0.host.on_hci_packet(
HCI_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x101,
bd_addr=d2.public_address,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled=True,
)
)
assert (yield) == None
@@ -158,13 +197,19 @@ async def test_device_connect_parallel():
d1.host.set_packet_sink(Sink(d1_flow()))
d2.host.set_packet_sink(Sink(d2_flow()))
[c01, c02, a10, a20, a01] = await asyncio.gather(*[
asyncio.create_task(d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)),
asyncio.create_task(d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)),
asyncio.create_task(d1.accept(peer_address=d0.public_address)),
asyncio.create_task(d2.accept()),
asyncio.create_task(d0.accept(peer_address=d1.public_address)),
])
[c01, c02, a10, a20, a01] = await asyncio.gather(
*[
asyncio.create_task(
d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)
),
asyncio.create_task(
d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)
),
asyncio.create_task(d1.accept(peer_address=d0.public_address)),
asyncio.create_task(d2.accept()),
asyncio.create_task(d0.accept(peer_address=d1.public_address)),
]
)
assert type(c01) == Connection
assert type(c02) == Connection
@@ -205,5 +250,5 @@ def test_gatt_services_without_gas():
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_test_device())

View File

@@ -36,7 +36,7 @@ from bumble.gatt import (
UTF8CharacteristicAdapter,
Service,
Characteristic,
CharacteristicValue
CharacteristicValue,
)
from bumble.transport import AsyncPipeSink
from bumble.core import UUID
@@ -45,7 +45,7 @@ from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU,
ATT_Error_Response,
ATT_Read_By_Group_Type_Request
ATT_Read_By_Group_Type_Request,
)
@@ -72,20 +72,20 @@ def test_UUID():
assert str(w) == '61A3512C-09BE-4DDC-A6A6-0B03667AAFC6'
u1 = UUID.from_16_bits(0x1234)
b1 = u1.to_bytes(force_128 = True)
b1 = u1.to_bytes(force_128=True)
u2 = UUID.from_bytes(b1)
assert u1 == u2
u3 = UUID.from_16_bits(0x180a)
u3 = UUID.from_16_bits(0x180A)
assert str(u3) == 'UUID-16:180A (Device Information)'
# -----------------------------------------------------------------------------
def test_ATT_Error_Response():
pdu = ATT_Error_Response(
request_opcode_in_error = ATT_EXCHANGE_MTU_REQUEST,
attribute_handle_in_error = 0x0000,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
request_opcode_in_error=ATT_EXCHANGE_MTU_REQUEST,
attribute_handle_in_error=0x0000,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
basic_check(pdu)
@@ -93,9 +93,9 @@ def test_ATT_Error_Response():
# -----------------------------------------------------------------------------
def test_ATT_Read_By_Group_Type_Request():
pdu = ATT_Read_By_Group_Type_Request(
starting_handle = 0x0001,
ending_handle = 0xFFFF,
attribute_group_type = UUID.from_16_bits(0x2800)
starting_handle=0x0001,
ending_handle=0xFFFF,
attribute_group_type=UUID.from_16_bits(0x2800),
)
basic_check(pdu)
@@ -110,7 +110,12 @@ async def test_characteristic_encoding():
def decode_value(self, value_bytes):
return value_bytes[0]
c = Foo(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, 123)
c = Foo(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
123,
)
x = c.read_value(None)
assert x == bytes([123])
c.write_value(None, bytes([122]))
@@ -123,7 +128,7 @@ async def test_characteristic_encoding():
characteristic.handle,
characteristic.end_group_handle,
characteristic.uuid,
characteristic.properties
characteristic.properties,
)
def encode_value(self, value):
@@ -138,13 +143,10 @@ async def test_characteristic_encoding():
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([123])
bytes([123]),
)
service = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic]
)
service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic])
server.add_service(service)
await client.power_on()
@@ -237,7 +239,7 @@ async def test_attribute_getters():
characteristic_uuid,
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([123])
bytes([123]),
)
service_uuid = UUID('3A657F47-D34F-46B3-B1EC-698E29B6B829')
@@ -247,22 +249,43 @@ async def test_attribute_getters():
service_attr = server.gatt_server.get_service_attribute(service_uuid)
assert service_attr
(char_decl_attr, char_value_attr) = server.gatt_server.get_characteristic_attributes(service_uuid, characteristic_uuid)
(
char_decl_attr,
char_value_attr,
) = server.gatt_server.get_characteristic_attributes(
service_uuid, characteristic_uuid
)
assert char_decl_attr and char_value_attr
desc_attr = server.gatt_server.get_descriptor_attribute(service_uuid, characteristic_uuid, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR)
desc_attr = server.gatt_server.get_descriptor_attribute(
service_uuid,
characteristic_uuid,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
)
assert desc_attr
# assert all handles are in expected order
assert service_attr.handle < char_decl_attr.handle < char_value_attr.handle < desc_attr.handle == service_attr.end_group_handle
assert (
service_attr.handle
< char_decl_attr.handle
< char_value_attr.handle
< desc_attr.handle
== service_attr.end_group_handle
)
# assert characteristic declarations attribute is followed by characteristic value attribute
assert char_decl_attr.handle + 1 == char_value_attr.handle
# -----------------------------------------------------------------------------
def test_CharacteristicAdapter():
# Check that the CharacteristicAdapter base class is transparent
v = bytes([1, 2, 3])
c = Characteristic(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, v)
c = Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
v,
)
a = CharacteristicAdapter(c)
value = a.read_value(None)
@@ -273,7 +296,9 @@ def test_CharacteristicAdapter():
assert c.value == v
# Simple delegated adapter
a = DelegatedCharacteristicAdapter(c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)))
a = DelegatedCharacteristicAdapter(
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
)
value = a.read_value(None)
assert value == bytes(reversed(v))
@@ -342,7 +367,9 @@ def test_CharacteristicValue():
assert x == b
result = []
c = CharacteristicValue(write=lambda connection, value: result.append((connection, value)))
c = CharacteristicValue(
write=lambda connection, value: result.append((connection, value))
)
z = object()
c.write(z, b)
assert result == [(z, b)]
@@ -355,23 +382,23 @@ class LinkedDevices:
self.link = LocalLink()
self.controllers = [
Controller('C1', link = self.link),
Controller('C2', link = self.link),
Controller('C3', link = self.link)
Controller('C1', link=self.link),
Controller('C2', link=self.link),
Controller('C3', link=self.link),
]
self.devices = [
Device(
address = 'F0:F1:F2:F3:F4:F5',
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
address='F0:F1:F2:F3:F4:F5',
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address = 'F1:F2:F3:F4:F5:F6',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
address='F1:F2:F3:F4:F5:F6',
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
Device(
address = 'F2:F3:F4:F5:F6:F7',
host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2]))
)
address='F2:F3:F4:F5:F6:F7',
host=Host(self.controllers[2], AsyncPipeSink(self.controllers[2])),
),
]
self.paired = [None, None, None]
@@ -392,7 +419,7 @@ async def test_read_write():
characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE
Characteristic.READABLE | Characteristic.WRITEABLE,
)
def on_characteristic1_write(connection, value):
@@ -410,15 +437,13 @@ async def test_read_write():
'66DE9057-C848-4ACA-B993-D675644EBB85',
Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=on_characteristic2_read, write=on_characteristic2_write)
CharacteristicValue(
read=on_characteristic2_read, write=on_characteristic2_write
),
)
service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[
characteristic1,
characteristic2
]
'3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1, characteristic2]
)
server.add_services([service1])
@@ -446,7 +471,9 @@ async def test_read_write():
assert v1 == b
assert type(characteristic1._last_value is tuple)
assert len(characteristic1._last_value) == 2
assert str(characteristic1._last_value[0].peer_address) == str(client.random_address)
assert str(characteristic1._last_value[0].peer_address) == str(
client.random_address
)
assert characteristic1._last_value[1] == b
bb = bytes([3, 4, 5, 6])
characteristic1.value = bb
@@ -457,7 +484,9 @@ async def test_read_write():
await async_barrier()
assert type(characteristic2._last_value is tuple)
assert len(characteristic2._last_value) == 2
assert str(characteristic2._last_value[0].peer_address) == str(client.random_address)
assert str(characteristic2._last_value[0].peer_address) == str(
client.random_address
)
assert characteristic2._last_value[1] == b
@@ -471,15 +500,10 @@ async def test_read_write2():
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE,
value=v
value=v,
)
service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[
characteristic1
]
)
service1 = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1])
server.add_services([service1])
await client.power_on()
@@ -520,11 +544,15 @@ async def test_subscribe_notify():
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([1, 2, 3])
bytes([1, 2, 3]),
)
def on_characteristic1_subscription(connection, notify_enabled, indicate_enabled):
characteristic1._last_subscription = (connection, notify_enabled, indicate_enabled)
characteristic1._last_subscription = (
connection,
notify_enabled,
indicate_enabled,
)
characteristic1.on('subscription', on_characteristic1_subscription)
@@ -532,11 +560,15 @@ async def test_subscribe_notify():
'66DE9057-C848-4ACA-B993-D675644EBB85',
Characteristic.READ | Characteristic.INDICATE,
Characteristic.READABLE,
bytes([4, 5, 6])
bytes([4, 5, 6]),
)
def on_characteristic2_subscription(connection, notify_enabled, indicate_enabled):
characteristic2._last_subscription = (connection, notify_enabled, indicate_enabled)
characteristic2._last_subscription = (
connection,
notify_enabled,
indicate_enabled,
)
characteristic2.on('subscription', on_characteristic2_subscription)
@@ -544,26 +576,33 @@ async def test_subscribe_notify():
'AB5E639C-40C1-4238-B9CB-AF41F8B806E4',
Characteristic.READ | Characteristic.NOTIFY | Characteristic.INDICATE,
Characteristic.READABLE,
bytes([7, 8, 9])
bytes([7, 8, 9]),
)
def on_characteristic3_subscription(connection, notify_enabled, indicate_enabled):
characteristic3._last_subscription = (connection, notify_enabled, indicate_enabled)
characteristic3._last_subscription = (
connection,
notify_enabled,
indicate_enabled,
)
characteristic3.on('subscription', on_characteristic3_subscription)
service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[
characteristic1,
characteristic2,
characteristic3
]
[characteristic1, characteristic2, characteristic3],
)
server.add_services([service1])
def on_characteristic_subscription(connection, characteristic, notify_enabled, indicate_enabled):
server._last_subscription = (connection, characteristic, notify_enabled, indicate_enabled)
def on_characteristic_subscription(
connection, characteristic, notify_enabled, indicate_enabled
):
server._last_subscription = (
connection,
characteristic,
notify_enabled,
indicate_enabled,
)
server.on('characteristic_subscription', on_characteristic_subscription)
@@ -630,17 +669,23 @@ async def test_subscribe_notify():
await peer.subscribe(c2, on_c2_update)
await async_barrier()
await server.notify_subscriber(characteristic2._last_subscription[0], characteristic2)
await server.notify_subscriber(
characteristic2._last_subscription[0], characteristic2
)
await async_barrier()
assert not c2._called
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
await server.indicate_subscriber(
characteristic2._last_subscription[0], characteristic2
)
await async_barrier()
assert c2._called
assert c2._last_update == characteristic2.value
c2._called = False
await peer.unsubscribe(c2, on_c2_update)
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2)
await server.indicate_subscriber(
characteristic2._last_subscription[0], characteristic2
)
await async_barrier()
assert not c2._called
@@ -666,7 +711,9 @@ async def test_subscribe_notify():
c3.on('update', on_c3_update)
await peer.subscribe(c3, on_c3_update_2)
await async_barrier()
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
await server.notify_subscriber(
characteristic3._last_subscription[0], characteristic3
)
await async_barrier()
assert c3._called
assert c3._last_update == characteristic3.value
@@ -681,7 +728,9 @@ async def test_subscribe_notify():
await peer.subscribe(c3, on_c3_update_3, prefer_notify=False)
await async_barrier()
characteristic3.value = bytes([1, 2, 3])
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
await server.indicate_subscriber(
characteristic3._last_subscription[0], characteristic3
)
await async_barrier()
assert c3._called
assert c3._last_update == characteristic3.value
@@ -693,8 +742,12 @@ async def test_subscribe_notify():
c3._called_2 = False
c3._called_3 = False
await peer.unsubscribe(c3)
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3)
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
await server.notify_subscriber(
characteristic3._last_subscription[0], characteristic3
)
await server.indicate_subscriber(
characteristic3._last_subscription[0], characteristic3
)
await async_barrier()
assert not c3._called
assert not c3._called_2
@@ -709,6 +762,7 @@ async def test_mtu_exchange():
d3.gatt_server.max_mtu = 100
d3_connections = []
@d3.on('connection')
def on_d3_connection(connection):
d3_connections.append(connection)
@@ -745,7 +799,12 @@ def test_char_property_to_string():
# double
assert Characteristic.properties_as_string(0x03) == "BROADCAST,READ"
assert Characteristic.properties_as_string(Characteristic.BROADCAST | Characteristic.READ) == "BROADCAST,READ"
assert (
Characteristic.properties_as_string(
Characteristic.BROADCAST | Characteristic.READ
)
== "BROADCAST,READ"
)
# -----------------------------------------------------------------------------
@@ -754,8 +813,14 @@ def test_char_property_string_to_type():
assert Characteristic.string_to_properties("BROADCAST") == Characteristic.BROADCAST
# double
assert Characteristic.string_to_properties("BROADCAST,READ") == Characteristic.BROADCAST | Characteristic.READ
assert Characteristic.string_to_properties("READ,BROADCAST") == Characteristic.BROADCAST | Characteristic.READ
assert (
Characteristic.string_to_properties("BROADCAST,READ")
== Characteristic.BROADCAST | Characteristic.READ
)
assert (
Characteristic.string_to_properties("READ,BROADCAST")
== Characteristic.BROADCAST | Characteristic.READ
)
# -----------------------------------------------------------------------------
@@ -767,16 +832,15 @@ async def test_server_string():
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([123])
bytes([123]),
)
service = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic]
)
service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic])
server.add_service(service)
assert str(server.gatt_server) == """Service(handle=0x0001, end=0x0005, uuid=UUID-16:1800 (Generic Access))
assert (
str(server.gatt_server)
== """Service(handle=0x0001, end=0x0005, uuid=UUID-16:1800 (Generic Access))
CharacteristicDeclaration(handle=0x0002, value_handle=0x0003, uuid=UUID-16:2A00 (Device Name), properties=READ)
Characteristic(handle=0x0003, end=0x0003, uuid=UUID-16:2A00 (Device Name), properties=READ)
CharacteristicDeclaration(handle=0x0004, value_handle=0x0005, uuid=UUID-16:2A01 (Appearance), properties=READ)
@@ -785,6 +849,8 @@ Service(handle=0x0006, end=0x0009, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829)
CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, properties=READ,WRITE,NOTIFY)
Characteristic(handle=0x0008, end=0x0009, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, properties=READ,WRITE,NOTIFY)
Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)"""
)
# -----------------------------------------------------------------------------
async def async_main():
@@ -797,7 +863,7 @@ async def async_main():
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()

View File

@@ -44,15 +44,15 @@ def test_HCI_Event():
def test_HCI_LE_Connection_Complete_Event():
address = Address('00:11:22:33:44:55')
event = HCI_LE_Connection_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 1,
role = 1,
peer_address_type = 1,
peer_address = address,
connection_interval = 3,
peripheral_latency = 4,
supervision_timeout = 5,
central_clock_accuracy = 6
status=HCI_SUCCESS,
connection_handle=1,
role=1,
peer_address_type=1,
peer_address=address,
connection_interval=3,
peripheral_latency=4,
supervision_timeout=5,
central_clock_accuracy=6,
)
basic_check(event)
@@ -62,11 +62,13 @@ def test_HCI_LE_Advertising_Report_Event():
address = Address('00:11:22:33:44:55/P')
report = HCI_LE_Advertising_Report_Event.Report(
HCI_LE_Advertising_Report_Event.Report.FIELDS,
event_type = HCI_LE_Advertising_Report_Event.ADV_IND,
address_type = Address.PUBLIC_DEVICE_ADDRESS,
address = address,
data = bytes.fromhex('0201061106ba5689a6fabfa2bd01467d6e00fbabad08160a181604659b03'),
rssi = 100
event_type=HCI_LE_Advertising_Report_Event.ADV_IND,
address_type=Address.PUBLIC_DEVICE_ADDRESS,
address=address,
data=bytes.fromhex(
'0201061106ba5689a6fabfa2bd01467d6e00fbabad08160a181604659b03'
),
rssi=100,
)
event = HCI_LE_Advertising_Report_Event([report])
basic_check(event)
@@ -75,9 +77,9 @@ def test_HCI_LE_Advertising_Report_Event():
# -----------------------------------------------------------------------------
def test_HCI_LE_Read_Remote_Features_Complete_Event():
event = HCI_LE_Read_Remote_Features_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x007,
le_features = bytes.fromhex('0011223344556677')
status=HCI_SUCCESS,
connection_handle=0x007,
le_features=bytes.fromhex('0011223344556677'),
)
basic_check(event)
@@ -85,11 +87,11 @@ def test_HCI_LE_Read_Remote_Features_Complete_Event():
# -----------------------------------------------------------------------------
def test_HCI_LE_Connection_Update_Complete_Event():
event = HCI_LE_Connection_Update_Complete_Event(
status = HCI_SUCCESS,
connection_handle = 0x007,
connection_interval = 10,
peripheral_latency = 3,
supervision_timeout = 5
status=HCI_SUCCESS,
connection_handle=0x007,
connection_interval=10,
peripheral_latency=3,
supervision_timeout=5,
)
basic_check(event)
@@ -97,8 +99,7 @@ def test_HCI_LE_Connection_Update_Complete_Event():
# -----------------------------------------------------------------------------
def test_HCI_LE_Channel_Selection_Algorithm_Event():
event = HCI_LE_Channel_Selection_Algorithm_Event(
connection_handle = 7,
channel_selection_algorithm = 1
connection_handle=7, channel_selection_algorithm=1
)
basic_check(event)
@@ -107,29 +108,29 @@ def test_HCI_LE_Channel_Selection_Algorithm_Event():
def test_HCI_Command_Complete_Event():
# With a serializable object
event = HCI_Command_Complete_Event(
num_hci_command_packets = 34,
command_opcode = HCI_LE_READ_BUFFER_SIZE_COMMAND,
return_parameters = HCI_LE_Read_Buffer_Size_Command.create_return_parameters(
status = 0,
hc_le_acl_data_packet_length = 1234,
hc_total_num_le_acl_data_packets = 56
)
num_hci_command_packets=34,
command_opcode=HCI_LE_READ_BUFFER_SIZE_COMMAND,
return_parameters=HCI_LE_Read_Buffer_Size_Command.create_return_parameters(
status=0,
hc_le_acl_data_packet_length=1234,
hc_total_num_le_acl_data_packets=56,
),
)
basic_check(event)
# With an arbitrary byte array
event = HCI_Command_Complete_Event(
num_hci_command_packets = 1,
command_opcode = HCI_RESET_COMMAND,
return_parameters = bytes([1, 2, 3, 4])
num_hci_command_packets=1,
command_opcode=HCI_RESET_COMMAND,
return_parameters=bytes([1, 2, 3, 4]),
)
basic_check(event)
# With a simple status as a 1-byte array
event = HCI_Command_Complete_Event(
num_hci_command_packets = 1,
command_opcode = HCI_RESET_COMMAND,
return_parameters = bytes([7])
num_hci_command_packets=1,
command_opcode=HCI_RESET_COMMAND,
return_parameters=bytes([7]),
)
basic_check(event)
event = HCI_Packet.from_bytes(event.to_bytes())
@@ -137,9 +138,7 @@ def test_HCI_Command_Complete_Event():
# With a simple status as an integer status
event = HCI_Command_Complete_Event(
num_hci_command_packets = 1,
command_opcode = HCI_RESET_COMMAND,
return_parameters = 9
num_hci_command_packets=1, command_opcode=HCI_RESET_COMMAND, return_parameters=9
)
basic_check(event)
assert event.return_parameters == 9
@@ -148,19 +147,14 @@ def test_HCI_Command_Complete_Event():
# -----------------------------------------------------------------------------
def test_HCI_Command_Status_Event():
event = HCI_Command_Status_Event(
status = 0,
num_hci_command_packets = 37,
command_opcode = HCI_DISCONNECT_COMMAND
status=0, num_hci_command_packets=37, command_opcode=HCI_DISCONNECT_COMMAND
)
basic_check(event)
# -----------------------------------------------------------------------------
def test_HCI_Number_Of_Completed_Packets_Event():
event = HCI_Number_Of_Completed_Packets_Event([
(1, 2),
(3, 4)
])
event = HCI_Number_Of_Completed_Packets_Event([(1, 2), (3, 4)])
basic_check(event)
@@ -199,25 +193,20 @@ def test_HCI_Read_Local_Supported_Features_Command():
# -----------------------------------------------------------------------------
def test_HCI_Disconnect_Command():
command = HCI_Disconnect_Command(
connection_handle = 123,
reason = 0x11
)
command = HCI_Disconnect_Command(connection_handle=123, reason=0x11)
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_Set_Event_Mask_Command():
command = HCI_Set_Event_Mask_Command(
event_mask = bytes.fromhex('0011223344556677')
)
command = HCI_Set_Event_Mask_Command(event_mask=bytes.fromhex('0011223344556677'))
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Event_Mask_Command():
command = HCI_LE_Set_Event_Mask_Command(
le_event_mask = bytes.fromhex('0011223344556677')
le_event_mask=bytes.fromhex('0011223344556677')
)
basic_check(command)
@@ -225,7 +214,7 @@ def test_HCI_LE_Set_Event_Mask_Command():
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Random_Address_Command():
command = HCI_LE_Set_Random_Address_Command(
random_address = Address('00:11:22:33:44:55')
random_address=Address('00:11:22:33:44:55')
)
basic_check(command)
@@ -233,14 +222,14 @@ def test_HCI_LE_Set_Random_Address_Command():
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Advertising_Parameters_Command():
command = HCI_LE_Set_Advertising_Parameters_Command(
advertising_interval_min = 20,
advertising_interval_max = 30,
advertising_type = HCI_LE_Set_Advertising_Parameters_Command.ADV_NONCONN_IND,
own_address_type = Address.PUBLIC_DEVICE_ADDRESS,
peer_address_type = Address.RANDOM_DEVICE_ADDRESS,
peer_address = Address('00:11:22:33:44:55'),
advertising_channel_map = 0x03,
advertising_filter_policy = 1
advertising_interval_min=20,
advertising_interval_max=30,
advertising_type=HCI_LE_Set_Advertising_Parameters_Command.ADV_NONCONN_IND,
own_address_type=Address.PUBLIC_DEVICE_ADDRESS,
peer_address_type=Address.RANDOM_DEVICE_ADDRESS,
peer_address=Address('00:11:22:33:44:55'),
advertising_channel_map=0x03,
advertising_filter_policy=1,
)
basic_check(command)
@@ -248,7 +237,7 @@ def test_HCI_LE_Set_Advertising_Parameters_Command():
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Advertising_Data_Command():
command = HCI_LE_Set_Advertising_Data_Command(
advertising_data = bytes.fromhex('AABBCC')
advertising_data=bytes.fromhex('AABBCC')
)
basic_check(command)
@@ -256,39 +245,36 @@ def test_HCI_LE_Set_Advertising_Data_Command():
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Scan_Parameters_Command():
command = HCI_LE_Set_Scan_Parameters_Command(
le_scan_type = 1,
le_scan_interval = 20,
le_scan_window = 10,
own_address_type = 1,
scanning_filter_policy = 0
le_scan_type=1,
le_scan_interval=20,
le_scan_window=10,
own_address_type=1,
scanning_filter_policy=0,
)
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Scan_Enable_Command():
command = HCI_LE_Set_Scan_Enable_Command(
le_scan_enable = 1,
filter_duplicates = 0
)
command = HCI_LE_Set_Scan_Enable_Command(le_scan_enable=1, filter_duplicates=0)
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_LE_Create_Connection_Command():
command = HCI_LE_Create_Connection_Command(
le_scan_interval = 4,
le_scan_window = 5,
initiator_filter_policy = 1,
peer_address_type = 1,
peer_address = Address('00:11:22:33:44:55'),
own_address_type = 2,
connection_interval_min = 7,
connection_interval_max = 8,
max_latency = 9,
supervision_timeout = 10,
min_ce_length = 11,
max_ce_length = 12
le_scan_interval=4,
le_scan_window=5,
initiator_filter_policy=1,
peer_address_type=1,
peer_address=Address('00:11:22:33:44:55'),
own_address_type=2,
connection_interval_min=7,
connection_interval_max=8,
max_latency=9,
supervision_timeout=10,
min_ce_length=11,
max_ce_length=12,
)
basic_check(command)
@@ -296,19 +282,19 @@ def test_HCI_LE_Create_Connection_Command():
# -----------------------------------------------------------------------------
def test_HCI_LE_Extended_Create_Connection_Command():
command = HCI_LE_Extended_Create_Connection_Command(
initiator_filter_policy = 0,
own_address_type = 0,
peer_address_type = 1,
peer_address = Address('00:11:22:33:44:55'),
initiating_phys = 3,
scan_intervals = (10, 11),
scan_windows = (12, 13),
connection_interval_mins = (14, 15),
connection_interval_maxs = (16, 17),
max_latencies = (18, 19),
supervision_timeouts = (20, 21),
min_ce_lengths = (100, 101),
max_ce_lengths = (102, 103)
initiator_filter_policy=0,
own_address_type=0,
peer_address_type=1,
peer_address=Address('00:11:22:33:44:55'),
initiating_phys=3,
scan_intervals=(10, 11),
scan_windows=(12, 13),
connection_interval_mins=(14, 15),
connection_interval_maxs=(16, 17),
max_latencies=(18, 19),
supervision_timeouts=(20, 21),
min_ce_lengths=(100, 101),
max_ce_lengths=(102, 103),
)
basic_check(command)
@@ -316,8 +302,7 @@ def test_HCI_LE_Extended_Create_Connection_Command():
# -----------------------------------------------------------------------------
def test_HCI_LE_Add_Device_To_Filter_Accept_List_Command():
command = HCI_LE_Add_Device_To_Filter_Accept_List_Command(
address_type = 1,
address = Address('00:11:22:33:44:55')
address_type=1, address=Address('00:11:22:33:44:55')
)
basic_check(command)
@@ -325,8 +310,7 @@ def test_HCI_LE_Add_Device_To_Filter_Accept_List_Command():
# -----------------------------------------------------------------------------
def test_HCI_LE_Remove_Device_From_Filter_Accept_List_Command():
command = HCI_LE_Remove_Device_From_Filter_Accept_List_Command(
address_type = 1,
address = Address('00:11:22:33:44:55')
address_type=1, address=Address('00:11:22:33:44:55')
)
basic_check(command)
@@ -334,32 +318,26 @@ def test_HCI_LE_Remove_Device_From_Filter_Accept_List_Command():
# -----------------------------------------------------------------------------
def test_HCI_LE_Connection_Update_Command():
command = HCI_LE_Connection_Update_Command(
connection_handle = 0x0002,
connection_interval_min = 10,
connection_interval_max = 20,
max_latency = 7,
supervision_timeout = 3,
min_ce_length = 100,
max_ce_length = 200
connection_handle=0x0002,
connection_interval_min=10,
connection_interval_max=20,
max_latency=7,
supervision_timeout=3,
min_ce_length=100,
max_ce_length=200,
)
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_LE_Read_Remote_Features_Command():
command = HCI_LE_Read_Remote_Features_Command(
connection_handle = 0x0002
)
command = HCI_LE_Read_Remote_Features_Command(connection_handle=0x0002)
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Default_PHY_Command():
command = HCI_LE_Set_Default_PHY_Command(
all_phys = 0,
tx_phys = 1,
rx_phys = 1
)
command = HCI_LE_Set_Default_PHY_Command(all_phys=0, tx_phys=1, rx_phys=1)
basic_check(command)
@@ -372,10 +350,10 @@ def test_HCI_LE_Set_Extended_Scan_Parameters_Command():
scan_types=[
HCI_LE_Set_Extended_Scan_Parameters_Command.ACTIVE_SCANNING,
HCI_LE_Set_Extended_Scan_Parameters_Command.ACTIVE_SCANNING,
HCI_LE_Set_Extended_Scan_Parameters_Command.PASSIVE_SCANNING
HCI_LE_Set_Extended_Scan_Parameters_Command.PASSIVE_SCANNING,
],
scan_intervals=[1, 2, 3],
scan_windows=[4, 5, 6]
scan_windows=[4, 5, 6],
)
basic_check(command)

View File

@@ -38,7 +38,7 @@ def test_import():
sdp,
smp,
transport,
utils
utils,
)
assert att
@@ -65,36 +65,47 @@ def test_import():
# -----------------------------------------------------------------------------
def test_app_imports():
from apps.console import main
assert main
from apps.controller_info import main
assert main
from apps.controllers import main
assert main
from apps.gatt_dump import main
assert main
from apps.gg_bridge import main
assert main
from apps.hci_bridge import main
assert main
from apps.pair import main
assert main
from apps.scan import main
assert main
from apps.show import main
assert main
from apps.unbond import main
assert main
from apps.usb_probe import main
assert main
@@ -103,7 +114,7 @@ def test_profiles_imports():
from bumble.profiles import (
battery_service,
device_information_service,
heart_rate_service
heart_rate_service,
)
assert battery_service

View File

@@ -27,9 +27,7 @@ from bumble.device import Device
from bumble.host import Host
from bumble.transport import AsyncPipeSink
from bumble.core import ProtocolError
from bumble.l2cap import (
L2CAP_Connection_Request
)
from bumble.l2cap import L2CAP_Connection_Request
# -----------------------------------------------------------------------------
@@ -45,18 +43,18 @@ class TwoDevices:
self.link = LocalLink()
self.controllers = [
Controller('C1', link = self.link),
Controller('C2', link = self.link)
Controller('C1', link=self.link),
Controller('C2', link=self.link),
]
self.devices = [
Device(
address = 'F0:F1:F2:F3:F4:F5',
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
address='F0:F1:F2:F3:F4:F5',
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address = 'F5:F4:F3:F2:F1:F0',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
)
address='F5:F4:F3:F2:F1:F0',
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
]
self.paired = [None, None]
@@ -74,8 +72,12 @@ async def setup_connection():
two_devices = TwoDevices()
# Attach listeners
two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection))
two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection))
two_devices.devices[0].on(
'connection', lambda connection: two_devices.on_connection(0, connection)
)
two_devices.devices[1].on(
'connection', lambda connection: two_devices.on_connection(1, connection)
)
# Start
await two_devices.devices[0].power_on()
@@ -102,19 +104,25 @@ def test_helpers():
psm = L2CAP_Connection_Request.serialize_psm(0x242311)
assert psm == bytes([0x11, 0x23, 0x24])
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x01, 0x00, 0x44]), 1)
(offset, psm) = L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x01, 0x00, 0x44]), 1
)
assert offset == 3
assert psm == 0x01
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x23, 0x10, 0x44]), 1)
(offset, psm) = L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x23, 0x10, 0x44]), 1
)
assert offset == 3
assert psm == 0x1023
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1)
(offset, psm) = L2CAP_Connection_Request.parse_psm(
bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1
)
assert offset == 4
assert psm == 0x242311
rq = L2CAP_Connection_Request(psm = 0x01, source_cid = 0x44)
rq = L2CAP_Connection_Request(psm=0x01, source_cid=0x44)
brq = bytes(rq)
srq = L2CAP_Connection_Request.from_bytes(brq)
assert srq.psm == rq.psm
@@ -147,11 +155,7 @@ async def test_basic_connection():
devices.devices[1].register_l2cap_channel_server(psm, on_coc)
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
messages = (
bytes([1, 2, 3]),
bytes([4, 5, 6]),
bytes(10000)
)
messages = (bytes([1, 2, 3]), bytes([4, 5, 6]), bytes(10000))
for message in messages:
l2cap_channel.write(message)
await asyncio.sleep(0)
@@ -191,18 +195,11 @@ async def transfer_payload(max_credits, mtu, mps):
channel.sink = on_data
psm = devices.devices[1].register_l2cap_channel_server(
psm = 0,
server = on_coc,
max_credits = max_credits,
mtu = mtu,
mps = mps
psm=0, server=on_coc, max_credits=max_credits, mtu=mtu, mps=mps
)
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
messages = [
bytes([1, 2, 3, 4, 5, 6, 7]) * x
for x in (3, 10, 100, 789)
]
messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100, 789)]
for message in messages:
l2cap_channel.write(message)
await asyncio.sleep(0)
@@ -233,7 +230,7 @@ async def test_bidirectional_transfer():
client_received = []
server_received = []
server_channel = None
server_channel = None
def on_server_coc(channel):
nonlocal server_channel
@@ -251,10 +248,7 @@ async def test_bidirectional_transfer():
client_channel = await devices.connections[0].open_l2cap_channel(psm)
client_channel.sink = on_client_data
messages = [
bytes([1, 2, 3, 4, 5, 6, 7]) * x
for x in (3, 10, 100)
]
messages = [bytes([1, 2, 3, 4, 5, 6, 7]) * x for x in (3, 10, 100)]
for message in messages:
client_channel.write(message)
await client_channel.drain()
@@ -278,7 +272,8 @@ async def run():
await test_transfer()
await test_bidirectional_transfer()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

View File

@@ -30,10 +30,10 @@ def basic_frame_check(x):
parsed_bytes = bytes(parsed)
if len(serialized) < 500:
print('Parsed Bytes:', parsed_bytes.hex())
assert(parsed_bytes == serialized)
assert parsed_bytes == serialized
x_str = str(x)
parsed_str = str(parsed)
assert(x_str == parsed_str)
assert x_str == parsed_str
# -----------------------------------------------------------------------------

View File

@@ -31,10 +31,10 @@ def basic_check(x):
parsed_bytes = bytes(parsed)
if len(serialized) < 500:
print('Parsed Bytes:', parsed_bytes.hex())
assert(parsed_bytes == serialized)
assert parsed_bytes == serialized
x_str = str(x)
parsed_str = str(parsed)
assert(x_str == parsed_str)
assert x_str == parsed_str
# -----------------------------------------------------------------------------
@@ -99,19 +99,25 @@ def test_data_elements():
e = DataElement(DataElement.SEQUENCE, [DataElement(DataElement.BOOLEAN, True)])
basic_check(e)
e = DataElement(DataElement.SEQUENCE, [
DataElement(DataElement.BOOLEAN, True),
DataElement(DataElement.TEXT_STRING, 'hello')
])
e = DataElement(
DataElement.SEQUENCE,
[
DataElement(DataElement.BOOLEAN, True),
DataElement(DataElement.TEXT_STRING, 'hello'),
],
)
basic_check(e)
e = DataElement(DataElement.ALTERNATIVE, [DataElement(DataElement.BOOLEAN, True)])
basic_check(e)
e = DataElement(DataElement.ALTERNATIVE, [
DataElement(DataElement.BOOLEAN, True),
DataElement(DataElement.TEXT_STRING, 'hello')
])
e = DataElement(
DataElement.ALTERNATIVE,
[
DataElement(DataElement.BOOLEAN, True),
DataElement(DataElement.TEXT_STRING, 'hello'),
],
)
basic_check(e)
e = DataElement(DataElement.URL, 'http://example.com')
@@ -133,10 +139,14 @@ def test_data_elements():
e = DataElement.boolean(True)
basic_check(e)
e = DataElement.sequence([DataElement.signed_integer(0, 1), DataElement.text_string('hello')])
e = DataElement.sequence(
[DataElement.signed_integer(0, 1), DataElement.text_string('hello')]
)
basic_check(e)
e = DataElement.alternative([DataElement.signed_integer(0, 1), DataElement.text_string('hello')])
e = DataElement.alternative(
[DataElement.signed_integer(0, 1), DataElement.text_string('hello')]
)
basic_check(e)
e = DataElement.url('http://foobar.com')

Some files were not shown because too many files have changed in this diff Show More