Compare commits

...

67 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
5b173cb879 add constant for 5.4 2023-10-11 17:47:21 -07:00
zxzxwu
8b04161da3 Merge pull request #317 from zxzxwu/pytest
Add missing @pytest.mark.asyncio decorator
2023-10-11 15:16:35 +08:00
Josh Wu
333940919b Add missing @pytest.mark.asyncio decorator 2023-10-11 13:52:06 +08:00
Gilles Boccon-Gibod
b9476be9ad Merge pull request #315 from google/gbg/company-ids
update to latest list of company ids
2023-10-10 22:13:16 -07:00
Gilles Boccon-Gibod
704c60491c Merge pull request #313 from benquike/pair_fix
Allow turning on BLE in classic pairing mode
2023-10-10 21:30:24 -07:00
Gilles Boccon-Gibod
4a8e612c6e update rust list 2023-10-10 21:29:39 -07:00
Gilles Boccon-Gibod
4e71ec5738 remove stale comment 2023-10-10 20:36:48 -07:00
uael
7255a09705 ci: add python avatar tests 2023-10-09 23:37:23 +02:00
zxzxwu
c2bf6b5f13 Merge pull request #289 from zxzxwu/l2cap_refactor
Refactor L2CAP API
2023-10-09 23:27:25 +08:00
Gilles Boccon-Gibod
d8e699b588 use the new yaml file instead of the previous CSV file 2023-10-07 23:10:49 -07:00
zxzxwu
3e4d4705f5 Merge pull request #314 from zxzxwu/sec_pandora
Pandora: Handle exception in WaitSecurity()
2023-10-08 01:42:45 +08:00
Josh Wu
c8b2804446 Pandora: Handle exception in WaitSecurity() 2023-10-07 21:17:01 +08:00
Josh Wu
e732f2589f Refactor L2CAP API 2023-10-07 20:01:15 +08:00
zxzxwu
aec5543081 Merge pull request #310 from zxzxwu/avdtp
Typing AVDTP
2023-10-07 19:50:56 +08:00
Josh Wu
e03d90ca57 Add typing for MediaCodecCapabilities members 2023-10-07 19:32:19 +08:00
Josh Wu
495ce62d9c Typing AVDTP 2023-10-07 19:32:19 +08:00
Hui Peng
fbc3959a5a Allow turning on BLE in classic pairing mode 2023-10-06 19:54:18 -07:00
Gilles Boccon-Gibod
dfa9131192 Merge pull request #311 from zxzxwu/rust
Fix Rust lints
2023-10-06 13:37:47 -07:00
Josh Wu
88c801b4c2 Replace or_insert_with with or_default 2023-10-06 18:02:46 +08:00
Gilles Boccon-Gibod
a1b55b94e0 Merge pull request #301 from whitevegagabriel/simplify-event-loop-copy
Remove unncecesary steps for injecting Python event loop
2023-10-02 12:12:41 -07:00
Gilles Boccon-Gibod
80db9e2e2f Merge pull request #303 from whitevegagabriel/hci-command-rs
Ability to send HCI commands from Rust
2023-10-02 12:12:05 -07:00
Gabriel White-Vega
ce74690420 Update pdl to 0.2.0
- Allows removing impl PartialEq for pdl Error
2023-10-02 11:20:44 -04:00
Gilles Boccon-Gibod
50de4dfb5d Merge pull request #307 from google/gbg/hotfix-001
don't delete advertising prefs on disconnection
2023-09-30 17:46:53 -07:00
Gilles Boccon-Gibod
9bcdf860f4 don't delete advertising prefs on disconnection 2023-09-30 17:41:18 -07:00
Gabriel White-Vega
511ab4b630 Add python async wrapper, move hci non-wrapper to internal, add hci::internal tests 2023-09-29 10:23:19 -04:00
Gilles Boccon-Gibod
6f2b623e3c Merge pull request #290 from google/gbg/netsim-transport-injectable-channels
make grpc channels injectable
2023-09-27 22:16:05 -07:00
Gilles Boccon-Gibod
fa12165cd3 Merge pull request #298 from google/gbg/use-address-to-string
use Address.to_string instead of manual suffix replacement
2023-09-27 21:59:32 -07:00
Gilles Boccon-Gibod
c0c6f3329d minor cleanup 2023-09-27 21:53:54 -07:00
Gilles Boccon-Gibod
406a932467 make grpc channels injectable 2023-09-27 21:37:36 -07:00
Gilles Boccon-Gibod
cc96d4245f address PR comments 2023-09-27 21:25:13 -07:00
Sparkling Diva
c6cdca8923 device: return the psm value from register_l2cap 2023-09-27 16:41:38 -07:00
Gabriel White-Vega
7e331c2944 Ability to send HCI commands from Rust
* Autogenerate packet code in Rust from PDL (packet file copied from rootcanal)
* Implement parsing of packets that have a type header
* Expose Python APIs for sending HCI commands
* Expose Python APIs for instantiating a local controller
2023-09-27 11:17:47 -04:00
Gilles Boccon-Gibod
10347765cb Merge pull request #302 from google/gbg/netsim-with-instance-num
support netsim instance numbers
2023-09-26 09:34:28 -07:00
Gilles Boccon-Gibod
c12dee4e76 Merge pull request #294 from mauricelam/wasm-cryptography
Make cryptography a valid dependency for emscripten targets
2023-09-25 19:29:09 -07:00
Maurice Lam
772c188674 Fix typo 2023-09-25 18:08:52 -07:00
Maurice Lam
7c1a3bb8f9 Separate version specifier for cryptography in Emscripten builds 2023-09-22 16:43:40 -07:00
Maurice Lam
8c3c0b1e13 Make cryptography a valid dependency for emscripten targets
Since only the special cryptography package bundled with pyodide can be
used, relax the version requirement to anything that's version 39.*.

Fix #284
2023-09-22 16:43:40 -07:00
Gilles Boccon-Gibod
1ad84ad51c fix linter errors 2023-09-22 15:08:10 -07:00
Gilles Boccon-Gibod
64937c3f77 support netsim instance numbers 2023-09-22 14:22:04 -07:00
Gabriel White-Vega
50fd2218fa Remove unncecesary steps for injecting Python event loop
* Context vars can be injected directly into Rust future and spawned with tokio
2023-09-22 15:23:01 -04:00
Gilles Boccon-Gibod
4c29a16271 Merge pull request #297 from google/gbg/websocket-full-url
ws-client: make implementation match the doc
2023-09-22 11:41:24 -07:00
Gilles Boccon-Gibod
762d3e92de Merge pull request #300 from google/gbg/issue-299
use correct own_address_type when restarting advertising
2023-09-22 11:41:04 -07:00
uael
2f97531d78 pandora: use public identity address for public addresses 2023-09-22 20:08:34 +02:00
Gilles Boccon-Gibod
f6c7cae661 use correct own_address_type when restarting advertising 2023-09-22 10:33:36 -07:00
Gilles Boccon-Gibod
f1777a5bd2 use .to_string instead of a manual suffix replacement 2023-09-21 19:03:54 -07:00
Gilles Boccon-Gibod
78a06ae8cf make implementation match the doc 2023-09-21 19:01:40 -07:00
zxzxwu
d290df4aa9 Merge pull request #278 from zxzxwu/gatt2
Typing GATT
2023-09-21 16:09:36 +08:00
Josh Wu
e559744f32 Typing att 2023-09-21 15:52:07 +08:00
zxzxwu
67418e649a Merge pull request #288 from zxzxwu/l2cap_states
L2CAP: Refactor states to enums
2023-09-21 15:42:21 +08:00
Gilles Boccon-Gibod
5adf9fab53 Merge pull request #275 from whitevegagabriel/file-header
Add license header check for rust files
2023-09-20 16:21:38 -07:00
Josh Wu
2491b686fa Handle SMP_Security_Request 2023-09-20 23:13:08 +02:00
Josh Wu
efd02b2f3e Adopt reviews 2023-09-20 23:03:23 +02:00
Josh Wu
3b14078646 Overload signatures 2023-09-20 23:03:23 +02:00
Josh Wu
eb9d5632bc Add utils_test type hint 2023-09-20 23:03:23 +02:00
Josh Wu
45f60edbb6 Pyee watcher context 2023-09-20 23:03:23 +02:00
David Duarte
393ea6a7bb pandora_server: Load server config
Pandora server has it's own config that we load from the 'server'
property of the current bumble config file
2023-09-18 14:28:42 -07:00
Gabriel White-Vega
6ec6f1efe5 Add license header check for rust files
Added binary that can check for and add Apache 2.0 licenses.
Run this binary during the build-rust workflow.
2023-09-14 14:29:47 -04:00
Josh Wu
5d9598ea51 L2CAP: Refactor states to enums 2023-09-14 20:52:33 +08:00
Gilles Boccon-Gibod
0d36d99a73 Merge pull request #287 from google/revert-286-gbg/package-depencencies-for-wasm
Revert "make cryptography a valid dependency for emscripten targets"
2023-09-13 23:37:42 -07:00
Gilles Boccon-Gibod
d8a9f5a724 Revert "make cryptography a valid dependency for emscripten targets" 2023-09-13 23:36:33 -07:00
Gilles Boccon-Gibod
2c66e1a042 Merge pull request #285 from google/gbg/fix-mypy-errors
mypy: ignore false positive errors
2023-09-13 23:30:50 -07:00
Gilles Boccon-Gibod
d5eccdb00f Merge pull request #286 from google/gbg/package-depencencies-for-wasm
make cryptography a valid dependency for emscripten targets
2023-09-13 23:30:28 -07:00
Gilles Boccon-Gibod
32626573a6 ignore false positive errors 2023-09-13 23:17:00 -07:00
Gilles Boccon-Gibod
caa82b8f7e make cryptography a valid dependency for emscripten targets 2023-09-13 22:38:28 -07:00
Gilles Boccon-Gibod
5af347b499 Merge pull request #282 from google/gbg/multi-python-pre-commit-check
run pre-commit tests with all supported Python versions
2023-09-13 07:47:32 -07:00
zxzxwu
4ed5bb5a9e Merge pull request #281 from zxzxwu/cleanup-transport
Replace | typing usage with Optional and Union
2023-09-13 13:31:41 +08:00
Josh Wu
f39f5f531c Replace | typing usage with Optional and Union 2023-09-12 15:50:51 +08:00
67 changed files with 10570 additions and 1164 deletions

43
.github/workflows/python-avatar.yml vendored Normal file
View File

@@ -0,0 +1,43 @@
name: Python Avatar
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
test:
name: Avatar [${{ matrix.shard }}]
runs-on: ubuntu-latest
strategy:
matrix:
shard: [
1/24, 2/24, 3/24, 4/24,
5/24, 6/24, 7/24, 8/24,
9/24, 10/24, 11/24, 12/24,
13/24, 14/24, 15/24, 16/24,
17/24, 18/24, 19/24, 20/24,
21/24, 22/24, 23/24, 24/24,
]
steps:
- uses: actions/checkout@v3
- name: Set Up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install .[avatar]
- name: Rootcanal
run: nohup python -m rootcanal > rootcanal.log &
- name: Test
run: |
avatar --list | grep -Ev '^=' > test-names.txt
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
- name: Rootcanal Logs
run: cat rootcanal.log

View File

@@ -65,6 +65,8 @@ jobs:
with: with:
components: clippy,rustfmt components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }} toolchain: ${{ matrix.rust-version }}
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build - name: Rust Build
run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
# Lints after build so what clippy needs is already built # Lints after build so what clippy needs is already built

View File

@@ -39,10 +39,12 @@
"libusb", "libusb",
"MITM", "MITM",
"NDIS", "NDIS",
"netsim",
"NONBLOCK", "NONBLOCK",
"NONCONN", "NONCONN",
"OXIMETER", "OXIMETER",
"popleft", "popleft",
"protobuf",
"psms", "psms",
"pyee", "pyee",
"pyusb", "pyusb",

View File

@@ -1172,7 +1172,7 @@ class ScanResult:
name = '' name = ''
# Remove any '/P' qualifier suffix from the address string # Remove any '/P' qualifier suffix from the address string
address_str = str(self.address).replace('/P', '') address_str = self.address.to_string(with_type_qualifier=False)
# RSSI bar # RSSI bar
bar_string = rssi_bar(self.rssi) bar_string = rssi_bar(self.rssi)

View File

@@ -63,7 +63,8 @@ async def get_classic_info(host):
if command_succeeded(response): if command_succeeded(response):
print() print()
print( print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr color('Classic Address:', 'yellow'),
response.return_parameters.bd_addr.to_string(False),
) )
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND): if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):

View File

@@ -306,6 +306,7 @@ async def pair(
# Expose a GATT characteristic that can be used to trigger pairing by # Expose a GATT characteristic that can be used to trigger pairing by
# responding with an authentication error when read # responding with an authentication error when read
if mode == 'le': if mode == 'le':
device.le_enabled = True
device.add_service( device.add_service(
Service( Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5', '50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
@@ -326,7 +327,6 @@ async def pair(
# Select LE or Classic # Select LE or Classic
if mode == 'classic': if mode == 'classic':
device.classic_enabled = True device.classic_enabled = True
device.le_enabled = False
device.classic_smp_enabled = ctkd device.classic_smp_enabled = ctkd
# Get things going # Get things going

View File

@@ -3,7 +3,7 @@ import click
import logging import logging
import json import json
from bumble.pandora import PandoraDevice, serve from bumble.pandora import PandoraDevice, Config, serve
from typing import Dict, Any from typing import Dict, Any
BUMBLE_SERVER_GRPC_PORT = 7999 BUMBLE_SERVER_GRPC_PORT = 7999
@@ -29,12 +29,14 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No
transport = transport.replace('<rootcanal-port>', str(rootcanal_port)) transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
bumble_config = retrieve_config(config) bumble_config = retrieve_config(config)
if 'transport' not in bumble_config.keys(): bumble_config.setdefault('transport', transport)
bumble_config.update({'transport': transport})
device = PandoraDevice(bumble_config) device = PandoraDevice(bumble_config)
server_config = Config()
server_config.load_from_dict(bumble_config.get('server', {}))
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
asyncio.run(serve(device, port=grpc_port)) asyncio.run(serve(device, config=server_config, port=grpc_port))
def retrieve_config(config: str) -> Dict[str, Any]: def retrieve_config(config: str) -> Dict[str, Any]:

View File

@@ -195,7 +195,7 @@ class WebSocketOutput(QueuedOutput):
except HCI_StatusError: except HCI_StatusError:
pass pass
peer_name = '' if connection.peer_name is None else connection.peer_name peer_name = '' if connection.peer_name is None else connection.peer_name
peer_address = str(connection.peer_address).replace('/P', '') peer_address = connection.peer_address.to_string(False)
await self.send_message( await self.send_message(
'connection', 'connection',
peer_address=peer_address, peer_address=peer_address,
@@ -376,7 +376,7 @@ class UiServer:
if connection := self.speaker().connection: if connection := self.speaker().connection:
await self.send_message( await self.send_message(
'connection', 'connection',
peer_address=str(connection.peer_address).replace('/P', ''), peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name, peer_name=connection.peer_name,
) )

View File

@@ -23,13 +23,14 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum
import functools import functools
import struct import struct
from pyee import EventEmitter from pyee import EventEmitter
from typing import Dict, Type, TYPE_CHECKING from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value, HCI_Constant from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color from bumble.colors import color
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -182,6 +183,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# pylint: enable=line-too-long # pylint: enable=line-too-long
# pylint: disable=invalid-name # pylint: disable=invalid-name
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Exceptions # Exceptions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -209,7 +211,7 @@ class ATT_PDU:
pdu_classes: Dict[int, Type[ATT_PDU]] = {} pdu_classes: Dict[int, Type[ATT_PDU]] = {}
op_code = 0 op_code = 0
name = None name: str
@staticmethod @staticmethod
def from_bytes(pdu): def from_bytes(pdu):
@@ -719,9 +721,18 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
''' '''
# -----------------------------------------------------------------------------
class ConnectionValue(Protocol):
def read(self, connection) -> bytes:
...
def write(self, connection, value: bytes) -> None:
...
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Attribute(EventEmitter): class Attribute(EventEmitter):
# Permission flags class Permissions(enum.IntFlag):
READABLE = 0x01 READABLE = 0x01
WRITEABLE = 0x02 WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04 READ_REQUIRES_ENCRYPTION = 0x04
@@ -731,36 +742,47 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = 0x40 READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80 WRITE_REQUIRES_AUTHORIZATION = 0x80
PERMISSION_NAMES = { @classmethod
READABLE: 'READABLE', def from_string(cls, permissions_str: str) -> Attribute.Permissions:
WRITEABLE: 'WRITEABLE',
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
}
@staticmethod
def string_to_permissions(permissions_str: str):
try: try:
return functools.reduce( return functools.reduce(
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y), lambda x, y: x | Attribute.Permissions[y],
permissions_str.split(","), permissions_str.replace('|', ',').split(","),
0, Attribute.Permissions(0),
) )
except TypeError as exc: except TypeError as exc:
# The check for `p.name is not None` here is needed because for InFlag
# enums, the .name property can be None, when the enum value is 0,
# so the type hint for .name is Optional[str].
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError( raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}" f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
) from exc ) from exc
def __init__(self, attribute_type, permissions, value=b''): # Permission flags(legacy-use only)
READABLE = Permissions.READABLE
WRITEABLE = Permissions.WRITEABLE
READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[str, bytes, ConnectionValue]
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, ConnectionValue] = b'',
) -> None:
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.handle = 0 self.handle = 0
self.end_group_handle = 0 self.end_group_handle = 0
if isinstance(permissions, str): if isinstance(permissions, str):
self.permissions = self.string_to_permissions(permissions) self.permissions = Attribute.Permissions.from_string(permissions)
else: else:
self.permissions = permissions self.permissions = permissions
@@ -778,22 +800,26 @@ class Attribute(EventEmitter):
else: else:
self.value = value self.value = value
def encode_value(self, value): def encode_value(self, value: Any) -> bytes:
return value return value
def decode_value(self, value_bytes): def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes return value_bytes
def read_value(self, connection: Connection): def read_value(self, connection: Optional[Connection]) -> bytes:
if ( if (
self.permissions & self.READ_REQUIRES_ENCRYPTION (self.permissions & self.READ_REQUIRES_ENCRYPTION)
) and not connection.encryption: and connection is not None
and not connection.encryption
):
raise ATT_Error( raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
) )
if ( if (
self.permissions & self.READ_REQUIRES_AUTHENTICATION (self.permissions & self.READ_REQUIRES_AUTHENTICATION)
) and not connection.authenticated: and connection is not None
and not connection.authenticated
):
raise ATT_Error( raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
) )
@@ -803,9 +829,9 @@ class Attribute(EventEmitter):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
) )
if read := getattr(self.value, 'read', None): if hasattr(self.value, 'read'):
try: try:
value = read(connection) # pylint: disable=not-callable value = self.value.read(connection)
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
@@ -815,7 +841,7 @@ class Attribute(EventEmitter):
return self.encode_value(value) return self.encode_value(value)
def write_value(self, connection: Connection, value_bytes): def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if ( if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption: ) and not connection.encryption:
@@ -836,9 +862,9 @@ class Attribute(EventEmitter):
value = self.decode_value(value_bytes) value = self.decode_value(value_bytes)
if write := getattr(self.value, 'write', None): if hasattr(self.value, 'write'):
try: try:
write(connection, value) # pylint: disable=not-callable self.value.write(connection, value) # pylint: disable=not-callable
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -80,7 +80,7 @@ class BaseError(Exception):
def __init__( def __init__(
self, self,
error_code: int | None, error_code: Optional[int],
error_namespace: str = '', error_namespace: str = '',
error_name: str = '', error_name: str = '',
details: str = '', details: str = '',

View File

@@ -33,6 +33,8 @@ from typing import (
Tuple, Tuple,
Type, Type,
Union, Union,
cast,
overload,
TYPE_CHECKING, TYPE_CHECKING,
) )
@@ -151,6 +153,7 @@ from .utils import (
CompositeEventEmitter, CompositeEventEmitter,
setup_event_forwarding, setup_event_forwarding,
composite_listener, composite_listener,
deprecated,
) )
from .keys import ( from .keys import (
KeyStore, KeyStore,
@@ -670,9 +673,7 @@ class Connection(CompositeEventEmitter):
def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None: def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(self.handle, cid, pdu) self.device.send_l2cap_pdu(self.handle, cid, pdu)
def create_l2cap_connector(self, psm): @deprecated("Please use create_l2cap_channel()")
return self.device.create_l2cap_connector(self, psm)
async def open_l2cap_channel( async def open_l2cap_channel(
self, self,
psm, psm,
@@ -682,6 +683,23 @@ class Connection(CompositeEventEmitter):
): ):
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps) return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
@overload
async def create_l2cap_channel(
self, spec: l2cap.ClassicChannelSpec
) -> l2cap.ClassicChannel:
...
@overload
async def create_l2cap_channel(
self, spec: l2cap.LeCreditBasedChannelSpec
) -> l2cap.LeCreditBasedChannel:
...
async def create_l2cap_channel(
self, spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec]
) -> Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]:
return await self.device.create_l2cap_channel(connection=self, spec=spec)
async def disconnect( async def disconnect(
self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
) -> None: ) -> None:
@@ -1180,15 +1198,11 @@ class Device(CompositeEventEmitter):
return None return None
def create_l2cap_connector(self, connection, psm): @deprecated("Please use create_l2cap_server()")
return lambda: self.l2cap_channel_manager.connect(connection, psm) def register_l2cap_server(self, psm, server) -> int:
return self.l2cap_channel_manager.register_server(psm, server)
def create_l2cap_registrar(self, psm):
return lambda handler: self.register_l2cap_server(psm, handler)
def register_l2cap_server(self, psm, server):
self.l2cap_channel_manager.register_server(psm, server)
@deprecated("Please use create_l2cap_server()")
def register_l2cap_channel_server( def register_l2cap_channel_server(
self, self,
psm, psm,
@@ -1201,6 +1215,7 @@ class Device(CompositeEventEmitter):
psm, server, max_credits, mtu, mps psm, server, max_credits, mtu, mps
) )
@deprecated("Please use create_l2cap_channel()")
async def open_l2cap_channel( async def open_l2cap_channel(
self, self,
connection, connection,
@@ -1213,6 +1228,74 @@ class Device(CompositeEventEmitter):
connection, psm, max_credits, mtu, mps connection, psm, max_credits, mtu, mps
) )
@overload
async def create_l2cap_channel(
self,
connection: Connection,
spec: l2cap.ClassicChannelSpec,
) -> l2cap.ClassicChannel:
...
@overload
async def create_l2cap_channel(
self,
connection: Connection,
spec: l2cap.LeCreditBasedChannelSpec,
) -> l2cap.LeCreditBasedChannel:
...
async def create_l2cap_channel(
self,
connection: Connection,
spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec],
) -> Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]:
if isinstance(spec, l2cap.ClassicChannelSpec):
return await self.l2cap_channel_manager.create_classic_channel(
connection=connection, spec=spec
)
if isinstance(spec, l2cap.LeCreditBasedChannelSpec):
return await self.l2cap_channel_manager.create_le_credit_based_channel(
connection=connection, spec=spec
)
@overload
def create_l2cap_server(
self,
spec: l2cap.ClassicChannelSpec,
handler: Optional[Callable[[l2cap.ClassicChannel], Any]] = None,
) -> l2cap.ClassicChannelServer:
...
@overload
def create_l2cap_server(
self,
spec: l2cap.LeCreditBasedChannelSpec,
handler: Optional[Callable[[l2cap.LeCreditBasedChannel], Any]] = None,
) -> l2cap.LeCreditBasedChannelServer:
...
def create_l2cap_server(
self,
spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec],
handler: Union[
Callable[[l2cap.ClassicChannel], Any],
Callable[[l2cap.LeCreditBasedChannel], Any],
None,
] = None,
) -> Union[l2cap.ClassicChannelServer, l2cap.LeCreditBasedChannelServer]:
if isinstance(spec, l2cap.ClassicChannelSpec):
return self.l2cap_channel_manager.create_classic_server(
spec=spec,
handler=cast(Callable[[l2cap.ClassicChannel], Any], handler),
)
elif isinstance(spec, l2cap.LeCreditBasedChannelSpec):
return self.l2cap_channel_manager.create_le_credit_based_server(
handler=cast(Callable[[l2cap.LeCreditBasedChannel], Any], handler),
spec=spec,
)
else:
raise ValueError(f'Unexpected mode {spec}')
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu) self.host.send_l2cap_pdu(connection_handle, cid, pdu)
@@ -1425,10 +1508,10 @@ class Device(CompositeEventEmitter):
check_result=True, check_result=True,
) )
self.advertising_own_address_type = own_address_type
self.auto_restart_advertising = auto_restart
self.advertising_type = advertising_type self.advertising_type = advertising_type
self.advertising_own_address_type = own_address_type
self.advertising = True self.advertising = True
self.auto_restart_advertising = auto_restart
async def stop_advertising(self) -> None: async def stop_advertising(self) -> None:
# Disable advertising # Disable advertising
@@ -1438,9 +1521,9 @@ class Device(CompositeEventEmitter):
check_result=True, check_result=True,
) )
self.advertising_type = None
self.advertising_own_address_type = None self.advertising_own_address_type = None
self.advertising = False self.advertising = False
self.advertising_type = None
self.auto_restart_advertising = False self.auto_restart_advertising = False
@property @property
@@ -2630,7 +2713,6 @@ class Device(CompositeEventEmitter):
own_address_type = self.advertising_own_address_type own_address_type = self.advertising_own_address_type
# We are no longer advertising # We are no longer advertising
self.advertising_own_address_type = None
self.advertising = False self.advertising = False
if own_address_type in ( if own_address_type in (
@@ -2687,7 +2769,6 @@ class Device(CompositeEventEmitter):
and self.advertising and self.advertising
and self.advertising_type.is_directed and self.advertising_type.is_directed
): ):
self.advertising_own_address_type = None
self.advertising = False self.advertising = False
# Notify listeners # Notify listeners
@@ -2758,7 +2839,9 @@ class Device(CompositeEventEmitter):
self.abort_on( self.abort_on(
'flush', 'flush',
self.start_advertising( self.start_advertising(
advertising_type=self.advertising_type, auto_restart=True advertising_type=self.advertising_type,
own_address_type=self.advertising_own_address_type,
auto_restart=True,
), ),
) )

View File

@@ -28,7 +28,7 @@ import enum
import functools import functools
import logging import logging
import struct import struct
from typing import Optional, Sequence, List from typing import Optional, Sequence, Iterable, List, Union
from .colors import color from .colors import color
from .core import UUID, get_dict_key_by_value from .core import UUID, get_dict_key_by_value
@@ -187,7 +187,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def show_services(services): def show_services(services: Iterable[Service]) -> None:
for service in services: for service in services:
print(color(str(service), 'cyan')) print(color(str(service), 'cyan'))
@@ -210,11 +210,11 @@ class Service(Attribute):
def __init__( def __init__(
self, self,
uuid, uuid: Union[str, UUID],
characteristics: List[Characteristic], characteristics: List[Characteristic],
primary=True, primary=True,
included_services: List[Service] = [], included_services: List[Service] = [],
): ) -> None:
# Convert the uuid to a UUID object if it isn't already # Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str): if isinstance(uuid, str):
uuid = UUID(uuid) uuid = UUID(uuid)
@@ -239,7 +239,7 @@ class Service(Attribute):
""" """
return None return None
def __str__(self): def __str__(self) -> str:
return ( return (
f'Service(handle=0x{self.handle:04X}, ' f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, ' f'end=0x{self.end_group_handle:04X}, '
@@ -255,9 +255,11 @@ class TemplateService(Service):
to expose their UUID as a class property to expose their UUID as a class property
''' '''
UUID: Optional[UUID] = None UUID: UUID
def __init__(self, characteristics, primary=True): def __init__(
self, characteristics: List[Characteristic], primary: bool = True
) -> None:
super().__init__(self.UUID, characteristics, primary) super().__init__(self.UUID, characteristics, primary)
@@ -269,7 +271,7 @@ class IncludedServiceDeclaration(Attribute):
service: Service service: Service
def __init__(self, service): def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack( declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes() '<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
) )
@@ -278,7 +280,7 @@ class IncludedServiceDeclaration(Attribute):
) )
self.service = service self.service = service
def __str__(self): def __str__(self) -> str:
return ( return (
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, ' f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
f'group_starting_handle=0x{self.service.handle:04X}, ' f'group_starting_handle=0x{self.service.handle:04X}, '
@@ -326,7 +328,7 @@ class Characteristic(Attribute):
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}" f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
) )
def __str__(self): def __str__(self) -> str:
# NOTE: we override this method to offer a consistent result between python # NOTE: we override this method to offer a consistent result between python
# versions: the value returned by IntFlag.__str__() changed in version 11. # versions: the value returned by IntFlag.__str__() changed in version 11.
return '|'.join( return '|'.join(
@@ -348,10 +350,10 @@ class Characteristic(Attribute):
def __init__( def __init__(
self, self,
uuid, uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties, properties: Characteristic.Properties,
permissions, permissions: Union[str, Attribute.Permissions],
value=b'', value: Union[str, bytes, CharacteristicValue] = b'',
descriptors: Sequence[Descriptor] = (), descriptors: Sequence[Descriptor] = (),
): ):
super().__init__(uuid, permissions, value) super().__init__(uuid, permissions, value)
@@ -369,7 +371,7 @@ class Characteristic(Attribute):
def has_properties(self, properties: Characteristic.Properties) -> bool: def has_properties(self, properties: Characteristic.Properties) -> bool:
return self.properties & properties == properties return self.properties & properties == properties
def __str__(self): def __str__(self) -> str:
return ( return (
f'Characteristic(handle=0x{self.handle:04X}, ' f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, ' f'end=0x{self.end_group_handle:04X}, '
@@ -386,7 +388,7 @@ class CharacteristicDeclaration(Attribute):
characteristic: Characteristic characteristic: Characteristic
def __init__(self, characteristic, value_handle): def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
declaration_bytes = ( declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle) struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes() + characteristic.uuid.to_pdu_bytes()
@@ -397,7 +399,7 @@ class CharacteristicDeclaration(Attribute):
self.value_handle = value_handle self.value_handle = value_handle
self.characteristic = characteristic self.characteristic = characteristic
def __str__(self): def __str__(self) -> str:
return ( return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, ' f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, ' f'value_handle=0x{self.value_handle:04X}, '
@@ -520,7 +522,7 @@ class CharacteristicAdapter:
return self.wrapped_characteristic.unsubscribe(subscriber) return self.wrapped_characteristic.unsubscribe(subscriber)
def __str__(self): def __str__(self) -> str:
wrapped = str(self.wrapped_characteristic) wrapped = str(self.wrapped_characteristic)
return f'{self.__class__.__name__}({wrapped})' return f'{self.__class__.__name__}({wrapped})'
@@ -600,10 +602,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
Adapter that converts strings to/from bytes using UTF-8 encoding Adapter that converts strings to/from bytes using UTF-8 encoding
''' '''
def encode_value(self, value): def encode_value(self, value: str) -> bytes:
return value.encode('utf-8') return value.encode('utf-8')
def decode_value(self, value): def decode_value(self, value: bytes) -> str:
return value.decode('utf-8') return value.decode('utf-8')
@@ -613,7 +615,7 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
''' '''
def __str__(self): def __str__(self) -> str:
return ( return (
f'Descriptor(handle=0x{self.handle:04X}, ' f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, ' f'type={self.type}, '

View File

@@ -28,7 +28,18 @@ import asyncio
import logging import logging
import struct import struct
from datetime import datetime from datetime import datetime
from typing import List, Optional, Dict, Tuple, Callable, Union, Any from typing import (
List,
Optional,
Dict,
Tuple,
Callable,
Union,
Any,
Iterable,
Type,
TYPE_CHECKING,
)
from pyee import EventEmitter from pyee import EventEmitter
@@ -66,8 +77,12 @@ from .gatt import (
GATT_INCLUDE_ATTRIBUTE_TYPE, GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic, Characteristic,
ClientCharacteristicConfigurationBits, ClientCharacteristicConfigurationBits,
TemplateService,
) )
if TYPE_CHECKING:
from bumble.device import Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -78,16 +93,16 @@ logger = logging.getLogger(__name__)
# Proxies # Proxies
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter): class AttributeProxy(EventEmitter):
client: Client def __init__(
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
def __init__(self, client, handle, end_group_handle, attribute_type): ) -> None:
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.client = client self.client = client
self.handle = handle self.handle = handle
self.end_group_handle = end_group_handle self.end_group_handle = end_group_handle
self.type = attribute_type self.type = attribute_type
async def read_value(self, no_long_read=False): async def read_value(self, no_long_read: bool = False) -> bytes:
return self.decode_value( return self.decode_value(
await self.client.read_value(self.handle, no_long_read) await self.client.read_value(self.handle, no_long_read)
) )
@@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter):
self.handle, self.encode_value(value), with_response self.handle, self.encode_value(value), with_response
) )
def encode_value(self, value): def encode_value(self, value: Any) -> bytes:
return value return value
def decode_value(self, value_bytes): def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes return value_bytes
def __str__(self): def __str__(self) -> str:
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})' return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
@@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy):
def get_characteristics_by_uuid(self, uuid): def get_characteristics_by_uuid(self, uuid):
return self.client.get_characteristics_by_uuid(uuid, self) return self.client.get_characteristics_by_uuid(uuid, self)
def __str__(self): def __str__(self) -> str:
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})' return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
class CharacteristicProxy(AttributeProxy): class CharacteristicProxy(AttributeProxy):
properties: Characteristic.Properties properties: Characteristic.Properties
descriptors: List[DescriptorProxy] descriptors: List[DescriptorProxy]
subscribers: Dict[Any, Callable] subscribers: Dict[Any, Callable[[bytes], Any]]
def __init__( def __init__(
self, self,
@@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.discover_descriptors(self) return await self.client.discover_descriptors(self)
async def subscribe( async def subscribe(
self, subscriber: Optional[Callable] = None, prefer_notify=True self,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
): ):
if subscriber is not None: if subscriber is not None:
if subscriber in self.subscribers: if subscriber in self.subscribers:
@@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.unsubscribe(self, subscriber) return await self.client.unsubscribe(self, subscriber)
def __str__(self): def __str__(self) -> str:
return ( return (
f'Characteristic(handle=0x{self.handle:04X}, ' f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, ' f'uuid={self.uuid}, '
@@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy):
def __init__(self, client, handle, descriptor_type): def __init__(self, client, handle, descriptor_type):
super().__init__(client, handle, 0, descriptor_type) super().__init__(client, handle, 0, descriptor_type)
def __str__(self): def __str__(self) -> str:
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})' return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
@@ -216,8 +233,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies Base class for profile-specific service proxies
''' '''
SERVICE_CLASS: Type[TemplateService]
@classmethod @classmethod
def from_client(cls, client): def from_client(cls, client: Client) -> ProfileServiceProxy:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -227,8 +246,12 @@ class ProfileServiceProxy:
class Client: class Client:
services: List[ServiceProxy] services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]] cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[int, Callable[[bytes], Any]]
indication_subscribers: Dict[int, Callable[[bytes], Any]]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU]
def __init__(self, connection): def __init__(self, connection: Connection) -> None:
self.connection = connection self.connection = connection
self.mtu_exchange_done = False self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1) self.request_semaphore = asyncio.Semaphore(1)
@@ -241,16 +264,16 @@ class Client:
self.services = [] self.services = []
self.cached_values = {} self.cached_values = {}
def send_gatt_pdu(self, pdu): def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu) self.connection.send_l2cap_pdu(ATT_CID, pdu)
async def send_command(self, command): async def send_command(self, command: ATT_PDU) -> None:
logger.debug( logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}' f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
) )
self.send_gatt_pdu(command.to_bytes()) self.send_gatt_pdu(command.to_bytes())
async def send_request(self, request): async def send_request(self, request: ATT_PDU):
logger.debug( logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}' f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
) )
@@ -279,14 +302,14 @@ class Client:
return response return response
def send_confirmation(self, confirmation): def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
logger.debug( logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] ' f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}' f'{confirmation}'
) )
self.send_gatt_pdu(confirmation.to_bytes()) self.send_gatt_pdu(confirmation.to_bytes())
async def request_mtu(self, mtu): async def request_mtu(self, mtu: int) -> int:
# Check the range # Check the range
if mtu < ATT_DEFAULT_MTU: if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}') raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
@@ -313,10 +336,12 @@ class Client:
return self.connection.att_mtu return self.connection.att_mtu
def get_services_by_uuid(self, uuid): def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid] return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid(self, uuid, service=None): def get_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy] = None
) -> List[CharacteristicProxy]:
services = [service] if service else self.services services = [service] if service else self.services
return [ return [
c c
@@ -363,7 +388,7 @@ class Client:
if not already_known: if not already_known:
self.services.append(service) self.services.append(service)
async def discover_services(self, uuids=None) -> List[ServiceProxy]: async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
''' '''
See Vol 3, Part G - 4.4.1 Discover All Primary Services See Vol 3, Part G - 4.4.1 Discover All Primary Services
''' '''
@@ -435,7 +460,7 @@ class Client:
return services return services
async def discover_service(self, uuid): async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
''' '''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
''' '''
@@ -468,7 +493,7 @@ class Client:
f'{HCI_Constant.error_name(response.error_code)}' f'{HCI_Constant.error_name(response.error_code)}'
) )
# TODO raise appropriate exception # TODO raise appropriate exception
return return []
break break
for attribute_handle, end_group_handle in response.handles_information: for attribute_handle, end_group_handle in response.handles_information:
@@ -480,7 +505,7 @@ class Client:
logger.warning( logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}' f'bogus handle values: {attribute_handle} {end_group_handle}'
) )
return return []
# Create a service proxy for this service # Create a service proxy for this service
service = ServiceProxy( service = ServiceProxy(
@@ -721,7 +746,7 @@ class Client:
return descriptors return descriptors
async def discover_attributes(self): async def discover_attributes(self) -> List[AttributeProxy]:
''' '''
Discover all attributes, regardless of type Discover all attributes, regardless of type
''' '''
@@ -844,7 +869,9 @@ class Client:
# No more subscribers left # No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True) await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(self, attribute, no_long_read=False): async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> Any:
''' '''
See Vol 3, Part G - 4.8.1 Read Characteristic Value See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -905,7 +932,9 @@ class Client:
# Return the value as bytes # Return the value as bytes
return attribute_value return attribute_value
async def read_characteristics_by_uuid(self, uuid, service): async def read_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy]
) -> List[bytes]:
''' '''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
''' '''
@@ -960,7 +989,12 @@ class Client:
return characteristics_values return characteristics_values
async def write_value(self, attribute, value, with_response=False): async def write_value(
self,
attribute: Union[int, AttributeProxy],
value: bytes,
with_response: bool = False,
) -> None:
''' '''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value Value
@@ -990,7 +1024,7 @@ class Client:
) )
) )
def on_gatt_pdu(self, att_pdu): def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug( logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
) )
@@ -1013,6 +1047,7 @@ class Client:
return return
# Return the response to the coroutine that is waiting for it # Return the response to the coroutine that is waiting for it
assert self.pending_response is not None
self.pending_response.set_result(att_pdu) self.pending_response.set_result(att_pdu)
else: else:
handler_name = f'on_{att_pdu.name.lower()}' handler_name = f'on_{att_pdu.name.lower()}'
@@ -1060,7 +1095,7 @@ class Client:
# Confirm that we received the indication # Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation()) self.send_confirmation(ATT_Handle_Value_Confirmation())
def cache_value(self, attribute_handle: int, value: bytes): def cache_value(self, attribute_handle: int, value: bytes) -> None:
self.cached_values[attribute_handle] = ( self.cached_values[attribute_handle] = (
datetime.now(), datetime.now(),
value, value,

View File

@@ -23,11 +23,12 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
import struct import struct
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter from pyee import EventEmitter
from .colors import color from .colors import color
@@ -42,6 +43,7 @@ from .att import (
ATT_INVALID_OFFSET_ERROR, ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR, ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS, ATT_REQUESTS,
ATT_PDU,
ATT_UNLIKELY_ERROR_ERROR, ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR, ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error, ATT_Error,
@@ -73,6 +75,8 @@ from .gatt import (
Service, Service,
) )
if TYPE_CHECKING:
from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -91,8 +95,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server(EventEmitter): class Server(EventEmitter):
attributes: List[Attribute] attributes: List[Attribute]
services: List[Service]
attributes_by_handle: Dict[int, Attribute]
subscribers: Dict[int, Dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
def __init__(self, device): def __init__(self, device: Device) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.services = [] self.services = []
@@ -107,16 +116,16 @@ class Server(EventEmitter):
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1)) self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None) self.pending_confirmations = defaultdict(lambda: None)
def __str__(self): def __str__(self) -> str:
return "\n".join(map(str, self.attributes)) return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle, pdu): def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu) self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
def next_handle(self): def next_handle(self) -> int:
return 1 + len(self.attributes) return 1 + len(self.attributes)
def get_advertising_service_data(self): def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
return { return {
attribute: data attribute: data
for attribute in self.attributes for attribute in self.attributes
@@ -124,7 +133,7 @@ class Server(EventEmitter):
and (data := attribute.get_advertising_data()) and (data := attribute.get_advertising_data())
} }
def get_attribute(self, handle): def get_attribute(self, handle: int) -> Optional[Attribute]:
attribute = self.attributes_by_handle.get(handle) attribute = self.attributes_by_handle.get(handle)
if attribute: if attribute:
return attribute return attribute
@@ -173,12 +182,17 @@ class Server(EventEmitter):
return next( return next(
( (
(attribute, self.get_attribute(attribute.characteristic.handle)) (
attribute,
self.get_attribute(attribute.characteristic.handle),
) # type: ignore
for attribute in map( for attribute in map(
self.get_attribute, self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1), range(service_handle.handle, service_handle.end_group_handle + 1),
) )
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE if attribute is not None
and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
and isinstance(attribute, CharacteristicDeclaration)
and attribute.characteristic.uuid == characteristic_uuid and attribute.characteristic.uuid == characteristic_uuid
), ),
None, None,
@@ -197,7 +211,7 @@ class Server(EventEmitter):
return next( return next(
( (
attribute attribute # type: ignore
for attribute in map( for attribute in map(
self.get_attribute, self.get_attribute,
range( range(
@@ -205,12 +219,12 @@ class Server(EventEmitter):
characteristic_value.end_group_handle + 1, characteristic_value.end_group_handle + 1,
), ),
) )
if attribute.type == descriptor_uuid if attribute is not None and attribute.type == descriptor_uuid
), ),
None, None,
) )
def add_attribute(self, attribute): def add_attribute(self, attribute: Attribute) -> None:
# Assign a handle to this attribute # Assign a handle to this attribute
attribute.handle = self.next_handle() attribute.handle = self.next_handle()
attribute.end_group_handle = ( attribute.end_group_handle = (
@@ -220,7 +234,7 @@ class Server(EventEmitter):
# Add this attribute to the list # Add this attribute to the list
self.attributes.append(attribute) self.attributes.append(attribute)
def add_service(self, service: Service): def add_service(self, service: Service) -> None:
# Add the service attribute to the DB # Add the service attribute to the DB
self.add_attribute(service) self.add_attribute(service)
@@ -285,11 +299,13 @@ class Server(EventEmitter):
service.end_group_handle = self.attributes[-1].handle service.end_group_handle = self.attributes[-1].handle
self.services.append(service) self.services.append(service)
def add_services(self, services): def add_services(self, services: Iterable[Service]) -> None:
for service in services: for service in services:
self.add_service(service) self.add_service(service)
def read_cccd(self, connection, characteristic): def read_cccd(
self, connection: Optional[Connection], characteristic: Characteristic
) -> bytes:
if connection is None: if connection is None:
return bytes([0, 0]) return bytes([0, 0])
@@ -300,7 +316,12 @@ class Server(EventEmitter):
return cccd or bytes([0, 0]) return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value): def write_cccd(
self,
connection: Connection,
characteristic: Characteristic,
value: bytes,
) -> None:
logger.debug( logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, ' f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}' f'handle=0x{characteristic.handle:04X}: {value.hex()}'
@@ -327,13 +348,19 @@ class Server(EventEmitter):
indicate_enabled, indicate_enabled,
) )
def send_response(self, connection, response): def send_response(self, connection: Connection, response: ATT_PDU) -> None:
logger.debug( logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}' f'GATT Response from server: [0x{connection.handle:04X}] {response}'
) )
self.send_gatt_pdu(connection.handle, response.to_bytes()) self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, value=None, force=False): async def notify_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Check if there's a subscriber # Check if there's a subscriber
if not force: if not force:
subscribers = self.subscribers.get(connection.handle) subscribers = self.subscribers.get(connection.handle)
@@ -370,7 +397,13 @@ class Server(EventEmitter):
) )
self.send_gatt_pdu(connection.handle, bytes(notification)) self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False): async def indicate_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Check if there's a subscriber # Check if there's a subscriber
if not force: if not force:
subscribers = self.subscribers.get(connection.handle) subscribers = self.subscribers.get(connection.handle)
@@ -411,15 +444,13 @@ class Server(EventEmitter):
assert self.pending_confirmations[connection.handle] is None assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
self.pending_confirmations[ pending_confirmation = self.pending_confirmations[
connection.handle connection.handle
] = asyncio.get_running_loop().create_future() ] = asyncio.get_running_loop().create_future()
try: try:
self.send_gatt_pdu(connection.handle, indication.to_bytes()) self.send_gatt_pdu(connection.handle, indication.to_bytes())
await asyncio.wait_for( await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError as error: except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red')) logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error raise TimeoutError(f'GATT timeout for {indication.name}') from error
@@ -427,8 +458,12 @@ class Server(EventEmitter):
self.pending_confirmations[connection.handle] = None self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers( async def notify_or_indicate_subscribers(
self, indicate, attribute, value=None, force=False self,
): indicate: bool,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Get all the connections for which there's at least one subscription # Get all the connections for which there's at least one subscription
connections = [ connections = [
connection connection
@@ -450,13 +485,23 @@ class Server(EventEmitter):
] ]
) )
async def notify_subscribers(self, attribute, value=None, force=False): async def notify_subscribers(
self,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
):
return await self.notify_or_indicate_subscribers(False, attribute, value, force) return await self.notify_or_indicate_subscribers(False, attribute, value, force)
async def indicate_subscribers(self, attribute, value=None, force=False): async def indicate_subscribers(
self,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
):
return await self.notify_or_indicate_subscribers(True, attribute, value, force) return await self.notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection): def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers: if connection.handle in self.subscribers:
del self.subscribers[connection.handle] del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores: if connection.handle in self.indication_semaphores:
@@ -464,7 +509,7 @@ class Server(EventEmitter):
if connection.handle in self.pending_confirmations: if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle] del self.pending_confirmations[connection.handle]
def on_gatt_pdu(self, connection, att_pdu): def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}') logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}' handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None) handler = getattr(self, handler_name, None)
@@ -506,7 +551,7 @@ class Server(EventEmitter):
####################################################### #######################################################
# ATT handlers # ATT handlers
####################################################### #######################################################
def on_att_request(self, connection, pdu): def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
''' '''
Handler for requests without a more specific handler Handler for requests without a more specific handler
''' '''
@@ -679,7 +724,6 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle and attribute.handle <= request.ending_handle
and pdu_space_available and pdu_space_available
): ):
try: try:
attribute_value = attribute.read_value(connection) attribute_value = attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:

View File

@@ -121,6 +121,7 @@ HCI_VERSION_BLUETOOTH_CORE_5_0 = 9
HCI_VERSION_BLUETOOTH_CORE_5_1 = 10 HCI_VERSION_BLUETOOTH_CORE_5_1 = 10
HCI_VERSION_BLUETOOTH_CORE_5_2 = 11 HCI_VERSION_BLUETOOTH_CORE_5_2 = 11
HCI_VERSION_BLUETOOTH_CORE_5_3 = 12 HCI_VERSION_BLUETOOTH_CORE_5_3 = 12
HCI_VERSION_BLUETOOTH_CORE_5_4 = 13
HCI_VERSION_NAMES = { HCI_VERSION_NAMES = {
HCI_VERSION_BLUETOOTH_CORE_1_0B: 'HCI_VERSION_BLUETOOTH_CORE_1_0B', HCI_VERSION_BLUETOOTH_CORE_1_0B: 'HCI_VERSION_BLUETOOTH_CORE_1_0B',
@@ -135,7 +136,8 @@ HCI_VERSION_NAMES = {
HCI_VERSION_BLUETOOTH_CORE_5_0: 'HCI_VERSION_BLUETOOTH_CORE_5_0', HCI_VERSION_BLUETOOTH_CORE_5_0: 'HCI_VERSION_BLUETOOTH_CORE_5_0',
HCI_VERSION_BLUETOOTH_CORE_5_1: 'HCI_VERSION_BLUETOOTH_CORE_5_1', HCI_VERSION_BLUETOOTH_CORE_5_1: 'HCI_VERSION_BLUETOOTH_CORE_5_1',
HCI_VERSION_BLUETOOTH_CORE_5_2: 'HCI_VERSION_BLUETOOTH_CORE_5_2', HCI_VERSION_BLUETOOTH_CORE_5_2: 'HCI_VERSION_BLUETOOTH_CORE_5_2',
HCI_VERSION_BLUETOOTH_CORE_5_3: 'HCI_VERSION_BLUETOOTH_CORE_5_3' HCI_VERSION_BLUETOOTH_CORE_5_3: 'HCI_VERSION_BLUETOOTH_CORE_5_3',
HCI_VERSION_BLUETOOTH_CORE_5_4: 'HCI_VERSION_BLUETOOTH_CORE_5_4',
} }
# LMP Version # LMP Version
@@ -4397,7 +4399,7 @@ class HCI_Event(HCI_Packet):
if len(parameters) != length: if len(parameters) != length:
raise ValueError('invalid packet length') raise ValueError('invalid packet length')
cls: Type[HCI_Event | HCI_LE_Meta_Event] | None cls: Any
if event_code == HCI_LE_META_EVENT: if event_code == HCI_LE_META_EVENT:
# We do this dispatch here and not in the subclass in order to avoid call # We do this dispatch here and not in the subclass in order to avoid call
# loops # loops

View File

@@ -17,6 +17,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import enum
import logging import logging
import struct import struct
@@ -37,6 +39,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
) )
from .utils import deprecated
from .colors import color from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
from .hci import ( from .hci import (
@@ -166,6 +169,34 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01
# pylint: disable=invalid-name # pylint: disable=invalid-name
@dataclasses.dataclass
class ClassicChannelSpec:
psm: Optional[int] = None
mtu: int = L2CAP_MIN_BR_EDR_MTU
@dataclasses.dataclass
class LeCreditBasedChannelSpec:
psm: Optional[int] = None
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
def __post_init__(self):
if (
self.max_credits < 1
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
raise ValueError('MTU too small')
if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
):
raise ValueError('MPS out of range')
class L2CAP_PDU: class L2CAP_PDU:
''' '''
See Bluetooth spec @ Vol 3, Part A - 3 DATA PACKET FORMAT See Bluetooth spec @ Vol 3, Part A - 3 DATA PACKET FORMAT
@@ -675,7 +706,8 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Channel(EventEmitter): class ClassicChannel(EventEmitter):
class State(enum.IntEnum):
# States # States
CLOSED = 0x00 CLOSED = 0x00
WAIT_CONNECT = 0x01 WAIT_CONNECT = 0x01
@@ -699,33 +731,11 @@ class Channel(EventEmitter):
WAIT_FINAL_RSP = 0x16 WAIT_FINAL_RSP = 0x16
WAIT_CONTROL_IND = 0x17 WAIT_CONTROL_IND = 0x17
STATE_NAMES = {
CLOSED: 'CLOSED',
WAIT_CONNECT: 'WAIT_CONNECT',
WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP',
OPEN: 'OPEN',
WAIT_DISCONNECT: 'WAIT_DISCONNECT',
WAIT_CREATE: 'WAIT_CREATE',
WAIT_CREATE_RSP: 'WAIT_CREATE_RSP',
WAIT_MOVE: 'WAIT_MOVE',
WAIT_MOVE_RSP: 'WAIT_MOVE_RSP',
WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM',
WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP',
WAIT_CONFIG: 'WAIT_CONFIG',
WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG',
WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP',
WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP',
WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ',
WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP',
WAIT_FINAL_RSP: 'WAIT_FINAL_RSP',
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
}
connection_result: Optional[asyncio.Future[None]] connection_result: Optional[asyncio.Future[None]]
disconnection_result: Optional[asyncio.Future[None]] disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]] response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]] sink: Optional[Callable[[bytes], Any]]
state: int state: State
connection: Connection connection: Connection
def __init__( def __init__(
@@ -741,7 +751,7 @@ class Channel(EventEmitter):
self.manager = manager self.manager = manager
self.connection = connection self.connection = connection
self.signaling_cid = signaling_cid self.signaling_cid = signaling_cid
self.state = Channel.CLOSED self.state = self.State.CLOSED
self.mtu = mtu self.mtu = mtu
self.psm = psm self.psm = psm
self.source_cid = source_cid self.source_cid = source_cid
@@ -751,13 +761,11 @@ class Channel(EventEmitter):
self.disconnection_result = None self.disconnection_result = None
self.sink = None self.sink = None
def change_state(self, new_state: int) -> None: def _change_state(self, new_state: State) -> None:
logger.debug( logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
)
self.state = new_state self.state = new_state
def send_pdu(self, pdu: SupportsBytes | bytes) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
@@ -767,7 +775,7 @@ class Channel(EventEmitter):
# Check that there isn't already a request pending # Check that there isn't already a request pending
if self.response: if self.response:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.state != Channel.OPEN: if self.state != self.State.OPEN:
raise InvalidStateError('channel not open') raise InvalidStateError('channel not open')
self.response = asyncio.get_running_loop().create_future() self.response = asyncio.get_running_loop().create_future()
@@ -787,14 +795,14 @@ class Channel(EventEmitter):
) )
async def connect(self) -> None: async def connect(self) -> None:
if self.state != Channel.CLOSED: if self.state != self.State.CLOSED:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
# Check that we can start a new connection # Check that we can start a new connection
if self.connection_result: if self.connection_result:
raise RuntimeError('connection already pending') raise RuntimeError('connection already pending')
self.change_state(Channel.WAIT_CONNECT_RSP) self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame( self.send_control_frame(
L2CAP_Connection_Request( L2CAP_Connection_Request(
identifier=self.manager.next_identifier(self.connection), identifier=self.manager.next_identifier(self.connection),
@@ -814,10 +822,10 @@ class Channel(EventEmitter):
self.connection_result = None self.connection_result = None
async def disconnect(self) -> None: async def disconnect(self) -> None:
if self.state != Channel.OPEN: if self.state != self.State.OPEN:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
self.change_state(Channel.WAIT_DISCONNECT) self._change_state(self.State.WAIT_DISCONNECT)
self.send_control_frame( self.send_control_frame(
L2CAP_Disconnection_Request( L2CAP_Disconnection_Request(
identifier=self.manager.next_identifier(self.connection), identifier=self.manager.next_identifier(self.connection),
@@ -832,8 +840,8 @@ class Channel(EventEmitter):
return await self.disconnection_result return await self.disconnection_result
def abort(self) -> None: def abort(self) -> None:
if self.state == self.OPEN: if self.state == self.State.OPEN:
self.change_state(self.CLOSED) self._change_state(self.State.CLOSED)
self.emit('close') self.emit('close')
def send_configure_request(self) -> None: def send_configure_request(self) -> None:
@@ -856,7 +864,7 @@ class Channel(EventEmitter):
def on_connection_request(self, request) -> None: def on_connection_request(self, request) -> None:
self.destination_cid = request.source_cid self.destination_cid = request.source_cid
self.change_state(Channel.WAIT_CONNECT) self._change_state(self.State.WAIT_CONNECT)
self.send_control_frame( self.send_control_frame(
L2CAP_Connection_Response( L2CAP_Connection_Response(
identifier=request.identifier, identifier=request.identifier,
@@ -866,24 +874,24 @@ class Channel(EventEmitter):
status=0x0000, status=0x0000,
) )
) )
self.change_state(Channel.WAIT_CONFIG) self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request() self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_REQ_RSP) self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
def on_connection_response(self, response): def on_connection_response(self, response):
if self.state != Channel.WAIT_CONNECT_RSP: if self.state != self.State.WAIT_CONNECT_RSP:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
return return
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL: if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
self.destination_cid = response.destination_cid self.destination_cid = response.destination_cid
self.change_state(Channel.WAIT_CONFIG) self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request() self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_REQ_RSP) self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING: elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
pass pass
else: else:
self.change_state(Channel.CLOSED) self._change_state(self.State.CLOSED)
self.connection_result.set_exception( self.connection_result.set_exception(
ProtocolError( ProtocolError(
response.result, response.result,
@@ -895,9 +903,9 @@ class Channel(EventEmitter):
def on_configure_request(self, request) -> None: def on_configure_request(self, request) -> None:
if self.state not in ( if self.state not in (
Channel.WAIT_CONFIG, self.State.WAIT_CONFIG,
Channel.WAIT_CONFIG_REQ, self.State.WAIT_CONFIG_REQ,
Channel.WAIT_CONFIG_REQ_RSP, self.State.WAIT_CONFIG_REQ_RSP,
): ):
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
return return
@@ -918,25 +926,28 @@ class Channel(EventEmitter):
options=request.options, # TODO: don't accept everything blindly options=request.options, # TODO: don't accept everything blindly
) )
) )
if self.state == Channel.WAIT_CONFIG: if self.state == self.State.WAIT_CONFIG:
self.change_state(Channel.WAIT_SEND_CONFIG) self._change_state(self.State.WAIT_SEND_CONFIG)
self.send_configure_request() self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_RSP) self._change_state(self.State.WAIT_CONFIG_RSP)
elif self.state == Channel.WAIT_CONFIG_REQ: elif self.state == self.State.WAIT_CONFIG_REQ:
self.change_state(Channel.OPEN) self._change_state(self.State.OPEN)
if self.connection_result: if self.connection_result:
self.connection_result.set_result(None) self.connection_result.set_result(None)
self.connection_result = None self.connection_result = None
self.emit('open') self.emit('open')
elif self.state == Channel.WAIT_CONFIG_REQ_RSP: elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_RSP) self._change_state(self.State.WAIT_CONFIG_RSP)
def on_configure_response(self, response) -> None: def on_configure_response(self, response) -> None:
if response.result == L2CAP_Configure_Response.SUCCESS: if response.result == L2CAP_Configure_Response.SUCCESS:
if self.state == Channel.WAIT_CONFIG_REQ_RSP: if self.state == self.State.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_REQ) self._change_state(self.State.WAIT_CONFIG_REQ)
elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND): elif self.state in (
self.change_state(Channel.OPEN) self.State.WAIT_CONFIG_RSP,
self.State.WAIT_CONTROL_IND,
):
self._change_state(self.State.OPEN)
if self.connection_result: if self.connection_result:
self.connection_result.set_result(None) self.connection_result.set_result(None)
self.connection_result = None self.connection_result = None
@@ -966,7 +977,7 @@ class Channel(EventEmitter):
# TODO: decide how to fail gracefully # TODO: decide how to fail gracefully
def on_disconnection_request(self, request) -> None: def on_disconnection_request(self, request) -> None:
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT): if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT):
self.send_control_frame( self.send_control_frame(
L2CAP_Disconnection_Response( L2CAP_Disconnection_Response(
identifier=request.identifier, identifier=request.identifier,
@@ -974,14 +985,14 @@ class Channel(EventEmitter):
source_cid=request.source_cid, source_cid=request.source_cid,
) )
) )
self.change_state(Channel.CLOSED) self._change_state(self.State.CLOSED)
self.emit('close') self.emit('close')
self.manager.on_channel_closed(self) self.manager.on_channel_closed(self)
else: else:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
def on_disconnection_response(self, response) -> None: def on_disconnection_response(self, response) -> None:
if self.state != Channel.WAIT_DISCONNECT: if self.state != self.State.WAIT_DISCONNECT:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
return return
@@ -992,7 +1003,7 @@ class Channel(EventEmitter):
logger.warning('unexpected source or destination CID') logger.warning('unexpected source or destination CID')
return return
self.change_state(Channel.CLOSED) self._change_state(self.State.CLOSED)
if self.disconnection_result: if self.disconnection_result:
self.disconnection_result.set_result(None) self.disconnection_result.set_result(None)
self.disconnection_result = None self.disconnection_result = None
@@ -1004,16 +1015,17 @@ class Channel(EventEmitter):
f'Channel({self.source_cid}->{self.destination_cid}, ' f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, ' f'PSM={self.psm}, '
f'MTU={self.mtu}, ' f'MTU={self.mtu}, '
f'state={Channel.STATE_NAMES[self.state]})' f'state={self.state.name})'
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class LeConnectionOrientedChannel(EventEmitter): class LeCreditBasedChannel(EventEmitter):
""" """
LE Credit-based Connection Oriented Channel LE Credit-based Connection Oriented Channel
""" """
class State(enum.IntEnum):
INIT = 0 INIT = 0
CONNECTED = 1 CONNECTED = 1
CONNECTING = 2 CONNECTING = 2
@@ -1021,26 +1033,13 @@ class LeConnectionOrientedChannel(EventEmitter):
DISCONNECTED = 4 DISCONNECTED = 4
CONNECTION_ERROR = 5 CONNECTION_ERROR = 5
STATE_NAMES = {
INIT: 'INIT',
CONNECTED: 'CONNECTED',
CONNECTING: 'CONNECTING',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
CONNECTION_ERROR: 'CONNECTION_ERROR',
}
out_queue: Deque[bytes] out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]] connection_result: Optional[asyncio.Future[LeCreditBasedChannel]]
disconnection_result: Optional[asyncio.Future[None]] disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes] out_sdu: Optional[bytes]
state: int state: State
connection: Connection connection: Connection
@staticmethod
def state_name(state: int) -> str:
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
def __init__( def __init__(
self, self,
manager: ChannelManager, manager: ChannelManager,
@@ -1083,30 +1082,28 @@ class LeConnectionOrientedChannel(EventEmitter):
self.drained.set() self.drained.set()
if connected: if connected:
self.state = LeConnectionOrientedChannel.CONNECTED self.state = self.State.CONNECTED
else: else:
self.state = LeConnectionOrientedChannel.INIT self.state = self.State.INIT
def change_state(self, new_state: int) -> None: def _change_state(self, new_state: State) -> None:
logger.debug( logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
self.state = new_state self.state = new_state
if new_state == self.CONNECTED: if new_state == self.State.CONNECTED:
self.emit('open') self.emit('open')
elif new_state == self.DISCONNECTED: elif new_state == self.State.DISCONNECTED:
self.emit('close') self.emit('close')
def send_pdu(self, pdu: SupportsBytes | bytes) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame) self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
async def connect(self) -> LeConnectionOrientedChannel: async def connect(self) -> LeCreditBasedChannel:
# Check that we're in the right state # Check that we're in the right state
if self.state != self.INIT: if self.state != self.State.INIT:
raise InvalidStateError('not in a connectable state') raise InvalidStateError('not in a connectable state')
# Check that we can start a new connection # Check that we can start a new connection
@@ -1114,7 +1111,7 @@ class LeConnectionOrientedChannel(EventEmitter):
if identifier in self.manager.le_coc_requests: if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests') raise RuntimeError('too many concurrent connection requests')
self.change_state(self.CONNECTING) self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request( request = L2CAP_LE_Credit_Based_Connection_Request(
identifier=identifier, identifier=identifier,
le_psm=self.le_psm, le_psm=self.le_psm,
@@ -1134,10 +1131,10 @@ class LeConnectionOrientedChannel(EventEmitter):
async def disconnect(self) -> None: async def disconnect(self) -> None:
# Check that we're connected # Check that we're connected
if self.state != self.CONNECTED: if self.state != self.State.CONNECTED:
raise InvalidStateError('not connected') raise InvalidStateError('not connected')
self.change_state(self.DISCONNECTING) self._change_state(self.State.DISCONNECTING)
self.flush_output() self.flush_output()
self.send_control_frame( self.send_control_frame(
L2CAP_Disconnection_Request( L2CAP_Disconnection_Request(
@@ -1153,15 +1150,15 @@ class LeConnectionOrientedChannel(EventEmitter):
return await self.disconnection_result return await self.disconnection_result
def abort(self) -> None: def abort(self) -> None:
if self.state == self.CONNECTED: if self.state == self.State.CONNECTED:
self.change_state(self.DISCONNECTED) self._change_state(self.State.DISCONNECTED)
def on_pdu(self, pdu: bytes) -> None: def on_pdu(self, pdu: bytes) -> None:
if self.sink is None: if self.sink is None:
logger.warning('received pdu without a sink') logger.warning('received pdu without a sink')
return return
if self.state != self.CONNECTED: if self.state != self.State.CONNECTED:
logger.warning('received PDU while not connected, dropping') logger.warning('received PDU while not connected, dropping')
# Manage the peer credits # Manage the peer credits
@@ -1240,7 +1237,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.credits = response.initial_credits self.credits = response.initial_credits
self.connected = True self.connected = True
self.connection_result.set_result(self) self.connection_result.set_result(self)
self.change_state(self.CONNECTED) self._change_state(self.State.CONNECTED)
else: else:
self.connection_result.set_exception( self.connection_result.set_exception(
ProtocolError( ProtocolError(
@@ -1251,7 +1248,7 @@ class LeConnectionOrientedChannel(EventEmitter):
), ),
) )
) )
self.change_state(self.CONNECTION_ERROR) self._change_state(self.State.CONNECTION_ERROR)
# Cleanup # Cleanup
self.connection_result = None self.connection_result = None
@@ -1271,11 +1268,11 @@ class LeConnectionOrientedChannel(EventEmitter):
source_cid=request.source_cid, source_cid=request.source_cid,
) )
) )
self.change_state(self.DISCONNECTED) self._change_state(self.State.DISCONNECTED)
self.flush_output() self.flush_output()
def on_disconnection_response(self, response) -> None: def on_disconnection_response(self, response) -> None:
if self.state != self.DISCONNECTING: if self.state != self.State.DISCONNECTING:
logger.warning(color('invalid state', 'red')) logger.warning(color('invalid state', 'red'))
return return
@@ -1286,7 +1283,7 @@ class LeConnectionOrientedChannel(EventEmitter):
logger.warning('unexpected source or destination CID') logger.warning('unexpected source or destination CID')
return return
self.change_state(self.DISCONNECTED) self._change_state(self.State.DISCONNECTED)
if self.disconnection_result: if self.disconnection_result:
self.disconnection_result.set_result(None) self.disconnection_result.set_result(None)
self.disconnection_result = None self.disconnection_result = None
@@ -1339,7 +1336,7 @@ class LeConnectionOrientedChannel(EventEmitter):
return return
def write(self, data: bytes) -> None: def write(self, data: bytes) -> None:
if self.state != self.CONNECTED: if self.state != self.State.CONNECTED:
logger.warning('not connected, dropping data') logger.warning('not connected, dropping data')
return return
@@ -1367,7 +1364,7 @@ class LeConnectionOrientedChannel(EventEmitter):
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
f'CoC({self.source_cid}->{self.destination_cid}, ' f'CoC({self.source_cid}->{self.destination_cid}, '
f'State={self.state_name(self.state)}, ' f'State={self.state.name}, '
f'PSM={self.le_psm}, ' f'PSM={self.le_psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, ' f'MTU={self.mtu}/{self.peer_mtu}, '
f'MPS={self.mps}/{self.peer_mps}, ' f'MPS={self.mps}/{self.peer_mps}, '
@@ -1375,15 +1372,67 @@ class LeConnectionOrientedChannel(EventEmitter):
) )
# -----------------------------------------------------------------------------
class ClassicChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[ClassicChannel], Any]],
mtu: int,
) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.mtu = mtu
def on_connection(self, channel: ClassicChannel) -> None:
self.emit('connection', channel)
if self.handler:
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.servers:
del self.manager.servers[self.psm]
# -----------------------------------------------------------------------------
class LeCreditBasedChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[LeCreditBasedChannel], Any]],
max_credits: int,
mtu: int,
mps: int,
) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
def on_connection(self, channel: LeCreditBasedChannel) -> None:
self.emit('connection', channel)
if self.handler:
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.le_coc_servers:
del self.manager.le_coc_servers[self.psm]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ChannelManager: class ChannelManager:
identifiers: Dict[int, int] identifiers: Dict[int, int]
channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]] channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]]
servers: Dict[int, Callable[[Channel], Any]] servers: Dict[int, ClassicChannelServer]
le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]] le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]]
le_coc_servers: Dict[ le_coc_servers: Dict[int, LeCreditBasedChannelServer]
int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int]
]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request] le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
_host: Optional[Host] _host: Optional[Host]
@@ -1462,21 +1511,6 @@ class ChannelManager:
raise RuntimeError('no free CID') raise RuntimeError('no free CID')
@staticmethod
def check_le_coc_parameters(max_credits: int, mtu: int, mps: int) -> None:
if (
max_credits < 1
or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
raise ValueError('MTU too small')
if (
mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
):
raise ValueError('MPS out of range')
def next_identifier(self, connection: Connection) -> int: def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier self.identifiers[connection.handle] = identifier
@@ -1491,8 +1525,22 @@ class ChannelManager:
if cid in self.fixed_channels: if cid in self.fixed_channels:
del self.fixed_channels[cid] del self.fixed_channels[cid]
def register_server(self, psm: int, server: Callable[[Channel], Any]) -> int: @deprecated("Please use create_classic_channel_server")
if psm == 0: def register_server(
self,
psm: int,
server: Callable[[ClassicChannel], Any],
) -> int:
return self.create_classic_server(
handler=server, spec=ClassicChannelSpec(psm=psm)
).psm
def create_classic_server(
self,
spec: ClassicChannelSpec,
handler: Optional[Callable[[ClassicChannel], Any]] = None,
) -> ClassicChannelServer:
if spec.psm is None:
# Find a free PSM # Find a free PSM
for candidate in range( for candidate in range(
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2 L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
@@ -1501,62 +1549,75 @@ class ChannelManager:
continue continue
if candidate in self.servers: if candidate in self.servers:
continue continue
psm = candidate spec.psm = candidate
break break
else: else:
raise InvalidStateError('no free PSM') raise InvalidStateError('no free PSM')
else: else:
# Check that the PSM isn't already in use # Check that the PSM isn't already in use
if psm in self.servers: if spec.psm in self.servers:
raise ValueError('PSM already in use') raise ValueError('PSM already in use')
# Check that the PSM is valid # Check that the PSM is valid
if psm % 2 == 0: if spec.psm % 2 == 0:
raise ValueError('invalid PSM (not odd)') raise ValueError('invalid PSM (not odd)')
check = psm >> 8 check = spec.psm >> 8
while check: while check:
if check % 2 != 0: if check % 2 != 0:
raise ValueError('invalid PSM') raise ValueError('invalid PSM')
check >>= 8 check >>= 8
self.servers[psm] = server self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
return psm return self.servers[spec.psm]
@deprecated("Please use create_le_credit_based_server()")
def register_le_coc_server( def register_le_coc_server(
self, self,
psm: int, psm: int,
server: Callable[[LeConnectionOrientedChannel], Any], server: Callable[[LeCreditBasedChannel], Any],
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, max_credits: int,
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, mtu: int,
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, mps: int,
) -> int: ) -> int:
self.check_le_coc_parameters(max_credits, mtu, mps) return self.create_le_credit_based_server(
spec=LeCreditBasedChannelSpec(
psm=None if psm == 0 else psm, mtu=mtu, mps=mps, max_credits=max_credits
),
handler=server,
).psm
if psm == 0: def create_le_credit_based_server(
self,
spec: LeCreditBasedChannelSpec,
handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None,
) -> LeCreditBasedChannelServer:
if spec.psm is None:
# Find a free PSM # Find a free PSM
for candidate in range( for candidate in range(
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1 L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
): ):
if candidate in self.le_coc_servers: if candidate in self.le_coc_servers:
continue continue
psm = candidate spec.psm = candidate
break break
else: else:
raise InvalidStateError('no free PSM') raise InvalidStateError('no free PSM')
else: else:
# Check that the PSM isn't already in use # Check that the PSM isn't already in use
if psm in self.le_coc_servers: if spec.psm in self.le_coc_servers:
raise ValueError('PSM already in use') raise ValueError('PSM already in use')
self.le_coc_servers[psm] = ( self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
server, self,
max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, spec.psm,
mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, handler,
mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, max_credits=spec.max_credits,
mtu=spec.mtu,
mps=spec.mps,
) )
return psm return self.le_coc_servers[spec.psm]
def on_disconnection(self, connection_handle: int, _reason: int) -> None: def on_disconnection(self, connection_handle: int, _reason: int) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels') logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
@@ -1571,7 +1632,7 @@ class ChannelManager:
if connection_handle in self.identifiers: if connection_handle in self.identifiers:
del self.identifiers[connection_handle] del self.identifiers[connection_handle]
def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None: def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
logger.debug( logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} ' f'{color(">>> Sending L2CAP PDU", "blue")} '
@@ -1683,13 +1744,13 @@ class ChannelManager:
logger.debug( logger.debug(
f'creating server channel with cid={source_cid} for psm {request.psm}' f'creating server channel with cid={source_cid} for psm {request.psm}'
) )
channel = Channel( channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, L2CAP_MIN_BR_EDR_MTU self, connection, cid, request.psm, source_cid, server.mtu
) )
connection_channels[source_cid] = channel connection_channels[source_cid] = channel
# Notify # Notify
server(channel) server.on_connection(channel)
channel.on_connection_request(request) channel.on_connection_request(request)
else: else:
logger.warning( logger.warning(
@@ -1911,7 +1972,7 @@ class ChannelManager:
self, connection: Connection, cid: int, request self, connection: Connection, cid: int, request
) -> None: ) -> None:
if request.le_psm in self.le_coc_servers: if request.le_psm in self.le_coc_servers:
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm] server = self.le_coc_servers[request.le_psm]
# Check that the CID isn't already used # Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault( le_connection_channels = self.le_coc_channels.setdefault(
@@ -1925,8 +1986,8 @@ class ChannelManager:
L2CAP_LE_Credit_Based_Connection_Response( L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier, identifier=request.identifier,
destination_cid=0, destination_cid=0,
mtu=mtu, mtu=server.mtu,
mps=mps, mps=server.mps,
initial_credits=0, initial_credits=0,
# pylint: disable=line-too-long # pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
@@ -1944,8 +2005,8 @@ class ChannelManager:
L2CAP_LE_Credit_Based_Connection_Response( L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier, identifier=request.identifier,
destination_cid=0, destination_cid=0,
mtu=mtu, mtu=server.mtu,
mps=mps, mps=server.mps,
initial_credits=0, initial_credits=0,
# pylint: disable=line-too-long # pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
@@ -1958,18 +2019,18 @@ class ChannelManager:
f'creating LE CoC server channel with cid={source_cid} for psm ' f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}' f'{request.le_psm}'
) )
channel = LeConnectionOrientedChannel( channel = LeCreditBasedChannel(
self, self,
connection, connection,
request.le_psm, request.le_psm,
source_cid, source_cid,
request.source_cid, request.source_cid,
mtu, server.mtu,
mps, server.mps,
request.initial_credits, request.initial_credits,
request.mtu, request.mtu,
request.mps, request.mps,
max_credits, server.max_credits,
True, True,
) )
connection_channels[source_cid] = channel connection_channels[source_cid] = channel
@@ -1982,16 +2043,16 @@ class ChannelManager:
L2CAP_LE_Credit_Based_Connection_Response( L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier, identifier=request.identifier,
destination_cid=source_cid, destination_cid=source_cid,
mtu=mtu, mtu=server.mtu,
mps=mps, mps=server.mps,
initial_credits=max_credits, initial_credits=server.max_credits,
# pylint: disable=line-too-long # pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL, result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
), ),
) )
# Notify # Notify
server(channel) server.on_connection(channel)
else: else:
logger.info( logger.info(
f'No LE server for connection 0x{connection.handle:04X} ' f'No LE server for connection 0x{connection.handle:04X} '
@@ -2046,37 +2107,51 @@ class ChannelManager:
channel.on_credits(credit.credits) channel.on_credits(credit.credits)
def on_channel_closed(self, channel: Channel) -> None: def on_channel_closed(self, channel: ClassicChannel) -> None:
connection_channels = self.channels.get(channel.connection.handle) connection_channels = self.channels.get(channel.connection.handle)
if connection_channels: if connection_channels:
if channel.source_cid in connection_channels: if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid] del connection_channels[channel.source_cid]
@deprecated("Please use create_le_credit_based_channel()")
async def open_le_coc( async def open_le_coc(
self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int
) -> LeConnectionOrientedChannel: ) -> LeCreditBasedChannel:
self.check_le_coc_parameters(max_credits, mtu, mps) return await self.create_le_credit_based_channel(
connection=connection,
spec=LeCreditBasedChannelSpec(
psm=psm, max_credits=max_credits, mtu=mtu, mps=mps
),
)
async def create_le_credit_based_channel(
self,
connection: Connection,
spec: LeCreditBasedChannelSpec,
) -> LeCreditBasedChannel:
# Find a free CID for the new channel # Find a free CID for the new channel
connection_channels = self.channels.setdefault(connection.handle, {}) connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels) source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen! if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use') raise RuntimeError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
# Create the channel # Create the channel
logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}') logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
channel = LeConnectionOrientedChannel( channel = LeCreditBasedChannel(
manager=self, manager=self,
connection=connection, connection=connection,
le_psm=psm, le_psm=spec.psm,
source_cid=source_cid, source_cid=source_cid,
destination_cid=0, destination_cid=0,
mtu=mtu, mtu=spec.mtu,
mps=mps, mps=spec.mps,
credits=0, credits=0,
peer_mtu=0, peer_mtu=0,
peer_mps=0, peer_mps=0,
peer_credits=max_credits, peer_credits=spec.max_credits,
connected=False, connected=False,
) )
connection_channels[source_cid] = channel connection_channels[source_cid] = channel
@@ -2095,7 +2170,15 @@ class ChannelManager:
return channel return channel
async def connect(self, connection: Connection, psm: int) -> Channel: @deprecated("Please use create_classic_channel()")
async def connect(self, connection: Connection, psm: int) -> ClassicChannel:
return await self.create_classic_channel(
connection=connection, spec=ClassicChannelSpec(psm=psm)
)
async def create_classic_channel(
self, connection: Connection, spec: ClassicChannelSpec
) -> ClassicChannel:
# NOTE: this implementation hard-codes BR/EDR # NOTE: this implementation hard-codes BR/EDR
# Find a free CID for a new channel # Find a free CID for a new channel
@@ -2104,10 +2187,20 @@ class ChannelManager:
if source_cid is None: # Should never happen! if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use') raise RuntimeError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
# Create the channel # Create the channel
logger.debug(f'creating client channel with cid={source_cid} for psm {psm}') logger.debug(
channel = Channel( f'creating client channel with cid={source_cid} for psm {spec.psm}'
self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU )
channel = ClassicChannel(
self,
connection,
L2CAP_SIGNALING_CID,
spec.psm,
source_cid,
spec.mtu,
) )
connection_channels[source_cid] = channel connection_channels[source_cid] = channel
@@ -2119,3 +2212,20 @@ class ChannelManager:
raise e raise e
return channel return channel
# -----------------------------------------------------------------------------
# Deprecated Classes
# -----------------------------------------------------------------------------
class Channel(ClassicChannel):
@deprecated("Please use ClassicChannel")
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
class LeConnectionOrientedChannel(LeCreditBasedChannel):
@deprecated("Please use LeCreditBasedChannel")
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import contextlib
import grpc import grpc
import logging import logging
@@ -27,8 +28,8 @@ from bumble.core import (
) )
from bumble.device import Connection as BumbleConnection, Device from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error from bumble.hci import HCI_Error
from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from contextlib import suppress
from google.protobuf import any_pb2 # pytype: disable=pyi-error from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
@@ -232,7 +233,11 @@ class SecurityService(SecurityServicer):
sc=config.pairing_sc_enable, sc=config.pairing_sc_enable,
mitm=config.pairing_mitm_enable, mitm=config.pairing_mitm_enable,
bonding=config.pairing_bonding_enable, bonding=config.pairing_bonding_enable,
identity_address_type=config.identity_address_type, identity_address_type=(
PairingConfig.AddressType.PUBLIC
if connection.self_address.is_public
else config.identity_address_type
),
delegate=PairingDelegate( delegate=PairingDelegate(
connection, connection,
self, self,
@@ -294,23 +299,35 @@ class SecurityService(SecurityServicer):
try: try:
self.log.debug('Pair...') self.log.debug('Pair...')
security_result = asyncio.get_running_loop().create_future()
with contextlib.closing(EventWatcher()) as watcher:
@watcher.on(connection, 'pairing')
def on_pairing(*_: Any) -> None:
security_result.set_result('success')
@watcher.on(connection, 'pairing_failure')
def on_pairing_failure(*_: Any) -> None:
security_result.set_result('pairing_failure')
@watcher.on(connection, 'disconnection')
def on_disconnection(*_: Any) -> None:
security_result.set_result('connection_died')
if ( if (
connection.transport == BT_LE_TRANSPORT connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE and connection.role == BT_PERIPHERAL_ROLE
): ):
wait_for_security: asyncio.Future[
bool
] = asyncio.get_running_loop().create_future()
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
connection.on("pairing_failure", wait_for_security.set_exception)
connection.request_pairing() connection.request_pairing()
await wait_for_security
else: else:
await connection.pair() await connection.pair()
self.log.debug('Paired') result = await security_result
self.log.debug(f'Pairing session complete, status={result}')
if result != 'success':
return SecureResponse(**{result: empty_pb2.Empty()})
except asyncio.CancelledError: except asyncio.CancelledError:
self.log.warning("Connection died during encryption") self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty()) return SecureResponse(connection_died=empty_pb2.Empty())
@@ -369,6 +386,7 @@ class SecurityService(SecurityServicer):
str str
] = asyncio.get_running_loop().create_future() ] = asyncio.get_running_loop().create_future()
authenticate_task: Optional[asyncio.Future[None]] = None authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None: async def authenticate() -> None:
assert connection assert connection
@@ -415,6 +433,10 @@ class SecurityService(SecurityServicer):
if authenticate_task is None: if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate()) authenticate_task = asyncio.create_task(authenticate())
def pair(*_: Any) -> None:
if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair())
listeners: Dict[str, Callable[..., None]] = { listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'), 'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'), 'pairing_failure': set_failure('pairing_failure'),
@@ -425,11 +447,13 @@ class SecurityService(SecurityServicer):
'connection_encryption_change': on_encryption_change, 'connection_encryption_change': on_encryption_change,
'classic_pairing': try_set_success, 'classic_pairing': try_set_success,
'classic_pairing_failure': set_failure('pairing_failure'), 'classic_pairing_failure': set_failure('pairing_failure'),
'security_request': pair,
} }
with contextlib.closing(EventWatcher()) as watcher:
# register event handlers # register event handlers
for event, listener in listeners.items(): for event, listener in listeners.items():
connection.on(event, listener) watcher.on(connection, event, listener)
# security level already reached # security level already reached
if self.reached_security_level(connection, level): if self.reached_security_level(connection, level):
@@ -439,10 +463,6 @@ class SecurityService(SecurityServicer):
kwargs = {} kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty() kwargs[await wait_for_security] = empty_pb2.Empty()
# remove event handlers
for event, listener in listeners.items():
connection.remove_listener(event, listener) # type: ignore
# wait for `authenticate` to finish if any # wait for `authenticate` to finish if any
if authenticate_task is not None: if authenticate_task is not None:
self.log.debug('Wait for authentication...') self.log.debug('Wait for authentication...')
@@ -452,6 +472,15 @@ class SecurityService(SecurityServicer):
pass pass
self.log.debug('Authenticated') self.log.debug('Authenticated')
# wait for `pair` to finish if any
if pair_task is not None:
self.log.debug('Wait for authentication...')
try:
await pair_task # type: ignore
except:
pass
self.log.debug('paired')
return WaitSecurityResponse(**kwargs) return WaitSecurityResponse(**kwargs)
def reached_security_level( def reached_security_level(
@@ -523,7 +552,7 @@ class SecurityStorageService(SecurityStorageServicer):
self.log.debug(f"DeleteBond: {address}") self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None: if self.device.keystore is not None:
with suppress(KeyError): with contextlib.suppress(KeyError):
await self.device.keystore.delete(str(address)) await self.device.keystore.delete(str(address))
return empty_pb2.Empty() return empty_pb2.Empty()

View File

@@ -674,7 +674,7 @@ class Multiplexer(EventEmitter):
acceptor: Optional[Callable[[int], bool]] acceptor: Optional[Callable[[int], bool]]
dlcs: Dict[int, DLC] dlcs: Dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.Channel, role: Role) -> None: def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
super().__init__() super().__init__()
self.role = role self.role = role
self.l2cap_channel = l2cap_channel self.l2cap_channel = l2cap_channel
@@ -887,7 +887,7 @@ class Multiplexer(EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
multiplexer: Optional[Multiplexer] multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.Channel] l2cap_channel: Optional[l2cap.ClassicChannel]
def __init__(self, device: Device, connection: Connection) -> None: def __init__(self, device: Device, connection: Connection) -> None:
self.device = device self.device = device
@@ -960,11 +960,11 @@ class Server(EventEmitter):
self.acceptors[channel] = acceptor self.acceptors[channel] = acceptor
return channel return channel
def on_connection(self, l2cap_channel: l2cap.Channel) -> None: def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}') logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None: def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
# Create a new multiplexer for the channel # Create a new multiplexer for the channel

View File

@@ -758,7 +758,7 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
channel: Optional[l2cap.Channel] channel: Optional[l2cap.ClassicChannel]
def __init__(self, device: Device) -> None: def __init__(self, device: Device) -> None:
self.device = device self.device = device
@@ -921,7 +921,7 @@ class Client:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server: class Server:
CONTINUATION_STATE = bytes([0x01, 0x43]) CONTINUATION_STATE = bytes([0x01, 0x43])
channel: Optional[l2cap.Channel] channel: Optional[l2cap.ClassicChannel]
Service = NewType('Service', List[ServiceAttribute]) Service = NewType('Service', List[ServiceAttribute])
service_records: Dict[int, Service] service_records: Dict[int, Service]
current_response: Union[None, bytes, Tuple[int, List[int]]] current_response: Union[None, bytes, Tuple[int, List[int]]]

View File

@@ -37,6 +37,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Type, Type,
cast,
) )
from pyee import EventEmitter from pyee import EventEmitter
@@ -1771,7 +1772,26 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes()) connection.send_l2cap_pdu(cid, command.to_bytes())
def on_smp_security_request_command(
self, connection: Connection, request: SMP_Security_Request_Command
) -> None:
connection.emit('security_request', request.auth_req)
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None: def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Security request is more than just pairing, so let applications handle them
if command.code == SMP_SECURITY_REQUEST_COMMAND:
self.on_smp_security_request_command(
connection, cast(SMP_Security_Request_Command, command)
)
return
# Look for a session with this connection, and create one if none exists # Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)): if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE: if connection.role == BT_CENTRAL_ROLE:
@@ -1782,13 +1802,6 @@ class Manager(EventEmitter):
) )
self.sessions[connection.handle] = session self.sessions[connection.handle] = session
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Delegate the handling of the command to the session # Delegate the handling of the command to the session
session.on_smp_command(command) session.on_smp_command(command)

View File

@@ -18,6 +18,8 @@
import logging import logging
import grpc.aio import grpc.aio
from typing import Optional, Union
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec: str | None) -> Transport: async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
''' '''
Open a transport connection to an Android emulator via its gRPC interface. Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax: The parameter string has this syntax:
@@ -82,7 +84,7 @@ async def open_android_emulator_transport(spec: str | None) -> Transport:
logger.debug(f'connecting to gRPC server at {server_address}') logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address) channel = grpc.aio.insecure_channel(server_address)
service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
if mode == 'host': if mode == 'host':
# Connect as a host # Connect as a host
service = EmulatedBluetoothServiceStub(channel) service = EmulatedBluetoothServiceStub(channel)
@@ -95,10 +97,13 @@ async def open_android_emulator_transport(spec: str | None) -> Transport:
raise ValueError('invalid mode') raise ValueError('invalid mode')
# Create the transport object # Create the transport object
transport = PumpedTransport( class EmulatorTransport(PumpedTransport):
PumpedPacketSource(hci_device.read), async def close(self):
PumpedPacketSink(hci_device.write), await super().close()
channel.close, await channel.close()
transport = EmulatorTransport(
PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write)
) )
transport.start() transport.start()

View File

@@ -18,11 +18,12 @@
import asyncio import asyncio
import atexit import atexit
import logging import logging
import grpc.aio
import os import os
import pathlib import pathlib
import sys import sys
from typing import Optional from typing import Dict, Optional
import grpc.aio
from .common import ( from .common import (
ParserSource, ParserSource,
@@ -33,8 +34,8 @@ from .common import (
) )
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from .grpc_protobuf.packet_streamer_pb2_grpc import PacketStreamerStub
from .grpc_protobuf.packet_streamer_pb2_grpc import ( from .grpc_protobuf.packet_streamer_pb2_grpc import (
PacketStreamerStub,
PacketStreamerServicer, PacketStreamerServicer,
add_PacketStreamerServicer_to_server, add_PacketStreamerServicer_to_server,
) )
@@ -43,6 +44,7 @@ from .grpc_protobuf.hci_packet_pb2 import HCIPacket
from .grpc_protobuf.startup_pb2 import Chip, ChipInfo from .grpc_protobuf.startup_pb2 import Chip, ChipInfo
from .grpc_protobuf.common_pb2 import ChipKind from .grpc_protobuf.common_pb2 import ChipKind
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -74,14 +76,20 @@ def get_ini_dir() -> Optional[pathlib.Path]:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def find_grpc_port() -> int: def ini_file_name(instance_number: int) -> str:
suffix = f'_{instance_number}' if instance_number > 0 else ''
return f'netsim{suffix}.ini'
# -----------------------------------------------------------------------------
def find_grpc_port(instance_number: int) -> int:
if not (ini_dir := get_ini_dir()): if not (ini_dir := get_ini_dir()):
logger.debug('no known directory for .ini file') logger.debug('no known directory for .ini file')
return 0 return 0
ini_file = ini_dir / 'netsim.ini' ini_file = ini_dir / ini_file_name(instance_number)
logger.debug(f'Looking for .ini file at {ini_file}')
if ini_file.is_file(): if ini_file.is_file():
logger.debug(f'Found .ini file at {ini_file}')
with open(ini_file, 'r') as ini_file_data: with open(ini_file, 'r') as ini_file_data:
for line in ini_file_data.readlines(): for line in ini_file_data.readlines():
if '=' in line: if '=' in line:
@@ -90,12 +98,14 @@ def find_grpc_port() -> int:
logger.debug(f'gRPC port = {value}') logger.debug(f'gRPC port = {value}')
return int(value) return int(value)
logger.debug('no grpc.port property found in .ini file')
# Not found # Not found
return 0 return 0
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def publish_grpc_port(grpc_port) -> bool: def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
if not (ini_dir := get_ini_dir()): if not (ini_dir := get_ini_dir()):
logger.debug('no known directory for .ini file') logger.debug('no known directory for .ini file')
return False return False
@@ -104,7 +114,7 @@ def publish_grpc_port(grpc_port) -> bool:
logger.debug('ini directory does not exist') logger.debug('ini directory does not exist')
return False return False
ini_file = ini_dir / 'netsim.ini' ini_file = ini_dir / ini_file_name(instance_number)
try: try:
ini_file.write_text(f'grpc.port={grpc_port}\n') ini_file.write_text(f'grpc.port={grpc_port}\n')
logger.debug(f"published gRPC port at {ini_file}") logger.debug(f"published gRPC port at {ini_file}")
@@ -122,14 +132,15 @@ def publish_grpc_port(grpc_port) -> bool:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport( async def open_android_netsim_controller_transport(
server_host: str | None, server_port: int server_host: Optional[str], server_port: int, options: Dict[str, str]
) -> Transport: ) -> Transport:
if not server_port: if not server_port:
raise ValueError('invalid port') raise ValueError('invalid port')
if server_host == '_' or not server_host: if server_host == '_' or not server_host:
server_host = 'localhost' server_host = 'localhost'
if not publish_grpc_port(server_port): instance_number = int(options.get('instance', "0"))
if not publish_grpc_port(server_port, instance_number):
logger.warning("unable to publish gRPC port") logger.warning("unable to publish gRPC port")
class HciDevice: class HciDevice:
@@ -186,16 +197,13 @@ async def open_android_netsim_controller_transport(
logger.debug(f'<<< PACKET: {data.hex()}') logger.debug(f'<<< PACKET: {data.hex()}')
self.on_data_received(data) self.on_data_received(data)
def send_packet(self, data): async def send_packet(self, data):
async def send(): return await self.context.write(
await self.context.write(
PacketResponse( PacketResponse(
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:]) hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
) )
) )
self.loop.create_task(send())
def terminate(self): def terminate(self):
self.task.cancel() self.task.cancel()
@@ -228,17 +236,17 @@ async def open_android_netsim_controller_transport(
logger.debug('gRPC server cancelled') logger.debug('gRPC server cancelled')
await self.grpc_server.stop(None) await self.grpc_server.stop(None)
def on_packet(self, packet): async def send_packet(self, packet):
if not self.device: if not self.device:
logger.debug('no device, dropping packet') logger.debug('no device, dropping packet')
return return
self.device.send_packet(packet) return await self.device.send_packet(packet)
async def StreamPackets(self, _request_iterator, context): async def StreamPackets(self, _request_iterator, context):
logger.debug('StreamPackets request') logger.debug('StreamPackets request')
# Check that we won't already have a device # Check that we don't already have a device
if self.device: if self.device:
logger.debug('busy, already serving a device') logger.debug('busy, already serving a device')
return PacketResponse(error='Busy') return PacketResponse(error='Busy')
@@ -261,15 +269,42 @@ async def open_android_netsim_controller_transport(
await server.start() await server.start()
asyncio.get_running_loop().create_task(server.serve()) asyncio.get_running_loop().create_task(server.serve())
class GrpcServerTransport(Transport): sink = PumpedPacketSink(server.send_packet)
async def close(self): sink.start()
await super().close() return Transport(server, sink)
return GrpcServerTransport(server, server)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_netsim_host_transport(server_host, server_port, options): async def open_android_netsim_host_transport_with_address(
server_host: Optional[str],
server_port: int,
options: Optional[Dict[str, str]] = None,
):
if server_host == '_' or not server_host:
server_host = 'localhost'
if not server_port:
# Look for the gRPC config in a .ini file
instance_number = 0 if options is None else int(options.get('instance', '0'))
server_port = find_grpc_port(instance_number)
if not server_port:
raise RuntimeError('gRPC server port not found')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
logger.debug(f'Connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address)
return await open_android_netsim_host_transport_with_channel(
channel,
options,
)
# -----------------------------------------------------------------------------
async def open_android_netsim_host_transport_with_channel(
channel, options: Optional[Dict[str, str]] = None
):
# Wrapper for I/O operations # Wrapper for I/O operations
class HciDevice: class HciDevice:
def __init__(self, name, manufacturer, hci_device): def __init__(self, name, manufacturer, hci_device):
@@ -288,10 +323,12 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
async def read(self): async def read(self):
response = await self.hci_device.read() response = await self.hci_device.read()
response_type = response.WhichOneof('response_type') response_type = response.WhichOneof('response_type')
if response_type == 'error': if response_type == 'error':
logger.warning(f'received error: {response.error}') logger.warning(f'received error: {response.error}')
raise RuntimeError(response.error) raise RuntimeError(response.error)
elif response_type == 'hci_packet':
if response_type == 'hci_packet':
return ( return (
bytes([response.hci_packet.packet_type]) bytes([response.hci_packet.packet_type])
+ response.hci_packet.packet + response.hci_packet.packet
@@ -306,24 +343,9 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
) )
) )
name = options.get('name', DEFAULT_NAME) name = DEFAULT_NAME if options is None else options.get('name', DEFAULT_NAME)
manufacturer = DEFAULT_MANUFACTURER manufacturer = DEFAULT_MANUFACTURER
if server_host == '_' or not server_host:
server_host = 'localhost'
if not server_port:
# Look for the gRPC config in a .ini file
server_host = 'localhost'
server_port = find_grpc_port()
if not server_port:
raise RuntimeError('gRPC server port not found')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
logger.debug(f'Connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address)
# Connect as a host # Connect as a host
service = PacketStreamerStub(channel) service = PacketStreamerStub(channel)
hci_device = HciDevice( hci_device = HciDevice(
@@ -334,10 +356,14 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
await hci_device.start() await hci_device.start()
# Create the transport object # Create the transport object
transport = PumpedTransport( class GrpcTransport(PumpedTransport):
async def close(self):
await super().close()
await channel.close()
transport = GrpcTransport(
PumpedPacketSource(hci_device.read), PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write), PumpedPacketSink(hci_device.write),
channel.close,
) )
transport.start() transport.start()
@@ -345,7 +371,7 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_netsim_transport(spec): async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
''' '''
Open a transport connection as a client or server, implementing Android's `netsim` Open a transport connection as a client or server, implementing Android's `netsim`
simulator protocol over gRPC. simulator protocol over gRPC.
@@ -359,6 +385,11 @@ async def open_android_netsim_transport(spec):
to connect *to* a netsim server (netsim is the controller), or accept to connect *to* a netsim server (netsim is the controller), or accept
connections *as* a netsim-compatible server. connections *as* a netsim-compatible server.
instance=<n>
Specifies an instance number, with <n> > 0. This is used to determine which
.init file to use. In `host` mode, it is ignored when the <host>:<port>
specifier is present, since in that case no .ini file is used.
In `host` mode: In `host` mode:
The <host>:<port> part is optional. When not specified, the transport The <host>:<port> part is optional. When not specified, the transport
looks for a netsim .ini file, from which it will read the `grpc.backend.port` looks for a netsim .ini file, from which it will read the `grpc.backend.port`
@@ -387,14 +418,15 @@ async def open_android_netsim_transport(spec):
params = spec.split(',') if spec else [] params = spec.split(',') if spec else []
if params and ':' in params[0]: if params and ':' in params[0]:
# Explicit <host>:<port> # Explicit <host>:<port>
host, port = params[0].split(':') host, port_str = params[0].split(':')
port = int(port_str)
params_offset = 1 params_offset = 1
else: else:
host = None host = None
port = 0 port = 0
params_offset = 0 params_offset = 0
options = {} options: Dict[str, str] = {}
for param in params[params_offset:]: for param in params[params_offset:]:
if '=' not in param: if '=' not in param:
raise ValueError('invalid parameter, expected <name>=<value>') raise ValueError('invalid parameter, expected <name>=<value>')
@@ -403,10 +435,12 @@ async def open_android_netsim_transport(spec):
mode = options.get('mode', 'host') mode = options.get('mode', 'host')
if mode == 'host': if mode == 'host':
return await open_android_netsim_host_transport(host, port, options) return await open_android_netsim_host_transport_with_address(
host, port, options
)
if mode == 'controller': if mode == 'controller':
if host is None: if host is None:
raise ValueError('<host>:<port> missing') raise ValueError('<host>:<port> missing')
return await open_android_netsim_controller_transport(host, port) return await open_android_netsim_controller_transport(host, port, options)
raise ValueError('invalid mode option') raise ValueError('invalid mode option')

View File

@@ -339,8 +339,9 @@ class PumpedPacketSource(ParserSource):
try: try:
packet = await self.receive_function() packet = await self.receive_function()
self.parser.feed_data(packet) self.parser.feed_data(packet)
except asyncio.exceptions.CancelledError: except asyncio.CancelledError:
logger.debug('source pump task done') logger.debug('source pump task done')
self.terminated.set_result(None)
break break
except Exception as error: except Exception as error:
logger.warning(f'exception while waiting for packet: {error}') logger.warning(f'exception while waiting for packet: {error}')
@@ -370,7 +371,7 @@ class PumpedPacketSink:
try: try:
packet = await self.packet_queue.get() packet = await self.packet_queue.get()
await self.send_function(packet) await self.send_function(packet)
except asyncio.exceptions.CancelledError: except asyncio.CancelledError:
logger.debug('sink pump task done') logger.debug('sink pump task done')
break break
except Exception as error: except Exception as error:
@@ -393,19 +394,13 @@ class PumpedTransport(Transport):
self, self,
source: PumpedPacketSource, source: PumpedPacketSource,
sink: PumpedPacketSink, sink: PumpedPacketSink,
close_function,
) -> None: ) -> None:
super().__init__(source, sink) super().__init__(source, sink)
self.close_function = close_function
def start(self) -> None: def start(self) -> None:
self.source.start() self.source.start()
self.sink.start() self.sink.start()
async def close(self) -> None:
await super().close()
await self.close_function()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SnoopingTransport(Transport): class SnoopingTransport(Transport):

View File

@@ -23,6 +23,8 @@ import socket
import ctypes import ctypes
import collections import collections
from typing import Optional
from .common import Transport, ParserSource from .common import Transport, ParserSource
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_hci_socket_transport(spec: str | None) -> Transport: async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
''' '''
Open an HCI Socket (only available on some platforms). Open an HCI Socket (only available on some platforms).
The parameter string is either empty (to use the first/default Bluetooth adapter) The parameter string is either empty (to use the first/default Bluetooth adapter)
@@ -45,9 +47,9 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
# Create a raw HCI socket # Create a raw HCI socket
try: try:
hci_socket = socket.socket( hci_socket = socket.socket(
socket.AF_BLUETOOTH, socket.AF_BLUETOOTH, # type: ignore[attr-defined]
socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.SOCK_RAW | socket.SOCK_NONBLOCK, # type: ignore[attr-defined]
socket.BTPROTO_HCI, # type: ignore socket.BTPROTO_HCI, # type: ignore[attr-defined]
) )
except AttributeError as error: except AttributeError as error:
# Not supported on this platform # Not supported on this platform
@@ -78,7 +80,7 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
bind_address = struct.pack( bind_address = struct.pack(
# pylint: disable=no-member # pylint: disable=no-member
'<HHH', '<HHH',
socket.AF_BLUETOOTH, socket.AF_BLUETOOTH, # type: ignore[attr-defined]
adapter_index, adapter_index,
HCI_CHANNEL_USER, HCI_CHANNEL_USER,
) )

View File

@@ -23,6 +23,8 @@ import atexit
import os import os
import logging import logging
from typing import Optional
from .common import Transport, StreamPacketSource, StreamPacketSink from .common import Transport, StreamPacketSource, StreamPacketSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_pty_transport(spec: str | None) -> Transport: async def open_pty_transport(spec: Optional[str]) -> Transport:
''' '''
Open a PTY transport. Open a PTY transport.
The parameter string may be empty, or a path name where a symbolic link The parameter string may be empty, or a path name where a symbolic link

View File

@@ -17,6 +17,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
from typing import Optional
from .common import Transport from .common import Transport
from .file import open_file_transport from .file import open_file_transport
@@ -27,7 +29,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_vhci_transport(spec: str | None) -> Transport: async def open_vhci_transport(spec: Optional[str]) -> Transport:
''' '''
Open a VHCI transport (only available on some platforms). Open a VHCI transport (only available on some platforms).
The parameter string is either empty (to use the default VHCI device The parameter string is either empty (to use the default VHCI device

View File

@@ -31,19 +31,21 @@ async def open_ws_client_transport(spec: str) -> Transport:
''' '''
Open a WebSocket client transport. Open a WebSocket client transport.
The parameter string has this syntax: The parameter string has this syntax:
<remote-host>:<remote-port> <websocket-url>
Example: 127.0.0.1:9001 Example: ws://localhost:7681/v1/websocket/bt
''' '''
remote_host, remote_port = spec.split(':') websocket = await websockets.client.connect(spec)
uri = f'ws://{remote_host}:{remote_port}'
websocket = await websockets.client.connect(uri)
transport = PumpedTransport( class WsTransport(PumpedTransport):
async def close(self):
await super().close()
await websocket.close()
transport = WsTransport(
PumpedPacketSource(websocket.recv), PumpedPacketSource(websocket.recv),
PumpedPacketSink(websocket.send), PumpedPacketSink(websocket.send),
websocket.close,
) )
transport.start() transport.start()
return transport return transport

View File

@@ -15,13 +15,26 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import traceback import traceback
import collections import collections
import sys import sys
from typing import Awaitable, Set, TypeVar import warnings
from functools import wraps from typing import (
Awaitable,
Set,
TypeVar,
List,
Tuple,
Callable,
Any,
Optional,
Union,
overload,
)
from functools import wraps, partial
from pyee import EventEmitter from pyee import EventEmitter
from .colors import color from .colors import color
@@ -64,6 +77,102 @@ def composite_listener(cls):
return cls return cls
# -----------------------------------------------------------------------------
_Handler = TypeVar('_Handler', bound=Callable)
class EventWatcher:
'''A wrapper class to control the lifecycle of event handlers better.
Usage:
```
watcher = EventWatcher()
def on_foo():
...
watcher.on(emitter, 'foo', on_foo)
@watcher.on(emitter, 'bar')
def on_bar():
...
# Close all event handlers watching through this watcher
watcher.close()
```
As context:
```
with contextlib.closing(EventWatcher()) as context:
@context.on(emitter, 'foo')
def on_foo():
...
# on_foo() has been removed here!
```
'''
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
def __init__(self) -> None:
self.handlers = []
@overload
def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
...
@overload
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
...
def on(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event until the context is closed.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
'''
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.on(event, f)
return f
return wrapper if handler is None else wrapper(handler)
@overload
def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
...
@overload
def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
...
def once(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event for once.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
'''
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.once(event, f)
return f
return wrapper if handler is None else wrapper(handler)
def close(self) -> None:
for emitter, event, handler in self.handlers:
if handler in emitter.listeners(event):
emitter.remove_listener(event, handler)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_T = TypeVar('_T') _T = TypeVar('_T')
@@ -302,3 +411,36 @@ class FlowControlAsyncPipe:
self.resume_source() self.resume_source()
self.check_pump() self.check_pump()
async def async_call(function, *args, **kwargs):
"""
Immediately calls the function with provided args and kwargs, wrapping it in an async function.
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject a running loop.
result = await async_call(some_function, ...)
"""
return function(*args, **kwargs)
def wrap_async(function):
"""
Wraps the provided function in an async function.
"""
return partial(async_call, function)
def deprecated(msg: str):
"""
Throw deprecation warning before execution
"""
def wrapper(function):
@wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return function(*args, **kwargs)
return inner
return wrapper

388
rust/Cargo.lock generated
View File

@@ -80,6 +80,37 @@ version = "1.0.75"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
[[package]]
name = "argh"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7af5ba06967ff7214ce4c7419c7d185be7ecd6cc4965a8f6e1d8ce0398aad219"
dependencies = [
"argh_derive",
"argh_shared",
]
[[package]]
name = "argh_derive"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56df0aeedf6b7a2fc67d06db35b09684c3e8da0c95f8f27685cb17e08413d87a"
dependencies = [
"argh_shared",
"proc-macro2",
"quote",
"syn 2.0.29",
]
[[package]]
name = "argh_shared"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5693f39141bda5760ecc4111ab08da40565d1771038c4a0250f03457ec707531"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "atty" name = "atty"
version = "0.2.14" version = "0.2.14"
@@ -130,15 +161,37 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635"
[[package]]
name = "block-buffer"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
]
[[package]]
name = "bstr"
version = "1.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c2f7349907b712260e64b0afe2f84692af14a454be26187d9df565c7f69266a"
dependencies = [
"memchr",
"serde",
]
[[package]] [[package]]
name = "bumble" name = "bumble"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes",
"clap 4.4.1", "clap 4.4.1",
"directories", "directories",
"env_logger", "env_logger",
"file-header",
"futures", "futures",
"globset",
"hex", "hex",
"itertools", "itertools",
"lazy_static", "lazy_static",
@@ -146,6 +199,8 @@ dependencies = [
"nix", "nix",
"nom", "nom",
"owo-colors", "owo-colors",
"pdl-derive",
"pdl-runtime",
"pyo3", "pyo3",
"pyo3-asyncio", "pyo3-asyncio",
"rand", "rand",
@@ -166,9 +221,9 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1"
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.4.0" version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
[[package]] [[package]]
name = "cc" name = "cc"
@@ -250,6 +305,16 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961"
[[package]]
name = "codespan-reporting"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
dependencies = [
"termcolor",
"unicode-width",
]
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.0" version = "1.0.0"
@@ -272,6 +337,102 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "cpufeatures"
version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1"
dependencies = [
"libc",
]
[[package]]
name = "crossbeam"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
dependencies = [
"cfg-if",
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
dependencies = [
"cfg-if",
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7"
dependencies = [
"autocfg",
"cfg-if",
"crossbeam-utils",
"memoffset 0.9.0",
"scopeguard",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
dependencies = [
"cfg-if",
]
[[package]]
name = "crypto-common"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "digest"
version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
]
[[package]] [[package]]
name = "directories" name = "directories"
version = "5.0.1" version = "5.0.1"
@@ -348,6 +509,19 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764" checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764"
[[package]]
name = "file-header"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5568149106e77ae33bc3a2c3ef3839cbe63ffa4a8dd4a81612a6f9dfdbc2e9f"
dependencies = [
"crossbeam",
"lazy_static",
"license",
"thiserror",
"walkdir",
]
[[package]] [[package]]
name = "fnv" name = "fnv"
version = "1.0.7" version = "1.0.7"
@@ -467,6 +641,16 @@ dependencies = [
"slab", "slab",
] ]
[[package]]
name = "generic-array"
version = "0.14.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
dependencies = [
"typenum",
"version_check",
]
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.2.10" version = "0.2.10"
@@ -484,6 +668,19 @@ version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
[[package]]
name = "globset"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "759c97c1e17c55525b57192c06a267cda0ac5210b222d6b82189a2338fa1c13d"
dependencies = [
"aho-corasick",
"bstr",
"fnv",
"log",
"regex",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.21" version = "0.3.21"
@@ -710,6 +907,17 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "license"
version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b66615d42e949152327c402e03cd29dab8bff91ce470381ac2ca6d380d8d9946"
dependencies = [
"reword",
"serde",
"serde_json",
]
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.4.5" version = "0.4.5"
@@ -756,6 +964,15 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "memoffset"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "mime" name = "mime"
version = "0.3.17" version = "0.3.17"
@@ -939,12 +1156,100 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "pdl-compiler"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee66995739fb9ddd9155767990a54aadd226ee32408a94f99f94883ff445ceba"
dependencies = [
"argh",
"codespan-reporting",
"heck",
"pest",
"pest_derive",
"prettyplease",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 2.0.29",
]
[[package]]
name = "pdl-derive"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "113e4a1215c407466b36d2c2f6a6318819d6b22ccdd3acb7bb35e27a68806034"
dependencies = [
"codespan-reporting",
"pdl-compiler",
"proc-macro2",
"quote",
"syn 2.0.29",
"termcolor",
]
[[package]]
name = "pdl-runtime"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d36c2f9799613babe78eb5cd9a353d527daaba6c3d1f39a1175657a35790732"
dependencies = [
"bytes",
"thiserror",
]
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.0" version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
[[package]]
name = "pest"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c022f1e7b65d6a24c0dbbd5fb344c66881bc01f3e5ae74a1c8100f2f985d98a4"
dependencies = [
"memchr",
"thiserror",
"ucd-trie",
]
[[package]]
name = "pest_derive"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35513f630d46400a977c4cb58f78e1bfbe01434316e60c37d27b9ad6139c66d8"
dependencies = [
"pest",
"pest_generator",
]
[[package]]
name = "pest_generator"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc9fc1b9e7057baba189b5c626e2d6f40681ae5b6eb064dc7c7834101ec8123a"
dependencies = [
"pest",
"pest_meta",
"proc-macro2",
"quote",
"syn 2.0.29",
]
[[package]]
name = "pest_meta"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1df74e9e7ec4053ceb980e7c0c8bd3594e977fde1af91daba9c928e8e8c6708d"
dependencies = [
"once_cell",
"pest",
"sha2",
]
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.13" version = "0.2.13"
@@ -969,6 +1274,16 @@ version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "prettyplease"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c64d9ba0963cdcea2e1b2230fbae2bab30eb25a174be395c41e764bfb65dd62"
dependencies = [
"proc-macro2",
"syn 2.0.29",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.66" version = "1.0.66"
@@ -1200,6 +1515,15 @@ dependencies = [
"winreg", "winreg",
] ]
[[package]]
name = "reword"
version = "7.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe272098dce9ed76b479995953f748d1851261390b08f8a0ff619c885a1f0765"
dependencies = [
"unicode-segmentation",
]
[[package]] [[package]]
name = "rusb" name = "rusb"
version = "0.9.3" version = "0.9.3"
@@ -1241,6 +1565,15 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "schannel" name = "schannel"
version = "0.1.22" version = "0.1.22"
@@ -1322,6 +1655,17 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "sha2"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.1" version = "1.4.1"
@@ -1568,6 +1912,18 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "ucd-trie"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9"
[[package]] [[package]]
name = "unicode-bidi" name = "unicode-bidi"
version = "0.3.13" version = "0.3.13"
@@ -1589,6 +1945,18 @@ dependencies = [
"tinyvec", "tinyvec",
] ]
[[package]]
name = "unicode-segmentation"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
[[package]]
name = "unicode-width"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85"
[[package]] [[package]]
name = "unindent" name = "unindent"
version = "0.1.11" version = "0.1.11"
@@ -1618,6 +1986,22 @@ version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee"
dependencies = [
"same-file",
"winapi-util",
]
[[package]] [[package]]
name = "want" name = "want"
version = "0.3.1" version = "0.3.1"

View File

@@ -23,6 +23,13 @@ hex = "0.4.3"
itertools = "0.11.0" itertools = "0.11.0"
lazy_static = "1.4.0" lazy_static = "1.4.0"
thiserror = "1.0.41" thiserror = "1.0.41"
bytes = "1.5.0"
pdl-derive = "0.2.0"
pdl-runtime = "0.2.0"
# Dev tools
file-header = { version = "0.1.2", optional = true }
globset = { version = "0.4.13", optional = true }
# CLI # CLI
anyhow = { version = "1.0.71", optional = true } anyhow = { version = "1.0.71", optional = true }
@@ -52,10 +59,15 @@ env_logger = "0.10.0"
[package.metadata.docs.rs] [package.metadata.docs.rs]
rustdoc-args = ["--generate-link-to-definition"] rustdoc-args = ["--generate-link-to-definition"]
[[bin]]
name = "file-header"
path = "tools/file_header.rs"
required-features = ["dev-tools"]
[[bin]] [[bin]]
name = "gen-assigned-numbers" name = "gen-assigned-numbers"
path = "tools/gen_assigned_numbers.rs" path = "tools/gen_assigned_numbers.rs"
required-features = ["bumble-codegen"] required-features = ["dev-tools"]
[[bin]] [[bin]]
name = "bumble" name = "bumble"
@@ -71,7 +83,7 @@ harness = false
[features] [features]
anyhow = ["pyo3/anyhow"] anyhow = ["pyo3/anyhow"]
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"] pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
bumble-codegen = ["dep:anyhow"] dev-tools = ["dep:anyhow", "dep:clap", "dep:file-header", "dep:globset"]
# separate feature for CLI so that dependencies don't spend time building these # separate feature for CLI so that dependencies don't spend time building these
bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"] bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"]
default = [] default = []

View File

@@ -62,5 +62,5 @@ in tests at `pytests/assigned_numbers.rs`.
To regenerate the assigned number tables based on the Python codebase: To regenerate the assigned number tables based on the Python codebase:
``` ```
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features bumble-codegen PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features dev-tools
``` ```

View File

@@ -20,7 +20,8 @@
use bumble::{ use bumble::{
adv::CommonDataType, adv::CommonDataType,
wrapper::{ wrapper::{
core::AdvertisementDataUnit, device::Device, hci::AddressType, transport::Transport, core::AdvertisementDataUnit, device::Device, hci::packets::AddressType,
transport::Transport,
}, },
}; };
use clap::Parser as _; use clap::Parser as _;
@@ -69,9 +70,7 @@ async fn main() -> PyResult<()> {
let mut seen_adv_cache = seen_adv_clone.lock().unwrap(); let mut seen_adv_cache = seen_adv_clone.lock().unwrap();
let expiry_duration = time::Duration::from_secs(cli.dedup_expiry_secs); let expiry_duration = time::Duration::from_secs(cli.dedup_expiry_secs);
let advs_from_addr = seen_adv_cache let advs_from_addr = seen_adv_cache.entry(addr_bytes).or_default();
.entry(addr_bytes)
.or_insert_with(collections::HashMap::new);
// we expect cache hits to be the norm, so we do a separate lookup to avoid cloning // we expect cache hits to be the norm, so we do a separate lookup to avoid cloning
// on every lookup with entry() // on every lookup with entry()
let show = if let Some(prev) = advs_from_addr.get_mut(&data_units) { let show = if let Some(prev) = advs_from_addr.get_mut(&data_units) {
@@ -102,7 +101,9 @@ async fn main() -> PyResult<()> {
}; };
let (type_style, qualifier) = match adv.address()?.address_type()? { let (type_style, qualifier) = match adv.address()?.address_type()? {
AddressType::PublicIdentity | AddressType::PublicDevice => (Style::new().cyan(), ""), AddressType::PublicIdentityAddress | AddressType::PublicDeviceAddress => {
(Style::new().cyan(), "")
}
_ => { _ => {
if addr.is_static()? { if addr.is_static()? {
(Style::new().green(), "(static)") (Style::new().green(), "(static)")

View File

@@ -1,3 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bumble::wrapper::{self, core::Uuid16}; use bumble::wrapper::{self, core::Uuid16};
use pyo3::{intern, prelude::*, types::PyDict}; use pyo3::{intern, prelude::*, types::PyDict};
use std::collections; use std::collections;

View File

@@ -12,9 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use bumble::wrapper::{drivers::rtk::DriverInfo, transport::Transport}; use bumble::wrapper::{
controller::Controller,
device::Device,
drivers::rtk::DriverInfo,
hci::{
packets::{
AddressType, ErrorCode, ReadLocalVersionInformationBuilder,
ReadLocalVersionInformationComplete,
},
Address, Error,
},
host::Host,
link::Link,
transport::Transport,
};
use nix::sys::stat::Mode; use nix::sys::stat::Mode;
use pyo3::PyResult; use pyo3::{
exceptions::PyException,
{PyErr, PyResult},
};
#[pyo3_asyncio::tokio::test] #[pyo3_asyncio::tokio::test]
async fn fifo_transport_can_open() -> PyResult<()> { async fn fifo_transport_can_open() -> PyResult<()> {
@@ -35,3 +52,26 @@ async fn realtek_driver_info_all_drivers() -> PyResult<()> {
assert_eq!(12, DriverInfo::all_drivers()?.len()); assert_eq!(12, DriverInfo::all_drivers()?.len());
Ok(()) Ok(())
} }
#[pyo3_asyncio::tokio::test]
async fn hci_command_wrapper_has_correct_methods() -> PyResult<()> {
let address = Address::new("F0:F1:F2:F3:F4:F5", &AddressType::RandomDeviceAddress)?;
let link = Link::new_local_link()?;
let controller = Controller::new("C1", None, None, Some(link), Some(address.clone())).await?;
let host = Host::new(controller.clone().into(), controller.into()).await?;
let device = Device::new(None, Some(address), None, Some(host), None)?;
device.power_on().await?;
// Send some simple command. A successful response means [HciCommandWrapper] has the minimum
// required interface for the Python code to think its an [HCI_Command] object.
let command = ReadLocalVersionInformationBuilder {};
let event: ReadLocalVersionInformationComplete = device
.send_command(&command.into(), true)
.await?
.try_into()
.map_err(|e: Error| PyErr::new::<PyException, _>(e.to_string()))?;
assert_eq!(ErrorCode::Success, event.get_status());
Ok(())
}

View File

@@ -1,3 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! BLE advertisements. //! BLE advertisements.
use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS}; use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS};

View File

@@ -179,7 +179,7 @@ pub(crate) fn parse(firmware_path: &path::Path) -> PyResult<()> {
pub(crate) async fn info(transport: &str, force: bool) -> PyResult<()> { pub(crate) async fn info(transport: &str, force: bool) -> PyResult<()> {
let transport = Transport::open(transport).await?; let transport = Transport::open(transport).await?;
let mut host = Host::new(transport.source()?, transport.sink()?)?; let mut host = Host::new(transport.source()?, transport.sink()?).await?;
host.reset(DriverFactory::None).await?; host.reset(DriverFactory::None).await?;
if !force && !Driver::check(&host).await? { if !force && !Driver::check(&host).await? {
@@ -203,7 +203,7 @@ pub(crate) async fn info(transport: &str, force: bool) -> PyResult<()> {
pub(crate) async fn load(transport: &str, force: bool) -> PyResult<()> { pub(crate) async fn load(transport: &str, force: bool) -> PyResult<()> {
let transport = Transport::open(transport).await?; let transport = Transport::open(transport).await?;
let mut host = Host::new(transport.source()?, transport.sink()?)?; let mut host = Host::new(transport.source()?, transport.sink()?).await?;
host.reset(DriverFactory::None).await?; host.reset(DriverFactory::None).await?;
match Driver::for_host(&host, force).await? { match Driver::for_host(&host, force).await? {
@@ -219,7 +219,7 @@ pub(crate) async fn load(transport: &str, force: bool) -> PyResult<()> {
pub(crate) async fn drop(transport: &str) -> PyResult<()> { pub(crate) async fn drop(transport: &str) -> PyResult<()> {
let transport = Transport::open(transport).await?; let transport = Transport::open(transport).await?;
let mut host = Host::new(transport.source()?, transport.sink()?)?; let mut host = Host::new(transport.source()?, transport.sink()?).await?;
host.reset(DriverFactory::None).await?; host.reset(DriverFactory::None).await?;
Driver::drop_firmware(&mut host).await?; Driver::drop_firmware(&mut host).await?;

View File

@@ -21,8 +21,7 @@
/// TCP client to connect. /// TCP client to connect.
/// When the L2CAP CoC channel is closed, the TCP connection is closed as well. /// When the L2CAP CoC channel is closed, the TCP connection is closed as well.
use crate::cli::l2cap::{ use crate::cli::l2cap::{
proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, run_future_with_current_task_locals, inject_py_event_loop, proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, BridgeData,
BridgeData,
}; };
use bumble::wrapper::{ use bumble::wrapper::{
device::{Connection, Device}, device::{Connection, Device},
@@ -85,11 +84,12 @@ pub async fn start(args: &Args, device: &mut Device) -> PyResult<()> {
let mtu = args.mtu; let mtu = args.mtu;
let mps = args.mps; let mps = args.mps;
let ble_connection = Arc::new(Mutex::new(ble_connection)); let ble_connection = Arc::new(Mutex::new(ble_connection));
// Ensure Python event loop is available to l2cap `disconnect` // spawn thread to handle incoming tcp connections
let _ = run_future_with_current_task_locals(async move { tokio::spawn(inject_py_event_loop(async move {
while let Ok((tcp_stream, addr)) = listener.accept().await { while let Ok((tcp_stream, addr)) = listener.accept().await {
let ble_connection = ble_connection.clone(); let ble_connection = ble_connection.clone();
let _ = run_future_with_current_task_locals(proxy_data_between_tcp_and_l2cap( // spawn thread to handle this specific tcp connection
if let Ok(future) = inject_py_event_loop(proxy_data_between_tcp_and_l2cap(
ble_connection, ble_connection,
tcp_stream, tcp_stream,
addr, addr,
@@ -97,10 +97,11 @@ pub async fn start(args: &Args, device: &mut Device) -> PyResult<()> {
max_credits, max_credits,
mtu, mtu,
mps, mps,
)); )) {
tokio::spawn(future);
} }
Ok(()) }
}); })?);
Ok(()) Ok(())
} }

View File

@@ -18,7 +18,7 @@ use crate::L2cap;
use anyhow::anyhow; use anyhow::anyhow;
use bumble::wrapper::{device::Device, l2cap::LeConnectionOrientedChannel, transport::Transport}; use bumble::wrapper::{device::Device, l2cap::LeConnectionOrientedChannel, transport::Transport};
use owo_colors::{colors::css::Orange, OwoColorize}; use owo_colors::{colors::css::Orange, OwoColorize};
use pyo3::{PyObject, PyResult, Python}; use pyo3::{PyResult, Python};
use std::{future::Future, path::PathBuf, sync::Arc}; use std::{future::Future, path::PathBuf, sync::Arc};
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
@@ -170,21 +170,12 @@ async fn proxy_tcp_rx_to_l2cap_tx(
} }
} }
/// Copies the current thread's TaskLocals into a Python "awaitable" and encapsulates it in a Rust /// Copies the current thread's Python even loop (contained in `TaskLocals`) into the given future.
/// future, running it as a Python Task. /// Useful when sending work to another thread that calls Python code which calls `get_running_loop()`.
/// `TaskLocals` stores the current event loop, and allows the user to copy the current Python pub fn inject_py_event_loop<F, R>(fut: F) -> PyResult<impl Future<Output = R>>
/// context if necessary. In this case, the python event loop is used when calling `disconnect` on
/// an l2cap connection, or else the call will fail.
pub fn run_future_with_current_task_locals<F>(
fut: F,
) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send>
where where
F: Future<Output = PyResult<()>> + Send + 'static, F: Future<Output = R> + Send + 'static,
{ {
Python::with_gil(|py| { let locals = Python::with_gil(pyo3_asyncio::tokio::get_current_locals)?;
let locals = pyo3_asyncio::tokio::get_current_locals(py)?; Ok(pyo3_asyncio::tokio::scope(locals, fut))
let future = pyo3_asyncio::tokio::scope(locals.clone(), fut);
pyo3_asyncio::tokio::future_into_py_with_locals(py, locals, future)
.and_then(pyo3_asyncio::tokio::into_future)
})
} }

View File

@@ -19,10 +19,7 @@
/// When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket /// When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
/// and waits for a new L2CAP CoC channel to be connected. /// and waits for a new L2CAP CoC channel to be connected.
/// When the TCP connection is closed by the TCP server, the L2CAP connection is closed as well. /// When the TCP connection is closed by the TCP server, the L2CAP connection is closed as well.
use crate::cli::l2cap::{ use crate::cli::l2cap::{proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, BridgeData};
proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, run_future_with_current_task_locals,
BridgeData,
};
use bumble::wrapper::{device::Device, hci::HciConstant, l2cap::LeConnectionOrientedChannel}; use bumble::wrapper::{device::Device, hci::HciConstant, l2cap::LeConnectionOrientedChannel};
use futures::executor::block_on; use futures::executor::block_on;
use owo_colors::OwoColorize; use owo_colors::OwoColorize;
@@ -49,19 +46,19 @@ pub async fn start(args: &Args, device: &mut Device) -> PyResult<()> {
let port = args.tcp_port; let port = args.tcp_port;
device.register_l2cap_channel_server( device.register_l2cap_channel_server(
args.psm, args.psm,
move |_py, l2cap_channel| { move |py, l2cap_channel| {
let channel_info = l2cap_channel let channel_info = l2cap_channel
.debug_string() .debug_string()
.unwrap_or_else(|e| format!("failed to get l2cap channel info ({e})")); .unwrap_or_else(|e| format!("failed to get l2cap channel info ({e})"));
println!("{} {channel_info}", "*** L2CAP channel:".cyan()); println!("{} {channel_info}", "*** L2CAP channel:".cyan());
let host = host.clone(); let host = host.clone();
// Ensure Python event loop is available to l2cap `disconnect` // Handles setting up a tokio runtime that runs this future to completion while also
let _ = run_future_with_current_task_locals(proxy_data_between_l2cap_and_tcp( // containing the necessary context vars.
l2cap_channel, pyo3_asyncio::tokio::future_into_py(
host, py,
port, proxy_data_between_l2cap_and_tcp(l2cap_channel, host, port),
)); )?;
Ok(()) Ok(())
}, },
args.max_credits, args.max_credits,

View File

@@ -143,10 +143,7 @@ pub(crate) fn probe(verbose: bool) -> anyhow::Result<()> {
); );
if let Some(s) = serial { if let Some(s) = serial {
println!("{:26}{}", " Serial:".green(), s); println!("{:26}{}", " Serial:".green(), s);
device_serials_by_id device_serials_by_id.entry(device_id).or_default().insert(s);
.entry(device_id)
.or_insert(HashSet::new())
.insert(s);
} }
if let Some(m) = mfg { if let Some(m) = mfg {
println!("{:26}{}", " Manufacturer:".green(), m); println!("{:26}{}", " Manufacturer:".green(), m);
@@ -314,7 +311,7 @@ impl ClassInfo {
self.protocol, self.protocol,
self.protocol_name() self.protocol_name()
.map(|s| format!(" [{}]", s)) .map(|s| format!(" [{}]", s))
.unwrap_or_else(String::new) .unwrap_or_default()
) )
} }
} }

View File

@@ -0,0 +1,161 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub use pdl_runtime::{Error, Packet};
use crate::internal::hci::packets::{Acl, Command, Event, Sco};
use pdl_derive::pdl;
#[allow(missing_docs, warnings, clippy::all)]
#[pdl("src/internal/hci/packets.pdl")]
pub mod packets {}
#[cfg(test)]
mod tests;
/// HCI Packet type, prepended to the packet.
/// Rootcanal's PDL declaration excludes this from ser/deser and instead is implemented in code.
/// To maintain the ability to easily use future versions of their packet PDL, packet type is
/// implemented here.
#[derive(Debug, PartialEq)]
pub(crate) enum PacketType {
Command = 0x01,
Acl = 0x02,
Sco = 0x03,
Event = 0x04,
}
impl TryFrom<u8> for PacketType {
type Error = PacketTypeParseError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0x01 => Ok(PacketType::Command),
0x02 => Ok(PacketType::Acl),
0x03 => Ok(PacketType::Sco),
0x04 => Ok(PacketType::Event),
_ => Err(PacketTypeParseError::InvalidPacketType { value }),
}
}
}
impl From<PacketType> for u8 {
fn from(packet_type: PacketType) -> Self {
match packet_type {
PacketType::Command => 0x01,
PacketType::Acl => 0x02,
PacketType::Sco => 0x03,
PacketType::Event => 0x04,
}
}
}
/// Allows for smoother interoperability between a [Packet] and a bytes representation of it that
/// includes its type as a header
pub(crate) trait WithPacketType<T: Packet> {
/// Converts the [Packet] into bytes, prefixed with its type
fn to_vec_with_packet_type(self) -> Vec<u8>;
/// Parses a [Packet] out of bytes that are prefixed with the packet's type
fn parse_with_packet_type(bytes: &[u8]) -> Result<T, PacketTypeParseError>;
}
/// Errors that may arise when parsing a packet that is prefixed with its type
#[derive(Debug, PartialEq, thiserror::Error)]
pub(crate) enum PacketTypeParseError {
#[error("The slice being parsed was empty")]
EmptySlice,
#[error("Packet type ({value:#X}) is invalid")]
InvalidPacketType { value: u8 },
#[error("Expected packet type: {expected:?}, but got: {actual:?}")]
PacketTypeMismatch {
expected: PacketType,
actual: PacketType,
},
#[error("Failed to parse packet after header: {error}")]
PacketParse { error: Error },
}
impl From<Error> for PacketTypeParseError {
fn from(error: Error) -> Self {
Self::PacketParse { error }
}
}
impl WithPacketType<Self> for Command {
fn to_vec_with_packet_type(self) -> Vec<u8> {
prepend_packet_type(PacketType::Command, self.to_vec())
}
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
parse_with_expected_packet_type(Command::parse, PacketType::Command, bytes)
}
}
impl WithPacketType<Self> for Acl {
fn to_vec_with_packet_type(self) -> Vec<u8> {
prepend_packet_type(PacketType::Acl, self.to_vec())
}
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
parse_with_expected_packet_type(Acl::parse, PacketType::Acl, bytes)
}
}
impl WithPacketType<Self> for Sco {
fn to_vec_with_packet_type(self) -> Vec<u8> {
prepend_packet_type(PacketType::Sco, self.to_vec())
}
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
parse_with_expected_packet_type(Sco::parse, PacketType::Sco, bytes)
}
}
impl WithPacketType<Self> for Event {
fn to_vec_with_packet_type(self) -> Vec<u8> {
prepend_packet_type(PacketType::Event, self.to_vec())
}
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
parse_with_expected_packet_type(Event::parse, PacketType::Event, bytes)
}
}
fn prepend_packet_type(packet_type: PacketType, mut packet_bytes: Vec<u8>) -> Vec<u8> {
packet_bytes.insert(0, packet_type.into());
packet_bytes
}
fn parse_with_expected_packet_type<T: Packet, F, E>(
parser: F,
expected_packet_type: PacketType,
bytes: &[u8],
) -> Result<T, PacketTypeParseError>
where
F: Fn(&[u8]) -> Result<T, E>,
PacketTypeParseError: From<E>,
{
let (first_byte, packet_bytes) = bytes
.split_first()
.ok_or(PacketTypeParseError::EmptySlice)?;
let actual_packet_type = PacketType::try_from(*first_byte)?;
if actual_packet_type == expected_packet_type {
Ok(parser(packet_bytes)?)
} else {
Err(PacketTypeParseError::PacketTypeMismatch {
expected: expected_packet_type,
actual: actual_packet_type,
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,94 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::internal::hci::{
packets::{Event, EventBuilder, EventCode, Sco},
parse_with_expected_packet_type, prepend_packet_type, Error, Packet, PacketType,
PacketTypeParseError, WithPacketType,
};
use bytes::Bytes;
#[test]
fn prepends_packet_type() {
let packet_type = PacketType::Event;
let packet_bytes = vec![0x00, 0x00, 0x00, 0x00];
let actual = prepend_packet_type(packet_type, packet_bytes);
assert_eq!(vec![0x04, 0x00, 0x00, 0x00, 0x00], actual);
}
#[test]
fn parse_empty_slice_should_error() {
let actual = parse_with_expected_packet_type(FakePacket::parse, PacketType::Event, &[]);
assert_eq!(Err(PacketTypeParseError::EmptySlice), actual);
}
#[test]
fn parse_invalid_packet_type_should_error() {
let actual = parse_with_expected_packet_type(FakePacket::parse, PacketType::Event, &[0xFF]);
assert_eq!(
Err(PacketTypeParseError::InvalidPacketType { value: 0xFF }),
actual
);
}
#[test]
fn parse_mismatched_packet_type_should_error() {
let actual = parse_with_expected_packet_type(FakePacket::parse, PacketType::Acl, &[0x01]);
assert_eq!(
Err(PacketTypeParseError::PacketTypeMismatch {
expected: PacketType::Acl,
actual: PacketType::Command
}),
actual
);
}
#[test]
fn parse_invalid_packet_should_error() {
let actual = parse_with_expected_packet_type(Sco::parse, PacketType::Sco, &[0x03]);
assert!(actual.is_err());
}
#[test]
fn test_packet_roundtrip_with_type() {
let event_packet = EventBuilder {
event_code: EventCode::InquiryComplete,
payload: None,
}
.build();
let event_packet_bytes = event_packet.clone().to_vec_with_packet_type();
let actual =
parse_with_expected_packet_type(Event::parse, PacketType::Event, &event_packet_bytes)
.unwrap();
assert_eq!(event_packet, actual);
}
#[derive(Debug, PartialEq)]
struct FakePacket;
impl FakePacket {
fn parse(_bytes: &[u8]) -> Result<Self, Error> {
Ok(Self)
}
}
impl Packet for FakePacket {
fn to_bytes(self) -> Bytes {
Bytes::new()
}
fn to_vec(self) -> Vec<u8> {
Vec::new()
}
}

View File

@@ -18,3 +18,4 @@
//! to discover. //! to discover.
pub(crate) mod drivers; pub(crate) mod drivers;
pub(crate) mod hci;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Shared resources found under bumble's common.py
use pyo3::{PyObject, Python, ToPyObject};
/// Represents the sink for some transport mechanism
pub struct TransportSink(pub(crate) PyObject);
impl ToPyObject for TransportSink {
fn to_object(&self, _py: Python<'_>) -> PyObject {
self.0.clone()
}
}
/// Represents the source for some transport mechanism
pub struct TransportSource(pub(crate) PyObject);
impl ToPyObject for TransportSource {
fn to_object(&self, _py: Python<'_>) -> PyObject {
self.0.clone()
}
}

View File

@@ -0,0 +1,66 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Controller components
use crate::wrapper::{
common::{TransportSink, TransportSource},
hci::Address,
link::Link,
wrap_python_async, PyDictExt,
};
use pyo3::{
intern,
types::{PyDict, PyModule},
PyObject, PyResult, Python,
};
use pyo3_asyncio::tokio::into_future;
/// A controller that can send and receive HCI frames via some link
#[derive(Clone)]
pub struct Controller(pub(crate) PyObject);
impl Controller {
/// Creates a new [Controller] object. When optional arguments are not specified, the Python
/// module specifies the defaults. Must be called from a thread with a Python event loop, which
/// should be true on `tokio::main` and `async_std::main`.
///
/// For more info, see https://awestlake87.github.io/pyo3-asyncio/master/doc/pyo3_asyncio/#event-loop-references-and-contextvars.
pub async fn new(
name: &str,
host_source: Option<TransportSource>,
host_sink: Option<TransportSink>,
link: Option<Link>,
public_address: Option<Address>,
) -> PyResult<Self> {
Python::with_gil(|py| {
let controller_ctr = PyModule::import(py, intern!(py, "bumble.controller"))?
.getattr(intern!(py, "Controller"))?;
let kwargs = PyDict::new(py);
kwargs.set_item("name", name)?;
kwargs.set_opt_item("host_source", host_source)?;
kwargs.set_opt_item("host_sink", host_sink)?;
kwargs.set_opt_item("link", link)?;
kwargs.set_opt_item("public_address", public_address)?;
// Controller constructor (`__init__`) is not (and can't be) marked async, but calls
// `get_running_loop`, and thus needs wrapped in an async function.
wrap_python_async(py, controller_ctr)?
.call((), Some(kwargs))
.and_then(into_future)
})?
.await
.map(Self)
}
}

View File

@@ -14,12 +14,16 @@
//! Devices and connections to them //! Devices and connections to them
use crate::internal::hci::WithPacketType;
use crate::{ use crate::{
adv::AdvertisementDataBuilder, adv::AdvertisementDataBuilder,
wrapper::{ wrapper::{
core::AdvertisingData, core::AdvertisingData,
gatt_client::{ProfileServiceProxy, ServiceProxy}, gatt_client::{ProfileServiceProxy, ServiceProxy},
hci::{Address, HciErrorCode}, hci::{
packets::{Command, ErrorCode, Event},
Address, HciCommandWrapper,
},
host::Host, host::Host,
l2cap::LeConnectionOrientedChannel, l2cap::LeConnectionOrientedChannel,
transport::{Sink, Source}, transport::{Sink, Source},
@@ -27,18 +31,73 @@ use crate::{
}, },
}; };
use pyo3::{ use pyo3::{
exceptions::PyException,
intern, intern,
types::{PyDict, PyModule}, types::{PyDict, PyModule},
IntoPy, PyObject, PyResult, Python, ToPyObject, IntoPy, PyErr, PyObject, PyResult, Python, ToPyObject,
}; };
use pyo3_asyncio::tokio::into_future; use pyo3_asyncio::tokio::into_future;
use std::path; use std::path;
/// Represents the various properties of some device
pub struct DeviceConfiguration(PyObject);
impl DeviceConfiguration {
/// Creates a new configuration, letting the internal Python object set all the defaults
pub fn new() -> PyResult<DeviceConfiguration> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "DeviceConfiguration"))?
.call0()
.map(|any| Self(any.into()))
})
}
/// Creates a new configuration from the specified file
pub fn load_from_file(&mut self, device_config: &path::Path) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "load_from_file"), (device_config,))
})
.map(|_| ())
}
}
impl ToPyObject for DeviceConfiguration {
fn to_object(&self, _py: Python<'_>) -> PyObject {
self.0.clone()
}
}
/// A device that can send/receive HCI frames. /// A device that can send/receive HCI frames.
#[derive(Clone)] #[derive(Clone)]
pub struct Device(PyObject); pub struct Device(PyObject);
impl Device { impl Device {
/// Creates a Device. When optional arguments are not specified, the Python object specifies the
/// defaults.
pub fn new(
name: Option<&str>,
address: Option<Address>,
config: Option<DeviceConfiguration>,
host: Option<Host>,
generic_access_service: Option<bool>,
) -> PyResult<Self> {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_opt_item("name", name)?;
kwargs.set_opt_item("address", address)?;
kwargs.set_opt_item("config", config)?;
kwargs.set_opt_item("host", host)?;
kwargs.set_opt_item("generic_access_service", generic_access_service)?;
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Device"))?
.call((), Some(kwargs))
.map(|any| Self(any.into()))
})
}
/// Create a Device per the provided file configured to communicate with a controller through an HCI source/sink /// Create a Device per the provided file configured to communicate with a controller through an HCI source/sink
pub fn from_config_file_with_hci( pub fn from_config_file_with_hci(
device_config: &path::Path, device_config: &path::Path,
@@ -66,6 +125,29 @@ impl Device {
}) })
} }
/// Sends an HCI command on this Device, returning the command's event result.
pub async fn send_command(&self, command: &Command, check_result: bool) -> PyResult<Event> {
Python::with_gil(|py| {
self.0
.call_method1(
py,
intern!(py, "send_command"),
(HciCommandWrapper(command.clone()), check_result),
)
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
})?
.await
.and_then(|event| {
Python::with_gil(|py| {
let py_bytes = event.call_method0(py, intern!(py, "__bytes__"))?;
let bytes: &[u8] = py_bytes.extract(py)?;
let event = Event::parse_with_packet_type(bytes)
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))?;
Ok(event)
})
})
}
/// Turn the device on /// Turn the device on
pub async fn power_on(&self) -> PyResult<()> { pub async fn power_on(&self) -> PyResult<()> {
Python::with_gil(|py| { Python::with_gil(|py| {
@@ -236,7 +318,7 @@ impl Connection {
kwargs.set_opt_item("mps", mps)?; kwargs.set_opt_item("mps", mps)?;
self.0 self.0
.call_method(py, intern!(py, "open_l2cap_channel"), (), Some(kwargs)) .call_method(py, intern!(py, "open_l2cap_channel"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py))) .and_then(|coroutine| into_future(coroutine.as_ref(py)))
})? })?
.await .await
.map(LeConnectionOrientedChannel::from) .map(LeConnectionOrientedChannel::from)
@@ -244,13 +326,13 @@ impl Connection {
/// Disconnect from device with provided reason. When optional arguments are not specified, the /// Disconnect from device with provided reason. When optional arguments are not specified, the
/// Python module specifies the defaults. /// Python module specifies the defaults.
pub async fn disconnect(&mut self, reason: Option<HciErrorCode>) -> PyResult<()> { pub async fn disconnect(&mut self, reason: Option<ErrorCode>) -> PyResult<()> {
Python::with_gil(|py| { Python::with_gil(|py| {
let kwargs = PyDict::new(py); let kwargs = PyDict::new(py);
kwargs.set_opt_item("reason", reason)?; kwargs.set_opt_item("reason", reason)?;
self.0 self.0
.call_method(py, intern!(py, "disconnect"), (), Some(kwargs)) .call_method(py, intern!(py, "disconnect"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py))) .and_then(|coroutine| into_future(coroutine.as_ref(py)))
})? })?
.await .await
.map(|_| ()) .map(|_| ())
@@ -259,7 +341,7 @@ impl Connection {
/// Register a callback to be called on disconnection. /// Register a callback to be called on disconnection.
pub fn on_disconnection( pub fn on_disconnection(
&mut self, &mut self,
callback: impl Fn(Python, HciErrorCode) -> PyResult<()> + Send + 'static, callback: impl Fn(Python, ErrorCode) -> PyResult<()> + Send + 'static,
) -> PyResult<()> { ) -> PyResult<()> {
let boxed = ClosureCallback::new(move |py, args, _kwargs| { let boxed = ClosureCallback::new(move |py, args, _kwargs| {
callback(py, args.get_item(0)?.extract()?) callback(py, args.get_item(0)?.extract()?)

View File

@@ -14,84 +14,62 @@
//! HCI //! HCI
pub use crate::internal::hci::{packets, Error, Packet};
use crate::{
internal::hci::WithPacketType,
wrapper::hci::packets::{AddressType, Command, ErrorCode},
};
use itertools::Itertools as _; use itertools::Itertools as _;
use pyo3::{ use pyo3::{
exceptions::PyException, intern, types::PyModule, FromPyObject, PyAny, PyErr, PyObject, exceptions::PyException,
PyResult, Python, ToPyObject, intern, pyclass, pymethods,
types::{PyBytes, PyModule},
FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, Python, ToPyObject,
}; };
/// HCI error code.
pub struct HciErrorCode(u8);
impl<'source> FromPyObject<'source> for HciErrorCode {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
Ok(HciErrorCode(ob.extract()?))
}
}
impl ToPyObject for HciErrorCode {
fn to_object(&self, py: Python<'_>) -> PyObject {
self.0.to_object(py)
}
}
/// Provides helpers for interacting with HCI /// Provides helpers for interacting with HCI
pub struct HciConstant; pub struct HciConstant;
impl HciConstant { impl HciConstant {
/// Human-readable error name /// Human-readable error name
pub fn error_name(status: HciErrorCode) -> PyResult<String> { pub fn error_name(status: ErrorCode) -> PyResult<String> {
Python::with_gil(|py| { Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.hci"))? PyModule::import(py, intern!(py, "bumble.hci"))?
.getattr(intern!(py, "HCI_Constant"))? .getattr(intern!(py, "HCI_Constant"))?
.call_method1(intern!(py, "error_name"), (status.0,))? .call_method1(intern!(py, "error_name"), (status.to_object(py),))?
.extract() .extract()
}) })
} }
} }
/// A Bluetooth address /// A Bluetooth address
#[derive(Clone)]
pub struct Address(pub(crate) PyObject); pub struct Address(pub(crate) PyObject);
impl Address { impl Address {
/// Creates a new [Address] object
pub fn new(address: &str, address_type: &AddressType) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Address"))?
.call1((address, address_type.to_object(py)))
.map(|any| Self(any.into()))
})
}
/// The type of address /// The type of address
pub fn address_type(&self) -> PyResult<AddressType> { pub fn address_type(&self) -> PyResult<AddressType> {
Python::with_gil(|py| { Python::with_gil(|py| {
let addr_type = self self.0
.0
.getattr(py, intern!(py, "address_type"))? .getattr(py, intern!(py, "address_type"))?
.extract::<u32>(py)?; .extract::<u8>(py)?
.try_into()
let module = PyModule::import(py, intern!(py, "bumble.hci"))?; .map_err(|addr_type| {
let klass = module.getattr(intern!(py, "Address"))?; PyErr::new::<PyException, _>(format!(
"Failed to convert {addr_type} to AddressType"
if addr_type ))
== klass })
.getattr(intern!(py, "PUBLIC_DEVICE_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::PublicDevice)
} else if addr_type
== klass
.getattr(intern!(py, "RANDOM_DEVICE_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::RandomDevice)
} else if addr_type
== klass
.getattr(intern!(py, "PUBLIC_IDENTITY_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::PublicIdentity)
} else if addr_type
== klass
.getattr(intern!(py, "RANDOM_IDENTITY_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::RandomIdentity)
} else {
Err(PyErr::new::<PyException, _>("Invalid address type"))
}
}) })
} }
@@ -134,12 +112,45 @@ impl Address {
} }
} }
/// BT address types impl ToPyObject for Address {
#[allow(missing_docs)] fn to_object(&self, _py: Python<'_>) -> PyObject {
#[derive(PartialEq, Eq, Debug)] self.0.clone()
pub enum AddressType { }
PublicDevice, }
RandomDevice,
PublicIdentity, /// Implements minimum necessary interface to be treated as bumble's [HCI_Command].
RandomIdentity, /// While pyo3's macros do not support generics, this could probably be refactored to allow multiple
/// implementations of the HCI_Command methods in the future, if needed.
#[pyclass]
pub(crate) struct HciCommandWrapper(pub(crate) Command);
#[pymethods]
impl HciCommandWrapper {
fn __bytes__(&self, py: Python) -> PyResult<PyObject> {
let bytes = PyBytes::new(py, &self.0.clone().to_vec_with_packet_type());
Ok(bytes.into_py(py))
}
#[getter]
fn op_code(&self) -> u16 {
self.0.get_op_code().into()
}
}
impl ToPyObject for AddressType {
fn to_object(&self, py: Python<'_>) -> PyObject {
u8::from(self).to_object(py)
}
}
impl<'source> FromPyObject<'source> for ErrorCode {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
ob.extract()
}
}
impl ToPyObject for ErrorCode {
fn to_object(&self, py: Python<'_>) -> PyObject {
u8::from(self).to_object(py)
}
} }

View File

@@ -14,8 +14,12 @@
//! Host-side types //! Host-side types
use crate::wrapper::transport::{Sink, Source}; use crate::wrapper::{
use pyo3::{intern, prelude::PyModule, types::PyDict, PyObject, PyResult, Python}; transport::{Sink, Source},
wrap_python_async,
};
use pyo3::{intern, prelude::PyModule, types::PyDict, PyObject, PyResult, Python, ToPyObject};
use pyo3_asyncio::tokio::into_future;
/// Host HCI commands /// Host HCI commands
pub struct Host { pub struct Host {
@@ -29,13 +33,23 @@ impl Host {
} }
/// Create a new Host /// Create a new Host
pub fn new(source: Source, sink: Sink) -> PyResult<Self> { pub async fn new(source: Source, sink: Sink) -> PyResult<Self> {
Python::with_gil(|py| { Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.host"))? let host_ctr =
.getattr(intern!(py, "Host"))? PyModule::import(py, intern!(py, "bumble.host"))?.getattr(intern!(py, "Host"))?;
.call((source.0, sink.0), None)
.map(|any| Self { obj: any.into() }) let kwargs = PyDict::new(py);
}) kwargs.set_item("controller_source", source.0)?;
kwargs.set_item("controller_sink", sink.0)?;
// Needed for Python 3.8-3.9, in which the Semaphore object, when constructed, calls
// `get_event_loop`.
wrap_python_async(py, host_ctr)?
.call((), Some(kwargs))
.and_then(into_future)
})?
.await
.map(|any| Self { obj: any })
} }
/// Send a reset command and perform other reset tasks. /// Send a reset command and perform other reset tasks.
@@ -61,6 +75,12 @@ impl Host {
} }
} }
impl ToPyObject for Host {
fn to_object(&self, _py: Python<'_>) -> PyObject {
self.obj.clone()
}
}
/// Driver factory to use when initializing a host /// Driver factory to use when initializing a host
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum DriverFactory { pub enum DriverFactory {

38
rust/src/wrapper/link.rs Normal file
View File

@@ -0,0 +1,38 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Link components
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python, ToPyObject};
/// Link bus for controllers to communicate with each other
#[derive(Clone)]
pub struct Link(pub(crate) PyObject);
impl Link {
/// Creates a [Link] object that transports messages locally
pub fn new_local_link() -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.link"))?
.getattr(intern!(py, "LocalLink"))?
.call0()
.map(|any| Self(any.into()))
})
}
}
impl ToPyObject for Link {
fn to_object(&self, _py: Python<'_>) -> PyObject {
self.0.clone()
}
}

View File

@@ -1,3 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Bumble & Python logging //! Bumble & Python logging
use pyo3::types::PyDict; use pyo3::types::PyDict;

View File

@@ -22,13 +22,17 @@
// Re-exported to make it easy for users to depend on the same `PyObject`, etc // Re-exported to make it easy for users to depend on the same `PyObject`, etc
pub use pyo3; pub use pyo3;
pub use pyo3_asyncio;
use pyo3::{ use pyo3::{
intern,
prelude::*, prelude::*,
types::{PyDict, PyTuple}, types::{PyDict, PyTuple},
}; };
pub use pyo3_asyncio;
pub mod assigned_numbers; pub mod assigned_numbers;
pub mod common;
pub mod controller;
pub mod core; pub mod core;
pub mod device; pub mod device;
pub mod drivers; pub mod drivers;
@@ -36,6 +40,7 @@ pub mod gatt_client;
pub mod hci; pub mod hci;
pub mod host; pub mod host;
pub mod l2cap; pub mod l2cap;
pub mod link;
pub mod logging; pub mod logging;
pub mod profile; pub mod profile;
pub mod transport; pub mod transport;
@@ -119,3 +124,11 @@ impl ClosureCallback {
(self.callback)(py, args, kwargs).map(|_| py.None()) (self.callback)(py, args, kwargs).map(|_| py.None())
} }
} }
/// Wraps the Python function in a Python async function. `pyo3_asyncio` needs functions to be
/// marked async to properly inject a running loop.
pub(crate) fn wrap_python_async<'a>(py: Python<'a>, function: &'a PyAny) -> PyResult<&'a PyAny> {
PyModule::import(py, intern!(py, "bumble.utils"))?
.getattr(intern!(py, "wrap_async"))?
.call1((function,))
}

View File

@@ -14,6 +14,7 @@
//! HCI packet transport //! HCI packet transport
use crate::wrapper::controller::Controller;
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python}; use pyo3::{intern, types::PyModule, PyObject, PyResult, Python};
/// A source/sink pair for HCI packet I/O. /// A source/sink pair for HCI packet I/O.
@@ -67,6 +68,18 @@ impl Drop for Transport {
#[derive(Clone)] #[derive(Clone)]
pub struct Source(pub(crate) PyObject); pub struct Source(pub(crate) PyObject);
impl From<Controller> for Source {
fn from(value: Controller) -> Self {
Self(value.0)
}
}
/// The sink side of a [Transport]. /// The sink side of a [Transport].
#[derive(Clone)] #[derive(Clone)]
pub struct Sink(pub(crate) PyObject); pub struct Sink(pub(crate) PyObject);
impl From<Controller> for Sink {
fn from(value: Controller) -> Self {
Self(value.0)
}
}

78
rust/tools/file_header.rs Normal file
View File

@@ -0,0 +1,78 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::anyhow;
use clap::Parser as _;
use file_header::{
add_headers_recursively, check_headers_recursively,
license::spdx::{YearCopyrightOwnerValue, APACHE_2_0},
};
use globset::{Glob, GlobSet, GlobSetBuilder};
use std::{env, path::PathBuf};
fn main() -> anyhow::Result<()> {
let rust_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
let ignore_globset = ignore_globset()?;
// Note: when adding headers, there is a bug where the line spacing is off for Apache 2.0 (see https://github.com/spdx/license-list-XML/issues/2127)
let header = APACHE_2_0.build_header(YearCopyrightOwnerValue::new(2023, "Google LLC".into()));
let cli = Cli::parse();
match cli.subcommand {
Subcommand::CheckAll => {
let result =
check_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header, 4)?;
if result.has_failure() {
return Err(anyhow!(
"The following files do not have headers: {result:?}"
));
}
}
Subcommand::AddAll => {
let files_with_new_header =
add_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header)?;
files_with_new_header
.iter()
.for_each(|path| println!("Added header to: {path:?}"));
}
}
Ok(())
}
fn ignore_globset() -> anyhow::Result<GlobSet> {
Ok(GlobSetBuilder::new()
.add(Glob::new("**/.idea/**")?)
.add(Glob::new("**/target/**")?)
.add(Glob::new("**/.gitignore")?)
.add(Glob::new("**/CHANGELOG.md")?)
.add(Glob::new("**/Cargo.lock")?)
.add(Glob::new("**/Cargo.toml")?)
.add(Glob::new("**/README.md")?)
.add(Glob::new("*.bin")?)
.build()?)
}
#[derive(clap::Parser)]
struct Cli {
#[clap(subcommand)]
subcommand: Subcommand,
}
#[derive(clap::Subcommand, Debug, Clone)]
enum Subcommand {
/// Checks if a license is present in files that are not in the ignore list.
CheckAll,
/// Adds a license as needed to files that are not in the ignore list.
AddAll,
}

View File

@@ -36,6 +36,10 @@ install_requires =
bt-test-interfaces >= 0.0.2; platform_system!='Emscripten' bt-test-interfaces >= 0.0.2; platform_system!='Emscripten'
click == 8.1.3; platform_system!='Emscripten' click == 8.1.3; platform_system!='Emscripten'
cryptography == 39; platform_system!='Emscripten' cryptography == 39; platform_system!='Emscripten'
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
# versions available on PyPI. Relax the version requirement since it's better than being
# completely unable to import the package in case of version mismatch.
cryptography >= 39.0; platform_system=='Emscripten'
grpcio == 1.57.0; platform_system!='Emscripten' grpcio == 1.57.0; platform_system!='Emscripten'
humanize >= 4.6.0; platform_system!='Emscripten' humanize >= 4.6.0; platform_system!='Emscripten'
libusb1 >= 2.0.1; platform_system!='Emscripten' libusb1 >= 2.0.1; platform_system!='Emscripten'
@@ -84,12 +88,16 @@ development =
black == 22.10 black == 22.10
grpcio-tools >= 1.57.0 grpcio-tools >= 1.57.0
invoke >= 1.7.3 invoke >= 1.7.3
mypy == 1.2.0 mypy == 1.5.0
nox >= 2022 nox >= 2022
pylint == 2.15.8 pylint == 2.15.8
pyyaml >= 6.0
types-appdirs >= 1.4.3 types-appdirs >= 1.4.3
types-invoke >= 1.7.3 types-invoke >= 1.7.3
types-protobuf >= 4.21.0 types-protobuf >= 4.21.0
avatar =
pandora-avatar == 0.0.5
rootcanal == 1.3.0 ; python_version>='3.10'
documentation = documentation =
mkdocs >= 1.4.0 mkdocs >= 1.4.0
mkdocs-material >= 8.5.6 mkdocs-material >= 8.5.6

View File

@@ -45,12 +45,14 @@ def test_messages():
] ]
message = Get_Capabilities_Response(capabilities) message = Get_Capabilities_Response(capabilities)
parsed = Message.create( parsed = Message.create(
AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload AVDTP_GET_CAPABILITIES, Message.MessageType.RESPONSE_ACCEPT, message.payload
) )
assert message.payload == parsed.payload assert message.payload == parsed.payload
message = Set_Configuration_Command(3, 4, capabilities) message = Set_Configuration_Command(3, 4, capabilities)
parsed = Message.create(AVDTP_SET_CONFIGURATION, Message.COMMAND, message.payload) parsed = Message.create(
AVDTP_SET_CONFIGURATION, Message.MessageType.COMMAND, message.payload
)
assert message.payload == parsed.payload assert message.payload == parsed.payload

View File

@@ -891,10 +891,10 @@ async def async_main():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_attribute_string_to_permissions(): def test_permissions_from_string():
assert Attribute.string_to_permissions('READABLE') == 1 assert Attribute.Permissions.from_string('READABLE') == 1
assert Attribute.string_to_permissions('WRITEABLE') == 2 assert Attribute.Permissions.from_string('WRITEABLE') == 2
assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3 assert Attribute.Permissions.from_string('READABLE,WRITEABLE') == 3
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -94,6 +94,7 @@ def temporary_file():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_basic(temporary_file): async def test_basic(temporary_file):
with open(temporary_file, mode='w', encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write("{}") file.write("{}")
@@ -125,6 +126,7 @@ async def test_basic(temporary_file):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_parsing(temporary_file): async def test_parsing(temporary_file):
with open(temporary_file, mode='w', encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write(JSON1) file.write(JSON1)
@@ -137,6 +139,7 @@ async def test_parsing(temporary_file):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_default_namespace(temporary_file): async def test_default_namespace(temporary_file):
with open(temporary_file, mode='w', encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write(JSON1) file.write(JSON1)

View File

@@ -18,6 +18,7 @@
import asyncio import asyncio
import logging import logging
import os import os
import pytest
from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
from bumble.sdp import ( from bumble.sdp import (
@@ -202,6 +203,7 @@ def sdp_records():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_search(): async def test_service_search():
# Setup connections # Setup connections
devices = TwoDevices() devices = TwoDevices()
@@ -224,6 +226,7 @@ async def test_service_search():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_attribute(): async def test_service_attribute():
# Setup connections # Setup connections
devices = TwoDevices() devices = TwoDevices()
@@ -244,6 +247,7 @@ async def test_service_attribute():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_search_attribute(): async def test_service_search_attribute():
# Setup connections # Setup connections
devices = TwoDevices() devices = TwoDevices()

77
tests/utils_test.py Normal file
View File

@@ -0,0 +1,77 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import os
from bumble import utils
from pyee import EventEmitter
from unittest.mock import MagicMock
def test_on() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
context.on(emitter, 'event', mock)
emitter.emit('event')
assert not emitter.listeners('event')
assert mock.call_count == 1
def test_on_decorator() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
@context.on(emitter, 'event')
def on_event(*_) -> None:
mock()
emitter.emit('event')
assert not emitter.listeners('event')
assert mock.call_count == 1
def test_multiple_handlers() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
context.once(emitter, 'a', mock)
context.once(emitter, 'b', mock)
emitter.emit('b', 'b')
assert not emitter.listeners('a')
assert not emitter.listeners('b')
mock.assert_called_once_with('b')
# -----------------------------------------------------------------------------
def run_tests():
test_on()
test_on_decorator()
test_multiple_handlers()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
run_tests()

View File

@@ -14,25 +14,25 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# This script generates a python-syntax list of dictionary entries for the # This script generates a python-syntax list of dictionary entries for the
# company IDs listed at: https://www.bluetooth.com/specifications/assigned-numbers/company-identifiers/ # company IDs listed at:
# The input to this script is the CSV file that can be obtained at that URL # https://bitbucket.org/bluetooth-SIG/public/src/main/assigned_numbers/company_identifiers/company_identifiers.yaml
# The input to this script is the YAML file that can be obtained at that URL
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import sys import sys
import csv import yaml
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
with open(sys.argv[1], newline='') as csvfile: with open(sys.argv[1], "r") as yaml_file:
reader = csv.reader(csvfile, delimiter=',', quotechar='"') root = yaml.safe_load(yaml_file)
lines = [] companies = {}
for row in reader: for company in root["company_identifiers"]:
if len(row) == 3 and row[1].startswith('0x'): companies[company["value"]] = company["name"]
company_id = row[1]
company_name = row[2]
escaped_company_name = company_name.replace('"', '\\"')
lines.append(f' {company_id}: "{escaped_company_name}"')
print(',\n'.join(reversed(lines))) for company_id in sorted(companies.keys()):
company_name = companies[company_id]
escaped_company_name = company_name.replace('"', '\\"')
print(f' 0x{company_id:04X}: "{escaped_company_name}",')

View File

@@ -74,7 +74,6 @@ export async function loadBumble(pyodide, bumblePackage) {
await pyodide.loadPackage("micropip"); await pyodide.loadPackage("micropip");
await pyodide.runPythonAsync(` await pyodide.runPythonAsync(`
import micropip import micropip
await micropip.install("cryptography")
await micropip.install("${bumblePackage}") await micropip.install("${bumblePackage}")
package_list = micropip.list() package_list = micropip.list()
print(package_list) print(package_list)

View File

@@ -23,7 +23,7 @@ from bumble.device import Device
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ScanEntry: class ScanEntry:
def __init__(self, advertisement): def __init__(self, advertisement):
self.address = str(advertisement.address).replace("/P", "") self.address = advertisement.address.to_string(False)
self.address_type = ('Public', 'Random', 'Public Identity', 'Random Identity')[ self.address_type = ('Public', 'Random', 'Public Identity', 'Random Identity')[
advertisement.address.address_type advertisement.address.address_type
] ]

View File

@@ -171,7 +171,7 @@ class Speaker:
self.connection = connection self.connection = connection
connection.on('disconnection', self.on_bluetooth_disconnection) connection.on('disconnection', self.on_bluetooth_disconnection)
peer_name = '' if connection.peer_name is None else connection.peer_name peer_name = '' if connection.peer_name is None else connection.peer_name
peer_address = str(connection.peer_address).replace('/P', '') peer_address = connection.peer_address.to_string(False)
self.emit_event( self.emit_event(
'connection', {'peer_name': peer_name, 'peer_address': peer_address} 'connection', {'peer_name': peer_name, 'peer_address': peer_address}
) )