Compare commits

..

1 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
34161e5e14 add test.release task to facilitate CI integration 2022-08-16 13:30:45 -07:00
193 changed files with 6699 additions and 20297 deletions

View File

@@ -1,2 +0,0 @@
# Migrate code style to Black
135df0dcc01ab765f432e19b1a5202d29bd55545

View File

@@ -1,35 +0,0 @@
# Check the code against the formatter and linter
name: Code format and lint check
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
check:
name: Check Code
runs-on: ubuntu-latest
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development]"
- name: Check
run: |
invoke project.pre-commit

View File

@@ -14,10 +14,6 @@ jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
fail-fast: false
steps:
- name: Check out from Git
@@ -25,18 +21,18 @@ jobs:
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]"
- name: Test
- name: Test with pytest
run: |
invoke test
pytest
- name: Build
run: |
inv build

8
.gitignore vendored
View File

@@ -3,9 +3,9 @@ build/
dist/
*.egg-info/
*~
bumble/__pycache__
docs/mkdocs/site
tests/__pycache__
test-results.xml
__pycache__
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json
bumble/transport/__pycache__
bumble/profiles/__pycache__

80
.vscode/settings.json vendored
View File

@@ -1,80 +0,0 @@
{
"cSpell.words": [
"Abortable",
"altsetting",
"ansiblue",
"ansicyan",
"ansigreen",
"ansimagenta",
"ansired",
"ansiyellow",
"appendleft",
"ASHA",
"asyncio",
"ATRAC",
"avdtp",
"bitpool",
"bitstruct",
"BSCP",
"BTPROTO",
"CCCD",
"cccds",
"cmac",
"CONNECTIONLESS",
"csrcs",
"datagram",
"DATALINK",
"delayreport",
"deregisters",
"deregistration",
"dhkey",
"diversifier",
"Fitbit",
"GATTLINK",
"HANDSFREE",
"keydown",
"keyup",
"levelname",
"libc",
"libusb",
"MITM",
"NDIS",
"NONBLOCK",
"NONCONN",
"OXIMETER",
"popleft",
"psms",
"pyee",
"pyusb",
"rfcomm",
"ROHC",
"rssi",
"SEID",
"seids",
"SERV",
"ssrc",
"strerror",
"subband",
"subbands",
"subevent",
"Subrating",
"substates",
"tobytes",
"tsep",
"usbmodem",
"vhci",
"websockets",
"xcursor",
"ycursor"
],
"[python]": {
"editor.rulers": [88]
},
"python.formatting.provider": "black",
"pylint.importStrategy": "useBundled",
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

21
LICENSE
View File

@@ -199,23 +199,4 @@
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---
Files: bumble/colors.py
Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
limitations under the License.

View File

@@ -9,16 +9,16 @@
Bluetooth Stack for Apps, Emulation, Test and Experimentation
=============================================================
<img src="docs/mkdocs/src/images/logo_framed.png" alt="Logo" width="200" height="200"/>
<img src="docs/mkdocs/src/images/logo_framed.png" alt="drawing" width="200" height="200"/>
Bumble is a full-featured Bluetooth stack written entirely in Python. It supports most of the common Bluetooth Low Energy (BLE) and Bluetooth Classic (BR/EDR) protocols and profiles, including GAP, L2CAP, ATT, GATT, SMP, SDP, RFCOMM, HFP, HID and A2DP. The stack can be used with physical radios via HCI over USB, UART, or the Linux VHCI, as well as virtual radios, including the virtual Bluetooth support of the Android emulator.
## Documentation
Browse the pre-built [Online Documentation](https://google.github.io/bumble/),
Browse the pre-built [Online Documentation](https://google.github.io/bumble/),
or see the documentation source under `docs/mkdocs/src`, or build the static HTML site from the markdown text with:
```
mkdocs build -f docs/mkdocs/mkdocs.yml
mkdocs build -f docs/mkdocs/mkdocs.yml
```
## Usage
@@ -29,7 +29,7 @@ For a quick start to using Bumble, see the [Getting Started](docs/mkdocs/src/get
### Dependencies
To install package dependencies needed to run the bumble examples, execute the following commands:
To install package dependencies needed to run the bumble examples execute the following commands:
```
python -m pip install --upgrade pip
@@ -38,20 +38,12 @@ python -m pip install ".[test,development,documentation]"
### Examples
Refer to the [Examples Documentation](examples/README.md) for details on the included example scripts and how to run them.
Refer to the [Example Documentation](examples/README.md) for details on the included example scripts and how to run them.
The complete [list of Examples](/docs/mkdocs/src/examples/index.md), and what they are designed to do is here.
There are also a set of [Apps and Tools](docs/mkdocs/src/apps_and_tools/index.md) that show the utility of Bumble.
### Using Bumble With a USB Dongle
Bumble is easiest to use with a dedicated USB dongle.
This is because internal Bluetooth interfaces tend to be locked down by the operating system.
You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if your are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
## License
Licensed under the [Apache 2.0](LICENSE) License.

View File

@@ -47,3 +47,5 @@ NOTE: this assumes you're running a Link Relay on port `10723`.
## `console.py`
A simple text-based-ui interactive Bluetooth device with GATT client capabilities.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -19,103 +19,44 @@ import asyncio
import os
import logging
import click
from colors import color
from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color
from bumble.core import name_or_number
from bumble.hci import (
map_null_terminated_utf8_string,
HCI_SUCCESS,
HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_SUCCESS,
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_Local_Name_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_READ_LOCAL_NAME_COMMAND
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
def command_succeeded(response):
if isinstance(response, HCI_Command_Status_Event):
return response.status == HCI_SUCCESS
if isinstance(response, HCI_Command_Complete_Event):
return response.return_parameters.status == HCI_SUCCESS
return False
# -----------------------------------------------------------------------------
async def get_classic_info(host):
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
if response.return_parameters.status == HCI_SUCCESS:
print()
print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
)
print(color('Classic Address:', 'yellow'), response.return_parameters.bd_addr)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command())
if command_succeeded(response):
if response.return_parameters.status == HCI_SUCCESS:
print()
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response.return_parameters.local_name),
)
print(color('Local Name:', 'yellow'), map_null_terminated_utf8_string(response.return_parameters.local_name))
# -----------------------------------------------------------------------------
async def get_le_info(host):
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if command_succeeded(response):
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if command_succeeded(response):
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
if command_succeeded(response):
print(
color('Maximum Data Length:', 'yellow'),
(
f'tx:{response.return_parameters.supported_max_tx_octets}/'
f'{response.return_parameters.supported_max_tx_time}, '
f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'{response.return_parameters.supported_max_rx_time}'
),
'\n',
)
print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features:
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
@@ -132,19 +73,10 @@ async def async_main(transport):
# Print version
print(color('Version:', 'yellow'))
print(
color(' Manufacturer: ', 'green'),
name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier),
)
print(
color(' HCI Version: ', 'green'),
name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version),
)
print(color(' Manufacturer: ', 'green'), name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier))
print(color(' HCI Version: ', 'green'), name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version))
print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion)
print(
color(' LMP Version: ', 'green'),
name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version),
)
print(color(' LMP Version: ', 'green'), name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version))
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the Classic info
@@ -164,7 +96,7 @@ async def async_main(transport):
@click.command()
@click.argument('transport')
def main(transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(transport))

View File

@@ -28,14 +28,11 @@ from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def async_main():
if len(sys.argv) != 3:
print(
'Usage: controllers.py <hci-transport-1> <hci-transport-2> '
'[<hci-transport-3> ...]'
)
print('Usage: controllers.py <hci-transport-1> <hci-transport-2> [<hci-transport-3> ...]')
print('example: python controllers.py pty:ble1 pty:ble2')
return
# Create a local link to attach the controllers to
# Create a loccal link to attach the controllers to
link = LocalLink()
# Create a transport and controller for all requested names
@@ -44,12 +41,7 @@ async def async_main():
for index, transport_name in enumerate(sys.argv[1:]):
transport = await open_transport_or_link(transport_name)
transports.append(transport)
controller = Controller(
f'C{index}',
host_source=transport.source,
host_sink=transport.sink,
link=link,
)
controller = Controller(f'C{index}', host_source = transport.source, host_sink = transport.sink, link = link)
controllers.append(controller)
# Wait until the user interrupts
@@ -62,7 +54,7 @@ async def async_main():
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main())

View File

@@ -19,9 +19,9 @@ import asyncio
import os
import logging
import click
from colors import color
import bumble.core
from bumble.colors import color
from bumble.core import ProtocolError, TimeoutError
from bumble.device import Device, Peer
from bumble.gatt import show_services
from bumble.transport import open_transport_or_link
@@ -49,9 +49,9 @@ async def dump_gatt_db(peer, done):
try:
value = await attribute.read_value()
print(color(f'{value.hex()}', 'green'))
except bumble.core.ProtocolError as error:
except ProtocolError as error:
print(color(error, 'red'))
except bumble.core.TimeoutError:
except TimeoutError:
print(color('read timeout', 'red'))
if done is not None:
@@ -64,13 +64,9 @@ async def async_main(device_config, encrypt, transport, address_or_name):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
else:
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
await device.power_on()
if address_or_name:
@@ -85,12 +81,7 @@ async def async_main(device_config, encrypt, transport, address_or_name):
else:
# Wait for a connection
done = asyncio.get_running_loop().create_future()
device.on(
'connection',
lambda connection: asyncio.create_task(
dump_gatt_db(Peer(connection), done)
),
)
device.on('connection', lambda connection: asyncio.create_task(dump_gatt_db(Peer(connection), done)))
await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue'))
@@ -108,7 +99,7 @@ def main(device_config, encrypt, transport, address_or_name):
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection.
"""
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name))

View File

@@ -17,14 +17,13 @@
# -----------------------------------------------------------------------------
import asyncio
import os
import struct
import logging
import click
from colors import color
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.core import AdvertisingData
from bumble.gatt import Service, Characteristic, CharacteristicValue
from bumble.gatt import Service, Characteristic
from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.hci import HCI_Constant
@@ -33,73 +32,24 @@ from bumble.hci import HCI_Constant
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = (
'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
)
GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = 'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
GG_PREFERRED_MTU = 256
# -----------------------------------------------------------------------------
class GattlinkL2capEndpoint:
class GattlinkHubBridge(Device.Listener):
def __init__(self):
self.l2cap_channel = None
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# Called when an L2CAP SDU has been received
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
while len(sdu):
if self.l2cap_packet_size == 0:
# Expect a new packet
self.l2cap_packet_size = sdu[0] + 1
sdu = sdu[1:]
else:
bytes_needed = self.l2cap_packet_size - len(self.l2cap_packet)
chunk = min(bytes_needed, len(sdu))
self.l2cap_packet += sdu[:chunk]
sdu = sdu[chunk:]
if len(self.l2cap_packet) == self.l2cap_packet_size:
self.on_l2cap_packet(self.l2cap_packet)
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# -----------------------------------------------------------------------------
class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device, peer_address):
super().__init__()
self.device = device
self.peer_address = peer_address
self.peer = None
self.tx_socket = None
self.peer = None
self.rx_socket = None
self.tx_socket = None
self.rx_characteristic = None
self.tx_characteristic = None
self.l2cap_psm_characteristic = None
device.listener = self
async def start(self):
# Connect to the peer
print(f'=== Connecting to {self.peer_address}...')
await self.device.connect(self.peer_address)
async def connect_l2cap(self, psm):
print(color(f'### Connecting with L2CAP on PSM = {psm}', 'yellow'))
try:
self.l2cap_channel = await self.peer.connection.open_l2cap_channel(psm)
print(color('*** Connected', 'yellow'), self.l2cap_channel)
self.l2cap_channel.sink = self.on_coc_sdu
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
@AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection):
print(f'=== Connected to {connection}')
self.peer = Peer(connection)
@@ -130,221 +80,115 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
self.rx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
self.tx_characteristic = characteristic
elif (
characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID
):
self.l2cap_psm_characteristic = characteristic
print('RX:', self.rx_characteristic)
print('TX:', self.tx_characteristic)
print('PSM:', self.l2cap_psm_characteristic)
if self.l2cap_psm_characteristic:
# Subscribe to and then read the PSM value
await self.peer.subscribe(
self.l2cap_psm_characteristic, self.on_l2cap_psm_received
)
psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
psm = struct.unpack('<H', psm_bytes)[0]
await self.connect_l2cap(psm)
elif self.tx_characteristic:
# Subscribe to TX
# Subscribe to TX
if self.tx_characteristic:
await self.peer.subscribe(self.tx_characteristic, self.on_tx_received)
print(color('=== Subscribed to Gattlink TX', 'yellow'))
else:
print(color('!!! No Gattlink TX or PSM found', 'red'))
print(color('!!! Gattlink TX not found', 'red'))
def on_connection_failure(self, error):
print(color(f'!!! Connection failed: {error}'))
def on_disconnection(self, reason):
print(
color(
f'!!! Disconnected from {self.peer}, '
f'reason={HCI_Constant.error_name(reason)}',
'red',
)
)
print(color(f'!!! Disconnected from {self.peer}, reason={HCI_Constant.error_name(reason)}', 'red'))
self.tx_characteristic = None
self.rx_characteristic = None
self.peer = None
# Called when an L2CAP packet has been received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called by the GATT client when a notification is received
def on_tx_received(self, value):
print(color(f'<<< [GATT TX]: {len(value)} bytes', 'cyan'))
print(color('>>> TX:', 'magenta'), value.hex())
if self.tx_socket:
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(value)
# Called by asyncio when the UDP socket is created
def on_l2cap_psm_received(self, value):
psm = struct.unpack('<H', value)[0]
asyncio.create_task(self.connect_l2cap(psm))
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
pass
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, _address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
def datagram_received(self, data, address):
print(color('<<< RX:', 'magenta'), data.hex())
if self.l2cap_channel:
print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.peer and self.rx_characteristic:
print(color('>>> [GATT RX]', 'yellow'))
# TODO: use a queue instead of creating a task everytime
if self.peer and self.rx_characteristic:
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# -----------------------------------------------------------------------------
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device):
super().__init__()
self.device = device
self.peer = None
class GattlinkNodeBridge(Device.Listener):
def __init__(self):
self.peer = None
self.rx_socket = None
self.tx_socket = None
self.tx_subscriber = None
self.rx_characteristic = None
self.transport = None
# Register as a listener
device.listener = self
# Listen for incoming L2CAP CoC connections
psm = 0xFB
device.register_l2cap_channel_server(0xFB, self.on_coc)
print(f'### Listening for CoC connection on PSM {psm}')
# Setup the Gattlink service
self.rx_characteristic = Characteristic(
GG_GATTLINK_RX_CHARACTERISTIC_UUID,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_rx_write),
)
self.tx_characteristic = Characteristic(
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
)
self.tx_characteristic.on('subscription', self.on_tx_subscription)
self.psm_characteristic = Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([psm, 0]),
)
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[self.rx_characteristic, self.tx_characteristic, self.psm_characteristic],
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(
reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))
),
),
]
)
)
async def start(self):
await self.device.start_advertising()
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
self.transport = transport
pass
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, _address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
def datagram_received(self, data, address):
print(color('<<< RX:', 'magenta'), data.hex())
if self.l2cap_channel:
print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.tx_subscriber:
print(color('>>> [GATT TX]', 'yellow'))
self.tx_characteristic.value = data
asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
# Called when a write to the RX characteristic has been received
def on_rx_write(self, _connection, data):
print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(data)
# Called when the subscription to the TX characteristic has changed
def on_tx_subscription(self, peer, enabled):
print(
f'### [GATT TX] subscription from {peer}: '
f'{"enabled" if enabled else "disabled"}'
)
if enabled:
self.tx_subscriber = peer
else:
self.tx_subscriber = None
# Called when an L2CAP packet is received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called when a new connection is established
def on_coc(self, channel):
print('*** CoC Connection', channel)
self.l2cap_channel = channel
channel.sink = self.on_coc_sdu
# TODO: use a queue instead of creating a task everytime
if self.peer and self.rx_characteristic:
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# -----------------------------------------------------------------------------
async def run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
async def run(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
# Instantiate a bridge object
device = Device.with_hci('Bumble GG', device_address, hci_source, hci_sink)
# Instantiate a bridge object
if role_or_peer_address == 'node':
bridge = GattlinkNodeBridge(device)
else:
bridge = GattlinkHubBridge(device, role_or_peer_address)
bridge = GattlinkNodeBridge()
# Create a UDP to RX bridge (receive from UDP, send to RX)
loop = asyncio.get_running_loop()
await loop.create_datagram_endpoint(
lambda: bridge, local_addr=(receive_host, receive_port)
lambda: bridge,
local_addr=(receive_host, receive_port)
)
# Create a UDP to TX bridge (receive from TX, send to UDP)
bridge.tx_socket, _ = await loop.create_datagram_endpoint(
asyncio.DatagramProtocol,
remote_addr=(send_host, send_port),
lambda: asyncio.DatagramProtocol(),
remote_addr=(send_host, send_port)
)
# Create a device to manage the host, with a custom listener
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device.listener = bridge
await device.power_on()
await bridge.start()
# Connect to the peer
# print(f'=== Connecting to {device_address}...')
# await device.connect(device_address)
# TODO move to class
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ,
Characteristic.READABLE,
bytes([193, 0])
)
]
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
)
await device.start_advertising()
# Wait until the source terminates
await hci_source.wait_for_termination()
@@ -353,44 +197,15 @@ async def run(
@click.command()
@click.argument('hci_transport')
@click.argument('device_address')
@click.argument('role_or_peer_address')
@click.option(
'-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to'
)
@click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to')
@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
@click.option(
'-rh',
'--receive-host',
type=str,
default='127.0.0.1',
help='UDP host to receive on',
)
@click.option(
'-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on'
)
def main(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
asyncio.run(
run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
)
)
@click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on')
@click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on')
def main(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run(hci_transport, device_address, send_host, send_port, receive_host, receive_port))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
main()

View File

@@ -34,29 +34,16 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def async_main():
if len(sys.argv) < 3:
print(
'Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> '
'[command-short-circuit-list]'
)
print(
'example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 '
'serial:/dev/tty.usbmodem0006839912171,1000000 '
'0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078'
)
print('Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]')
print('example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078')
return
print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[1]) as (
hci_host_source,
hci_host_sink,
):
async with await transport.open_transport_or_link(sys.argv[1]) as (hci_host_source, hci_host_sink):
print('>>> connected')
print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[2]) as (
hci_controller_source,
hci_controller_sink,
):
async with await transport.open_transport_or_link(sys.argv[2]) as (hci_controller_source, hci_controller_sink):
print('>>> connected')
command_short_circuits = []
@@ -64,43 +51,36 @@ async def async_main():
for op_code_str in sys.argv[3].split(','):
if ':' in op_code_str:
ogf, ocf = op_code_str.split(':')
command_short_circuits.append(
hci.hci_command_op_code(int(ogf, 16), int(ocf, 16))
)
command_short_circuits.append(hci.hci_command_op_code(int(ogf, 16), int(ocf, 16)))
else:
command_short_circuits.append(int(op_code_str, 16))
def host_to_controller_filter(hci_packet):
if (
hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET
and hci_packet.op_code in command_short_circuits
):
if hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET and hci_packet.op_code in command_short_circuits:
# Respond with a success response
logger.debug('short-circuiting packet')
response = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci_packet.op_code,
return_parameters=bytes([hci.HCI_SUCCESS]),
num_hci_command_packets = 1,
command_opcode = hci_packet.op_code,
return_parameters = bytes([hci.HCI_SUCCESS])
)
# Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True)
return None
_ = HCI_Bridge(
hci_host_source,
hci_host_sink,
hci_controller_source,
hci_controller_sink,
host_to_controller_filter,
None,
None
)
await asyncio.get_running_loop().create_future()
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main())

View File

@@ -1,350 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import click
from bumble.colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device
from bumble.utils import FlowControlAsyncPipe
from bumble.hci import HCI_Constant
# -----------------------------------------------------------------------------
class ServerBridge:
"""
L2CAP CoC server bridge: waits for a peer to connect an L2CAP CoC channel
on a specified PSM. When the connection is made, the bridge connects a TCP
socket to a remote host and bridges the data in both directions, with flow
control.
When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
and waits for a new L2CAP CoC channel to be connected.
When the TCP connection is closed by the TCP server, XXXX
"""
def __init__(self, psm, max_credits, mtu, mps, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
# Listen for incoming L2CAP CoC connections
device.register_l2cap_channel_server(
psm=self.psm,
server=self.on_coc,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
print(color('@@@ Bluetooth connection:', 'green'), connection)
connection.on('disconnection', on_ble_disconnection)
device.on('connection', on_ble_connection)
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_coc(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
def __init__(self, bridge, l2cap_channel):
self.bridge = bridge
self.tcp_transport = None
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_coc_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
print(
color(
f'### Connecting to TCP {self.bridge.tcp_host}:'
f'{self.bridge.tcp_port}...',
'yellow',
)
)
class TcpClientProtocol(asyncio.Protocol):
def __init__(self, pipe):
self.pipe = pipe
def connection_lost(self, exc):
print(color(f'!!! TCP connection lost: {exc}', 'red'))
if self.pipe.l2cap_channel is not None:
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
def data_received(self, data):
print(f'<<< Received on TCP: {len(data)}')
self.pipe.l2cap_channel.write(data)
try:
(
self.tcp_transport,
_,
) = await asyncio.get_running_loop().create_connection(
lambda: TcpClientProtocol(self),
host=self.bridge.tcp_host,
port=self.bridge.tcp_port,
)
print(color('### Connected', 'green'))
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
await self.l2cap_channel.disconnect()
def on_l2cap_close(self):
self.l2cap_channel = None
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
return
self.tcp_transport.write(sdu)
pipe = Pipe(self, l2cap_channel)
asyncio.create_task(pipe.connect_to_tcp())
# -----------------------------------------------------------------------------
class ClientBridge:
"""
L2CAP CoC client bridge: connects to a BLE device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
L2CAP CoC channel connection to the BLE device is established, and the data
is bridged in both directions, with flow control.
When the TCP connection is closed by the client, the L2CAP CoC channel is
disconnected, but the connection to the BLE device remains, ready for a new
TCP client to connect.
When the L2CAP CoC channel is closed, XXXX
"""
READ_CHUNK_SIZE = 4096
def __init__(self, psm, max_credits, mtu, mps, address, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
print(color(f'### Connecting to {self.address}...', 'yellow'))
connection = await device.connect(self.address)
print(color('### Connected', 'green'))
# Called when the BLE connection is disconnected
def on_ble_disconnection(reason):
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
connection.on('disconnection', on_ble_disconnection)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
def on_l2cap_close():
print(color('*** L2CAP channel closed', 'red'))
l2cap_to_tcp_pipe.stop()
writer.close()
# Connect a new L2CAP channel
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.open_l2cap_channel(
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
writer.close()
return
l2cap_channel.sink = on_coc_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
l2cap_to_tcp_pipe = FlowControlAsyncPipe(
l2cap_channel.pause_reading,
l2cap_channel.resume_reading,
writer.write,
writer.drain,
)
l2cap_to_tcp_pipe.start()
# Pipe data from TCP to L2CAP
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color('!!! End of stream', 'red'))
await l2cap_channel.disconnect()
return
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
l2cap_channel.write(data)
await l2cap_channel.drain()
except Exception as error:
print(f'!!! Exception: {error}')
break
writer.close()
print(color('~~~ Bye bye', 'magenta'))
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != '_' else None,
port=self.tcp_port,
)
print(
color(
f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'
)
)
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Let's go
await device.power_on()
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option(
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
)
@click.option(
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
default=1024,
)
def cli(
context,
device_config,
hci_transport,
psm,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option('--tcp-host', help='TCP host', default='localhost')
@click.option('--tcp-port', help='TCP port', default=9544)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument('bluetooth-address')
@click.option('--tcp-host', help='TCP host', default='_')
@click.option('--tcp-port', help='TCP port', default=9543)
def client(context, bluetooth_address, tcp_host, tcp_port):
bridge = ClientBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
bluetooth_address,
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
cli(obj={}) # pylint: disable=no-value-for-parameter

View File

@@ -16,6 +16,7 @@
# Imports
# ----------------------------------------------------------------------------
import sys
import websockets
import logging
import json
import asyncio
@@ -23,9 +24,7 @@ import argparse
import uuid
import os
from urllib.parse import urlparse
import websockets
from bumble.colors import color
from colors import color
# -----------------------------------------------------------------------------
# Logging
@@ -66,9 +65,9 @@ class Connection:
"""
def __init__(self, room, websocket):
self.room = room
self.room = room
self.websocket = websocket
self.address = str(uuid.uuid4())
self.address = str(uuid.uuid4())
async def send_message(self, message):
try:
@@ -99,11 +98,7 @@ class Connection:
self.address = address
def __str__(self):
return (
f'Connection(address="{self.address}", '
f'client={self.websocket.remote_address[0]}:'
f'{self.websocket.remote_address[1]})'
)
return f'Connection(address="{self.address}", client={self.websocket.remote_address[0]}:{self.websocket.remote_address[1]})'
# ----------------------------------------------------------------------------
@@ -115,9 +110,9 @@ class Room:
"""
def __init__(self, relay, name):
self.relay = relay
self.name = name
self.observers = []
self.relay = relay
self.name = name
self.observers = []
self.connections = []
async def add_connection(self, connection):
@@ -144,15 +139,13 @@ class Room:
# Parse the message to decide how to handle it
if message.startswith('@'):
# This is a targeted message
await self.on_targeted_message(connection, message)
# This is a targetted message
await self.on_targetted_message(connection, message)
elif message.startswith('/'):
# This is an RPC request
await self.on_rpc_request(connection, message)
else:
await connection.send_message(
f'result:{error_to_json("error: invalid message")}'
)
await connection.send_message(f'result:{error_to_json("error: invalid message")}')
async def broadcast_message(self, sender, message):
'''
@@ -162,9 +155,7 @@ class Room:
async def on_rpc_request(self, connection, message):
command, *params = message.split(' ', 1)
if handler := getattr(
self, f'on_{command[1:].lower().replace("-","_")}_command', None
):
if handler := getattr(self, f'on_{command[1:].lower().replace("-","_")}_command', None):
try:
result = await handler(connection, params)
except Exception as error:
@@ -174,7 +165,7 @@ class Room:
await connection.send_message(result or 'result:{}')
async def on_targeted_message(self, connection, message):
async def on_targetted_message(self, connection, message):
target, *payload = message.split(' ', 1)
if not payload:
return error_to_json('missing arguments')
@@ -183,8 +174,7 @@ class Room:
# Determine what targets to send to
if target == '*':
# Send to all connections in the room except the connection from which the
# message was received
# Send to all connections in the room except the connection from which the message was received
connections = [c for c in self.connections if c != connection]
else:
connections = self.find_connections_by_address(target)
@@ -202,9 +192,7 @@ class Room:
current_address = connection.address
new_address = params[0]
connection.set_address(new_address)
await self.broadcast_message(
connection, f'address-changed:from={current_address},to={new_address}'
)
await self.broadcast_message(connection, f'address-changed:from={current_address},to={new_address}')
# ----------------------------------------------------------------------------
@@ -222,10 +210,9 @@ class Relay:
def start(self):
logger.info(f'Starting Relay on port {self.port}')
# pylint: disable-next=no-member
return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None)
async def serve_as_controller(self, connection):
async def serve_as_controller(connection):
pass
async def serve(self, websocket, path):
@@ -259,24 +246,24 @@ def main():
print('ERROR: Python 3.6.1 or higher is required')
sys.exit(1)
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Parse arguments
arg_parser = argparse.ArgumentParser(description='Bumble Link Relay')
arg_parser.add_argument('--log-level', default='INFO', help='logger level')
arg_parser.add_argument('--log-config', help='logger config file (YAML)')
arg_parser.add_argument(
'--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on'
)
arg_parser.add_argument('--port',
type = int,
default = DEFAULT_RELAY_PORT,
help = 'Port to listen on')
args = arg_parser.parse_args()
# Setup logger
if args.log_config:
from logging import config # pylint: disable=import-outside-toplevel
from logging import config
config.fileConfig(args.log_config)
else:
logging.basicConfig(level=getattr(logging, args.log_level.upper()))
logging.basicConfig(level = getattr(logging, args.log_level.upper()))
# Start a relay
relay = Relay(args.port)

View File

@@ -19,12 +19,12 @@ import asyncio
import os
import logging
import click
from prompt_toolkit.shortcuts import PromptSession
import aioconsole
from colors import color
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link
from bumble.pairing import PairingDelegate, PairingConfig
from bumble.smp import PairingDelegate, PairingConfig
from bumble.smp import error_name as smp_error_name
from bumble.keys import JsonKeyStore
from bumble.core import ProtocolError
@@ -33,57 +33,30 @@ from bumble.gatt import (
GATT_GENERIC_ACCESS_SERVICE,
Service,
Characteristic,
CharacteristicValue,
CharacteristicValue
)
from bumble.att import (
ATT_Error,
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR
)
# -----------------------------------------------------------------------------
class Waiter:
instance = None
def __init__(self):
self.done = asyncio.get_running_loop().create_future()
def terminate(self):
self.done.set_result(None)
async def wait_until_terminated(self):
return await self.done
# -----------------------------------------------------------------------------
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__(
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
'none': PairingDelegate.NO_OUTPUT_NO_INPUT,
}[capability_string.lower()]
)
def __init__(self, mode, connection, capability_string, prompt):
super().__init__({
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
'none': PairingDelegate.NO_OUTPUT_NO_INPUT
}[capability_string.lower()])
self.mode = mode
self.peer = Peer(connection)
self.mode = mode
self.peer = Peer(connection)
self.peer_name = None
self.do_prompt = do_prompt
def print(self, message):
print(color(message, 'yellow'))
async def prompt(self, message):
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
session = PromptSession(message)
response = await session.prompt_async()
return response.lower().strip()
self.prompt = prompt
async def update_peer_name(self):
if self.peer_name is not None:
@@ -98,83 +71,87 @@ class Delegate(PairingDelegate):
self.peer_name = '[?]'
async def accept(self):
if self.do_prompt:
if self.prompt:
await self.update_peer_name()
# Prompt for acceptance
self.print('###-----------------------------------')
self.print(f'### Pairing request from {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await self.prompt('>>> Accept? ')
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for acceptance
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing request from {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
while True:
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow'))
response = response.lower().strip()
if response == 'yes':
return True
if response == 'no':
elif response == 'no':
return False
# Accept silently
return True
else:
# Accept silently
return True
async def compare_numbers(self, number, digits):
await self.update_peer_name()
# Prompt for a numeric comparison
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await self.prompt(
f'>>> Does the other device display {number:0{digits}}? '
)
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a numeric comparison
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
while True:
response = await aioconsole.ainput(color(f'>>> Does the other device display {number:0{digits}}? ', 'yellow'))
response = response.lower().strip()
if response == 'yes':
return True
if response == 'no':
elif response == 'no':
return False
async def get_number(self):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a PIN
while True:
try:
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
return int(await self.prompt('>>> Enter PIN: '))
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
return int(await aioconsole.ainput(color('>>> Enter PIN: ', 'yellow')))
except ValueError:
pass
async def display_number(self, number, digits):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Display a PIN code
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print(f'### PIN: {number:0{digits}}')
self.print('###-----------------------------------')
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color(f'### PIN: {number:0{digits}}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
# -----------------------------------------------------------------------------
async def get_peer_name(peer, mode):
if mode == 'classic':
return await peer.request_name()
else:
# Try to get the peer name from GATT
services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
if not services:
return None
# Try to get the peer name from GATT
services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
if not services:
return None
values = await peer.read_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
)
if values:
return values[0].decode('utf-8')
return None
values = await peer.read_characteristics_by_uuid(GATT_DEVICE_NAME_CHARACTERISTIC, services[0])
if values:
return values[0].decode('utf-8')
# -----------------------------------------------------------------------------
@@ -187,12 +164,12 @@ def read_with_error(connection):
if AUTHENTICATION_ERROR_RETURNED[0]:
return bytes([1])
AUTHENTICATION_ERROR_RETURNED[0] = True
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
else:
AUTHENTICATION_ERROR_RETURNED[0] = True
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
def write_with_error(connection, _value):
def write_with_error(connection, value):
if not connection.is_encrypted:
raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
@@ -206,14 +183,14 @@ def on_connection(connection, request):
print(color(f'<<< Connection: {connection}', 'green'))
# Listen for pairing events
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', on_pairing)
connection.on('pairing_start', on_pairing_start)
connection.on('pairing', on_pairing)
connection.on('pairing_failure', on_pairing_failure)
# Listen for encryption changes
connection.on(
'connection_encryption_change',
lambda: on_connection_encryption_change(connection),
lambda: on_connection_encryption_change(connection)
)
# Request pairing if needed
@@ -225,12 +202,7 @@ def on_connection(connection, request):
# -----------------------------------------------------------------------------
def on_connection_encryption_change(connection):
print(color('@@@-----------------------------------', 'blue'))
print(
color(
f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted',
'blue',
)
)
print(color(f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
@@ -247,7 +219,6 @@ def on_pairing(keys):
print(color('*** Paired!', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@@ -255,7 +226,6 @@ def on_pairing_failure(reason):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red'))
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@@ -264,7 +234,6 @@ async def pair(
sc,
mitm,
bond,
ctkd,
io,
prompt,
request,
@@ -272,10 +241,8 @@ async def pair(
keystore_file,
device_config,
hci_transport,
address_or_name,
address_or_name
):
Waiter.instance = Waiter()
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -303,14 +270,11 @@ async def pair(
[
Characteristic(
'552957FB-CF1F-4A31-9535-E78847E1A714',
Characteristic.Properties.READ
| Characteristic.Properties.WRITE,
Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(
read=read_with_error, write=write_with_error
),
CharacteristicValue(read=read_with_error, write=write_with_error)
)
],
]
)
)
@@ -318,14 +282,16 @@ async def pair(
if mode == 'classic':
device.classic_enabled = True
device.le_enabled = False
device.classic_smp_enabled = ctkd
# Get things going
await device.power_on()
# Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig(
sc, mitm, bond, Delegate(mode, connection, io, prompt)
sc,
mitm,
bond,
Delegate(mode, connection, io, prompt)
)
# Connect to a peer or wait for a connection
@@ -345,114 +311,29 @@ async def pair(
print(color(f'Pairing failed: {error}', 'red'))
return
else:
if mode == 'le':
# Advertise so that peers can find us and connect
await device.start_advertising(auto_restart=True)
else:
# Become discoverable and connectable
await device.set_discoverable(True)
await device.set_connectable(True)
# Advertise so that peers can find us and connect
await device.start_advertising(auto_restart=True)
# Run until the user asks to exit
await Waiter.instance.wait_until_terminated()
# -----------------------------------------------------------------------------
class LogHandler(logging.Handler):
def __init__(self):
super().__init__()
self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))
def emit(self, record):
message = self.format(record)
print(message)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True
)
@click.option(
'--sc',
type=bool,
default=True,
help='Use the Secure Connections protocol',
show_default=True,
)
@click.option(
'--mitm', type=bool, default=True, help='Request MITM protection', show_default=True
)
@click.option(
'--bond', type=bool, default=True, help='Enable bonding', show_default=True
)
@click.option(
'--ctkd',
type=bool,
default=True,
help='Enable CTKD',
show_default=True,
)
@click.option(
'--io',
type=click.Choice(
['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']
),
default='display+keyboard',
show_default=True,
)
@click.option('--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True)
@click.option('--sc', type=bool, default=True, help='Use the Secure Connections protocol', show_default=True)
@click.option('--mitm', type=bool, default=True, help='Request MITM protection', show_default=True)
@click.option('--bond', type=bool, default=True, help='Enable bonding', show_default=True)
@click.option('--io', type=click.Choice(['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']), default='display+keyboard', show_default=True)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
)
@click.option('--request', is_flag=True, help='Request that the connecting peer initiate pairing')
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
@click.option(
'--keystore-file',
metavar='<filename>',
help='File in which to store the pairing keys',
)
@click.option('--keystore-file', help='File in which to store the pairing keys')
@click.argument('device-config')
@click.argument('hci_transport')
@click.argument('address-or-name', required=False)
def main(
mode,
sc,
mitm,
bond,
ctkd,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
):
# Setup logging
log_handler = LogHandler()
root_logger = logging.getLogger()
root_logger.addHandler(log_handler)
root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Pair
asyncio.run(
pair(
mode,
sc,
mitm,
bond,
ctkd,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
)
)
def main(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(pair(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name))
# -----------------------------------------------------------------------------

View File

@@ -19,20 +19,20 @@ import asyncio
import os
import logging
import click
from colors import color
from bumble.colors import color
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver
from bumble.device import Advertisement
from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
from bumble.hci import HCI_LE_Advertising_Report_Event
from bumble.core import AdvertisingData
# -----------------------------------------------------------------------------
def make_rssi_bar(rssi):
DISPLAY_MIN_RSSI = -105
DISPLAY_MAX_RSSI = -30
DISPLAY_MIN_RSSI = -105
DISPLAY_MAX_RSSI = -30
DEFAULT_RSSI_BAR_WIDTH = 30
blocks = ['', '', '', '', '', '', '', '']
@@ -48,24 +48,19 @@ class AdvertisementPrinter:
self.min_rssi = min_rssi
self.resolver = resolver
def print_advertisement(self, advertisement):
address = advertisement.address
address_color = 'yellow' if advertisement.is_connectable else 'red'
if self.min_rssi is not None and advertisement.rssi < self.min_rssi:
def print_advertisement(self, address, address_color, ad_data, rssi):
if self.min_rssi is not None and rssi < self.min_rssi:
return
address_qualifier = ''
resolution_qualifier = ''
if self.resolver and advertisement.address.is_resolvable:
resolved = self.resolver.resolve(advertisement.address)
if self.resolver and address.is_resolvable:
resolved = self.resolver.resolve(address)
if resolved is not None:
resolution_qualifier = f'(resolved from {advertisement.address})'
resolution_qualifier = f'(resolved from {address})'
address = resolved
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[address.address_type]
if address.is_public:
type_color = 'cyan'
else:
@@ -79,32 +74,18 @@ class AdvertisementPrinter:
type_color = 'blue'
address_qualifier = '(non-resolvable)'
rssi_bar = make_rssi_bar(rssi)
separator = '\n '
rssi_bar = make_rssi_bar(advertisement.rssi)
if not advertisement.is_legacy:
phy_info = (
f'PHY: {HCI_Constant.le_phy_name(advertisement.primary_phy)}/'
f'{HCI_Constant.le_phy_name(advertisement.secondary_phy)} '
f'{separator}'
)
else:
phy_info = ''
print(f'>>> {color(address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}{resolution_qualifier}:{separator}RSSI:{rssi:4} {rssi_bar}{separator}{ad_data.to_string(separator)}\n')
print(
f'>>> {color(address, address_color)} '
f'[{color(address_type_string, type_color)}]{address_qualifier}'
f'{resolution_qualifier}:{separator}'
f'{phy_info}'
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
f'{advertisement.data.to_string(separator)}\n'
)
def on_advertisement(self, address, ad_data, rssi, connectable):
address_color = 'yellow' if connectable else 'red'
self.print_advertisement(address, address_color, ad_data, rssi)
def on_advertisement(self, advertisement):
self.print_advertisement(advertisement)
def on_advertising_report(self, report):
print(f'{color("EVENT", "green")}: {report.event_type_string()}')
self.print_advertisement(Advertisement.from_advertising_report(report))
def on_advertising_report(self, address, ad_data, rssi, event_type):
print(f'{color("EVENT", "green")}: {HCI_LE_Advertising_Report_Event.event_type_name(event_type)}')
ad_data = AdvertisingData.from_bytes(ad_data)
self.print_advertisement(address, 'yellow', ad_data, rssi)
# -----------------------------------------------------------------------------
@@ -113,25 +94,20 @@ async def scan(
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
transport
):
print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected')
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
else:
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
if keystore_file:
keystore = JsonKeyStore(namespace=None, filename=keystore_file)
@@ -150,18 +126,11 @@ async def scan(
device.on('advertisement', printer.on_advertisement)
await device.power_on()
if phy is None:
scanning_phys = [HCI_LE_1M_PHY, HCI_LE_CODED_PHY]
else:
scanning_phys = [{'1m': HCI_LE_1M_PHY, 'coded': HCI_LE_CODED_PHY}[phy]]
await device.start_scanning(
active=(not passive),
scan_interval=scan_interval,
scan_window=scan_window,
filter_duplicates=filter_duplicates,
scanning_phys=scanning_phys,
filter_duplicates=filter_duplicates
)
await hci_source.wait_for_termination()
@@ -173,51 +142,14 @@ async def scan(
@click.option('--passive', is_flag=True, default=False, help='Perform passive scanning')
@click.option('--scan-interval', type=int, default=60, help='Scan interval')
@click.option('--scan-window', type=int, default=60, help='Scan window')
@click.option(
'--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY'
)
@click.option(
'--filter-duplicates',
type=bool,
default=True,
help='Filter duplicates at the controller level',
)
@click.option(
'--raw',
is_flag=True,
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option('--filter-duplicates', type=bool, default=True, help='Filter duplicates at the controller level')
@click.option('--raw', is_flag=True, default=False, help='Listen for raw advertising reports instead of processed ones')
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device')
@click.argument('transport')
def main(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(
scan(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
)
)
def main(min_rssi, passive, scan_interval, scan_window, filter_duplicates, raw, keystore_file, device_config, transport):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(scan(min_rssi, passive, scan_interval, scan_window, filter_duplicates, raw, keystore_file, device_config, transport))
# -----------------------------------------------------------------------------

View File

@@ -17,8 +17,8 @@
# -----------------------------------------------------------------------------
import struct
import click
from colors import color
from bumble.colors import color
from bumble import hci
from bumble.transport.common import PacketReader
from bumble.helpers import PacketTracer
@@ -27,14 +27,13 @@ from bumble.helpers import PacketTracer
# -----------------------------------------------------------------------------
class SnoopPacketReader:
'''
Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not
exactly the same...)
Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not exactly the same...)
'''
DATALINK_H1 = 1001
DATALINK_H4 = 1002
DATALINK_H1 = 1001
DATALINK_H4 = 1002
DATALINK_BSCP = 1003
DATALINK_H5 = 1004
DATALINK_H5 = 1004
def __init__(self, source):
self.source = source
@@ -42,13 +41,9 @@ class SnoopPacketReader:
# Read the header
identification_pattern = source.read(8)
if identification_pattern.hex().lower() != '6274736e6f6f7000':
raise ValueError(
'not a valid snoop file, unexpected identification pattern'
)
(self.version_number, self.data_link_type) = struct.unpack(
'>II', source.read(8)
)
if self.data_link_type not in (self.DATALINK_H4, self.DATALINK_H1):
raise ValueError('not a valid snoop file, unexpected identification pattern')
(self.version_number, self.data_link_type) = struct.unpack('>II', source.read(8))
if self.data_link_type != self.DATALINK_H4 and self.data_link_type != self.DATALINK_H1:
raise ValueError(f'datalink type {self.data_link_type} not supported')
def next_packet(self):
@@ -60,9 +55,9 @@ class SnoopPacketReader:
original_length,
included_length,
packet_flags,
_cumulative_drops,
_timestamp_seconds,
_timestamp_microsecond,
cumulative_drops,
timestamp_seconds,
timestamp_microsecond
) = struct.unpack('>IIIIII', header)
# Abort on truncated packets
@@ -84,34 +79,24 @@ class SnoopPacketReader:
else:
packet_type = hci.HCI_ACL_DATA_PACKET
return (
packet_flags & 1,
bytes([packet_type]) + self.source.read(included_length),
)
return (packet_flags & 1, self.source.read(included_length))
return (packet_flags & 1, bytes([packet_type]) + self.source.read(included_length))
else:
return (packet_flags & 1, self.source.read(included_length))
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--format',
type=click.Choice(['h4', 'snoop']),
default='h4',
help='Format of the input file',
)
@click.option('--format', type=click.Choice(['h4', 'snoop']), default='h4', help='Format of the input file')
@click.argument('filename')
# pylint: disable=redefined-builtin
def main(format, filename):
def show(format, filename):
input = open(filename, 'rb')
if format == 'h4':
packet_reader = PacketReader(input)
def read_next_packet():
return (0, packet_reader.next_packet())
(0, packet_reader.next_packet())
else:
packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet
@@ -127,8 +112,9 @@ def main(format, filename):
except Exception as error:
print(color(f'!!! {error}', 'red'))
pass
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter
show()

View File

@@ -54,7 +54,7 @@ async def unbond(keystore_file, device_config, address):
@click.argument('device-config')
@click.argument('address', required=False)
def main(keystore_file, device_config, address):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(unbond(keystore_file, device_config, address))

View File

@@ -1,278 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# This tool lists all the USB devices, with details about each device.
# For each device, the different possible Bumble transport strings that can
# refer to it are listed. If the device is known to be a Bluetooth HCI device,
# its identifier is printed in reverse colors, and the transport names in cyan color.
# For other devices, regardless of their type, the transport names are printed
# in red. Whether that device is actually a Bluetooth device or not depends on
# whether it is a Bluetooth device that uses a non-standard Class, or some other
# type of device (there's no way to tell).
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import os
import logging
import click
import usb1
from bumble.colors import color
from bumble.transport.usb import load_libusb
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
USB_DEVICE_CLASS_DEVICE = 0x00
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_DEVICE_CLASSES = {
0x00: 'Device',
0x01: 'Audio',
0x02: 'Communications and CDC Control',
0x03: 'Human Interface Device',
0x05: 'Physical',
0x06: 'Still Imaging',
0x07: 'Printer',
0x08: 'Mass Storage',
0x09: 'Hub',
0x0A: 'CDC Data',
0x0B: 'Smart Card',
0x0D: 'Content Security',
0x0E: 'Video',
0x0F: 'Personal Healthcare',
0x10: 'Audio/Video',
0x11: 'Billboard',
0x12: 'USB Type-C Bridge',
0x3C: 'I3C',
0xDC: 'Diagnostic',
USB_DEVICE_CLASS_WIRELESS_CONTROLLER: (
'Wireless Controller',
{
0x01: {
0x01: 'Bluetooth',
0x02: 'UWB',
0x03: 'Remote NDIS',
0x04: 'Bluetooth AMP',
}
},
),
0xEF: 'Miscellaneous',
0xFE: 'Application Specific',
0xFF: 'Vendor Specific',
}
USB_ENDPOINT_IN = 0x80
USB_ENDPOINT_TYPES = ['CONTROL', 'ISOCHRONOUS', 'BULK', 'INTERRUPT']
USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
# -----------------------------------------------------------------------------
def show_device_details(device):
for configuration in device:
print(f' Configuration {configuration.getConfigurationValue()}')
for interface in configuration:
for setting in interface:
alternate_setting = setting.getAlternateSetting()
suffix = (
f'/{alternate_setting}' if interface.getNumSettings() > 1 else ''
)
(class_string, subclass_string) = get_class_info(
setting.getClass(), setting.getSubClass(), setting.getProtocol()
)
details = f'({class_string}, {subclass_string})'
print(f' Interface: {setting.getNumber()}{suffix} {details}')
for endpoint in setting:
endpoint_type = USB_ENDPOINT_TYPES[endpoint.getAttributes() & 3]
endpoint_direction = (
'OUT'
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
else 'IN'
)
print(
f' Endpoint 0x{endpoint.getAddress():02X}: '
f'{endpoint_type} {endpoint_direction}'
)
# -----------------------------------------------------------------------------
def get_class_info(cls, subclass, protocol):
class_info = USB_DEVICE_CLASSES.get(cls)
protocol_string = ''
if class_info is None:
class_string = f'0x{cls:02X}'
else:
if isinstance(class_info, tuple):
class_string = class_info[0]
subclass_info = class_info[1].get(subclass)
if subclass_info:
protocol_string = subclass_info.get(protocol)
if protocol_string is not None:
protocol_string = f' [{protocol_string}]'
else:
class_string = class_info
subclass_string = f'{subclass}/{protocol}{protocol_string}'
return (class_string, subclass_string)
# -----------------------------------------------------------------------------
def is_bluetooth_hci(device):
# Check if the device class indicates a match
if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
# If the device class is 'Device', look for a matching interface
if device.getDeviceClass() == USB_DEVICE_CLASS_DEVICE:
for configuration in device:
for interface in configuration:
for setting in interface:
if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
return False
# -----------------------------------------------------------------------------
@click.command()
@click.option('--verbose', is_flag=True, default=False, help='Print more details')
def main(verbose):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
load_libusb()
with usb1.USBContext() as context:
bluetooth_device_count = 0
devices = {}
for device in context.getDeviceIterator(skip_on_error=True):
device_class = device.getDeviceClass()
device_subclass = device.getDeviceSubClass()
device_protocol = device.getDeviceProtocol()
device_id = (device.getVendorID(), device.getProductID())
(device_class_string, device_subclass_string) = get_class_info(
device_class, device_subclass, device_protocol
)
try:
device_serial_number = device.getSerialNumber()
except usb1.USBError:
device_serial_number = None
try:
device_manufacturer = device.getManufacturer()
except usb1.USBError:
device_manufacturer = None
try:
device_product = device.getProduct()
except usb1.USBError:
device_product = None
device_is_bluetooth_hci = is_bluetooth_hci(device)
if device_is_bluetooth_hci:
bluetooth_device_count += 1
fg_color = 'black'
bg_color = 'yellow'
else:
fg_color = 'yellow'
bg_color = 'black'
# Compute the different ways this can be referenced as a Bumble transport
bumble_transport_names = []
basic_transport_name = (
f'usb:{device.getVendorID():04X}:{device.getProductID():04X}'
)
if device_is_bluetooth_hci:
bumble_transport_names.append(f'usb:{bluetooth_device_count - 1}')
if device_id not in devices:
bumble_transport_names.append(basic_transport_name)
else:
bumble_transport_names.append(
f'{basic_transport_name}#{len(devices[device_id])}'
)
if device_serial_number is not None:
if (
device_id not in devices
or device_serial_number not in devices[device_id]
):
bumble_transport_names.append(
f'{basic_transport_name}/{device_serial_number}'
)
# Print the results
print(
color(
f'ID {device.getVendorID():04X}:{device.getProductID():04X}',
fg=fg_color,
bg=bg_color,
)
)
if bumble_transport_names:
print(
color(' Bumble Transport Names:', 'blue'),
' or '.join(
color(x, 'cyan' if device_is_bluetooth_hci else 'red')
for x in bumble_transport_names
),
)
print(
color(' Bus/Device: ', 'green'),
f'{device.getBusNumber():03}/{device.getDeviceAddress():03}',
)
print(color(' Class: ', 'green'), device_class_string)
print(color(' Subclass/Protocol: ', 'green'), device_subclass_string)
if device_serial_number is not None:
print(color(' Serial: ', 'green'), device_serial_number)
if device_manufacturer is not None:
print(color(' Manufacturer: ', 'green'), device_manufacturer)
if device_product is not None:
print(color(' Product: ', 'green'), device_product)
if verbose:
show_device_details(device)
print()
devices.setdefault(device_id, []).append(device_serial_number)
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter

View File

@@ -1,4 +0,0 @@
try:
from ._version import version as __version__
except ImportError:
__version__ = "unknown version"

View File

@@ -16,8 +16,10 @@
# Imports
# -----------------------------------------------------------------------------
import struct
import bitstruct
import logging
from collections import namedtuple
from colors import color
from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
@@ -28,7 +30,7 @@ from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
from .core import (
BT_L2CAP_PROTOCOL_ID,
@@ -36,7 +38,7 @@ from .core import (
BT_AUDIO_SINK_SERVICE,
BT_AVDTP_PROTOCOL_ID,
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number,
name_or_number
)
@@ -49,7 +51,6 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
A2DP_SBC_CODEC_TYPE = 0x00
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01
@@ -126,115 +127,71 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
# fmt: on
# -----------------------------------------------------------------------------
def flags_to_list(flags, values):
result = []
for i, value in enumerate(values):
for i in range(len(values)):
if flags & (1 << (len(values) - i - 1)):
result.append(value)
result.append(values[i])
return result
# -----------------------------------------------------------------------------
def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1]
return [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)),
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
])),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)
])),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.sequence([
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM)
]),
DataElement.sequence([
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int)
])
])),
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int)
])),
]
# -----------------------------------------------------------------------------
def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1]
return [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)),
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
])),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_AUDIO_SINK_SERVICE)
])),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.sequence([
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM)
]),
DataElement.sequence([
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int)
])
])),
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int)
])),
]
@@ -249,46 +206,45 @@ class SbcMediaCodecInformation(
'subbands',
'allocation_method',
'minimum_bitpool_value',
'maximum_bitpool_value',
],
'maximum_bitpool_value'
]
)
):
'''
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1,
BIT_FIELDS = 'u4u4u4u2u2u8u8'
SAMPLING_FREQUENCY_BITS = {
16000: 1 << 3,
32000: 1 << 2,
44100: 1 << 1,
48000: 1
}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1
}
BLOCK_LENGTH_BITS = {
4: 1 << 3,
8: 1 << 2,
12: 1 << 1,
16: 1
}
SUBBANDS_BITS = {
4: 1 << 1,
8: 1
}
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
ALLOCATION_METHOD_BITS = {
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1,
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1
}
@staticmethod
def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
subbands = (data[1] >> 2) & 0x03
allocation_method = (data[1] >> 0) & 0x03
minimum_bitpool_value = (data[2] >> 0) & 0xFF
maximum_bitpool_value = (data[3] >> 0) & 0xFF
return SbcMediaCodecInformation(
sampling_frequency,
channel_mode,
block_length,
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value,
)
def from_bytes(data):
return SbcMediaCodecInformation(*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data))
@classmethod
def from_discrete_values(
@@ -299,16 +255,16 @@ class SbcMediaCodecInformation(
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value,
maximum_bitpool_value
):
return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
block_length=cls.BLOCK_LENGTH_BITS[block_length],
subbands=cls.SUBBANDS_BITS[subbands],
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode = cls.CHANNEL_MODE_BITS[channel_mode],
block_length = cls.BLOCK_LENGTH_BITS[block_length],
subbands = cls.SUBBANDS_BITS[subbands],
allocation_method = cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value = minimum_bitpool_value,
maximum_bitpool_value = maximum_bitpool_value
)
@classmethod
@@ -320,71 +276,63 @@ class SbcMediaCodecInformation(
subbands,
allocation_methods,
minimum_bitpool_value,
maximum_bitpool_value,
maximum_bitpool_value
):
return SbcMediaCodecInformation(
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
allocation_method=sum(
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
),
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies),
channel_mode = sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
block_length = sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
subbands = sum(cls.SUBBANDS_BITS[x] for x in subbands),
allocation_method = sum(cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods),
minimum_bitpool_value = minimum_bitpool_value,
maximum_bitpool_value = maximum_bitpool_value
)
def __bytes__(self) -> bytes:
return bytes(
[
(self.sampling_frequency << 4) | self.channel_mode,
(self.block_length << 4)
| (self.subbands << 2)
| self.allocation_method,
self.minimum_bitpool_value,
self.maximum_bitpool_value,
]
)
def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self)
def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join(
# pylint: disable=line-too-long
[
'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
]
)
return '\n'.join([
'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}'
')'
])
# -----------------------------------------------------------------------------
class AacMediaCodecInformation(
namedtuple(
'AacMediaCodecInformation',
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
[
'object_type',
'sampling_frequency',
'channels',
'vbr',
'bitrate'
]
)
):
'''
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
BIT_FIELDS = 'u8u12u2p2u1u23'
OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4
}
SAMPLING_FREQUENCY_BITS = {
8000: 1 << 11,
8000: 1 << 11,
11025: 1 << 10,
12000: 1 << 9,
16000: 1 << 8,
@@ -395,82 +343,66 @@ class AacMediaCodecInformation(
48000: 1 << 3,
64000: 1 << 2,
88200: 1 << 1,
96000: 1,
96000: 1
}
CHANNELS_BITS = {
1: 1 << 1,
2: 1
}
CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod
def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
rfa = 0
vbr = (data[3] >> 7) & 0x01
bitrate = ((data[3] & 0x7F) << 16) | (data[4] << 8) | data[5]
return AacMediaCodecInformation(
object_type, sampling_frequency, channels, rfa, vbr, bitrate
)
def from_bytes(data):
return AacMediaCodecInformation(*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data))
@classmethod
def from_discrete_values(
cls, object_type, sampling_frequency, channels, vbr, bitrate
cls,
object_type,
sampling_frequency,
channels,
vbr,
bitrate
):
return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels],
rfa=0,
vbr=vbr,
bitrate=bitrate,
object_type = cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels = cls.CHANNELS_BITS[channels],
vbr = vbr,
bitrate = bitrate
)
@classmethod
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
def from_lists(
cls,
object_types,
sampling_frequencies,
channels,
vbr,
bitrate
):
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channels=sum(cls.CHANNELS_BITS[x] for x in channels),
vbr=vbr,
bitrate=bitrate,
object_type = sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies),
channels = sum(cls.CHANNELS_BITS[x] for x in channels),
vbr = vbr,
bitrate = bitrate
)
def __bytes__(self) -> bytes:
return bytes(
[
self.object_type & 0xFF,
(self.sampling_frequency >> 4) & 0xFF,
(((self.sampling_frequency & 0x0F) << 4) | (self.channels << 2)) & 0xFF,
((self.vbr << 7) | ((self.bitrate >> 16) & 0x7F)) & 0xFF,
((self.bitrate >> 8) & 0xFF) & 0xFF,
self.bitrate & 0xFF,
]
)
def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self)
def __str__(self):
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
'MPEG_4_AAC_LTP',
'MPEG_4_AAC_SCALABLE',
'[4]',
'[5]',
'[6]',
'[7]',
]
object_types = ['MPEG_2_AAC_LC', 'MPEG_4_AAC_LC', 'MPEG_4_AAC_LTP', 'MPEG_4_AAC_SCALABLE', '[4]', '[5]', '[6]', '[7]']
channels = [1, 2]
# pylint: disable=line-too-long
return '\n'.join(
[
'AacMediaCodecInformation(',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' vbr: {self.vbr}',
f' bitrate: {self.bitrate}' ')',
]
)
return '\n'.join([
'AacMediaCodecInformation(',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' vbr: {self.vbr}',
f' bitrate: {self.bitrate}'
')'
])
# -----------------------------------------------------------------------------
@@ -486,34 +418,37 @@ class VendorSpecificMediaCodecInformation:
def __init__(self, vendor_id, codec_id, value):
self.vendor_id = vendor_id
self.codec_id = codec_id
self.value = value
self.codec_id = codec_id
self.value = value
def __bytes__(self):
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self):
# pylint: disable=line-too-long
return '\n'.join(
[
'VendorSpecificMediaCodecInformation(',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
f' codec_id: {self.codec_id:04X}',
f' value: {self.value.hex()}' ')',
]
)
return '\n'.join([
'VendorSpecificMediaCodecInformation(',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
f' codec_id: {self.codec_id:04X}',
f' value: {self.value.hex()}'
')'
])
# -----------------------------------------------------------------------------
class SbcFrame:
def __init__(
self, sampling_frequency, block_count, channel_mode, subband_count, payload
self,
sampling_frequency,
block_count,
channel_mode,
subband_count,
payload
):
self.sampling_frequency = sampling_frequency
self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
@property
def sample_count(self):
@@ -528,13 +463,7 @@ class SbcFrame:
return self.sample_count / self.sampling_frequency
def __str__(self):
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
f'br={self.bitrate},'
f'sc={self.sample_count},'
f'size={len(self.payload)})'
)
return f'SBC(sf={self.sampling_frequency},cm={self.channel_mode},br={self.bitrate},sc={self.sample_count},size={len(self.payload)})'
# -----------------------------------------------------------------------------
@@ -558,30 +487,24 @@ class SbcParser:
# Extract some of the header fields
sampling_frequency = SBC_SAMPLING_FREQUENCIES[(header[1] >> 6) & 3]
blocks = 4 * (1 + ((header[1] >> 4) & 3))
channel_mode = (header[1] >> 2) & 3
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
subbands = 8 if ((header[1]) & 1) else 4
bitpool = header[2]
blocks = 4 * (1 + ((header[1] >> 4) & 3))
channel_mode = (header[1] >> 2) & 3
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
subbands = 8 if ((header[1]) & 1) else 4
bitpool = header[2]
# Compute the frame length
frame_length = 4 + (4 * subbands * channels) // 8
if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE):
frame_length += (blocks * channels * bitpool) // 8
else:
frame_length += (
(1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0)
* subbands
+ blocks * bitpool
) // 8
frame_length += ((1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0) * subbands + blocks * bitpool) // 8
# Read the rest of the frame
payload = header + await self.read(frame_length - 4)
# Emit the next frame
yield SbcFrame(
sampling_frequency, blocks, channel_mode, subbands, payload
)
yield SbcFrame(sampling_frequency, blocks, channel_mode, subbands, payload)
return generate_frames()
@@ -589,20 +512,19 @@ class SbcParser:
# -----------------------------------------------------------------------------
class SbcPacketSource:
def __init__(self, read, mtu, codec_capabilities):
self.read = read
self.mtu = mtu
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities
@property
def packets(self):
async def generate_packets():
# pylint: disable=import-outside-toplevel
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
timestamp = 0
frames = []
frames_size = 0
timestamp = 0
frames = []
frames_size = 0
max_rtp_payload = self.mtu - 12 - 1
# NOTE: this doesn't support frame fragments
@@ -610,25 +532,18 @@ class SbcPacketSource:
async for frame in sbc_parser.frames:
print(frame)
if (
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == 16
):
if frames_size + len(frame.payload) > max_rtp_payload or len(frames) == 16:
# Need to flush what has been accumulated so far
# Emit a packet
sbc_payload = bytes([len(frames)]) + b''.join(
[frame.payload for frame in frames]
)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
sbc_payload = bytes([len(frames)]) + b''.join([frame.payload for frame in frames])
packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload)
packet.timestamp_seconds = timestamp / frame.sampling_frequency
yield packet
# Prepare for next packets
sequence_number += 1
timestamp += sum((frame.sample_count for frame in frames))
timestamp += sum([frame.sample_count for frame in frames])
frames = [frame]
frames_size = len(frame.payload)
else:

View File

@@ -22,25 +22,15 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import functools
import struct
from colors import color
from pyee import EventEmitter
from typing import Dict, Type, TYPE_CHECKING
from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
from bumble.hci import HCI_Object, key_with_value, HCI_Constant
from bumble.colors import color
if TYPE_CHECKING:
from bumble.device import Connection
from .core import *
from .hci import *
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
ATT_CID = 0x04
ATT_ERROR_RESPONSE = 0x01
@@ -173,30 +163,30 @@ ATT_ERROR_NAMES = {
ATT_DEFAULT_MTU = 23
HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'}
# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y)
# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) # noqa: E731
UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def key_with_value(dictionary, target_value):
for key, value in dictionary.items():
if value == target_value:
return key
return None
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
class ATT_Error(ProtocolError):
def __init__(self, error_code, att_handle=0x0000, message=''):
super().__init__(
error_code,
error_namespace='att',
error_name=ATT_PDU.error_name(error_code),
)
class ATT_Error(Exception):
def __init__(self, error_code, att_handle=0x0000):
self.error_code = error_code
self.att_handle = att_handle
self.message = message
def __str__(self):
return f'ATT_Error(error={self.error_name}, handle={self.att_handle:04X}): {self.message}'
return f'ATT_Error({ATT_PDU.error_name(self.error_code)})'
# -----------------------------------------------------------------------------
@@ -206,10 +196,8 @@ class ATT_PDU:
'''
See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU
'''
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
pdu_classes = {}
op_code = 0
name = None
@staticmethod
def from_bytes(pdu):
@@ -286,13 +274,11 @@ class ATT_PDU:
# -----------------------------------------------------------------------------
@ATT_PDU.subclass(
[
('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}),
('attribute_handle_in_error', HANDLE_FIELD_SPEC),
('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}),
]
)
@ATT_PDU.subclass([
('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}),
('attribute_handle_in_error', HANDLE_FIELD_SPEC),
('error_code', {'size': 1, 'mapper': ATT_PDU.error_name})
])
class ATT_Error_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
@@ -300,7 +286,9 @@ class ATT_Error_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('client_rx_mtu', 2)])
@ATT_PDU.subclass([
('client_rx_mtu', 2)
])
class ATT_Exchange_MTU_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request
@@ -308,7 +296,9 @@ class ATT_Exchange_MTU_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('server_rx_mtu', 2)])
@ATT_PDU.subclass([
('server_rx_mtu', 2)
])
class ATT_Exchange_MTU_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response
@@ -316,9 +306,10 @@ class ATT_Exchange_MTU_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass(
[('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC)]
)
@ATT_PDU.subclass([
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC)
])
class ATT_Find_Information_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -326,7 +317,10 @@ class ATT_Find_Information_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('format', 1), ('information_data', '*')])
@ATT_PDU.subclass([
('format', 1),
('information_data', '*')
])
class ATT_Find_Information_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response
@@ -338,7 +332,7 @@ class ATT_Find_Information_Response(ATT_PDU):
uuid_size = 2 if self.format == 1 else 16
while offset + uuid_size <= len(self.information_data):
handle = struct.unpack_from('<H', self.information_data, offset)[0]
uuid = self.information_data[2 + offset : 2 + offset + uuid_size]
uuid = self.information_data[2 + offset:2 + offset + uuid_size]
self.information.append((handle, uuid))
offset += 2 + uuid_size
@@ -352,33 +346,20 @@ class ATT_Find_Information_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('format', 1),
(
'information',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x]
)
},
),
],
' ',
)
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
('format', 1),
('information', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x])})
], ' ')
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_FIELD_SPEC),
('attribute_value', '*'),
]
)
@ATT_PDU.subclass([
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Find_By_Type_Value_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request
@@ -386,7 +367,9 @@ class ATT_Find_By_Type_Value_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('handles_information_list', '*')])
@ATT_PDU.subclass([
('handles_information_list', '*')
])
class ATT_Find_By_Type_Value_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response
@@ -396,9 +379,7 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
self.handles_information = []
offset = 0
while offset + 4 <= len(self.handles_information_list):
found_attribute_handle, group_end_handle = struct.unpack_from(
'<HH', self.handles_information_list, offset
)
found_attribute_handle, group_end_handle = struct.unpack_from('<HH', self.handles_information_list, offset)
self.handles_information.append((found_attribute_handle, group_end_handle))
offset += 4
@@ -412,34 +393,18 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
(
'handles_information',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle1:04X}-0x{handle2:04X}'
for handle1, handle2 in x
]
)
},
)
],
' ',
)
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
('handles_information', {'mapper': lambda x: ', '.join([f'0x{handle1:04X}-0x{handle2:04X}' for handle1, handle2 in x])})
], ' ')
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_16_FIELD_SPEC),
]
)
@ATT_PDU.subclass([
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_16_FIELD_SPEC)
])
class ATT_Read_By_Type_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request
@@ -447,7 +412,10 @@ class ATT_Read_By_Type_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
@ATT_PDU.subclass([
('length', 1),
('attribute_data_list', '*')
])
class ATT_Read_By_Type_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response
@@ -456,15 +424,9 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def parse_attribute_data_list(self):
self.attributes = []
offset = 0
while self.length != 0 and offset + self.length <= len(
self.attribute_data_list
):
(attribute_handle,) = struct.unpack_from(
'<H', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 2 : offset + self.length
]
while self.length != 0 and offset + self.length <= len(self.attribute_data_list):
attribute_handle, = struct.unpack_from('<H', self.attribute_data_list, offset)
attribute_value = self.attribute_data_list[offset + 2:offset + self.length]
self.attributes.append((attribute_handle, attribute_value))
offset += self.length
@@ -478,26 +440,17 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('length', 1),
(
'attributes',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{value.hex()}' for handle, value in x]
)
},
),
],
' ',
)
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
('length', 1),
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{value.hex()}' for handle, value in x])})
], ' ')
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC)])
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC)
])
class ATT_Read_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request
@@ -505,7 +458,9 @@ class ATT_Read_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_value', '*')])
@ATT_PDU.subclass([
('attribute_value', '*')
])
class ATT_Read_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response
@@ -513,7 +468,10 @@ class ATT_Read_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('value_offset', 2)])
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2)
])
class ATT_Read_Blob_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -521,7 +479,9 @@ class ATT_Read_Blob_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('part_attribute_value', '*')])
@ATT_PDU.subclass([
('part_attribute_value', '*')
])
class ATT_Read_Blob_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response
@@ -529,7 +489,9 @@ class ATT_Read_Blob_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('set_of_handles', '*')])
@ATT_PDU.subclass([
('set_of_handles', '*')
])
class ATT_Read_Multiple_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request
@@ -537,7 +499,9 @@ class ATT_Read_Multiple_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('set_of_values', '*')])
@ATT_PDU.subclass([
('set_of_values', '*')
])
class ATT_Read_Multiple_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response
@@ -545,13 +509,11 @@ class ATT_Read_Multiple_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_group_type', UUID_2_16_FIELD_SPEC),
]
)
@ATT_PDU.subclass([
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_group_type', UUID_2_16_FIELD_SPEC)
])
class ATT_Read_By_Group_Type_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -559,7 +521,10 @@ class ATT_Read_By_Group_Type_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
@ATT_PDU.subclass([
('length', 1),
('attribute_data_list', '*')
])
class ATT_Read_By_Group_Type_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response
@@ -568,18 +533,10 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def parse_attribute_data_list(self):
self.attributes = []
offset = 0
while self.length != 0 and offset + self.length <= len(
self.attribute_data_list
):
attribute_handle, end_group_handle = struct.unpack_from(
'<HH', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 4 : offset + self.length
]
self.attributes.append(
(attribute_handle, end_group_handle, attribute_value)
)
while self.length != 0 and offset + self.length <= len(self.attribute_data_list):
attribute_handle, end_group_handle = struct.unpack_from('<HH', self.attribute_data_list, offset)
attribute_value = self.attribute_data_list[offset + 4:offset + self.length]
self.attributes.append((attribute_handle, end_group_handle, attribute_value))
offset += self.length
def __init__(self, *args, **kwargs):
@@ -592,29 +549,18 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('length', 1),
(
'attributes',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle:04X}-0x{end:04X}:{value.hex()}'
for handle, end, value in x
]
)
},
),
],
' ',
)
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
('length', 1),
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}-0x{end:04X}:{value.hex()}' for handle, end, value in x])})
], ' ')
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Write_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request
@@ -630,7 +576,10 @@ class ATT_Write_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Write_Command(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command
@@ -638,13 +587,11 @@ class ATT_Write_Command(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
# ('authentication_signature', 'TODO')
]
)
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
# ('authentication_signature', 'TODO')
])
class ATT_Signed_Write_Command(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command
@@ -652,13 +599,11 @@ class ATT_Signed_Write_Command(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*'),
]
)
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*')
])
class ATT_Prepare_Write_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request
@@ -666,13 +611,11 @@ class ATT_Prepare_Write_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*'),
]
)
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*')
])
class ATT_Prepare_Write_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response
@@ -696,7 +639,10 @@ class ATT_Execute_Write_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Handle_Value_Notification(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification
@@ -704,7 +650,10 @@ class ATT_Handle_Value_Notification(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Handle_Value_Indication(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication
@@ -722,143 +671,58 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# -----------------------------------------------------------------------------
class Attribute(EventEmitter):
# Permission flags
READABLE = 0x01
WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10
READABLE = 0x01
WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10
WRITE_REQUIRES_AUTHENTICATION = 0x20
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
PERMISSION_NAMES = {
READABLE: 'READABLE',
WRITEABLE: 'WRITEABLE',
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
}
@staticmethod
def string_to_permissions(permissions_str: str):
try:
return functools.reduce(
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
permissions_str.split(","),
0,
)
except TypeError as exc:
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
) from exc
def __init__(self, attribute_type, permissions, value=b''):
def __init__(self, attribute_type, permissions, value = b''):
EventEmitter.__init__(self)
self.handle = 0
self.handle = 0
self.end_group_handle = 0
if isinstance(permissions, str):
self.permissions = self.string_to_permissions(permissions)
else:
self.permissions = permissions
self.permissions = permissions
# Convert the type to a UUID object if it isn't already
if isinstance(attribute_type, str):
if type(attribute_type) is str:
self.type = UUID(attribute_type)
elif isinstance(attribute_type, bytes):
elif type(attribute_type) is bytes:
self.type = UUID.from_bytes(attribute_type)
else:
self.type = attribute_type
# Convert the value to a byte array
if isinstance(value, str):
if type(value) is str:
self.value = bytes(value, 'utf-8')
else:
self.value = value
def encode_value(self, value):
return value
def decode_value(self, value_bytes):
return value_bytes
def read_value(self, connection: Connection):
if (
self.permissions & self.READ_REQUIRES_ENCRYPTION
) and not connection.encryption:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
self.permissions & self.READ_REQUIRES_AUTHENTICATION
) and not connection.authenticated:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
if self.permissions & self.READ_REQUIRES_AUTHORIZATION:
# TODO: handle authorization better
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
def read_value(self, connection):
if read := getattr(self.value, 'read', None):
try:
value = read(connection) # pylint: disable=not-callable
return read(connection)
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
else:
value = self.value
return self.encode_value(value)
def write_value(self, connection: Connection, value_bytes):
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
self.permissions & self.WRITE_REQUIRES_AUTHENTICATION
) and not connection.authenticated:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
if self.permissions & self.WRITE_REQUIRES_AUTHORIZATION:
# TODO: handle authorization better
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
value = self.decode_value(value_bytes)
return self.value
def write_value(self, connection, value):
if write := getattr(self.value, 'write', None):
try:
write(connection, value) # pylint: disable=not-callable
write(connection, value)
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
else:
self.value = value
self.emit('write', connection, value)
def __repr__(self):
if isinstance(self.value, bytes):
value_str = self.value.hex()
else:
value_str = str(self.value)
if value_str:
if len(self.value) > 0:
value_string = f', value={self.value.hex()}'
else:
value_string = ''
return (
f'Attribute(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'permissions={self.permissions}{value_string})'
)
return f'Attribute(handle=0x{self.handle:04X}, type={self.type}, permissions={self.permissions}{value_string})'

File diff suppressed because it is too large Load Diff

View File

@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
class HCI_Bridge:
class Forwarder:
def __init__(self, hci_sink, sender_hci_sink, packet_filter, trace):
self.hci_sink = hci_sink
self.hci_sink = hci_sink
self.sender_hci_sink = sender_hci_sink
self.packet_filter = packet_filter
self.trace = trace
self.packet_filter = packet_filter
self.trace = trace
def on_packet(self, packet):
# Convert the packet bytes to an object
@@ -61,15 +61,15 @@ class HCI_Bridge:
hci_host_sink,
hci_controller_source,
hci_controller_sink,
host_to_controller_filter=None,
controller_to_host_filter=None,
host_to_controller_filter = None,
controller_to_host_filter = None
):
tracer = PacketTracer(emit_message=logger.info)
host_to_controller_forwarder = HCI_Bridge.Forwarder(
hci_controller_sink,
hci_host_sink,
host_to_controller_filter,
lambda packet: tracer.trace(packet, 0),
lambda packet: tracer.trace(packet, 0)
)
hci_host_source.set_packet_sink(host_to_controller_forwarder)
@@ -77,6 +77,6 @@ class HCI_Bridge:
hci_host_sink,
hci_controller_sink,
controller_to_host_filter,
lambda packet: tracer.trace(packet, 1),
lambda packet: tracer.trace(packet, 1)
)
hci_controller_source.set_packet_sink(controller_to_host_forwarder)

View File

@@ -1,103 +0,0 @@
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from functools import partial
from typing import List, Optional, Union
# ANSI color names. There is also a "default"
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
# ANSI style names
STYLES = (
'none',
'bold',
'faint',
'italic',
'underline',
'blink',
'blink2',
'negative',
'concealed',
'crossed',
)
ColorSpec = Union[str, int]
def _join(*values: ColorSpec) -> str:
return ';'.join(str(v) for v in values)
def _color_code(spec: ColorSpec, base: int) -> str:
if isinstance(spec, str):
spec = spec.strip().lower()
if spec == 'default':
return _join(base + 9)
elif spec in COLORS:
return _join(base + COLORS.index(spec))
elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
else:
raise ValueError('Invalid color spec "%s"' % spec)
def color(
s: str,
fg: Optional[ColorSpec] = None,
bg: Optional[ColorSpec] = None,
style: Optional[str] = None,
) -> str:
codes: List[ColorSpec] = []
if fg:
codes.append(_color_code(fg, 30))
if bg:
codes.append(_color_code(bg, 40))
if style:
for style_part in style.split('+'):
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
raise ValueError('Invalid style "%s"' % style_part)
if codes:
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
else:
return s
# Foreground color shortcuts
black = partial(color, fg='black')
red = partial(color, fg='red')
green = partial(color, fg='green')
yellow = partial(color, fg='yellow')
blue = partial(color, fg='blue')
magenta = partial(color, fg='magenta')
cyan = partial(color, fg='cyan')
white = partial(color, fg='white')
# Style shortcuts
bold = partial(color, style='bold')
none = partial(color, style='none')
faint = partial(color, style='faint')
italic = partial(color, style='italic')
underline = partial(color, style='underline')
blink = partial(color, style='blink')
blink2 = partial(color, style='blink2')
negative = partial(color, style='negative')
concealed = partial(color, style='concealed')
crossed = partial(color, style='crossed')

View File

@@ -17,7 +17,6 @@
# the `generate_company_id_list.py` script
# -----------------------------------------------------------------------------
# pylint: disable=line-too-long
COMPANY_IDENTIFIERS = {
0x0000: "Ericsson Technology Licensing",
0x0001: "Nokia Mobile Phones",
@@ -197,28 +196,28 @@ COMPANY_IDENTIFIERS = {
0x00AF: "Cinetix",
0x00B0: "Passif Semiconductor Corp",
0x00B1: "Saris Cycling Group, Inc",
0x00B2: "Bekey A/S",
0x00B3: "Clarinox Technologies Pty. Ltd.",
0x00B4: "BDE Technology Co., Ltd.",
0x00B2: "Bekey A/S",
0x00B3: "Clarinox Technologies Pty. Ltd.",
0x00B4: "BDE Technology Co., Ltd.",
0x00B5: "Swirl Networks",
0x00B6: "Meso international",
0x00B7: "TreLab Ltd",
0x00B8: "Qualcomm Innovation Center, Inc. (QuIC)",
0x00B9: "Johnson Controls, Inc.",
0x00BA: "Starkey Laboratories Inc.",
0x00BB: "S-Power Electronics Limited",
0x00BC: "Ace Sensor Inc",
0x00BD: "Aplix Corporation",
0x00BE: "AAMP of America",
0x00BF: "Stalmart Technology Limited",
0x00C0: "AMICCOM Electronics Corporation",
0x00C1: "Shenzhen Excelsecu Data Technology Co.,Ltd",
0x00C2: "Geneq Inc.",
0x00C3: "adidas AG",
0x00C4: "LG Electronics",
0x00C5: "Onset Computer Corporation",
0x00C6: "Selfly BV",
0x00C7: "Quuppa Oy.",
0x00B6: "Meso international",
0x00B7: "TreLab Ltd",
0x00B8: "Qualcomm Innovation Center, Inc. (QuIC)",
0x00B9: "Johnson Controls, Inc.",
0x00BA: "Starkey Laboratories Inc.",
0x00BB: "S-Power Electronics Limited",
0x00BC: "Ace Sensor Inc",
0x00BD: "Aplix Corporation",
0x00BE: "AAMP of America",
0x00BF: "Stalmart Technology Limited",
0x00C0: "AMICCOM Electronics Corporation",
0x00C1: "Shenzhen Excelsecu Data Technology Co.,Ltd",
0x00C2: "Geneq Inc.",
0x00C3: "adidas AG",
0x00C4: "LG Electronics",
0x00C5: "Onset Computer Corporation",
0x00C6: "Selfly BV",
0x00C7: "Quuppa Oy.",
0x00C8: "GeLo Inc",
0x00C9: "Evluma",
0x00CA: "MC10",
@@ -250,10 +249,10 @@ COMPANY_IDENTIFIERS = {
0x00E4: "Laird Connectivity, Inc. formerly L.S. Research Inc.",
0x00E5: "Eden Software Consultants Ltd.",
0x00E6: "Freshtemp",
0x00E7: "KS Technologies",
0x00E8: "ACTS Technologies",
0x00E9: "Vtrack Systems",
0x00EA: "Nielsen-Kellerman Company",
0x00E7: "KS Technologies",
0x00E8: "ACTS Technologies",
0x00E9: "Vtrack Systems",
0x00EA: "Nielsen-Kellerman Company",
0x00EB: "Server Technology Inc.",
0x00EC: "BioResearch Associates",
0x00ED: "Jolly Logic, LLC",
@@ -2705,5 +2704,5 @@ COMPANY_IDENTIFIERS = {
0x0A7C: "WAFERLOCK",
0x0A7D: "Freedman Electronics Pty Ltd",
0x0A7E: "Keba AG",
0x0A7F: "Intuity Medical",
}
0x0A7F: "Intuity Medical"
}

File diff suppressed because it is too large Load Diff

View File

@@ -15,9 +15,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from typing import List, Optional, Tuple, Union, cast
from .company_ids import COMPANY_IDENTIFIERS
@@ -25,8 +23,6 @@ from .company_ids import COMPANY_IDENTIFIERS
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
BT_CENTRAL_ROLE = 0
BT_PERIPHERAL_ROLE = 1
@@ -34,9 +30,6 @@ BT_BR_EDR_TRANSPORT = 0
BT_LE_TRANSPORT = 1
# fmt: on
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
@@ -65,25 +58,17 @@ def padded_bytes(buffer, size):
return buffer + bytes(padding_size)
def get_dict_key_by_value(dictionary, value):
for key, val in dictionary.items():
if val == value:
return key
return None
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
class BaseError(Exception):
"""Base class for errors with an error code, error name and namespace"""
""" Base class for errors with an error code, error name and namespace"""
def __init__(self, error_code, error_namespace='', error_name='', details=''):
super().__init__()
self.error_code = error_code
self.error_code = error_code
self.error_namespace = error_namespace
self.error_name = error_name
self.details = details
self.error_name = error_name
self.details = details
def __str__(self):
if self.error_namespace:
@@ -99,40 +84,22 @@ class BaseError(Exception):
class ProtocolError(BaseError):
"""Protocol Error"""
""" Protocol Error """
class TimeoutError(Exception): # pylint: disable=redefined-builtin
"""Timeout Error"""
class CommandTimeoutError(Exception):
"""Command Timeout Error"""
class TimeoutError(Exception):
""" Timeout Error """
class InvalidStateError(Exception):
"""Invalid State Error"""
""" Invalid State Error """
class ConnectionError(BaseError): # pylint: disable=redefined-builtin
"""Connection Error"""
FAILURE = 0x01
class ConnectionError(BaseError):
""" Connection Error """
FAILURE = 0x01
CONNECTION_REFUSED = 0x02
def __init__(
self,
error_code,
transport,
peer_address,
error_namespace='',
error_name='',
details='',
):
super().__init__(error_code, error_namespace, error_name, details)
self.transport = transport
self.peer_address = peer_address
# -----------------------------------------------------------------------------
# UUID
@@ -144,37 +111,27 @@ class ConnectionError(BaseError): # pylint: disable=redefined-builtin
class UUID:
'''
See Bluetooth spec Vol 3, Part B - 2.5.1 UUID
Note that this class expects and works in little-endian byte-order throughout.
The exception is when interacting with strings, which are in big-endian byte-order.
'''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
UUIDS = [] # Registry of all instances created
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')[::-1] # little-endian
UUIDS: List[UUID] = [] # Registry of all instances created
def __init__(self, uuid_str_or_int, name=None):
if isinstance(uuid_str_or_int, int):
def __init__(self, uuid_str_or_int, name = None):
if type(uuid_str_or_int) is int:
self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else:
if len(uuid_str_or_int) == 36:
if (
uuid_str_or_int[8] != '-'
or uuid_str_or_int[13] != '-'
or uuid_str_or_int[18] != '-'
or uuid_str_or_int[23] != '-'
):
if uuid_str_or_int[8] != '-' or uuid_str_or_int[13] != '-' or uuid_str_or_int[18] != '-' or uuid_str_or_int[23] != '-':
raise ValueError('invalid UUID format')
uuid_str = uuid_str_or_int.replace('-', '')
else:
uuid_str = uuid_str_or_int
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
raise ValueError(f"invalid UUID format: {uuid_str}")
raise ValueError('invalid UUID format')
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
self.name = name
def register(self):
# Register this object in the class registry, and update the entry's name if
# it wasn't set already
# Register this object in the class registry, and update the entry's name if it wasn't set already
for uuid in self.UUIDS:
if self == uuid:
if uuid.name is None:
@@ -185,47 +142,39 @@ class UUID:
return self
@classmethod
def from_bytes(cls, uuid_bytes: bytes, name: Optional[str] = None) -> UUID:
if len(uuid_bytes) in (2, 4, 16):
def from_bytes(cls, uuid_bytes, name = None):
if len(uuid_bytes) in {2, 4, 16}:
self = cls.__new__(cls)
self.uuid_bytes = uuid_bytes
self.name = name
return self.register()
raise ValueError('only 2, 4 and 16 bytes are allowed')
else:
raise ValueError('only 2, 4 and 16 bytes are allowed')
@classmethod
def from_16_bits(cls, uuid_16, name=None):
def from_16_bits(cls, uuid_16, name = None):
return cls.from_bytes(struct.pack('<H', uuid_16), name)
@classmethod
def from_32_bits(cls, uuid_32, name=None):
def from_32_bits(cls, uuid_32, name = None):
return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod
def parse_uuid(cls, uuid_as_bytes, offset):
return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
def parse_uuid(cls, bytes, offset):
return len(bytes), cls.from_bytes(bytes[offset:])
@classmethod
def parse_uuid_2(cls, uuid_as_bytes, offset):
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
def parse_uuid_2(cls, bytes, offset):
return offset + 2, cls.from_bytes(bytes[offset:offset + 2])
def to_bytes(self, force_128=False):
'''
Serialize UUID in little-endian byte-order
'''
if not force_128:
def to_bytes(self, force_128 = False):
if len(self.uuid_bytes) == 16 or not force_128:
return self.uuid_bytes
if len(self.uuid_bytes) == 2:
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
elif len(self.uuid_bytes) == 4:
return self.BASE_UUID + self.uuid_bytes
elif len(self.uuid_bytes) == 16:
return self.uuid_bytes
return self.uuid_bytes + UUID.BASE_UUID
else:
assert False, "unreachable"
return self.uuid_bytes + bytes([0, 0]) + UUID.BASE_UUID
def to_pdu_bytes(self):
'''
@@ -234,30 +183,27 @@ class UUID:
"All 32-bit Attribute UUIDs shall be converted to 128-bit UUIDs when the
Attribute UUID is contained in an ATT PDU."
'''
return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
return self.to_bytes(force_128 = (len(self.uuid_bytes) == 4))
def to_hex_str(self) -> str:
def to_hex_str(self):
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
return bytes(reversed(self.uuid_bytes)).hex().upper()
return ''.join(
[
else:
return ''.join([
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
bytes(reversed(self.uuid_bytes[0:6])).hex()
]).upper()
def __bytes__(self):
return self.to_bytes()
def __eq__(self, other):
if isinstance(other, UUID):
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
if isinstance(other, str):
return self.to_bytes(force_128 = True) == other.to_bytes(force_128 = True)
elif type(other) is str:
return UUID(other) == self
return False
@@ -267,26 +213,23 @@ class UUID:
def __str__(self):
if len(self.uuid_bytes) == 2:
uuid = struct.unpack('<H', self.uuid_bytes)[0]
result = f'UUID-16:{uuid:04X}'
v = struct.unpack('<H', self.uuid_bytes)[0]
result = f'UUID-16:{v:04X}'
elif len(self.uuid_bytes) == 4:
uuid = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{uuid:08X}'
v = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{v:08X}'
else:
result = '-'.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
result = '-'.join([
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex()
]).upper()
if self.name is not None:
return result + f' ({self.name})'
return result
else:
return result
def __repr__(self):
return str(self)
@@ -295,8 +238,6 @@ class UUID:
# -----------------------------------------------------------------------------
# Common UUID constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
# Protocol Identifiers
BT_SDP_PROTOCOL_ID = UUID.from_16_bits(0x0001, 'SDP')
@@ -402,17 +343,11 @@ BT_HDP_SERVICE = UUID.from_16_bits(0x1400,
BT_HDP_SOURCE_SERVICE = UUID.from_16_bits(0x1401, 'HDP Source')
BT_HDP_SINK_SERVICE = UUID.from_16_bits(0x1402, 'HDP Sink')
# fmt: on
# pylint: enable=line-too-long
# -----------------------------------------------------------------------------
# DeviceClass
# -----------------------------------------------------------------------------
class DeviceClass:
# fmt: off
# pylint: disable=line-too-long
# Major Service Classes (flags combined with OR)
LIMITED_DISCOVERABLE_MODE_SERVICE_CLASS = (1 << 0)
LE_AUDIO_SERVICE_CLASS = (1 << 1)
@@ -580,18 +515,11 @@ class DeviceClass:
PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES
}
# fmt: on
# pylint: enable=line-too-long
@staticmethod
def split_class_of_device(class_of_device):
# Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class)
return (
(class_of_device >> 13 & 0x7FF),
(class_of_device >> 8 & 0x1F),
(class_of_device >> 2 & 0x3F),
)
return ((class_of_device >> 13 & 0x7FF), (class_of_device >> 8 & 0x1F), (class_of_device >> 2 & 0x3F))
@staticmethod
def pack_class_of_device(service_classes, major_device_class, minor_device_class):
@@ -599,9 +527,7 @@ class DeviceClass:
@staticmethod
def service_class_labels(service_class_flags):
return bit_flags_to_strings(
service_class_flags, DeviceClass.SERVICE_CLASS_LABELS
)
return bit_flags_to_strings(service_class_flags, DeviceClass.SERVICE_CLASS_LABELS)
@staticmethod
def major_device_class_name(device_class):
@@ -618,15 +544,7 @@ class DeviceClass:
# -----------------------------------------------------------------------------
# Advertising Data
# -----------------------------------------------------------------------------
AdvertisingObject = Union[
List[UUID], Tuple[UUID, bytes], bytes, str, int, Tuple[int, int], Tuple[int, bytes]
]
class AdvertisingData:
# fmt: off
# pylint: disable=line-too-long
# This list is only partial, it still needs to be filled in from the spec
FLAGS = 0x01
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02
@@ -738,14 +656,7 @@ class AdvertisingData:
BR_EDR_CONTROLLER_FLAG = 0x08
BR_EDR_HOST_FLAG = 0x10
ad_structures: List[Tuple[int, bytes]]
# fmt: on
# pylint: enable=line-too-long
def __init__(self, ad_structures: Optional[List[Tuple[int, bytes]]] = None) -> None:
if ad_structures is None:
ad_structures = []
def __init__(self, ad_structures = []):
self.ad_structures = ad_structures[:]
@staticmethod
@@ -756,36 +667,36 @@ class AdvertisingData:
@staticmethod
def flags_to_string(flags, short=False):
flag_names = (
['LE Limited', 'LE General', 'No BR/EDR', 'BR/EDR C', 'BR/EDR H']
if short
else [
'LE Limited Discoverable Mode',
'LE General Discoverable Mode',
'BR/EDR Not Supported',
'Simultaneous LE and BR/EDR (Controller)',
'Simultaneous LE and BR/EDR (Host)',
]
)
flag_names = [
'LE Limited',
'LE General',
'No BR/EDR',
'BR/EDR C',
'BR/EDR H'
] if short else [
'LE Limited Discoverable Mode',
'LE General Discoverable Mode',
'BR/EDR Not Supported',
'Simultaneous LE and BR/EDR (Controller)',
'Simultaneous LE and BR/EDR (Host)'
]
return ','.join(bit_flags_to_strings(flags, flag_names))
@staticmethod
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]:
def uuid_list_to_objects(ad_data, uuid_size):
uuids = []
offset = 0
while (uuid_size * (offset + 1)) <= len(ad_data):
uuids.append(UUID.from_bytes(ad_data[offset : offset + uuid_size]))
uuids.append(UUID.from_bytes(ad_data[offset:offset + uuid_size]))
offset += uuid_size
return uuids
@staticmethod
def uuid_list_to_string(ad_data, uuid_size):
return ', '.join(
[
str(uuid)
for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size)
]
)
return ', '.join([
str(uuid)
for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size)
])
@staticmethod
def ad_data_to_string(ad_type, ad_data):
@@ -845,65 +756,40 @@ class AdvertisingData:
return f'[{ad_type_str}]: {ad_data_str}'
# pylint: disable=too-many-return-statements
@staticmethod
def ad_data_to_object(ad_type: int, ad_data: bytes) -> AdvertisingObject:
if ad_type in (
def ad_data_to_object(ad_type, ad_data):
if ad_type in {
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
):
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS
}:
return AdvertisingData.uuid_list_to_objects(ad_data, 2)
if ad_type in (
elif ad_type in {
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
):
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS
}:
return AdvertisingData.uuid_list_to_objects(ad_data, 4)
if ad_type in (
elif ad_type in {
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
):
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS
}:
return AdvertisingData.uuid_list_to_objects(ad_data, 16)
if ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
return (UUID.from_bytes(ad_data[:2]), ad_data[2:])
if ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
return (UUID.from_bytes(ad_data[:4]), ad_data[4:])
if ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
return (UUID.from_bytes(ad_data[:16]), ad_data[16:])
if ad_type in (
elif ad_type in {
AdvertisingData.SHORTENED_LOCAL_NAME,
AdvertisingData.COMPLETE_LOCAL_NAME,
AdvertisingData.URI,
):
AdvertisingData.COMPLETE_LOCAL_NAME
}:
return ad_data.decode("utf-8")
if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS):
return cast(int, struct.unpack('B', ad_data)[0])
if ad_type in (
AdvertisingData.APPEARANCE,
AdvertisingData.ADVERTISING_INTERVAL,
):
return cast(int, struct.unpack('<H', ad_data)[0])
if ad_type == AdvertisingData.CLASS_OF_DEVICE:
return cast(int, struct.unpack('<I', bytes([*ad_data, 0]))[0])
if ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
return cast(Tuple[int, int], struct.unpack('<HH', ad_data))
if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:])
return ad_data
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
return ad_data[0]
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:])
else:
return ad_data
def append(self, data):
offset = 0
@@ -912,41 +798,30 @@ class AdvertisingData:
offset += 1
if length > 0:
ad_type = data[offset]
ad_data = data[offset + 1 : offset + length]
ad_data = data[offset + 1:offset + length]
self.ad_structures.append((ad_type, ad_data))
offset += length
def get_all(self, type_id: int, raw: bool = False) -> List[AdvertisingObject]:
def get(self, type_id, return_all=False, raw=True):
'''
Get Advertising Data Structure(s) with a given type
Returns a (possibly empty) list of matches.
If return_all is True, returns a (possibly empty) list of matches,
else returns the first entry, or None if no structure matches.
'''
def process_ad_data(ad_data: bytes) -> AdvertisingObject:
def process_ad_data(ad_data):
return ad_data if raw else self.ad_data_to_object(type_id, ad_data)
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingObject]:
'''
Get Advertising Data Structure(s) with a given type
Returns the first entry, or None if no structure matches.
'''
all = self.get_all(type_id, raw=raw)
return all[0] if all else None
if return_all:
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
else:
return next((process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id), None)
def __bytes__(self):
return b''.join(
[bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures]
)
return b''.join([bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures])
def to_string(self, separator=', '):
return separator.join(
[AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures]
)
return separator.join([AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures])
def __str__(self):
return self.to_string()
@@ -956,17 +831,13 @@ class AdvertisingData:
# Connection Parameters
# -----------------------------------------------------------------------------
class ConnectionParameters:
def __init__(self, connection_interval, peripheral_latency, supervision_timeout):
def __init__(self, connection_interval, connection_latency, supervision_timeout):
self.connection_interval = connection_interval
self.peripheral_latency = peripheral_latency
self.connection_latency = connection_latency
self.supervision_timeout = supervision_timeout
def __str__(self):
return (
f'ConnectionParameters(connection_interval={self.connection_interval}, '
f'peripheral_latency={self.peripheral_latency}, '
f'supervision_timeout={self.supervision_timeout}'
)
return f'ConnectionParameters(connection_interval={self.connection_interval}, connection_latency={self.connection_latency}, supervision_timeout={self.supervision_timeout}'
# -----------------------------------------------------------------------------

View File

@@ -24,16 +24,19 @@
import logging
import operator
import platform
if platform.system() != 'Emscripten':
import secrets
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.ciphers import (
Cipher,
algorithms,
modes
)
from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key,
ECDH,
EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers,
SECP256R1,
SECP256R1
)
from cryptography.hazmat.primitives import cmac
else:
@@ -63,26 +66,16 @@ class EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
private_key = EllipticCurvePrivateNumbers(
d, EllipticCurvePublicNumbers(x, y, SECP256R1())
).private_key()
private_key = EllipticCurvePrivateNumbers(d, EllipticCurvePublicNumbers(x, y, SECP256R1())).private_key()
return cls(private_key)
@property
def x(self):
return (
self.private_key.public_key()
.public_numbers()
.x.to_bytes(32, byteorder='big')
)
return self.private_key.public_key().public_numbers().x.to_bytes(32, byteorder='big')
@property
def y(self):
return (
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
return self.private_key.public_key().public_numbers().y.to_bytes(32, byteorder='big')
def dh(self, public_key_x, public_key_y):
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
@@ -99,7 +92,7 @@ class EccKey:
# -----------------------------------------------------------------------------
def xor(x, y):
assert len(x) == len(y)
assert(len(x) == len(y))
return bytes(map(operator.xor, x, y))
@@ -125,7 +118,7 @@ def e(key, data):
# -----------------------------------------------------------------------------
def ah(k, r): # pylint: disable=redefined-outer-name
def ah(k, r):
'''
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
'''
@@ -136,10 +129,9 @@ def ah(k, r): # pylint: disable=redefined-outer-name
# -----------------------------------------------------------------------------
def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
def c1(k, r, preq, pres, iat, rat, ia, ra):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
LE Legacy Pairing
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for LE Legacy Pairing
'''
p1 = bytes([iat, rat]) + preq + pres
@@ -150,8 +142,7 @@ def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-n
# -----------------------------------------------------------------------------
def s1(k, r1, r2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
Pairing
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy Pairing
'''
return e(k, r2[0:8] + r1[0:8])
@@ -172,95 +163,71 @@ def aes_cmac(m, k):
# -----------------------------------------------------------------------------
def f4(u, v, x, z):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
Generation Function f4
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value Generation Function f4
'''
return bytes(
reversed(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
)
)
return bytes(reversed(aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))))
# -----------------------------------------------------------------------------
def f5(w, n1, n2, a1, a2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
Function f5
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation Function f5
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
'''
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
t = aes_cmac(bytes(reversed(w)), salt)
key_id = bytes([0x62, 0x74, 0x6C, 0x65])
key_id = bytes([0x62, 0x74, 0x6c, 0x65])
return (
bytes(
reversed(
aes_cmac(
bytes([0])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
bytes(
reversed(
aes_cmac(
bytes([1])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
bytes(reversed(aes_cmac(
bytes([0]) +
key_id +
bytes(reversed(n1)) +
bytes(reversed(n2)) +
bytes(reversed(a1)) +
bytes(reversed(a2)) +
bytes([1, 0]),
t
))),
bytes(reversed(aes_cmac(
bytes([1]) +
key_id +
bytes(reversed(n1)) +
bytes(reversed(n2)) +
bytes(reversed(a1)) +
bytes(reversed(a2)) +
bytes([1, 0]),
t
)))
)
# -----------------------------------------------------------------------------
def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
def f6(w, n1, n2, r, io_cap, a1, a2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
Generation Function f6
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value Generation Function f6
'''
return bytes(
reversed(
aes_cmac(
bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(r))
+ bytes(reversed(io_cap))
+ bytes(reversed(a1))
+ bytes(reversed(a2)),
bytes(reversed(w)),
)
)
)
return bytes(reversed(aes_cmac(
bytes(reversed(n1)) +
bytes(reversed(n2)) +
bytes(reversed(r)) +
bytes(reversed(io_cap)) +
bytes(reversed(a1)) +
bytes(reversed(a2)),
bytes(reversed(w))
)))
# -----------------------------------------------------------------------------
def g2(u, v, x, y):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
Value Generation Function g2
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison Value Generation Function g2
'''
return int.from_bytes(
aes_cmac(
bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
bytes(reversed(x)),
)[-4:],
byteorder='big',
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)), bytes(reversed(x)))[-4:],
byteorder='big'
)
# -----------------------------------------------------------------------------
def h6(w, key_id):
'''
@@ -268,7 +235,6 @@ def h6(w, key_id):
'''
return aes_cmac(key_id, w)
# -----------------------------------------------------------------------------
def h7(salt, w):
'''

View File

@@ -1,416 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
WL = [-60, -30, 58, 172, 334, 538, 1198, 3042]
RL42 = [0, 7, 6, 5, 4, 3, 2, 1, 7, 6, 5, 4, 3, 2, 1, 0]
ILB = [
2048,
2093,
2139,
2186,
2233,
2282,
2332,
2383,
2435,
2489,
2543,
2599,
2656,
2714,
2774,
2834,
2896,
2960,
3025,
3091,
3158,
3228,
3298,
3371,
3444,
3520,
3597,
3676,
3756,
3838,
3922,
4008,
]
WH = [0, -214, 798]
RH2 = [2, 1, 2, 1]
# Values in QM2/QM4/QM6 left shift three bits than original g722 specification.
QM2 = [-7408, -1616, 7408, 1616]
QM4 = [
0,
-20456,
-12896,
-8968,
-6288,
-4240,
-2584,
-1200,
20456,
12896,
8968,
6288,
4240,
2584,
1200,
0,
]
QM6 = [
-136,
-136,
-136,
-136,
-24808,
-21904,
-19008,
-16704,
-14984,
-13512,
-12280,
-11192,
-10232,
-9360,
-8576,
-7856,
-7192,
-6576,
-6000,
-5456,
-4944,
-4464,
-4008,
-3576,
-3168,
-2776,
-2400,
-2032,
-1688,
-1360,
-1040,
-728,
24808,
21904,
19008,
16704,
14984,
13512,
12280,
11192,
10232,
9360,
8576,
7856,
7192,
6576,
6000,
5456,
4944,
4464,
4008,
3576,
3168,
2776,
2400,
2032,
1688,
1360,
1040,
728,
432,
136,
-432,
-136,
]
QMF_COEFFS = [3, -11, 12, 32, -210, 951, 3876, -805, 362, -156, 53, -11]
# fmt: on
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class G722Decoder(object):
"""G.722 decoder with bitrate 64kbit/s.
For the Blocks in the sub-band decoders, please refer to the G.722
specification for the required information. G722 specification:
https://www.itu.int/rec/T-REC-G.722-201209-I
"""
def __init__(self):
self._x = [0] * 24
self._band = [Band(), Band()]
# The initial value in BLOCK 3L
self._band[0].det = 32
# The initial value in BLOCK 3H
self._band[1].det = 8
def decode_frame(self, encoded_data) -> bytearray:
result_array = bytearray(len(encoded_data) * 4)
self.g722_decode(result_array, encoded_data)
return result_array
def g722_decode(self, result_array, encoded_data) -> int:
"""Decode the data frame using g722 decoder."""
result_length = 0
for code in encoded_data:
higher_bits = (code >> 6) & 0x03
lower_bits = code & 0x3F
rlow = self.lower_sub_band_decoder(lower_bits)
rhigh = self.higher_sub_band_decoder(higher_bits)
# Apply the receive QMF
self._x[:22] = self._x[2:]
self._x[22] = rlow + rhigh
self._x[23] = rlow - rhigh
xout2 = sum(self._x[2 * i] * QMF_COEFFS[i] for i in range(12))
xout1 = sum(self._x[2 * i + 1] * QMF_COEFFS[11 - i] for i in range(12))
result_length = self.update_decoded_result(
xout1, result_length, result_array
)
result_length = self.update_decoded_result(
xout2, result_length, result_array
)
return result_length
def update_decoded_result(self, xout, byte_length, byte_array) -> int:
result = (int)(xout >> 11)
bytes_result = result.to_bytes(2, 'little', signed=True)
byte_array[byte_length] = bytes_result[0]
byte_array[byte_length + 1] = bytes_result[1]
return byte_length + 2
def lower_sub_band_decoder(self, lower_bits) -> int:
"""Lower sub-band decoder for last six bits."""
# Block 5L
# INVQBL
wd1 = lower_bits
wd2 = QM6[wd1]
wd1 >>= 2
wd2 = (self._band[0].det * wd2) >> 15
# RECONS
rlow = self._band[0].s + wd2
# Block 6L
# LIMIT
if rlow > 16383:
rlow = 16383
elif rlow < -16384:
rlow = -16384
# Block 2L
# INVQAL
wd2 = QM4[wd1]
dlowt = (self._band[0].det * wd2) >> 15
# Block 3L
# LOGSCL
wd2 = RL42[wd1]
wd1 = (self._band[0].nb * 127) >> 7
wd1 += WL[wd2]
if wd1 < 0:
wd1 = 0
elif wd1 > 18432:
wd1 = 18432
self._band[0].nb = wd1
# SCALEL
wd1 = (self._band[0].nb >> 6) & 31
wd2 = 8 - (self._band[0].nb >> 11)
if wd2 < 0:
wd3 = ILB[wd1] << -wd2
else:
wd3 = ILB[wd1] >> wd2
self._band[0].det = wd3 << 2
# Block 4L
self._band[0].block4(dlowt)
return rlow
def higher_sub_band_decoder(self, higher_bits) -> int:
"""Higher sub-band decoder for first two bits."""
# Block 2H
# INVQAH
wd2 = QM2[higher_bits]
dhigh = (self._band[1].det * wd2) >> 15
# Block 5H
# RECONS
rhigh = dhigh + self._band[1].s
# Block 6H
# LIMIT
if rhigh > 16383:
rhigh = 16383
elif rhigh < -16384:
rhigh = -16384
# Block 3H
# LOGSCH
wd2 = RH2[higher_bits]
wd1 = (self._band[1].nb * 127) >> 7
wd1 += WH[wd2]
if wd1 < 0:
wd1 = 0
elif wd1 > 22528:
wd1 = 22528
self._band[1].nb = wd1
# SCALEH
wd1 = (self._band[1].nb >> 6) & 31
wd2 = 10 - (self._band[1].nb >> 11)
if wd2 < 0:
wd3 = ILB[wd1] << -wd2
else:
wd3 = ILB[wd1] >> wd2
self._band[1].det = wd3 << 2
# Block 4H
self._band[1].block4(dhigh)
return rhigh
# -----------------------------------------------------------------------------
class Band(object):
"""Structure for G722 decode proccessing."""
s: int = 0
nb: int = 0
det: int = 0
def __init__(self):
self._sp = 0
self._sz = 0
self._r = [0] * 3
self._a = [0] * 3
self._ap = [0] * 3
self._p = [0] * 3
self._d = [0] * 7
self._b = [0] * 7
self._bp = [0] * 7
self._sg = [0] * 7
def saturate(self, amp: int) -> int:
if amp > 32767:
return 32767
elif amp < -32768:
return -32768
else:
return amp
def block4(self, d: int) -> None:
"""Block4 for both lower and higher sub-band decoder."""
wd1 = 0
wd2 = 0
wd3 = 0
# RECONS
self._d[0] = d
self._r[0] = self.saturate(self.s + d)
# PARREC
self._p[0] = self.saturate(self._sz + d)
# UPPOL2
for i in range(3):
self._sg[i] = (self._p[i]) >> 15
wd1 = self.saturate((self._a[1]) << 2)
wd2 = -wd1 if self._sg[0] == self._sg[1] else wd1
if wd2 > 32767:
wd2 = 32767
wd3 = 128 if self._sg[0] == self._sg[2] else -128
wd3 += wd2 >> 7
wd3 += (self._a[2] * 32512) >> 15
if wd3 > 12288:
wd3 = 12288
elif wd3 < -12288:
wd3 = -12288
self._ap[2] = wd3
# UPPOL1
self._sg[0] = (self._p[0]) >> 15
self._sg[1] = (self._p[1]) >> 15
wd1 = 192 if self._sg[0] == self._sg[1] else -192
wd2 = (self._a[1] * 32640) >> 15
self._ap[1] = self.saturate(wd1 + wd2)
wd3 = self.saturate(15360 - self._ap[2])
if self._ap[1] > wd3:
self._ap[1] = wd3
elif self._ap[1] < -wd3:
self._ap[1] = -wd3
# UPZERO
wd1 = 0 if d == 0 else 128
self._sg[0] = d >> 15
for i in range(1, 7):
self._sg[i] = (self._d[i]) >> 15
wd2 = wd1 if self._sg[i] == self._sg[0] else -wd1
wd3 = (self._b[i] * 32640) >> 15
self._bp[i] = self.saturate(wd2 + wd3)
# DELAYA
for i in range(6, 0, -1):
self._d[i] = self._d[i - 1]
self._b[i] = self._bp[i]
for i in range(2, 0, -1):
self._r[i] = self._r[i - 1]
self._p[i] = self._p[i - 1]
self._a[i] = self._ap[i]
# FILTEP
self._sp = 0
for i in range(1, 3):
wd1 = self.saturate(self._r[i] + self._r[i])
self._sp += (self._a[i] * wd1) >> 15
self._sp = self.saturate(self._sp)
# FILTEZ
self._sz = 0
for i in range(6, 0, -1):
wd1 = self.saturate(self._d[i] + self._d[i])
self._sz += (self._b[i] * wd1) >> 15
self._sz = self.saturate(self._sz)
# PREDIC
self.s = self.saturate(self._sp + self._sz)

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,7 @@ from .gatt import (
Characteristic,
GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC
)
# -----------------------------------------------------------------------------
@@ -38,22 +38,22 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class GenericAccessService(Service):
def __init__(self, device_name, appearance=(0, 0)):
def __init__(self, device_name, appearance = (0, 0)):
device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248],
device_name.encode('utf-8')[:248]
)
appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1]),
struct.pack('<H', (appearance[0] << 6) | appearance[1])
)
super().__init__(
GATT_GENERIC_ACCESS_SERVICE,
[device_name_characteristic, appearance_characteristic],
)
super().__init__(GATT_GENERIC_ACCESS_SERVICE, [
device_name_characteristic,
appearance_characteristic
])

View File

@@ -22,18 +22,14 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import functools
import types
import logging
import struct
from typing import Optional, Sequence, List
from .colors import color
from .core import UUID, get_dict_key_by_value
from .att import Attribute
from colors import color
from .core import *
from .hci import *
from .att import *
# -----------------------------------------------------------------------------
# Logging
@@ -43,9 +39,6 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
GATT_REQUEST_TIMEOUT = 30 # seconds
GATT_MAX_ATTRIBUTE_VALUE_SIZE = 512
@@ -156,14 +149,6 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
# Battery Service
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
# Misc
GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name')
GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance')
@@ -178,15 +163,11 @@ GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bi
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
# fmt: on
# pylint: enable=line-too-long
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def show_services(services):
for service in services:
print(color(str(service), 'cyan'))
@@ -204,40 +185,23 @@ class Service(Attribute):
See Vol 3, Part G - 3.1 SERVICE DEFINITION
'''
uuid: UUID
def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
def __init__(self, uuid, characteristics, primary=True):
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
if type(uuid) is str:
uuid = UUID(uuid)
super().__init__(
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Attribute.READABLE,
uuid.to_pdu_bytes(),
uuid.to_pdu_bytes()
)
self.uuid = uuid
# self.included_services = []
self.characteristics = characteristics[:]
self.primary = primary
def get_advertising_data(self) -> Optional[bytes]:
"""
Get Service specific advertising data
Defined by each Service, default value is empty
:return Service data for advertising
"""
return None
self.uuid = uuid
self.included_services = []
self.characteristics = characteristics[:]
self.primary = primary
def __str__(self):
return (
f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid})'
f'{"" if self.primary else "*"}'
)
return f'Service(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}){"" if self.primary else "*"}'
# -----------------------------------------------------------------------------
@@ -246,7 +210,6 @@ class TemplateService(Service):
Convenience abstract class that can be used by profile-specific subclasses that want
to expose their UUID as a class property
'''
UUID = None
def __init__(self, characteristics, primary=True):
@@ -259,115 +222,51 @@ class Characteristic(Attribute):
See Vol 3, Part G - 3.3 CHARACTERISTIC DEFINITION
'''
uuid: UUID
properties: Characteristic.Properties
# Property flags
BROADCAST = 0x01
READ = 0x02
WRITE_WITHOUT_RESPONSE = 0x04
WRITE = 0x08
NOTIFY = 0x10
INDICATE = 0X20
AUTHENTICATED_SIGNED_WRITES = 0X40
EXTENDED_PROPERTIES = 0X80
class Properties(enum.IntFlag):
"""Property flags"""
PROPERTY_NAMES = {
BROADCAST: 'BROADCAST',
READ: 'READ',
WRITE_WITHOUT_RESPONSE: 'WRITE_WITHOUT_RESPONSE',
WRITE: 'WRITE',
NOTIFY: 'NOTIFY',
INDICATE: 'INDICATE',
AUTHENTICATED_SIGNED_WRITES: 'AUTHENTICATED_SIGNED_WRITES',
EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES'
}
BROADCAST = 0x01
READ = 0x02
WRITE_WITHOUT_RESPONSE = 0x04
WRITE = 0x08
NOTIFY = 0x10
INDICATE = 0x20
AUTHENTICATED_SIGNED_WRITES = 0x40
EXTENDED_PROPERTIES = 0x80
@staticmethod
def property_name(property):
return Characteristic.PROPERTY_NAMES.get(property, '')
@staticmethod
def from_string(properties_str: str) -> Characteristic.Properties:
property_names: List[str] = []
for property in Characteristic.Properties:
if property.name is None:
raise TypeError()
property_names.append(property.name)
@staticmethod
def properties_as_string(properties):
return ','.join([
Characteristic.property_name(p) for p in Characteristic.PROPERTY_NAMES.keys()
if properties & p
])
def string_to_property(property_string) -> Characteristic.Properties:
for property in zip(Characteristic.Properties, property_names):
if property_string == property[1]:
return property[0]
raise TypeError(f"Unable to convert {property_string} to Property")
try:
return functools.reduce(
lambda x, y: x | string_to_property(y),
properties_str.split(","),
Characteristic.Properties(0),
)
except TypeError:
raise TypeError(
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by commas: {','.join(property_names)}\nGot: {properties_str}"
)
# For backwards compatibility these are defined here
# For new code, please use Characteristic.Properties.X
BROADCAST = Properties.BROADCAST
READ = Properties.READ
WRITE_WITHOUT_RESPONSE = Properties.WRITE_WITHOUT_RESPONSE
WRITE = Properties.WRITE
NOTIFY = Properties.NOTIFY
INDICATE = Properties.INDICATE
AUTHENTICATED_SIGNED_WRITES = Properties.AUTHENTICATED_SIGNED_WRITES
EXTENDED_PROPERTIES = Properties.EXTENDED_PROPERTIES
def __init__(
self,
uuid,
properties: Characteristic.Properties,
permissions,
value=b'',
descriptors: Sequence[Descriptor] = (),
):
def __init__(self, uuid, properties, permissions, value = b'', descriptors = []):
super().__init__(uuid, permissions, value)
self.uuid = self.type
self.properties = properties
self.uuid = self.type
self.properties = properties
self.descriptors = descriptors
def get_descriptor(self, descriptor_type):
for descriptor in self.descriptors:
if descriptor.type == descriptor_type:
if descriptor.uuid == descriptor_type:
return descriptor
return None
def has_properties(self, properties: Characteristic.Properties) -> bool:
return self.properties & properties == properties
def __str__(self):
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid}, '
f'{self.properties!s})'
)
# -----------------------------------------------------------------------------
class CharacteristicDeclaration(Attribute):
'''
See Vol 3, Part G - 3.3.1 CHARACTERISTIC DECLARATION
'''
characteristic: Characteristic
def __init__(self, characteristic, value_handle):
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
)
super().__init__(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
)
self.value_handle = value_handle
self.characteristic = characteristic
def __str__(self):
return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
f'uuid={self.characteristic.uuid}, '
f'{self.characteristic.properties!s})'
)
return f'Characteristic(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})'
# -----------------------------------------------------------------------------
@@ -376,7 +275,6 @@ class CharacteristicValue:
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
@@ -403,38 +301,27 @@ class CharacteristicAdapter:
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
'''
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
if (
asyncio.iscoroutinefunction(characteristic.read_value) and
asyncio.iscoroutinefunction(characteristic.write_value)
):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
self.read_value = self.read_encoded_value
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
return getattr(self.wrapped_characteristic, name)
def __setattr__(self, name, value):
if name in (
'wrapped_characteristic',
'subscribers',
'read_value',
'write_value',
'subscribe',
'unsubscribe',
):
if name in {'wrapped_characteristic', 'read_value', 'write_value', 'subscribe'}:
super().__setattr__(name, value)
else:
setattr(self.wrapped_characteristic, name, value)
@@ -443,17 +330,13 @@ class CharacteristicAdapter:
return self.encode_value(self.wrapped_characteristic.read_value(connection))
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
return self.wrapped_characteristic.write_value(connection, self.decode_value(value))
async def read_decoded_value(self):
return self.decode_value(await self.wrapped_characteristic.read_value())
async def write_decoded_value(self, value, with_response=False):
return await self.wrapped_characteristic.write_value(
self.encode_value(value), with_response
)
async def write_decoded_value(self, value):
return await self.wrapped_characteristic.write_value(self.encode_value(value))
def encode_value(self, value):
return value
@@ -462,27 +345,9 @@ class CharacteristicAdapter:
return value
def wrapped_subscribe(self, subscriber=None):
if subscriber is not None:
if subscriber in self.subscribers:
# We already have a proxy subscriber
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return self.wrapped_characteristic.subscribe(subscriber)
def wrapped_unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return self.wrapped_characteristic.unsubscribe(subscriber)
return self.wrapped_characteristic.subscribe(
None if subscriber is None else lambda value: subscriber(self.decode_value(value))
)
def __str__(self):
wrapped = str(self.wrapped_characteristic)
@@ -494,7 +359,6 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts bytes values using an encode and a decode function.
'''
def __init__(self, characteristic, encode=None, decode=None):
super().__init__(characteristic)
self.encode = encode
@@ -517,10 +381,9 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
they return/accept a tuple with the same number of elements as is required for
the format.
'''
def __init__(self, characteristic, pack_format):
def __init__(self, characteristic, format):
super().__init__(characteristic)
self.struct = struct.Struct(pack_format)
self.struct = struct.Struct(format)
def pack(self, *values):
return self.struct.pack(*values)
@@ -529,7 +392,7 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
return self.struct.unpack(buffer)
def encode_value(self, value):
return self.pack(*value if isinstance(value, tuple) else (value,))
return self.pack(*value if type(value) is tuple else (value,))
def decode_value(self, value):
unpacked = self.unpack(value)
@@ -542,15 +405,13 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept aa dictionary which
is packed/unpacked according to format, with the arguments extracted from the
dictionary by key, in the same order as they occur in the `keys` parameter.
is packed/unpacked according to format, with the arguments extracted from the dictionary
by key, in the same order as they occur in the `keys` parameter.
'''
def __init__(self, characteristic, pack_format, keys):
super().__init__(characteristic, pack_format)
def __init__(self, characteristic, format, keys):
super().__init__(characteristic, format)
self.keys = keys
# pylint: disable=arguments-differ
def pack(self, values):
return super().pack(*(values[key] for key in self.keys))
@@ -563,7 +424,6 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
def encode_value(self, value):
return value.encode('utf-8')
@@ -577,20 +437,8 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
'''
def __init__(self, descriptor_type, permissions, value = b''):
super().__init__(descriptor_type, permissions, value)
def __str__(self):
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={self.read_value(None).hex()})'
)
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
field definition
'''
DEFAULT = 0x0000
NOTIFICATION = 0x0001
INDICATION = 0x0002
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type}, value={self.read_value(None).hex()})'

View File

@@ -23,48 +23,21 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import struct
from datetime import datetime
from typing import List, Optional, Dict, Tuple, Callable, Union, Any
from colors import color
from pyee import EventEmitter
from .colors import color
from .hci import HCI_Constant
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
ATT_DEFAULT_MTU,
ATT_ERROR_RESPONSE,
ATT_INVALID_OFFSET_ERROR,
ATT_PDU,
ATT_RESPONSES,
ATT_Exchange_MTU_Request,
ATT_Find_By_Type_Value_Request,
ATT_Find_Information_Request,
ATT_Handle_Value_Confirmation,
ATT_Read_Blob_Request,
ATT_Read_By_Group_Type_Request,
ATT_Read_By_Type_Request,
ATT_Read_Request,
ATT_Write_Command,
ATT_Write_Request,
ATT_Error,
)
from . import core
from .core import UUID, InvalidStateError, ProtocolError
from .core import ProtocolError, TimeoutError
from .hci import *
from .att import *
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
Characteristic
)
# -----------------------------------------------------------------------------
@@ -77,58 +50,38 @@ logger = logging.getLogger(__name__)
# Proxies
# -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter):
client: Client
def __init__(self, client, handle, end_group_handle, attribute_type):
EventEmitter.__init__(self)
self.client = client
self.handle = handle
self.client = client
self.handle = handle
self.end_group_handle = end_group_handle
self.type = attribute_type
self.type = attribute_type
async def read_value(self, no_long_read=False):
return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
return await self.client.read_value(self.handle, no_long_read)
async def write_value(self, value, with_response=False):
return await self.client.write_value(
self.handle, self.encode_value(value), with_response
)
def encode_value(self, value):
return value
def decode_value(self, value_bytes):
return value_bytes
return await self.client.write_value(self.handle, value, with_response)
def __str__(self):
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
return f'Attribute(handle=0x{self.handle:04X}, type={self.uuid})'
class ServiceProxy(AttributeProxy):
uuid: UUID
characteristics: List[CharacteristicProxy]
@staticmethod
def from_client(service_class, client, service_uuid):
# The service and its characteristics are considered to have already been
# discovered
def from_client(cls, client, service_uuid):
# The service and its characteristics are considered to have already been discovered
services = client.get_services_by_uuid(service_uuid)
service = services[0] if services else None
return service_class(service) if service else None
return cls(service) if service else None
def __init__(self, client, handle, end_group_handle, uuid, primary=True):
attribute_type = (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
)
attribute_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
super().__init__(client, handle, end_group_handle, attribute_type)
self.uuid = uuid
self.uuid = uuid
self.characteristics = []
async def discover_characteristics(self, uuids=()):
async def discover_characteristics(self, uuids=[]):
return await self.client.discover_characteristics(uuids, self)
def get_characteristics_by_uuid(self, uuid):
@@ -139,66 +92,29 @@ class ServiceProxy(AttributeProxy):
class CharacteristicProxy(AttributeProxy):
properties: Characteristic.Properties
descriptors: List[DescriptorProxy]
subscribers: Dict[Any, Callable]
def __init__(
self,
client,
handle,
end_group_handle,
uuid,
properties: int,
):
def __init__(self, client, handle, end_group_handle, uuid, properties):
super().__init__(client, handle, end_group_handle, uuid)
self.uuid = uuid
self.properties = Characteristic.Properties(properties)
self.descriptors = []
self.uuid = uuid
self.properties = properties
self.descriptors = []
self.descriptors_discovered = False
self.subscribers = {} # Map from subscriber to proxy subscriber
def get_descriptor(self, descriptor_type):
for descriptor in self.descriptors:
if descriptor.type == descriptor_type:
return descriptor
return None
async def discover_descriptors(self):
return await self.client.discover_descriptors(self)
async def subscribe(
self, subscriber: Optional[Callable] = None, prefer_notify=True
):
if subscriber is not None:
if subscriber in self.subscribers:
# We already have a proxy subscriber
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return await self.client.subscribe(self, subscriber, prefer_notify)
async def subscribe(self, subscriber=None):
return await self.client.subscribe(self, subscriber)
async def unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber)
def __str__(self):
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, '
f'{self.properties!s})'
)
return f'Characteristic(handle=0x{self.handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})'
class DescriptorProxy(AttributeProxy):
@@ -213,7 +129,6 @@ class ProfileServiceProxy:
'''
Base class for profile-specific service proxies
'''
@classmethod
def from_client(cls, client):
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -223,65 +138,51 @@ class ProfileServiceProxy:
# GATT Client
# -----------------------------------------------------------------------------
class Client:
services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]]
def __init__(self, connection):
self.connection = connection
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = (
{}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.services = []
self.cached_values = {}
self.connection = connection
self.mtu = ATT_DEFAULT_MTU
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = {} # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.services = []
def send_gatt_pdu(self, pdu):
self.connection.send_l2cap_pdu(ATT_CID, pdu)
async def send_command(self, command):
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
logger.debug(f'GATT Command from client: [0x{self.connection.handle:04X}] {command}')
self.send_gatt_pdu(command.to_bytes())
async def send_request(self, request):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
logger.debug(f'GATT Request from client: [0x{self.connection.handle:04X}] {request}')
# Wait until we can send (only one pending command at a time for the connection)
response = None
async with self.request_semaphore:
assert self.pending_request is None
assert self.pending_response is None
assert(self.pending_request is None)
assert(self.pending_response is None)
# Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future()
self.pending_request = request
self.pending_request = request
try:
self.send_gatt_pdu(request.to_bytes())
response = await asyncio.wait_for(
self.pending_response, GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError as error:
response = await asyncio.wait_for(self.pending_response, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError:
logger.warning(color('!!! GATT Request timeout', 'red'))
raise core.TimeoutError(f'GATT timeout for {request.name}') from error
raise TimeoutError(f'GATT timeout for {request.name}')
finally:
self.pending_request = None
self.pending_request = None
self.pending_response = None
return response
def send_confirmation(self, confirmation):
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
logger.debug(f'GATT Confirmation from client: [0x{self.connection.handle:04X}] {confirmation}')
self.send_gatt_pdu(confirmation.to_bytes())
async def request_mtu(self, mtu):
@@ -293,66 +194,31 @@ class Client:
# We can only send one request per connection
if self.mtu_exchange_done:
return self.connection.att_mtu
return
# Send the request
self.mtu_exchange_done = True
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu))
response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu = mtu))
if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
response
)
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
return self.connection.att_mtu
self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu)
return self.mtu
def get_services_by_uuid(self, uuid):
return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid(self, uuid, service=None):
def get_characteristics_by_uuid(self, uuid, service = None):
services = [service] if service else self.services
return [
c
for c in [c for s in services for c in s.characteristics]
if c.uuid == uuid
]
def get_attribute_grouping(
self, attribute_handle: int
) -> Optional[
Union[
ServiceProxy,
Tuple[ServiceProxy, CharacteristicProxy],
Tuple[ServiceProxy, CharacteristicProxy, DescriptorProxy],
]
]:
"""
Get the attribute(s) associated with an attribute handle
"""
for service in self.services:
if service.handle == attribute_handle:
return service
if service.handle <= attribute_handle <= service.end_group_handle:
for characteristic in service.characteristics:
if characteristic.handle == attribute_handle:
return (service, characteristic)
if (
characteristic.handle
<= attribute_handle
<= characteristic.end_group_handle
):
for descriptor in characteristic.descriptors:
if descriptor.handle == attribute_handle:
return (service, characteristic, descriptor)
return None
return [c for c in [c for s in services for c in s.characteristics] if c.uuid == uuid]
def on_service_discovered(self, service):
'''Add a service to the service list if it wasn't already there'''
''' Add a service to the service list if it wasn't already there '''
already_known = False
for existing_service in self.services:
if existing_service.handle == service.handle:
@@ -361,7 +227,7 @@ class Client:
if not already_known:
self.services.append(service)
async def discover_services(self, uuids=None) -> List[ServiceProxy]:
async def discover_services(self, uuids = None):
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -370,9 +236,9 @@ class Client:
while starting_handle < 0xFFFF:
response = await self.send_request(
ATT_Read_By_Group_Type_Request(
starting_handle=starting_handle,
ending_handle=0xFFFF,
attribute_group_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
starting_handle = starting_handle,
ending_handle = 0xFFFF,
attribute_group_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
)
)
if response is None:
@@ -383,30 +249,16 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
'!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
raise ATT_Error(
error_code=response.error_code,
message='Unexpected error while discovering services',
)
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return
break
for (
attribute_handle,
end_group_handle,
attribute_value,
) in response.attributes:
if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
for attribute_handle, end_group_handle, attribute_value in response.attributes:
if attribute_handle < starting_handle or end_group_handle < attribute_handle:
# Something's not right
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return []
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}')
return
# Create a service proxy for this service
service = ServiceProxy(
@@ -414,7 +266,7 @@ class Client:
attribute_handle,
end_group_handle,
UUID.from_bytes(attribute_value),
True,
True
)
# Filter out returned services based on the given uuids list
@@ -439,7 +291,7 @@ class Client:
'''
# Force uuid to be a UUID object
if isinstance(uuid, str):
if type(uuid) is str:
uuid = UUID(uuid)
starting_handle = 0x0001
@@ -447,10 +299,10 @@ class Client:
while starting_handle < 0xFFFF:
response = await self.send_request(
ATT_Find_By_Type_Value_Request(
starting_handle=starting_handle,
ending_handle=0xFFFF,
attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
attribute_value=uuid.to_pdu_bytes(),
starting_handle = starting_handle,
ending_handle = 0xFFFF,
attribute_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
attribute_value = uuid.to_pdu_bytes()
)
)
if response is None:
@@ -461,29 +313,19 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
'!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return
break
for attribute_handle, end_group_handle in response.handles_information:
if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
if attribute_handle < starting_handle or end_group_handle < attribute_handle:
# Something's not right
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}')
return
# Create a service proxy for this service
service = ServiceProxy(
self, attribute_handle, end_group_handle, uuid, True
)
service = ServiceProxy(self, attribute_handle, end_group_handle, uuid, True)
# Add the service to the peer's service list
services.append(service)
@@ -502,40 +344,37 @@ class Client:
return services
async def discover_included_services(self, _service):
async def discover_included_services(self, service):
'''
See Vol 3, Part G - 4.5.1 Find Included Services
'''
# TODO
return []
async def discover_characteristics(
self, uuids, service: Optional[ServiceProxy]
) -> List[CharacteristicProxy]:
async def discover_characteristics(self, uuids, service):
'''
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
Discover Characteristics by UUID
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2 Discover Characteristics by UUID
'''
# Cast the UUIDs type from string to object if needed
uuids = [UUID(uuid) if isinstance(uuid, str) else uuid for uuid in uuids]
uuids = [UUID(uuid) if type(uuid) is str else uuid for uuid in uuids]
# Decide which services to discover for
services = [service] if service else self.services
# Perform characteristic discovery for each service
discovered_characteristics: List[CharacteristicProxy] = []
discovered_characteristics = []
for service in services:
starting_handle = service.handle
ending_handle = service.end_group_handle
ending_handle = service.end_group_handle
characteristics: List[CharacteristicProxy] = []
characteristics = []
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Read_By_Type_Request(
starting_handle=starting_handle,
ending_handle=ending_handle,
attribute_type=GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
starting_handle = starting_handle,
ending_handle = ending_handle,
attribute_type = GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
)
)
if response is None:
@@ -546,14 +385,9 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
'!!! unexpected error while discovering characteristics: '
f'{HCI_Constant.error_name(response.error_code)}'
)
raise ATT_Error(
error_code=response.error_code,
message='Unexpected error while discovering characteristics',
)
logger.warning(f'!!! unexpected error while discovering characteristics: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return
break
# Stop if for some reason the list was empty
@@ -569,9 +403,7 @@ class Client:
properties, handle = struct.unpack_from('<BH', attribute_value)
characteristic_uuid = UUID.from_bytes(attribute_value[3:])
characteristic = CharacteristicProxy(
self, handle, 0, characteristic_uuid, properties
)
characteristic = CharacteristicProxy(self, handle, 0, characteristic_uuid, properties)
# Set the previous characteristic's end handle
if characteristics:
@@ -587,37 +419,31 @@ class Client:
characteristics[-1].end_group_handle = service.end_group_handle
# Set the service's characteristics
characteristics = [
c for c in characteristics if not uuids or c.uuid in uuids
]
characteristics = [c for c in characteristics if not uuids or c.uuid in uuids]
service.characteristics = characteristics
discovered_characteristics.extend(characteristics)
return discovered_characteristics
async def discover_descriptors(
self,
characteristic: Optional[CharacteristicProxy] = None,
start_handle=None,
end_handle=None,
) -> List[DescriptorProxy]:
async def discover_descriptors(self, characteristic = None, start_handle = None, end_handle = None):
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
'''
if characteristic:
starting_handle = characteristic.handle + 1
ending_handle = characteristic.end_group_handle
ending_handle = characteristic.end_group_handle
elif start_handle and end_handle:
starting_handle = start_handle
ending_handle = end_handle
ending_handle = end_handle
else:
return []
descriptors: List[DescriptorProxy] = []
descriptors = []
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Find_Information_Request(
starting_handle=starting_handle, ending_handle=ending_handle
starting_handle = starting_handle,
ending_handle = ending_handle
)
)
if response is None:
@@ -628,10 +454,7 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
'!!! unexpected error while discovering descriptors: '
f'{HCI_Constant.error_name(response.error_code)}'
)
logger.warning(f'!!! unexpected error while discovering descriptors: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return []
break
@@ -647,9 +470,7 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}')
return []
descriptor = DescriptorProxy(
self, attribute_handle, UUID.from_bytes(attribute_uuid)
)
descriptor = DescriptorProxy(self, attribute_handle, UUID.from_bytes(attribute_uuid))
descriptors.append(descriptor)
# TODO: read descriptor value
@@ -667,12 +488,13 @@ class Client:
Discover all attributes, regardless of type
'''
starting_handle = 0x0001
ending_handle = 0xFFFF
ending_handle = 0xFFFF
attributes = []
while True:
response = await self.send_request(
ATT_Find_Information_Request(
starting_handle=starting_handle, ending_handle=ending_handle
starting_handle = starting_handle,
ending_handle = ending_handle
)
)
if response is None:
@@ -682,10 +504,7 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
'!!! unexpected error while discovering attributes: '
f'{HCI_Constant.error_name(response.error_code)}'
)
logger.warning(f'!!! unexpected error while discovering attributes: {HCI_Constant.error_name(response.error_code)}')
return []
break
@@ -695,9 +514,7 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}')
return []
attribute = AttributeProxy(
self, attribute_handle, 0, UUID.from_bytes(attribute_uuid)
)
attribute = AttributeProxy(self, attribute_handle, 0, UUID.from_bytes(attribute_uuid))
attributes.append(attribute)
# Move on to the next attributes
@@ -705,85 +522,60 @@ class Client:
return attributes
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
# If we haven't already discovered the descriptors for this characteristic,
# do it now
async def subscribe(self, characteristic, subscriber=None):
# If we haven't already discovered the descriptors for this characteristic, do it now
if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor
cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
cccd = characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR)
if not cccd:
logger.warning('subscribing to characteristic with no CCCD descriptor')
return
if (
characteristic.properties & Characteristic.Properties.NOTIFY
and characteristic.properties & Characteristic.Properties.INDICATE
):
if prefer_notify:
bits = ClientCharacteristicConfigurationBits.NOTIFICATION
subscribers = self.notification_subscribers
else:
bits = ClientCharacteristicConfigurationBits.INDICATION
subscribers = self.indication_subscribers
elif characteristic.properties & Characteristic.Properties.NOTIFY:
bits = ClientCharacteristicConfigurationBits.NOTIFICATION
subscribers = self.notification_subscribers
elif characteristic.properties & Characteristic.Properties.INDICATE:
bits = ClientCharacteristicConfigurationBits.INDICATION
subscribers = self.indication_subscribers
else:
raise InvalidStateError("characteristic is not notify or indicate")
# Set the subscription bits and select the subscriber set
bits = 0
subscriber_sets = []
if characteristic.properties & Characteristic.NOTIFY:
bits |= 0x0001
subscriber_sets.append(self.notification_subscribers.setdefault(characteristic.handle, set()))
if characteristic.properties & Characteristic.INDICATE:
bits |= 0x0002
subscriber_sets.append(self.indication_subscribers.setdefault(characteristic.handle, set()))
# Add subscribers to the sets
subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None:
subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication
# is received
subscriber_set.add(characteristic)
for subscriber_set in subscriber_sets:
if subscriber is not None:
subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the characteristic
# emitting an 'update' event when a notification or indication is received
subscriber_set.add(characteristic)
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe(self, characteristic, subscriber=None):
# If we haven't already discovered the descriptors for this characteristic,
# do it now
# If we haven't already discovered the descriptors for this characteristic, do it now
if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor
cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
cccd = characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR)
if not cccd:
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return
if subscriber is not None:
# Remove matching subscriber from subscriber sets
for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
for subscriber_set in (self.notification_subscribers, self.indication_subscribers):
subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers:
subscribers.remove(subscriber)
# Cleanup if we removed the last one
if not subscribers:
del subscriber_set[characteristic.handle]
else:
# Remove all subscribers for this attribute from the sets!
self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None)
if not self.notification_subscribers and not self.indication_subscribers:
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(self, attribute, no_long_read=False):
'''
@@ -793,10 +585,8 @@ class Client:
'''
# Send a request to read
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
response = await self.send_request(
ATT_Read_Request(attribute_handle=attribute_handle)
)
attribute_handle = attribute if type(attribute) is int else attribute.handle
response = await self.send_request(ATT_Read_Request(attribute_handle = attribute_handle))
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
@@ -804,45 +594,39 @@ class Client:
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
response
)
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
if not no_long_read and len(attribute_value) == self.mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
response = await self.send_request(
ATT_Read_Blob_Request(
attribute_handle=attribute_handle, value_offset=offset
)
ATT_Read_Blob_Request(attribute_handle = attribute_handle, value_offset = offset)
)
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code in (
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_INVALID_OFFSET_ERROR,
):
if response.error_code == ATT_ATTRIBUTE_NOT_LONG_ERROR or response.error_code == ATT_INVALID_OFFSET_ERROR:
break
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
response
)
part = response.part_attribute_value
attribute_value += part
if len(part) < self.connection.att_mtu - 1:
if len(part) < self.mtu - 1:
break
offset += len(part)
self.cache_value(attribute_handle, attribute_value)
# Return the value as bytes
return attribute_value
@@ -853,18 +637,18 @@ class Client:
if service is None:
starting_handle = 0x0001
ending_handle = 0xFFFF
ending_handle = 0xFFFF
else:
starting_handle = service.handle
ending_handle = service.end_group_handle
ending_handle = service.end_group_handle
characteristics_values = []
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Read_By_Type_Request(
starting_handle=starting_handle,
ending_handle=ending_handle,
attribute_type=uuid,
starting_handle = starting_handle,
ending_handle = ending_handle,
attribute_type = uuid
)
)
if response is None:
@@ -875,10 +659,7 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(
'!!! unexpected error while reading characteristics: '
f'{HCI_Constant.error_name(response.error_code)}'
)
logger.warning(f'!!! unexpected error while reading characteristics: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return []
break
@@ -903,54 +684,47 @@ class Client:
async def write_value(self, attribute, value, with_response=False):
'''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic Value
`attribute` can be an Attribute object, or a handle value
'''
# Send a request or command to write
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
attribute_handle = attribute if type(attribute) is int else attribute.handle
if with_response:
response = await self.send_request(
ATT_Write_Request(
attribute_handle=attribute_handle, attribute_value=value
attribute_handle = attribute_handle,
attribute_value = value
)
)
if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response,
ATT_PDU.error_name(response.error_code), response
)
else:
await self.send_command(
ATT_Write_Command(
attribute_handle=attribute_handle, attribute_value=value
attribute_handle = attribute_handle,
attribute_value = value
)
)
def on_gatt_pdu(self, att_pdu):
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
logger.debug(f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}')
if att_pdu.op_code in ATT_RESPONSES:
if self.pending_request is None:
# Not expected!
logger.warning('!!! unexpected response, there is no pending request')
return
# Sanity check: the response should match the pending request unless it is
# an error response
# Sanity check: the response should match the pending request unless it is an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace(
'_REQUEST', '_RESPONSE'
)
expected_response_name = self.pending_request.name.replace('_REQUEST', '_RESPONSE')
if att_pdu.name != expected_response_name:
logger.warning(
f'!!! mismatched response: expected {expected_response_name}'
)
logger.warning(f'!!! mismatched response: expected {expected_response_name}')
return
# Return the response to the coroutine that is waiting for it
@@ -961,24 +735,13 @@ class Client:
if handler is not None:
handler(att_pdu)
else:
logger.warning(
color(
'--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'red',
)
+ str(att_pdu)
)
logger.warning(f'{color(f"--- Ignoring GATT Response from [0x{self.connection.handle:04X}]:", "red")} {att_pdu}')
def on_att_handle_value_notification(self, notification):
# Call all subscribers
subscribers = self.notification_subscribers.get(
notification.attribute_handle, []
)
subscribers = self.notification_subscribers.get(notification.attribute_handle, [])
if not subscribers:
logger.warning('!!! received notification with no subscriber')
self.cache_value(notification.attribute_handle, notification.attribute_value)
for subscriber in subscribers:
if callable(subscriber):
subscriber(notification.attribute_value)
@@ -990,8 +753,6 @@ class Client:
subscribers = self.indication_subscribers.get(indication.attribute_handle, [])
if not subscribers:
logger.warning('!!! received indication with no subscriber')
self.cache_value(indication.attribute_handle, indication.attribute_value)
for subscriber in subscribers:
if callable(subscriber):
subscriber(indication.attribute_value)
@@ -1000,9 +761,3 @@ class Client:
# Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation())
def cache_value(self, attribute_handle: int, value: bytes):
self.cached_values[attribute_handle] = (
datetime.now(),
value,
)

View File

@@ -26,52 +26,13 @@
import asyncio
import logging
from collections import defaultdict
import struct
from typing import List, Tuple, Optional, TypeVar, Type
from pyee import EventEmitter
from colors import color
from .colors import color
from .core import UUID
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
ATT_DEFAULT_MTU,
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
ATT_INVALID_HANDLE_ERROR,
ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS,
ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error,
ATT_Error_Response,
ATT_Exchange_MTU_Response,
ATT_Find_By_Type_Value_Response,
ATT_Find_Information_Response,
ATT_Handle_Value_Indication,
ATT_Handle_Value_Notification,
ATT_Read_Blob_Response,
ATT_Read_By_Group_Type_Response,
ATT_Read_By_Type_Response,
ATT_Read_Response,
ATT_Write_Response,
Attribute,
)
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
CharacteristicDeclaration,
CharacteristicValue,
Descriptor,
Service,
)
from .core import *
from .hci import *
from .att import *
from .gatt import *
# -----------------------------------------------------------------------------
# Logging
@@ -79,49 +40,27 @@ from .gatt import (
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
# GATT Server
# -----------------------------------------------------------------------------
class Server(EventEmitter):
attributes: List[Attribute]
def __init__(self, device):
super().__init__()
self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = (
GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
)
self.subscribers = (
{}
) # Map of subscriber states by connection handle and attribute handle
self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
self.mtus = {} # Map of ATT MTU values by connection handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
def __str__(self):
return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle, pdu):
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
def next_handle(self):
return 1 + len(self.attributes)
def get_advertising_service_data(self):
return {
attribute: data
for attribute in self.attributes
if isinstance(attribute, Service)
and (data := attribute.get_advertising_data())
}
def get_attribute(self, handle):
attribute = self.attributes_by_handle.get(handle)
if attribute:
@@ -135,90 +74,15 @@ class Server(EventEmitter):
return attribute
return None
AttributeGroupType = TypeVar('AttributeGroupType', Service, Characteristic)
def get_attribute_group(
self, handle: int, group_type: Type[AttributeGroupType]
) -> Optional[AttributeGroupType]:
return next(
(
attribute
for attribute in self.attributes
if isinstance(attribute, group_type)
and attribute.handle <= handle <= attribute.end_group_handle
),
None,
)
def get_service_attribute(self, service_uuid: UUID) -> Optional[Service]:
return next(
(
attribute
for attribute in self.attributes
if attribute.type == GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
and isinstance(attribute, Service)
and attribute.uuid == service_uuid
),
None,
)
def get_characteristic_attributes(
self, service_uuid: UUID, characteristic_uuid: UUID
) -> Optional[Tuple[CharacteristicDeclaration, Characteristic]]:
service_handle = self.get_service_attribute(service_uuid)
if not service_handle:
return None
return next(
(
(attribute, self.get_attribute(attribute.characteristic.handle))
for attribute in map(
self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1),
)
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
and attribute.characteristic.uuid == characteristic_uuid
),
None,
)
def get_descriptor_attribute(
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
) -> Optional[Descriptor]:
characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid
)
if not characteristics:
return None
(_, characteristic_value) = characteristics
return next(
(
attribute
for attribute in map(
self.get_attribute,
range(
characteristic_value.handle + 1,
characteristic_value.end_group_handle + 1,
),
)
if attribute.type == descriptor_uuid
),
None,
)
def add_attribute(self, attribute):
# Assign a handle to this attribute
attribute.handle = self.next_handle()
attribute.end_group_handle = (
attribute.handle
) # TODO: keep track of descriptors in the group
attribute.end_group_handle = attribute.handle # TODO: keep track of descriptors in the group
# Add this attribute to the list
self.attributes.append(attribute)
def add_service(self, service: Service):
def add_service(self, service):
# Add the service attribute to the DB
self.add_attribute(service)
@@ -226,9 +90,16 @@ class Server(EventEmitter):
# Add all characteristics
for characteristic in service.characteristics:
# Add a Characteristic Declaration
characteristic_declaration = CharacteristicDeclaration(
characteristic, self.next_handle() + 1
# Add a Characteristic Declaration (Vol 3, Part G - 3.3.1 Characteristic Declaration)
declaration_bytes = struct.pack(
'<BH',
characteristic.properties,
self.next_handle() + 1, # The value will be the next attribute after this declaration
) + characteristic.uuid.to_pdu_bytes()
characteristic_declaration = Attribute(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
Attribute.READABLE,
declaration_bytes
)
self.add_attribute(characteristic_declaration)
@@ -242,29 +113,17 @@ class Server(EventEmitter):
# If the characteristic supports subscriptions, add a CCCD descriptor
# unless there is one already
if (
characteristic.properties
& (
Characteristic.Properties.NOTIFY
| Characteristic.Properties.INDICATE
)
and characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
is None
characteristic.properties & (Characteristic.NOTIFY | Characteristic.INDICATE) and
characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR) is None
):
self.add_attribute(
# pylint: disable=line-too-long
Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
Attribute.READABLE | Attribute.WRITEABLE,
CharacteristicValue(
read=lambda connection, characteristic=characteristic: self.read_cccd(
connection, characteristic
),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
read=lambda connection, characteristic=characteristic: self.read_cccd(connection, characteristic),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(connection, characteristic, value)
)
)
)
@@ -291,39 +150,26 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value):
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
logger.debug(f'Subscription update for connection={connection.handle:04X}, handle={characteristic.handle:04X}: {value.hex()}')
# Sanity check
if len(value) != 2:
logger.warning('CCCD value not 2 bytes long')
logger.warn('CCCD value not 2 bytes long')
return
cccds = self.subscribers.setdefault(connection.handle, {})
cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}')
notify_enabled = value[0] & 0x01 != 0
indicate_enabled = value[0] & 0x02 != 0
characteristic.emit(
'subscription', connection, notify_enabled, indicate_enabled
)
self.emit(
'characteristic_subscription',
connection,
characteristic,
notify_enabled,
indicate_enabled,
)
notify_enabled = (value[0] & 0x01 != 0)
indicate_enabled = (value[0] & 0x02 != 0)
characteristic.emit('subscription', connection, notify_enabled, indicate_enabled)
self.emit('characteristic_subscription', connection, characteristic, notify_enabled, indicate_enabled)
def send_response(self, connection, response):
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}')
self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, value=None, force=False):
async def notify_subscriber(self, connection, attribute, force=False):
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -332,35 +178,47 @@ class Server(EventEmitter):
return
cccd = subscribers.get(attribute.handle)
if not cccd:
logger.debug(
f'not notifying, no subscribers for handle {attribute.handle:04X}'
)
logger.debug(f'not notifying, no subscribers for handle {attribute.handle:04X}')
return
if len(cccd) != 2 or (cccd[0] & 0x01 == 0):
logger.debug(f'not notifying, cccd={cccd.hex()}')
return
# Get or encode the value
value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Get the value
value = attribute.read_value(connection)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
# Notify
notification = ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value
attribute_handle = attribute.handle,
attribute_value = value
)
logger.debug(
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
)
self.send_gatt_pdu(connection.handle, bytes(notification))
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
self.send_gatt_pdu(connection.handle, notification.to_bytes())
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
async def notify_subscribers(self, attribute, force=False):
# Get all the connections for which there's at least one subscription
connections = [
connection for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
if connection is not None
]
# Notify for each connection
if connections:
await asyncio.wait([
self.notify_subscriber(connection, attribute, force)
for connection in connections
])
async def indicate_subscriber(self, connection, attribute, force=False):
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -369,84 +227,64 @@ class Server(EventEmitter):
return
cccd = subscribers.get(attribute.handle)
if not cccd:
logger.debug(
f'not indicating, no subscribers for handle {attribute.handle:04X}'
)
logger.debug(f'not indicating, no subscribers for handle {attribute.handle:04X}')
return
if len(cccd) != 2 or (cccd[0] & 0x02 == 0):
logger.debug(f'not indicating, cccd={cccd.hex()}')
return
# Get or encode the value
value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Get the value
value = attribute.read_value(connection)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
# Indicate
indication = ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
attribute_handle = attribute.handle,
attribute_value = value
)
logger.debug(f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}')
# Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[connection.handle]:
assert self.pending_confirmations[connection.handle] is None
assert(self.pending_confirmations[connection.handle] is None)
# Create a future value to hold the eventual response
self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
self.pending_confirmations[connection.handle] = asyncio.get_running_loop().create_future()
try:
self.send_gatt_pdu(connection.handle, indication.to_bytes())
await asyncio.wait_for(
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError as error:
await asyncio.wait_for(self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error
raise TimeoutError(f'GATT timeout for {indication.name}')
finally:
self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers(
self, indicate, attribute, value=None, force=False
):
async def indicate_subscribers(self, attribute):
# Get all the connections for which there's at least one subscription
connections = [
connection
for connection in [
connection for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
if subscribers.get(attribute.handle)
]
if connection is not None
]
# Indicate or notify for each connection
# Indicate for each connection
if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
await asyncio.wait(
[
asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
]
)
async def notify_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
async def indicate_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
await asyncio.wait([
self.indicate_subscriber(connection, attribute)
for connection in connections
])
def on_disconnection(self, connection):
if connection.handle in self.mtus:
del self.mtus[connection.handle]
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
@@ -464,17 +302,17 @@ class Server(EventEmitter):
except ATT_Error as error:
logger.debug(f'normal exception returned by handler: {error}')
response = ATT_Error_Response(
request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error=error.att_handle,
error_code=error.error_code,
request_opcode_in_error = att_pdu.op_code,
attribute_handle_in_error = error.att_handle,
error_code = error.error_code
)
self.send_response(connection, response)
except Exception as error:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
response = ATT_Error_Response(
request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error=0x0000,
error_code=ATT_UNLIKELY_ERROR_ERROR,
request_opcode_in_error = att_pdu.op_code,
attribute_handle_in_error = 0x0000,
error_code = ATT_UNLIKELY_ERROR_ERROR
)
self.send_response(connection, response)
raise error
@@ -485,13 +323,10 @@ class Server(EventEmitter):
self.on_att_request(connection, att_pdu)
else:
# Just ignore
logger.warning(
color(
f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ',
'red',
)
+ str(att_pdu)
)
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
def get_mtu(self, connection):
return self.mtus.get(connection.handle, ATT_DEFAULT_MTU)
#######################################################
# ATT handlers
@@ -500,16 +335,11 @@ class Server(EventEmitter):
'''
Handler for requests without a more specific handler
'''
logger.warning(
color(
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red'
)
+ str(pdu)
)
logger.warning(f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}')
response = ATT_Error_Response(
request_opcode_in_error=pdu.op_code,
attribute_handle_in_error=0x0000,
error_code=ATT_REQUEST_NOT_SUPPORTED_ERROR,
request_opcode_in_error = pdu.op_code,
attribute_handle_in_error = 0x0000,
error_code = ATT_REQUEST_NOT_SUPPORTED_ERROR
)
self.send_response(connection, response)
@@ -517,18 +347,12 @@ class Server(EventEmitter):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
self.send_response(
connection, ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
)
mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu))
self.mtus[connection.handle] = mtu
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu))
# Compute the final MTU
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
def on_att_find_information_request(self, connection, request):
'''
@@ -536,30 +360,25 @@ class Server(EventEmitter):
'''
# Check the request parameters
if (
request.starting_handle == 0
or request.starting_handle > request.ending_handle
):
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
),
)
if request.starting_handle == 0 or request.starting_handle > request.ending_handle:
self.send_response(connection, ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_INVALID_HANDLE_ERROR
))
return
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = self.get_mtu(connection) - 2
attributes = []
uuid_size = 0
for attribute in (
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
attribute for attribute in self.attributes if
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle
):
# TODO: check permissions
this_uuid_size = len(attribute.type.to_pdu_bytes())
if attributes:
@@ -583,14 +402,14 @@ class Server(EventEmitter):
for attribute in attributes
]
response = ATT_Find_Information_Response(
format=1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2,
information_data=b''.join(information_data_list),
format = 1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2,
information_data = b''.join(information_data_list)
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
)
self.send_response(connection, response)
@@ -601,16 +420,15 @@ class Server(EventEmitter):
'''
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
pdu_space_available = self.get_mtu(connection) - 2
attributes = []
for attribute in (
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and attribute.read_value(connection) == request.attribute_value
and pdu_space_available >= 4
attribute for attribute in self.attributes if
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle and
attribute.type == request.attribute_type and
attribute.read_value(connection) == request.attribute_value and
pdu_space_available >= 4
):
# TODO: check permissions
@@ -622,27 +440,25 @@ class Server(EventEmitter):
if attributes:
handles_information_list = []
for attribute in attributes:
if attribute.type in (
if attribute.type in {
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
):
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
}:
# Part of a group
group_end_handle = attribute.end_group_handle
else:
# Not part of a group
group_end_handle = attribute.handle
handles_information_list.append(
struct.pack('<HH', attribute.handle, group_end_handle)
)
handles_information_list.append(struct.pack('<HH', attribute.handle, group_end_handle))
response = ATT_Find_By_Type_Value_Response(
handles_information_list=b''.join(handles_information_list)
handles_information_list = b''.join(handles_information_list)
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
)
self.send_response(connection, response)
@@ -652,39 +468,21 @@ class Server(EventEmitter):
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
pdu_space_available = connection.att_mtu - 2
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
attributes = []
for attribute in (
attribute
for attribute in self.attributes
if attribute.type == request.attribute_type
and attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and pdu_space_available
attribute for attribute in self.attributes if
attribute.type == request.attribute_type and
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle and
pdu_space_available
):
try:
attribute_value = attribute.read_value(connection)
except ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
if not attributes:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=attribute.handle,
error_code=error.error_code,
)
break
# TODO: check permissions
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 4, 253)
attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -702,14 +500,17 @@ class Server(EventEmitter):
pdu_space_available -= entry_size
if attributes:
attribute_data_list = [
struct.pack('<H', handle) + value for handle, value in attributes
]
attribute_data_list = [struct.pack('<H', handle) + value for handle, value in attributes]
response = ATT_Read_By_Type_Response(
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
length = entry_size,
attribute_data_list = b''.join(attribute_data_list)
)
else:
logging.debug(f"not found {request}")
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
)
self.send_response(connection, response)
@@ -719,22 +520,17 @@ class Server(EventEmitter):
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=error.error_code,
)
else:
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(attribute_value=value[:value_size])
# TODO: check permissions
value = attribute.read_value(connection)
value_size = min(self.get_mtu(connection) - 1, len(value))
response = ATT_Read_Response(
attribute_value = value[:value_size]
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR
)
self.send_response(connection, response)
@@ -744,41 +540,31 @@ class Server(EventEmitter):
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = attribute.read_value(connection)
except ATT_Error as error:
# TODO: check permissions
mtu = self.get_mtu(connection)
value = attribute.read_value(connection)
if request.value_offset > len(value):
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=error.error_code,
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR
)
elif len(value) <= mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
)
else:
if request.value_offset > len(value):
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
)
else:
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
)
response = ATT_Read_Blob_Response(
part_attribute_value=value[
request.value_offset : request.value_offset + part_size
]
)
part_size = min(mtu - 1, len(value) - request.value_offset)
response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR
)
self.send_response(connection, response)
@@ -786,33 +572,32 @@ class Server(EventEmitter):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
'''
if request.attribute_group_type not in (
if request.attribute_group_type not in {
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
):
GATT_INCLUDE_ATTRIBUTE_TYPE
}:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_UNSUPPORTED_GROUP_TYPE_ERROR
)
self.send_response(connection, response)
return
pdu_space_available = connection.att_mtu - 2
mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
attributes = []
for attribute in (
attribute
for attribute in self.attributes
if attribute.type == request.attribute_group_type
and attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and pdu_space_available
attribute for attribute in self.attributes if
attribute.type == request.attribute_group_type and
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle and
pdu_space_available
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -826,9 +611,7 @@ class Server(EventEmitter):
break
# Add the attribute to the list
attributes.append(
(attribute.handle, attribute.end_group_handle, attribute_value)
)
attributes.append((attribute.handle, attribute.end_group_handle, attribute_value))
pdu_space_available -= entry_size
if attributes:
@@ -837,14 +620,14 @@ class Server(EventEmitter):
for handle, end_group_handle, value in attributes
]
response = ATT_Read_By_Group_Type_Response(
length=len(attribute_data_list[0]),
attribute_data_list=b''.join(attribute_data_list),
length = len(attribute_data_list[0]),
attribute_data_list = b''.join(attribute_data_list)
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
)
self.send_response(connection, response)
@@ -857,28 +640,22 @@ class Server(EventEmitter):
# Check that the attribute exists
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_HANDLE_ERROR,
),
)
self.send_response(connection, ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR
))
return
# TODO: check permissions
# Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code=ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
),
)
self.send_response(connection, ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_ATTRIBUTE_LENGTH_ERROR
))
return
# Accept the value
@@ -909,15 +686,13 @@ class Server(EventEmitter):
except Exception as error:
logger.warning(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, _confirmation):
def on_att_handle_value_confirmation(self, connection, confirmation):
'''
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
'''
if self.pending_confirmations[connection.handle] is None:
# Not expected!
logger.warning(
'!!! unexpected confirmation, there is no pending indication'
)
logger.warning('!!! unexpected confirmation, there is no pending indication')
return
self.pending_confirmations[connection.handle].set_result(None)

File diff suppressed because it is too large Load Diff

View File

@@ -16,11 +16,10 @@
# Imports
# -----------------------------------------------------------------------------
import logging
from colors import color
from .colors import color
from .att import ATT_CID, ATT_PDU
from .smp import SMP_CID, SMP_Command
from .core import name_or_number
from .gatt import ATT_PDU, ATT_CID
from .l2cap import (
L2CAP_PDU,
L2CAP_CONNECTION_REQUEST,
@@ -28,17 +27,20 @@ from .l2cap import (
L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame,
L2CAP_Connection_Response,
L2CAP_Connection_Response
)
from .hci import (
HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler,
HCI_AclDataPacketAssembler
)
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
from .avdtp import (
MessageAssembler as AVDTP_MessageAssembler,
AVDTP_PSM
)
# -----------------------------------------------------------------------------
# Logging
@@ -49,8 +51,8 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
PSM_NAMES = {
RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP'
SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP'
# TODO: add more PSM values
}
@@ -59,23 +61,19 @@ PSM_NAMES = {
class PacketTracer:
class AclStream:
def __init__(self, analyzer):
self.analyzer = analyzer
self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
# pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(att_pdu)
elif l2cap_pdu.cid == SMP_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command)
elif l2cap_pdu.cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID:
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame)
@@ -83,26 +81,16 @@ class PacketTracer:
if control_frame.code == L2CAP_CONNECTION_REQUEST:
self.psms[control_frame.source_cid] = control_frame.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
if (
control_frame.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if control_frame.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid):
# Found a pending connection
self.psms[control_frame.destination_cid] = psm
# For AVDTP connections, create a packet assembler for
# each direction
# For AVDTP connections, create a packet assembler for each direction
if psm == AVDTP_PSM:
self.avdtp_assemblers[
control_frame.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
control_frame.destination_cid
] = AVDTP_MessageAssembler(
self.peer.on_avdtp_message
)
self.avdtp_assemblers[control_frame.source_cid] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[control_frame.destination_cid] = AVDTP_MessageAssembler(self.peer.on_avdtp_message)
else:
# Try to find the PSM associated with this PDU
@@ -114,42 +102,31 @@ class PacketTracer:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}')
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler:
assembler.on_pdu(l2cap_pdu.payload)
else:
psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM={psm_string}]: {l2cap_pdu.payload.hex()}'
)
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}')
else:
self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(self, transaction_label, message):
self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
self.analyzer.emit(f'{color("AVDTP", "green")} [{transaction_label}] {message}')
def feed_packet(self, packet):
self.packet_assembler.feed_packet(packet)
class Analyzer:
def __init__(self, label, emit_message):
self.label = label
self.label = label
self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle
self.peer = None # Analyzer in the other direction
self.acl_streams = {} # ACL streams, by connection handle
self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle):
logger.info(
f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}'
)
logger.info(f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}')
stream = PacketTracer.AclStream(self)
self.acl_streams[connection_handle] = stream
@@ -162,10 +139,7 @@ class PacketTracer:
def end_acl_stream(self, connection_handle):
if connection_handle in self.acl_streams:
logger.info(
f'[{self.label}] --- Removing ACL stream for connection '
f'0x{connection_handle:04X}'
)
logger.info(f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}')
del self.acl_streams[connection_handle]
# Let the other forwarder know so it can cleanup its stream as well
@@ -197,13 +171,9 @@ class PacketTracer:
self,
host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info,
emit_message=logger.info
):
self.host_to_controller_analyzer = PacketTracer.Analyzer(
host_to_controller_label, emit_message
)
self.controller_to_host_analyzer = PacketTracer.Analyzer(
controller_to_host_label, emit_message
)
self.host_to_controller_analyzer = PacketTracer.Analyzer(host_to_controller_label, emit_message)
self.controller_to_host_analyzer = PacketTracer.Analyzer(controller_to_host_label, emit_message)
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer

View File

@@ -18,8 +18,7 @@
import logging
import asyncio
import collections
from .colors import color
from colors import color
# -----------------------------------------------------------------------------
@@ -35,16 +34,16 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class HfpProtocol:
def __init__(self, dlc):
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
self.lines_available = asyncio.Event()
dlc.sink = self.feed
def feed(self, data):
# Convert the data to a string if needed
if isinstance(data, bytes):
if type(data) == bytes:
data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}')
@@ -53,7 +52,7 @@ class HfpProtocol:
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
line = self.buffer[:separator].strip()
self.buffer = self.buffer[separator + 1 :]
self.buffer = self.buffer[separator + 1:]
if len(line) > 0:
self.on_line(line)
@@ -80,16 +79,16 @@ class HfpProtocol:
async def initialize_service(self):
# Perform Service Level Connection Initialization
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
await (self.next_line())
await (self.next_line())
line = await(self.next_line())
line = await(self.next_line())
self.send_command_line('AT+CIND=?')
await (self.next_line())
await (self.next_line())
line = await(self.next_line())
line = await(self.next_line())
self.send_command_line('AT+CIND?')
await (self.next_line())
await (self.next_line())
line = await(self.next_line())
line = await(self.next_line())
self.send_command_line('AT+CMER=3,0,0,1')
await (self.next_line())
line = await(self.next_line())

View File

@@ -16,62 +16,16 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import collections
import logging
import struct
from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
from typing import Optional
from .hci import (
Address,
HCI_ACL_DATA_PACKET,
HCI_COMMAND_COMPLETE_EVENT,
HCI_COMMAND_PACKET,
HCI_EVENT_PACKET,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND,
HCI_RESET_COMMAND,
HCI_SUCCESS,
HCI_SUPPORTED_COMMANDS_FLAGS,
HCI_VERSION_BLUETOOTH_CORE_4_0,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
HCI_Constant,
HCI_Error,
HCI_LE_Long_Term_Key_Request_Negative_Reply_Command,
HCI_LE_Long_Term_Key_Request_Reply_Command,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_Read_Local_Supported_Features_Command,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_LE_Remote_Connection_Parameter_Request_Reply_Command,
HCI_LE_Set_Event_Mask_Command,
HCI_LE_Write_Suggested_Default_Data_Length_Command,
HCI_Link_Key_Request_Negative_Reply_Command,
HCI_Link_Key_Request_Reply_Command,
HCI_Packet,
HCI_Read_Buffer_Size_Command,
HCI_Read_Local_Supported_Commands_Command,
HCI_Read_Local_Version_Information_Command,
HCI_Reset_Command,
HCI_Set_Event_Mask_Command,
)
from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
ConnectionPHY,
ConnectionParameters,
)
from .utils import AbortableEventEmitter
from pyee import EventEmitter
from colors import color
from .hci import *
from .l2cap import *
from .att import *
from .gatt import *
from .smp import *
from .core import ConnectionParameters
# -----------------------------------------------------------------------------
# Logging
@@ -82,60 +36,58 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
# fmt: on
# -----------------------------------------------------------------------------
class Connection:
def __init__(self, host, handle, peer_address, transport):
self.host = host
self.handle = handle
self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
def __init__(self, host, handle, role, peer_address):
self.host = host
self.handle = handle
self.role = role
self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet)
def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
if l2cap_pdu.cid == ATT_CID:
self.host.on_gatt_pdu(self, l2cap_pdu.payload)
elif l2cap_pdu.cid == SMP_CID:
self.host.on_smp_pdu(self, l2cap_pdu.payload)
else:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None):
class Host(EventEmitter):
def __init__(self, controller_source = None, controller_sink = None):
super().__init__()
self.hci_sink = None
self.ready = False # True when we can accept incoming packets
self.reset_done = False
self.connections = {} # Connections, by connection handle
self.pending_command = None
self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
self.hci_sink = None
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.pending_command = None
self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
self.suggested_max_tx_octets = 251 # Max allowed
self.suggested_max_tx_time = 2120 # Max allowed
self.command_semaphore = asyncio.Semaphore(1)
self.long_term_key_provider = None
self.link_key_provider = None
self.pairing_io_capability_provider = None # Classic only
self.snooper = None
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
self.command_semaphore = asyncio.Semaphore(1)
self.long_term_key_provider = None
self.link_key_provider = None
self.pairing_io_capability_provider = None # Classic only
# Connect to the source and sink if specified
if controller_source:
@@ -143,140 +95,55 @@ class Host(AbortableEventEmitter):
if controller_sink:
self.set_packet_sink(controller_sink)
def find_connection_by_bd_addr(
self,
bd_addr: Address,
transport: Optional[int] = None,
check_address_type: bool = False,
) -> Optional[Connection]:
for connection in self.connections.values():
if connection.peer_address.to_bytes() == bd_addr.to_bytes():
if (
check_address_type
and connection.peer_address.address_type != bd_addr.address_type
):
continue
if transport is None or connection.transport == transport:
return connection
return None
async def flush(self) -> None:
# Make sure no command is pending
await self.command_semaphore.acquire()
# Flush current host state, then release command semaphore
self.emit('flush')
self.command_semaphore.release()
async def reset(self):
if self.ready:
self.ready = False
await self.flush()
await self.send_command(HCI_Reset_Command(), check_result=True)
await self.send_command(HCI_Reset_Command())
self.ready = True
response = await self.send_command(
HCI_Read_Local_Supported_Commands_Command(), check_result=True
)
self.local_supported_commands = response.return_parameters.supported_commands
await self.send_command(HCI_Set_Event_Mask_Command(event_mask = bytes.fromhex('FFFFFFFFFFFFFFFF')))
await self.send_command(HCI_LE_Set_Event_Mask_Command(le_event_mask = bytes.fromhex('FFFFF00000000000')))
if self.supports_command(HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(
HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
)
self.local_le_features = struct.unpack(
'<Q', response.return_parameters.le_features
)[0]
response = await self.send_command(HCI_Read_Local_Supported_Commands_Command())
if response.return_parameters.status == HCI_SUCCESS:
self.local_supported_commands = response.return_parameters.supported_commands
else:
logger.warn(f'HCI_Read_Local_Supported_Commands_Command failed: {response.return_parameters.status}')
if self.supports_command(HCI_WRITE_LE_HOST_SUPPORT_COMMAND):
await self.send_command(HCI_Write_LE_Host_Support_Command(le_supported_host = 1, simultaneous_le_host = 0))
if self.supports_command(HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
response = await self.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
self.local_version = response.return_parameters
await self.send_command(
HCI_Set_Event_Mask_Command(event_mask=bytes.fromhex('FFFFFFFFFFFFFF3F'))
)
if (
self.local_version is not None
and self.local_version.hci_version <= HCI_VERSION_BLUETOOTH_CORE_4_0
):
# Some older controllers don't like event masks with bits they don't
# understand
le_event_mask = bytes.fromhex('1F00000000000000')
else:
le_event_mask = bytes.fromhex('FFFFF00000000000')
await self.send_command(
HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
)
if self.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
self.hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
self.hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
logger.debug(
'HCI ACL flow control: '
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
)
response = await self.send_command(HCI_Read_Local_Version_Information_Command())
if response.return_parameters.status == HCI_SUCCESS:
self.local_version = response.return_parameters
else:
logger.warn(f'HCI_Read_Local_Version_Information_Command failed: {response.return_parameters.status}')
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
self.hc_le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length
)
self.hc_total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets
)
response = await self.send_command(HCI_LE_Read_Buffer_Size_Command())
if response.return_parameters.status == HCI_SUCCESS:
self.hc_le_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets
logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={response.return_parameters.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={response.return_parameters.hc_total_num_le_acl_data_packets}')
else:
logger.warn(f'HCI_LE_Read_Buffer_Size_Command failed: {response.return_parameters.status}')
if response.return_parameters.hc_le_acl_data_packet_length == 0 or response.return_parameters.hc_total_num_le_acl_data_packets == 0:
# Read the non-LE-specific values
response = await self.send_command(HCI_Read_Buffer_Size_Command())
if response.return_parameters.status == HCI_SUCCESS:
self.hc_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length
self.hc_le_acl_data_packet_length = self.hc_le_acl_data_packet_length or self.hc_acl_data_packet_length
self.hc_total_num_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets
self.hc_total_num_le_acl_data_packets = self.hc_total_num_le_acl_data_packets or self.hc_total_num_acl_data_packets
logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}')
else:
logger.warn(f'HCI_Read_Buffer_Size_Command failed: {response.return_parameters.status}')
logger.debug(
'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
'hc_total_num_le_acl_data_packets='
f'{self.hc_total_num_le_acl_data_packets}'
)
if (
response.return_parameters.hc_le_acl_data_packet_length == 0
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
):
# LE and Classic share the same values
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = (
self.hc_total_num_acl_data_packets
)
if self.supports_command(
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
) and self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await self.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(
HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time=self.suggested_max_tx_time,
)
)
if self.supports_command(HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(HCI_LE_Read_Local_Supported_Features_Command())
if response.return_parameters.status == HCI_SUCCESS:
self.local_le_features = struct.unpack('<Q', response.return_parameters.le_features)[0]
else:
logger.warn(f'HCI_LE_Read_Supported_Features_Command failed: {response.return_parameters.status}')
self.reset_done = True
@@ -294,18 +161,15 @@ class Host(AbortableEventEmitter):
self.hci_sink = sink
def send_hci_packet(self, packet):
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(packet.to_bytes())
async def send_command(self, command, check_result=False):
async def send_command(self, command):
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
assert self.pending_command is None
assert self.pending_response is None
assert(self.pending_command is None)
assert(self.pending_response is None)
# Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future()
@@ -314,29 +178,11 @@ class Host(AbortableEventEmitter):
try:
self.send_hci_packet(command)
response = await self.pending_response
# Check the return parameters if required
if check_result:
if isinstance(response.return_parameters, int):
status = response.return_parameters
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
status = response.return_parameters.status
if status != HCI_SUCCESS:
logger.warning(
f'{command.name} failed ({HCI_Constant.error_name(status)})'
)
raise HCI_Error(status)
# TODO: check error values
return response
except Exception as error:
logger.warning(
f'{color("!!! Exception while sending HCI packet:", "red")} {error}'
)
raise error
logger.warning(f'{color("!!! Exception while sending HCI packet:", "red")} {error}')
# raise error
finally:
self.pending_command = None
self.pending_response = None
@@ -356,18 +202,15 @@ class Host(AbortableEventEmitter):
offset = 0
pb_flag = 0
while bytes_remaining:
# TODO: support different LE/Classic lengths
data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
acl_packet = HCI_AclDataPacket(
connection_handle=connection_handle,
pb_flag=pb_flag,
bc_flag=0,
data_total_length=data_total_length,
data=l2cap_pdu[offset : offset + data_total_length],
)
logger.debug(
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
connection_handle = connection_handle,
pb_flag = pb_flag,
bc_flag = 0,
data_total_length = data_total_length,
data = l2cap_pdu[offset:offset + data_total_length]
)
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}')
self.queue_acl_packet(acl_packet)
pb_flag = 1
offset += data_total_length
@@ -378,38 +221,30 @@ class Host(AbortableEventEmitter):
self.check_acl_packet_queue()
if len(self.acl_packet_queue):
logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, '
f'{len(self.acl_packet_queue)} in queue'
)
logger.debug(f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue')
def check_acl_packet_queue(self):
# Send all we can (TODO: support different LE/Classic limits)
while (
len(self.acl_packet_queue) > 0
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
):
# Send all we can
while len(self.acl_packet_queue) > 0 and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets:
packet = self.acl_packet_queue.pop()
self.send_hci_packet(packet)
self.acl_packets_in_flight += 1
def supports_command(self, command):
# Find the support flag position for this command
for octet, flags in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS):
for flag_position, value in enumerate(flags):
for (octet, flags) in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS):
for (flag_position, value) in enumerate(flags):
if value == command:
# Check if the flag is set
if octet < len(self.local_supported_commands) and flag_position < 8:
return (
self.local_supported_commands[octet] & (1 << flag_position)
) != 0
return (self.local_supported_commands[octet] & (1 << flag_position)) != 0
return False
@property
def supported_commands(self):
commands = []
for octet, flags in enumerate(self.local_supported_commands):
for (octet, flags) in enumerate(self.local_supported_commands):
if octet < len(HCI_SUPPORTED_COMMANDS_FLAGS):
for flag in range(8):
if flags & (1 << flag) != 0:
@@ -424,17 +259,15 @@ class Host(AbortableEventEmitter):
@property
def supported_le_features(self):
return [
feature for feature in range(64) if self.local_le_features & (1 << feature)
]
return [feature for feature in range(64) if self.local_le_features & (1 << feature)]
# Packet Sink protocol (packets coming from the controller via HCI)
def on_packet(self, packet):
hci_packet = HCI_Packet.from_bytes(packet)
if self.ready or (
hci_packet.hci_packet_type == HCI_EVENT_PACKET
and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT
and hci_packet.command_opcode == HCI_RESET_COMMAND
hci_packet.hci_packet_type == HCI_EVENT_PACKET and
hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT and
hci_packet.command_opcode == HCI_RESET_COMMAND
):
self.on_hci_packet(hci_packet)
else:
@@ -443,9 +276,6 @@ class Host(AbortableEventEmitter):
def on_hci_packet(self, packet):
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET:
self.on_hci_command_packet(packet)
@@ -469,6 +299,12 @@ class Host(AbortableEventEmitter):
if connection := self.connections.get(packet.connection_handle):
connection.on_hci_acl_data_packet(packet)
def on_gatt_pdu(self, connection, pdu):
self.emit('gatt_pdu', connection.handle, pdu)
def on_smp_pdu(self, connection, pdu):
self.emit('smp_pdu', connection.handle, pdu)
def on_l2cap_pdu(self, connection, cid, pdu):
self.emit('l2cap_pdu', connection.handle, cid, pdu)
@@ -476,11 +312,7 @@ class Host(AbortableEventEmitter):
if self.pending_response:
# Check that it is what we were expecting
if self.pending_command.op_code != event.command_opcode:
logger.warning(
'!!! command result mismatch, expected '
f'0x{self.pending_command.op_code:X} but got '
f'0x{event.command_opcode:X}'
)
logger.warning(f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}')
self.pending_response.set_result(event)
else:
@@ -494,12 +326,10 @@ class Host(AbortableEventEmitter):
def on_hci_command_complete_event(self, event):
if event.command_opcode == 0:
# This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
# This is used just for the Num_HCI_Command_Packets field, not related to an actual command
logger.debug('no-command event')
return None
return self.on_command_processed(event)
else:
return self.on_command_processed(event)
def on_hci_command_status_event(self, event):
return self.on_command_processed(event)
@@ -510,64 +340,51 @@ class Host(AbortableEventEmitter):
self.acl_packets_in_flight -= total_packets
self.check_acl_packet_queue()
else:
logger.warning(
color(
'!!! {total_packets} completed but only '
f'{self.acl_packets_in_flight} in flight'
)
)
logger.warning(color(f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight'))
self.acl_packets_in_flight = 0
# Classic only
def on_hci_connection_request_event(self, event):
# Notify the listeners
self.emit(
'connection_request',
event.bd_addr,
event.class_of_device,
event.link_type,
# For now, just accept everything
# TODO: delegate the decision
self.send_command_sync(
HCI_Accept_Connection_Request_Command(
bd_addr = event.bd_addr,
role = 0x01 # Remain the peripheral
)
)
def on_hci_le_connection_complete_event(self, event):
# Check if this is a cancellation
if event.status == HCI_SUCCESS:
# Create/update the connection
logger.debug(
f'### LE CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.peer_address} as {HCI_Constant.role_name(event.role)}'
)
logger.debug(f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}')
connection = self.connections.get(event.connection_handle)
if connection is None:
connection = Connection(
self,
event.connection_handle,
event.peer_address,
BT_LE_TRANSPORT,
)
connection = Connection(self, event.connection_handle, event.role, event.peer_address)
self.connections[event.connection_handle] = connection
# Notify the client
connection_parameters = ConnectionParameters(
event.connection_interval,
event.peripheral_latency,
event.supervision_timeout,
event.conn_interval,
event.conn_latency,
event.supervision_timeout
)
self.emit(
'connection',
event.connection_handle,
BT_LE_TRANSPORT,
event.peer_address,
None,
event.role,
connection_parameters,
connection_parameters
)
else:
logger.debug(f'### CONNECTION FAILED: {event.status}')
# Notify the listeners
self.emit(
'connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status
)
self.emit('connection_failure', event.status)
def on_hci_le_enhanced_connection_complete_event(self, event):
# Just use the same implementation as for the non-enhanced event for now
@@ -576,19 +393,11 @@ class Host(AbortableEventEmitter):
def on_hci_connection_complete_event(self, event):
if event.status == HCI_SUCCESS:
# Create/update the connection
logger.debug(
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr}'
)
logger.debug(f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}')
connection = self.connections.get(event.connection_handle)
if connection is None:
connection = Connection(
self,
event.connection_handle,
event.bd_addr,
BT_BR_EDR_TRANSPORT,
)
connection = Connection(self, event.connection_handle, BT_CENTRAL_ROLE, event.bd_addr)
self.connections[event.connection_handle] = connection
# Notify the client
@@ -598,15 +407,14 @@ class Host(AbortableEventEmitter):
BT_BR_EDR_TRANSPORT,
event.bd_addr,
None,
None,
BT_CENTRAL_ROLE,
None
)
else:
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
# Notify the client
self.emit(
'connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status
)
self.emit('connection_failure', event.connection_handle, event.status)
def on_hci_disconnection_complete_event(self, event):
# Find the connection
@@ -615,11 +423,7 @@ class Host(AbortableEventEmitter):
return
if event.status == HCI_SUCCESS:
logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] '
f'{connection.peer_address} '
f'reason={event.reason}'
)
logger.debug(f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}')
del self.connections[event.connection_handle]
# Notify the listeners
@@ -628,7 +432,7 @@ class Host(AbortableEventEmitter):
logger.debug(f'### DISCONNECTION FAILED: {event.status}')
# Notify the listeners
self.emit('disconnection_failure', event.connection_handle, event.status)
self.emit('disconnection_failure', event.status)
def on_hci_le_connection_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None:
@@ -638,17 +442,13 @@ class Host(AbortableEventEmitter):
# Notify the client
if event.status == HCI_SUCCESS:
connection_parameters = ConnectionParameters(
event.connection_interval,
event.peripheral_latency,
event.supervision_timeout,
)
self.emit(
'connection_parameters_update', connection.handle, connection_parameters
event.conn_interval,
event.conn_latency,
event.supervision_timeout
)
self.emit('connection_parameters_update', connection.handle, connection_parameters)
else:
self.emit(
'connection_parameters_update_failure', connection.handle, event.status
)
self.emit('connection_parameters_update_failure', connection.handle, event.status)
def on_hci_le_phy_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None:
@@ -664,10 +464,13 @@ class Host(AbortableEventEmitter):
def on_hci_le_advertising_report_event(self, event):
for report in event.reports:
self.emit('advertising_report', report)
def on_hci_le_extended_advertising_report_event(self, event):
self.on_hci_le_advertising_report_event(event)
self.emit(
'advertising_report',
report.address,
report.data,
report.rssi,
report.event_type
)
def on_hci_le_remote_connection_parameter_request_event(self, event):
if event.connection_handle not in self.connections:
@@ -678,13 +481,13 @@ class Host(AbortableEventEmitter):
# TODO: delegate the decision
self.send_command_sync(
HCI_LE_Remote_Connection_Parameter_Request_Reply_Command(
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,
connection_handle = event.connection_handle,
interval_min = event.interval_min,
interval_max = event.interval_max,
latency = event.latency,
timeout = event.timeout,
minimum_ce_length = 0,
maximum_ce_length = 0
)
)
@@ -698,23 +501,19 @@ class Host(AbortableEventEmitter):
logger.debug('no long term key provider')
long_term_key = None
else:
long_term_key = await self.abort_on(
'flush',
# pylint: disable-next=not-callable
self.long_term_key_provider(
connection.handle,
event.random_number,
event.encryption_diversifier,
),
long_term_key = await self.long_term_key_provider(
connection.handle,
event.random_number,
event.encryption_diversifier
)
if long_term_key:
response = HCI_LE_Long_Term_Key_Request_Reply_Command(
connection_handle=event.connection_handle,
long_term_key=long_term_key,
connection_handle = event.connection_handle,
long_term_key = long_term_key
)
else:
response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command(
connection_handle=event.connection_handle
connection_handle = event.connection_handle
)
await self.send_command(response)
@@ -729,17 +528,10 @@ class Host(AbortableEventEmitter):
def on_hci_role_change_event(self, event):
if event.status == HCI_SUCCESS:
logger.debug(
f'role change for {event.bd_addr}: '
f'{HCI_Constant.role_name(event.new_role)}'
)
self.emit('role_change', event.bd_addr, event.new_role)
logger.debug(f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}')
# TODO: lookup the connection and update the role
else:
logger.debug(
f'role change for {event.bd_addr} failed: '
f'{HCI_Constant.error_name(event.status)}'
)
self.emit('role_change_failure', event.bd_addr, event.status)
logger.debug(f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}')
def on_hci_le_data_length_change_event(self, event):
self.emit(
@@ -748,7 +540,7 @@ class Host(AbortableEventEmitter):
event.max_tx_octets,
event.max_tx_time,
event.max_rx_octets,
event.max_rx_time,
event.max_rx_time
)
def on_hci_authentication_complete_event(self, event):
@@ -756,35 +548,21 @@ class Host(AbortableEventEmitter):
if event.status == HCI_SUCCESS:
self.emit('connection_authentication', event.connection_handle)
else:
self.emit(
'connection_authentication_failure',
event.connection_handle,
event.status,
)
self.emit('connection_authentication_failure', event.connection_handle, event.status)
def on_hci_encryption_change_event(self, event):
# Notify the client
if event.status == HCI_SUCCESS:
self.emit(
'connection_encryption_change',
event.connection_handle,
event.encryption_enabled,
)
self.emit('connection_encryption_change', event.connection_handle, event.encryption_enabled)
else:
self.emit(
'connection_encryption_failure', event.connection_handle, event.status
)
self.emit('connection_encryption_failure', event.connection_handle, event.status)
def on_hci_encryption_key_refresh_complete_event(self, event):
# Notify the client
if event.status == HCI_SUCCESS:
self.emit('connection_encryption_key_refresh', event.connection_handle)
else:
self.emit(
'connection_encryption_key_refresh_failure',
event.connection_handle,
event.status,
)
self.emit('connection_encryption_key_refresh_failure', event.connection_handle, event.status)
def on_hci_link_supervision_timeout_changed_event(self, event):
pass
@@ -796,20 +574,20 @@ class Host(AbortableEventEmitter):
pass
def on_hci_link_key_notification_event(self, event):
logger.debug(
f'link key for {event.bd_addr}: {event.link_key.hex()}, '
f'type={HCI_Constant.link_key_type_name(event.key_type)}'
)
logger.debug(f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}')
self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
def on_hci_simple_pairing_complete_event(self, event):
logger.debug(
f'simple pairing complete for {event.bd_addr}: '
f'status={HCI_Constant.status_name(event.status)}'
)
logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}')
def on_hci_pin_code_request_event(self, event):
self.emit('pin_code_request', event.bd_addr)
# For now, just refuse all requests
# TODO: delegate the decision
self.send_command_sync(
HCI_PIN_Code_Request_Negative_Reply_Command(
bd_addr = event.bd_addr
)
)
def on_hci_link_key_request_event(self, event):
async def send_link_key():
@@ -817,18 +595,15 @@ class Host(AbortableEventEmitter):
logger.debug('no link key provider')
link_key = None
else:
link_key = await self.abort_on(
'flush',
# pylint: disable-next=not-callable
self.link_key_provider(event.bd_addr),
)
link_key = await self.link_key_provider(event.bd_addr)
if link_key:
response = HCI_Link_Key_Request_Reply_Command(
bd_addr=event.bd_addr, link_key=link_key
bd_addr = event.bd_addr,
link_key = link_key
)
else:
response = HCI_Link_Key_Request_Negative_Reply_Command(
bd_addr=event.bd_addr
bd_addr = event.bd_addr
)
await self.send_command(response)
@@ -839,29 +614,15 @@ class Host(AbortableEventEmitter):
self.emit('authentication_io_capability_request', event.bd_addr)
def on_hci_io_capability_response_event(self, event):
self.emit(
'authentication_io_capability_response',
event.bd_addr,
event.io_capability,
event.authentication_requirements,
)
pass
def on_hci_user_confirmation_request_event(self, event):
self.emit(
'authentication_user_confirmation_request',
event.bd_addr,
event.numeric_value,
)
self.emit('authentication_user_confirmation_request', event.bd_addr, event.numeric_value)
def on_hci_user_passkey_request_event(self, event):
self.emit('authentication_user_passkey_request', event.bd_addr)
def on_hci_user_passkey_notification_event(self, event):
self.emit(
'authentication_user_passkey_notification', event.bd_addr, event.passkey
)
def on_hci_inquiry_complete_event(self, _event):
def on_hci_inquiry_complete_event(self, event):
self.emit('inquiry_complete')
def on_hci_inquiry_result_with_rssi_event(self, event):
@@ -871,7 +632,7 @@ class Host(AbortableEventEmitter):
response.bd_addr,
response.class_of_device,
b'',
response.rssi,
response.rssi
)
def on_hci_extended_inquiry_result_event(self, event):
@@ -880,7 +641,7 @@ class Host(AbortableEventEmitter):
event.bd_addr,
event.class_of_device,
event.extended_inquiry_response,
event.rssi,
event.rssi
)
def on_hci_remote_name_request_complete_event(self, event):
@@ -888,10 +649,3 @@ class Host(AbortableEventEmitter):
self.emit('remote_name_failure', event.bd_addr, event.status)
else:
self.emit('remote_name', event.bd_addr, event.remote_name)
def on_hci_remote_host_supported_features_notification_event(self, event):
self.emit(
'remote_host_supported_features',
event.bd_addr,
event.host_supported_features,
)

View File

@@ -20,19 +20,13 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import os
import json
from typing import TYPE_CHECKING, Optional
from colors import color
from .colors import color
from .hci import Address
if TYPE_CHECKING:
from .device import Device
# -----------------------------------------------------------------------------
# Logging
@@ -44,10 +38,10 @@ logger = logging.getLogger(__name__)
class PairingKeys:
class Key:
def __init__(self, value, authenticated=False, ediv=None, rand=None):
self.value = value
self.value = value
self.authenticated = authenticated
self.ediv = ediv
self.rand = rand
self.ediv = ediv
self.rand = rand
@classmethod
def from_dict(cls, key_dict):
@@ -70,33 +64,31 @@ class PairingKeys:
return key_dict
def __init__(self):
self.address_type = None
self.ltk = None
self.ltk_central = None
self.address_type = None
self.ltk = None
self.ltk_central = None
self.ltk_peripheral = None
self.irk = None
self.csrk = None
self.link_key = None # Classic
self.irk = None
self.csrk = None
self.link_key = None # Classic
@staticmethod
def key_from_dict(keys_dict, key_name):
key_dict = keys_dict.get(key_name)
if key_dict is None:
return None
return PairingKeys.Key.from_dict(key_dict)
if key_dict is not None:
return PairingKeys.Key.from_dict(key_dict)
@staticmethod
def from_dict(keys_dict):
keys = PairingKeys()
keys.address_type = keys_dict.get('address_type')
keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
keys.address_type = keys_dict.get('address_type')
keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
return keys
@@ -128,13 +120,13 @@ class PairingKeys:
def print(self, prefix=''):
keys_dict = self.to_dict()
for (container_property, value) in keys_dict.items():
if isinstance(value, dict):
print(f'{prefix}{color(container_property, "cyan")}:')
for (property, value) in keys_dict.items():
if type(value) is dict:
print(f'{prefix}{color(property, "cyan")}:')
for (key_property, key_value) in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}')
else:
print(f'{prefix}{color(container_property, "cyan")}: {value}')
print(f'{prefix}{color(property, "cyan")}: {value}')
# -----------------------------------------------------------------------------
@@ -145,16 +137,12 @@ class KeyStore:
async def update(self, name, keys):
pass
async def get(self, _name):
async def get(self, name):
return PairingKeys()
async def get_all(self):
return []
async def delete_all(self):
all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
async def get_resolving_keys(self):
all_keys = await self.get_all()
resolving_keys = []
@@ -173,26 +161,26 @@ class KeyStore:
separator = ''
for (name, keys) in entries:
print(separator + prefix + color(name, 'yellow'))
keys.print(prefix=prefix + ' ')
keys.print(prefix = prefix + ' ')
separator = '\n'
@staticmethod
def create_for_device(device: Device) -> Optional[KeyStore]:
if device.config.keystore is None:
def create_for_device(device_config):
if device_config.keystore is None:
return None
keystore_type = device.config.keystore.split(':', 1)[0]
keystore_type = device_config.keystore.split(':', 1)[0]
if keystore_type == 'JsonKeyStore':
return JsonKeyStore.from_device(device)
return JsonKeyStore.from_device_config(device_config)
return None
# -----------------------------------------------------------------------------
class JsonKeyStore(KeyStore):
APP_NAME = 'Bumble'
APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing'
APP_NAME = 'Bumble'
APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing'
DEFAULT_NAMESPACE = '__DEFAULT__'
def __init__(self, namespace, filename=None):
@@ -200,17 +188,12 @@ class JsonKeyStore(KeyStore):
if filename is None:
# Use a default for the current user
# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
json_filename = (
f'{self.namespace}.json'.lower().replace(':', '-').replace('/p', '-p')
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR),
self.KEYS_DIR
)
json_filename = f'{self.namespace}.json'.lower().replace(':', '-')
self.filename = os.path.join(self.directory_name, json_filename)
else:
self.filename = filename
@@ -219,21 +202,11 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}')
@staticmethod
def from_device(device: Device) -> Optional[JsonKeyStore]:
if not device.config.keystore:
return None
params = device.config.keystore.split(':', 1)[1:]
# Use a namespace based on the device address
if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
namespace = str(device.public_address)
elif device.random_address != Address.ANY_RANDOM:
namespace = str(device.random_address)
else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE
def from_device_config(device_config):
params = device_config.keystore.split(':', 1)[1:]
namespace = str(device_config.address)
if params:
filename = params[0]
filename = params[1]
else:
filename = None
@@ -241,7 +214,7 @@ class JsonKeyStore(KeyStore):
async def load(self):
try:
with open(self.filename, 'r', encoding='utf-8') as json_file:
with open(self.filename, 'r') as json_file:
return json.load(json_file)
except FileNotFoundError:
return {}
@@ -253,13 +226,13 @@ class JsonKeyStore(KeyStore):
# Save to a temporary file
temp_filename = self.filename + '.tmp'
with open(temp_filename, 'w', encoding='utf-8') as output:
with open(temp_filename, 'w') as output:
json.dump(db, output, sort_keys=True, indent=4)
# Atomically replace the previous file
os.replace(temp_filename, self.filename)
os.rename(temp_filename, self.filename)
async def delete(self, name: str) -> None:
async def delete(self, name):
db = await self.load()
namespace = db.get(self.namespace)
@@ -273,7 +246,7 @@ class JsonKeyStore(KeyStore):
db = await self.load()
namespace = db.setdefault(self.namespace, {})
namespace.setdefault(name, {}).update(keys.to_dict())
namespace[name] = keys.to_dict()
await self.save(db)
@@ -284,18 +257,9 @@ class JsonKeyStore(KeyStore):
if namespace is None:
return []
return [
(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
]
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()]
async def delete_all(self):
db = await self.load()
db.pop(self.namespace, None)
await self.save(db)
async def get(self, name: str) -> Optional[PairingKeys]:
async def get(self, name):
db = await self.load()
namespace = db.get(self.namespace)

File diff suppressed because it is too large Load Diff

View File

@@ -17,17 +17,15 @@
# -----------------------------------------------------------------------------
import logging
import asyncio
import websockets
from functools import partial
from colors import color
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
from bumble.colors import color
from bumble.hci import (
Address,
HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_Connection_Complete_Event,
HCI_CONNECTION_TIMEOUT_ERROR
)
# -----------------------------------------------------------------------------
@@ -49,8 +47,7 @@ def parse_parameters(params_str):
# -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
# TODO: add more support for various LL exchanges (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
# -----------------------------------------------------------------------------
class LocalLink:
'''
@@ -58,13 +55,8 @@ class LocalLink:
'''
def __init__(self):
self.controllers = set()
self.controllers = set()
self.pending_connection = None
self.pending_classic_connection = None
############################################################
# Common utils
############################################################
def add_controller(self, controller):
logger.debug(f'new controller: {controller}')
@@ -79,39 +71,22 @@ class LocalLink:
return controller
return None
def find_classic_controller(self, address):
for controller in self.controllers:
if controller.public_address == address:
return controller
return None
def on_address_changed(self, controller):
pass
def get_pending_connection(self):
return self.pending_connection
############################################################
# LE handlers
############################################################
def on_address_changed(self, controller):
pass
def send_advertising_data(self, sender_address, data):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_advertising_data(sender_address, data)
def send_acl_data(self, sender_controller, destination_address, transport, data):
def send_acl_data(self, sender_address, destination_address, data):
# Send the data to the first controller with a matching address
if transport == BT_LE_TRANSPORT:
destination_controller = self.find_controller(destination_address)
source_address = sender_controller.random_address
elif transport == BT_BR_EDR_TRANSPORT:
destination_controller = self.find_classic_controller(destination_address)
source_address = sender_controller.public_address
if destination_controller is not None:
destination_controller.on_link_acl_data(source_address, transport, data)
if controller := self.find_controller(destination_address):
controller.on_link_acl_data(sender_address, data)
def on_connection_complete(self):
# Check that we expect this call
@@ -128,31 +103,23 @@ class LocalLink:
return
# Connect to the first controller with a matching address
if peripheral_controller := self.find_controller(
le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_SUCCESS
)
if peripheral_controller := self.find_controller(le_create_connection_command.peer_address):
central_controller.on_link_peripheral_connection_complete(le_create_connection_command, HCI_SUCCESS)
peripheral_controller.on_link_central_connected(central_address)
return
# No peripheral found
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
le_create_connection_command,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
)
def connect(self, central_address, le_create_connection_command):
logger.debug(
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
logger.debug(f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}')
self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(
self, central_address, peripheral_address, disconnect_command
):
def on_disconnection_complete(self, central_address, peripheral_address, disconnect_command):
# Find the controller that initiated the disconnection
if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found')
@@ -160,26 +127,16 @@ class LocalLink:
# Disconnect from the first controller with a matching address
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected(
central_address, disconnect_command.reason
)
peripheral_controller.on_link_central_disconnected(central_address, disconnect_command.reason)
central_controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
central_controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS)
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(
f'$$$ DISCONNECTION {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
logger.debug(f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}')
args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
# pylint: disable=too-many-arguments
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address):
@@ -188,89 +145,6 @@ class LocalLink:
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
############################################################
# Classic handlers
############################################################
def classic_connect(self, initiator_controller, responder_address):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
if responder_controller is None:
initiator_controller.on_classic_connection_complete(
responder_address, HCI_PAGE_TIMEOUT_ERROR
)
return
self.pending_classic_connection = (initiator_controller, responder_controller)
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
def classic_accept_connection(
self, responder_controller, initiator_address, responder_role
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR
)
return
async def task():
if responder_role != BT_PERIPHERAL_ROLE:
initiator_controller.on_classic_role_change(
responder_controller.public_address, int(not (responder_role))
)
initiator_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_SUCCESS
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, responder_role
)
responder_controller.on_classic_connection_complete(
initiator_controller.public_address, HCI_SUCCESS
)
self.pending_classic_connection = None
def classic_disconnect(self, initiator_controller, responder_address, reason):
logger.debug(
f'[Classic] {initiator_controller.public_address} disconnects {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
async def task():
initiator_controller.on_classic_disconnected(responder_address, reason)
asyncio.create_task(task())
responder_controller.on_classic_disconnected(
initiator_controller.public_address, reason
)
def classic_switch_role(
self, initiator_controller, responder_address, initiator_new_role
):
responder_controller = self.find_classic_controller(responder_address)
if responder_controller is None:
return
async def task():
initiator_controller.on_classic_role_change(
responder_address, initiator_new_role
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, int(not (initiator_new_role))
)
# -----------------------------------------------------------------------------
class RemoteLink:
@@ -278,18 +152,15 @@ class RemoteLink:
A Link implementation that communicates with other virtual controllers via a
WebSocket relay
'''
def __init__(self, uri):
self.controller = None
self.uri = uri
self.execution_queue = asyncio.Queue()
self.websocket = asyncio.get_running_loop().create_future()
self.rpc_result = None
self.pending_connection = None
self.central_connections = set() # List of addresses that we have connected to
self.peripheral_connections = (
set()
) # List of addresses that have connected to us
self.controller = None
self.uri = uri
self.execution_queue = asyncio.Queue()
self.websocket = asyncio.get_running_loop().create_future()
self.rpc_result = None
self.pending_connection = None
self.central_connections = set() # List of addresses that we have connected to
self.peripheral_connections = set() # List of addresses that have connected to us
# Connect and run asynchronously
asyncio.create_task(self.run_connection())
@@ -308,9 +179,6 @@ class RemoteLink:
def get_pending_connection(self):
return self.pending_connection
def get_pending_classic_connection(self):
return self.pending_classic_connection
async def wait_until_connected(self):
await self.websocket
@@ -324,16 +192,11 @@ class RemoteLink:
try:
await item
except Exception as error:
logger.warning(
f'{color("!!! Exception in async handler:", "red")} {error}'
)
logger.warning(f'{color("!!! Exception in async handler:", "red")} {error}')
async def run_connection(self):
import websockets # lazy import
# Connect to the relay
logger.debug(f'connecting to {self.uri}')
# pylint: disable-next=no-member
websocket = await websockets.connect(self.uri)
self.websocket.set_result(websocket)
logger.debug(f'connected to {self.uri}')
@@ -364,9 +227,7 @@ class RemoteLink:
self.central_connections.remove(address)
if address in self.peripheral_connections:
self.controller.on_link_central_disconnected(
address, HCI_CONNECTION_TIMEOUT_ERROR
)
self.controller.on_link_central_disconnected(address, HCI_CONNECTION_TIMEOUT_ERROR)
self.peripheral_connections.remove(address)
async def on_unreachable_received(self, target):
@@ -383,9 +244,7 @@ class RemoteLink:
async def on_advertisement_message_received(self, sender, advertisement):
try:
self.controller.on_link_advertising_data(
Address(sender), bytes.fromhex(advertisement)
)
self.controller.on_link_advertising_data(Address(sender), bytes.fromhex(advertisement))
except Exception:
logger.exception('exception')
@@ -404,11 +263,11 @@ class RemoteLink:
self.controller.on_link_central_connected(Address(sender))
# Accept the connection by responding to it
await self.send_targeted_message(sender, 'connected')
await self.send_targetted_message(sender, 'connected')
async def on_connected_message_received(self, sender, _):
if not self.pending_connection:
logger.warning('received a connection ack, but no connection is pending')
logger.warn('received a connection ack, but no connection is pending')
return
# Remember the connection
@@ -416,9 +275,7 @@ class RemoteLink:
# Notify the controller
logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
self.controller.on_link_peripheral_connection_complete(
self.pending_connection, HCI_SUCCESS
)
self.controller.on_link_peripheral_connection_complete(self.pending_connection, HCI_SUCCESS)
async def on_disconnect_message_received(self, sender, message):
# Notify the controller
@@ -430,7 +287,7 @@ class RemoteLink:
if sender in self.peripheral_connections:
self.peripheral_connections.remove(sender)
async def on_encrypted_message_received(self, sender, _):
async def on_encrypted_message_received(self, sender, message):
# TODO parse params to get real args
self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
@@ -439,7 +296,7 @@ class RemoteLink:
websocket = await self.websocket
# Create a future value to hold the eventual result
assert self.rpc_result is None
assert(self.rpc_result is None)
self.rpc_result = asyncio.get_running_loop().create_future()
# Send the command
@@ -452,7 +309,7 @@ class RemoteLink:
# TODO: parse the result
async def send_targeted_message(self, target, message):
async def send_targetted_message(self, target, message):
# Ensure we have a connection
websocket = await self.websocket
@@ -469,62 +326,35 @@ class RemoteLink:
self.execute(self.notify_address_changed)
async def send_advertising_data_to_relay(self, data):
await self.send_targeted_message('*', f'advertisement:{data.hex()}')
await self.send_targetted_message('*', f'advertisement:{data.hex()}')
def send_advertising_data(self, _, data):
def send_advertising_data(self, sender_address, data):
self.execute(partial(self.send_advertising_data_to_relay, data))
async def send_acl_data_to_relay(self, peer_address, data):
await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
await self.send_targetted_message(peer_address, f'acl:{data.hex()}')
def send_acl_data(self, _, peer_address, _transport, data):
# TODO: handle different transport
def send_acl_data(self, sender_address, peer_address, data):
self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
async def send_connection_request_to_relay(self, peer_address):
await self.send_targeted_message(peer_address, 'connect')
await self.send_targetted_message(peer_address, 'connect')
def connect(self, _, le_create_connection_command):
def connect(self, central_address, le_create_connection_command):
if self.pending_connection:
logger.warning('connection already pending')
logger.warn('connection already pending')
return
self.pending_connection = le_create_connection_command
self.execute(
partial(
self.send_connection_request_to_relay,
str(le_create_connection_command.peer_address),
)
)
self.execute(partial(self.send_connection_request_to_relay, str(le_create_connection_command.peer_address)))
def on_disconnection_complete(self, disconnect_command):
self.controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
self.controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS)
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(
f'disconnect {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'disconnect:reason={disconnect_command.reason}',
)
)
asyncio.get_running_loop().call_soon(
self.on_disconnection_complete, disconnect_command
)
logger.debug(f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}')
self.execute(partial(self.send_targetted_message, peripheral_address, f'disconnect:reason={disconnect_command.reason}'))
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, disconnect_command)
def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk):
asyncio.get_running_loop().call_soon(
self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'encrypted:ltk={ltk.hex()}',
)
)
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk):
asyncio.get_running_loop().call_soon(self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk)
self.execute(partial(self.send_targetted_message, peripheral_address, f'encrypted:ltk={ltk.hex()}'))

View File

@@ -1,184 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import enum
from typing import Optional, Tuple
from .hci import (
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
)
from .smp import (
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
SMP_DISPLAY_ONLY_IO_CAPABILITY,
SMP_DISPLAY_YES_NO_IO_CAPABILITY,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
SMP_ENC_KEY_DISTRIBUTION_FLAG,
SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
SMP_LINK_KEY_DISTRIBUTION_FLAG,
)
# -----------------------------------------------------------------------------
class PairingDelegate:
"""Abstract base class for Pairing Delegates."""
# I/O Capabilities.
# These are defined abstractly, and can be mapped to specific Classic pairing
# and/or SMP constants.
class IoCapability(enum.IntEnum):
NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY
# Direct names for backward compatibility.
NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT
KEYBOARD_INPUT_ONLY = IoCapability.KEYBOARD_INPUT_ONLY
DISPLAY_OUTPUT_ONLY = IoCapability.DISPLAY_OUTPUT_ONLY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = IoCapability.DISPLAY_OUTPUT_AND_YES_NO_INPUT
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = IoCapability.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT
# Key Distribution [LE only]
class KeyDistribution(enum.IntFlag):
DISTRIBUTE_ENCRYPTION_KEY = SMP_ENC_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_IDENTITY_KEY = SMP_ID_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
DEFAULT_KEY_DISTRIBUTION: int = (
SMP_ENC_KEY_DISTRIBUTION_FLAG | SMP_ID_KEY_DISTRIBUTION_FLAG
)
# Default mapping from abstract to Classic I/O capabilities.
# Subclasses may override this if they prefer a different mapping.
CLASSIC_IO_CAPABILITIES_MAP = {
NO_OUTPUT_NO_INPUT: HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
KEYBOARD_INPUT_ONLY: HCI_KEYBOARD_ONLY_IO_CAPABILITY,
DISPLAY_OUTPUT_ONLY: HCI_DISPLAY_ONLY_IO_CAPABILITY,
DISPLAY_OUTPUT_AND_YES_NO_INPUT: HCI_DISPLAY_YES_NO_IO_CAPABILITY,
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT: HCI_DISPLAY_YES_NO_IO_CAPABILITY,
}
io_capability: IoCapability
local_initiator_key_distribution: KeyDistribution
local_responder_key_distribution: KeyDistribution
def __init__(
self,
io_capability=NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION,
) -> None:
self.io_capability = io_capability
self.local_initiator_key_distribution = local_initiator_key_distribution
self.local_responder_key_distribution = local_responder_key_distribution
@property
def classic_io_capability(self) -> int:
"""Map the abstract I/O capability to a Classic constant."""
# pylint: disable=line-too-long
return self.CLASSIC_IO_CAPABILITIES_MAP.get(
self.io_capability, HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
)
@property
def smp_io_capability(self) -> int:
"""Map the abstract I/O capability to an SMP constant."""
# This is just a 1-1 direct mapping
return self.io_capability
async def accept(self) -> bool:
"""Accept or reject a Pairing request."""
return True
async def confirm(self) -> bool:
"""Respond yes or no to a Pairing confirmation question."""
return True
# pylint: disable-next=unused-argument
async def compare_numbers(self, number: int, digits: int) -> bool:
"""Compare two numbers."""
return True
async def get_number(self) -> Optional[int]:
"""
Return an optional number as an answer to a passkey request.
Returning `None` will result in a negative reply.
"""
return 0
async def get_string(self, max_length) -> Optional[str]:
"""
Return a string whose utf-8 encoding is up to max_length bytes.
"""
return None
# pylint: disable-next=unused-argument
async def display_number(self, number: int, digits: int) -> None:
"""Display a number."""
# [LE only]
async def key_distribution_response(
self, peer_initiator_key_distribution: int, peer_responder_key_distribution: int
) -> Tuple[int, int]:
"""
Return the key distribution response in an SMP protocol context.
NOTE: since it is only used by the SMP protocol, this method's input and output
are directly as integers, using the SMP constants, rather than the abstract
KeyDistribution enums.
"""
return (
int(
peer_initiator_key_distribution & self.local_initiator_key_distribution
),
int(
peer_responder_key_distribution & self.local_responder_key_distribution
),
)
# -----------------------------------------------------------------------------
class PairingConfig:
"""Configuration for the Pairing protocol."""
def __init__(
self,
sc: bool = True,
mitm: bool = True,
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
) -> None:
self.sc = sc
self.mitm = mitm
self.bonding = bonding
self.delegate = delegate or PairingDelegate()
def __str__(self) -> str:
return (
f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, '
f'delegate[{self.delegate.io_capability}])'
)

View File

@@ -1,188 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List
from ..core import AdvertisingData
from ..device import Device, Connection
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
GATT_ASHA_VOLUME_CHARACTERISTIC,
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
)
from ..utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class AshaService(TemplateService):
UUID = GATT_ASHA_SERVICE
OPCODE_START = 1
OPCODE_STOP = 2
OPCODE_STATUS = 3
PROTOCOL_VERSION = 0x01
RESERVED_FOR_FUTURE_USE = [00, 00]
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
def on_volume_write(connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
self.emit('volume', connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
)
self.emit(
'start',
connection,
{
'codec': value[1],
'audiotype': value[2],
'volume': value[3],
'otherstate': value[4],
},
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
self.emit('stop', connection)
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
# OPCODE_STATUS does not need audio status point update
if opcode != AshaService.OPCODE_STATUS:
AsyncRunner.spawn(
device.notify_subscribers(
self.audio_status_characteristic, force=True
)
)
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
logging.debug(f'<<< data received:{data}')
self.emit('data', channel.connection, data)
self.audio_out_data += data
channel.sink = on_data
# let the server find a free PSM
self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self):
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)
+ bytes(
[
AshaService.PROTOCOL_VERSION,
self.capability,
]
)
+ bytes(self.hisyncid[:4]),
),
]
)
)

View File

@@ -23,7 +23,7 @@ from ..gatt import (
TemplateService,
Characteristic,
CharacteristicValue,
PackedCharacteristicAdapter,
PackedCharacteristicAdapter
)
@@ -36,11 +36,11 @@ class BatteryService(TemplateService):
self.battery_level_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
CharacteristicValue(read=read_battery_level),
CharacteristicValue(read=read_battery_level)
),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
format=BatteryService.BATTERY_LEVEL_FORMAT
)
super().__init__([self.battery_level_characteristic])
@@ -52,11 +52,10 @@ class BatteryServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BATTERY_LEVEL_CHARACTERISTIC):
self.battery_level = PackedCharacteristicAdapter(
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
characteristics[0],
format=BatteryService.BATTERY_LEVEL_FORMAT
)
else:
self.battery_level = None

View File

@@ -17,7 +17,7 @@
# Imports
# -----------------------------------------------------------------------------
import struct
from typing import Optional, Tuple
from typing import Tuple
from ..gatt_client import ProfileServiceProxy
from ..gatt import (
@@ -33,7 +33,7 @@ from ..gatt import (
TemplateService,
Characteristic,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter,
UTF8CharacteristicAdapter
)
@@ -52,50 +52,49 @@ class DeviceInformationService(TemplateService):
def __init__(
self,
manufacturer_name: Optional[str] = None,
model_number: Optional[str] = None,
serial_number: Optional[str] = None,
hardware_revision: Optional[str] = None,
firmware_revision: Optional[str] = None,
software_revision: Optional[str] = None,
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: Optional[bytes] = None
manufacturer_name: str = None,
model_number: str = None,
serial_number: str = None,
hardware_revision: str = None,
firmware_revision: str = None,
software_revision: str = None,
system_id: Tuple[int, int] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: bytes = None
# TODO: pnp_id
):
characteristics = [
Characteristic(
uuid, Characteristic.Properties.READ, Characteristic.READABLE, field
uuid,
Characteristic.READ,
Characteristic.READABLE,
field
)
for (field, uuid) in (
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
(model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
(serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
(model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
(serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
(hardware_revision, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
(firmware_revision, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
(software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
(software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC)
)
if field is not None
]
if system_id is not None:
characteristics.append(
Characteristic(
GATT_SYSTEM_ID_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
self.pack_system_id(*system_id),
)
)
characteristics.append(Characteristic(
GATT_SYSTEM_ID_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
self.pack_system_id(*system_id)
))
if ieee_regulatory_certification_data_list is not None:
characteristics.append(
Characteristic(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
ieee_regulatory_certification_data_list,
)
)
characteristics.append(Characteristic(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
ieee_regulatory_certification_data_list
))
super().__init__(characteristics)
@@ -109,11 +108,11 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
for (field, uuid) in (
('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
('hardware_revision', GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
('firmware_revision', GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC)
):
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
characteristic = UTF8CharacteristicAdapter(characteristics[0])
@@ -121,20 +120,16 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
characteristic = None
self.__setattr__(field, characteristic)
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_SYSTEM_ID_CHARACTERISTIC
):
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_SYSTEM_ID_CHARACTERISTIC):
self.system_id = DelegatedCharacteristicAdapter(
characteristics[0],
encode=lambda v: DeviceInformationService.pack_system_id(*v),
decode=DeviceInformationService.unpack_system_id,
decode=DeviceInformationService.unpack_system_id
)
else:
self.system_id = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC
):
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC):
self.ieee_regulatory_certification_data_list = characteristics[0]
else:
self.ieee_regulatory_certification_data_list = None

View File

@@ -30,25 +30,25 @@ from ..gatt import (
Characteristic,
CharacteristicValue,
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter,
PackedCharacteristicAdapter
)
# -----------------------------------------------------------------------------
class HeartRateService(TemplateService):
UUID = GATT_HEART_RATE_SERVICE
UUID = GATT_HEART_RATE_SERVICE
HEART_RATE_CONTROL_POINT_FORMAT = 'B'
CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01
CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01
class BodySensorLocation(IntEnum):
OTHER = (0,)
CHEST = (1,)
WRIST = (2,)
FINGER = (3,)
HAND = (4,)
EAR_LOBE = (5,)
FOOT = 6
OTHER = 0,
CHEST = 1,
WRIST = 2,
FINGER = 3,
HAND = 4,
EAR_LOBE = 5,
FOOT = 6
class HeartRateMeasurement:
def __init__(
@@ -56,14 +56,12 @@ class HeartRateService(TemplateService):
heart_rate,
sensor_contact_detected=None,
energy_expended=None,
rr_intervals=None,
rr_intervals=None
):
if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range')
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
):
if energy_expended is not None and (energy_expended < 0 or energy_expended > 0xFFFF):
raise ValueError('energy_expended out of range')
if rr_intervals:
@@ -71,10 +69,10 @@ class HeartRateService(TemplateService):
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise ValueError('rr_intervals out of range')
self.heart_rate = heart_rate
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
@classmethod
def from_bytes(cls, data):
@@ -89,7 +87,7 @@ class HeartRateService(TemplateService):
offset += 1
if flags & (1 << 2):
sensor_contact_detected = flags & (1 << 1) != 0
sensor_contact_detected = (flags & (1 << 1) != 0)
else:
sensor_contact_detected = None
@@ -121,57 +119,51 @@ class HeartRateService(TemplateService):
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
if self.energy_expended is not None:
flags |= 1 << 3
flags |= (1 << 3)
data += struct.pack('<H', self.energy_expended)
if self.rr_intervals:
flags |= 1 << 4
data += b''.join(
[
struct.pack('<H', int(rr_interval * 1024))
for rr_interval in self.rr_intervals
]
)
flags |= (1 << 4)
data += b''.join([
struct.pack('<H', int(rr_interval * 1024))
for rr_interval in self.rr_intervals
])
return bytes([flags]) + data
def __str__(self):
return (
f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
return f'HeartRateMeasurement(heart_rate={self.heart_rate},'\
f' sensor_contact_detected={self.sensor_contact_detected},'\
f' energy_expended={self.energy_expended},'\
f' rr_intervals={self.rr_intervals})'
)
def __init__(
self,
read_heart_rate_measurement,
body_sensor_location=None,
reset_energy_expended=None,
reset_energy_expended=None
):
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.Properties.NOTIFY,
Characteristic.NOTIFY,
0,
CharacteristicValue(read=read_heart_rate_measurement),
CharacteristicValue(read=read_heart_rate_measurement)
),
# pylint: disable=unnecessary-lambda
encode=lambda value: bytes(value),
encode=lambda value: bytes(value)
)
characteristics = [self.heart_rate_measurement_characteristic]
if body_sensor_location is not None:
self.body_sensor_location_characteristic = Characteristic(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READ,
Characteristic.READABLE,
bytes([int(body_sensor_location)]),
bytes([int(body_sensor_location)])
)
characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value):
if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None:
@@ -182,11 +174,11 @@ class HeartRateService(TemplateService):
self.heart_rate_control_point_characteristic = PackedCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE,
Characteristic.WRITE,
Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value),
CharacteristicValue(write=write_heart_rate_control_point_value)
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT
)
characteristics.append(self.heart_rate_control_point_characteristic)
@@ -200,38 +192,30 @@ class HeartRateServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC):
self.heart_rate_measurement = DelegatedCharacteristicAdapter(
characteristics[0],
decode=HeartRateService.HeartRateMeasurement.from_bytes,
decode=HeartRateService.HeartRateMeasurement.from_bytes
)
else:
self.heart_rate_measurement = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
):
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC):
self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
decode=lambda value: HeartRateService.BodySensorLocation(value[0])
)
else:
self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
):
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC):
self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT
)
else:
self.heart_rate_control_point = None
async def reset_energy_expended(self):
if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED
)
return await self.heart_rate_control_point.write_value(HeartRateService.RESET_ENERGY_EXPENDED)

View File

View File

@@ -18,11 +18,10 @@
import logging
import asyncio
from colors import color
from pyee import EventEmitter
from . import core
from .colors import color
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError
from .core import InvalidStateError, ProtocolError, ConnectionError
# -----------------------------------------------------------------------------
# Logging
@@ -33,8 +32,6 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
RFCOMM_PSM = 0x0003
@@ -101,24 +98,22 @@ RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# fmt: on
# -----------------------------------------------------------------------------
def compute_fcs(buffer):
result = 0xFF
def fcs(buffer):
fcs = 0xFF
for byte in buffer:
result = CRC_TABLE[result ^ byte]
return 0xFF - result
fcs = CRC_TABLE[fcs ^ byte]
return 0xFF - fcs
# -----------------------------------------------------------------------------
class RFCOMM_Frame:
def __init__(self, frame_type, c_r, dlci, p_f, information=b'', with_credits=False):
self.type = frame_type
self.c_r = c_r
self.dlci = dlci
self.p_f = p_f
def __init__(self, type, c_r, dlci, p_f, information = b'', with_credits = False):
self.type = type
self.c_r = c_r
self.dlci = dlci
self.p_f = p_f
self.information = information
length = len(information)
if with_credits:
@@ -129,19 +124,19 @@ class RFCOMM_Frame:
else:
# 1-byte length indicator
self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = frame_type | (p_f << 4)
if frame_type == RFCOMM_UIH_FRAME:
self.fcs = compute_fcs(bytes([self.address, self.control]))
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = type | (p_f << 4)
if type == RFCOMM_UIH_FRAME:
self.fcs = fcs(bytes([self.address, self.control]))
else:
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
self.fcs = fcs(bytes([self.address, self.control]) + self.length)
def type_name(self):
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod
def parse_mcc(data):
mcc_type = data[0] >> 2
type = data[0] >> 2
c_r = (data[0] >> 1) & 1
length = data[1]
if data[1] & 1:
@@ -149,16 +144,13 @@ class RFCOMM_Frame:
value = data[2:]
else:
length = (data[3] << 7) & (length >> 1)
value = data[3 : 3 + length]
value = data[3:3 + length]
return (mcc_type, c_r, value)
return (type, c_r, value)
@staticmethod
def make_mcc(mcc_type, c_r, data):
return (
bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
+ data
)
def make_mcc(type, c_r, data):
return bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1]) + data
@staticmethod
def sabm(c_r, dlci):
@@ -177,17 +169,15 @@ class RFCOMM_Frame:
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
@staticmethod
def uih(c_r, dlci, information, p_f=0):
return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
def uih(c_r, dlci, information, p_f = 0):
return RFCOMM_Frame(RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits = (p_f == 1))
@staticmethod
def from_bytes(data):
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
frame_type = data[1] & 0xEF
type = data[1] & 0xEF
p_f = (data[1] >> 4) & 0x01
length = data[2]
if length & 0x01:
@@ -199,182 +189,132 @@ class RFCOMM_Frame:
fcs = data[-1]
# Construct the frame and check the CRC
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
frame = RFCOMM_Frame(type, c_r, dlci, p_f, information)
if frame.fcs != fcs:
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
logger.warn(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
raise ValueError('fcs mismatch')
return frame
def __bytes__(self):
return (
bytes([self.address, self.control])
+ self.length
+ self.information
+ bytes([self.fcs])
)
return bytes([self.address, self.control]) + self.length + self.information + bytes([self.fcs])
def __str__(self):
return (
f'{color(self.type_name(), "yellow")}'
f'(c/r={self.c_r},'
f'dlci={self.dlci},'
f'p/f={self.p_f},'
f'length={len(self.information)},'
f'fcs=0x{self.fcs:02X})'
)
return f'{color(self.type_name(), "yellow")}(c/r={self.c_r},dlci={self.dlci},p/f={self.p_f},length={len(self.information)},fcs=0x{self.fcs:02X})'
# -----------------------------------------------------------------------------
class RFCOMM_MCC_PN:
def __init__(
self,
dlci,
cl,
priority,
ack_timer,
max_frame_size,
max_retransmissions,
window_size,
):
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
def __init__(self, dlci, cl, priority, ack_timer, max_frame_size, max_retransmissions, window_size):
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
self.max_retransmissions = max_retransmissions
self.window_size = window_size
self.window_size = window_size
@staticmethod
def from_bytes(data):
return RFCOMM_MCC_PN(
dlci=data[0],
cl=data[1],
priority=data[2],
ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6],
window_size=data[7],
dlci = data[0],
cl = data[1],
priority = data[2],
ack_timer = data[3],
max_frame_size = data[4] | data[5] << 8,
max_retransmissions = data[6],
window_size = data[7]
)
def __bytes__(self):
return bytes(
[
self.dlci & 0xFF,
self.cl & 0xFF,
self.priority & 0xFF,
self.ack_timer & 0xFF,
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
self.window_size & 0xFF,
]
)
return bytes([
self.dlci & 0xFF,
self.cl & 0xFF,
self.priority & 0xFF,
self.ack_timer & 0xFF,
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
self.window_size & 0xFF
])
def __str__(self):
return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)
return f'PN(dlci={self.dlci},cl={self.cl},priority={self.priority},ack_timer={self.ack_timer},max_frame_size={self.max_frame_size},max_retransmissions={self.max_retransmissions},window_size={self.window_size})'
# -----------------------------------------------------------------------------
class RFCOMM_MCC_MSC:
def __init__(self, dlci, fc, rtc, rtr, ic, dv):
self.dlci = dlci
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv
@staticmethod
def from_bytes(data):
return RFCOMM_MCC_MSC(
dlci=data[0] >> 2,
fc=data[1] >> 1 & 1,
rtc=data[1] >> 2 & 1,
rtr=data[1] >> 3 & 1,
ic=data[1] >> 6 & 1,
dv=data[1] >> 7 & 1,
dlci = data[0] >> 2,
fc = data[1] >> 1 & 1,
rtc = data[1] >> 2 & 1,
rtr = data[1] >> 3 & 1,
ic = data[1] >> 6 & 1,
dv = data[1] >> 7 & 1
)
def __bytes__(self):
return bytes(
[
(self.dlci << 2) | 3,
1
| self.fc << 1
| self.rtc << 2
| self.rtr << 3
| self.ic << 6
| self.dv << 7,
]
)
return bytes([
(self.dlci << 2) | 3,
1 | self.fc << 1 | self.rtc << 2 | self.rtr << 3 | self.ic << 6 | self.dv << 7
])
def __str__(self):
return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)
return f'MSC(dlci={self.dlci},fc={self.fc},rtc={self.rtc},rtr={self.rtr},ic={self.ic},dv={self.dv})'
# -----------------------------------------------------------------------------
class DLC(EventEmitter):
# States
INIT = 0x00
CONNECTING = 0x01
CONNECTED = 0x02
INIT = 0x00
CONNECTING = 0x01
CONNECTED = 0x02
DISCONNECTING = 0x03
DISCONNECTED = 0x04
RESET = 0x05
DISCONNECTED = 0x04
RESET = 0x05
STATE_NAMES = {
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET'
}
def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits):
super().__init__()
self.multiplexer = multiplexer
self.dlci = dlci
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.multiplexer = multiplexer
self.dlci = dlci
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.rx_threshold = self.rx_credits // 2
self.tx_credits = initial_tx_credits
self.tx_buffer = b''
self.state = DLC.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0
self.sink = None
self.connection_result = None
self.tx_credits = initial_tx_credits
self.tx_buffer = b''
self.state = DLC.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0
self.sink = None
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
)
self.mtu = min(max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead)
@staticmethod
def state_name(state):
return DLC.STATE_NAMES[state]
def change_state(self, new_state):
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "magenta")}'
)
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "magenta")}')
self.state = new_state
def send_frame(self, frame):
@@ -384,40 +324,58 @@ class DLC(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame):
def on_sabm_frame(self, frame):
if self.state != DLC.CONNECTING:
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
)
logger.warn(color('!!! received SABM when not in CONNECTING state', 'red'))
return
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
self.send_frame(RFCOMM_Frame.ua(c_r = 1 - self.c_r, dlci = self.dlci))
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
msc = RFCOMM_MCC_MSC(
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.change_state(DLC.CONNECTED)
self.emit('open')
def on_ua_frame(self, _frame):
def on_ua_frame(self, frame):
if self.state != DLC.CONNECTING:
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
)
logger.warn(color('!!! received SABM when not in CONNECTING state', 'red'))
return
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
msc = RFCOMM_MCC_MSC(
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.change_state(DLC.CONNECTED)
self.multiplexer.on_dlc_open_complete(self)
@@ -426,36 +384,29 @@ class DLC(EventEmitter):
# TODO: handle all states
pass
def on_disc_frame(self, _frame):
def on_disc_frame(self, frame):
# TODO: handle all states
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
self.send_frame(RFCOMM_Frame.ua(c_r = 1 - self.c_r, dlci = self.dlci))
def on_uih_frame(self, frame):
data = frame.information
if frame.p_f == 1:
# With credits
received_credits = frame.information[0]
self.tx_credits += received_credits
credits = frame.information[0]
self.tx_credits += credits
logger.debug(
f'<<< Credits [{self.dlci}]: '
f'received {received_credits}, total={self.tx_credits}'
)
logger.debug(f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}')
data = data[1:]
logger.debug(
f'{color("<<< Data", "yellow")} '
f'[{self.dlci}] {len(data)} bytes, '
f'rx_credits={self.rx_credits}: {data.hex()}'
)
logger.debug(f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}')
if len(data) and self.sink:
self.sink(data) # pylint: disable=not-callable
self.sink(data)
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
logger.warn(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits)
self.process_tx()
@@ -467,47 +418,69 @@ class DLC(EventEmitter):
if c_r:
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
msc = RFCOMM_MCC_MSC(
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 0, data = bytes(msc))
logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
else:
# Response
logger.debug(f'<<< MCC MSC Response: {msc}')
def connect(self):
if self.state != DLC.INIT:
if not self.state == DLC.INIT:
raise InvalidStateError('invalid state')
self.change_state(DLC.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
self.send_frame(
RFCOMM_Frame.sabm(
c_r = self.c_r,
dlci = self.dlci
)
)
def accept(self):
if self.state != DLC.INIT:
if not self.state == DLC.INIT:
raise InvalidStateError('invalid state')
pn = RFCOMM_MCC_PN(
dlci=self.dlci,
cl=0xE0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
dlci = self.dlci,
cl = 0xE0,
priority = 7,
ack_timer = 0,
max_frame_size = RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions = 0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 0, data = bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.change_state(DLC.CONNECTING)
def rx_credits_needed(self):
if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
return 0
else:
return 0
def process_tx(self):
# Send anything we can (or an empty frame if we need to send rx credits)
@@ -515,13 +488,13 @@ class DLC(EventEmitter):
while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
# Get the next chunk, up to MTU size
if rx_credits_needed > 0:
chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
chunk = bytes([rx_credits_needed]) + self.tx_buffer[:self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1:]
self.rx_credits += rx_credits_needed
tx_credit_spent = len(chunk) > 1
tx_credit_spent = (len(chunk) > 1)
else:
chunk = self.tx_buffer[: self.mtu]
self.tx_buffer = self.tx_buffer[len(chunk) :]
chunk = self.tx_buffer[:self.mtu]
self.tx_buffer = self.tx_buffer[len(chunk):]
tx_credit_spent = True
# Update the tx credits
@@ -530,17 +503,13 @@ class DLC(EventEmitter):
self.tx_credits -= 1
# Send the frame
logger.debug(
f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, '
f'rx_credits={self.rx_credits}, '
f'tx_credits={self.tx_credits}'
)
logger.debug(f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}')
self.send_frame(
RFCOMM_Frame.uih(
c_r=self.c_r,
dlci=self.dlci,
information=chunk,
p_f=1 if rx_credits_needed > 0 else 0,
c_r = self.c_r,
dlci = self.dlci,
information = chunk,
p_f = 1 if rx_credits_needed > 0 else 0
)
)
@@ -549,8 +518,8 @@ class DLC(EventEmitter):
# Stream protocol
def write(self, data):
# We can only send bytes
if not isinstance(data, bytes):
if isinstance(data, str):
if type(data) != bytes:
if type(data) == str:
# Automatically convert strings to bytes using UTF-8
data = data.encode('utf-8')
else:
@@ -574,34 +543,34 @@ class Multiplexer(EventEmitter):
RESPONDER = 0x01
# States
INIT = 0x00
CONNECTING = 0x01
CONNECTED = 0x02
OPENING = 0x03
INIT = 0x00
CONNECTING = 0x01
CONNECTED = 0x02
OPENING = 0x03
DISCONNECTING = 0x04
DISCONNECTED = 0x05
RESET = 0x06
DISCONNECTED = 0x05
RESET = 0x06
STATE_NAMES = {
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
OPENING: 'OPENING',
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
OPENING: 'OPENING',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET'
}
def __init__(self, l2cap_channel, role):
super().__init__()
self.role = role
self.l2cap_channel = l2cap_channel
self.state = Multiplexer.INIT
self.dlcs = {} # DLCs, by DLCI
self.connection_result = None
self.role = role
self.l2cap_channel = l2cap_channel
self.state = Multiplexer.INIT
self.dlcs = {} # DLCs, by DLCI
self.connection_result = None
self.disconnection_result = None
self.open_result = None
self.acceptor = None
self.open_result = None
self.acceptor = None
# Become a sink for the L2CAP channel
l2cap_channel.sink = self.on_pdu
@@ -611,9 +580,7 @@ class Multiplexer(EventEmitter):
return Multiplexer.STATE_NAMES[state]
def change_state(self, new_state):
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "cyan")}')
self.state = new_state
def send_frame(self, frame):
@@ -629,14 +596,14 @@ class Multiplexer(EventEmitter):
self.on_frame(frame)
else:
if frame.type == RFCOMM_DM_FRAME:
# DM responses are for a DLCI, but since we only create the dlc when we
# receive a PN response (because we need the parameters), we handle DM
# frames at the Multiplexer level
# DM responses are for a DLCI, but since we only create the dlc when we receive
# a PN response (because we need the parameters), we handle DM frames at the Multiplexer
# level
self.on_dm_frame(frame)
else:
dlc = self.dlcs.get(frame.dlci)
if dlc is None:
logger.warning(f'no dlc for DLCI {frame.dlci}')
logger.warn(f'no dlc for DLCI {frame.dlci}')
return
dlc.on_frame(frame)
@@ -644,14 +611,14 @@ class Multiplexer(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame):
def on_sabm_frame(self, frame):
if self.state != Multiplexer.INIT:
logger.debug('not in INIT state, ignoring SABM')
return
self.change_state(Multiplexer.CONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
self.send_frame(RFCOMM_Frame.ua(c_r = 1, dlci = 0))
def on_ua_frame(self, _frame):
def on_ua_frame(self, frame):
if self.state == Multiplexer.CONNECTING:
self.change_state(Multiplexer.CONNECTED)
if self.connection_result:
@@ -663,34 +630,25 @@ class Multiplexer(EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_dm_frame(self, _frame):
def on_dm_frame(self, frame):
if self.state == Multiplexer.OPENING:
self.change_state(Multiplexer.CONNECTED)
if self.open_result:
self.open_result.set_exception(
core.ConnectionError(
core.ConnectionError.CONNECTION_REFUSED,
BT_BR_EDR_TRANSPORT,
self.l2cap_channel.connection.peer_address,
'rfcomm',
)
)
self.open_result.set_exception(ConnectionError(ConnectionError.CONNECTION_REFUSED))
else:
logger.warning(f'unexpected state for DM: {self}')
logger.warn(f'unexpected state for DM: {self}')
def on_disc_frame(self, _frame):
def on_disc_frame(self, frame):
self.change_state(Multiplexer.DISCONNECTED)
self.send_frame(
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0)
)
self.send_frame(RFCOMM_Frame.ua(c_r = 0 if self.role == Multiplexer.INITIATOR else 1, dlci = 0))
def on_uih_frame(self, frame):
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
(type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if mcc_type == RFCOMM_MCC_PN_TYPE:
if type == RFCOMM_MCC_PN_TYPE:
pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn)
elif mcc_type == RFCOMM_MCC_MSC_TYPE:
elif type == RFCOMM_MCC_MSC_TYPE:
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)
@@ -706,7 +664,7 @@ class Multiplexer(EventEmitter):
if pn.dlci & 1:
# Not expected, this is an initiator-side number
# TODO: error out
logger.warning(f'invalid DLCI: {pn.dlci}')
logger.warn(f'invalid DLCI: {pn.dlci}')
else:
if self.acceptor:
channel_number = pn.dlci >> 1
@@ -722,10 +680,10 @@ class Multiplexer(EventEmitter):
dlc.accept()
else:
# No acceptor, we're in Disconnected Mode
self.send_frame(RFCOMM_Frame.dm(c_r=1, dlci=pn.dlci))
self.send_frame(RFCOMM_Frame.dm(c_r = 1, dlci = pn.dlci))
else:
# No acceptor?? shouldn't happen
logger.warning(color('!!! no acceptor registered', 'red'))
logger.warn(color('!!! no acceptor registered', 'red'))
else:
# Response
logger.debug(f'>>> PN Response: {pn}')
@@ -734,12 +692,12 @@ class Multiplexer(EventEmitter):
self.dlcs[pn.dlci] = dlc
dlc.connect()
else:
logger.warning('ignoring PN response')
logger.warn('ignoring PN response')
def on_mcc_msc(self, c_r, msc):
dlc = self.dlcs.get(msc.dlci)
if dlc is None:
logger.warning(f'no dlc for DLCI {msc.dlci}')
logger.warn(f'no dlc for DLCI {msc.dlci}')
return
dlc.on_mcc_msc(c_r, msc)
@@ -749,7 +707,7 @@ class Multiplexer(EventEmitter):
self.change_state(Multiplexer.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0))
self.send_frame(RFCOMM_Frame.sabm(c_r = 1, dlci = 0))
return await self.connection_result
async def disconnect(self):
@@ -758,38 +716,34 @@ class Multiplexer(EventEmitter):
self.disconnection_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.DISCONNECTING)
self.send_frame(
RFCOMM_Frame.disc(
c_r=1 if self.role == Multiplexer.INITIATOR else 0, dlci=0
)
)
self.send_frame(RFCOMM_Frame.disc(c_r = 1 if self.role == Multiplexer.INITIATOR else 0, dlci = 0))
await self.disconnection_result
async def open_dlc(self, channel):
if self.state != Multiplexer.CONNECTED:
if self.state == Multiplexer.OPENING:
raise InvalidStateError('open already in progress')
raise InvalidStateError('not connected')
else:
raise InvalidStateError('not connected')
pn = RFCOMM_MCC_PN(
dlci=channel << 1,
cl=0xF0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
dlci = channel << 1,
cl = 0xF0,
priority = 7,
ack_timer = 0,
max_frame_size = RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions = 0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 1, data = bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.OPENING)
self.send_frame(
RFCOMM_Frame.uih(
c_r=1 if self.role == Multiplexer.INITIATOR else 0,
dlci=0,
information=mcc,
c_r = 1 if self.role == Multiplexer.INITIATOR else 0,
dlci = 0,
information = mcc
)
)
result = await self.open_result
@@ -809,19 +763,17 @@ class Multiplexer(EventEmitter):
# -----------------------------------------------------------------------------
class Client:
def __init__(self, device, connection):
self.device = device
self.connection = connection
self.device = device
self.connection = connection
self.l2cap_channel = None
self.multiplexer = None
self.multiplexer = None
async def start(self):
# Create a new L2CAP connection
try:
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
self.connection, RFCOMM_PSM
)
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(self.connection, RFCOMM_PSM)
except ProtocolError as error:
logger.warning(f'L2CAP connection failed: {error}')
logger.warn(f'L2CAP connection failed: {error}')
raise
# Create a mutliplexer to manage DLCs with the server
@@ -845,34 +797,22 @@ class Client:
class Server(EventEmitter):
def __init__(self, device):
super().__init__()
self.device = device
self.device = device
self.multiplexer = None
self.acceptors = {}
self.acceptors = {}
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
def listen(self, acceptor, channel=0):
if channel:
if channel in self.acceptors:
# Busy
return 0
else:
# Find a free channel number
for candidate in range(
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START,
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1,
):
if candidate not in self.acceptors:
channel = candidate
break
def listen(self, acceptor):
# Find a free channel number
for channel in range(RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1):
if channel not in self.acceptors:
self.acceptors[channel] = acceptor
return channel
if channel == 0:
# All channels used...
return 0
self.acceptors[channel] = acceptor
return channel
# All channels used...
return 0
def on_connection(self, l2cap_channel):
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')

View File

@@ -15,13 +15,12 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import logging
import struct
from typing import Dict, List, Type
from colors import color
import colors
from . import core
from .colors import color
from .core import InvalidStateError
from .hci import HCI_Object, name_or_number, key_with_value
@@ -34,9 +33,6 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do
SDP_PSM = 0x0001
@@ -116,162 +112,137 @@ SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
# To be used in searches where an attribute ID list allows a range to be specified
SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
class DataElement:
NIL = 0
NIL = 0
UNSIGNED_INTEGER = 1
SIGNED_INTEGER = 2
UUID = 3
TEXT_STRING = 4
BOOLEAN = 5
SEQUENCE = 6
ALTERNATIVE = 7
URL = 8
SIGNED_INTEGER = 2
UUID = 3
TEXT_STRING = 4
BOOLEAN = 5
SEQUENCE = 6
ALTERNATIVE = 7
URL = 8
TYPE_NAMES = {
NIL: 'NIL',
NIL: 'NIL',
UNSIGNED_INTEGER: 'UNSIGNED_INTEGER',
SIGNED_INTEGER: 'SIGNED_INTEGER',
UUID: 'UUID',
TEXT_STRING: 'TEXT_STRING',
BOOLEAN: 'BOOLEAN',
SEQUENCE: 'SEQUENCE',
ALTERNATIVE: 'ALTERNATIVE',
URL: 'URL',
SIGNED_INTEGER: 'SIGNED_INTEGER',
UUID: 'UUID',
TEXT_STRING: 'TEXT_STRING',
BOOLEAN: 'BOOLEAN',
SEQUENCE: 'SEQUENCE',
ALTERNATIVE: 'ALTERNATIVE',
URL: 'URL'
}
type_constructors = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(DataElement.UNSIGNED_INTEGER, DataElement.unsigned_integer_from_bytes(x), value_size=y),
SIGNED_INTEGER: lambda x, y: DataElement(DataElement.SIGNED_INTEGER, DataElement.signed_integer_from_bytes(x), value_size=y),
UUID: lambda x: DataElement(DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(DataElement.SEQUENCE, DataElement.list_from_bytes(x)),
ALTERNATIVE: lambda x: DataElement(DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8'))
}
def __init__(self, element_type, value, value_size=None):
self.type = element_type
self.value = value
def __init__(self, type, value, value_size=None):
self.type = type
self.value = value
self.value_size = value_size
# Used as a cache when parsing from bytes so we can emit a byte-for-byte replica
self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
self.bytes = None # Used a cache when parsing from bytes so we can emit a byte-for-byte replica
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER:
if value_size is None:
raise ValueError('integer types must have a value size specified')
@staticmethod
def nil() -> DataElement:
def nil():
return DataElement(DataElement.NIL, None)
@staticmethod
def unsigned_integer(value: int, value_size: int) -> DataElement:
def unsigned_integer(value, value_size):
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
@staticmethod
def unsigned_integer_8(value: int) -> DataElement:
def unsigned_integer_8(value):
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
@staticmethod
def unsigned_integer_16(value: int) -> DataElement:
def unsigned_integer_16(value):
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
@staticmethod
def unsigned_integer_32(value: int) -> DataElement:
def unsigned_integer_32(value):
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
@staticmethod
def signed_integer(value: int, value_size: int) -> DataElement:
def signed_integer(value, value_size):
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
@staticmethod
def signed_integer_8(value: int) -> DataElement:
def signed_integer_8(value):
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
@staticmethod
def signed_integer_16(value: int) -> DataElement:
def signed_integer_16(value):
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
@staticmethod
def signed_integer_32(value: int) -> DataElement:
def signed_integer_32(value):
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
@staticmethod
def uuid(value: core.UUID) -> DataElement:
def uuid(value):
return DataElement(DataElement.UUID, value)
@staticmethod
def text_string(value: str) -> DataElement:
def text_string(value):
return DataElement(DataElement.TEXT_STRING, value)
@staticmethod
def boolean(value: bool) -> DataElement:
def boolean(value):
return DataElement(DataElement.BOOLEAN, value)
@staticmethod
def sequence(value: List[DataElement]) -> DataElement:
def sequence(value):
return DataElement(DataElement.SEQUENCE, value)
@staticmethod
def alternative(value: List[DataElement]) -> DataElement:
def alternative(value):
return DataElement(DataElement.ALTERNATIVE, value)
@staticmethod
def url(value: str) -> DataElement:
def url(value):
return DataElement(DataElement.URL, value)
@staticmethod
def unsigned_integer_from_bytes(data):
if len(data) == 1:
return data[0]
if len(data) == 2:
elif len(data) == 2:
return struct.unpack('>H', data)[0]
if len(data) == 4:
elif len(data) == 4:
return struct.unpack('>I', data)[0]
if len(data) == 8:
elif len(data) == 8:
return struct.unpack('>Q', data)[0]
raise ValueError(f'invalid integer length {len(data)}')
else:
raise ValueError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
if len(data) == 1:
return struct.unpack('b', data)[0]
if len(data) == 2:
elif len(data) == 2:
return struct.unpack('>h', data)[0]
if len(data) == 4:
elif len(data) == 4:
return struct.unpack('>i', data)[0]
if len(data) == 8:
elif len(data) == 8:
return struct.unpack('>q', data)[0]
raise ValueError(f'invalid integer length {len(data)}')
else:
raise ValueError(f'invalid integer length {len(data)}')
@staticmethod
def list_from_bytes(data):
@@ -279,7 +250,7 @@ class DataElement:
while data:
element = DataElement.from_bytes(data)
elements.append(element)
data = data[len(bytes(element)) :]
data = data[len(bytes(element)):]
return elements
@staticmethod
@@ -289,11 +260,11 @@ class DataElement:
@staticmethod
def from_bytes(data):
element_type = data[0] >> 3
size_index = data[0] & 7
type = data[0] >> 3
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if element_type == DataElement.NIL:
if type == DataElement.NIL:
value_size = 0
else:
value_size = 1
@@ -315,21 +286,16 @@ class DataElement:
value_size = struct.unpack('>I', data[1:5])[0]
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.type_constructors.get(element_type)
value_data = data[1 + value_offset:1 + value_offset + value_size]
constructor = DataElement.type_constructors.get(type)
if constructor:
if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER:
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(element_type, value_data)
result.bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
result = DataElement(type, value_data)
result.bytes = data[:1 + value_offset + value_size] # Keep a copy so we can re-serialize to an exact replica
return result
def to_bytes(self):
@@ -345,8 +311,7 @@ class DataElement:
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise ValueError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1:
elif self.value_size == 1:
data = struct.pack('B', self.value)
elif self.value_size == 2:
data = struct.pack('>H', self.value)
@@ -369,11 +334,11 @@ class DataElement:
raise ValueError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
elif self.type == DataElement.TEXT_STRING or self.type == DataElement.URL:
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
elif self.type == DataElement.SEQUENCE or self.type == DataElement.ALTERNATIVE:
data = b''.join([bytes(element) for element in self.value])
else:
data = self.value
@@ -384,11 +349,9 @@ class DataElement:
if size != 0:
raise ValueError('NIL must be empty')
size_index = 0
elif self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
elif (self.type == DataElement.UNSIGNED_INTEGER or
self.type == DataElement.SIGNED_INTEGER or
self.type == DataElement.UUID):
if size <= 1:
size_index = 0
elif size == 2:
@@ -401,12 +364,10 @@ class DataElement:
size_index = 4
else:
raise ValueError('invalid data size')
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
DataElement.ALTERNATIVE,
DataElement.URL,
):
elif (self.type == DataElement.TEXT_STRING or
self.type == DataElement.SEQUENCE or
self.type == DataElement.ALTERNATIVE or
self.type == DataElement.URL):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
@@ -431,19 +392,11 @@ class DataElement:
type_name = name_or_number(self.TYPE_NAMES, self.type)
if self.type == DataElement.NIL:
value_string = ''
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
elif self.type == DataElement.SEQUENCE or self.type == DataElement.ALTERNATIVE:
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
value_string = f'[{container_separator}{element_separator.join([element.to_string(pretty, indentation + 1 if pretty else 0) for element in self.value])}{container_separator}{prefix}]'
elif self.type == DataElement.UNSIGNED_INTEGER or self.type == DataElement.SIGNED_INTEGER:
value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
@@ -457,17 +410,17 @@ class DataElement:
# -----------------------------------------------------------------------------
class ServiceAttribute:
def __init__(self, attribute_id: int, value: DataElement) -> None:
self.id = attribute_id
def __init__(self, id, value):
self.id = id
self.value = value
@staticmethod
def list_from_data_elements(elements):
attribute_list = []
for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
attribute_id, attribute_value = elements[2 * i:2 * (i + 1)]
if attribute_id.type != DataElement.UNSIGNED_INTEGER:
logger.warning('attribute ID element is not an integer')
logger.warn('attribute ID element is not an integer')
continue
attribute_list.append(ServiceAttribute(attribute_id.value, attribute_value))
@@ -475,41 +428,30 @@ class ServiceAttribute:
@staticmethod
def find_attribute_in_list(attribute_list, attribute_id):
return next(
(
attribute.value
for attribute in attribute_list
if attribute.id == attribute_id
),
None,
)
return next((attribute.value for attribute in attribute_list if attribute.id == attribute_id), None)
@staticmethod
def id_name(id_code):
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
def id_name(id):
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id)
@staticmethod
def is_uuid_in_value(uuid, value):
# Find if a uuid matches a value, either directly or recursing into sequences
if value.type == DataElement.UUID:
return value.value == uuid
if value.type == DataElement.SEQUENCE:
elif value.type == DataElement.SEQUENCE:
for element in value.value:
if ServiceAttribute.is_uuid_in_value(uuid, element):
return True
return False
else:
return False
return False
def to_string(self, with_colors=False):
if with_colors:
return (
f'Attribute(id={color(self.id_name(self.id),"magenta")},'
f'value={self.value})'
)
return f'Attribute(id={self.id_name(self.id)},value={self.value})'
def to_string(self, color=False):
if color:
return f'Attribute(id={colors.color(self.id_name(self.id),"magenta")},value={self.value})'
else:
return f'Attribute(id={self.id_name(self.id)},value={self.value})'
def __str__(self):
return self.to_string()
@@ -520,14 +462,11 @@ class SDP_PDU:
'''
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
'''
sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {}
name = None
pdu_id = 0
sdp_pdu_classes = {}
@staticmethod
def from_bytes(pdu):
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
pdu_id, transaction_id, parameters_length = struct.unpack_from('>BHH', pdu, 0)
cls = SDP_PDU.sdp_pdu_classes.get(pdu_id)
if cls is None:
@@ -545,15 +484,13 @@ class SDP_PDU:
@staticmethod
def parse_service_record_handle_list_preceded_by_count(data, offset):
count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
]
handle_list = [struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)]
return offset + count * 4, handle_list
@staticmethod
def parse_bytes_preceded_by_length(data, offset):
length = struct.unpack_from('>H', data, offset - 2)[0]
return offset + length, data[offset : offset + length]
return offset + length, data[offset:offset + length]
@staticmethod
def error_name(error_code):
@@ -595,10 +532,7 @@ class SDP_PDU:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
parameters = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters))
+ parameters
)
pdu = struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters)) + parameters
self.pdu = pdu
self.transaction_id = transaction_id
@@ -621,7 +555,9 @@ class SDP_PDU:
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})])
@SDP_PDU.subclass([
('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})
])
class SDP_ErrorResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
@@ -629,13 +565,11 @@ class SDP_ErrorResponse(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_service_record_count', '>2'),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass([
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_service_record_count', '>2'),
('continuation_state', '*')
])
class SDP_ServiceSearchRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
@@ -643,17 +577,12 @@ class SDP_ServiceSearchRequest(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('total_service_record_count', '>2'),
('current_service_record_count', '>2'),
(
'service_record_handle_list',
SDP_PDU.parse_service_record_handle_list_preceded_by_count,
),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass([
('total_service_record_count', '>2'),
('current_service_record_count', '>2'),
('service_record_handle_list', SDP_PDU.parse_service_record_handle_list_preceded_by_count),
('continuation_state', '*')
])
class SDP_ServiceSearchResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
@@ -661,14 +590,12 @@ class SDP_ServiceSearchResponse(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('service_record_handle', '>4'),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass([
('service_record_handle', '>4'),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*')
])
class SDP_ServiceAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
@@ -676,13 +603,11 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('attribute_list_byte_count', '>2'),
('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass([
('attribute_list_byte_count', '>2'),
('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*')
])
class SDP_ServiceAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
@@ -690,14 +615,12 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass([
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*')
])
class SDP_ServiceSearchAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
@@ -705,13 +628,11 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass(
[
('attribute_lists_byte_count', '>2'),
('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*'),
]
)
@SDP_PDU.subclass([
('attribute_lists_byte_count', '>2'),
('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*')
])
class SDP_ServiceSearchAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
@@ -721,9 +642,9 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
# -----------------------------------------------------------------------------
class Client:
def __init__(self, device):
self.device = device
self.device = device
self.pending_request = None
self.channel = None
self.channel = None
async def connect(self, connection):
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
@@ -738,9 +659,7 @@ class Client:
if self.pending_request is not None:
raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids])
# Request and accumulate until there's no more continuation
service_record_handle_list = []
@@ -749,10 +668,10 @@ class Client:
while watchdog > 0:
response_pdu = await self.channel.send_request(
SDP_ServiceSearchRequest(
transaction_id=0, # Transaction ID TODO: pick a real value
service_search_pattern=service_search_pattern,
maximum_service_record_count=0xFFFF,
continuation_state=continuation_state,
transaction_id = 0, # Transaction ID TODO: pick a real value
service_search_pattern = service_search_pattern,
maximum_service_record_count = 0xFFFF,
continuation_state = continuation_state
)
)
response = SDP_PDU.from_bytes(response_pdu)
@@ -770,15 +689,11 @@ class Client:
if self.pending_request is not None:
raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids])
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1])
if type(attribute_id) is tuple
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
@@ -791,11 +706,11 @@ class Client:
while watchdog > 0:
response_pdu = await self.channel.send_request(
SDP_ServiceSearchAttributeRequest(
transaction_id=0, # Transaction ID TODO: pick a real value
service_search_pattern=service_search_pattern,
maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list,
continuation_state=continuation_state,
transaction_id = 0, # Transaction ID TODO: pick a real value
service_search_pattern = service_search_pattern,
maximum_attribute_byte_count = 0xFFFF,
attribute_id_list = attribute_id_list,
continuation_state = continuation_state
)
)
response = SDP_PDU.from_bytes(response_pdu)
@@ -810,7 +725,7 @@ class Client:
# Parse the result into attribute lists
attribute_lists_sequences = DataElement.from_bytes(accumulator)
if attribute_lists_sequences.type != DataElement.SEQUENCE:
logger.warning('unexpected data type')
logger.warn('unexpected data type')
return []
return [
@@ -825,10 +740,8 @@ class Client:
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1])
if type(attribute_id) is tuple
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
@@ -841,11 +754,11 @@ class Client:
while watchdog > 0:
response_pdu = await self.channel.send_request(
SDP_ServiceAttributeRequest(
transaction_id=0, # Transaction ID TODO: pick a real value
service_record_handle=service_record_handle,
maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list,
continuation_state=continuation_state,
transaction_id = 0, # Transaction ID TODO: pick a real value
service_record_handle = service_record_handle,
maximum_attribute_byte_count = 0xFFFF,
attribute_id_list = attribute_id_list,
continuation_state = continuation_state
)
)
response = SDP_PDU.from_bytes(response_pdu)
@@ -860,7 +773,7 @@ class Client:
# Parse the result into a list of attributes
attribute_list_sequence = DataElement.from_bytes(accumulator)
if attribute_list_sequence.type != DataElement.SEQUENCE:
logger.warning('unexpected data type')
logger.warn('unexpected data type')
return []
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
@@ -871,9 +784,8 @@ class Server:
CONTINUATION_STATE = bytes([0x01, 0x43])
def __init__(self, device):
self.device = device
self.service_records = {} # Service records maps, by record handle
self.channel = None
self.device = device
self.service_records = {} # Service records maps, by record handle
self.current_response = None
def register(self, l2cap_channel_manager):
@@ -908,10 +820,11 @@ class Server:
try:
sdp_pdu = SDP_PDU.from_bytes(pdu)
except Exception as error:
logger.warning(color(f'failed to parse SDP Request PDU: {error}', 'red'))
logger.warn(color(f'failed to parse SDP Request PDU: {error}', 'red'))
self.send_response(
SDP_ErrorResponse(
transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
transaction_id = 0,
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR
)
)
@@ -927,16 +840,16 @@ class Server:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,
error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR,
transaction_id = sdp_pdu.transaction_id,
error_code = SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR
)
)
else:
logger.error(color('SDP Request not handled???', 'red'))
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,
error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR,
transaction_id = sdp_pdu.transaction_id,
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR
)
)
@@ -959,18 +872,17 @@ class Server:
if attribute_id.value_size == 4:
# Attribute ID range
id_range_start = attribute_id.value >> 16
id_range_end = attribute_id.value & 0xFFFF
id_range_end = attribute_id.value & 0xFFFF
else:
id_range_start = attribute_id.value
id_range_end = attribute_id.value
id_range_end = attribute_id.value
attributes += [
attribute
for attribute in service
attribute for attribute in service
if attribute.id >= id_range_start and attribute.id <= id_range_end
]
# Return the matching attributes, sorted by attribute id
attributes.sort(key=lambda x: x.id)
# Return the maching attributes, sorted by attribute id
attributes.sort(key = lambda x: x.id)
attribute_list = DataElement.sequence([])
for attribute in attributes:
attribute_list.value.append(DataElement.unsigned_integer_16(attribute.id))
@@ -984,8 +896,8 @@ class Server:
if not self.current_response:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
transaction_id = request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR
)
)
return
@@ -998,38 +910,30 @@ class Server:
service_record_handles = list(matching_services.keys())
# Only return up to the maximum requested
service_record_handles_subset = service_record_handles[
: request.maximum_service_record_count
]
service_record_handles_subset = service_record_handles[:request.maximum_service_record_count]
# Serialize to a byte array, and remember the total count
logger.debug(f'Service Record Handles: {service_record_handles}')
self.current_response = (
len(service_record_handles),
service_record_handles_subset,
service_record_handles_subset
)
# Respond, keeping any unsent handles for later
service_record_handles = self.current_response[1][
: request.maximum_service_record_count
]
service_record_handles = self.current_response[1][:request.maximum_service_record_count]
self.current_response = (
self.current_response[0],
self.current_response[1][request.maximum_service_record_count :],
)
continuation_state = (
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
)
service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles]
self.current_response[1][request.maximum_service_record_count:]
)
continuation_state = Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
service_record_handle_list = b''.join([struct.pack('>I', handle) for handle in service_record_handles])
self.send_response(
SDP_ServiceSearchResponse(
transaction_id=request.transaction_id,
total_service_record_count=self.current_response[0],
current_service_record_count=len(service_record_handles),
service_record_handle_list=service_record_handle_list,
continuation_state=continuation_state,
transaction_id = request.transaction_id,
total_service_record_count = self.current_response[0],
current_service_record_count = len(service_record_handles),
service_record_handle_list = service_record_handle_list,
continuation_state = continuation_state
)
)
@@ -1039,8 +943,8 @@ class Server:
if not self.current_response:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
transaction_id = request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR
)
)
return
@@ -1053,31 +957,27 @@ class Server:
if service is None:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR,
transaction_id = request.transaction_id,
error_code = SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR
)
)
return
# Get the attributes for the service
attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value)
# Serialize to a byte array
logger.debug(f'Attributes: {attribute_list}')
self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later
attribute_list, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
attribute_list, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count)
self.send_response(
SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list),
attribute_list=attribute_list,
continuation_state=continuation_state,
transaction_id = request.transaction_id,
attribute_list_byte_count = len(attribute_list),
attribute_list = attribute_list,
continuation_state = continuation_state
)
)
@@ -1087,8 +987,8 @@ class Server:
if not self.current_response:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
transaction_id = request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR
)
)
else:
@@ -1096,16 +996,12 @@ class Server:
self.current_response = None
# Find the matching services
matching_services = self.match_services(
request.service_search_pattern
).values()
matching_services = self.match_services(request.service_search_pattern).values()
# Filter the required attributes
attribute_lists = DataElement.sequence([])
for service in matching_services:
attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value)
if attribute_list.value:
attribute_lists.value.append(attribute_list)
@@ -1114,14 +1010,12 @@ class Server:
self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later
attribute_lists, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
attribute_lists, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count)
self.send_response(
SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists),
attribute_lists=attribute_lists,
continuation_state=continuation_state,
transaction_id = request.transaction_id,
attribute_lists_byte_count = len(attribute_lists),
attribute_lists = attribute_lists,
continuation_state = continuation_state
)
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,170 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from contextlib import contextmanager
from enum import IntEnum
import logging
import struct
import datetime
from typing import BinaryIO, Generator
import os
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Snooper:
"""
Base class for snooper implementations.
A snooper is an object that will be provided with HCI packets as they are
exchanged between a host and a controller.
"""
class Direction(IntEnum):
HOST_TO_CONTROLLER = 0
CONTROLLER_TO_HOST = 1
class DataLinkType(IntEnum):
H1 = 1001
H4 = 1002
HCI_BSCP = 1003
H5 = 1004
def snoop(self, hci_packet: bytes, direction: Direction) -> None:
"""Snoop on an HCI packet."""
# -----------------------------------------------------------------------------
class BtSnooper(Snooper):
"""
Snooper that saves HCI packets using the BTSnoop format, based on RFC 1761.
"""
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MS = datetime.timedelta(microseconds=1)
def __init__(self, output: BinaryIO):
self.output = output
# Write the header
self.output.write(
self.IDENTIFICATION_PATTERN + struct.pack('>LL', 1, self.DataLinkType.H4)
)
def snoop(self, hci_packet: bytes, direction: Snooper.Direction) -> None:
flags = int(direction)
packet_type = hci_packet[0]
if packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET):
flags |= 0x10
# Compute the current timestamp
timestamp = (
int((datetime.datetime.utcnow() - self.TIMESTAMP_ANCHOR) / self.ONE_MS)
+ self.TIMESTAMP_DELTA
)
# Emit the record
self.output.write(
struct.pack(
'>IIIIQ',
len(hci_packet), # Original Length
len(hci_packet), # Included Length
flags, # Packet Flags
0, # Cumulative Drops
timestamp, # Timestamp
)
+ hci_packet
)
# -----------------------------------------------------------------------------
_SNOOPER_INSTANCE_COUNT = 0
@contextmanager
def create_snooper(spec: str) -> Generator[Snooper, None, None]:
"""
Create a snooper given a specification string.
The general syntax for the specification string is:
<snooper-type>:<type-specific-arguments>
Supported snooper types are:
btsnoop
The syntax for the type-specific arguments for this type is:
<io-type>:<io-type-specific-arguments>
Supported I/O types are:
file
The type-specific arguments for this I/O type is a string that is converted
to a file path using the python `str.format()` string formatting. The log
records will be written to that file if it can be opened/created.
The keyword args that may be referenced by the string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.utcnow()`
pid: the current process ID.
instance: the instance ID in the current process.
Examples:
btsnoop:file:my_btsnoop.log
btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log
"""
if ':' not in spec:
raise ValueError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop':
if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop snooper type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file':
# Process the file name string pattern.
global _SNOOPER_INSTANCE_COUNT
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.utcnow(),
pid=os.getpid(),
instance=_SNOOPER_INSTANCE_COUNT,
)
# Open the file
logger.debug(f'Snoop file: {file_path}')
with open(file_path, 'wb') as snoop_file:
_SNOOPER_INSTANCE_COUNT += 1
yield BtSnooper(snoop_file)
_SNOOPER_INSTANCE_COUNT -= 1
return
raise ValueError(f'I/O type {io_type} not supported')
raise ValueError(f'snooper type {snooper_type} not found')

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,13 +15,11 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from contextlib import asynccontextmanager
import logging
import os
from .common import Transport, AsyncPipeSink, SnoopingTransport
from .common import Transport, AsyncPipeSink
from ..link import RemoteLink
from ..controller import Controller
from ..snoop import create_snooper
# -----------------------------------------------------------------------------
# Logging
@@ -30,148 +28,68 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def _wrap_transport(transport: Transport) -> Transport:
"""
Automatically wrap a Transport instance when a wrapping class can be inferred
from the environment.
If no wrapping class is applicable, the transport argument is returned as-is.
"""
# If BUMBLE_SNOOPER is set, try to automatically create a snooper.
if snooper_spec := os.getenv('BUMBLE_SNOOPER'):
try:
return SnoopingTransport.create_with(
transport, create_snooper(snooper_spec)
)
except Exception as exc:
logger.warning(f'Exception while creating snooper: {exc}')
return transport
# -----------------------------------------------------------------------------
async def open_transport(name: str) -> Transport:
"""
async def open_transport(name):
'''
Open a transport by name.
The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types).
The supported types are:
* serial
* udp
* tcp-client
* tcp-server
* ws-client
* ws-server
* pty
* file
* vhci
* hci-socket
* usb
* pyusb
* android-emulator
"""
return _wrap_transport(await _open_transport(name))
# -----------------------------------------------------------------------------
async def _open_transport(name: str) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
The supported types are: serial,udp,tcp,pty,usb
'''
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec:
from .serial import open_serial_transport
return await open_serial_transport(spec[0])
if scheme == 'udp' and spec:
elif scheme == 'udp' and spec:
from .udp import open_udp_transport
return await open_udp_transport(spec[0])
if scheme == 'tcp-client' and spec:
elif scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0])
if scheme == 'tcp-server' and spec:
elif scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0])
if scheme == 'ws-client' and spec:
elif scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0])
if scheme == 'ws-server' and spec:
elif scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0])
if scheme == 'pty':
elif scheme == 'pty':
from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None)
if scheme == 'file':
elif scheme == 'file':
from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None)
if scheme == 'vhci':
elif scheme == 'vhci':
from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None)
if scheme == 'hci-socket':
elif scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None)
if scheme == 'usb':
elif scheme == 'usb':
from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None)
if scheme == 'pyusb':
elif scheme == 'pyusb':
from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None)
if scheme == 'android-emulator':
elif scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None)
raise ValueError('unknown transport scheme')
else:
raise ValueError('unknown transport scheme')
# -----------------------------------------------------------------------------
async def open_transport_or_link(name: str) -> Transport:
"""
Open a transport or a link relay.
Args:
name:
Name of the transport or link relay to open.
When the name starts with "link-relay:", open a link relay (see RemoteLink
for details on what the arguments are).
For other namespaces, see `open_transport`.
"""
async def open_transport_or_link(name):
if name.startswith('link-relay:'):
from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:])
await link.wait_until_connected()
controller = Controller('remote', link=link)
controller = Controller('remote', link = link)
class LinkTransport(Transport):
async def close(self):
link.close()
return _wrap_transport(LinkTransport(controller, AsyncPipeSink(controller)))
return await open_transport(name)
return LinkTransport(controller, AsyncPipeSink(controller))
else:
return await open_transport(name)

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,10 +20,8 @@ import grpc
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink
from .emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
from .emulated_bluetooth_vhci_pb2_grpc import VhciForwardingServiceStub
# pylint: disable-next=no-name-in-module
from .emulated_bluetooth_packets_pb2 import HCIPacket
from .emulated_bluetooth_vhci_pb2_grpc import VhciForwardingServiceStub
# -----------------------------------------------------------------------------
@@ -61,10 +59,15 @@ async def open_android_emulator_transport(spec):
return bytes([packet.type]) + packet.packet
async def write(self, packet):
await self.hci_device.write(HCIPacket(type=packet[0], packet=packet[1:]))
await self.hci_device.write(
HCIPacket(
type = packet[0],
packet = packet[1:]
)
)
# Parse the parameters
mode = 'host'
mode = 'host'
server_host = 'localhost'
server_port = 8554
if spec is not None:
@@ -97,7 +100,7 @@ async def open_android_emulator_transport(spec):
transport = PumpedTransport(
PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write),
channel.close,
channel.close
)
transport.start()

View File

@@ -15,16 +15,12 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import contextlib
import struct
import asyncio
import logging
from typing import ContextManager
from colors import color
from .. import hci
from ..colors import color
from ..snoop import Snooper
# -----------------------------------------------------------------------------
@@ -37,10 +33,10 @@ logger = logging.getLogger(__name__)
# For each packet type, the info represents:
# (length-size, length-offset, unpack-type)
HCI_PACKET_INFO = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B')
}
@@ -52,7 +48,7 @@ class PacketPump:
def __init__(self, reader, sink):
self.reader = reader
self.sink = sink
self.sink = sink
async def run(self):
while True:
@@ -69,51 +65,43 @@ class PacketPump:
# -----------------------------------------------------------------------------
class PacketParser:
'''
In-line parser that accepts data and emits 'on_packet' when a full packet has been
parsed
In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed
'''
# pylint: disable=attribute-defined-outside-init
NEED_TYPE = 0
NEED_TYPE = 0
NEED_LENGTH = 1
NEED_BODY = 2
NEED_BODY = 2
def __init__(self, sink=None):
def __init__(self, sink = None):
self.sink = sink
self.extended_packet_info = {}
self.reset()
def reset(self):
self.state = PacketParser.NEED_TYPE
self.state = PacketParser.NEED_TYPE
self.bytes_needed = 1
self.packet = bytearray()
self.packet_info = None
self.packet = bytearray()
self.packet_info = None
def feed_data(self, data):
data_offset = 0
data_left = len(data)
while data_left and self.bytes_needed:
consumed = min(self.bytes_needed, data_left)
self.packet.extend(data[data_offset : data_offset + consumed])
data_offset += consumed
data_left -= consumed
self.packet.extend(data[data_offset:data_offset + consumed])
data_offset += consumed
data_left -= consumed
self.bytes_needed -= consumed
if self.bytes_needed == 0:
if self.state == PacketParser.NEED_TYPE:
packet_type = self.packet[0]
self.packet_info = HCI_PACKET_INFO.get(
packet_type
) or self.extended_packet_info.get(packet_type)
self.packet_info = HCI_PACKET_INFO.get(packet_type) or self.extended_packet_info.get(packet_type)
if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}')
self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH:
body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0]
body_length = struct.unpack_from(self.packet_info[2], self.packet, 1 + self.packet_info[1])[0]
self.bytes_needed = body_length
self.state = PacketParser.NEED_BODY
@@ -123,9 +111,7 @@ class PacketParser:
try:
self.sink.on_packet(bytes(self.packet))
except Exception as error:
logger.warning(
color(f'!!! Exception in on_packet: {error}', 'red')
)
logger.warning(color(f'!!! Exception in on_packet: {error}', 'red'))
self.reset()
def set_packet_sink(self, sink):
@@ -201,7 +187,6 @@ class AsyncPipeSink:
'''
Sink that forwards packets asynchronously to another sink
'''
def __init__(self, sink):
self.sink = sink
self.loop = asyncio.get_running_loop()
@@ -217,7 +202,7 @@ class ParserSource:
"""
def __init__(self):
self.parser = PacketParser()
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future()
def set_packet_sink(self, sink):
@@ -250,23 +235,9 @@ class StreamPacketSink:
# -----------------------------------------------------------------------------
class Transport:
"""
Base class for all transports.
A Transport represents a source and a sink together.
An instance must be closed by calling close() when no longer used. Instances
implement the ContextManager protocol so that they may be used in a `async with`
statement.
An instance is iterable. The iterator yields, in order, its source and sink, so
that it may be used with a convenient call syntax like:
async with create_transport() as (source, sink):
...
"""
def __init__(self, source, sink):
self.source = source
self.sink = sink
self.sink = sink
async def __aenter__(self):
return self
@@ -277,7 +248,7 @@ class Transport:
def __iter__(self):
return iter((self.source, self.sink))
async def close(self) -> None:
async def close(self):
self.source.close()
self.sink.close()
@@ -287,7 +258,7 @@ class PumpedPacketSource(ParserSource):
def __init__(self, receive):
super().__init__()
self.receive_function = receive
self.pump_task = None
self.pump_task = None
def start(self):
async def pump_packets():
@@ -299,11 +270,11 @@ class PumpedPacketSource(ParserSource):
logger.debug('source pump task done')
break
except Exception as error:
logger.warning(f'exception while waiting for packet: {error}')
logger.warn(f'exception while waiting for packet: {error}')
self.terminated.set_result(error)
break
self.pump_task = asyncio.create_task(pump_packets())
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
def close(self):
if self.pump_task:
@@ -314,8 +285,8 @@ class PumpedPacketSource(ParserSource):
class PumpedPacketSink:
def __init__(self, send):
self.send_function = send
self.packet_queue = asyncio.Queue()
self.pump_task = None
self.packet_queue = asyncio.Queue()
self.pump_task = None
def on_packet(self, packet):
self.packet_queue.put_nowait(packet)
@@ -330,10 +301,10 @@ class PumpedPacketSink:
logger.debug('sink pump task done')
break
except Exception as error:
logger.warning(f'exception while sending packet: {error}')
logger.warn(f'exception while sending packet: {error}')
break
self.pump_task = asyncio.create_task(pump_packets())
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
def close(self):
if self.pump_task:
@@ -353,60 +324,3 @@ class PumpedTransport(Transport):
async def close(self):
await super().close()
await self.close_function()
# -----------------------------------------------------------------------------
class SnoopingTransport(Transport):
"""Transport wrapper that snoops on packets to/from a wrapped transport."""
@staticmethod
def create_with(
transport: Transport, snooper: ContextManager[Snooper]
) -> SnoopingTransport:
"""
Create an instance given a snooper that works as as context manager.
The returned instance will exit the snooper context when it is closed.
"""
with contextlib.ExitStack() as exit_stack:
return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
)
raise RuntimeError('unexpected code path') # Satisfy the type checker
class Source:
def __init__(self, source, snooper):
self.source = source
self.snooper = snooper
self.sink = None
def set_packet_sink(self, sink):
self.sink = sink
self.source.set_packet_sink(self)
def on_packet(self, packet):
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
if self.sink:
self.sink.on_packet(packet)
class Sink:
def __init__(self, sink, snooper):
self.sink = sink
self.snooper = snooper
def on_packet(self, packet):
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
if self.sink:
self.sink.on_packet(packet)
def __init__(self, transport, snooper, close_snooper=None):
super().__init__(
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
)
self.transport = transport
self.close_snooper = close_snooper
async def close(self):
await self.transport.close()
if self.close_snooper:
self.close_snooper()

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,30 +16,37 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth_packets.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, 'emulated_bluetooth_packets_pb2', globals()
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
_HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket']
_HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType']
HCIPacket = _reflection.GeneratedProtocolMessageType('HCIPacket', (_message.Message,), {
'DESCRIPTOR' : _HCIPACKET,
'__module__' : 'emulated_bluetooth_packets_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket)
})
_sym_db.RegisterMessage(HCIPacket)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_HCIPACKET._serialized_start = 66
_HCIPACKET._serialized_end = 317
_HCIPACKET_PACKETTYPE._serialized_start = 161
_HCIPACKET_PACKETTYPE._serialized_end = 317
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_HCIPACKET._serialized_start=66
_HCIPACKET._serialized_end=317
_HCIPACKET_PACKETTYPE._serialized_start=161
_HCIPACKET_PACKETTYPE._serialized_end=317
# @@protoc_insertion_point(module_scope)

View File

@@ -1,41 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class HCIPacket(_message.Message):
__slots__ = ["packet", "type"]
class PacketType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = []
PACKET_FIELD_NUMBER: _ClassVar[int]
PACKET_TYPE_ACL: HCIPacket.PacketType
PACKET_TYPE_EVENT: HCIPacket.PacketType
PACKET_TYPE_HCI_COMMAND: HCIPacket.PacketType
PACKET_TYPE_ISO: HCIPacket.PacketType
PACKET_TYPE_SCO: HCIPacket.PacketType
PACKET_TYPE_UNSPECIFIED: HCIPacket.PacketType
TYPE_FIELD_NUMBER: _ClassVar[int]
packet: bytes
type: HCIPacket.PacketType
def __init__(
self,
type: _Optional[_Union[HCIPacket.PacketType, str]] = ...,
packet: _Optional[bytes] = ...,
) -> None: ...

View File

@@ -1,17 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,11 +16,11 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -29,18 +29,25 @@ _sym_db = _symbol_database.Default()
from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'emulated_bluetooth_pb2', globals())
_RAWDATA = DESCRIPTOR.message_types_by_name['RawData']
RawData = _reflection.GeneratedProtocolMessageType('RawData', (_message.Message,), {
'DESCRIPTOR' : _RAWDATA,
'__module__' : 'emulated_bluetooth_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData)
})
_sym_db.RegisterMessage(RawData)
_EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\036com.android.emulator.bluetoothP\001'
_RAWDATA._serialized_start = 91
_RAWDATA._serialized_end = 116
_EMULATEDBLUETOOTHSERVICE._serialized_start = 119
_EMULATEDBLUETOOTHSERVICE._serialized_end = 450
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\036com.android.emulator.bluetoothP\001'
_RAWDATA._serialized_start=91
_RAWDATA._serialized_end=116
_EMULATEDBLUETOOTHSERVICE._serialized_start=119
_EMULATEDBLUETOOTHSERVICE._serialized_end=450
# @@protoc_insertion_point(module_scope)

View File

@@ -1,26 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import emulated_bluetooth_packets_pb2 as _emulated_bluetooth_packets_pb2
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional
DESCRIPTOR: _descriptor.FileDescriptor
class RawData(_message.Message):
__slots__ = ["packet"]
PACKET_FIELD_NUMBER: _ClassVar[int]
packet: bytes
def __init__(self, packet: _Optional[bytes] = ...) -> None: ...

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,20 +39,20 @@ class EmulatedBluetoothServiceStub(object):
channel: A grpc.Channel.
"""
self.registerClassicPhy = channel.stream_stream(
'/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
)
'/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
)
self.registerBlePhy = channel.stream_stream(
'/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
)
'/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
request_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
response_deserializer=emulated__bluetooth__pb2.RawData.FromString,
)
self.registerHCIDevice = channel.stream_stream(
'/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
)
'/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
)
class EmulatedBluetoothServiceServicer(object):
@@ -121,29 +121,28 @@ class EmulatedBluetoothServiceServicer(object):
def add_EmulatedBluetoothServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'registerClassicPhy': grpc.stream_stream_rpc_method_handler(
servicer.registerClassicPhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
),
'registerBlePhy': grpc.stream_stream_rpc_method_handler(
servicer.registerBlePhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
),
'registerHCIDevice': grpc.stream_stream_rpc_method_handler(
servicer.registerHCIDevice,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
),
'registerClassicPhy': grpc.stream_stream_rpc_method_handler(
servicer.registerClassicPhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
),
'registerBlePhy': grpc.stream_stream_rpc_method_handler(
servicer.registerBlePhy,
request_deserializer=emulated__bluetooth__pb2.RawData.FromString,
response_serializer=emulated__bluetooth__pb2.RawData.SerializeToString,
),
'registerHCIDevice': grpc.stream_stream_rpc_method_handler(
servicer.registerHCIDevice,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers
)
'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
# This class is part of an EXPERIMENTAL API.
class EmulatedBluetoothService(object):
"""An Emulated Bluetooth Service exposes the emulated bluetooth chip from the
android emulator. It allows you to register emulated bluetooth devices and
@@ -157,88 +156,52 @@ class EmulatedBluetoothService(object):
"""
@staticmethod
def registerClassicPhy(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
def registerClassicPhy(request_iterator,
target,
'/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def registerBlePhy(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
def registerBlePhy(request_iterator,
target,
'/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def registerHCIDevice(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
def registerHCIDevice(request_iterator,
target,
'/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,31 +16,28 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth_vhci.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, 'emulated_bluetooth_vhci_pb2', globals()
)
_VHCIFORWARDINGSERVICE = DESCRIPTOR.services_by_name['VhciForwardingService']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_VHCIFORWARDINGSERVICE._serialized_start = 96
_VHCIFORWARDINGSERVICE._serialized_end = 217
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\037com.android.emulation.bluetoothP\001\370\001\001\242\002\003AEB\252\002\033Android.Emulation.Bluetooth'
_VHCIFORWARDINGSERVICE._serialized_start=96
_VHCIFORWARDINGSERVICE._serialized_end=217
# @@protoc_insertion_point(module_scope)

View File

@@ -1,19 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import emulated_bluetooth_packets_pb2 as _emulated_bluetooth_packets_pb2
from google.protobuf import descriptor as _descriptor
from typing import ClassVar as _ClassVar
DESCRIPTOR: _descriptor.FileDescriptor

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,10 +35,10 @@ class VhciForwardingServiceStub(object):
channel: A grpc.Channel.
"""
self.attachVhci = channel.stream_stream(
'/android.emulation.bluetooth.VhciForwardingService/attachVhci',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
)
'/android.emulation.bluetooth.VhciForwardingService/attachVhci',
request_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
response_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
)
class VhciForwardingServiceServicer(object):
@@ -75,19 +75,18 @@ class VhciForwardingServiceServicer(object):
def add_VhciForwardingServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'attachVhci': grpc.stream_stream_rpc_method_handler(
servicer.attachVhci,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
),
'attachVhci': grpc.stream_stream_rpc_method_handler(
servicer.attachVhci,
request_deserializer=emulated__bluetooth__packets__pb2.HCIPacket.FromString,
response_serializer=emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers
)
'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
# This class is part of an EXPERIMENTAL API.
class VhciForwardingService(object):
"""This is a service which allows you to directly intercept the VHCI packets
that are coming and going to the device before they are delivered to
@@ -98,30 +97,18 @@ class VhciForwardingService(object):
"""
@staticmethod
def attachVhci(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
def attachVhci(request_iterator,
target,
'/android.emulation.bluetooth.VhciForwardingService/attachVhci',
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.VhciForwardingService/attachVhci',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

View File

@@ -30,9 +30,8 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_file_transport(spec):
'''
Open a File transport (typically not for a real file, but for a PTY or other unix
virtual files).
The parameter string is the path of the file to open.
Open a File transport (typically not for a real file, but for a PTY or other unix virtual files).
The parameter string is the path of the file to open
'''
# Open the file
@@ -40,12 +39,14 @@ async def open_file_transport(spec):
# Setup reading
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
StreamPacketSource, file
lambda: StreamPacketSource(),
file
)
# Setup writing
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
asyncio.BaseProtocol, file
lambda: asyncio.BaseProtocol(),
file
)
packet_sink = StreamPacketSink(write_transport)
@@ -56,3 +57,4 @@ async def open_file_transport(spec):
file.close()
return FileTransport(packet_source, packet_sink)

View File

@@ -40,21 +40,15 @@ async def open_hci_socket_transport(spec):
or a 0-based integer to indicate the adapter number.
'''
HCI_CHANNEL_USER = 1 # pylint: disable=invalid-name
HCI_CHANNEL_USER = 1
# Create a raw HCI socket
try:
hci_socket = socket.socket(
socket.AF_BLUETOOTH,
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI,
)
except AttributeError as error:
hci_socket = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.BTPROTO_HCI)
except AttributeError:
# Not supported on this platform
logger.info("HCI sockets not supported on this platform")
raise Exception(
'Bluetooth HCI sockets not supported on this platform'
) from error
raise Exception('Bluetooth HCI sockets not supported on this platform')
# Compute the adapter index
if spec is None:
@@ -68,37 +62,20 @@ async def open_hci_socket_transport(spec):
try:
ctypes.cdll.LoadLibrary('libc.so.6')
libc = ctypes.CDLL('libc.so.6', use_errno=True)
except OSError as error:
except OSError:
logger.info("HCI sockets not supported on this platform")
raise Exception(
'Bluetooth HCI sockets not supported on this platform'
) from error
raise Exception('Bluetooth HCI sockets not supported on this platform')
libc.bind.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_char), ctypes.c_int)
libc.bind.restype = ctypes.c_int
bind_address = struct.pack(
# pylint: disable=no-member
'<HHH',
socket.AF_BLUETOOTH,
adapter_index,
HCI_CHANNEL_USER,
)
if (
libc.bind(
hci_socket.fileno(),
ctypes.create_string_buffer(bind_address),
len(bind_address),
)
!= 0
):
bind_address = struct.pack('<HHH', socket.AF_BLUETOOTH, adapter_index, HCI_CHANNEL_USER)
if libc.bind(hci_socket.fileno(), ctypes.create_string_buffer(bind_address), len(bind_address)) != 0:
raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
class HciSocketSource(ParserSource):
def __init__(self, hci_socket):
def __init__(self, socket):
super().__init__()
self.socket = hci_socket
asyncio.get_running_loop().add_reader(
self.socket.fileno(), self.recv_until_would_block
)
self.socket = socket
asyncio.get_running_loop().add_reader(socket.fileno(), self.recv_until_would_block)
def recv_until_would_block(self):
logger.debug('recv until would block +++')
@@ -115,9 +92,9 @@ async def open_hci_socket_transport(spec):
asyncio.get_running_loop().remove_reader(self.socket.fileno())
class HciSocketSink:
def __init__(self, hci_socket):
self.socket = hci_socket
self.packets = collections.deque()
def __init__(self, socket):
self.socket = socket
self.packets = collections.deque()
self.writer_added = False
def send_until_would_block(self):
@@ -135,14 +112,9 @@ async def open_hci_socket_transport(spec):
break
if self.packets:
# There's still something to send, ensure that we are monitoring the
# socket
# There's still something to send, ensure that we are monitoring the socket
if not self.writer_added:
asyncio.get_running_loop().add_writer(
# pylint: disable=no-member
self.socket.fileno(),
self.send_until_would_block,
)
asyncio.get_running_loop().add_writer(socket.fileno(), self.send_until_would_block)
self.writer_added = True
else:
# Nothing left to send, stop monitoring the socket
@@ -159,9 +131,9 @@ async def open_hci_socket_transport(spec):
asyncio.get_running_loop().remove_writer(self.socket.fileno())
class HciSocketTransport(Transport):
def __init__(self, hci_socket, source, sink):
def __init__(self, socket, source, sink):
super().__init__(source, sink)
self.socket = hci_socket
self.socket = socket
async def close(self):
logger.debug('closing HCI socket transport')

View File

@@ -47,11 +47,13 @@ async def open_pty_transport(spec):
tty.setraw(replica)
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
StreamPacketSource, io.open(primary, 'rb', closefd=False)
lambda: StreamPacketSource(),
io.open(primary, 'rb', closefd=False)
)
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
asyncio.BaseProtocol, io.open(primary, 'wb', closefd=False)
lambda: asyncio.BaseProtocol(),
io.open(primary, 'wb', closefd=False)
)
packet_sink = StreamPacketSink(write_transport)

View File

@@ -17,15 +17,14 @@
# -----------------------------------------------------------------------------
import asyncio
import logging
import threading
import time
import usb.core
import usb.util
import threading
import time
from colors import color
from .common import Transport, ParserSource
from .. import hci
from ..colors import color
# -----------------------------------------------------------------------------
@@ -49,26 +48,25 @@ async def open_pyusb_transport(spec):
04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
'''
# pylint: disable=invalid-name
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
USB_ENDPOINT_ACL_IN = 0x82
USB_ENDPOINT_SCO_IN = 0x83
USB_ENDPOINT_ACL_OUT = 0x02
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
USB_ENDPOINT_ACL_IN = 0x82
USB_ENDPOINT_SCO_IN = 0x83
USB_ENDPOINT_ACL_OUT = 0x02
# USB_ENDPOINT_SCO_OUT = 0x03
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
READ_SIZE = 1024
READ_SIZE = 1024
READ_TIMEOUT = 1000
class UsbPacketSink:
def __init__(self, device):
self.device = device
self.thread = threading.Thread(target=self.run)
self.loop = asyncio.get_running_loop()
self.device = device
self.thread = threading.Thread(target=self.run)
self.loop = asyncio.get_running_loop()
self.stop_event = None
def on_packet(self, packet):
@@ -82,17 +80,9 @@ async def open_pyusb_transport(spec):
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:])
elif packet_type == hci.HCI_COMMAND_PACKET:
self.device.ctrl_transfer(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
)
self.device.ctrl_transfer(USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, packet[1:])
else:
logger.warning(
color(f'unsupported packet type {packet_type}', 'red')
)
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
except usb.core.USBTimeoutError:
logger.warning('USB Write Timeout')
except usb.core.USBError as error:
@@ -110,21 +100,22 @@ async def open_pyusb_transport(spec):
def run(self):
while self.stop_event is None:
time.sleep(1)
self.loop.call_soon_threadsafe(self.stop_event.set)
self.loop.call_soon_threadsafe(lambda: self.stop_event.set())
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, sco_enabled):
super().__init__()
self.device = device
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.device = device
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.event_thread = threading.Thread(
target=self.run, args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
target=self.run,
args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
)
self.event_thread.stop_event = None
self.acl_thread = threading.Thread(
target=self.run, args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
target=self.run,
args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
)
self.acl_thread.stop_event = None
@@ -133,12 +124,12 @@ async def open_pyusb_transport(spec):
if sco_enabled:
self.sco_thread = threading.Thread(
target=self.run,
args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET),
args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET)
)
self.sco_thread.stop_event = None
def data_received(self, data):
self.parser.feed_data(data)
def data_received(self, packet):
self.parser.feed_data(packet)
def enqueue(self, packet):
self.queue.put_nowait(packet)
@@ -164,7 +155,7 @@ async def open_pyusb_transport(spec):
# Create stop events and wait for them to be signaled
self.event_thread.stop_event = asyncio.Event()
self.acl_thread.stop_event = asyncio.Event()
self.acl_thread.stop_event = asyncio.Event()
await self.event_thread.stop_event.wait()
await self.acl_thread.stop_event.wait()
if self.sco_enabled:
@@ -182,17 +173,16 @@ async def open_pyusb_transport(spec):
except usb.core.USBTimeoutError:
continue
except usb.core.USBError:
# Don't log this: because pyusb doesn't really support multiple
# threads reading at the same time, we can get occasional
# USBError(errno=5) Input/Output errors reported, but they seem to
# be harmless.
# Don't log this: because pyusb doesn't really support multiple threads
# reading at the same time, we can get occasional USBError(errno=5)
# Input/Output errors reported, but they seem to be harmless.
# Until support for async or multi-thread support is added to pyusb,
# we'll just live with this as is...
# logger.warning(f'USB read error: {error}')
time.sleep(1) # Sleep one second to avoid busy looping
stop_event = current_thread.stop_event
self.loop.call_soon_threadsafe(stop_event.set)
self.loop.call_soon_threadsafe(lambda: stop_event.set())
class UsbTransport(Transport):
def __init__(self, device, source, sink):
@@ -204,28 +194,18 @@ async def open_pyusb_transport(spec):
await self.sink.stop()
usb.util.release_interface(self.device, 0)
usb_find = usb.core.find
try:
import libusb_package
except ImportError:
logger.debug('libusb_package is not available')
else:
usb_find = libusb_package.find
# Find the device according to the spec moniker
if ':' in spec:
vendor_id, product_id = spec.split(':')
device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
device = usb.core.find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
else:
device_index = int(spec)
devices = list(
usb_find(
find_all=1,
bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER,
bDeviceProtocol=USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
)
devices = list(usb.core.find(
find_all = 1,
bDeviceClass = USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
bDeviceSubClass = USB_DEVICE_SUBCLASS_RF_CONTROLLER,
bDeviceProtocol = USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
))
if len(devices) > device_index:
device = devices[device_index]
else:
@@ -252,7 +232,6 @@ async def open_pyusb_transport(spec):
# Select an alternate setting for SCO, if available
sco_enabled = False
# pylint: disable=line-too-long
# NOTE: this is disabled for now, because SCO with alternate settings is broken,
# see: https://github.com/libusb/libusb/issues/36
#
@@ -294,4 +273,4 @@ async def open_pyusb_transport(spec):
packet_source.start()
packet_sink.start()
return UsbTransport(device, packet_source, packet_sink)
return UsbTransport(device, packet_source, packet_sink)

View File

@@ -60,12 +60,13 @@ async def open_serial_transport(spec):
device = spec
serial_transport, packet_source = await serial_asyncio.create_serial_connection(
asyncio.get_running_loop(),
StreamPacketSource,
lambda: StreamPacketSource(),
device,
baudrate=speed,
rtscts=rtscts,
dsrdtr=dsrdtr,
dsrdtr=dsrdtr
)
packet_sink = StreamPacketSink(serial_transport)
return Transport(packet_source, packet_sink)

View File

@@ -37,13 +37,13 @@ async def open_tcp_client_transport(spec):
'''
class TcpPacketSource(StreamPacketSource):
def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}')
self.terminated.set_result(exc)
def connection_lost(self, error):
logger.debug(f'connection lost: {error}')
self.terminated.set_result(error)
remote_host, remote_port = spec.split(':')
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(
TcpPacketSource,
lambda: TcpPacketSource(),
host=remote_host,
port=int(remote_port),
)

View File

@@ -45,12 +45,12 @@ async def open_tcp_server_transport(spec):
class TcpServerProtocol:
def __init__(self, packet_source, packet_sink):
self.packet_source = packet_source
self.packet_sink = packet_sink
self.packet_sink = packet_sink
# Called when a new connection is established
def connection_made(self, transport):
peer_name = transport.get_extra_info('peer_name')
logger.debug(f'connection from {peer_name}')
peername = transport.get_extra_info('peername')
logger.debug('connection from {}'.format(peername))
self.packet_sink.transport = transport
# Called when the client is disconnected
@@ -78,7 +78,7 @@ async def open_tcp_server_transport(spec):
local_host, local_port = spec.split(':')
packet_source = StreamPacketSource()
packet_sink = TcpServerPacketSink()
packet_sink = TcpServerPacketSink()
await asyncio.get_running_loop().create_server(
lambda: TcpServerProtocol(packet_source, packet_sink),
host=local_host if local_host != '_' else None,

View File

@@ -53,13 +53,10 @@ async def open_udp_transport(spec):
local, remote = spec.split(',')
local_host, local_port = local.split(':')
remote_host, remote_port = remote.split(':')
(
udp_transport,
packet_source,
) = await asyncio.get_running_loop().create_datagram_endpoint(
UdpPacketSource,
udp_transport, packet_source = await asyncio.get_running_loop().create_datagram_endpoint(
lambda: UdpPacketSource(),
local_addr=(local_host, int(local_port)),
remote_addr=(remote_host, int(remote_port)),
remote_addr=(remote_host, int(remote_port))
)
packet_sink = UdpPacketSink(udp_transport)

View File

@@ -17,16 +17,13 @@
# -----------------------------------------------------------------------------
import asyncio
import logging
import usb1
import threading
import collections
import ctypes
import platform
import usb1
from colors import color
from .common import Transport, ParserSource
from .. import hci
from ..colors import color
# -----------------------------------------------------------------------------
@@ -36,88 +33,42 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def load_libusb():
'''
Attempt to load the libusb-1.0 C library from libusb_package in site-packages.
If the library exists, we create a DLL object and initialize the usb1 backend.
This only needs to be done once, but before a usb1.USBContext is created.
If the library does not exists, do nothing and usb1 will search default system paths
when usb1.USBContext is created.
'''
try:
import libusb_package
except ImportError:
logger.debug('libusb_package is not available')
else:
if libusb_path := libusb_package.get_library_path():
logger.debug(f'loading libusb library at {libusb_path}')
dll_loader = (
ctypes.WinDLL if platform.system() == 'Windows' else ctypes.CDLL
)
libusb_dll = dll_loader(
str(libusb_path), use_errno=True, use_last_error=True
)
usb1.loadLibrary(libusb_dll)
async def open_usb_transport(spec):
'''
Open a USB transport.
The moniker string has this syntax:
either <index> or
<vendor>:<product> or
<vendor>:<product>/<serial-number>] or
<vendor>:<product>#<index>
The parameter string has this syntax:
either <index> or <vendor>:<product>[/<serial-number>]
With <index> as the 0-based index to select amongst all the devices that appear
to be supporting Bluetooth HCI (0 being the first one), or
Where <vendor> and <product> are the vendor ID and product ID in hexadecimal. The
/<serial-number> suffix or #<index> suffix max be specified when more than one
device with the same vendor and product identifiers are present.
In addition, if the moniker ends with the symbol "!", the device will be used in
"forced" mode:
the first USB interface of the device will be used, regardless of the interface
class/subclass.
This may be useful for some devices that use a custom class/subclass but may
nonetheless work as-is.
/<serial-number> suffix max be specified when more than one device with the same
vendor and product identifiers are present.
Examples:
0 --> the first BT USB dongle
04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
04b4:f901#2 --> the third USB device with vendor=04b4 and product=f901
04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and
serial number 00E04C239987
usb:0B05:17CB! --> the BT USB dongle vendor=0B05 and product=17CB, in "forced" mode.
04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and serial number 00E04C239987
'''
# pylint: disable=invalid-name
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_DEVICE_CLASS_DEVICE = 0x00
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
USB_ENDPOINT_ACL_IN = 0x82
USB_ENDPOINT_ACL_OUT = 0x02
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_ENDPOINT_TRANSFER_TYPE_BULK = 0x02
USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT = 0x03
USB_ENDPOINT_IN = 0x80
USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
READ_SIZE = 1024
class UsbPacketSink:
def __init__(self, device, acl_out):
self.device = device
self.acl_out = acl_out
self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
def __init__(self, device):
self.device = device
self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
self.cancel_done = self.loop.create_future()
self.closed = False
self.closed = False
def start(self):
pass
@@ -141,15 +92,12 @@ async def open_usb_transport(spec):
status = transfer.getStatus()
# logger.debug(f'<<< USB out transfer callback: status={status}')
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
self.loop.call_soon_threadsafe(self.on_packet_sent_)
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else:
logger.warning(
color(f'!!! out transfer not completed: status={status}', 'red')
)
logger.warning(color(f'!!! out transfer not completed: status={status}', 'red'))
def on_packet_sent_(self):
if self.packets:
@@ -164,38 +112,32 @@ async def open_usb_transport(spec):
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.transfer.setBulk(
self.acl_out, packet[1:], callback=self.on_packet_sent
USB_ENDPOINT_ACL_OUT,
packet[1:],
callback=self.on_packet_sent
)
logger.debug('submit ACL')
self.transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET:
self.transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0,
packet[1:],
callback=self.on_packet_sent,
callback=self.on_packet_sent
)
logger.debug('submit COMMAND')
self.transfer.submit()
else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
def close(self):
async def close(self):
self.closed = True
async def terminate(self):
if not self.closed:
self.close()
# Empty the packet queue so that we don't send any more data
self.packets.clear()
# If we have a transfer in flight, cancel it
if self.transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have
# already completed
# Try to cancel the transfer, but that may fail because it may have already completed
try:
self.transfer.cancel()
@@ -206,23 +148,18 @@ async def open_usb_transport(spec):
logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, context, device, acl_in, events_in):
def __init__(self, context, device):
super().__init__()
self.context = context
self.device = device
self.acl_in = acl_in
self.events_in = events_in
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.closed = False
self.context = context
self.device = device
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.closed = False
self.event_loop_done = self.loop.create_future()
self.cancel_done = {
hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future()
}
self.events_in_transfer = None
self.acl_in_transfer = None
# Create a thread to process events
self.event_thread = threading.Thread(target=self.run)
@@ -231,19 +168,19 @@ async def open_usb_transport(spec):
# Set up transfer objects for input
self.events_in_transfer = device.getTransfer()
self.events_in_transfer.setInterrupt(
self.events_in,
USB_ENDPOINT_EVENTS_IN,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET,
user_data=hci.HCI_EVENT_PACKET
)
self.events_in_transfer.submit()
self.acl_in_transfer = device.getTransfer()
self.acl_in_transfer.setBulk(
self.acl_in,
USB_ENDPOINT_ACL_IN,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET,
user_data=hci.HCI_ACL_DATA_PACKET
)
self.acl_in_transfer.submit()
@@ -253,28 +190,16 @@ async def open_usb_transport(spec):
def on_packet_received(self, transfer):
packet_type = transfer.getUserData()
status = transfer.getStatus()
# logger.debug(
# f'<<< USB IN transfer callback: status={status} '
# f'packet_type={packet_type} '
# f'length={transfer.getActualLength()}'
# )
# logger.debug(f'<<< USB IN transfer callback: status={status} packet_type={packet_type}')
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
packet = (
bytes([packet_type])
+ transfer.getBuffer()[: transfer.getActualLength()]
)
packet = bytes([packet_type]) + transfer.getBuffer()[:transfer.getActualLength()]
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(
self.cancel_done[packet_type].set_result, None
)
self.loop.call_soon_threadsafe(self.cancel_done[packet_type].set_result, None)
return
else:
logger.warning(
color(f'!!! transfer not completed: status={status}', 'red')
)
logger.warning(color(f'!!! transfer not completed: status={status}', 'red'))
# Re-submit the transfer so we can receive more data
transfer.submit()
@@ -289,11 +214,7 @@ async def open_usb_transport(spec):
def run(self):
logger.debug('starting USB event loop')
while (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
):
# pylint: disable=no-member
while self.events_in_transfer.isSubmitted() or self.acl_in_transfer.isSubmitted():
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
@@ -302,130 +223,75 @@ async def open_usb_transport(spec):
logger.debug('USB event loop done')
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
def close(self):
async def close(self):
self.closed = True
async def terminate(self):
if not self.closed:
self.close()
self.dequeue_task.cancel()
# Cancel the transfers
for transfer in (self.events_in_transfer, self.acl_in_transfer):
if transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have
# already completed
# Try to cancel the transfer, but that may fail because it may have already completed
packet_type = transfer.getUserData()
try:
transfer.cancel()
logger.debug(
f'waiting for IN[{packet_type}] transfer cancellation '
'to be done...'
)
logger.debug(f'waiting for IN[{packet_type}] transfer cancellation to be done...')
await self.cancel_done[packet_type]
logger.debug(f'IN[{packet_type}] transfer cancellation done')
except usb1.USBError:
logger.debug(
f'IN[{packet_type}] transfer likely already completed'
)
logger.debug(f'IN[{packet_type}] transfer likely already completed')
# Wait for the thread to terminate
await self.event_loop_done
class UsbTransport(Transport):
def __init__(self, context, device, interface, setting, source, sink):
def __init__(self, context, device, interface, source, sink):
super().__init__(source, sink)
self.context = context
self.device = device
self.context = context
self.device = device
self.interface = interface
# Get exclusive access
device.claimInterface(interface)
# Set the alternate setting if not the default
if setting != 0:
device.setInterfaceAltSetting(interface, setting)
# The source and sink can now start
source.start()
sink.start()
async def close(self):
self.source.close()
self.sink.close()
await self.source.terminate()
await self.sink.terminate()
await self.source.close()
await self.sink.close()
self.device.releaseInterface(self.interface)
self.device.close()
self.context.close()
# Find the device according to the spec moniker
load_libusb()
context = usb1.USBContext()
context.open()
try:
found = None
if spec.endswith('!'):
spec = spec[:-1]
forced_mode = True
else:
forced_mode = False
if ':' in spec:
vendor_id, product_id = spec.split(':')
serial_number = None
device_index = 0
if '/' in product_id:
product_id, serial_number = product_id.split('/')
elif '#' in product_id:
product_id, device_index_str = product_id.split('#')
device_index = int(device_index_str)
for device in context.getDeviceIterator(skip_on_error=True):
try:
device_serial_number = device.getSerialNumber()
except usb1.USBError:
device_serial_number = None
if (
device.getVendorID() == int(vendor_id, 16)
and device.getProductID() == int(product_id, 16)
and (serial_number is None or serial_number == device_serial_number)
):
if device_index == 0:
for device in context.getDeviceIterator(skip_on_error=True):
if (
device.getVendorID() == int(vendor_id, 16) and
device.getProductID() == int(product_id, 16) and
device.getSerialNumber() == serial_number
):
found = device
break
device_index -= 1
device.close()
device.close()
else:
found = context.getByVendorIDAndProductID(int(vendor_id, 16), int(product_id, 16), skip_on_error=True)
else:
# Look for a compatible device by index
def device_is_bluetooth_hci(device):
# Check if the device class indicates a match
if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
# If the device class is 'Device', look for a matching interface
if device.getDeviceClass() == USB_DEVICE_CLASS_DEVICE:
for configuration in device:
for interface in configuration:
for setting in interface:
if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
return False
device_index = int(spec)
for device in context.getDeviceIterator(skip_on_error=True):
if device_is_bluetooth_hci(device):
if (
device.getDeviceClass() == USB_DEVICE_CLASS_WIRELESS_CONTROLLER and
device.getDeviceSubClass() == USB_DEVICE_SUBCLASS_RF_CONTROLLER and
device.getDeviceProtocol() == USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
):
if device_index == 0:
found = device
break
@@ -437,107 +303,34 @@ async def open_usb_transport(spec):
raise ValueError('device not found')
logger.debug(f'USB Device: {found}')
# Look for the first interface with the right class and endpoints
def find_endpoints(device):
# pylint: disable-next=too-many-nested-blocks
for (configuration_index, configuration) in enumerate(device):
interface = None
for interface in configuration:
setting = None
for setting in interface:
if (
not forced_mode
and (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
)
!= USB_BT_HCI_CLASS_TUPLE
):
continue
events_in = None
acl_in = None
acl_out = None
for endpoint in setting:
attributes = endpoint.getAttributes()
address = endpoint.getAddress()
if attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_BULK:
if address & USB_ENDPOINT_IN and acl_in is None:
acl_in = address
elif acl_out is None:
acl_out = address
elif (
attributes & 0x03
== USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT
):
if address & USB_ENDPOINT_IN and events_in is None:
events_in = address
# Return if we found all 3 endpoints
if (
acl_in is not None
and acl_out is not None
and events_in is not None
):
return (
configuration_index + 1,
setting.getNumber(),
setting.getAlternateSetting(),
acl_in,
acl_out,
events_in,
)
logger.debug(
f'skipping configuration {configuration_index + 1} / '
f'interface {setting.getNumber()}'
)
return None
endpoints = find_endpoints(found)
if endpoints is None:
raise ValueError('no compatible interface found for device')
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
logger.debug(
f'selected endpoints: configuration={configuration}, '
f'interface={interface}, '
f'setting={setting}, '
f'acl_in=0x{acl_in:02X}, '
f'acl_out=0x{acl_out:02X}, '
f'events_in=0x{events_in:02X}, '
)
device = found.open()
# Auto-detach the kernel driver if supported
# pylint: disable=no-member
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
try:
logger.debug('auto-detaching kernel driver')
device.setAutoDetachKernelDriver(True)
except usb1.USBError as error:
logger.warning(f'unable to auto-detach kernel driver: {error}')
# Set the configuration if needed
try:
current_configuration = device.getConfiguration()
logger.debug(f'current configuration = {current_configuration}')
configuration = device.getConfiguration()
logger.debug(f'current configuration = {configuration}')
except usb1.USBError:
current_configuration = 0
if current_configuration != configuration:
try:
logger.debug(f'setting configuration {configuration}')
device.setConfiguration(configuration)
logger.debug('setting configuration 1')
device.setConfiguration(1)
except usb1.USBError:
logger.warning('failed to set configuration')
logger.debug('failed to set configuration 1')
source = UsbPacketSource(context, device, acl_in, events_in)
sink = UsbPacketSink(device, acl_out)
return UsbTransport(context, device, interface, setting, source, sink)
# Use the first interface
interface = 0
# Detach the kernel driver if supported and needed
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
try:
if device.kernelDriverActive(interface):
logger.debug("detaching kernel driver")
device.detachKernelDriver(interface)
except usb1.USBError:
pass
source = UsbPacketSource(context, device)
sink = UsbPacketSink(device)
return UsbTransport(context, device, interface, source, sink)
except usb1.USBError as error:
logger.warning(color(f'!!! failed to open USB device: {error}', 'red'))
context.close()

View File

@@ -33,7 +33,7 @@ async def open_vhci_transport(spec):
path at /dev/vhci), or the path of a VHCI device
'''
HCI_VENDOR_PKT = 0xFF
HCI_VENDOR_PKT = 0xff
HCI_BREDR = 0x00 # Controller type
# Open the VHCI device
@@ -56,3 +56,4 @@ async def open_vhci_transport(spec):
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))
return transport

View File

@@ -43,7 +43,7 @@ async def open_ws_client_transport(spec):
transport = PumpedTransport(
PumpedPacketSource(websocket.recv),
PumpedPacketSink(websocket.send),
websocket.close,
websocket.close
)
transport.start()
return transport

View File

@@ -41,36 +41,30 @@ async def open_ws_server_transport(spec):
class WsServerTransport(Transport):
def __init__(self):
source = ParserSource()
sink = PumpedPacketSink(self.send_packet)
source = ParserSource()
sink = PumpedPacketSink(self.send_packet)
self.connection = asyncio.get_running_loop().create_future()
self.server = None
super().__init__(source, sink)
async def serve(self, local_host, local_port):
self.sink.start()
# pylint: disable-next=no-member
self.server = await websockets.serve(
ws_handler=self.on_connection,
host=local_host if local_host != '_' else None,
port=int(local_port),
ws_handler = self.on_connection,
host = local_host if local_host != '_' else None,
port = int(local_port)
)
logger.debug(f'websocket server ready on port {local_port}')
async def on_connection(self, connection):
logger.debug(
f'new connection on {connection.local_address} '
f'from {connection.remote_address}'
)
logger.debug(f'new connection on {connection.local_address} from {connection.remote_address}')
self.connection.set_result(connection)
# pylint: disable=no-member
try:
async for packet in connection:
if isinstance(packet, bytes):
if type(packet) is bytes:
self.source.parser.feed_data(packet)
else:
logger.warning('discarding packet: not a BINARY frame')
logger.warn('discarding packet: not a BINARY frame')
except websockets.WebSocketException as error:
logger.debug(f'exception while receiving packet: {error}')

View File

@@ -18,13 +18,10 @@
import asyncio
import logging
import traceback
import collections
import sys
from typing import Awaitable, Set, TypeVar
from functools import wraps
from colors import color
from pyee import EventEmitter
from .colors import color
# -----------------------------------------------------------------------------
# Logging
@@ -36,7 +33,6 @@ logger = logging.getLogger(__name__)
def setup_event_forwarding(emitter, forwarder, event_name):
def emit(*args, **kwargs):
forwarder.emit(event_name, *args, **kwargs)
emitter.on(event_name, emit)
@@ -47,8 +43,6 @@ def composite_listener(cls):
registers/deregisters all methods named `on_<event_name>` as a listener for
the <event_name> event with an emitter.
"""
# pylint: disable=protected-access
def register(self, emitter):
for method_name in dir(cls):
if method_name.startswith('on_'):
@@ -59,47 +53,13 @@ def composite_listener(cls):
if method_name.startswith('on_'):
emitter.remove_listener(method_name[3:], getattr(self, method_name))
cls._bumble_register_composite = register
cls._bumble_register_composite = register
cls._bumble_deregister_composite = deregister
return cls
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
class AbortableEventEmitter(EventEmitter):
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
"""
Set a coroutine or future to abort when an event occur.
"""
future = asyncio.ensure_future(awaitable)
if future.done():
return future
def on_event(*_):
if future.done():
return
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
def on_done(_):
self.remove_listener(event, on_event)
self.on(event, on_event)
future.add_done_callback(on_done)
return future
# -----------------------------------------------------------------------------
class CompositeEventEmitter(AbortableEventEmitter):
class CompositeEventEmitter(EventEmitter):
def __init__(self):
super().__init__()
self._listener = None
@@ -110,7 +70,6 @@ class CompositeEventEmitter(AbortableEventEmitter):
@listener.setter
def listener(self, listener):
# pylint: disable=protected-access
if self._listener:
# Call the deregistration methods for each base class that has them
for cls in self._listener.__class__.mro():
@@ -150,16 +109,11 @@ class AsyncRunner:
try:
await item
except Exception as error:
logger.warning(
f'{color("!!! Exception in work queue:", "red")} {error}'
)
logger.warning(f'{color("!!! Exception in work queue:", "red")} {error}')
# Shared default queue
default_queue = WorkQueue()
# Shared set of running tasks
running_tasks: Set[Awaitable] = set()
@staticmethod
def run_in_task(queue=None):
"""
@@ -176,10 +130,7 @@ class AsyncRunner:
try:
await coroutine
except Exception:
logger.warning(
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
logger.warning(f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}')
asyncio.create_task(run())
else:
@@ -189,116 +140,3 @@ class AsyncRunner:
return wrapper
return decorator
@staticmethod
def spawn(coroutine):
"""
Spawn a task to run a coroutine in a "fire and forget" mode.
Using this method instead of just calling `asyncio.create_task(coroutine)`
is necessary when you don't keep a reference to the task, because `asyncio`
only keeps weak references to alive tasks.
"""
task = asyncio.create_task(coroutine)
AsyncRunner.running_tasks.add(task)
task.add_done_callback(AsyncRunner.running_tasks.remove)
# -----------------------------------------------------------------------------
class FlowControlAsyncPipe:
"""
Asyncio pipe with flow control. When writing to the pipe, the source is
paused (by calling a function passed in when the pipe is created) if the
amount of queued data exceeds a specified threshold.
"""
def __init__(
self,
pause_source,
resume_source,
write_to_sink=None,
drain_sink=None,
threshold=0,
):
self.pause_source = pause_source
self.resume_source = resume_source
self.write_to_sink = write_to_sink
self.drain_sink = drain_sink
self.threshold = threshold
self.queue = collections.deque() # Queue of packets
self.queued_bytes = 0 # Number of bytes in the queue
self.ready_to_pump = asyncio.Event()
self.paused = False
self.source_paused = False
self.pump_task = None
def start(self):
if self.pump_task is None:
self.pump_task = asyncio.create_task(self.pump())
self.check_pump()
def stop(self):
if self.pump_task is not None:
self.pump_task.cancel()
self.pump_task = None
def write(self, packet):
self.queued_bytes += len(packet)
self.queue.append(packet)
# Pause the source if we're over the threshold
if self.queued_bytes > self.threshold and not self.source_paused:
logger.debug(f'pausing source (queued={self.queued_bytes})')
self.pause_source()
self.source_paused = True
self.check_pump()
def pause(self):
if not self.paused:
self.paused = True
if not self.source_paused:
self.pause_source()
self.source_paused = True
self.check_pump()
def resume(self):
if self.paused:
self.paused = False
if self.source_paused:
self.resume_source()
self.source_paused = False
self.check_pump()
def can_pump(self):
return self.queue and not self.paused and self.write_to_sink is not None
def check_pump(self):
if self.can_pump():
self.ready_to_pump.set()
else:
self.ready_to_pump.clear()
async def pump(self):
while True:
# Wait until we can try to pump packets
await self.ready_to_pump.wait()
# Try to pump a packet
if self.can_pump():
packet = self.queue.pop()
self.write_to_sink(packet)
self.queued_bytes -= len(packet)
# Drain the sink if we can
if self.drain_sink:
await self.drain_sink()
# Check if we can accept more
if self.queued_bytes <= self.threshold and self.source_paused:
logger.debug(f'resuming source (queued={self.queued_bytes})')
self.source_paused = False
self.resume_source()
self.check_pump()

View File

@@ -2,7 +2,7 @@ Bumble Documentation
====================
The documentation consists of a collection of markdown text files, with the root of the file
hierarchy at `docs/mkdocs/src`, starting with `docs/mkdocs/src/index.md`.
hierarchy at `docs/mkdocs/src`, starting with `docs/mkdocs/src/index.md`.
You can read the documentation as text, with any text viewer or your favorite markdown viewer,
or generate a static HTML "site" using `mkdocs`, which you can then open with any browser.
@@ -14,9 +14,9 @@ The `mkdocs` directory contains all the data (actual documentation) and metadata
`mkdocs/mkdocs.yml` contains the site configuration.
`mkdocs/src/` is the directory where the actual documentation text, in markdown format, is located.
To build, from the project's root directory:
To build, from the project's root directory:
```
$ mkdocs build -f docs/mkdocs/mkdocs.yml
$ mkdocs build -f docs/mkdocs/mkdocs.yml
```
You can then open `docs/mkdocs/site/index.html` with any web browser.

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
{"date":644900643.85054696,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}
{"date":644900643.85054696,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}

View File

@@ -1 +1 @@
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
{"date":644900741.09290397,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}
{"date":644900741.09290397,"appVersion":"4.1.5","drawing":{"modificationDate":644894800.328192,"activeArtboardIndex":0,"settings":{"outlineMode":false,"isolateActiveLayer":false,"snapToEdges":false,"snapToPoints":false,"guidesVisible":true,"snapToGrid":false,"units":"Pixels","dimensionsVisible":true,"dynamicGuides":false,"isCMYKColorPreviewEnabled":false,"undoHistoryDisabled":false,"snapToGuides":true,"drawOnlyUsingPencil":false,"whiteBackground":false,"rulersVisible":true,"isTimeLapseWatermarkDisabled":false},"artboardPaths":["Artboard0.json"],"documentVersion":"unknown"}}

View File

@@ -1 +1 @@
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}
{"documentJSONFilename":"Document.json","undoHistoryJSONFilename":"UndoHistory.json","fileFormatVersion":0,"thumbnailImageFilename":"Thumbnail.png"}

File diff suppressed because one or more lines are too long

View File

@@ -7,8 +7,6 @@ nav:
- Getting Started: getting_started.md
- Development:
- Python Environments: development/python_environments.md
- Contributing: development/contributing.md
- Code Style: development/code_style.md
- Use Cases:
- Overview: use_cases/index.md
- Use Case 1: use_cases/use_case_1.md
@@ -43,15 +41,10 @@ nav:
- Apps & Tools:
- Overview: apps_and_tools/index.md
- Console: apps_and_tools/console.md
- Bench: apps_and_tools/bench.md
- Link Relay: apps_and_tools/link_relay.md
- HCI Bridge: apps_and_tools/hci_bridge.md
- Golden Gate Bridge: apps_and_tools/gg_bridge.md
- Show: apps_and_tools/show.md
- GATT Dump: apps_and_tools/gatt_dump.md
- Pair: apps_and_tools/pair.md
- Unbond: apps_and_tools/unbond.md
- USB Probe: apps_and_tools/usb_probe.md
- Link Relay: apps_and_tools/link_relay.md
- Hardware:
- Overview: hardware/index.md
- Platforms:
@@ -63,7 +56,7 @@ nav:
- Examples:
- Overview: examples/index.md
copyright: Copyright 2021-2023 Google LLC
copyright: Copyright 2021-2022 Google LLC
theme:
name: 'material'

View File

@@ -1,6 +1,6 @@
# This requirements file is for python3
mkdocs == 1.4.0
mkdocs-material == 8.5.6
mkdocs-material-extensions == 1.0.3
pymdown-extensions == 9.6
mkdocstrings-python == 0.7.1
mkdocs == 1.2.3
mkdocs-material == 7.1.7
mkdocs-material-extensions == 1.0.1
pymdown-extensions == 8.2
mkdocstrings == 0.15.1

View File

@@ -1,2 +1,2 @@
API EXAMPLES
============
============

View File

@@ -1,2 +1,2 @@
API DEVELOPER GUIDE
===================
===================

View File

@@ -16,3 +16,4 @@ Bumble Python API
### HCI_Disconnect_Command
::: bumble.hci.HCI_Disconnect_Command

View File

@@ -1,158 +0,0 @@
BENCH TOOL
==========
The "bench" tool implements a number of different ways of measuring the
throughput and/or latency between two devices.
# General Usage
```
Usage: bench.py [OPTIONS] COMMAND [ARGS]...
Options:
--device-config FILENAME Device configuration file
--role [sender|receiver|ping|pong]
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
-c, --packet-count COUNT Packet count (server role)
-sd, --start-delay SECONDS Start delay (server role)
--help Show this message and exit.
Commands:
central Run as a central (initiates the connection)
peripheral Run as a peripheral (waits for a connection)
```
## Options for the ``central`` Command
```
Usage: bumble-bench central [OPTIONS] TRANSPORT
Run as a central (initiates the connection)
Options:
--peripheral ADDRESS_OR_NAME Address or name to connect to
--connection-interval, --ci CONNECTION_INTERVAL
Connection interval (in ms)
--phy [1m|2m|coded] PHY to use
--help Show this message and exit.
```
To test once device against another, one of the two devices must be running
the ``peripheral`` command and the other the ``central`` command. The device
running the ``peripheral`` command will accept connections from the device
running the ``central`` command.
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central
using the ``--peripheral`` option. The address will be printed by the Peripheral when
it starts.
Independently of whether the device is the Central or Peripheral, each device selects a
``mode`` and and ``role`` to run as. The ``mode`` and ``role`` of the Central and Peripheral
must be compatible.
Device 1 mode | Device 2 mode
------------------|------------------
``gatt-client`` | ``gatt-server``
``l2cap-client`` | ``l2cap-server``
``rfcomm-client`` | ``rfcomm-server``
Device 1 role | Device 2 role
--------------|--------------
``sender`` | ``receiver``
``ping`` | ``pong``
# Examples
In the following examples, we have two USB Bluetooth controllers, one on `usb:0` and
the other on `usb:1`, and two consoles/terminals. We will run a command in each.
!!! example "GATT Throughput"
Using the default mode and role for the Central and Peripheral.
In the first console/terminal:
```
$ bumble-bench peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench central usb:1
```
In this default configuration, the Central runs a Sender, as a GATT client,
connecting to the Peripheral running a Receiver, as a GATT server.
!!! example "L2CAP Throughput"
In the first console/terminal:
```
$ bumble-bench --mode l2cap-server peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --mode l2cap-client central usb:1
```
!!! example "RFComm Throughput"
In the first console/terminal:
```
$ bumble-bench --mode rfcomm-server peripheral usb:0
```
NOTE: the BT address of the Peripheral will be printed out, use it with the
``--peripheral`` option for the Central.
In this example, we use a larger packet size and packet count than the default.
In the second console/terminal:
```
$ bumble-bench --mode rfcomm-client --packet-size 2000 --packet-count 100 central --peripheral 00:16:A4:5A:40:F2 usb:1
```
!!! example "Ping/Pong Latency"
In the first console/terminal:
```
$ bumble-bench --role pong peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --role ping central usb:1
```
!!! example "Reversed modes with GATT and custom connection interval"
In the first console/terminal:
```
$ bumble-bench --mode gatt-client peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --mode gatt-server central --ci 10 usb:1
```
!!! example "Reversed modes with L2CAP and custom PHY"
In the first console/terminal:
```
$ bumble-bench --mode l2cap-client peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --mode l2cap-server central --phy 2m usb:1
```
!!! example "Reversed roles with L2CAP"
In the first console/terminal:
```
$ bumble-bench --mode l2cap-client --role sender peripheral usb:0
```
In the second console/terminal:
```
$ bumble-bench --mode l2cap-server --role receiver central usb:1
```

View File

@@ -1,2 +1,2 @@
GOLDEN GATE BRIDGE
==================
==================

View File

@@ -28,3 +28,5 @@ a host that send custom HCI commands that the controller may not understand.
(through which the communication with other virtual controllers will be mediated).
NOTE: this assumes you're running a Link Relay on port `10723`.

View File

@@ -5,10 +5,10 @@ Included in the project are a few apps and tools, built on top of the core libra
These include:
* [Console](console.md) - an interactive text-based console
* [Bench](bench.md) - Speed and Latency benchmarking between two devices (LE and Classic)
* [Pair](pair.md) - Pair/bond two devices (LE and Classic)
* [Unbond](unbond.md) - Remove a previously established bond
* [HCI Bridge](hci_bridge.md) - a HCI transport bridge to connect two HCI transports and filter/snoop the HCI packets
* [Golden Gate Bridge](gg_bridge.md) - a bridge between GATT and UDP to use with the Golden Gate "stack tool"
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form
* [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other.

Some files were not shown because too many files have changed in this diff Show More