forked from auracaster/bumble_mirror
Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9476be9ad | ||
|
|
704c60491c | ||
|
|
4a8e612c6e | ||
|
|
4e71ec5738 | ||
|
|
7255a09705 | ||
|
|
c2bf6b5f13 | ||
|
|
d8e699b588 | ||
|
|
3e4d4705f5 | ||
|
|
c8b2804446 | ||
|
|
e732f2589f | ||
|
|
aec5543081 | ||
|
|
e03d90ca57 | ||
|
|
495ce62d9c | ||
|
|
fbc3959a5a | ||
|
|
dfa9131192 | ||
|
|
88c801b4c2 | ||
|
|
a1b55b94e0 | ||
|
|
80db9e2e2f | ||
|
|
ce74690420 | ||
|
|
50de4dfb5d | ||
|
|
9bcdf860f4 | ||
|
|
511ab4b630 | ||
|
|
6f2b623e3c | ||
|
|
fa12165cd3 | ||
|
|
c0c6f3329d | ||
|
|
406a932467 | ||
|
|
cc96d4245f | ||
|
|
c6cdca8923 | ||
|
|
7e331c2944 | ||
|
|
10347765cb | ||
|
|
c12dee4e76 | ||
|
|
772c188674 | ||
|
|
7c1a3bb8f9 | ||
|
|
8c3c0b1e13 | ||
|
|
1ad84ad51c | ||
|
|
64937c3f77 | ||
|
|
50fd2218fa | ||
|
|
4c29a16271 | ||
|
|
762d3e92de | ||
|
|
2f97531d78 | ||
|
|
f6c7cae661 | ||
|
|
f1777a5bd2 | ||
|
|
78a06ae8cf | ||
|
|
d290df4aa9 | ||
|
|
e559744f32 | ||
|
|
67418e649a | ||
|
|
5adf9fab53 | ||
|
|
6ec6f1efe5 | ||
|
|
5d9598ea51 |
43
.github/workflows/python-avatar.yml
vendored
Normal file
43
.github/workflows/python-avatar.yml
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
name: Python Avatar
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Avatar [${{ matrix.shard }}]
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
shard: [
|
||||
1/24, 2/24, 3/24, 4/24,
|
||||
5/24, 6/24, 7/24, 8/24,
|
||||
9/24, 10/24, 11/24, 12/24,
|
||||
13/24, 14/24, 15/24, 16/24,
|
||||
17/24, 18/24, 19/24, 20/24,
|
||||
21/24, 22/24, 23/24, 24/24,
|
||||
]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set Up Python 3.11
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
- name: Install
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install .[avatar]
|
||||
- name: Rootcanal
|
||||
run: nohup python -m rootcanal > rootcanal.log &
|
||||
- name: Test
|
||||
run: |
|
||||
avatar --list | grep -Ev '^=' > test-names.txt
|
||||
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
|
||||
- name: Rootcanal Logs
|
||||
run: cat rootcanal.log
|
||||
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
@@ -65,6 +65,8 @@ jobs:
|
||||
with:
|
||||
components: clippy,rustfmt
|
||||
toolchain: ${{ matrix.rust-version }}
|
||||
- name: Check License Headers
|
||||
run: cd rust && cargo run --features dev-tools --bin file-header check-all
|
||||
- name: Rust Build
|
||||
run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
|
||||
# Lints after build so what clippy needs is already built
|
||||
|
||||
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
@@ -39,10 +39,12 @@
|
||||
"libusb",
|
||||
"MITM",
|
||||
"NDIS",
|
||||
"netsim",
|
||||
"NONBLOCK",
|
||||
"NONCONN",
|
||||
"OXIMETER",
|
||||
"popleft",
|
||||
"protobuf",
|
||||
"psms",
|
||||
"pyee",
|
||||
"pyusb",
|
||||
|
||||
@@ -1172,7 +1172,7 @@ class ScanResult:
|
||||
name = ''
|
||||
|
||||
# Remove any '/P' qualifier suffix from the address string
|
||||
address_str = str(self.address).replace('/P', '')
|
||||
address_str = self.address.to_string(with_type_qualifier=False)
|
||||
|
||||
# RSSI bar
|
||||
bar_string = rssi_bar(self.rssi)
|
||||
|
||||
@@ -63,7 +63,8 @@ async def get_classic_info(host):
|
||||
if command_succeeded(response):
|
||||
print()
|
||||
print(
|
||||
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
|
||||
color('Classic Address:', 'yellow'),
|
||||
response.return_parameters.bd_addr.to_string(False),
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
|
||||
|
||||
@@ -306,6 +306,7 @@ async def pair(
|
||||
# Expose a GATT characteristic that can be used to trigger pairing by
|
||||
# responding with an authentication error when read
|
||||
if mode == 'le':
|
||||
device.le_enabled = True
|
||||
device.add_service(
|
||||
Service(
|
||||
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
|
||||
@@ -326,7 +327,6 @@ async def pair(
|
||||
# Select LE or Classic
|
||||
if mode == 'classic':
|
||||
device.classic_enabled = True
|
||||
device.le_enabled = False
|
||||
device.classic_smp_enabled = ctkd
|
||||
|
||||
# Get things going
|
||||
|
||||
@@ -195,7 +195,7 @@ class WebSocketOutput(QueuedOutput):
|
||||
except HCI_StatusError:
|
||||
pass
|
||||
peer_name = '' if connection.peer_name is None else connection.peer_name
|
||||
peer_address = str(connection.peer_address).replace('/P', '')
|
||||
peer_address = connection.peer_address.to_string(False)
|
||||
await self.send_message(
|
||||
'connection',
|
||||
peer_address=peer_address,
|
||||
@@ -376,7 +376,7 @@ class UiServer:
|
||||
if connection := self.speaker().connection:
|
||||
await self.send_message(
|
||||
'connection',
|
||||
peer_address=str(connection.peer_address).replace('/P', ''),
|
||||
peer_address=connection.peer_address.to_string(False),
|
||||
peer_name=connection.peer_name,
|
||||
)
|
||||
|
||||
|
||||
124
bumble/att.py
124
bumble/att.py
@@ -23,13 +23,14 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import functools
|
||||
import struct
|
||||
from pyee import EventEmitter
|
||||
from typing import Dict, Type, TYPE_CHECKING
|
||||
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
|
||||
|
||||
from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
|
||||
from bumble.hci import HCI_Object, key_with_value, HCI_Constant
|
||||
from bumble.core import UUID, name_or_number, ProtocolError
|
||||
from bumble.hci import HCI_Object, key_with_value
|
||||
from bumble.colors import color
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -182,6 +183,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
|
||||
# pylint: enable=line-too-long
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Exceptions
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -209,7 +211,7 @@ class ATT_PDU:
|
||||
|
||||
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
|
||||
op_code = 0
|
||||
name = None
|
||||
name: str
|
||||
|
||||
@staticmethod
|
||||
def from_bytes(pdu):
|
||||
@@ -719,48 +721,68 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ConnectionValue(Protocol):
|
||||
def read(self, connection) -> bytes:
|
||||
...
|
||||
|
||||
def write(self, connection, value: bytes) -> None:
|
||||
...
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Attribute(EventEmitter):
|
||||
# Permission flags
|
||||
READABLE = 0x01
|
||||
WRITEABLE = 0x02
|
||||
READ_REQUIRES_ENCRYPTION = 0x04
|
||||
WRITE_REQUIRES_ENCRYPTION = 0x08
|
||||
READ_REQUIRES_AUTHENTICATION = 0x10
|
||||
WRITE_REQUIRES_AUTHENTICATION = 0x20
|
||||
READ_REQUIRES_AUTHORIZATION = 0x40
|
||||
WRITE_REQUIRES_AUTHORIZATION = 0x80
|
||||
class Permissions(enum.IntFlag):
|
||||
READABLE = 0x01
|
||||
WRITEABLE = 0x02
|
||||
READ_REQUIRES_ENCRYPTION = 0x04
|
||||
WRITE_REQUIRES_ENCRYPTION = 0x08
|
||||
READ_REQUIRES_AUTHENTICATION = 0x10
|
||||
WRITE_REQUIRES_AUTHENTICATION = 0x20
|
||||
READ_REQUIRES_AUTHORIZATION = 0x40
|
||||
WRITE_REQUIRES_AUTHORIZATION = 0x80
|
||||
|
||||
PERMISSION_NAMES = {
|
||||
READABLE: 'READABLE',
|
||||
WRITEABLE: 'WRITEABLE',
|
||||
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
|
||||
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
|
||||
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
|
||||
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
|
||||
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
|
||||
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
|
||||
}
|
||||
@classmethod
|
||||
def from_string(cls, permissions_str: str) -> Attribute.Permissions:
|
||||
try:
|
||||
return functools.reduce(
|
||||
lambda x, y: x | Attribute.Permissions[y],
|
||||
permissions_str.replace('|', ',').split(","),
|
||||
Attribute.Permissions(0),
|
||||
)
|
||||
except TypeError as exc:
|
||||
# The check for `p.name is not None` here is needed because for InFlag
|
||||
# enums, the .name property can be None, when the enum value is 0,
|
||||
# so the type hint for .name is Optional[str].
|
||||
enum_list: List[str] = [p.name for p in cls if p.name is not None]
|
||||
enum_list_str = ",".join(enum_list)
|
||||
raise TypeError(
|
||||
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
|
||||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def string_to_permissions(permissions_str: str):
|
||||
try:
|
||||
return functools.reduce(
|
||||
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
|
||||
permissions_str.split(","),
|
||||
0,
|
||||
)
|
||||
except TypeError as exc:
|
||||
raise TypeError(
|
||||
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
|
||||
) from exc
|
||||
# Permission flags(legacy-use only)
|
||||
READABLE = Permissions.READABLE
|
||||
WRITEABLE = Permissions.WRITEABLE
|
||||
READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
|
||||
WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
|
||||
READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
|
||||
WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
|
||||
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
|
||||
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
|
||||
|
||||
def __init__(self, attribute_type, permissions, value=b''):
|
||||
value: Union[str, bytes, ConnectionValue]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attribute_type: Union[str, bytes, UUID],
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Union[str, bytes, ConnectionValue] = b'',
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
self.handle = 0
|
||||
self.end_group_handle = 0
|
||||
if isinstance(permissions, str):
|
||||
self.permissions = self.string_to_permissions(permissions)
|
||||
self.permissions = Attribute.Permissions.from_string(permissions)
|
||||
else:
|
||||
self.permissions = permissions
|
||||
|
||||
@@ -778,22 +800,26 @@ class Attribute(EventEmitter):
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def encode_value(self, value):
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return value
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
def decode_value(self, value_bytes: bytes) -> Any:
|
||||
return value_bytes
|
||||
|
||||
def read_value(self, connection: Connection):
|
||||
def read_value(self, connection: Optional[Connection]) -> bytes:
|
||||
if (
|
||||
self.permissions & self.READ_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
|
||||
and connection is not None
|
||||
and not connection.encryption
|
||||
):
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
|
||||
)
|
||||
if (
|
||||
self.permissions & self.READ_REQUIRES_AUTHENTICATION
|
||||
) and not connection.authenticated:
|
||||
(self.permissions & self.READ_REQUIRES_AUTHENTICATION)
|
||||
and connection is not None
|
||||
and not connection.authenticated
|
||||
):
|
||||
raise ATT_Error(
|
||||
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
@@ -803,9 +829,9 @@ class Attribute(EventEmitter):
|
||||
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
|
||||
)
|
||||
|
||||
if read := getattr(self.value, 'read', None):
|
||||
if hasattr(self.value, 'read'):
|
||||
try:
|
||||
value = read(connection) # pylint: disable=not-callable
|
||||
value = self.value.read(connection)
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
@@ -815,7 +841,7 @@ class Attribute(EventEmitter):
|
||||
|
||||
return self.encode_value(value)
|
||||
|
||||
def write_value(self, connection: Connection, value_bytes):
|
||||
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
|
||||
if (
|
||||
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
@@ -836,9 +862,9 @@ class Attribute(EventEmitter):
|
||||
|
||||
value = self.decode_value(value_bytes)
|
||||
|
||||
if write := getattr(self.value, 'write', None):
|
||||
if hasattr(self.value, 'write'):
|
||||
try:
|
||||
write(connection, value) # pylint: disable=not-callable
|
||||
self.value.write(connection, value) # pylint: disable=not-callable
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
|
||||
480
bumble/avdtp.py
480
bumble/avdtp.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
117
bumble/device.py
117
bumble/device.py
@@ -33,6 +33,8 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
@@ -151,6 +153,7 @@ from .utils import (
|
||||
CompositeEventEmitter,
|
||||
setup_event_forwarding,
|
||||
composite_listener,
|
||||
deprecated,
|
||||
)
|
||||
from .keys import (
|
||||
KeyStore,
|
||||
@@ -670,9 +673,7 @@ class Connection(CompositeEventEmitter):
|
||||
def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None:
|
||||
self.device.send_l2cap_pdu(self.handle, cid, pdu)
|
||||
|
||||
def create_l2cap_connector(self, psm):
|
||||
return self.device.create_l2cap_connector(self, psm)
|
||||
|
||||
@deprecated("Please use create_l2cap_channel()")
|
||||
async def open_l2cap_channel(
|
||||
self,
|
||||
psm,
|
||||
@@ -682,6 +683,23 @@ class Connection(CompositeEventEmitter):
|
||||
):
|
||||
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
|
||||
|
||||
@overload
|
||||
async def create_l2cap_channel(
|
||||
self, spec: l2cap.ClassicChannelSpec
|
||||
) -> l2cap.ClassicChannel:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def create_l2cap_channel(
|
||||
self, spec: l2cap.LeCreditBasedChannelSpec
|
||||
) -> l2cap.LeCreditBasedChannel:
|
||||
...
|
||||
|
||||
async def create_l2cap_channel(
|
||||
self, spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec]
|
||||
) -> Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]:
|
||||
return await self.device.create_l2cap_channel(connection=self, spec=spec)
|
||||
|
||||
async def disconnect(
|
||||
self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
|
||||
) -> None:
|
||||
@@ -1180,15 +1198,11 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
return None
|
||||
|
||||
def create_l2cap_connector(self, connection, psm):
|
||||
return lambda: self.l2cap_channel_manager.connect(connection, psm)
|
||||
|
||||
def create_l2cap_registrar(self, psm):
|
||||
return lambda handler: self.register_l2cap_server(psm, handler)
|
||||
|
||||
def register_l2cap_server(self, psm, server):
|
||||
self.l2cap_channel_manager.register_server(psm, server)
|
||||
@deprecated("Please use create_l2cap_server()")
|
||||
def register_l2cap_server(self, psm, server) -> int:
|
||||
return self.l2cap_channel_manager.register_server(psm, server)
|
||||
|
||||
@deprecated("Please use create_l2cap_server()")
|
||||
def register_l2cap_channel_server(
|
||||
self,
|
||||
psm,
|
||||
@@ -1201,6 +1215,7 @@ class Device(CompositeEventEmitter):
|
||||
psm, server, max_credits, mtu, mps
|
||||
)
|
||||
|
||||
@deprecated("Please use create_l2cap_channel()")
|
||||
async def open_l2cap_channel(
|
||||
self,
|
||||
connection,
|
||||
@@ -1213,6 +1228,74 @@ class Device(CompositeEventEmitter):
|
||||
connection, psm, max_credits, mtu, mps
|
||||
)
|
||||
|
||||
@overload
|
||||
async def create_l2cap_channel(
|
||||
self,
|
||||
connection: Connection,
|
||||
spec: l2cap.ClassicChannelSpec,
|
||||
) -> l2cap.ClassicChannel:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def create_l2cap_channel(
|
||||
self,
|
||||
connection: Connection,
|
||||
spec: l2cap.LeCreditBasedChannelSpec,
|
||||
) -> l2cap.LeCreditBasedChannel:
|
||||
...
|
||||
|
||||
async def create_l2cap_channel(
|
||||
self,
|
||||
connection: Connection,
|
||||
spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec],
|
||||
) -> Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel]:
|
||||
if isinstance(spec, l2cap.ClassicChannelSpec):
|
||||
return await self.l2cap_channel_manager.create_classic_channel(
|
||||
connection=connection, spec=spec
|
||||
)
|
||||
if isinstance(spec, l2cap.LeCreditBasedChannelSpec):
|
||||
return await self.l2cap_channel_manager.create_le_credit_based_channel(
|
||||
connection=connection, spec=spec
|
||||
)
|
||||
|
||||
@overload
|
||||
def create_l2cap_server(
|
||||
self,
|
||||
spec: l2cap.ClassicChannelSpec,
|
||||
handler: Optional[Callable[[l2cap.ClassicChannel], Any]] = None,
|
||||
) -> l2cap.ClassicChannelServer:
|
||||
...
|
||||
|
||||
@overload
|
||||
def create_l2cap_server(
|
||||
self,
|
||||
spec: l2cap.LeCreditBasedChannelSpec,
|
||||
handler: Optional[Callable[[l2cap.LeCreditBasedChannel], Any]] = None,
|
||||
) -> l2cap.LeCreditBasedChannelServer:
|
||||
...
|
||||
|
||||
def create_l2cap_server(
|
||||
self,
|
||||
spec: Union[l2cap.ClassicChannelSpec, l2cap.LeCreditBasedChannelSpec],
|
||||
handler: Union[
|
||||
Callable[[l2cap.ClassicChannel], Any],
|
||||
Callable[[l2cap.LeCreditBasedChannel], Any],
|
||||
None,
|
||||
] = None,
|
||||
) -> Union[l2cap.ClassicChannelServer, l2cap.LeCreditBasedChannelServer]:
|
||||
if isinstance(spec, l2cap.ClassicChannelSpec):
|
||||
return self.l2cap_channel_manager.create_classic_server(
|
||||
spec=spec,
|
||||
handler=cast(Callable[[l2cap.ClassicChannel], Any], handler),
|
||||
)
|
||||
elif isinstance(spec, l2cap.LeCreditBasedChannelSpec):
|
||||
return self.l2cap_channel_manager.create_le_credit_based_server(
|
||||
handler=cast(Callable[[l2cap.LeCreditBasedChannel], Any], handler),
|
||||
spec=spec,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unexpected mode {spec}')
|
||||
|
||||
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
|
||||
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
|
||||
|
||||
@@ -1425,10 +1508,10 @@ class Device(CompositeEventEmitter):
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
self.advertising_own_address_type = own_address_type
|
||||
self.auto_restart_advertising = auto_restart
|
||||
self.advertising_type = advertising_type
|
||||
self.advertising_own_address_type = own_address_type
|
||||
self.advertising = True
|
||||
self.auto_restart_advertising = auto_restart
|
||||
|
||||
async def stop_advertising(self) -> None:
|
||||
# Disable advertising
|
||||
@@ -1438,9 +1521,9 @@ class Device(CompositeEventEmitter):
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
self.advertising_type = None
|
||||
self.advertising_own_address_type = None
|
||||
self.advertising = False
|
||||
self.advertising_type = None
|
||||
self.auto_restart_advertising = False
|
||||
|
||||
@property
|
||||
@@ -2630,7 +2713,6 @@ class Device(CompositeEventEmitter):
|
||||
own_address_type = self.advertising_own_address_type
|
||||
|
||||
# We are no longer advertising
|
||||
self.advertising_own_address_type = None
|
||||
self.advertising = False
|
||||
|
||||
if own_address_type in (
|
||||
@@ -2687,7 +2769,6 @@ class Device(CompositeEventEmitter):
|
||||
and self.advertising
|
||||
and self.advertising_type.is_directed
|
||||
):
|
||||
self.advertising_own_address_type = None
|
||||
self.advertising = False
|
||||
|
||||
# Notify listeners
|
||||
@@ -2758,7 +2839,9 @@ class Device(CompositeEventEmitter):
|
||||
self.abort_on(
|
||||
'flush',
|
||||
self.start_advertising(
|
||||
advertising_type=self.advertising_type, auto_restart=True
|
||||
advertising_type=self.advertising_type,
|
||||
own_address_type=self.advertising_own_address_type,
|
||||
auto_restart=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ import enum
|
||||
import functools
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Sequence, List
|
||||
from typing import Optional, Sequence, Iterable, List, Union
|
||||
|
||||
from .colors import color
|
||||
from .core import UUID, get_dict_key_by_value
|
||||
@@ -187,7 +187,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def show_services(services):
|
||||
def show_services(services: Iterable[Service]) -> None:
|
||||
for service in services:
|
||||
print(color(str(service), 'cyan'))
|
||||
|
||||
@@ -210,11 +210,11 @@ class Service(Attribute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uuid,
|
||||
uuid: Union[str, UUID],
|
||||
characteristics: List[Characteristic],
|
||||
primary=True,
|
||||
included_services: List[Service] = [],
|
||||
):
|
||||
) -> None:
|
||||
# Convert the uuid to a UUID object if it isn't already
|
||||
if isinstance(uuid, str):
|
||||
uuid = UUID(uuid)
|
||||
@@ -239,7 +239,7 @@ class Service(Attribute):
|
||||
"""
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Service(handle=0x{self.handle:04X}, '
|
||||
f'end=0x{self.end_group_handle:04X}, '
|
||||
@@ -255,9 +255,11 @@ class TemplateService(Service):
|
||||
to expose their UUID as a class property
|
||||
'''
|
||||
|
||||
UUID: Optional[UUID] = None
|
||||
UUID: UUID
|
||||
|
||||
def __init__(self, characteristics, primary=True):
|
||||
def __init__(
|
||||
self, characteristics: List[Characteristic], primary: bool = True
|
||||
) -> None:
|
||||
super().__init__(self.UUID, characteristics, primary)
|
||||
|
||||
|
||||
@@ -269,7 +271,7 @@ class IncludedServiceDeclaration(Attribute):
|
||||
|
||||
service: Service
|
||||
|
||||
def __init__(self, service):
|
||||
def __init__(self, service: Service) -> None:
|
||||
declaration_bytes = struct.pack(
|
||||
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
|
||||
)
|
||||
@@ -278,7 +280,7 @@ class IncludedServiceDeclaration(Attribute):
|
||||
)
|
||||
self.service = service
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
|
||||
f'group_starting_handle=0x{self.service.handle:04X}, '
|
||||
@@ -326,7 +328,7 @@ class Characteristic(Attribute):
|
||||
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
# NOTE: we override this method to offer a consistent result between python
|
||||
# versions: the value returned by IntFlag.__str__() changed in version 11.
|
||||
return '|'.join(
|
||||
@@ -348,10 +350,10 @@ class Characteristic(Attribute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uuid,
|
||||
uuid: Union[str, bytes, UUID],
|
||||
properties: Characteristic.Properties,
|
||||
permissions,
|
||||
value=b'',
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Union[str, bytes, CharacteristicValue] = b'',
|
||||
descriptors: Sequence[Descriptor] = (),
|
||||
):
|
||||
super().__init__(uuid, permissions, value)
|
||||
@@ -369,7 +371,7 @@ class Characteristic(Attribute):
|
||||
def has_properties(self, properties: Characteristic.Properties) -> bool:
|
||||
return self.properties & properties == properties
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Characteristic(handle=0x{self.handle:04X}, '
|
||||
f'end=0x{self.end_group_handle:04X}, '
|
||||
@@ -386,7 +388,7 @@ class CharacteristicDeclaration(Attribute):
|
||||
|
||||
characteristic: Characteristic
|
||||
|
||||
def __init__(self, characteristic, value_handle):
|
||||
def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
|
||||
declaration_bytes = (
|
||||
struct.pack('<BH', characteristic.properties, value_handle)
|
||||
+ characteristic.uuid.to_pdu_bytes()
|
||||
@@ -397,7 +399,7 @@ class CharacteristicDeclaration(Attribute):
|
||||
self.value_handle = value_handle
|
||||
self.characteristic = characteristic
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
|
||||
f'value_handle=0x{self.value_handle:04X}, '
|
||||
@@ -520,7 +522,7 @@ class CharacteristicAdapter:
|
||||
|
||||
return self.wrapped_characteristic.unsubscribe(subscriber)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
wrapped = str(self.wrapped_characteristic)
|
||||
return f'{self.__class__.__name__}({wrapped})'
|
||||
|
||||
@@ -600,10 +602,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
|
||||
Adapter that converts strings to/from bytes using UTF-8 encoding
|
||||
'''
|
||||
|
||||
def encode_value(self, value):
|
||||
def encode_value(self, value: str) -> bytes:
|
||||
return value.encode('utf-8')
|
||||
|
||||
def decode_value(self, value):
|
||||
def decode_value(self, value: bytes) -> str:
|
||||
return value.decode('utf-8')
|
||||
|
||||
|
||||
@@ -613,7 +615,7 @@ class Descriptor(Attribute):
|
||||
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
|
||||
'''
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Descriptor(handle=0x{self.handle:04X}, '
|
||||
f'type={self.type}, '
|
||||
|
||||
@@ -28,7 +28,18 @@ import asyncio
|
||||
import logging
|
||||
import struct
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Tuple, Callable, Union, Any
|
||||
from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Dict,
|
||||
Tuple,
|
||||
Callable,
|
||||
Union,
|
||||
Any,
|
||||
Iterable,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
@@ -66,8 +77,12 @@ from .gatt import (
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
ClientCharacteristicConfigurationBits,
|
||||
TemplateService,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -78,16 +93,16 @@ logger = logging.getLogger(__name__)
|
||||
# Proxies
|
||||
# -----------------------------------------------------------------------------
|
||||
class AttributeProxy(EventEmitter):
|
||||
client: Client
|
||||
|
||||
def __init__(self, client, handle, end_group_handle, attribute_type):
|
||||
def __init__(
|
||||
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
self.client = client
|
||||
self.handle = handle
|
||||
self.end_group_handle = end_group_handle
|
||||
self.type = attribute_type
|
||||
|
||||
async def read_value(self, no_long_read=False):
|
||||
async def read_value(self, no_long_read: bool = False) -> bytes:
|
||||
return self.decode_value(
|
||||
await self.client.read_value(self.handle, no_long_read)
|
||||
)
|
||||
@@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter):
|
||||
self.handle, self.encode_value(value), with_response
|
||||
)
|
||||
|
||||
def encode_value(self, value):
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return value
|
||||
|
||||
def decode_value(self, value_bytes):
|
||||
def decode_value(self, value_bytes: bytes) -> Any:
|
||||
return value_bytes
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
|
||||
|
||||
|
||||
@@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy):
|
||||
def get_characteristics_by_uuid(self, uuid):
|
||||
return self.client.get_characteristics_by_uuid(uuid, self)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
|
||||
|
||||
|
||||
class CharacteristicProxy(AttributeProxy):
|
||||
properties: Characteristic.Properties
|
||||
descriptors: List[DescriptorProxy]
|
||||
subscribers: Dict[Any, Callable]
|
||||
subscribers: Dict[Any, Callable[[bytes], Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy):
|
||||
return await self.client.discover_descriptors(self)
|
||||
|
||||
async def subscribe(
|
||||
self, subscriber: Optional[Callable] = None, prefer_notify=True
|
||||
self,
|
||||
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||
prefer_notify: bool = True,
|
||||
):
|
||||
if subscriber is not None:
|
||||
if subscriber in self.subscribers:
|
||||
@@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy):
|
||||
|
||||
return await self.client.unsubscribe(self, subscriber)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Characteristic(handle=0x{self.handle:04X}, '
|
||||
f'uuid={self.uuid}, '
|
||||
@@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy):
|
||||
def __init__(self, client, handle, descriptor_type):
|
||||
super().__init__(client, handle, 0, descriptor_type)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
|
||||
|
||||
|
||||
@@ -216,8 +233,10 @@ class ProfileServiceProxy:
|
||||
Base class for profile-specific service proxies
|
||||
'''
|
||||
|
||||
SERVICE_CLASS: Type[TemplateService]
|
||||
|
||||
@classmethod
|
||||
def from_client(cls, client):
|
||||
def from_client(cls, client: Client) -> ProfileServiceProxy:
|
||||
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
|
||||
|
||||
|
||||
@@ -227,8 +246,12 @@ class ProfileServiceProxy:
|
||||
class Client:
|
||||
services: List[ServiceProxy]
|
||||
cached_values: Dict[int, Tuple[datetime, bytes]]
|
||||
notification_subscribers: Dict[int, Callable[[bytes], Any]]
|
||||
indication_subscribers: Dict[int, Callable[[bytes], Any]]
|
||||
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
|
||||
pending_request: Optional[ATT_PDU]
|
||||
|
||||
def __init__(self, connection):
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
self.connection = connection
|
||||
self.mtu_exchange_done = False
|
||||
self.request_semaphore = asyncio.Semaphore(1)
|
||||
@@ -241,16 +264,16 @@ class Client:
|
||||
self.services = []
|
||||
self.cached_values = {}
|
||||
|
||||
def send_gatt_pdu(self, pdu):
|
||||
def send_gatt_pdu(self, pdu: bytes) -> None:
|
||||
self.connection.send_l2cap_pdu(ATT_CID, pdu)
|
||||
|
||||
async def send_command(self, command):
|
||||
async def send_command(self, command: ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
|
||||
)
|
||||
self.send_gatt_pdu(command.to_bytes())
|
||||
|
||||
async def send_request(self, request):
|
||||
async def send_request(self, request: ATT_PDU):
|
||||
logger.debug(
|
||||
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
|
||||
)
|
||||
@@ -279,14 +302,14 @@ class Client:
|
||||
|
||||
return response
|
||||
|
||||
def send_confirmation(self, confirmation):
|
||||
def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
|
||||
logger.debug(
|
||||
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
|
||||
f'{confirmation}'
|
||||
)
|
||||
self.send_gatt_pdu(confirmation.to_bytes())
|
||||
|
||||
async def request_mtu(self, mtu):
|
||||
async def request_mtu(self, mtu: int) -> int:
|
||||
# Check the range
|
||||
if mtu < ATT_DEFAULT_MTU:
|
||||
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||
@@ -313,10 +336,12 @@ class Client:
|
||||
|
||||
return self.connection.att_mtu
|
||||
|
||||
def get_services_by_uuid(self, uuid):
|
||||
def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
|
||||
return [service for service in self.services if service.uuid == uuid]
|
||||
|
||||
def get_characteristics_by_uuid(self, uuid, service=None):
|
||||
def get_characteristics_by_uuid(
|
||||
self, uuid: UUID, service: Optional[ServiceProxy] = None
|
||||
) -> List[CharacteristicProxy]:
|
||||
services = [service] if service else self.services
|
||||
return [
|
||||
c
|
||||
@@ -363,7 +388,7 @@ class Client:
|
||||
if not already_known:
|
||||
self.services.append(service)
|
||||
|
||||
async def discover_services(self, uuids=None) -> List[ServiceProxy]:
|
||||
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.4.1 Discover All Primary Services
|
||||
'''
|
||||
@@ -435,7 +460,7 @@ class Client:
|
||||
|
||||
return services
|
||||
|
||||
async def discover_service(self, uuid):
|
||||
async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
|
||||
'''
|
||||
@@ -468,7 +493,7 @@ class Client:
|
||||
f'{HCI_Constant.error_name(response.error_code)}'
|
||||
)
|
||||
# TODO raise appropriate exception
|
||||
return
|
||||
return []
|
||||
break
|
||||
|
||||
for attribute_handle, end_group_handle in response.handles_information:
|
||||
@@ -480,7 +505,7 @@ class Client:
|
||||
logger.warning(
|
||||
f'bogus handle values: {attribute_handle} {end_group_handle}'
|
||||
)
|
||||
return
|
||||
return []
|
||||
|
||||
# Create a service proxy for this service
|
||||
service = ServiceProxy(
|
||||
@@ -721,7 +746,7 @@ class Client:
|
||||
|
||||
return descriptors
|
||||
|
||||
async def discover_attributes(self):
|
||||
async def discover_attributes(self) -> List[AttributeProxy]:
|
||||
'''
|
||||
Discover all attributes, regardless of type
|
||||
'''
|
||||
@@ -844,7 +869,9 @@ class Client:
|
||||
# No more subscribers left
|
||||
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
||||
|
||||
async def read_value(self, attribute, no_long_read=False):
|
||||
async def read_value(
|
||||
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
|
||||
) -> Any:
|
||||
'''
|
||||
See Vol 3, Part G - 4.8.1 Read Characteristic Value
|
||||
|
||||
@@ -905,7 +932,9 @@ class Client:
|
||||
# Return the value as bytes
|
||||
return attribute_value
|
||||
|
||||
async def read_characteristics_by_uuid(self, uuid, service):
|
||||
async def read_characteristics_by_uuid(
|
||||
self, uuid: UUID, service: Optional[ServiceProxy]
|
||||
) -> List[bytes]:
|
||||
'''
|
||||
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
|
||||
'''
|
||||
@@ -960,7 +989,12 @@ class Client:
|
||||
|
||||
return characteristics_values
|
||||
|
||||
async def write_value(self, attribute, value, with_response=False):
|
||||
async def write_value(
|
||||
self,
|
||||
attribute: Union[int, AttributeProxy],
|
||||
value: bytes,
|
||||
with_response: bool = False,
|
||||
) -> None:
|
||||
'''
|
||||
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
|
||||
Value
|
||||
@@ -990,7 +1024,7 @@ class Client:
|
||||
)
|
||||
)
|
||||
|
||||
def on_gatt_pdu(self, att_pdu):
|
||||
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
|
||||
)
|
||||
@@ -1013,6 +1047,7 @@ class Client:
|
||||
return
|
||||
|
||||
# Return the response to the coroutine that is waiting for it
|
||||
assert self.pending_response is not None
|
||||
self.pending_response.set_result(att_pdu)
|
||||
else:
|
||||
handler_name = f'on_{att_pdu.name.lower()}'
|
||||
@@ -1060,7 +1095,7 @@ class Client:
|
||||
# Confirm that we received the indication
|
||||
self.send_confirmation(ATT_Handle_Value_Confirmation())
|
||||
|
||||
def cache_value(self, attribute_handle: int, value: bytes):
|
||||
def cache_value(self, attribute_handle: int, value: bytes) -> None:
|
||||
self.cached_values[attribute_handle] = (
|
||||
datetime.now(),
|
||||
value,
|
||||
|
||||
@@ -23,11 +23,12 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import struct
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
@@ -42,6 +43,7 @@ from .att import (
|
||||
ATT_INVALID_OFFSET_ERROR,
|
||||
ATT_REQUEST_NOT_SUPPORTED_ERROR,
|
||||
ATT_REQUESTS,
|
||||
ATT_PDU,
|
||||
ATT_UNLIKELY_ERROR_ERROR,
|
||||
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
|
||||
ATT_Error,
|
||||
@@ -73,6 +75,8 @@ from .gatt import (
|
||||
Service,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Device, Connection
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -91,8 +95,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server(EventEmitter):
|
||||
attributes: List[Attribute]
|
||||
services: List[Service]
|
||||
attributes_by_handle: Dict[int, Attribute]
|
||||
subscribers: Dict[int, Dict[int, bytes]]
|
||||
indication_semaphores: defaultdict[int, asyncio.Semaphore]
|
||||
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
|
||||
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: Device) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.services = []
|
||||
@@ -107,16 +116,16 @@ class Server(EventEmitter):
|
||||
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
||||
self.pending_confirmations = defaultdict(lambda: None)
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "\n".join(map(str, self.attributes))
|
||||
|
||||
def send_gatt_pdu(self, connection_handle, pdu):
|
||||
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
|
||||
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
|
||||
|
||||
def next_handle(self):
|
||||
def next_handle(self) -> int:
|
||||
return 1 + len(self.attributes)
|
||||
|
||||
def get_advertising_service_data(self):
|
||||
def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
|
||||
return {
|
||||
attribute: data
|
||||
for attribute in self.attributes
|
||||
@@ -124,7 +133,7 @@ class Server(EventEmitter):
|
||||
and (data := attribute.get_advertising_data())
|
||||
}
|
||||
|
||||
def get_attribute(self, handle):
|
||||
def get_attribute(self, handle: int) -> Optional[Attribute]:
|
||||
attribute = self.attributes_by_handle.get(handle)
|
||||
if attribute:
|
||||
return attribute
|
||||
@@ -173,12 +182,17 @@ class Server(EventEmitter):
|
||||
|
||||
return next(
|
||||
(
|
||||
(attribute, self.get_attribute(attribute.characteristic.handle))
|
||||
(
|
||||
attribute,
|
||||
self.get_attribute(attribute.characteristic.handle),
|
||||
) # type: ignore
|
||||
for attribute in map(
|
||||
self.get_attribute,
|
||||
range(service_handle.handle, service_handle.end_group_handle + 1),
|
||||
)
|
||||
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||
if attribute is not None
|
||||
and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||
and isinstance(attribute, CharacteristicDeclaration)
|
||||
and attribute.characteristic.uuid == characteristic_uuid
|
||||
),
|
||||
None,
|
||||
@@ -197,7 +211,7 @@ class Server(EventEmitter):
|
||||
|
||||
return next(
|
||||
(
|
||||
attribute
|
||||
attribute # type: ignore
|
||||
for attribute in map(
|
||||
self.get_attribute,
|
||||
range(
|
||||
@@ -205,12 +219,12 @@ class Server(EventEmitter):
|
||||
characteristic_value.end_group_handle + 1,
|
||||
),
|
||||
)
|
||||
if attribute.type == descriptor_uuid
|
||||
if attribute is not None and attribute.type == descriptor_uuid
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def add_attribute(self, attribute):
|
||||
def add_attribute(self, attribute: Attribute) -> None:
|
||||
# Assign a handle to this attribute
|
||||
attribute.handle = self.next_handle()
|
||||
attribute.end_group_handle = (
|
||||
@@ -220,7 +234,7 @@ class Server(EventEmitter):
|
||||
# Add this attribute to the list
|
||||
self.attributes.append(attribute)
|
||||
|
||||
def add_service(self, service: Service):
|
||||
def add_service(self, service: Service) -> None:
|
||||
# Add the service attribute to the DB
|
||||
self.add_attribute(service)
|
||||
|
||||
@@ -285,11 +299,13 @@ class Server(EventEmitter):
|
||||
service.end_group_handle = self.attributes[-1].handle
|
||||
self.services.append(service)
|
||||
|
||||
def add_services(self, services):
|
||||
def add_services(self, services: Iterable[Service]) -> None:
|
||||
for service in services:
|
||||
self.add_service(service)
|
||||
|
||||
def read_cccd(self, connection, characteristic):
|
||||
def read_cccd(
|
||||
self, connection: Optional[Connection], characteristic: Characteristic
|
||||
) -> bytes:
|
||||
if connection is None:
|
||||
return bytes([0, 0])
|
||||
|
||||
@@ -300,7 +316,12 @@ class Server(EventEmitter):
|
||||
|
||||
return cccd or bytes([0, 0])
|
||||
|
||||
def write_cccd(self, connection, characteristic, value):
|
||||
def write_cccd(
|
||||
self,
|
||||
connection: Connection,
|
||||
characteristic: Characteristic,
|
||||
value: bytes,
|
||||
) -> None:
|
||||
logger.debug(
|
||||
f'Subscription update for connection=0x{connection.handle:04X}, '
|
||||
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
|
||||
@@ -327,13 +348,19 @@ class Server(EventEmitter):
|
||||
indicate_enabled,
|
||||
)
|
||||
|
||||
def send_response(self, connection, response):
|
||||
def send_response(self, connection: Connection, response: ATT_PDU) -> None:
|
||||
logger.debug(
|
||||
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
||||
|
||||
async def notify_subscriber(self, connection, attribute, value=None, force=False):
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
@@ -370,7 +397,13 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, bytes(notification))
|
||||
|
||||
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
|
||||
async def indicate_subscriber(
|
||||
self,
|
||||
connection: Connection,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
@@ -411,15 +444,13 @@ class Server(EventEmitter):
|
||||
assert self.pending_confirmations[connection.handle] is None
|
||||
|
||||
# Create a future value to hold the eventual response
|
||||
self.pending_confirmations[
|
||||
pending_confirmation = self.pending_confirmations[
|
||||
connection.handle
|
||||
] = asyncio.get_running_loop().create_future()
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(connection.handle, indication.to_bytes())
|
||||
await asyncio.wait_for(
|
||||
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
|
||||
)
|
||||
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
|
||||
except asyncio.TimeoutError as error:
|
||||
logger.warning(color('!!! GATT Indicate timeout', 'red'))
|
||||
raise TimeoutError(f'GATT timeout for {indication.name}') from error
|
||||
@@ -427,8 +458,12 @@ class Server(EventEmitter):
|
||||
self.pending_confirmations[connection.handle] = None
|
||||
|
||||
async def notify_or_indicate_subscribers(
|
||||
self, indicate, attribute, value=None, force=False
|
||||
):
|
||||
self,
|
||||
indicate: bool,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
# Get all the connections for which there's at least one subscription
|
||||
connections = [
|
||||
connection
|
||||
@@ -450,13 +485,23 @@ class Server(EventEmitter):
|
||||
]
|
||||
)
|
||||
|
||||
async def notify_subscribers(self, attribute, value=None, force=False):
|
||||
async def notify_subscribers(
|
||||
self,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
|
||||
|
||||
async def indicate_subscribers(self, attribute, value=None, force=False):
|
||||
async def indicate_subscribers(
|
||||
self,
|
||||
attribute: Attribute,
|
||||
value: Optional[bytes] = None,
|
||||
force: bool = False,
|
||||
):
|
||||
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
||||
|
||||
def on_disconnection(self, connection):
|
||||
def on_disconnection(self, connection: Connection) -> None:
|
||||
if connection.handle in self.subscribers:
|
||||
del self.subscribers[connection.handle]
|
||||
if connection.handle in self.indication_semaphores:
|
||||
@@ -464,7 +509,7 @@ class Server(EventEmitter):
|
||||
if connection.handle in self.pending_confirmations:
|
||||
del self.pending_confirmations[connection.handle]
|
||||
|
||||
def on_gatt_pdu(self, connection, att_pdu):
|
||||
def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
|
||||
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
|
||||
handler_name = f'on_{att_pdu.name.lower()}'
|
||||
handler = getattr(self, handler_name, None)
|
||||
@@ -506,7 +551,7 @@ class Server(EventEmitter):
|
||||
#######################################################
|
||||
# ATT handlers
|
||||
#######################################################
|
||||
def on_att_request(self, connection, pdu):
|
||||
def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
|
||||
'''
|
||||
Handler for requests without a more specific handler
|
||||
'''
|
||||
@@ -679,7 +724,6 @@ class Server(EventEmitter):
|
||||
and attribute.handle <= request.ending_handle
|
||||
and pdu_space_available
|
||||
):
|
||||
|
||||
try:
|
||||
attribute_value = attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
|
||||
510
bumble/l2cap.py
510
bumble/l2cap.py
@@ -17,6 +17,8 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
|
||||
@@ -37,6 +39,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .utils import deprecated
|
||||
from .colors import color
|
||||
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
|
||||
from .hci import (
|
||||
@@ -166,6 +169,34 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ClassicChannelSpec:
|
||||
psm: Optional[int] = None
|
||||
mtu: int = L2CAP_MIN_BR_EDR_MTU
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LeCreditBasedChannelSpec:
|
||||
psm: Optional[int] = None
|
||||
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
|
||||
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
|
||||
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.max_credits < 1
|
||||
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
||||
):
|
||||
raise ValueError('max credits out of range')
|
||||
if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
|
||||
raise ValueError('MTU too small')
|
||||
if (
|
||||
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
|
||||
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
|
||||
):
|
||||
raise ValueError('MPS out of range')
|
||||
|
||||
|
||||
class L2CAP_PDU:
|
||||
'''
|
||||
See Bluetooth spec @ Vol 3, Part A - 3 DATA PACKET FORMAT
|
||||
@@ -675,57 +706,36 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Channel(EventEmitter):
|
||||
# States
|
||||
CLOSED = 0x00
|
||||
WAIT_CONNECT = 0x01
|
||||
WAIT_CONNECT_RSP = 0x02
|
||||
OPEN = 0x03
|
||||
WAIT_DISCONNECT = 0x04
|
||||
WAIT_CREATE = 0x05
|
||||
WAIT_CREATE_RSP = 0x06
|
||||
WAIT_MOVE = 0x07
|
||||
WAIT_MOVE_RSP = 0x08
|
||||
WAIT_MOVE_CONFIRM = 0x09
|
||||
WAIT_CONFIRM_RSP = 0x0A
|
||||
class ClassicChannel(EventEmitter):
|
||||
class State(enum.IntEnum):
|
||||
# States
|
||||
CLOSED = 0x00
|
||||
WAIT_CONNECT = 0x01
|
||||
WAIT_CONNECT_RSP = 0x02
|
||||
OPEN = 0x03
|
||||
WAIT_DISCONNECT = 0x04
|
||||
WAIT_CREATE = 0x05
|
||||
WAIT_CREATE_RSP = 0x06
|
||||
WAIT_MOVE = 0x07
|
||||
WAIT_MOVE_RSP = 0x08
|
||||
WAIT_MOVE_CONFIRM = 0x09
|
||||
WAIT_CONFIRM_RSP = 0x0A
|
||||
|
||||
# CONFIG substates
|
||||
WAIT_CONFIG = 0x10
|
||||
WAIT_SEND_CONFIG = 0x11
|
||||
WAIT_CONFIG_REQ_RSP = 0x12
|
||||
WAIT_CONFIG_RSP = 0x13
|
||||
WAIT_CONFIG_REQ = 0x14
|
||||
WAIT_IND_FINAL_RSP = 0x15
|
||||
WAIT_FINAL_RSP = 0x16
|
||||
WAIT_CONTROL_IND = 0x17
|
||||
|
||||
STATE_NAMES = {
|
||||
CLOSED: 'CLOSED',
|
||||
WAIT_CONNECT: 'WAIT_CONNECT',
|
||||
WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP',
|
||||
OPEN: 'OPEN',
|
||||
WAIT_DISCONNECT: 'WAIT_DISCONNECT',
|
||||
WAIT_CREATE: 'WAIT_CREATE',
|
||||
WAIT_CREATE_RSP: 'WAIT_CREATE_RSP',
|
||||
WAIT_MOVE: 'WAIT_MOVE',
|
||||
WAIT_MOVE_RSP: 'WAIT_MOVE_RSP',
|
||||
WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM',
|
||||
WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP',
|
||||
WAIT_CONFIG: 'WAIT_CONFIG',
|
||||
WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG',
|
||||
WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP',
|
||||
WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP',
|
||||
WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ',
|
||||
WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP',
|
||||
WAIT_FINAL_RSP: 'WAIT_FINAL_RSP',
|
||||
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
|
||||
}
|
||||
# CONFIG substates
|
||||
WAIT_CONFIG = 0x10
|
||||
WAIT_SEND_CONFIG = 0x11
|
||||
WAIT_CONFIG_REQ_RSP = 0x12
|
||||
WAIT_CONFIG_RSP = 0x13
|
||||
WAIT_CONFIG_REQ = 0x14
|
||||
WAIT_IND_FINAL_RSP = 0x15
|
||||
WAIT_FINAL_RSP = 0x16
|
||||
WAIT_CONTROL_IND = 0x17
|
||||
|
||||
connection_result: Optional[asyncio.Future[None]]
|
||||
disconnection_result: Optional[asyncio.Future[None]]
|
||||
response: Optional[asyncio.Future[bytes]]
|
||||
sink: Optional[Callable[[bytes], Any]]
|
||||
state: int
|
||||
state: State
|
||||
connection: Connection
|
||||
|
||||
def __init__(
|
||||
@@ -741,7 +751,7 @@ class Channel(EventEmitter):
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
self.signaling_cid = signaling_cid
|
||||
self.state = Channel.CLOSED
|
||||
self.state = self.State.CLOSED
|
||||
self.mtu = mtu
|
||||
self.psm = psm
|
||||
self.source_cid = source_cid
|
||||
@@ -751,10 +761,8 @@ class Channel(EventEmitter):
|
||||
self.disconnection_result = None
|
||||
self.sink = None
|
||||
|
||||
def change_state(self, new_state: int) -> None:
|
||||
logger.debug(
|
||||
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
|
||||
)
|
||||
def _change_state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
|
||||
self.state = new_state
|
||||
|
||||
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||
@@ -767,7 +775,7 @@ class Channel(EventEmitter):
|
||||
# Check that there isn't already a request pending
|
||||
if self.response:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.state != Channel.OPEN:
|
||||
if self.state != self.State.OPEN:
|
||||
raise InvalidStateError('channel not open')
|
||||
|
||||
self.response = asyncio.get_running_loop().create_future()
|
||||
@@ -787,14 +795,14 @@ class Channel(EventEmitter):
|
||||
)
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self.state != Channel.CLOSED:
|
||||
if self.state != self.State.CLOSED:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
# Check that we can start a new connection
|
||||
if self.connection_result:
|
||||
raise RuntimeError('connection already pending')
|
||||
|
||||
self.change_state(Channel.WAIT_CONNECT_RSP)
|
||||
self._change_state(self.State.WAIT_CONNECT_RSP)
|
||||
self.send_control_frame(
|
||||
L2CAP_Connection_Request(
|
||||
identifier=self.manager.next_identifier(self.connection),
|
||||
@@ -814,10 +822,10 @@ class Channel(EventEmitter):
|
||||
self.connection_result = None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.state != Channel.OPEN:
|
||||
if self.state != self.State.OPEN:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
self.change_state(Channel.WAIT_DISCONNECT)
|
||||
self._change_state(self.State.WAIT_DISCONNECT)
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Request(
|
||||
identifier=self.manager.next_identifier(self.connection),
|
||||
@@ -832,8 +840,8 @@ class Channel(EventEmitter):
|
||||
return await self.disconnection_result
|
||||
|
||||
def abort(self) -> None:
|
||||
if self.state == self.OPEN:
|
||||
self.change_state(self.CLOSED)
|
||||
if self.state == self.State.OPEN:
|
||||
self._change_state(self.State.CLOSED)
|
||||
self.emit('close')
|
||||
|
||||
def send_configure_request(self) -> None:
|
||||
@@ -856,7 +864,7 @@ class Channel(EventEmitter):
|
||||
|
||||
def on_connection_request(self, request) -> None:
|
||||
self.destination_cid = request.source_cid
|
||||
self.change_state(Channel.WAIT_CONNECT)
|
||||
self._change_state(self.State.WAIT_CONNECT)
|
||||
self.send_control_frame(
|
||||
L2CAP_Connection_Response(
|
||||
identifier=request.identifier,
|
||||
@@ -866,24 +874,24 @@ class Channel(EventEmitter):
|
||||
status=0x0000,
|
||||
)
|
||||
)
|
||||
self.change_state(Channel.WAIT_CONFIG)
|
||||
self._change_state(self.State.WAIT_CONFIG)
|
||||
self.send_configure_request()
|
||||
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
|
||||
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
|
||||
|
||||
def on_connection_response(self, response):
|
||||
if self.state != Channel.WAIT_CONNECT_RSP:
|
||||
if self.state != self.State.WAIT_CONNECT_RSP:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
|
||||
self.destination_cid = response.destination_cid
|
||||
self.change_state(Channel.WAIT_CONFIG)
|
||||
self._change_state(self.State.WAIT_CONFIG)
|
||||
self.send_configure_request()
|
||||
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
|
||||
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
|
||||
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
|
||||
pass
|
||||
else:
|
||||
self.change_state(Channel.CLOSED)
|
||||
self._change_state(self.State.CLOSED)
|
||||
self.connection_result.set_exception(
|
||||
ProtocolError(
|
||||
response.result,
|
||||
@@ -895,9 +903,9 @@ class Channel(EventEmitter):
|
||||
|
||||
def on_configure_request(self, request) -> None:
|
||||
if self.state not in (
|
||||
Channel.WAIT_CONFIG,
|
||||
Channel.WAIT_CONFIG_REQ,
|
||||
Channel.WAIT_CONFIG_REQ_RSP,
|
||||
self.State.WAIT_CONFIG,
|
||||
self.State.WAIT_CONFIG_REQ,
|
||||
self.State.WAIT_CONFIG_REQ_RSP,
|
||||
):
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
@@ -918,25 +926,28 @@ class Channel(EventEmitter):
|
||||
options=request.options, # TODO: don't accept everything blindly
|
||||
)
|
||||
)
|
||||
if self.state == Channel.WAIT_CONFIG:
|
||||
self.change_state(Channel.WAIT_SEND_CONFIG)
|
||||
if self.state == self.State.WAIT_CONFIG:
|
||||
self._change_state(self.State.WAIT_SEND_CONFIG)
|
||||
self.send_configure_request()
|
||||
self.change_state(Channel.WAIT_CONFIG_RSP)
|
||||
elif self.state == Channel.WAIT_CONFIG_REQ:
|
||||
self.change_state(Channel.OPEN)
|
||||
self._change_state(self.State.WAIT_CONFIG_RSP)
|
||||
elif self.state == self.State.WAIT_CONFIG_REQ:
|
||||
self._change_state(self.State.OPEN)
|
||||
if self.connection_result:
|
||||
self.connection_result.set_result(None)
|
||||
self.connection_result = None
|
||||
self.emit('open')
|
||||
elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
|
||||
self.change_state(Channel.WAIT_CONFIG_RSP)
|
||||
elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
|
||||
self._change_state(self.State.WAIT_CONFIG_RSP)
|
||||
|
||||
def on_configure_response(self, response) -> None:
|
||||
if response.result == L2CAP_Configure_Response.SUCCESS:
|
||||
if self.state == Channel.WAIT_CONFIG_REQ_RSP:
|
||||
self.change_state(Channel.WAIT_CONFIG_REQ)
|
||||
elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND):
|
||||
self.change_state(Channel.OPEN)
|
||||
if self.state == self.State.WAIT_CONFIG_REQ_RSP:
|
||||
self._change_state(self.State.WAIT_CONFIG_REQ)
|
||||
elif self.state in (
|
||||
self.State.WAIT_CONFIG_RSP,
|
||||
self.State.WAIT_CONTROL_IND,
|
||||
):
|
||||
self._change_state(self.State.OPEN)
|
||||
if self.connection_result:
|
||||
self.connection_result.set_result(None)
|
||||
self.connection_result = None
|
||||
@@ -966,7 +977,7 @@ class Channel(EventEmitter):
|
||||
# TODO: decide how to fail gracefully
|
||||
|
||||
def on_disconnection_request(self, request) -> None:
|
||||
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
|
||||
if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT):
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Response(
|
||||
identifier=request.identifier,
|
||||
@@ -974,14 +985,14 @@ class Channel(EventEmitter):
|
||||
source_cid=request.source_cid,
|
||||
)
|
||||
)
|
||||
self.change_state(Channel.CLOSED)
|
||||
self._change_state(self.State.CLOSED)
|
||||
self.emit('close')
|
||||
self.manager.on_channel_closed(self)
|
||||
else:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
|
||||
def on_disconnection_response(self, response) -> None:
|
||||
if self.state != Channel.WAIT_DISCONNECT:
|
||||
if self.state != self.State.WAIT_DISCONNECT:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
@@ -992,7 +1003,7 @@ class Channel(EventEmitter):
|
||||
logger.warning('unexpected source or destination CID')
|
||||
return
|
||||
|
||||
self.change_state(Channel.CLOSED)
|
||||
self._change_state(self.State.CLOSED)
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
@@ -1004,43 +1015,31 @@ class Channel(EventEmitter):
|
||||
f'Channel({self.source_cid}->{self.destination_cid}, '
|
||||
f'PSM={self.psm}, '
|
||||
f'MTU={self.mtu}, '
|
||||
f'state={Channel.STATE_NAMES[self.state]})'
|
||||
f'state={self.state.name})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LeConnectionOrientedChannel(EventEmitter):
|
||||
class LeCreditBasedChannel(EventEmitter):
|
||||
"""
|
||||
LE Credit-based Connection Oriented Channel
|
||||
"""
|
||||
|
||||
INIT = 0
|
||||
CONNECTED = 1
|
||||
CONNECTING = 2
|
||||
DISCONNECTING = 3
|
||||
DISCONNECTED = 4
|
||||
CONNECTION_ERROR = 5
|
||||
|
||||
STATE_NAMES = {
|
||||
INIT: 'INIT',
|
||||
CONNECTED: 'CONNECTED',
|
||||
CONNECTING: 'CONNECTING',
|
||||
DISCONNECTING: 'DISCONNECTING',
|
||||
DISCONNECTED: 'DISCONNECTED',
|
||||
CONNECTION_ERROR: 'CONNECTION_ERROR',
|
||||
}
|
||||
class State(enum.IntEnum):
|
||||
INIT = 0
|
||||
CONNECTED = 1
|
||||
CONNECTING = 2
|
||||
DISCONNECTING = 3
|
||||
DISCONNECTED = 4
|
||||
CONNECTION_ERROR = 5
|
||||
|
||||
out_queue: Deque[bytes]
|
||||
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
|
||||
connection_result: Optional[asyncio.Future[LeCreditBasedChannel]]
|
||||
disconnection_result: Optional[asyncio.Future[None]]
|
||||
out_sdu: Optional[bytes]
|
||||
state: int
|
||||
state: State
|
||||
connection: Connection
|
||||
|
||||
@staticmethod
|
||||
def state_name(state: int) -> str:
|
||||
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: ChannelManager,
|
||||
@@ -1083,19 +1082,17 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
self.drained.set()
|
||||
|
||||
if connected:
|
||||
self.state = LeConnectionOrientedChannel.CONNECTED
|
||||
self.state = self.State.CONNECTED
|
||||
else:
|
||||
self.state = LeConnectionOrientedChannel.INIT
|
||||
self.state = self.State.INIT
|
||||
|
||||
def change_state(self, new_state: int) -> None:
|
||||
logger.debug(
|
||||
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
|
||||
)
|
||||
def _change_state(self, new_state: State) -> None:
|
||||
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
|
||||
self.state = new_state
|
||||
|
||||
if new_state == self.CONNECTED:
|
||||
if new_state == self.State.CONNECTED:
|
||||
self.emit('open')
|
||||
elif new_state == self.DISCONNECTED:
|
||||
elif new_state == self.State.DISCONNECTED:
|
||||
self.emit('close')
|
||||
|
||||
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||
@@ -1104,9 +1101,9 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
|
||||
|
||||
async def connect(self) -> LeConnectionOrientedChannel:
|
||||
async def connect(self) -> LeCreditBasedChannel:
|
||||
# Check that we're in the right state
|
||||
if self.state != self.INIT:
|
||||
if self.state != self.State.INIT:
|
||||
raise InvalidStateError('not in a connectable state')
|
||||
|
||||
# Check that we can start a new connection
|
||||
@@ -1114,7 +1111,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
if identifier in self.manager.le_coc_requests:
|
||||
raise RuntimeError('too many concurrent connection requests')
|
||||
|
||||
self.change_state(self.CONNECTING)
|
||||
self._change_state(self.State.CONNECTING)
|
||||
request = L2CAP_LE_Credit_Based_Connection_Request(
|
||||
identifier=identifier,
|
||||
le_psm=self.le_psm,
|
||||
@@ -1134,10 +1131,10 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
# Check that we're connected
|
||||
if self.state != self.CONNECTED:
|
||||
if self.state != self.State.CONNECTED:
|
||||
raise InvalidStateError('not connected')
|
||||
|
||||
self.change_state(self.DISCONNECTING)
|
||||
self._change_state(self.State.DISCONNECTING)
|
||||
self.flush_output()
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Request(
|
||||
@@ -1153,15 +1150,15 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
return await self.disconnection_result
|
||||
|
||||
def abort(self) -> None:
|
||||
if self.state == self.CONNECTED:
|
||||
self.change_state(self.DISCONNECTED)
|
||||
if self.state == self.State.CONNECTED:
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if self.sink is None:
|
||||
logger.warning('received pdu without a sink')
|
||||
return
|
||||
|
||||
if self.state != self.CONNECTED:
|
||||
if self.state != self.State.CONNECTED:
|
||||
logger.warning('received PDU while not connected, dropping')
|
||||
|
||||
# Manage the peer credits
|
||||
@@ -1240,7 +1237,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
self.credits = response.initial_credits
|
||||
self.connected = True
|
||||
self.connection_result.set_result(self)
|
||||
self.change_state(self.CONNECTED)
|
||||
self._change_state(self.State.CONNECTED)
|
||||
else:
|
||||
self.connection_result.set_exception(
|
||||
ProtocolError(
|
||||
@@ -1251,7 +1248,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
),
|
||||
)
|
||||
)
|
||||
self.change_state(self.CONNECTION_ERROR)
|
||||
self._change_state(self.State.CONNECTION_ERROR)
|
||||
|
||||
# Cleanup
|
||||
self.connection_result = None
|
||||
@@ -1271,11 +1268,11 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
source_cid=request.source_cid,
|
||||
)
|
||||
)
|
||||
self.change_state(self.DISCONNECTED)
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
self.flush_output()
|
||||
|
||||
def on_disconnection_response(self, response) -> None:
|
||||
if self.state != self.DISCONNECTING:
|
||||
if self.state != self.State.DISCONNECTING:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
@@ -1286,7 +1283,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
logger.warning('unexpected source or destination CID')
|
||||
return
|
||||
|
||||
self.change_state(self.DISCONNECTED)
|
||||
self._change_state(self.State.DISCONNECTED)
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
@@ -1339,7 +1336,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
return
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
if self.state != self.CONNECTED:
|
||||
if self.state != self.State.CONNECTED:
|
||||
logger.warning('not connected, dropping data')
|
||||
return
|
||||
|
||||
@@ -1367,7 +1364,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'CoC({self.source_cid}->{self.destination_cid}, '
|
||||
f'State={self.state_name(self.state)}, '
|
||||
f'State={self.state.name}, '
|
||||
f'PSM={self.le_psm}, '
|
||||
f'MTU={self.mtu}/{self.peer_mtu}, '
|
||||
f'MPS={self.mps}/{self.peer_mps}, '
|
||||
@@ -1375,15 +1372,67 @@ class LeConnectionOrientedChannel(EventEmitter):
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClassicChannelServer(EventEmitter):
|
||||
def __init__(
|
||||
self,
|
||||
manager: ChannelManager,
|
||||
psm: int,
|
||||
handler: Optional[Callable[[ClassicChannel], Any]],
|
||||
mtu: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.handler = handler
|
||||
self.psm = psm
|
||||
self.mtu = mtu
|
||||
|
||||
def on_connection(self, channel: ClassicChannel) -> None:
|
||||
self.emit('connection', channel)
|
||||
if self.handler:
|
||||
self.handler(channel)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.psm in self.manager.servers:
|
||||
del self.manager.servers[self.psm]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LeCreditBasedChannelServer(EventEmitter):
|
||||
def __init__(
|
||||
self,
|
||||
manager: ChannelManager,
|
||||
psm: int,
|
||||
handler: Optional[Callable[[LeCreditBasedChannel], Any]],
|
||||
max_credits: int,
|
||||
mtu: int,
|
||||
mps: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.handler = handler
|
||||
self.psm = psm
|
||||
self.max_credits = max_credits
|
||||
self.mtu = mtu
|
||||
self.mps = mps
|
||||
|
||||
def on_connection(self, channel: LeCreditBasedChannel) -> None:
|
||||
self.emit('connection', channel)
|
||||
if self.handler:
|
||||
self.handler(channel)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.psm in self.manager.le_coc_servers:
|
||||
del self.manager.le_coc_servers[self.psm]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ChannelManager:
|
||||
identifiers: Dict[int, int]
|
||||
channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]]
|
||||
servers: Dict[int, Callable[[Channel], Any]]
|
||||
le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]]
|
||||
le_coc_servers: Dict[
|
||||
int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int]
|
||||
]
|
||||
channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]]
|
||||
servers: Dict[int, ClassicChannelServer]
|
||||
le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]]
|
||||
le_coc_servers: Dict[int, LeCreditBasedChannelServer]
|
||||
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
|
||||
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
|
||||
_host: Optional[Host]
|
||||
@@ -1462,21 +1511,6 @@ class ChannelManager:
|
||||
|
||||
raise RuntimeError('no free CID')
|
||||
|
||||
@staticmethod
|
||||
def check_le_coc_parameters(max_credits: int, mtu: int, mps: int) -> None:
|
||||
if (
|
||||
max_credits < 1
|
||||
or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
||||
):
|
||||
raise ValueError('max credits out of range')
|
||||
if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
|
||||
raise ValueError('MTU too small')
|
||||
if (
|
||||
mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
|
||||
or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
|
||||
):
|
||||
raise ValueError('MPS out of range')
|
||||
|
||||
def next_identifier(self, connection: Connection) -> int:
|
||||
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
|
||||
self.identifiers[connection.handle] = identifier
|
||||
@@ -1491,8 +1525,22 @@ class ChannelManager:
|
||||
if cid in self.fixed_channels:
|
||||
del self.fixed_channels[cid]
|
||||
|
||||
def register_server(self, psm: int, server: Callable[[Channel], Any]) -> int:
|
||||
if psm == 0:
|
||||
@deprecated("Please use create_classic_channel_server")
|
||||
def register_server(
|
||||
self,
|
||||
psm: int,
|
||||
server: Callable[[ClassicChannel], Any],
|
||||
) -> int:
|
||||
return self.create_classic_server(
|
||||
handler=server, spec=ClassicChannelSpec(psm=psm)
|
||||
).psm
|
||||
|
||||
def create_classic_server(
|
||||
self,
|
||||
spec: ClassicChannelSpec,
|
||||
handler: Optional[Callable[[ClassicChannel], Any]] = None,
|
||||
) -> ClassicChannelServer:
|
||||
if spec.psm is None:
|
||||
# Find a free PSM
|
||||
for candidate in range(
|
||||
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
|
||||
@@ -1501,62 +1549,75 @@ class ChannelManager:
|
||||
continue
|
||||
if candidate in self.servers:
|
||||
continue
|
||||
psm = candidate
|
||||
spec.psm = candidate
|
||||
break
|
||||
else:
|
||||
raise InvalidStateError('no free PSM')
|
||||
else:
|
||||
# Check that the PSM isn't already in use
|
||||
if psm in self.servers:
|
||||
if spec.psm in self.servers:
|
||||
raise ValueError('PSM already in use')
|
||||
|
||||
# Check that the PSM is valid
|
||||
if psm % 2 == 0:
|
||||
if spec.psm % 2 == 0:
|
||||
raise ValueError('invalid PSM (not odd)')
|
||||
check = psm >> 8
|
||||
check = spec.psm >> 8
|
||||
while check:
|
||||
if check % 2 != 0:
|
||||
raise ValueError('invalid PSM')
|
||||
check >>= 8
|
||||
|
||||
self.servers[psm] = server
|
||||
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
|
||||
|
||||
return psm
|
||||
return self.servers[spec.psm]
|
||||
|
||||
@deprecated("Please use create_le_credit_based_server()")
|
||||
def register_le_coc_server(
|
||||
self,
|
||||
psm: int,
|
||||
server: Callable[[LeConnectionOrientedChannel], Any],
|
||||
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
|
||||
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
|
||||
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
|
||||
server: Callable[[LeCreditBasedChannel], Any],
|
||||
max_credits: int,
|
||||
mtu: int,
|
||||
mps: int,
|
||||
) -> int:
|
||||
self.check_le_coc_parameters(max_credits, mtu, mps)
|
||||
return self.create_le_credit_based_server(
|
||||
spec=LeCreditBasedChannelSpec(
|
||||
psm=None if psm == 0 else psm, mtu=mtu, mps=mps, max_credits=max_credits
|
||||
),
|
||||
handler=server,
|
||||
).psm
|
||||
|
||||
if psm == 0:
|
||||
def create_le_credit_based_server(
|
||||
self,
|
||||
spec: LeCreditBasedChannelSpec,
|
||||
handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None,
|
||||
) -> LeCreditBasedChannelServer:
|
||||
if spec.psm is None:
|
||||
# Find a free PSM
|
||||
for candidate in range(
|
||||
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
|
||||
):
|
||||
if candidate in self.le_coc_servers:
|
||||
continue
|
||||
psm = candidate
|
||||
spec.psm = candidate
|
||||
break
|
||||
else:
|
||||
raise InvalidStateError('no free PSM')
|
||||
else:
|
||||
# Check that the PSM isn't already in use
|
||||
if psm in self.le_coc_servers:
|
||||
if spec.psm in self.le_coc_servers:
|
||||
raise ValueError('PSM already in use')
|
||||
|
||||
self.le_coc_servers[psm] = (
|
||||
server,
|
||||
max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
|
||||
mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
|
||||
mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
|
||||
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
|
||||
self,
|
||||
spec.psm,
|
||||
handler,
|
||||
max_credits=spec.max_credits,
|
||||
mtu=spec.mtu,
|
||||
mps=spec.mps,
|
||||
)
|
||||
|
||||
return psm
|
||||
return self.le_coc_servers[spec.psm]
|
||||
|
||||
def on_disconnection(self, connection_handle: int, _reason: int) -> None:
|
||||
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
|
||||
@@ -1683,13 +1744,13 @@ class ChannelManager:
|
||||
logger.debug(
|
||||
f'creating server channel with cid={source_cid} for psm {request.psm}'
|
||||
)
|
||||
channel = Channel(
|
||||
self, connection, cid, request.psm, source_cid, L2CAP_MIN_BR_EDR_MTU
|
||||
channel = ClassicChannel(
|
||||
self, connection, cid, request.psm, source_cid, server.mtu
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
|
||||
# Notify
|
||||
server(channel)
|
||||
server.on_connection(channel)
|
||||
channel.on_connection_request(request)
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -1911,7 +1972,7 @@ class ChannelManager:
|
||||
self, connection: Connection, cid: int, request
|
||||
) -> None:
|
||||
if request.le_psm in self.le_coc_servers:
|
||||
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
|
||||
server = self.le_coc_servers[request.le_psm]
|
||||
|
||||
# Check that the CID isn't already used
|
||||
le_connection_channels = self.le_coc_channels.setdefault(
|
||||
@@ -1925,8 +1986,8 @@ class ChannelManager:
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier=request.identifier,
|
||||
destination_cid=0,
|
||||
mtu=mtu,
|
||||
mps=mps,
|
||||
mtu=server.mtu,
|
||||
mps=server.mps,
|
||||
initial_credits=0,
|
||||
# pylint: disable=line-too-long
|
||||
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
|
||||
@@ -1944,8 +2005,8 @@ class ChannelManager:
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier=request.identifier,
|
||||
destination_cid=0,
|
||||
mtu=mtu,
|
||||
mps=mps,
|
||||
mtu=server.mtu,
|
||||
mps=server.mps,
|
||||
initial_credits=0,
|
||||
# pylint: disable=line-too-long
|
||||
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
|
||||
@@ -1958,18 +2019,18 @@ class ChannelManager:
|
||||
f'creating LE CoC server channel with cid={source_cid} for psm '
|
||||
f'{request.le_psm}'
|
||||
)
|
||||
channel = LeConnectionOrientedChannel(
|
||||
channel = LeCreditBasedChannel(
|
||||
self,
|
||||
connection,
|
||||
request.le_psm,
|
||||
source_cid,
|
||||
request.source_cid,
|
||||
mtu,
|
||||
mps,
|
||||
server.mtu,
|
||||
server.mps,
|
||||
request.initial_credits,
|
||||
request.mtu,
|
||||
request.mps,
|
||||
max_credits,
|
||||
server.max_credits,
|
||||
True,
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
@@ -1982,16 +2043,16 @@ class ChannelManager:
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier=request.identifier,
|
||||
destination_cid=source_cid,
|
||||
mtu=mtu,
|
||||
mps=mps,
|
||||
initial_credits=max_credits,
|
||||
mtu=server.mtu,
|
||||
mps=server.mps,
|
||||
initial_credits=server.max_credits,
|
||||
# pylint: disable=line-too-long
|
||||
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
|
||||
),
|
||||
)
|
||||
|
||||
# Notify
|
||||
server(channel)
|
||||
server.on_connection(channel)
|
||||
else:
|
||||
logger.info(
|
||||
f'No LE server for connection 0x{connection.handle:04X} '
|
||||
@@ -2046,37 +2107,51 @@ class ChannelManager:
|
||||
|
||||
channel.on_credits(credit.credits)
|
||||
|
||||
def on_channel_closed(self, channel: Channel) -> None:
|
||||
def on_channel_closed(self, channel: ClassicChannel) -> None:
|
||||
connection_channels = self.channels.get(channel.connection.handle)
|
||||
if connection_channels:
|
||||
if channel.source_cid in connection_channels:
|
||||
del connection_channels[channel.source_cid]
|
||||
|
||||
@deprecated("Please use create_le_credit_based_channel()")
|
||||
async def open_le_coc(
|
||||
self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int
|
||||
) -> LeConnectionOrientedChannel:
|
||||
self.check_le_coc_parameters(max_credits, mtu, mps)
|
||||
) -> LeCreditBasedChannel:
|
||||
return await self.create_le_credit_based_channel(
|
||||
connection=connection,
|
||||
spec=LeCreditBasedChannelSpec(
|
||||
psm=psm, max_credits=max_credits, mtu=mtu, mps=mps
|
||||
),
|
||||
)
|
||||
|
||||
async def create_le_credit_based_channel(
|
||||
self,
|
||||
connection: Connection,
|
||||
spec: LeCreditBasedChannelSpec,
|
||||
) -> LeCreditBasedChannel:
|
||||
# Find a free CID for the new channel
|
||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||
source_cid = self.find_free_le_cid(connection_channels)
|
||||
if source_cid is None: # Should never happen!
|
||||
raise RuntimeError('all CIDs already in use')
|
||||
|
||||
if spec.psm is None:
|
||||
raise ValueError('PSM cannot be None')
|
||||
|
||||
# Create the channel
|
||||
logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}')
|
||||
channel = LeConnectionOrientedChannel(
|
||||
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
|
||||
channel = LeCreditBasedChannel(
|
||||
manager=self,
|
||||
connection=connection,
|
||||
le_psm=psm,
|
||||
le_psm=spec.psm,
|
||||
source_cid=source_cid,
|
||||
destination_cid=0,
|
||||
mtu=mtu,
|
||||
mps=mps,
|
||||
mtu=spec.mtu,
|
||||
mps=spec.mps,
|
||||
credits=0,
|
||||
peer_mtu=0,
|
||||
peer_mps=0,
|
||||
peer_credits=max_credits,
|
||||
peer_credits=spec.max_credits,
|
||||
connected=False,
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
@@ -2095,7 +2170,15 @@ class ChannelManager:
|
||||
|
||||
return channel
|
||||
|
||||
async def connect(self, connection: Connection, psm: int) -> Channel:
|
||||
@deprecated("Please use create_classic_channel()")
|
||||
async def connect(self, connection: Connection, psm: int) -> ClassicChannel:
|
||||
return await self.create_classic_channel(
|
||||
connection=connection, spec=ClassicChannelSpec(psm=psm)
|
||||
)
|
||||
|
||||
async def create_classic_channel(
|
||||
self, connection: Connection, spec: ClassicChannelSpec
|
||||
) -> ClassicChannel:
|
||||
# NOTE: this implementation hard-codes BR/EDR
|
||||
|
||||
# Find a free CID for a new channel
|
||||
@@ -2104,10 +2187,20 @@ class ChannelManager:
|
||||
if source_cid is None: # Should never happen!
|
||||
raise RuntimeError('all CIDs already in use')
|
||||
|
||||
if spec.psm is None:
|
||||
raise ValueError('PSM cannot be None')
|
||||
|
||||
# Create the channel
|
||||
logger.debug(f'creating client channel with cid={source_cid} for psm {psm}')
|
||||
channel = Channel(
|
||||
self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU
|
||||
logger.debug(
|
||||
f'creating client channel with cid={source_cid} for psm {spec.psm}'
|
||||
)
|
||||
channel = ClassicChannel(
|
||||
self,
|
||||
connection,
|
||||
L2CAP_SIGNALING_CID,
|
||||
spec.psm,
|
||||
source_cid,
|
||||
spec.mtu,
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
|
||||
@@ -2119,3 +2212,20 @@ class ChannelManager:
|
||||
raise e
|
||||
|
||||
return channel
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Deprecated Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Channel(ClassicChannel):
|
||||
@deprecated("Please use ClassicChannel")
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class LeConnectionOrientedChannel(LeCreditBasedChannel):
|
||||
@deprecated("Please use LeCreditBasedChannel")
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -233,7 +233,11 @@ class SecurityService(SecurityServicer):
|
||||
sc=config.pairing_sc_enable,
|
||||
mitm=config.pairing_mitm_enable,
|
||||
bonding=config.pairing_bonding_enable,
|
||||
identity_address_type=config.identity_address_type,
|
||||
identity_address_type=(
|
||||
PairingConfig.AddressType.PUBLIC
|
||||
if connection.self_address.is_public
|
||||
else config.identity_address_type
|
||||
),
|
||||
delegate=PairingDelegate(
|
||||
connection,
|
||||
self,
|
||||
@@ -446,21 +450,18 @@ class SecurityService(SecurityServicer):
|
||||
'security_request': pair,
|
||||
}
|
||||
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.on(event, listener)
|
||||
with contextlib.closing(EventWatcher()) as watcher:
|
||||
# register event handlers
|
||||
for event, listener in listeners.items():
|
||||
watcher.on(connection, event, listener)
|
||||
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
# security level already reached
|
||||
if self.reached_security_level(connection, level):
|
||||
return WaitSecurityResponse(success=empty_pb2.Empty())
|
||||
|
||||
self.log.debug('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# remove event handlers
|
||||
for event, listener in listeners.items():
|
||||
connection.remove_listener(event, listener) # type: ignore
|
||||
self.log.debug('Wait for security...')
|
||||
kwargs = {}
|
||||
kwargs[await wait_for_security] = empty_pb2.Empty()
|
||||
|
||||
# wait for `authenticate` to finish if any
|
||||
if authenticate_task is not None:
|
||||
|
||||
@@ -674,7 +674,7 @@ class Multiplexer(EventEmitter):
|
||||
acceptor: Optional[Callable[[int], bool]]
|
||||
dlcs: Dict[int, DLC]
|
||||
|
||||
def __init__(self, l2cap_channel: l2cap.Channel, role: Role) -> None:
|
||||
def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
|
||||
super().__init__()
|
||||
self.role = role
|
||||
self.l2cap_channel = l2cap_channel
|
||||
@@ -887,7 +887,7 @@ class Multiplexer(EventEmitter):
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
multiplexer: Optional[Multiplexer]
|
||||
l2cap_channel: Optional[l2cap.Channel]
|
||||
l2cap_channel: Optional[l2cap.ClassicChannel]
|
||||
|
||||
def __init__(self, device: Device, connection: Connection) -> None:
|
||||
self.device = device
|
||||
@@ -960,11 +960,11 @@ class Server(EventEmitter):
|
||||
self.acceptors[channel] = acceptor
|
||||
return channel
|
||||
|
||||
def on_connection(self, l2cap_channel: l2cap.Channel) -> None:
|
||||
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
|
||||
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
|
||||
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None:
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
|
||||
|
||||
# Create a new multiplexer for the channel
|
||||
|
||||
@@ -758,7 +758,7 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
channel: Optional[l2cap.Channel]
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
|
||||
def __init__(self, device: Device) -> None:
|
||||
self.device = device
|
||||
@@ -921,7 +921,7 @@ class Client:
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server:
|
||||
CONTINUATION_STATE = bytes([0x01, 0x43])
|
||||
channel: Optional[l2cap.Channel]
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
Service = NewType('Service', List[ServiceAttribute])
|
||||
service_records: Dict[int, Service]
|
||||
current_response: Union[None, bytes, Tuple[int, List[int]]]
|
||||
|
||||
@@ -97,10 +97,13 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
raise ValueError('invalid mode')
|
||||
|
||||
# Create the transport object
|
||||
transport = PumpedTransport(
|
||||
PumpedPacketSource(hci_device.read),
|
||||
PumpedPacketSink(hci_device.write),
|
||||
channel.close,
|
||||
class EmulatorTransport(PumpedTransport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await channel.close()
|
||||
|
||||
transport = EmulatorTransport(
|
||||
PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write)
|
||||
)
|
||||
transport.start()
|
||||
|
||||
|
||||
@@ -18,11 +18,12 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
import logging
|
||||
import grpc.aio
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import grpc.aio
|
||||
|
||||
from .common import (
|
||||
ParserSource,
|
||||
@@ -33,8 +34,8 @@ from .common import (
|
||||
)
|
||||
|
||||
# pylint: disable=no-name-in-module
|
||||
from .grpc_protobuf.packet_streamer_pb2_grpc import PacketStreamerStub
|
||||
from .grpc_protobuf.packet_streamer_pb2_grpc import (
|
||||
PacketStreamerStub,
|
||||
PacketStreamerServicer,
|
||||
add_PacketStreamerServicer_to_server,
|
||||
)
|
||||
@@ -43,6 +44,7 @@ from .grpc_protobuf.hci_packet_pb2 import HCIPacket
|
||||
from .grpc_protobuf.startup_pb2 import Chip, ChipInfo
|
||||
from .grpc_protobuf.common_pb2 import ChipKind
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -74,14 +76,20 @@ def get_ini_dir() -> Optional[pathlib.Path]:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def find_grpc_port() -> int:
|
||||
def ini_file_name(instance_number: int) -> str:
|
||||
suffix = f'_{instance_number}' if instance_number > 0 else ''
|
||||
return f'netsim{suffix}.ini'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def find_grpc_port(instance_number: int) -> int:
|
||||
if not (ini_dir := get_ini_dir()):
|
||||
logger.debug('no known directory for .ini file')
|
||||
return 0
|
||||
|
||||
ini_file = ini_dir / 'netsim.ini'
|
||||
ini_file = ini_dir / ini_file_name(instance_number)
|
||||
logger.debug(f'Looking for .ini file at {ini_file}')
|
||||
if ini_file.is_file():
|
||||
logger.debug(f'Found .ini file at {ini_file}')
|
||||
with open(ini_file, 'r') as ini_file_data:
|
||||
for line in ini_file_data.readlines():
|
||||
if '=' in line:
|
||||
@@ -90,12 +98,14 @@ def find_grpc_port() -> int:
|
||||
logger.debug(f'gRPC port = {value}')
|
||||
return int(value)
|
||||
|
||||
logger.debug('no grpc.port property found in .ini file')
|
||||
|
||||
# Not found
|
||||
return 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def publish_grpc_port(grpc_port) -> bool:
|
||||
def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
|
||||
if not (ini_dir := get_ini_dir()):
|
||||
logger.debug('no known directory for .ini file')
|
||||
return False
|
||||
@@ -104,7 +114,7 @@ def publish_grpc_port(grpc_port) -> bool:
|
||||
logger.debug('ini directory does not exist')
|
||||
return False
|
||||
|
||||
ini_file = ini_dir / 'netsim.ini'
|
||||
ini_file = ini_dir / ini_file_name(instance_number)
|
||||
try:
|
||||
ini_file.write_text(f'grpc.port={grpc_port}\n')
|
||||
logger.debug(f"published gRPC port at {ini_file}")
|
||||
@@ -122,14 +132,15 @@ def publish_grpc_port(grpc_port) -> bool:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_controller_transport(
|
||||
server_host: Optional[str], server_port: int
|
||||
server_host: Optional[str], server_port: int, options: Dict[str, str]
|
||||
) -> Transport:
|
||||
if not server_port:
|
||||
raise ValueError('invalid port')
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
if not publish_grpc_port(server_port):
|
||||
instance_number = int(options.get('instance', "0"))
|
||||
if not publish_grpc_port(server_port, instance_number):
|
||||
logger.warning("unable to publish gRPC port")
|
||||
|
||||
class HciDevice:
|
||||
@@ -186,15 +197,12 @@ async def open_android_netsim_controller_transport(
|
||||
logger.debug(f'<<< PACKET: {data.hex()}')
|
||||
self.on_data_received(data)
|
||||
|
||||
def send_packet(self, data):
|
||||
async def send():
|
||||
await self.context.write(
|
||||
PacketResponse(
|
||||
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
|
||||
)
|
||||
async def send_packet(self, data):
|
||||
return await self.context.write(
|
||||
PacketResponse(
|
||||
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
|
||||
)
|
||||
|
||||
self.loop.create_task(send())
|
||||
)
|
||||
|
||||
def terminate(self):
|
||||
self.task.cancel()
|
||||
@@ -228,17 +236,17 @@ async def open_android_netsim_controller_transport(
|
||||
logger.debug('gRPC server cancelled')
|
||||
await self.grpc_server.stop(None)
|
||||
|
||||
def on_packet(self, packet):
|
||||
async def send_packet(self, packet):
|
||||
if not self.device:
|
||||
logger.debug('no device, dropping packet')
|
||||
return
|
||||
|
||||
self.device.send_packet(packet)
|
||||
return await self.device.send_packet(packet)
|
||||
|
||||
async def StreamPackets(self, _request_iterator, context):
|
||||
logger.debug('StreamPackets request')
|
||||
|
||||
# Check that we won't already have a device
|
||||
# Check that we don't already have a device
|
||||
if self.device:
|
||||
logger.debug('busy, already serving a device')
|
||||
return PacketResponse(error='Busy')
|
||||
@@ -261,15 +269,42 @@ async def open_android_netsim_controller_transport(
|
||||
await server.start()
|
||||
asyncio.get_running_loop().create_task(server.serve())
|
||||
|
||||
class GrpcServerTransport(Transport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
|
||||
return GrpcServerTransport(server, server)
|
||||
sink = PumpedPacketSink(server.send_packet)
|
||||
sink.start()
|
||||
return Transport(server, sink)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
async def open_android_netsim_host_transport_with_address(
|
||||
server_host: Optional[str],
|
||||
server_port: int,
|
||||
options: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
if not server_port:
|
||||
# Look for the gRPC config in a .ini file
|
||||
instance_number = 0 if options is None else int(options.get('instance', '0'))
|
||||
server_port = find_grpc_port(instance_number)
|
||||
if not server_port:
|
||||
raise RuntimeError('gRPC server port not found')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
logger.debug(f'Connecting to gRPC server at {server_address}')
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
return await open_android_netsim_host_transport_with_channel(
|
||||
channel,
|
||||
options,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_host_transport_with_channel(
|
||||
channel, options: Optional[Dict[str, str]] = None
|
||||
):
|
||||
# Wrapper for I/O operations
|
||||
class HciDevice:
|
||||
def __init__(self, name, manufacturer, hci_device):
|
||||
@@ -288,10 +323,12 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
async def read(self):
|
||||
response = await self.hci_device.read()
|
||||
response_type = response.WhichOneof('response_type')
|
||||
|
||||
if response_type == 'error':
|
||||
logger.warning(f'received error: {response.error}')
|
||||
raise RuntimeError(response.error)
|
||||
elif response_type == 'hci_packet':
|
||||
|
||||
if response_type == 'hci_packet':
|
||||
return (
|
||||
bytes([response.hci_packet.packet_type])
|
||||
+ response.hci_packet.packet
|
||||
@@ -306,24 +343,9 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
)
|
||||
)
|
||||
|
||||
name = options.get('name', DEFAULT_NAME)
|
||||
name = DEFAULT_NAME if options is None else options.get('name', DEFAULT_NAME)
|
||||
manufacturer = DEFAULT_MANUFACTURER
|
||||
|
||||
if server_host == '_' or not server_host:
|
||||
server_host = 'localhost'
|
||||
|
||||
if not server_port:
|
||||
# Look for the gRPC config in a .ini file
|
||||
server_host = 'localhost'
|
||||
server_port = find_grpc_port()
|
||||
if not server_port:
|
||||
raise RuntimeError('gRPC server port not found')
|
||||
|
||||
# Connect to the gRPC server
|
||||
server_address = f'{server_host}:{server_port}'
|
||||
logger.debug(f'Connecting to gRPC server at {server_address}')
|
||||
channel = grpc.aio.insecure_channel(server_address)
|
||||
|
||||
# Connect as a host
|
||||
service = PacketStreamerStub(channel)
|
||||
hci_device = HciDevice(
|
||||
@@ -334,10 +356,14 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
await hci_device.start()
|
||||
|
||||
# Create the transport object
|
||||
transport = PumpedTransport(
|
||||
class GrpcTransport(PumpedTransport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await channel.close()
|
||||
|
||||
transport = GrpcTransport(
|
||||
PumpedPacketSource(hci_device.read),
|
||||
PumpedPacketSink(hci_device.write),
|
||||
channel.close,
|
||||
)
|
||||
transport.start()
|
||||
|
||||
@@ -345,7 +371,7 @@ async def open_android_netsim_host_transport(server_host, server_port, options):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_android_netsim_transport(spec):
|
||||
async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
|
||||
'''
|
||||
Open a transport connection as a client or server, implementing Android's `netsim`
|
||||
simulator protocol over gRPC.
|
||||
@@ -359,6 +385,11 @@ async def open_android_netsim_transport(spec):
|
||||
to connect *to* a netsim server (netsim is the controller), or accept
|
||||
connections *as* a netsim-compatible server.
|
||||
|
||||
instance=<n>
|
||||
Specifies an instance number, with <n> > 0. This is used to determine which
|
||||
.init file to use. In `host` mode, it is ignored when the <host>:<port>
|
||||
specifier is present, since in that case no .ini file is used.
|
||||
|
||||
In `host` mode:
|
||||
The <host>:<port> part is optional. When not specified, the transport
|
||||
looks for a netsim .ini file, from which it will read the `grpc.backend.port`
|
||||
@@ -387,14 +418,15 @@ async def open_android_netsim_transport(spec):
|
||||
params = spec.split(',') if spec else []
|
||||
if params and ':' in params[0]:
|
||||
# Explicit <host>:<port>
|
||||
host, port = params[0].split(':')
|
||||
host, port_str = params[0].split(':')
|
||||
port = int(port_str)
|
||||
params_offset = 1
|
||||
else:
|
||||
host = None
|
||||
port = 0
|
||||
params_offset = 0
|
||||
|
||||
options = {}
|
||||
options: Dict[str, str] = {}
|
||||
for param in params[params_offset:]:
|
||||
if '=' not in param:
|
||||
raise ValueError('invalid parameter, expected <name>=<value>')
|
||||
@@ -403,10 +435,12 @@ async def open_android_netsim_transport(spec):
|
||||
|
||||
mode = options.get('mode', 'host')
|
||||
if mode == 'host':
|
||||
return await open_android_netsim_host_transport(host, port, options)
|
||||
return await open_android_netsim_host_transport_with_address(
|
||||
host, port, options
|
||||
)
|
||||
if mode == 'controller':
|
||||
if host is None:
|
||||
raise ValueError('<host>:<port> missing')
|
||||
return await open_android_netsim_controller_transport(host, port)
|
||||
return await open_android_netsim_controller_transport(host, port, options)
|
||||
|
||||
raise ValueError('invalid mode option')
|
||||
|
||||
@@ -339,8 +339,9 @@ class PumpedPacketSource(ParserSource):
|
||||
try:
|
||||
packet = await self.receive_function()
|
||||
self.parser.feed_data(packet)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
except asyncio.CancelledError:
|
||||
logger.debug('source pump task done')
|
||||
self.terminated.set_result(None)
|
||||
break
|
||||
except Exception as error:
|
||||
logger.warning(f'exception while waiting for packet: {error}')
|
||||
@@ -370,7 +371,7 @@ class PumpedPacketSink:
|
||||
try:
|
||||
packet = await self.packet_queue.get()
|
||||
await self.send_function(packet)
|
||||
except asyncio.exceptions.CancelledError:
|
||||
except asyncio.CancelledError:
|
||||
logger.debug('sink pump task done')
|
||||
break
|
||||
except Exception as error:
|
||||
@@ -393,19 +394,13 @@ class PumpedTransport(Transport):
|
||||
self,
|
||||
source: PumpedPacketSource,
|
||||
sink: PumpedPacketSink,
|
||||
close_function,
|
||||
) -> None:
|
||||
super().__init__(source, sink)
|
||||
self.close_function = close_function
|
||||
|
||||
def start(self) -> None:
|
||||
self.source.start()
|
||||
self.sink.start()
|
||||
|
||||
async def close(self) -> None:
|
||||
await super().close()
|
||||
await self.close_function()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SnoopingTransport(Transport):
|
||||
|
||||
@@ -31,19 +31,21 @@ async def open_ws_client_transport(spec: str) -> Transport:
|
||||
'''
|
||||
Open a WebSocket client transport.
|
||||
The parameter string has this syntax:
|
||||
<remote-host>:<remote-port>
|
||||
<websocket-url>
|
||||
|
||||
Example: 127.0.0.1:9001
|
||||
Example: ws://localhost:7681/v1/websocket/bt
|
||||
'''
|
||||
|
||||
remote_host, remote_port = spec.split(':')
|
||||
uri = f'ws://{remote_host}:{remote_port}'
|
||||
websocket = await websockets.client.connect(uri)
|
||||
websocket = await websockets.client.connect(spec)
|
||||
|
||||
transport = PumpedTransport(
|
||||
class WsTransport(PumpedTransport):
|
||||
async def close(self):
|
||||
await super().close()
|
||||
await websocket.close()
|
||||
|
||||
transport = WsTransport(
|
||||
PumpedPacketSource(websocket.recv),
|
||||
PumpedPacketSink(websocket.send),
|
||||
websocket.close,
|
||||
)
|
||||
transport.start()
|
||||
return transport
|
||||
|
||||
@@ -21,6 +21,7 @@ import logging
|
||||
import traceback
|
||||
import collections
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Set,
|
||||
@@ -33,7 +34,7 @@ from typing import (
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from functools import wraps
|
||||
from functools import wraps, partial
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
@@ -410,3 +411,36 @@ class FlowControlAsyncPipe:
|
||||
self.resume_source()
|
||||
|
||||
self.check_pump()
|
||||
|
||||
|
||||
async def async_call(function, *args, **kwargs):
|
||||
"""
|
||||
Immediately calls the function with provided args and kwargs, wrapping it in an async function.
|
||||
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject a running loop.
|
||||
|
||||
result = await async_call(some_function, ...)
|
||||
"""
|
||||
return function(*args, **kwargs)
|
||||
|
||||
|
||||
def wrap_async(function):
|
||||
"""
|
||||
Wraps the provided function in an async function.
|
||||
"""
|
||||
return partial(async_call, function)
|
||||
|
||||
|
||||
def deprecated(msg: str):
|
||||
"""
|
||||
Throw deprecation warning before execution
|
||||
"""
|
||||
|
||||
def wrapper(function):
|
||||
@wraps(function)
|
||||
def inner(*args, **kwargs):
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
388
rust/Cargo.lock
generated
388
rust/Cargo.lock
generated
@@ -80,6 +80,37 @@ version = "1.0.75"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6"
|
||||
|
||||
[[package]]
|
||||
name = "argh"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7af5ba06967ff7214ce4c7419c7d185be7ecd6cc4965a8f6e1d8ce0398aad219"
|
||||
dependencies = [
|
||||
"argh_derive",
|
||||
"argh_shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "argh_derive"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "56df0aeedf6b7a2fc67d06db35b09684c3e8da0c95f8f27685cb17e08413d87a"
|
||||
dependencies = [
|
||||
"argh_shared",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.29",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "argh_shared"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5693f39141bda5760ecc4111ab08da40565d1771038c4a0250f03457ec707531"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
@@ -130,15 +161,37 @@ version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635"
|
||||
|
||||
[[package]]
|
||||
name = "block-buffer"
|
||||
version = "0.10.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bstr"
|
||||
version = "1.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c2f7349907b712260e64b0afe2f84692af14a454be26187d9df565c7f69266a"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumble"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"clap 4.4.1",
|
||||
"directories",
|
||||
"env_logger",
|
||||
"file-header",
|
||||
"futures",
|
||||
"globset",
|
||||
"hex",
|
||||
"itertools",
|
||||
"lazy_static",
|
||||
@@ -146,6 +199,8 @@ dependencies = [
|
||||
"nix",
|
||||
"nom",
|
||||
"owo-colors",
|
||||
"pdl-derive",
|
||||
"pdl-runtime",
|
||||
"pyo3",
|
||||
"pyo3-asyncio",
|
||||
"rand",
|
||||
@@ -166,9 +221,9 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.4.0"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
|
||||
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
@@ -250,6 +305,16 @@ version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961"
|
||||
|
||||
[[package]]
|
||||
name = "codespan-reporting"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
|
||||
dependencies = [
|
||||
"termcolor",
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.0"
|
||||
@@ -272,6 +337,102 @@ version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-channel",
|
||||
"crossbeam-deque",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-queue",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.5.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"memoffset 0.9.0",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-queue"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crypto-common"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"crypto-common",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "directories"
|
||||
version = "5.0.1"
|
||||
@@ -348,6 +509,19 @@ version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764"
|
||||
|
||||
[[package]]
|
||||
name = "file-header"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5568149106e77ae33bc3a2c3ef3839cbe63ffa4a8dd4a81612a6f9dfdbc2e9f"
|
||||
dependencies = [
|
||||
"crossbeam",
|
||||
"lazy_static",
|
||||
"license",
|
||||
"thiserror",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
@@ -467,6 +641,16 @@ dependencies = [
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.10"
|
||||
@@ -484,6 +668,19 @@ version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
|
||||
|
||||
[[package]]
|
||||
name = "globset"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "759c97c1e17c55525b57192c06a267cda0ac5210b222d6b82189a2338fa1c13d"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"bstr",
|
||||
"fnv",
|
||||
"log",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.21"
|
||||
@@ -710,6 +907,17 @@ dependencies = [
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "license"
|
||||
version = "3.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b66615d42e949152327c402e03cd29dab8bff91ce470381ac2ca6d380d8d9946"
|
||||
dependencies = [
|
||||
"reword",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.5"
|
||||
@@ -756,6 +964,15 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
@@ -939,12 +1156,100 @@ dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pdl-compiler"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee66995739fb9ddd9155767990a54aadd226ee32408a94f99f94883ff445ceba"
|
||||
dependencies = [
|
||||
"argh",
|
||||
"codespan-reporting",
|
||||
"heck",
|
||||
"pest",
|
||||
"pest_derive",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"syn 2.0.29",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pdl-derive"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "113e4a1215c407466b36d2c2f6a6318819d6b22ccdd3acb7bb35e27a68806034"
|
||||
dependencies = [
|
||||
"codespan-reporting",
|
||||
"pdl-compiler",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.29",
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pdl-runtime"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d36c2f9799613babe78eb5cd9a353d527daaba6c3d1f39a1175657a35790732"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94"
|
||||
|
||||
[[package]]
|
||||
name = "pest"
|
||||
version = "2.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c022f1e7b65d6a24c0dbbd5fb344c66881bc01f3e5ae74a1c8100f2f985d98a4"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"thiserror",
|
||||
"ucd-trie",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pest_derive"
|
||||
version = "2.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35513f630d46400a977c4cb58f78e1bfbe01434316e60c37d27b9ad6139c66d8"
|
||||
dependencies = [
|
||||
"pest",
|
||||
"pest_generator",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pest_generator"
|
||||
version = "2.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc9fc1b9e7057baba189b5c626e2d6f40681ae5b6eb064dc7c7834101ec8123a"
|
||||
dependencies = [
|
||||
"pest",
|
||||
"pest_meta",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.29",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pest_meta"
|
||||
version = "2.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1df74e9e7ec4053ceb980e7c0c8bd3594e977fde1af91daba9c928e8e8c6708d"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"pest",
|
||||
"sha2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.13"
|
||||
@@ -969,6 +1274,16 @@ version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c64d9ba0963cdcea2e1b2230fbae2bab30eb25a174be395c41e764bfb65dd62"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn 2.0.29",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.66"
|
||||
@@ -1200,6 +1515,15 @@ dependencies = [
|
||||
"winreg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "reword"
|
||||
version = "7.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe272098dce9ed76b479995953f748d1851261390b08f8a0ff619c885a1f0765"
|
||||
dependencies = [
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rusb"
|
||||
version = "0.9.3"
|
||||
@@ -1241,6 +1565,15 @@ version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.22"
|
||||
@@ -1322,6 +1655,17 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.1"
|
||||
@@ -1568,6 +1912,18 @@ version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
|
||||
|
||||
[[package]]
|
||||
name = "ucd-trie"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.13"
|
||||
@@ -1589,6 +1945,18 @@ dependencies = [
|
||||
"tinyvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85"
|
||||
|
||||
[[package]]
|
||||
name = "unindent"
|
||||
version = "0.1.11"
|
||||
@@ -1618,6 +1986,22 @@ version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "want"
|
||||
version = "0.3.1"
|
||||
|
||||
@@ -23,6 +23,13 @@ hex = "0.4.3"
|
||||
itertools = "0.11.0"
|
||||
lazy_static = "1.4.0"
|
||||
thiserror = "1.0.41"
|
||||
bytes = "1.5.0"
|
||||
pdl-derive = "0.2.0"
|
||||
pdl-runtime = "0.2.0"
|
||||
|
||||
# Dev tools
|
||||
file-header = { version = "0.1.2", optional = true }
|
||||
globset = { version = "0.4.13", optional = true }
|
||||
|
||||
# CLI
|
||||
anyhow = { version = "1.0.71", optional = true }
|
||||
@@ -52,10 +59,15 @@ env_logger = "0.10.0"
|
||||
[package.metadata.docs.rs]
|
||||
rustdoc-args = ["--generate-link-to-definition"]
|
||||
|
||||
[[bin]]
|
||||
name = "file-header"
|
||||
path = "tools/file_header.rs"
|
||||
required-features = ["dev-tools"]
|
||||
|
||||
[[bin]]
|
||||
name = "gen-assigned-numbers"
|
||||
path = "tools/gen_assigned_numbers.rs"
|
||||
required-features = ["bumble-codegen"]
|
||||
required-features = ["dev-tools"]
|
||||
|
||||
[[bin]]
|
||||
name = "bumble"
|
||||
@@ -71,7 +83,7 @@ harness = false
|
||||
[features]
|
||||
anyhow = ["pyo3/anyhow"]
|
||||
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
|
||||
bumble-codegen = ["dep:anyhow"]
|
||||
dev-tools = ["dep:anyhow", "dep:clap", "dep:file-header", "dep:globset"]
|
||||
# separate feature for CLI so that dependencies don't spend time building these
|
||||
bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"]
|
||||
default = []
|
||||
|
||||
@@ -62,5 +62,5 @@ in tests at `pytests/assigned_numbers.rs`.
|
||||
To regenerate the assigned number tables based on the Python codebase:
|
||||
|
||||
```
|
||||
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features bumble-codegen
|
||||
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features dev-tools
|
||||
```
|
||||
@@ -20,7 +20,8 @@
|
||||
use bumble::{
|
||||
adv::CommonDataType,
|
||||
wrapper::{
|
||||
core::AdvertisementDataUnit, device::Device, hci::AddressType, transport::Transport,
|
||||
core::AdvertisementDataUnit, device::Device, hci::packets::AddressType,
|
||||
transport::Transport,
|
||||
},
|
||||
};
|
||||
use clap::Parser as _;
|
||||
@@ -69,9 +70,7 @@ async fn main() -> PyResult<()> {
|
||||
let mut seen_adv_cache = seen_adv_clone.lock().unwrap();
|
||||
let expiry_duration = time::Duration::from_secs(cli.dedup_expiry_secs);
|
||||
|
||||
let advs_from_addr = seen_adv_cache
|
||||
.entry(addr_bytes)
|
||||
.or_insert_with(collections::HashMap::new);
|
||||
let advs_from_addr = seen_adv_cache.entry(addr_bytes).or_default();
|
||||
// we expect cache hits to be the norm, so we do a separate lookup to avoid cloning
|
||||
// on every lookup with entry()
|
||||
let show = if let Some(prev) = advs_from_addr.get_mut(&data_units) {
|
||||
@@ -102,7 +101,9 @@ async fn main() -> PyResult<()> {
|
||||
};
|
||||
|
||||
let (type_style, qualifier) = match adv.address()?.address_type()? {
|
||||
AddressType::PublicIdentity | AddressType::PublicDevice => (Style::new().cyan(), ""),
|
||||
AddressType::PublicIdentityAddress | AddressType::PublicDeviceAddress => {
|
||||
(Style::new().cyan(), "")
|
||||
}
|
||||
_ => {
|
||||
if addr.is_static()? {
|
||||
(Style::new().green(), "(static)")
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use bumble::wrapper::{self, core::Uuid16};
|
||||
use pyo3::{intern, prelude::*, types::PyDict};
|
||||
use std::collections;
|
||||
|
||||
@@ -12,9 +12,26 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use bumble::wrapper::{drivers::rtk::DriverInfo, transport::Transport};
|
||||
use bumble::wrapper::{
|
||||
controller::Controller,
|
||||
device::Device,
|
||||
drivers::rtk::DriverInfo,
|
||||
hci::{
|
||||
packets::{
|
||||
AddressType, ErrorCode, ReadLocalVersionInformationBuilder,
|
||||
ReadLocalVersionInformationComplete,
|
||||
},
|
||||
Address, Error,
|
||||
},
|
||||
host::Host,
|
||||
link::Link,
|
||||
transport::Transport,
|
||||
};
|
||||
use nix::sys::stat::Mode;
|
||||
use pyo3::PyResult;
|
||||
use pyo3::{
|
||||
exceptions::PyException,
|
||||
{PyErr, PyResult},
|
||||
};
|
||||
|
||||
#[pyo3_asyncio::tokio::test]
|
||||
async fn fifo_transport_can_open() -> PyResult<()> {
|
||||
@@ -35,3 +52,26 @@ async fn realtek_driver_info_all_drivers() -> PyResult<()> {
|
||||
assert_eq!(12, DriverInfo::all_drivers()?.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pyo3_asyncio::tokio::test]
|
||||
async fn hci_command_wrapper_has_correct_methods() -> PyResult<()> {
|
||||
let address = Address::new("F0:F1:F2:F3:F4:F5", &AddressType::RandomDeviceAddress)?;
|
||||
let link = Link::new_local_link()?;
|
||||
let controller = Controller::new("C1", None, None, Some(link), Some(address.clone())).await?;
|
||||
let host = Host::new(controller.clone().into(), controller.into()).await?;
|
||||
let device = Device::new(None, Some(address), None, Some(host), None)?;
|
||||
|
||||
device.power_on().await?;
|
||||
|
||||
// Send some simple command. A successful response means [HciCommandWrapper] has the minimum
|
||||
// required interface for the Python code to think its an [HCI_Command] object.
|
||||
let command = ReadLocalVersionInformationBuilder {};
|
||||
let event: ReadLocalVersionInformationComplete = device
|
||||
.send_command(&command.into(), true)
|
||||
.await?
|
||||
.try_into()
|
||||
.map_err(|e: Error| PyErr::new::<PyException, _>(e.to_string()))?;
|
||||
|
||||
assert_eq!(ErrorCode::Success, event.get_status());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! BLE advertisements.
|
||||
|
||||
use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS};
|
||||
|
||||
@@ -179,7 +179,7 @@ pub(crate) fn parse(firmware_path: &path::Path) -> PyResult<()> {
|
||||
pub(crate) async fn info(transport: &str, force: bool) -> PyResult<()> {
|
||||
let transport = Transport::open(transport).await?;
|
||||
|
||||
let mut host = Host::new(transport.source()?, transport.sink()?)?;
|
||||
let mut host = Host::new(transport.source()?, transport.sink()?).await?;
|
||||
host.reset(DriverFactory::None).await?;
|
||||
|
||||
if !force && !Driver::check(&host).await? {
|
||||
@@ -203,7 +203,7 @@ pub(crate) async fn info(transport: &str, force: bool) -> PyResult<()> {
|
||||
pub(crate) async fn load(transport: &str, force: bool) -> PyResult<()> {
|
||||
let transport = Transport::open(transport).await?;
|
||||
|
||||
let mut host = Host::new(transport.source()?, transport.sink()?)?;
|
||||
let mut host = Host::new(transport.source()?, transport.sink()?).await?;
|
||||
host.reset(DriverFactory::None).await?;
|
||||
|
||||
match Driver::for_host(&host, force).await? {
|
||||
@@ -219,7 +219,7 @@ pub(crate) async fn load(transport: &str, force: bool) -> PyResult<()> {
|
||||
pub(crate) async fn drop(transport: &str) -> PyResult<()> {
|
||||
let transport = Transport::open(transport).await?;
|
||||
|
||||
let mut host = Host::new(transport.source()?, transport.sink()?)?;
|
||||
let mut host = Host::new(transport.source()?, transport.sink()?).await?;
|
||||
host.reset(DriverFactory::None).await?;
|
||||
|
||||
Driver::drop_firmware(&mut host).await?;
|
||||
|
||||
@@ -21,8 +21,7 @@
|
||||
/// TCP client to connect.
|
||||
/// When the L2CAP CoC channel is closed, the TCP connection is closed as well.
|
||||
use crate::cli::l2cap::{
|
||||
proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, run_future_with_current_task_locals,
|
||||
BridgeData,
|
||||
inject_py_event_loop, proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, BridgeData,
|
||||
};
|
||||
use bumble::wrapper::{
|
||||
device::{Connection, Device},
|
||||
@@ -85,11 +84,12 @@ pub async fn start(args: &Args, device: &mut Device) -> PyResult<()> {
|
||||
let mtu = args.mtu;
|
||||
let mps = args.mps;
|
||||
let ble_connection = Arc::new(Mutex::new(ble_connection));
|
||||
// Ensure Python event loop is available to l2cap `disconnect`
|
||||
let _ = run_future_with_current_task_locals(async move {
|
||||
// spawn thread to handle incoming tcp connections
|
||||
tokio::spawn(inject_py_event_loop(async move {
|
||||
while let Ok((tcp_stream, addr)) = listener.accept().await {
|
||||
let ble_connection = ble_connection.clone();
|
||||
let _ = run_future_with_current_task_locals(proxy_data_between_tcp_and_l2cap(
|
||||
// spawn thread to handle this specific tcp connection
|
||||
if let Ok(future) = inject_py_event_loop(proxy_data_between_tcp_and_l2cap(
|
||||
ble_connection,
|
||||
tcp_stream,
|
||||
addr,
|
||||
@@ -97,10 +97,11 @@ pub async fn start(args: &Args, device: &mut Device) -> PyResult<()> {
|
||||
max_credits,
|
||||
mtu,
|
||||
mps,
|
||||
));
|
||||
)) {
|
||||
tokio::spawn(future);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
})?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::L2cap;
|
||||
use anyhow::anyhow;
|
||||
use bumble::wrapper::{device::Device, l2cap::LeConnectionOrientedChannel, transport::Transport};
|
||||
use owo_colors::{colors::css::Orange, OwoColorize};
|
||||
use pyo3::{PyObject, PyResult, Python};
|
||||
use pyo3::{PyResult, Python};
|
||||
use std::{future::Future, path::PathBuf, sync::Arc};
|
||||
use tokio::{
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
@@ -170,21 +170,12 @@ async fn proxy_tcp_rx_to_l2cap_tx(
|
||||
}
|
||||
}
|
||||
|
||||
/// Copies the current thread's TaskLocals into a Python "awaitable" and encapsulates it in a Rust
|
||||
/// future, running it as a Python Task.
|
||||
/// `TaskLocals` stores the current event loop, and allows the user to copy the current Python
|
||||
/// context if necessary. In this case, the python event loop is used when calling `disconnect` on
|
||||
/// an l2cap connection, or else the call will fail.
|
||||
pub fn run_future_with_current_task_locals<F>(
|
||||
fut: F,
|
||||
) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send>
|
||||
/// Copies the current thread's Python even loop (contained in `TaskLocals`) into the given future.
|
||||
/// Useful when sending work to another thread that calls Python code which calls `get_running_loop()`.
|
||||
pub fn inject_py_event_loop<F, R>(fut: F) -> PyResult<impl Future<Output = R>>
|
||||
where
|
||||
F: Future<Output = PyResult<()>> + Send + 'static,
|
||||
F: Future<Output = R> + Send + 'static,
|
||||
{
|
||||
Python::with_gil(|py| {
|
||||
let locals = pyo3_asyncio::tokio::get_current_locals(py)?;
|
||||
let future = pyo3_asyncio::tokio::scope(locals.clone(), fut);
|
||||
pyo3_asyncio::tokio::future_into_py_with_locals(py, locals, future)
|
||||
.and_then(pyo3_asyncio::tokio::into_future)
|
||||
})
|
||||
let locals = Python::with_gil(pyo3_asyncio::tokio::get_current_locals)?;
|
||||
Ok(pyo3_asyncio::tokio::scope(locals, fut))
|
||||
}
|
||||
|
||||
@@ -19,10 +19,7 @@
|
||||
/// When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
|
||||
/// and waits for a new L2CAP CoC channel to be connected.
|
||||
/// When the TCP connection is closed by the TCP server, the L2CAP connection is closed as well.
|
||||
use crate::cli::l2cap::{
|
||||
proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, run_future_with_current_task_locals,
|
||||
BridgeData,
|
||||
};
|
||||
use crate::cli::l2cap::{proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, BridgeData};
|
||||
use bumble::wrapper::{device::Device, hci::HciConstant, l2cap::LeConnectionOrientedChannel};
|
||||
use futures::executor::block_on;
|
||||
use owo_colors::OwoColorize;
|
||||
@@ -49,19 +46,19 @@ pub async fn start(args: &Args, device: &mut Device) -> PyResult<()> {
|
||||
let port = args.tcp_port;
|
||||
device.register_l2cap_channel_server(
|
||||
args.psm,
|
||||
move |_py, l2cap_channel| {
|
||||
move |py, l2cap_channel| {
|
||||
let channel_info = l2cap_channel
|
||||
.debug_string()
|
||||
.unwrap_or_else(|e| format!("failed to get l2cap channel info ({e})"));
|
||||
println!("{} {channel_info}", "*** L2CAP channel:".cyan());
|
||||
|
||||
let host = host.clone();
|
||||
// Ensure Python event loop is available to l2cap `disconnect`
|
||||
let _ = run_future_with_current_task_locals(proxy_data_between_l2cap_and_tcp(
|
||||
l2cap_channel,
|
||||
host,
|
||||
port,
|
||||
));
|
||||
// Handles setting up a tokio runtime that runs this future to completion while also
|
||||
// containing the necessary context vars.
|
||||
pyo3_asyncio::tokio::future_into_py(
|
||||
py,
|
||||
proxy_data_between_l2cap_and_tcp(l2cap_channel, host, port),
|
||||
)?;
|
||||
Ok(())
|
||||
},
|
||||
args.max_credits,
|
||||
|
||||
@@ -143,10 +143,7 @@ pub(crate) fn probe(verbose: bool) -> anyhow::Result<()> {
|
||||
);
|
||||
if let Some(s) = serial {
|
||||
println!("{:26}{}", " Serial:".green(), s);
|
||||
device_serials_by_id
|
||||
.entry(device_id)
|
||||
.or_insert(HashSet::new())
|
||||
.insert(s);
|
||||
device_serials_by_id.entry(device_id).or_default().insert(s);
|
||||
}
|
||||
if let Some(m) = mfg {
|
||||
println!("{:26}{}", " Manufacturer:".green(), m);
|
||||
@@ -314,7 +311,7 @@ impl ClassInfo {
|
||||
self.protocol,
|
||||
self.protocol_name()
|
||||
.map(|s| format!(" [{}]", s))
|
||||
.unwrap_or_else(String::new)
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
161
rust/src/internal/hci/mod.rs
Normal file
161
rust/src/internal/hci/mod.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub use pdl_runtime::{Error, Packet};
|
||||
|
||||
use crate::internal::hci::packets::{Acl, Command, Event, Sco};
|
||||
use pdl_derive::pdl;
|
||||
|
||||
#[allow(missing_docs, warnings, clippy::all)]
|
||||
#[pdl("src/internal/hci/packets.pdl")]
|
||||
pub mod packets {}
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// HCI Packet type, prepended to the packet.
|
||||
/// Rootcanal's PDL declaration excludes this from ser/deser and instead is implemented in code.
|
||||
/// To maintain the ability to easily use future versions of their packet PDL, packet type is
|
||||
/// implemented here.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub(crate) enum PacketType {
|
||||
Command = 0x01,
|
||||
Acl = 0x02,
|
||||
Sco = 0x03,
|
||||
Event = 0x04,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for PacketType {
|
||||
type Error = PacketTypeParseError;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
0x01 => Ok(PacketType::Command),
|
||||
0x02 => Ok(PacketType::Acl),
|
||||
0x03 => Ok(PacketType::Sco),
|
||||
0x04 => Ok(PacketType::Event),
|
||||
_ => Err(PacketTypeParseError::InvalidPacketType { value }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PacketType> for u8 {
|
||||
fn from(packet_type: PacketType) -> Self {
|
||||
match packet_type {
|
||||
PacketType::Command => 0x01,
|
||||
PacketType::Acl => 0x02,
|
||||
PacketType::Sco => 0x03,
|
||||
PacketType::Event => 0x04,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allows for smoother interoperability between a [Packet] and a bytes representation of it that
|
||||
/// includes its type as a header
|
||||
pub(crate) trait WithPacketType<T: Packet> {
|
||||
/// Converts the [Packet] into bytes, prefixed with its type
|
||||
fn to_vec_with_packet_type(self) -> Vec<u8>;
|
||||
|
||||
/// Parses a [Packet] out of bytes that are prefixed with the packet's type
|
||||
fn parse_with_packet_type(bytes: &[u8]) -> Result<T, PacketTypeParseError>;
|
||||
}
|
||||
|
||||
/// Errors that may arise when parsing a packet that is prefixed with its type
|
||||
#[derive(Debug, PartialEq, thiserror::Error)]
|
||||
pub(crate) enum PacketTypeParseError {
|
||||
#[error("The slice being parsed was empty")]
|
||||
EmptySlice,
|
||||
#[error("Packet type ({value:#X}) is invalid")]
|
||||
InvalidPacketType { value: u8 },
|
||||
#[error("Expected packet type: {expected:?}, but got: {actual:?}")]
|
||||
PacketTypeMismatch {
|
||||
expected: PacketType,
|
||||
actual: PacketType,
|
||||
},
|
||||
#[error("Failed to parse packet after header: {error}")]
|
||||
PacketParse { error: Error },
|
||||
}
|
||||
|
||||
impl From<Error> for PacketTypeParseError {
|
||||
fn from(error: Error) -> Self {
|
||||
Self::PacketParse { error }
|
||||
}
|
||||
}
|
||||
|
||||
impl WithPacketType<Self> for Command {
|
||||
fn to_vec_with_packet_type(self) -> Vec<u8> {
|
||||
prepend_packet_type(PacketType::Command, self.to_vec())
|
||||
}
|
||||
|
||||
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
|
||||
parse_with_expected_packet_type(Command::parse, PacketType::Command, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl WithPacketType<Self> for Acl {
|
||||
fn to_vec_with_packet_type(self) -> Vec<u8> {
|
||||
prepend_packet_type(PacketType::Acl, self.to_vec())
|
||||
}
|
||||
|
||||
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
|
||||
parse_with_expected_packet_type(Acl::parse, PacketType::Acl, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl WithPacketType<Self> for Sco {
|
||||
fn to_vec_with_packet_type(self) -> Vec<u8> {
|
||||
prepend_packet_type(PacketType::Sco, self.to_vec())
|
||||
}
|
||||
|
||||
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
|
||||
parse_with_expected_packet_type(Sco::parse, PacketType::Sco, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl WithPacketType<Self> for Event {
|
||||
fn to_vec_with_packet_type(self) -> Vec<u8> {
|
||||
prepend_packet_type(PacketType::Event, self.to_vec())
|
||||
}
|
||||
|
||||
fn parse_with_packet_type(bytes: &[u8]) -> Result<Self, PacketTypeParseError> {
|
||||
parse_with_expected_packet_type(Event::parse, PacketType::Event, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
fn prepend_packet_type(packet_type: PacketType, mut packet_bytes: Vec<u8>) -> Vec<u8> {
|
||||
packet_bytes.insert(0, packet_type.into());
|
||||
packet_bytes
|
||||
}
|
||||
|
||||
fn parse_with_expected_packet_type<T: Packet, F, E>(
|
||||
parser: F,
|
||||
expected_packet_type: PacketType,
|
||||
bytes: &[u8],
|
||||
) -> Result<T, PacketTypeParseError>
|
||||
where
|
||||
F: Fn(&[u8]) -> Result<T, E>,
|
||||
PacketTypeParseError: From<E>,
|
||||
{
|
||||
let (first_byte, packet_bytes) = bytes
|
||||
.split_first()
|
||||
.ok_or(PacketTypeParseError::EmptySlice)?;
|
||||
let actual_packet_type = PacketType::try_from(*first_byte)?;
|
||||
if actual_packet_type == expected_packet_type {
|
||||
Ok(parser(packet_bytes)?)
|
||||
} else {
|
||||
Err(PacketTypeParseError::PacketTypeMismatch {
|
||||
expected: expected_packet_type,
|
||||
actual: actual_packet_type,
|
||||
})
|
||||
}
|
||||
}
|
||||
6253
rust/src/internal/hci/packets.pdl
Normal file
6253
rust/src/internal/hci/packets.pdl
Normal file
File diff suppressed because it is too large
Load Diff
94
rust/src/internal/hci/tests.rs
Normal file
94
rust/src/internal/hci/tests.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::internal::hci::{
|
||||
packets::{Event, EventBuilder, EventCode, Sco},
|
||||
parse_with_expected_packet_type, prepend_packet_type, Error, Packet, PacketType,
|
||||
PacketTypeParseError, WithPacketType,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
|
||||
#[test]
|
||||
fn prepends_packet_type() {
|
||||
let packet_type = PacketType::Event;
|
||||
let packet_bytes = vec![0x00, 0x00, 0x00, 0x00];
|
||||
let actual = prepend_packet_type(packet_type, packet_bytes);
|
||||
assert_eq!(vec![0x04, 0x00, 0x00, 0x00, 0x00], actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_empty_slice_should_error() {
|
||||
let actual = parse_with_expected_packet_type(FakePacket::parse, PacketType::Event, &[]);
|
||||
assert_eq!(Err(PacketTypeParseError::EmptySlice), actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_invalid_packet_type_should_error() {
|
||||
let actual = parse_with_expected_packet_type(FakePacket::parse, PacketType::Event, &[0xFF]);
|
||||
assert_eq!(
|
||||
Err(PacketTypeParseError::InvalidPacketType { value: 0xFF }),
|
||||
actual
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_mismatched_packet_type_should_error() {
|
||||
let actual = parse_with_expected_packet_type(FakePacket::parse, PacketType::Acl, &[0x01]);
|
||||
assert_eq!(
|
||||
Err(PacketTypeParseError::PacketTypeMismatch {
|
||||
expected: PacketType::Acl,
|
||||
actual: PacketType::Command
|
||||
}),
|
||||
actual
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_invalid_packet_should_error() {
|
||||
let actual = parse_with_expected_packet_type(Sco::parse, PacketType::Sco, &[0x03]);
|
||||
assert!(actual.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_packet_roundtrip_with_type() {
|
||||
let event_packet = EventBuilder {
|
||||
event_code: EventCode::InquiryComplete,
|
||||
payload: None,
|
||||
}
|
||||
.build();
|
||||
let event_packet_bytes = event_packet.clone().to_vec_with_packet_type();
|
||||
let actual =
|
||||
parse_with_expected_packet_type(Event::parse, PacketType::Event, &event_packet_bytes)
|
||||
.unwrap();
|
||||
assert_eq!(event_packet, actual);
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct FakePacket;
|
||||
|
||||
impl FakePacket {
|
||||
fn parse(_bytes: &[u8]) -> Result<Self, Error> {
|
||||
Ok(Self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Packet for FakePacket {
|
||||
fn to_bytes(self) -> Bytes {
|
||||
Bytes::new()
|
||||
}
|
||||
|
||||
fn to_vec(self) -> Vec<u8> {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
@@ -18,3 +18,4 @@
|
||||
//! to discover.
|
||||
|
||||
pub(crate) mod drivers;
|
||||
pub(crate) mod hci;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
34
rust/src/wrapper/common.rs
Normal file
34
rust/src/wrapper/common.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Shared resources found under bumble's common.py
|
||||
use pyo3::{PyObject, Python, ToPyObject};
|
||||
|
||||
/// Represents the sink for some transport mechanism
|
||||
pub struct TransportSink(pub(crate) PyObject);
|
||||
|
||||
impl ToPyObject for TransportSink {
|
||||
fn to_object(&self, _py: Python<'_>) -> PyObject {
|
||||
self.0.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the source for some transport mechanism
|
||||
pub struct TransportSource(pub(crate) PyObject);
|
||||
|
||||
impl ToPyObject for TransportSource {
|
||||
fn to_object(&self, _py: Python<'_>) -> PyObject {
|
||||
self.0.clone()
|
||||
}
|
||||
}
|
||||
66
rust/src/wrapper/controller.rs
Normal file
66
rust/src/wrapper/controller.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Controller components
|
||||
use crate::wrapper::{
|
||||
common::{TransportSink, TransportSource},
|
||||
hci::Address,
|
||||
link::Link,
|
||||
wrap_python_async, PyDictExt,
|
||||
};
|
||||
use pyo3::{
|
||||
intern,
|
||||
types::{PyDict, PyModule},
|
||||
PyObject, PyResult, Python,
|
||||
};
|
||||
use pyo3_asyncio::tokio::into_future;
|
||||
|
||||
/// A controller that can send and receive HCI frames via some link
|
||||
#[derive(Clone)]
|
||||
pub struct Controller(pub(crate) PyObject);
|
||||
|
||||
impl Controller {
|
||||
/// Creates a new [Controller] object. When optional arguments are not specified, the Python
|
||||
/// module specifies the defaults. Must be called from a thread with a Python event loop, which
|
||||
/// should be true on `tokio::main` and `async_std::main`.
|
||||
///
|
||||
/// For more info, see https://awestlake87.github.io/pyo3-asyncio/master/doc/pyo3_asyncio/#event-loop-references-and-contextvars.
|
||||
pub async fn new(
|
||||
name: &str,
|
||||
host_source: Option<TransportSource>,
|
||||
host_sink: Option<TransportSink>,
|
||||
link: Option<Link>,
|
||||
public_address: Option<Address>,
|
||||
) -> PyResult<Self> {
|
||||
Python::with_gil(|py| {
|
||||
let controller_ctr = PyModule::import(py, intern!(py, "bumble.controller"))?
|
||||
.getattr(intern!(py, "Controller"))?;
|
||||
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("name", name)?;
|
||||
kwargs.set_opt_item("host_source", host_source)?;
|
||||
kwargs.set_opt_item("host_sink", host_sink)?;
|
||||
kwargs.set_opt_item("link", link)?;
|
||||
kwargs.set_opt_item("public_address", public_address)?;
|
||||
|
||||
// Controller constructor (`__init__`) is not (and can't be) marked async, but calls
|
||||
// `get_running_loop`, and thus needs wrapped in an async function.
|
||||
wrap_python_async(py, controller_ctr)?
|
||||
.call((), Some(kwargs))
|
||||
.and_then(into_future)
|
||||
})?
|
||||
.await
|
||||
.map(Self)
|
||||
}
|
||||
}
|
||||
@@ -14,12 +14,16 @@
|
||||
|
||||
//! Devices and connections to them
|
||||
|
||||
use crate::internal::hci::WithPacketType;
|
||||
use crate::{
|
||||
adv::AdvertisementDataBuilder,
|
||||
wrapper::{
|
||||
core::AdvertisingData,
|
||||
gatt_client::{ProfileServiceProxy, ServiceProxy},
|
||||
hci::{Address, HciErrorCode},
|
||||
hci::{
|
||||
packets::{Command, ErrorCode, Event},
|
||||
Address, HciCommandWrapper,
|
||||
},
|
||||
host::Host,
|
||||
l2cap::LeConnectionOrientedChannel,
|
||||
transport::{Sink, Source},
|
||||
@@ -27,18 +31,73 @@ use crate::{
|
||||
},
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::PyException,
|
||||
intern,
|
||||
types::{PyDict, PyModule},
|
||||
IntoPy, PyObject, PyResult, Python, ToPyObject,
|
||||
IntoPy, PyErr, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
use pyo3_asyncio::tokio::into_future;
|
||||
use std::path;
|
||||
|
||||
/// Represents the various properties of some device
|
||||
pub struct DeviceConfiguration(PyObject);
|
||||
|
||||
impl DeviceConfiguration {
|
||||
/// Creates a new configuration, letting the internal Python object set all the defaults
|
||||
pub fn new() -> PyResult<DeviceConfiguration> {
|
||||
Python::with_gil(|py| {
|
||||
PyModule::import(py, intern!(py, "bumble.device"))?
|
||||
.getattr(intern!(py, "DeviceConfiguration"))?
|
||||
.call0()
|
||||
.map(|any| Self(any.into()))
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a new configuration from the specified file
|
||||
pub fn load_from_file(&mut self, device_config: &path::Path) -> PyResult<()> {
|
||||
Python::with_gil(|py| {
|
||||
self.0
|
||||
.call_method1(py, intern!(py, "load_from_file"), (device_config,))
|
||||
})
|
||||
.map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
impl ToPyObject for DeviceConfiguration {
|
||||
fn to_object(&self, _py: Python<'_>) -> PyObject {
|
||||
self.0.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// A device that can send/receive HCI frames.
|
||||
#[derive(Clone)]
|
||||
pub struct Device(PyObject);
|
||||
|
||||
impl Device {
|
||||
/// Creates a Device. When optional arguments are not specified, the Python object specifies the
|
||||
/// defaults.
|
||||
pub fn new(
|
||||
name: Option<&str>,
|
||||
address: Option<Address>,
|
||||
config: Option<DeviceConfiguration>,
|
||||
host: Option<Host>,
|
||||
generic_access_service: Option<bool>,
|
||||
) -> PyResult<Self> {
|
||||
Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_opt_item("name", name)?;
|
||||
kwargs.set_opt_item("address", address)?;
|
||||
kwargs.set_opt_item("config", config)?;
|
||||
kwargs.set_opt_item("host", host)?;
|
||||
kwargs.set_opt_item("generic_access_service", generic_access_service)?;
|
||||
|
||||
PyModule::import(py, intern!(py, "bumble.device"))?
|
||||
.getattr(intern!(py, "Device"))?
|
||||
.call((), Some(kwargs))
|
||||
.map(|any| Self(any.into()))
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a Device per the provided file configured to communicate with a controller through an HCI source/sink
|
||||
pub fn from_config_file_with_hci(
|
||||
device_config: &path::Path,
|
||||
@@ -66,6 +125,29 @@ impl Device {
|
||||
})
|
||||
}
|
||||
|
||||
/// Sends an HCI command on this Device, returning the command's event result.
|
||||
pub async fn send_command(&self, command: &Command, check_result: bool) -> PyResult<Event> {
|
||||
Python::with_gil(|py| {
|
||||
self.0
|
||||
.call_method1(
|
||||
py,
|
||||
intern!(py, "send_command"),
|
||||
(HciCommandWrapper(command.clone()), check_result),
|
||||
)
|
||||
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
|
||||
})?
|
||||
.await
|
||||
.and_then(|event| {
|
||||
Python::with_gil(|py| {
|
||||
let py_bytes = event.call_method0(py, intern!(py, "__bytes__"))?;
|
||||
let bytes: &[u8] = py_bytes.extract(py)?;
|
||||
let event = Event::parse_with_packet_type(bytes)
|
||||
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))?;
|
||||
Ok(event)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Turn the device on
|
||||
pub async fn power_on(&self) -> PyResult<()> {
|
||||
Python::with_gil(|py| {
|
||||
@@ -236,7 +318,7 @@ impl Connection {
|
||||
kwargs.set_opt_item("mps", mps)?;
|
||||
self.0
|
||||
.call_method(py, intern!(py, "open_l2cap_channel"), (), Some(kwargs))
|
||||
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
|
||||
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
|
||||
})?
|
||||
.await
|
||||
.map(LeConnectionOrientedChannel::from)
|
||||
@@ -244,13 +326,13 @@ impl Connection {
|
||||
|
||||
/// Disconnect from device with provided reason. When optional arguments are not specified, the
|
||||
/// Python module specifies the defaults.
|
||||
pub async fn disconnect(&mut self, reason: Option<HciErrorCode>) -> PyResult<()> {
|
||||
pub async fn disconnect(&mut self, reason: Option<ErrorCode>) -> PyResult<()> {
|
||||
Python::with_gil(|py| {
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_opt_item("reason", reason)?;
|
||||
self.0
|
||||
.call_method(py, intern!(py, "disconnect"), (), Some(kwargs))
|
||||
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
|
||||
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
|
||||
})?
|
||||
.await
|
||||
.map(|_| ())
|
||||
@@ -259,7 +341,7 @@ impl Connection {
|
||||
/// Register a callback to be called on disconnection.
|
||||
pub fn on_disconnection(
|
||||
&mut self,
|
||||
callback: impl Fn(Python, HciErrorCode) -> PyResult<()> + Send + 'static,
|
||||
callback: impl Fn(Python, ErrorCode) -> PyResult<()> + Send + 'static,
|
||||
) -> PyResult<()> {
|
||||
let boxed = ClosureCallback::new(move |py, args, _kwargs| {
|
||||
callback(py, args.get_item(0)?.extract()?)
|
||||
|
||||
@@ -14,84 +14,62 @@
|
||||
|
||||
//! HCI
|
||||
|
||||
pub use crate::internal::hci::{packets, Error, Packet};
|
||||
|
||||
use crate::{
|
||||
internal::hci::WithPacketType,
|
||||
wrapper::hci::packets::{AddressType, Command, ErrorCode},
|
||||
};
|
||||
use itertools::Itertools as _;
|
||||
use pyo3::{
|
||||
exceptions::PyException, intern, types::PyModule, FromPyObject, PyAny, PyErr, PyObject,
|
||||
PyResult, Python, ToPyObject,
|
||||
exceptions::PyException,
|
||||
intern, pyclass, pymethods,
|
||||
types::{PyBytes, PyModule},
|
||||
FromPyObject, IntoPy, PyAny, PyErr, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
|
||||
/// HCI error code.
|
||||
pub struct HciErrorCode(u8);
|
||||
|
||||
impl<'source> FromPyObject<'source> for HciErrorCode {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
Ok(HciErrorCode(ob.extract()?))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToPyObject for HciErrorCode {
|
||||
fn to_object(&self, py: Python<'_>) -> PyObject {
|
||||
self.0.to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides helpers for interacting with HCI
|
||||
pub struct HciConstant;
|
||||
|
||||
impl HciConstant {
|
||||
/// Human-readable error name
|
||||
pub fn error_name(status: HciErrorCode) -> PyResult<String> {
|
||||
pub fn error_name(status: ErrorCode) -> PyResult<String> {
|
||||
Python::with_gil(|py| {
|
||||
PyModule::import(py, intern!(py, "bumble.hci"))?
|
||||
.getattr(intern!(py, "HCI_Constant"))?
|
||||
.call_method1(intern!(py, "error_name"), (status.0,))?
|
||||
.call_method1(intern!(py, "error_name"), (status.to_object(py),))?
|
||||
.extract()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A Bluetooth address
|
||||
#[derive(Clone)]
|
||||
pub struct Address(pub(crate) PyObject);
|
||||
|
||||
impl Address {
|
||||
/// Creates a new [Address] object
|
||||
pub fn new(address: &str, address_type: &AddressType) -> PyResult<Self> {
|
||||
Python::with_gil(|py| {
|
||||
PyModule::import(py, intern!(py, "bumble.device"))?
|
||||
.getattr(intern!(py, "Address"))?
|
||||
.call1((address, address_type.to_object(py)))
|
||||
.map(|any| Self(any.into()))
|
||||
})
|
||||
}
|
||||
|
||||
/// The type of address
|
||||
pub fn address_type(&self) -> PyResult<AddressType> {
|
||||
Python::with_gil(|py| {
|
||||
let addr_type = self
|
||||
.0
|
||||
self.0
|
||||
.getattr(py, intern!(py, "address_type"))?
|
||||
.extract::<u32>(py)?;
|
||||
|
||||
let module = PyModule::import(py, intern!(py, "bumble.hci"))?;
|
||||
let klass = module.getattr(intern!(py, "Address"))?;
|
||||
|
||||
if addr_type
|
||||
== klass
|
||||
.getattr(intern!(py, "PUBLIC_DEVICE_ADDRESS"))?
|
||||
.extract::<u32>()?
|
||||
{
|
||||
Ok(AddressType::PublicDevice)
|
||||
} else if addr_type
|
||||
== klass
|
||||
.getattr(intern!(py, "RANDOM_DEVICE_ADDRESS"))?
|
||||
.extract::<u32>()?
|
||||
{
|
||||
Ok(AddressType::RandomDevice)
|
||||
} else if addr_type
|
||||
== klass
|
||||
.getattr(intern!(py, "PUBLIC_IDENTITY_ADDRESS"))?
|
||||
.extract::<u32>()?
|
||||
{
|
||||
Ok(AddressType::PublicIdentity)
|
||||
} else if addr_type
|
||||
== klass
|
||||
.getattr(intern!(py, "RANDOM_IDENTITY_ADDRESS"))?
|
||||
.extract::<u32>()?
|
||||
{
|
||||
Ok(AddressType::RandomIdentity)
|
||||
} else {
|
||||
Err(PyErr::new::<PyException, _>("Invalid address type"))
|
||||
}
|
||||
.extract::<u8>(py)?
|
||||
.try_into()
|
||||
.map_err(|addr_type| {
|
||||
PyErr::new::<PyException, _>(format!(
|
||||
"Failed to convert {addr_type} to AddressType"
|
||||
))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -134,12 +112,45 @@ impl Address {
|
||||
}
|
||||
}
|
||||
|
||||
/// BT address types
|
||||
#[allow(missing_docs)]
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub enum AddressType {
|
||||
PublicDevice,
|
||||
RandomDevice,
|
||||
PublicIdentity,
|
||||
RandomIdentity,
|
||||
impl ToPyObject for Address {
|
||||
fn to_object(&self, _py: Python<'_>) -> PyObject {
|
||||
self.0.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements minimum necessary interface to be treated as bumble's [HCI_Command].
|
||||
/// While pyo3's macros do not support generics, this could probably be refactored to allow multiple
|
||||
/// implementations of the HCI_Command methods in the future, if needed.
|
||||
#[pyclass]
|
||||
pub(crate) struct HciCommandWrapper(pub(crate) Command);
|
||||
|
||||
#[pymethods]
|
||||
impl HciCommandWrapper {
|
||||
fn __bytes__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let bytes = PyBytes::new(py, &self.0.clone().to_vec_with_packet_type());
|
||||
Ok(bytes.into_py(py))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn op_code(&self) -> u16 {
|
||||
self.0.get_op_code().into()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToPyObject for AddressType {
|
||||
fn to_object(&self, py: Python<'_>) -> PyObject {
|
||||
u8::from(self).to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'source> FromPyObject<'source> for ErrorCode {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
ob.extract()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToPyObject for ErrorCode {
|
||||
fn to_object(&self, py: Python<'_>) -> PyObject {
|
||||
u8::from(self).to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,8 +14,12 @@
|
||||
|
||||
//! Host-side types
|
||||
|
||||
use crate::wrapper::transport::{Sink, Source};
|
||||
use pyo3::{intern, prelude::PyModule, types::PyDict, PyObject, PyResult, Python};
|
||||
use crate::wrapper::{
|
||||
transport::{Sink, Source},
|
||||
wrap_python_async,
|
||||
};
|
||||
use pyo3::{intern, prelude::PyModule, types::PyDict, PyObject, PyResult, Python, ToPyObject};
|
||||
use pyo3_asyncio::tokio::into_future;
|
||||
|
||||
/// Host HCI commands
|
||||
pub struct Host {
|
||||
@@ -29,13 +33,23 @@ impl Host {
|
||||
}
|
||||
|
||||
/// Create a new Host
|
||||
pub fn new(source: Source, sink: Sink) -> PyResult<Self> {
|
||||
pub async fn new(source: Source, sink: Sink) -> PyResult<Self> {
|
||||
Python::with_gil(|py| {
|
||||
PyModule::import(py, intern!(py, "bumble.host"))?
|
||||
.getattr(intern!(py, "Host"))?
|
||||
.call((source.0, sink.0), None)
|
||||
.map(|any| Self { obj: any.into() })
|
||||
})
|
||||
let host_ctr =
|
||||
PyModule::import(py, intern!(py, "bumble.host"))?.getattr(intern!(py, "Host"))?;
|
||||
|
||||
let kwargs = PyDict::new(py);
|
||||
kwargs.set_item("controller_source", source.0)?;
|
||||
kwargs.set_item("controller_sink", sink.0)?;
|
||||
|
||||
// Needed for Python 3.8-3.9, in which the Semaphore object, when constructed, calls
|
||||
// `get_event_loop`.
|
||||
wrap_python_async(py, host_ctr)?
|
||||
.call((), Some(kwargs))
|
||||
.and_then(into_future)
|
||||
})?
|
||||
.await
|
||||
.map(|any| Self { obj: any })
|
||||
}
|
||||
|
||||
/// Send a reset command and perform other reset tasks.
|
||||
@@ -61,6 +75,12 @@ impl Host {
|
||||
}
|
||||
}
|
||||
|
||||
impl ToPyObject for Host {
|
||||
fn to_object(&self, _py: Python<'_>) -> PyObject {
|
||||
self.obj.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Driver factory to use when initializing a host
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DriverFactory {
|
||||
|
||||
38
rust/src/wrapper/link.rs
Normal file
38
rust/src/wrapper/link.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Link components
|
||||
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python, ToPyObject};
|
||||
|
||||
/// Link bus for controllers to communicate with each other
|
||||
#[derive(Clone)]
|
||||
pub struct Link(pub(crate) PyObject);
|
||||
|
||||
impl Link {
|
||||
/// Creates a [Link] object that transports messages locally
|
||||
pub fn new_local_link() -> PyResult<Self> {
|
||||
Python::with_gil(|py| {
|
||||
PyModule::import(py, intern!(py, "bumble.link"))?
|
||||
.getattr(intern!(py, "LocalLink"))?
|
||||
.call0()
|
||||
.map(|any| Self(any.into()))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ToPyObject for Link {
|
||||
fn to_object(&self, _py: Python<'_>) -> PyObject {
|
||||
self.0.clone()
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,17 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Bumble & Python logging
|
||||
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
@@ -22,13 +22,17 @@
|
||||
|
||||
// Re-exported to make it easy for users to depend on the same `PyObject`, etc
|
||||
pub use pyo3;
|
||||
pub use pyo3_asyncio;
|
||||
|
||||
use pyo3::{
|
||||
intern,
|
||||
prelude::*,
|
||||
types::{PyDict, PyTuple},
|
||||
};
|
||||
pub use pyo3_asyncio;
|
||||
|
||||
pub mod assigned_numbers;
|
||||
pub mod common;
|
||||
pub mod controller;
|
||||
pub mod core;
|
||||
pub mod device;
|
||||
pub mod drivers;
|
||||
@@ -36,6 +40,7 @@ pub mod gatt_client;
|
||||
pub mod hci;
|
||||
pub mod host;
|
||||
pub mod l2cap;
|
||||
pub mod link;
|
||||
pub mod logging;
|
||||
pub mod profile;
|
||||
pub mod transport;
|
||||
@@ -119,3 +124,11 @@ impl ClosureCallback {
|
||||
(self.callback)(py, args, kwargs).map(|_| py.None())
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps the Python function in a Python async function. `pyo3_asyncio` needs functions to be
|
||||
/// marked async to properly inject a running loop.
|
||||
pub(crate) fn wrap_python_async<'a>(py: Python<'a>, function: &'a PyAny) -> PyResult<&'a PyAny> {
|
||||
PyModule::import(py, intern!(py, "bumble.utils"))?
|
||||
.getattr(intern!(py, "wrap_async"))?
|
||||
.call1((function,))
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
//! HCI packet transport
|
||||
|
||||
use crate::wrapper::controller::Controller;
|
||||
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python};
|
||||
|
||||
/// A source/sink pair for HCI packet I/O.
|
||||
@@ -67,6 +68,18 @@ impl Drop for Transport {
|
||||
#[derive(Clone)]
|
||||
pub struct Source(pub(crate) PyObject);
|
||||
|
||||
impl From<Controller> for Source {
|
||||
fn from(value: Controller) -> Self {
|
||||
Self(value.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// The sink side of a [Transport].
|
||||
#[derive(Clone)]
|
||||
pub struct Sink(pub(crate) PyObject);
|
||||
|
||||
impl From<Controller> for Sink {
|
||||
fn from(value: Controller) -> Self {
|
||||
Self(value.0)
|
||||
}
|
||||
}
|
||||
|
||||
78
rust/tools/file_header.rs
Normal file
78
rust/tools/file_header.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
// Copyright 2023 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::anyhow;
|
||||
use clap::Parser as _;
|
||||
use file_header::{
|
||||
add_headers_recursively, check_headers_recursively,
|
||||
license::spdx::{YearCopyrightOwnerValue, APACHE_2_0},
|
||||
};
|
||||
use globset::{Glob, GlobSet, GlobSetBuilder};
|
||||
use std::{env, path::PathBuf};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let rust_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
|
||||
let ignore_globset = ignore_globset()?;
|
||||
// Note: when adding headers, there is a bug where the line spacing is off for Apache 2.0 (see https://github.com/spdx/license-list-XML/issues/2127)
|
||||
let header = APACHE_2_0.build_header(YearCopyrightOwnerValue::new(2023, "Google LLC".into()));
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.subcommand {
|
||||
Subcommand::CheckAll => {
|
||||
let result =
|
||||
check_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header, 4)?;
|
||||
if result.has_failure() {
|
||||
return Err(anyhow!(
|
||||
"The following files do not have headers: {result:?}"
|
||||
));
|
||||
}
|
||||
}
|
||||
Subcommand::AddAll => {
|
||||
let files_with_new_header =
|
||||
add_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header)?;
|
||||
files_with_new_header
|
||||
.iter()
|
||||
.for_each(|path| println!("Added header to: {path:?}"));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ignore_globset() -> anyhow::Result<GlobSet> {
|
||||
Ok(GlobSetBuilder::new()
|
||||
.add(Glob::new("**/.idea/**")?)
|
||||
.add(Glob::new("**/target/**")?)
|
||||
.add(Glob::new("**/.gitignore")?)
|
||||
.add(Glob::new("**/CHANGELOG.md")?)
|
||||
.add(Glob::new("**/Cargo.lock")?)
|
||||
.add(Glob::new("**/Cargo.toml")?)
|
||||
.add(Glob::new("**/README.md")?)
|
||||
.add(Glob::new("*.bin")?)
|
||||
.build()?)
|
||||
}
|
||||
|
||||
#[derive(clap::Parser)]
|
||||
struct Cli {
|
||||
#[clap(subcommand)]
|
||||
subcommand: Subcommand,
|
||||
}
|
||||
|
||||
#[derive(clap::Subcommand, Debug, Clone)]
|
||||
enum Subcommand {
|
||||
/// Checks if a license is present in files that are not in the ignore list.
|
||||
CheckAll,
|
||||
/// Adds a license as needed to files that are not in the ignore list.
|
||||
AddAll,
|
||||
}
|
||||
@@ -36,6 +36,10 @@ install_requires =
|
||||
bt-test-interfaces >= 0.0.2; platform_system!='Emscripten'
|
||||
click == 8.1.3; platform_system!='Emscripten'
|
||||
cryptography == 39; platform_system!='Emscripten'
|
||||
# Pyodide bundles a version of cryptography that is built for wasm, which may not match the
|
||||
# versions available on PyPI. Relax the version requirement since it's better than being
|
||||
# completely unable to import the package in case of version mismatch.
|
||||
cryptography >= 39.0; platform_system=='Emscripten'
|
||||
grpcio == 1.57.0; platform_system!='Emscripten'
|
||||
humanize >= 4.6.0; platform_system!='Emscripten'
|
||||
libusb1 >= 2.0.1; platform_system!='Emscripten'
|
||||
@@ -87,9 +91,13 @@ development =
|
||||
mypy == 1.5.0
|
||||
nox >= 2022
|
||||
pylint == 2.15.8
|
||||
pyyaml >= 6.0
|
||||
types-appdirs >= 1.4.3
|
||||
types-invoke >= 1.7.3
|
||||
types-protobuf >= 4.21.0
|
||||
avatar =
|
||||
pandora-avatar == 0.0.5
|
||||
rootcanal == 1.3.0 ; python_version>='3.10'
|
||||
documentation =
|
||||
mkdocs >= 1.4.0
|
||||
mkdocs-material >= 8.5.6
|
||||
|
||||
@@ -45,12 +45,14 @@ def test_messages():
|
||||
]
|
||||
message = Get_Capabilities_Response(capabilities)
|
||||
parsed = Message.create(
|
||||
AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload
|
||||
AVDTP_GET_CAPABILITIES, Message.MessageType.RESPONSE_ACCEPT, message.payload
|
||||
)
|
||||
assert message.payload == parsed.payload
|
||||
|
||||
message = Set_Configuration_Command(3, 4, capabilities)
|
||||
parsed = Message.create(AVDTP_SET_CONFIGURATION, Message.COMMAND, message.payload)
|
||||
parsed = Message.create(
|
||||
AVDTP_SET_CONFIGURATION, Message.MessageType.COMMAND, message.payload
|
||||
)
|
||||
assert message.payload == parsed.payload
|
||||
|
||||
|
||||
|
||||
@@ -891,10 +891,10 @@ async def async_main():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_attribute_string_to_permissions():
|
||||
assert Attribute.string_to_permissions('READABLE') == 1
|
||||
assert Attribute.string_to_permissions('WRITEABLE') == 2
|
||||
assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3
|
||||
def test_permissions_from_string():
|
||||
assert Attribute.Permissions.from_string('READABLE') == 1
|
||||
assert Attribute.Permissions.from_string('WRITEABLE') == 2
|
||||
assert Attribute.Permissions.from_string('READABLE,WRITEABLE') == 3
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -14,25 +14,25 @@
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# This script generates a python-syntax list of dictionary entries for the
|
||||
# company IDs listed at: https://www.bluetooth.com/specifications/assigned-numbers/company-identifiers/
|
||||
# The input to this script is the CSV file that can be obtained at that URL
|
||||
# company IDs listed at:
|
||||
# https://bitbucket.org/bluetooth-SIG/public/src/main/assigned_numbers/company_identifiers/company_identifiers.yaml
|
||||
# The input to this script is the YAML file that can be obtained at that URL
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import sys
|
||||
import csv
|
||||
import yaml
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
with open(sys.argv[1], newline='') as csvfile:
|
||||
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
|
||||
lines = []
|
||||
for row in reader:
|
||||
if len(row) == 3 and row[1].startswith('0x'):
|
||||
company_id = row[1]
|
||||
company_name = row[2]
|
||||
escaped_company_name = company_name.replace('"', '\\"')
|
||||
lines.append(f' {company_id}: "{escaped_company_name}"')
|
||||
with open(sys.argv[1], "r") as yaml_file:
|
||||
root = yaml.safe_load(yaml_file)
|
||||
companies = {}
|
||||
for company in root["company_identifiers"]:
|
||||
companies[company["value"]] = company["name"]
|
||||
|
||||
print(',\n'.join(reversed(lines)))
|
||||
for company_id in sorted(companies.keys()):
|
||||
company_name = companies[company_id]
|
||||
escaped_company_name = company_name.replace('"', '\\"')
|
||||
print(f' 0x{company_id:04X}: "{escaped_company_name}",')
|
||||
|
||||
@@ -74,7 +74,6 @@ export async function loadBumble(pyodide, bumblePackage) {
|
||||
await pyodide.loadPackage("micropip");
|
||||
await pyodide.runPythonAsync(`
|
||||
import micropip
|
||||
await micropip.install("cryptography")
|
||||
await micropip.install("${bumblePackage}")
|
||||
package_list = micropip.list()
|
||||
print(package_list)
|
||||
|
||||
@@ -23,7 +23,7 @@ from bumble.device import Device
|
||||
# -----------------------------------------------------------------------------
|
||||
class ScanEntry:
|
||||
def __init__(self, advertisement):
|
||||
self.address = str(advertisement.address).replace("/P", "")
|
||||
self.address = advertisement.address.to_string(False)
|
||||
self.address_type = ('Public', 'Random', 'Public Identity', 'Random Identity')[
|
||||
advertisement.address.address_type
|
||||
]
|
||||
|
||||
@@ -171,7 +171,7 @@ class Speaker:
|
||||
self.connection = connection
|
||||
connection.on('disconnection', self.on_bluetooth_disconnection)
|
||||
peer_name = '' if connection.peer_name is None else connection.peer_name
|
||||
peer_address = str(connection.peer_address).replace('/P', '')
|
||||
peer_address = connection.peer_address.to_string(False)
|
||||
self.emit_event(
|
||||
'connection', {'peer_name': peer_name, 'peer_address': peer_address}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user