forked from auracaster/bumble_mirror
Compare commits
22 Commits
gbg/multi-
...
gbg/issue-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6c7cae661 | ||
|
|
d290df4aa9 | ||
|
|
e559744f32 | ||
|
|
67418e649a | ||
|
|
5adf9fab53 | ||
|
|
2491b686fa | ||
|
|
efd02b2f3e | ||
|
|
3b14078646 | ||
|
|
eb9d5632bc | ||
|
|
45f60edbb6 | ||
|
|
393ea6a7bb | ||
|
|
6ec6f1efe5 | ||
|
|
5d9598ea51 | ||
|
|
0d36d99a73 | ||
|
|
d8a9f5a724 | ||
|
|
2c66e1a042 | ||
|
|
d5eccdb00f | ||
|
|
32626573a6 | ||
|
|
caa82b8f7e | ||
|
|
5af347b499 | ||
|
|
4ed5bb5a9e | ||
|
|
f39f5f531c |
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
@@ -65,6 +65,8 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
components: clippy,rustfmt
|
components: clippy,rustfmt
|
||||||
toolchain: ${{ matrix.rust-version }}
|
toolchain: ${{ matrix.rust-version }}
|
||||||
|
- name: Check License Headers
|
||||||
|
run: cd rust && cargo run --features dev-tools --bin file-header check-all
|
||||||
- name: Rust Build
|
- name: Rust Build
|
||||||
run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
|
run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
|
||||||
# Lints after build so what clippy needs is already built
|
# Lints after build so what clippy needs is already built
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import click
|
|||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from bumble.pandora import PandoraDevice, serve
|
from bumble.pandora import PandoraDevice, Config, serve
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
BUMBLE_SERVER_GRPC_PORT = 7999
|
BUMBLE_SERVER_GRPC_PORT = 7999
|
||||||
@@ -29,12 +29,14 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No
|
|||||||
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
|
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
|
||||||
|
|
||||||
bumble_config = retrieve_config(config)
|
bumble_config = retrieve_config(config)
|
||||||
if 'transport' not in bumble_config.keys():
|
bumble_config.setdefault('transport', transport)
|
||||||
bumble_config.update({'transport': transport})
|
|
||||||
device = PandoraDevice(bumble_config)
|
device = PandoraDevice(bumble_config)
|
||||||
|
|
||||||
|
server_config = Config()
|
||||||
|
server_config.load_from_dict(bumble_config.get('server', {}))
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
asyncio.run(serve(device, port=grpc_port))
|
asyncio.run(serve(device, config=server_config, port=grpc_port))
|
||||||
|
|
||||||
|
|
||||||
def retrieve_config(config: str) -> Dict[str, Any]:
|
def retrieve_config(config: str) -> Dict[str, Any]:
|
||||||
|
|||||||
@@ -23,13 +23,14 @@
|
|||||||
# Imports
|
# Imports
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import struct
|
import struct
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
from typing import Dict, Type, TYPE_CHECKING
|
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
|
||||||
|
|
||||||
from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
|
from bumble.core import UUID, name_or_number, ProtocolError
|
||||||
from bumble.hci import HCI_Object, key_with_value, HCI_Constant
|
from bumble.hci import HCI_Object, key_with_value
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -182,6 +183,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
|
|||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Exceptions
|
# Exceptions
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -209,7 +211,7 @@ class ATT_PDU:
|
|||||||
|
|
||||||
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
|
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
|
||||||
op_code = 0
|
op_code = 0
|
||||||
name = None
|
name: str
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_bytes(pdu):
|
def from_bytes(pdu):
|
||||||
@@ -719,9 +721,18 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
class ConnectionValue(Protocol):
|
||||||
|
def read(self, connection) -> bytes:
|
||||||
|
...
|
||||||
|
|
||||||
|
def write(self, connection, value: bytes) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Attribute(EventEmitter):
|
class Attribute(EventEmitter):
|
||||||
# Permission flags
|
class Permissions(enum.IntFlag):
|
||||||
READABLE = 0x01
|
READABLE = 0x01
|
||||||
WRITEABLE = 0x02
|
WRITEABLE = 0x02
|
||||||
READ_REQUIRES_ENCRYPTION = 0x04
|
READ_REQUIRES_ENCRYPTION = 0x04
|
||||||
@@ -731,36 +742,47 @@ class Attribute(EventEmitter):
|
|||||||
READ_REQUIRES_AUTHORIZATION = 0x40
|
READ_REQUIRES_AUTHORIZATION = 0x40
|
||||||
WRITE_REQUIRES_AUTHORIZATION = 0x80
|
WRITE_REQUIRES_AUTHORIZATION = 0x80
|
||||||
|
|
||||||
PERMISSION_NAMES = {
|
@classmethod
|
||||||
READABLE: 'READABLE',
|
def from_string(cls, permissions_str: str) -> Attribute.Permissions:
|
||||||
WRITEABLE: 'WRITEABLE',
|
|
||||||
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
|
|
||||||
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
|
|
||||||
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
|
|
||||||
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
|
|
||||||
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
|
|
||||||
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def string_to_permissions(permissions_str: str):
|
|
||||||
try:
|
try:
|
||||||
return functools.reduce(
|
return functools.reduce(
|
||||||
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
|
lambda x, y: x | Attribute.Permissions[y],
|
||||||
permissions_str.split(","),
|
permissions_str.replace('|', ',').split(","),
|
||||||
0,
|
Attribute.Permissions(0),
|
||||||
)
|
)
|
||||||
except TypeError as exc:
|
except TypeError as exc:
|
||||||
|
# The check for `p.name is not None` here is needed because for InFlag
|
||||||
|
# enums, the .name property can be None, when the enum value is 0,
|
||||||
|
# so the type hint for .name is Optional[str].
|
||||||
|
enum_list: List[str] = [p.name for p in cls if p.name is not None]
|
||||||
|
enum_list_str = ",".join(enum_list)
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
|
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
def __init__(self, attribute_type, permissions, value=b''):
|
# Permission flags(legacy-use only)
|
||||||
|
READABLE = Permissions.READABLE
|
||||||
|
WRITEABLE = Permissions.WRITEABLE
|
||||||
|
READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
|
||||||
|
WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
|
||||||
|
READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
|
||||||
|
WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
|
||||||
|
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
|
||||||
|
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
|
||||||
|
|
||||||
|
value: Union[str, bytes, ConnectionValue]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
attribute_type: Union[str, bytes, UUID],
|
||||||
|
permissions: Union[str, Attribute.Permissions],
|
||||||
|
value: Union[str, bytes, ConnectionValue] = b'',
|
||||||
|
) -> None:
|
||||||
EventEmitter.__init__(self)
|
EventEmitter.__init__(self)
|
||||||
self.handle = 0
|
self.handle = 0
|
||||||
self.end_group_handle = 0
|
self.end_group_handle = 0
|
||||||
if isinstance(permissions, str):
|
if isinstance(permissions, str):
|
||||||
self.permissions = self.string_to_permissions(permissions)
|
self.permissions = Attribute.Permissions.from_string(permissions)
|
||||||
else:
|
else:
|
||||||
self.permissions = permissions
|
self.permissions = permissions
|
||||||
|
|
||||||
@@ -778,22 +800,26 @@ class Attribute(EventEmitter):
|
|||||||
else:
|
else:
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def encode_value(self, value):
|
def encode_value(self, value: Any) -> bytes:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def decode_value(self, value_bytes):
|
def decode_value(self, value_bytes: bytes) -> Any:
|
||||||
return value_bytes
|
return value_bytes
|
||||||
|
|
||||||
def read_value(self, connection: Connection):
|
def read_value(self, connection: Optional[Connection]) -> bytes:
|
||||||
if (
|
if (
|
||||||
self.permissions & self.READ_REQUIRES_ENCRYPTION
|
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
|
||||||
) and not connection.encryption:
|
and connection is not None
|
||||||
|
and not connection.encryption
|
||||||
|
):
|
||||||
raise ATT_Error(
|
raise ATT_Error(
|
||||||
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
|
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.permissions & self.READ_REQUIRES_AUTHENTICATION
|
(self.permissions & self.READ_REQUIRES_AUTHENTICATION)
|
||||||
) and not connection.authenticated:
|
and connection is not None
|
||||||
|
and not connection.authenticated
|
||||||
|
):
|
||||||
raise ATT_Error(
|
raise ATT_Error(
|
||||||
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
|
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
|
||||||
)
|
)
|
||||||
@@ -803,9 +829,9 @@ class Attribute(EventEmitter):
|
|||||||
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
|
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
|
||||||
)
|
)
|
||||||
|
|
||||||
if read := getattr(self.value, 'read', None):
|
if hasattr(self.value, 'read'):
|
||||||
try:
|
try:
|
||||||
value = read(connection) # pylint: disable=not-callable
|
value = self.value.read(connection)
|
||||||
except ATT_Error as error:
|
except ATT_Error as error:
|
||||||
raise ATT_Error(
|
raise ATT_Error(
|
||||||
error_code=error.error_code, att_handle=self.handle
|
error_code=error.error_code, att_handle=self.handle
|
||||||
@@ -815,7 +841,7 @@ class Attribute(EventEmitter):
|
|||||||
|
|
||||||
return self.encode_value(value)
|
return self.encode_value(value)
|
||||||
|
|
||||||
def write_value(self, connection: Connection, value_bytes):
|
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
|
||||||
if (
|
if (
|
||||||
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
|
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
|
||||||
) and not connection.encryption:
|
) and not connection.encryption:
|
||||||
@@ -836,9 +862,9 @@ class Attribute(EventEmitter):
|
|||||||
|
|
||||||
value = self.decode_value(value_bytes)
|
value = self.decode_value(value_bytes)
|
||||||
|
|
||||||
if write := getattr(self.value, 'write', None):
|
if hasattr(self.value, 'write'):
|
||||||
try:
|
try:
|
||||||
write(connection, value) # pylint: disable=not-callable
|
self.value.write(connection, value) # pylint: disable=not-callable
|
||||||
except ATT_Error as error:
|
except ATT_Error as error:
|
||||||
raise ATT_Error(
|
raise ATT_Error(
|
||||||
error_code=error.error_code, att_handle=self.handle
|
error_code=error.error_code, att_handle=self.handle
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class BaseError(Exception):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
error_code: int | None,
|
error_code: Optional[int],
|
||||||
error_namespace: str = '',
|
error_namespace: str = '',
|
||||||
error_name: str = '',
|
error_name: str = '',
|
||||||
details: str = '',
|
details: str = '',
|
||||||
|
|||||||
@@ -2758,7 +2758,9 @@ class Device(CompositeEventEmitter):
|
|||||||
self.abort_on(
|
self.abort_on(
|
||||||
'flush',
|
'flush',
|
||||||
self.start_advertising(
|
self.start_advertising(
|
||||||
advertising_type=self.advertising_type, auto_restart=True
|
advertising_type=self.advertising_type,
|
||||||
|
own_address_type=self.advertising_own_address_type,
|
||||||
|
auto_restart=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import enum
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from typing import Optional, Sequence, List
|
from typing import Optional, Sequence, Iterable, List, Union
|
||||||
|
|
||||||
from .colors import color
|
from .colors import color
|
||||||
from .core import UUID, get_dict_key_by_value
|
from .core import UUID, get_dict_key_by_value
|
||||||
@@ -187,7 +187,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def show_services(services):
|
def show_services(services: Iterable[Service]) -> None:
|
||||||
for service in services:
|
for service in services:
|
||||||
print(color(str(service), 'cyan'))
|
print(color(str(service), 'cyan'))
|
||||||
|
|
||||||
@@ -210,11 +210,11 @@ class Service(Attribute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uuid,
|
uuid: Union[str, UUID],
|
||||||
characteristics: List[Characteristic],
|
characteristics: List[Characteristic],
|
||||||
primary=True,
|
primary=True,
|
||||||
included_services: List[Service] = [],
|
included_services: List[Service] = [],
|
||||||
):
|
) -> None:
|
||||||
# Convert the uuid to a UUID object if it isn't already
|
# Convert the uuid to a UUID object if it isn't already
|
||||||
if isinstance(uuid, str):
|
if isinstance(uuid, str):
|
||||||
uuid = UUID(uuid)
|
uuid = UUID(uuid)
|
||||||
@@ -239,7 +239,7 @@ class Service(Attribute):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f'Service(handle=0x{self.handle:04X}, '
|
f'Service(handle=0x{self.handle:04X}, '
|
||||||
f'end=0x{self.end_group_handle:04X}, '
|
f'end=0x{self.end_group_handle:04X}, '
|
||||||
@@ -255,9 +255,11 @@ class TemplateService(Service):
|
|||||||
to expose their UUID as a class property
|
to expose their UUID as a class property
|
||||||
'''
|
'''
|
||||||
|
|
||||||
UUID: Optional[UUID] = None
|
UUID: UUID
|
||||||
|
|
||||||
def __init__(self, characteristics, primary=True):
|
def __init__(
|
||||||
|
self, characteristics: List[Characteristic], primary: bool = True
|
||||||
|
) -> None:
|
||||||
super().__init__(self.UUID, characteristics, primary)
|
super().__init__(self.UUID, characteristics, primary)
|
||||||
|
|
||||||
|
|
||||||
@@ -269,7 +271,7 @@ class IncludedServiceDeclaration(Attribute):
|
|||||||
|
|
||||||
service: Service
|
service: Service
|
||||||
|
|
||||||
def __init__(self, service):
|
def __init__(self, service: Service) -> None:
|
||||||
declaration_bytes = struct.pack(
|
declaration_bytes = struct.pack(
|
||||||
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
|
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
|
||||||
)
|
)
|
||||||
@@ -278,7 +280,7 @@ class IncludedServiceDeclaration(Attribute):
|
|||||||
)
|
)
|
||||||
self.service = service
|
self.service = service
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
|
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
|
||||||
f'group_starting_handle=0x{self.service.handle:04X}, '
|
f'group_starting_handle=0x{self.service.handle:04X}, '
|
||||||
@@ -326,7 +328,7 @@ class Characteristic(Attribute):
|
|||||||
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
|
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
# NOTE: we override this method to offer a consistent result between python
|
# NOTE: we override this method to offer a consistent result between python
|
||||||
# versions: the value returned by IntFlag.__str__() changed in version 11.
|
# versions: the value returned by IntFlag.__str__() changed in version 11.
|
||||||
return '|'.join(
|
return '|'.join(
|
||||||
@@ -348,10 +350,10 @@ class Characteristic(Attribute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uuid,
|
uuid: Union[str, bytes, UUID],
|
||||||
properties: Characteristic.Properties,
|
properties: Characteristic.Properties,
|
||||||
permissions,
|
permissions: Union[str, Attribute.Permissions],
|
||||||
value=b'',
|
value: Union[str, bytes, CharacteristicValue] = b'',
|
||||||
descriptors: Sequence[Descriptor] = (),
|
descriptors: Sequence[Descriptor] = (),
|
||||||
):
|
):
|
||||||
super().__init__(uuid, permissions, value)
|
super().__init__(uuid, permissions, value)
|
||||||
@@ -369,7 +371,7 @@ class Characteristic(Attribute):
|
|||||||
def has_properties(self, properties: Characteristic.Properties) -> bool:
|
def has_properties(self, properties: Characteristic.Properties) -> bool:
|
||||||
return self.properties & properties == properties
|
return self.properties & properties == properties
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f'Characteristic(handle=0x{self.handle:04X}, '
|
f'Characteristic(handle=0x{self.handle:04X}, '
|
||||||
f'end=0x{self.end_group_handle:04X}, '
|
f'end=0x{self.end_group_handle:04X}, '
|
||||||
@@ -386,7 +388,7 @@ class CharacteristicDeclaration(Attribute):
|
|||||||
|
|
||||||
characteristic: Characteristic
|
characteristic: Characteristic
|
||||||
|
|
||||||
def __init__(self, characteristic, value_handle):
|
def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
|
||||||
declaration_bytes = (
|
declaration_bytes = (
|
||||||
struct.pack('<BH', characteristic.properties, value_handle)
|
struct.pack('<BH', characteristic.properties, value_handle)
|
||||||
+ characteristic.uuid.to_pdu_bytes()
|
+ characteristic.uuid.to_pdu_bytes()
|
||||||
@@ -397,7 +399,7 @@ class CharacteristicDeclaration(Attribute):
|
|||||||
self.value_handle = value_handle
|
self.value_handle = value_handle
|
||||||
self.characteristic = characteristic
|
self.characteristic = characteristic
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
|
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
|
||||||
f'value_handle=0x{self.value_handle:04X}, '
|
f'value_handle=0x{self.value_handle:04X}, '
|
||||||
@@ -520,7 +522,7 @@ class CharacteristicAdapter:
|
|||||||
|
|
||||||
return self.wrapped_characteristic.unsubscribe(subscriber)
|
return self.wrapped_characteristic.unsubscribe(subscriber)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
wrapped = str(self.wrapped_characteristic)
|
wrapped = str(self.wrapped_characteristic)
|
||||||
return f'{self.__class__.__name__}({wrapped})'
|
return f'{self.__class__.__name__}({wrapped})'
|
||||||
|
|
||||||
@@ -600,10 +602,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
|
|||||||
Adapter that converts strings to/from bytes using UTF-8 encoding
|
Adapter that converts strings to/from bytes using UTF-8 encoding
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def encode_value(self, value):
|
def encode_value(self, value: str) -> bytes:
|
||||||
return value.encode('utf-8')
|
return value.encode('utf-8')
|
||||||
|
|
||||||
def decode_value(self, value):
|
def decode_value(self, value: bytes) -> str:
|
||||||
return value.decode('utf-8')
|
return value.decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
@@ -613,7 +615,7 @@ class Descriptor(Attribute):
|
|||||||
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
|
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f'Descriptor(handle=0x{self.handle:04X}, '
|
f'Descriptor(handle=0x{self.handle:04X}, '
|
||||||
f'type={self.type}, '
|
f'type={self.type}, '
|
||||||
|
|||||||
@@ -28,7 +28,18 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional, Dict, Tuple, Callable, Union, Any
|
from typing import (
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Dict,
|
||||||
|
Tuple,
|
||||||
|
Callable,
|
||||||
|
Union,
|
||||||
|
Any,
|
||||||
|
Iterable,
|
||||||
|
Type,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
)
|
||||||
|
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
|
|
||||||
@@ -66,8 +77,12 @@ from .gatt import (
|
|||||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||||
Characteristic,
|
Characteristic,
|
||||||
ClientCharacteristicConfigurationBits,
|
ClientCharacteristicConfigurationBits,
|
||||||
|
TemplateService,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from bumble.device import Connection
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Logging
|
# Logging
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -78,16 +93,16 @@ logger = logging.getLogger(__name__)
|
|||||||
# Proxies
|
# Proxies
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class AttributeProxy(EventEmitter):
|
class AttributeProxy(EventEmitter):
|
||||||
client: Client
|
def __init__(
|
||||||
|
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
|
||||||
def __init__(self, client, handle, end_group_handle, attribute_type):
|
) -> None:
|
||||||
EventEmitter.__init__(self)
|
EventEmitter.__init__(self)
|
||||||
self.client = client
|
self.client = client
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
self.end_group_handle = end_group_handle
|
self.end_group_handle = end_group_handle
|
||||||
self.type = attribute_type
|
self.type = attribute_type
|
||||||
|
|
||||||
async def read_value(self, no_long_read=False):
|
async def read_value(self, no_long_read: bool = False) -> bytes:
|
||||||
return self.decode_value(
|
return self.decode_value(
|
||||||
await self.client.read_value(self.handle, no_long_read)
|
await self.client.read_value(self.handle, no_long_read)
|
||||||
)
|
)
|
||||||
@@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter):
|
|||||||
self.handle, self.encode_value(value), with_response
|
self.handle, self.encode_value(value), with_response
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode_value(self, value):
|
def encode_value(self, value: Any) -> bytes:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def decode_value(self, value_bytes):
|
def decode_value(self, value_bytes: bytes) -> Any:
|
||||||
return value_bytes
|
return value_bytes
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
|
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
|
||||||
|
|
||||||
|
|
||||||
@@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy):
|
|||||||
def get_characteristics_by_uuid(self, uuid):
|
def get_characteristics_by_uuid(self, uuid):
|
||||||
return self.client.get_characteristics_by_uuid(uuid, self)
|
return self.client.get_characteristics_by_uuid(uuid, self)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
|
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
|
||||||
|
|
||||||
|
|
||||||
class CharacteristicProxy(AttributeProxy):
|
class CharacteristicProxy(AttributeProxy):
|
||||||
properties: Characteristic.Properties
|
properties: Characteristic.Properties
|
||||||
descriptors: List[DescriptorProxy]
|
descriptors: List[DescriptorProxy]
|
||||||
subscribers: Dict[Any, Callable]
|
subscribers: Dict[Any, Callable[[bytes], Any]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy):
|
|||||||
return await self.client.discover_descriptors(self)
|
return await self.client.discover_descriptors(self)
|
||||||
|
|
||||||
async def subscribe(
|
async def subscribe(
|
||||||
self, subscriber: Optional[Callable] = None, prefer_notify=True
|
self,
|
||||||
|
subscriber: Optional[Callable[[bytes], Any]] = None,
|
||||||
|
prefer_notify: bool = True,
|
||||||
):
|
):
|
||||||
if subscriber is not None:
|
if subscriber is not None:
|
||||||
if subscriber in self.subscribers:
|
if subscriber in self.subscribers:
|
||||||
@@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy):
|
|||||||
|
|
||||||
return await self.client.unsubscribe(self, subscriber)
|
return await self.client.unsubscribe(self, subscriber)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f'Characteristic(handle=0x{self.handle:04X}, '
|
f'Characteristic(handle=0x{self.handle:04X}, '
|
||||||
f'uuid={self.uuid}, '
|
f'uuid={self.uuid}, '
|
||||||
@@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy):
|
|||||||
def __init__(self, client, handle, descriptor_type):
|
def __init__(self, client, handle, descriptor_type):
|
||||||
super().__init__(client, handle, 0, descriptor_type)
|
super().__init__(client, handle, 0, descriptor_type)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
|
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
|
||||||
|
|
||||||
|
|
||||||
@@ -216,8 +233,10 @@ class ProfileServiceProxy:
|
|||||||
Base class for profile-specific service proxies
|
Base class for profile-specific service proxies
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
SERVICE_CLASS: Type[TemplateService]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_client(cls, client):
|
def from_client(cls, client: Client) -> ProfileServiceProxy:
|
||||||
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
|
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
|
||||||
|
|
||||||
|
|
||||||
@@ -227,8 +246,12 @@ class ProfileServiceProxy:
|
|||||||
class Client:
|
class Client:
|
||||||
services: List[ServiceProxy]
|
services: List[ServiceProxy]
|
||||||
cached_values: Dict[int, Tuple[datetime, bytes]]
|
cached_values: Dict[int, Tuple[datetime, bytes]]
|
||||||
|
notification_subscribers: Dict[int, Callable[[bytes], Any]]
|
||||||
|
indication_subscribers: Dict[int, Callable[[bytes], Any]]
|
||||||
|
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
|
||||||
|
pending_request: Optional[ATT_PDU]
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection: Connection) -> None:
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.mtu_exchange_done = False
|
self.mtu_exchange_done = False
|
||||||
self.request_semaphore = asyncio.Semaphore(1)
|
self.request_semaphore = asyncio.Semaphore(1)
|
||||||
@@ -241,16 +264,16 @@ class Client:
|
|||||||
self.services = []
|
self.services = []
|
||||||
self.cached_values = {}
|
self.cached_values = {}
|
||||||
|
|
||||||
def send_gatt_pdu(self, pdu):
|
def send_gatt_pdu(self, pdu: bytes) -> None:
|
||||||
self.connection.send_l2cap_pdu(ATT_CID, pdu)
|
self.connection.send_l2cap_pdu(ATT_CID, pdu)
|
||||||
|
|
||||||
async def send_command(self, command):
|
async def send_command(self, command: ATT_PDU) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
|
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
|
||||||
)
|
)
|
||||||
self.send_gatt_pdu(command.to_bytes())
|
self.send_gatt_pdu(command.to_bytes())
|
||||||
|
|
||||||
async def send_request(self, request):
|
async def send_request(self, request: ATT_PDU):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
|
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
|
||||||
)
|
)
|
||||||
@@ -279,14 +302,14 @@ class Client:
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def send_confirmation(self, confirmation):
|
def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
|
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
|
||||||
f'{confirmation}'
|
f'{confirmation}'
|
||||||
)
|
)
|
||||||
self.send_gatt_pdu(confirmation.to_bytes())
|
self.send_gatt_pdu(confirmation.to_bytes())
|
||||||
|
|
||||||
async def request_mtu(self, mtu):
|
async def request_mtu(self, mtu: int) -> int:
|
||||||
# Check the range
|
# Check the range
|
||||||
if mtu < ATT_DEFAULT_MTU:
|
if mtu < ATT_DEFAULT_MTU:
|
||||||
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
|
||||||
@@ -313,10 +336,12 @@ class Client:
|
|||||||
|
|
||||||
return self.connection.att_mtu
|
return self.connection.att_mtu
|
||||||
|
|
||||||
def get_services_by_uuid(self, uuid):
|
def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
|
||||||
return [service for service in self.services if service.uuid == uuid]
|
return [service for service in self.services if service.uuid == uuid]
|
||||||
|
|
||||||
def get_characteristics_by_uuid(self, uuid, service=None):
|
def get_characteristics_by_uuid(
|
||||||
|
self, uuid: UUID, service: Optional[ServiceProxy] = None
|
||||||
|
) -> List[CharacteristicProxy]:
|
||||||
services = [service] if service else self.services
|
services = [service] if service else self.services
|
||||||
return [
|
return [
|
||||||
c
|
c
|
||||||
@@ -363,7 +388,7 @@ class Client:
|
|||||||
if not already_known:
|
if not already_known:
|
||||||
self.services.append(service)
|
self.services.append(service)
|
||||||
|
|
||||||
async def discover_services(self, uuids=None) -> List[ServiceProxy]:
|
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
|
||||||
'''
|
'''
|
||||||
See Vol 3, Part G - 4.4.1 Discover All Primary Services
|
See Vol 3, Part G - 4.4.1 Discover All Primary Services
|
||||||
'''
|
'''
|
||||||
@@ -435,7 +460,7 @@ class Client:
|
|||||||
|
|
||||||
return services
|
return services
|
||||||
|
|
||||||
async def discover_service(self, uuid):
|
async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
|
||||||
'''
|
'''
|
||||||
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
|
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
|
||||||
'''
|
'''
|
||||||
@@ -468,7 +493,7 @@ class Client:
|
|||||||
f'{HCI_Constant.error_name(response.error_code)}'
|
f'{HCI_Constant.error_name(response.error_code)}'
|
||||||
)
|
)
|
||||||
# TODO raise appropriate exception
|
# TODO raise appropriate exception
|
||||||
return
|
return []
|
||||||
break
|
break
|
||||||
|
|
||||||
for attribute_handle, end_group_handle in response.handles_information:
|
for attribute_handle, end_group_handle in response.handles_information:
|
||||||
@@ -480,7 +505,7 @@ class Client:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f'bogus handle values: {attribute_handle} {end_group_handle}'
|
f'bogus handle values: {attribute_handle} {end_group_handle}'
|
||||||
)
|
)
|
||||||
return
|
return []
|
||||||
|
|
||||||
# Create a service proxy for this service
|
# Create a service proxy for this service
|
||||||
service = ServiceProxy(
|
service = ServiceProxy(
|
||||||
@@ -721,7 +746,7 @@ class Client:
|
|||||||
|
|
||||||
return descriptors
|
return descriptors
|
||||||
|
|
||||||
async def discover_attributes(self):
|
async def discover_attributes(self) -> List[AttributeProxy]:
|
||||||
'''
|
'''
|
||||||
Discover all attributes, regardless of type
|
Discover all attributes, regardless of type
|
||||||
'''
|
'''
|
||||||
@@ -844,7 +869,9 @@ class Client:
|
|||||||
# No more subscribers left
|
# No more subscribers left
|
||||||
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
await self.write_value(cccd, b'\x00\x00', with_response=True)
|
||||||
|
|
||||||
async def read_value(self, attribute, no_long_read=False):
|
async def read_value(
|
||||||
|
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
|
||||||
|
) -> Any:
|
||||||
'''
|
'''
|
||||||
See Vol 3, Part G - 4.8.1 Read Characteristic Value
|
See Vol 3, Part G - 4.8.1 Read Characteristic Value
|
||||||
|
|
||||||
@@ -905,7 +932,9 @@ class Client:
|
|||||||
# Return the value as bytes
|
# Return the value as bytes
|
||||||
return attribute_value
|
return attribute_value
|
||||||
|
|
||||||
async def read_characteristics_by_uuid(self, uuid, service):
|
async def read_characteristics_by_uuid(
|
||||||
|
self, uuid: UUID, service: Optional[ServiceProxy]
|
||||||
|
) -> List[bytes]:
|
||||||
'''
|
'''
|
||||||
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
|
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
|
||||||
'''
|
'''
|
||||||
@@ -960,7 +989,12 @@ class Client:
|
|||||||
|
|
||||||
return characteristics_values
|
return characteristics_values
|
||||||
|
|
||||||
async def write_value(self, attribute, value, with_response=False):
|
async def write_value(
|
||||||
|
self,
|
||||||
|
attribute: Union[int, AttributeProxy],
|
||||||
|
value: bytes,
|
||||||
|
with_response: bool = False,
|
||||||
|
) -> None:
|
||||||
'''
|
'''
|
||||||
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
|
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
|
||||||
Value
|
Value
|
||||||
@@ -990,7 +1024,7 @@ class Client:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_gatt_pdu(self, att_pdu):
|
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
|
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
|
||||||
)
|
)
|
||||||
@@ -1013,6 +1047,7 @@ class Client:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Return the response to the coroutine that is waiting for it
|
# Return the response to the coroutine that is waiting for it
|
||||||
|
assert self.pending_response is not None
|
||||||
self.pending_response.set_result(att_pdu)
|
self.pending_response.set_result(att_pdu)
|
||||||
else:
|
else:
|
||||||
handler_name = f'on_{att_pdu.name.lower()}'
|
handler_name = f'on_{att_pdu.name.lower()}'
|
||||||
@@ -1060,7 +1095,7 @@ class Client:
|
|||||||
# Confirm that we received the indication
|
# Confirm that we received the indication
|
||||||
self.send_confirmation(ATT_Handle_Value_Confirmation())
|
self.send_confirmation(ATT_Handle_Value_Confirmation())
|
||||||
|
|
||||||
def cache_value(self, attribute_handle: int, value: bytes):
|
def cache_value(self, attribute_handle: int, value: bytes) -> None:
|
||||||
self.cached_values[attribute_handle] = (
|
self.cached_values[attribute_handle] = (
|
||||||
datetime.now(),
|
datetime.now(),
|
||||||
value,
|
value,
|
||||||
|
|||||||
@@ -23,11 +23,12 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Imports
|
# Imports
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import struct
|
import struct
|
||||||
from typing import List, Tuple, Optional, TypeVar, Type
|
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
|
|
||||||
from .colors import color
|
from .colors import color
|
||||||
@@ -42,6 +43,7 @@ from .att import (
|
|||||||
ATT_INVALID_OFFSET_ERROR,
|
ATT_INVALID_OFFSET_ERROR,
|
||||||
ATT_REQUEST_NOT_SUPPORTED_ERROR,
|
ATT_REQUEST_NOT_SUPPORTED_ERROR,
|
||||||
ATT_REQUESTS,
|
ATT_REQUESTS,
|
||||||
|
ATT_PDU,
|
||||||
ATT_UNLIKELY_ERROR_ERROR,
|
ATT_UNLIKELY_ERROR_ERROR,
|
||||||
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
|
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
|
||||||
ATT_Error,
|
ATT_Error,
|
||||||
@@ -73,6 +75,8 @@ from .gatt import (
|
|||||||
Service,
|
Service,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from bumble.device import Device, Connection
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Logging
|
# Logging
|
||||||
@@ -91,8 +95,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Server(EventEmitter):
|
class Server(EventEmitter):
|
||||||
attributes: List[Attribute]
|
attributes: List[Attribute]
|
||||||
|
services: List[Service]
|
||||||
|
attributes_by_handle: Dict[int, Attribute]
|
||||||
|
subscribers: Dict[int, Dict[int, bytes]]
|
||||||
|
indication_semaphores: defaultdict[int, asyncio.Semaphore]
|
||||||
|
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
|
||||||
|
|
||||||
def __init__(self, device):
|
def __init__(self, device: Device) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.services = []
|
self.services = []
|
||||||
@@ -107,16 +116,16 @@ class Server(EventEmitter):
|
|||||||
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
|
||||||
self.pending_confirmations = defaultdict(lambda: None)
|
self.pending_confirmations = defaultdict(lambda: None)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "\n".join(map(str, self.attributes))
|
return "\n".join(map(str, self.attributes))
|
||||||
|
|
||||||
def send_gatt_pdu(self, connection_handle, pdu):
|
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
|
||||||
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
|
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
|
||||||
|
|
||||||
def next_handle(self):
|
def next_handle(self) -> int:
|
||||||
return 1 + len(self.attributes)
|
return 1 + len(self.attributes)
|
||||||
|
|
||||||
def get_advertising_service_data(self):
|
def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
|
||||||
return {
|
return {
|
||||||
attribute: data
|
attribute: data
|
||||||
for attribute in self.attributes
|
for attribute in self.attributes
|
||||||
@@ -124,7 +133,7 @@ class Server(EventEmitter):
|
|||||||
and (data := attribute.get_advertising_data())
|
and (data := attribute.get_advertising_data())
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_attribute(self, handle):
|
def get_attribute(self, handle: int) -> Optional[Attribute]:
|
||||||
attribute = self.attributes_by_handle.get(handle)
|
attribute = self.attributes_by_handle.get(handle)
|
||||||
if attribute:
|
if attribute:
|
||||||
return attribute
|
return attribute
|
||||||
@@ -173,12 +182,17 @@ class Server(EventEmitter):
|
|||||||
|
|
||||||
return next(
|
return next(
|
||||||
(
|
(
|
||||||
(attribute, self.get_attribute(attribute.characteristic.handle))
|
(
|
||||||
|
attribute,
|
||||||
|
self.get_attribute(attribute.characteristic.handle),
|
||||||
|
) # type: ignore
|
||||||
for attribute in map(
|
for attribute in map(
|
||||||
self.get_attribute,
|
self.get_attribute,
|
||||||
range(service_handle.handle, service_handle.end_group_handle + 1),
|
range(service_handle.handle, service_handle.end_group_handle + 1),
|
||||||
)
|
)
|
||||||
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
if attribute is not None
|
||||||
|
and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
|
||||||
|
and isinstance(attribute, CharacteristicDeclaration)
|
||||||
and attribute.characteristic.uuid == characteristic_uuid
|
and attribute.characteristic.uuid == characteristic_uuid
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
@@ -197,7 +211,7 @@ class Server(EventEmitter):
|
|||||||
|
|
||||||
return next(
|
return next(
|
||||||
(
|
(
|
||||||
attribute
|
attribute # type: ignore
|
||||||
for attribute in map(
|
for attribute in map(
|
||||||
self.get_attribute,
|
self.get_attribute,
|
||||||
range(
|
range(
|
||||||
@@ -205,12 +219,12 @@ class Server(EventEmitter):
|
|||||||
characteristic_value.end_group_handle + 1,
|
characteristic_value.end_group_handle + 1,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if attribute.type == descriptor_uuid
|
if attribute is not None and attribute.type == descriptor_uuid
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_attribute(self, attribute):
|
def add_attribute(self, attribute: Attribute) -> None:
|
||||||
# Assign a handle to this attribute
|
# Assign a handle to this attribute
|
||||||
attribute.handle = self.next_handle()
|
attribute.handle = self.next_handle()
|
||||||
attribute.end_group_handle = (
|
attribute.end_group_handle = (
|
||||||
@@ -220,7 +234,7 @@ class Server(EventEmitter):
|
|||||||
# Add this attribute to the list
|
# Add this attribute to the list
|
||||||
self.attributes.append(attribute)
|
self.attributes.append(attribute)
|
||||||
|
|
||||||
def add_service(self, service: Service):
|
def add_service(self, service: Service) -> None:
|
||||||
# Add the service attribute to the DB
|
# Add the service attribute to the DB
|
||||||
self.add_attribute(service)
|
self.add_attribute(service)
|
||||||
|
|
||||||
@@ -285,11 +299,13 @@ class Server(EventEmitter):
|
|||||||
service.end_group_handle = self.attributes[-1].handle
|
service.end_group_handle = self.attributes[-1].handle
|
||||||
self.services.append(service)
|
self.services.append(service)
|
||||||
|
|
||||||
def add_services(self, services):
|
def add_services(self, services: Iterable[Service]) -> None:
|
||||||
for service in services:
|
for service in services:
|
||||||
self.add_service(service)
|
self.add_service(service)
|
||||||
|
|
||||||
def read_cccd(self, connection, characteristic):
|
def read_cccd(
|
||||||
|
self, connection: Optional[Connection], characteristic: Characteristic
|
||||||
|
) -> bytes:
|
||||||
if connection is None:
|
if connection is None:
|
||||||
return bytes([0, 0])
|
return bytes([0, 0])
|
||||||
|
|
||||||
@@ -300,7 +316,12 @@ class Server(EventEmitter):
|
|||||||
|
|
||||||
return cccd or bytes([0, 0])
|
return cccd or bytes([0, 0])
|
||||||
|
|
||||||
def write_cccd(self, connection, characteristic, value):
|
def write_cccd(
|
||||||
|
self,
|
||||||
|
connection: Connection,
|
||||||
|
characteristic: Characteristic,
|
||||||
|
value: bytes,
|
||||||
|
) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'Subscription update for connection=0x{connection.handle:04X}, '
|
f'Subscription update for connection=0x{connection.handle:04X}, '
|
||||||
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
|
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
|
||||||
@@ -327,13 +348,19 @@ class Server(EventEmitter):
|
|||||||
indicate_enabled,
|
indicate_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_response(self, connection, response):
|
def send_response(self, connection: Connection, response: ATT_PDU) -> None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
|
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
|
||||||
)
|
)
|
||||||
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
||||||
|
|
||||||
async def notify_subscriber(self, connection, attribute, value=None, force=False):
|
async def notify_subscriber(
|
||||||
|
self,
|
||||||
|
connection: Connection,
|
||||||
|
attribute: Attribute,
|
||||||
|
value: Optional[bytes] = None,
|
||||||
|
force: bool = False,
|
||||||
|
) -> None:
|
||||||
# Check if there's a subscriber
|
# Check if there's a subscriber
|
||||||
if not force:
|
if not force:
|
||||||
subscribers = self.subscribers.get(connection.handle)
|
subscribers = self.subscribers.get(connection.handle)
|
||||||
@@ -370,7 +397,13 @@ class Server(EventEmitter):
|
|||||||
)
|
)
|
||||||
self.send_gatt_pdu(connection.handle, bytes(notification))
|
self.send_gatt_pdu(connection.handle, bytes(notification))
|
||||||
|
|
||||||
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
|
async def indicate_subscriber(
|
||||||
|
self,
|
||||||
|
connection: Connection,
|
||||||
|
attribute: Attribute,
|
||||||
|
value: Optional[bytes] = None,
|
||||||
|
force: bool = False,
|
||||||
|
) -> None:
|
||||||
# Check if there's a subscriber
|
# Check if there's a subscriber
|
||||||
if not force:
|
if not force:
|
||||||
subscribers = self.subscribers.get(connection.handle)
|
subscribers = self.subscribers.get(connection.handle)
|
||||||
@@ -411,15 +444,13 @@ class Server(EventEmitter):
|
|||||||
assert self.pending_confirmations[connection.handle] is None
|
assert self.pending_confirmations[connection.handle] is None
|
||||||
|
|
||||||
# Create a future value to hold the eventual response
|
# Create a future value to hold the eventual response
|
||||||
self.pending_confirmations[
|
pending_confirmation = self.pending_confirmations[
|
||||||
connection.handle
|
connection.handle
|
||||||
] = asyncio.get_running_loop().create_future()
|
] = asyncio.get_running_loop().create_future()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.send_gatt_pdu(connection.handle, indication.to_bytes())
|
self.send_gatt_pdu(connection.handle, indication.to_bytes())
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
|
||||||
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError as error:
|
except asyncio.TimeoutError as error:
|
||||||
logger.warning(color('!!! GATT Indicate timeout', 'red'))
|
logger.warning(color('!!! GATT Indicate timeout', 'red'))
|
||||||
raise TimeoutError(f'GATT timeout for {indication.name}') from error
|
raise TimeoutError(f'GATT timeout for {indication.name}') from error
|
||||||
@@ -427,8 +458,12 @@ class Server(EventEmitter):
|
|||||||
self.pending_confirmations[connection.handle] = None
|
self.pending_confirmations[connection.handle] = None
|
||||||
|
|
||||||
async def notify_or_indicate_subscribers(
|
async def notify_or_indicate_subscribers(
|
||||||
self, indicate, attribute, value=None, force=False
|
self,
|
||||||
):
|
indicate: bool,
|
||||||
|
attribute: Attribute,
|
||||||
|
value: Optional[bytes] = None,
|
||||||
|
force: bool = False,
|
||||||
|
) -> None:
|
||||||
# Get all the connections for which there's at least one subscription
|
# Get all the connections for which there's at least one subscription
|
||||||
connections = [
|
connections = [
|
||||||
connection
|
connection
|
||||||
@@ -450,13 +485,23 @@ class Server(EventEmitter):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def notify_subscribers(self, attribute, value=None, force=False):
|
async def notify_subscribers(
|
||||||
|
self,
|
||||||
|
attribute: Attribute,
|
||||||
|
value: Optional[bytes] = None,
|
||||||
|
force: bool = False,
|
||||||
|
):
|
||||||
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
|
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
|
||||||
|
|
||||||
async def indicate_subscribers(self, attribute, value=None, force=False):
|
async def indicate_subscribers(
|
||||||
|
self,
|
||||||
|
attribute: Attribute,
|
||||||
|
value: Optional[bytes] = None,
|
||||||
|
force: bool = False,
|
||||||
|
):
|
||||||
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
|
||||||
|
|
||||||
def on_disconnection(self, connection):
|
def on_disconnection(self, connection: Connection) -> None:
|
||||||
if connection.handle in self.subscribers:
|
if connection.handle in self.subscribers:
|
||||||
del self.subscribers[connection.handle]
|
del self.subscribers[connection.handle]
|
||||||
if connection.handle in self.indication_semaphores:
|
if connection.handle in self.indication_semaphores:
|
||||||
@@ -464,7 +509,7 @@ class Server(EventEmitter):
|
|||||||
if connection.handle in self.pending_confirmations:
|
if connection.handle in self.pending_confirmations:
|
||||||
del self.pending_confirmations[connection.handle]
|
del self.pending_confirmations[connection.handle]
|
||||||
|
|
||||||
def on_gatt_pdu(self, connection, att_pdu):
|
def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
|
||||||
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
|
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
|
||||||
handler_name = f'on_{att_pdu.name.lower()}'
|
handler_name = f'on_{att_pdu.name.lower()}'
|
||||||
handler = getattr(self, handler_name, None)
|
handler = getattr(self, handler_name, None)
|
||||||
@@ -506,7 +551,7 @@ class Server(EventEmitter):
|
|||||||
#######################################################
|
#######################################################
|
||||||
# ATT handlers
|
# ATT handlers
|
||||||
#######################################################
|
#######################################################
|
||||||
def on_att_request(self, connection, pdu):
|
def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
|
||||||
'''
|
'''
|
||||||
Handler for requests without a more specific handler
|
Handler for requests without a more specific handler
|
||||||
'''
|
'''
|
||||||
@@ -679,7 +724,6 @@ class Server(EventEmitter):
|
|||||||
and attribute.handle <= request.ending_handle
|
and attribute.handle <= request.ending_handle
|
||||||
and pdu_space_available
|
and pdu_space_available
|
||||||
):
|
):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
attribute_value = attribute.read_value(connection)
|
attribute_value = attribute.read_value(connection)
|
||||||
except ATT_Error as error:
|
except ATT_Error as error:
|
||||||
|
|||||||
@@ -4397,7 +4397,7 @@ class HCI_Event(HCI_Packet):
|
|||||||
if len(parameters) != length:
|
if len(parameters) != length:
|
||||||
raise ValueError('invalid packet length')
|
raise ValueError('invalid packet length')
|
||||||
|
|
||||||
cls: Type[HCI_Event | HCI_LE_Meta_Event] | None
|
cls: Any
|
||||||
if event_code == HCI_LE_META_EVENT:
|
if event_code == HCI_LE_META_EVENT:
|
||||||
# We do this dispatch here and not in the subclass in order to avoid call
|
# We do this dispatch here and not in the subclass in order to avoid call
|
||||||
# loops
|
# loops
|
||||||
|
|||||||
167
bumble/l2cap.py
167
bumble/l2cap.py
@@ -17,6 +17,7 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
@@ -676,6 +677,7 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Channel(EventEmitter):
|
class Channel(EventEmitter):
|
||||||
|
class State(enum.IntEnum):
|
||||||
# States
|
# States
|
||||||
CLOSED = 0x00
|
CLOSED = 0x00
|
||||||
WAIT_CONNECT = 0x01
|
WAIT_CONNECT = 0x01
|
||||||
@@ -699,33 +701,11 @@ class Channel(EventEmitter):
|
|||||||
WAIT_FINAL_RSP = 0x16
|
WAIT_FINAL_RSP = 0x16
|
||||||
WAIT_CONTROL_IND = 0x17
|
WAIT_CONTROL_IND = 0x17
|
||||||
|
|
||||||
STATE_NAMES = {
|
|
||||||
CLOSED: 'CLOSED',
|
|
||||||
WAIT_CONNECT: 'WAIT_CONNECT',
|
|
||||||
WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP',
|
|
||||||
OPEN: 'OPEN',
|
|
||||||
WAIT_DISCONNECT: 'WAIT_DISCONNECT',
|
|
||||||
WAIT_CREATE: 'WAIT_CREATE',
|
|
||||||
WAIT_CREATE_RSP: 'WAIT_CREATE_RSP',
|
|
||||||
WAIT_MOVE: 'WAIT_MOVE',
|
|
||||||
WAIT_MOVE_RSP: 'WAIT_MOVE_RSP',
|
|
||||||
WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM',
|
|
||||||
WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP',
|
|
||||||
WAIT_CONFIG: 'WAIT_CONFIG',
|
|
||||||
WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG',
|
|
||||||
WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP',
|
|
||||||
WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP',
|
|
||||||
WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ',
|
|
||||||
WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP',
|
|
||||||
WAIT_FINAL_RSP: 'WAIT_FINAL_RSP',
|
|
||||||
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
|
|
||||||
}
|
|
||||||
|
|
||||||
connection_result: Optional[asyncio.Future[None]]
|
connection_result: Optional[asyncio.Future[None]]
|
||||||
disconnection_result: Optional[asyncio.Future[None]]
|
disconnection_result: Optional[asyncio.Future[None]]
|
||||||
response: Optional[asyncio.Future[bytes]]
|
response: Optional[asyncio.Future[bytes]]
|
||||||
sink: Optional[Callable[[bytes], Any]]
|
sink: Optional[Callable[[bytes], Any]]
|
||||||
state: int
|
state: State
|
||||||
connection: Connection
|
connection: Connection
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -741,7 +721,7 @@ class Channel(EventEmitter):
|
|||||||
self.manager = manager
|
self.manager = manager
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.signaling_cid = signaling_cid
|
self.signaling_cid = signaling_cid
|
||||||
self.state = Channel.CLOSED
|
self.state = self.State.CLOSED
|
||||||
self.mtu = mtu
|
self.mtu = mtu
|
||||||
self.psm = psm
|
self.psm = psm
|
||||||
self.source_cid = source_cid
|
self.source_cid = source_cid
|
||||||
@@ -751,13 +731,11 @@ class Channel(EventEmitter):
|
|||||||
self.disconnection_result = None
|
self.disconnection_result = None
|
||||||
self.sink = None
|
self.sink = None
|
||||||
|
|
||||||
def change_state(self, new_state: int) -> None:
|
def _change_state(self, new_state: State) -> None:
|
||||||
logger.debug(
|
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
|
||||||
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
|
|
||||||
)
|
|
||||||
self.state = new_state
|
self.state = new_state
|
||||||
|
|
||||||
def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
|
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||||
|
|
||||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||||
@@ -767,7 +745,7 @@ class Channel(EventEmitter):
|
|||||||
# Check that there isn't already a request pending
|
# Check that there isn't already a request pending
|
||||||
if self.response:
|
if self.response:
|
||||||
raise InvalidStateError('request already pending')
|
raise InvalidStateError('request already pending')
|
||||||
if self.state != Channel.OPEN:
|
if self.state != self.State.OPEN:
|
||||||
raise InvalidStateError('channel not open')
|
raise InvalidStateError('channel not open')
|
||||||
|
|
||||||
self.response = asyncio.get_running_loop().create_future()
|
self.response = asyncio.get_running_loop().create_future()
|
||||||
@@ -787,14 +765,14 @@ class Channel(EventEmitter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
if self.state != Channel.CLOSED:
|
if self.state != self.State.CLOSED:
|
||||||
raise InvalidStateError('invalid state')
|
raise InvalidStateError('invalid state')
|
||||||
|
|
||||||
# Check that we can start a new connection
|
# Check that we can start a new connection
|
||||||
if self.connection_result:
|
if self.connection_result:
|
||||||
raise RuntimeError('connection already pending')
|
raise RuntimeError('connection already pending')
|
||||||
|
|
||||||
self.change_state(Channel.WAIT_CONNECT_RSP)
|
self._change_state(self.State.WAIT_CONNECT_RSP)
|
||||||
self.send_control_frame(
|
self.send_control_frame(
|
||||||
L2CAP_Connection_Request(
|
L2CAP_Connection_Request(
|
||||||
identifier=self.manager.next_identifier(self.connection),
|
identifier=self.manager.next_identifier(self.connection),
|
||||||
@@ -814,10 +792,10 @@ class Channel(EventEmitter):
|
|||||||
self.connection_result = None
|
self.connection_result = None
|
||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
if self.state != Channel.OPEN:
|
if self.state != self.State.OPEN:
|
||||||
raise InvalidStateError('invalid state')
|
raise InvalidStateError('invalid state')
|
||||||
|
|
||||||
self.change_state(Channel.WAIT_DISCONNECT)
|
self._change_state(self.State.WAIT_DISCONNECT)
|
||||||
self.send_control_frame(
|
self.send_control_frame(
|
||||||
L2CAP_Disconnection_Request(
|
L2CAP_Disconnection_Request(
|
||||||
identifier=self.manager.next_identifier(self.connection),
|
identifier=self.manager.next_identifier(self.connection),
|
||||||
@@ -832,8 +810,8 @@ class Channel(EventEmitter):
|
|||||||
return await self.disconnection_result
|
return await self.disconnection_result
|
||||||
|
|
||||||
def abort(self) -> None:
|
def abort(self) -> None:
|
||||||
if self.state == self.OPEN:
|
if self.state == self.State.OPEN:
|
||||||
self.change_state(self.CLOSED)
|
self._change_state(self.State.CLOSED)
|
||||||
self.emit('close')
|
self.emit('close')
|
||||||
|
|
||||||
def send_configure_request(self) -> None:
|
def send_configure_request(self) -> None:
|
||||||
@@ -856,7 +834,7 @@ class Channel(EventEmitter):
|
|||||||
|
|
||||||
def on_connection_request(self, request) -> None:
|
def on_connection_request(self, request) -> None:
|
||||||
self.destination_cid = request.source_cid
|
self.destination_cid = request.source_cid
|
||||||
self.change_state(Channel.WAIT_CONNECT)
|
self._change_state(self.State.WAIT_CONNECT)
|
||||||
self.send_control_frame(
|
self.send_control_frame(
|
||||||
L2CAP_Connection_Response(
|
L2CAP_Connection_Response(
|
||||||
identifier=request.identifier,
|
identifier=request.identifier,
|
||||||
@@ -866,24 +844,24 @@ class Channel(EventEmitter):
|
|||||||
status=0x0000,
|
status=0x0000,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.change_state(Channel.WAIT_CONFIG)
|
self._change_state(self.State.WAIT_CONFIG)
|
||||||
self.send_configure_request()
|
self.send_configure_request()
|
||||||
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
|
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
|
||||||
|
|
||||||
def on_connection_response(self, response):
|
def on_connection_response(self, response):
|
||||||
if self.state != Channel.WAIT_CONNECT_RSP:
|
if self.state != self.State.WAIT_CONNECT_RSP:
|
||||||
logger.warning(color('invalid state', 'red'))
|
logger.warning(color('invalid state', 'red'))
|
||||||
return
|
return
|
||||||
|
|
||||||
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
|
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
|
||||||
self.destination_cid = response.destination_cid
|
self.destination_cid = response.destination_cid
|
||||||
self.change_state(Channel.WAIT_CONFIG)
|
self._change_state(self.State.WAIT_CONFIG)
|
||||||
self.send_configure_request()
|
self.send_configure_request()
|
||||||
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
|
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
|
||||||
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
|
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.change_state(Channel.CLOSED)
|
self._change_state(self.State.CLOSED)
|
||||||
self.connection_result.set_exception(
|
self.connection_result.set_exception(
|
||||||
ProtocolError(
|
ProtocolError(
|
||||||
response.result,
|
response.result,
|
||||||
@@ -895,9 +873,9 @@ class Channel(EventEmitter):
|
|||||||
|
|
||||||
def on_configure_request(self, request) -> None:
|
def on_configure_request(self, request) -> None:
|
||||||
if self.state not in (
|
if self.state not in (
|
||||||
Channel.WAIT_CONFIG,
|
self.State.WAIT_CONFIG,
|
||||||
Channel.WAIT_CONFIG_REQ,
|
self.State.WAIT_CONFIG_REQ,
|
||||||
Channel.WAIT_CONFIG_REQ_RSP,
|
self.State.WAIT_CONFIG_REQ_RSP,
|
||||||
):
|
):
|
||||||
logger.warning(color('invalid state', 'red'))
|
logger.warning(color('invalid state', 'red'))
|
||||||
return
|
return
|
||||||
@@ -918,25 +896,28 @@ class Channel(EventEmitter):
|
|||||||
options=request.options, # TODO: don't accept everything blindly
|
options=request.options, # TODO: don't accept everything blindly
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.state == Channel.WAIT_CONFIG:
|
if self.state == self.State.WAIT_CONFIG:
|
||||||
self.change_state(Channel.WAIT_SEND_CONFIG)
|
self._change_state(self.State.WAIT_SEND_CONFIG)
|
||||||
self.send_configure_request()
|
self.send_configure_request()
|
||||||
self.change_state(Channel.WAIT_CONFIG_RSP)
|
self._change_state(self.State.WAIT_CONFIG_RSP)
|
||||||
elif self.state == Channel.WAIT_CONFIG_REQ:
|
elif self.state == self.State.WAIT_CONFIG_REQ:
|
||||||
self.change_state(Channel.OPEN)
|
self._change_state(self.State.OPEN)
|
||||||
if self.connection_result:
|
if self.connection_result:
|
||||||
self.connection_result.set_result(None)
|
self.connection_result.set_result(None)
|
||||||
self.connection_result = None
|
self.connection_result = None
|
||||||
self.emit('open')
|
self.emit('open')
|
||||||
elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
|
elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
|
||||||
self.change_state(Channel.WAIT_CONFIG_RSP)
|
self._change_state(self.State.WAIT_CONFIG_RSP)
|
||||||
|
|
||||||
def on_configure_response(self, response) -> None:
|
def on_configure_response(self, response) -> None:
|
||||||
if response.result == L2CAP_Configure_Response.SUCCESS:
|
if response.result == L2CAP_Configure_Response.SUCCESS:
|
||||||
if self.state == Channel.WAIT_CONFIG_REQ_RSP:
|
if self.state == self.State.WAIT_CONFIG_REQ_RSP:
|
||||||
self.change_state(Channel.WAIT_CONFIG_REQ)
|
self._change_state(self.State.WAIT_CONFIG_REQ)
|
||||||
elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND):
|
elif self.state in (
|
||||||
self.change_state(Channel.OPEN)
|
self.State.WAIT_CONFIG_RSP,
|
||||||
|
self.State.WAIT_CONTROL_IND,
|
||||||
|
):
|
||||||
|
self._change_state(self.State.OPEN)
|
||||||
if self.connection_result:
|
if self.connection_result:
|
||||||
self.connection_result.set_result(None)
|
self.connection_result.set_result(None)
|
||||||
self.connection_result = None
|
self.connection_result = None
|
||||||
@@ -966,7 +947,7 @@ class Channel(EventEmitter):
|
|||||||
# TODO: decide how to fail gracefully
|
# TODO: decide how to fail gracefully
|
||||||
|
|
||||||
def on_disconnection_request(self, request) -> None:
|
def on_disconnection_request(self, request) -> None:
|
||||||
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
|
if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT):
|
||||||
self.send_control_frame(
|
self.send_control_frame(
|
||||||
L2CAP_Disconnection_Response(
|
L2CAP_Disconnection_Response(
|
||||||
identifier=request.identifier,
|
identifier=request.identifier,
|
||||||
@@ -974,14 +955,14 @@ class Channel(EventEmitter):
|
|||||||
source_cid=request.source_cid,
|
source_cid=request.source_cid,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.change_state(Channel.CLOSED)
|
self._change_state(self.State.CLOSED)
|
||||||
self.emit('close')
|
self.emit('close')
|
||||||
self.manager.on_channel_closed(self)
|
self.manager.on_channel_closed(self)
|
||||||
else:
|
else:
|
||||||
logger.warning(color('invalid state', 'red'))
|
logger.warning(color('invalid state', 'red'))
|
||||||
|
|
||||||
def on_disconnection_response(self, response) -> None:
|
def on_disconnection_response(self, response) -> None:
|
||||||
if self.state != Channel.WAIT_DISCONNECT:
|
if self.state != self.State.WAIT_DISCONNECT:
|
||||||
logger.warning(color('invalid state', 'red'))
|
logger.warning(color('invalid state', 'red'))
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -992,7 +973,7 @@ class Channel(EventEmitter):
|
|||||||
logger.warning('unexpected source or destination CID')
|
logger.warning('unexpected source or destination CID')
|
||||||
return
|
return
|
||||||
|
|
||||||
self.change_state(Channel.CLOSED)
|
self._change_state(self.State.CLOSED)
|
||||||
if self.disconnection_result:
|
if self.disconnection_result:
|
||||||
self.disconnection_result.set_result(None)
|
self.disconnection_result.set_result(None)
|
||||||
self.disconnection_result = None
|
self.disconnection_result = None
|
||||||
@@ -1004,7 +985,7 @@ class Channel(EventEmitter):
|
|||||||
f'Channel({self.source_cid}->{self.destination_cid}, '
|
f'Channel({self.source_cid}->{self.destination_cid}, '
|
||||||
f'PSM={self.psm}, '
|
f'PSM={self.psm}, '
|
||||||
f'MTU={self.mtu}, '
|
f'MTU={self.mtu}, '
|
||||||
f'state={Channel.STATE_NAMES[self.state]})'
|
f'state={self.state.name})'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1014,6 +995,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
LE Credit-based Connection Oriented Channel
|
LE Credit-based Connection Oriented Channel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class State(enum.IntEnum):
|
||||||
INIT = 0
|
INIT = 0
|
||||||
CONNECTED = 1
|
CONNECTED = 1
|
||||||
CONNECTING = 2
|
CONNECTING = 2
|
||||||
@@ -1021,26 +1003,13 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
DISCONNECTED = 4
|
DISCONNECTED = 4
|
||||||
CONNECTION_ERROR = 5
|
CONNECTION_ERROR = 5
|
||||||
|
|
||||||
STATE_NAMES = {
|
|
||||||
INIT: 'INIT',
|
|
||||||
CONNECTED: 'CONNECTED',
|
|
||||||
CONNECTING: 'CONNECTING',
|
|
||||||
DISCONNECTING: 'DISCONNECTING',
|
|
||||||
DISCONNECTED: 'DISCONNECTED',
|
|
||||||
CONNECTION_ERROR: 'CONNECTION_ERROR',
|
|
||||||
}
|
|
||||||
|
|
||||||
out_queue: Deque[bytes]
|
out_queue: Deque[bytes]
|
||||||
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
|
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
|
||||||
disconnection_result: Optional[asyncio.Future[None]]
|
disconnection_result: Optional[asyncio.Future[None]]
|
||||||
out_sdu: Optional[bytes]
|
out_sdu: Optional[bytes]
|
||||||
state: int
|
state: State
|
||||||
connection: Connection
|
connection: Connection
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def state_name(state: int) -> str:
|
|
||||||
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
manager: ChannelManager,
|
manager: ChannelManager,
|
||||||
@@ -1083,22 +1052,20 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
self.drained.set()
|
self.drained.set()
|
||||||
|
|
||||||
if connected:
|
if connected:
|
||||||
self.state = LeConnectionOrientedChannel.CONNECTED
|
self.state = self.State.CONNECTED
|
||||||
else:
|
else:
|
||||||
self.state = LeConnectionOrientedChannel.INIT
|
self.state = self.State.INIT
|
||||||
|
|
||||||
def change_state(self, new_state: int) -> None:
|
def _change_state(self, new_state: State) -> None:
|
||||||
logger.debug(
|
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
|
||||||
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
|
|
||||||
)
|
|
||||||
self.state = new_state
|
self.state = new_state
|
||||||
|
|
||||||
if new_state == self.CONNECTED:
|
if new_state == self.State.CONNECTED:
|
||||||
self.emit('open')
|
self.emit('open')
|
||||||
elif new_state == self.DISCONNECTED:
|
elif new_state == self.State.DISCONNECTED:
|
||||||
self.emit('close')
|
self.emit('close')
|
||||||
|
|
||||||
def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
|
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||||
|
|
||||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||||
@@ -1106,7 +1073,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
|
|
||||||
async def connect(self) -> LeConnectionOrientedChannel:
|
async def connect(self) -> LeConnectionOrientedChannel:
|
||||||
# Check that we're in the right state
|
# Check that we're in the right state
|
||||||
if self.state != self.INIT:
|
if self.state != self.State.INIT:
|
||||||
raise InvalidStateError('not in a connectable state')
|
raise InvalidStateError('not in a connectable state')
|
||||||
|
|
||||||
# Check that we can start a new connection
|
# Check that we can start a new connection
|
||||||
@@ -1114,7 +1081,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
if identifier in self.manager.le_coc_requests:
|
if identifier in self.manager.le_coc_requests:
|
||||||
raise RuntimeError('too many concurrent connection requests')
|
raise RuntimeError('too many concurrent connection requests')
|
||||||
|
|
||||||
self.change_state(self.CONNECTING)
|
self._change_state(self.State.CONNECTING)
|
||||||
request = L2CAP_LE_Credit_Based_Connection_Request(
|
request = L2CAP_LE_Credit_Based_Connection_Request(
|
||||||
identifier=identifier,
|
identifier=identifier,
|
||||||
le_psm=self.le_psm,
|
le_psm=self.le_psm,
|
||||||
@@ -1134,10 +1101,10 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
# Check that we're connected
|
# Check that we're connected
|
||||||
if self.state != self.CONNECTED:
|
if self.state != self.State.CONNECTED:
|
||||||
raise InvalidStateError('not connected')
|
raise InvalidStateError('not connected')
|
||||||
|
|
||||||
self.change_state(self.DISCONNECTING)
|
self._change_state(self.State.DISCONNECTING)
|
||||||
self.flush_output()
|
self.flush_output()
|
||||||
self.send_control_frame(
|
self.send_control_frame(
|
||||||
L2CAP_Disconnection_Request(
|
L2CAP_Disconnection_Request(
|
||||||
@@ -1153,15 +1120,15 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
return await self.disconnection_result
|
return await self.disconnection_result
|
||||||
|
|
||||||
def abort(self) -> None:
|
def abort(self) -> None:
|
||||||
if self.state == self.CONNECTED:
|
if self.state == self.State.CONNECTED:
|
||||||
self.change_state(self.DISCONNECTED)
|
self._change_state(self.State.DISCONNECTED)
|
||||||
|
|
||||||
def on_pdu(self, pdu: bytes) -> None:
|
def on_pdu(self, pdu: bytes) -> None:
|
||||||
if self.sink is None:
|
if self.sink is None:
|
||||||
logger.warning('received pdu without a sink')
|
logger.warning('received pdu without a sink')
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.state != self.CONNECTED:
|
if self.state != self.State.CONNECTED:
|
||||||
logger.warning('received PDU while not connected, dropping')
|
logger.warning('received PDU while not connected, dropping')
|
||||||
|
|
||||||
# Manage the peer credits
|
# Manage the peer credits
|
||||||
@@ -1240,7 +1207,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
self.credits = response.initial_credits
|
self.credits = response.initial_credits
|
||||||
self.connected = True
|
self.connected = True
|
||||||
self.connection_result.set_result(self)
|
self.connection_result.set_result(self)
|
||||||
self.change_state(self.CONNECTED)
|
self._change_state(self.State.CONNECTED)
|
||||||
else:
|
else:
|
||||||
self.connection_result.set_exception(
|
self.connection_result.set_exception(
|
||||||
ProtocolError(
|
ProtocolError(
|
||||||
@@ -1251,7 +1218,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.change_state(self.CONNECTION_ERROR)
|
self._change_state(self.State.CONNECTION_ERROR)
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
self.connection_result = None
|
self.connection_result = None
|
||||||
@@ -1271,11 +1238,11 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
source_cid=request.source_cid,
|
source_cid=request.source_cid,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.change_state(self.DISCONNECTED)
|
self._change_state(self.State.DISCONNECTED)
|
||||||
self.flush_output()
|
self.flush_output()
|
||||||
|
|
||||||
def on_disconnection_response(self, response) -> None:
|
def on_disconnection_response(self, response) -> None:
|
||||||
if self.state != self.DISCONNECTING:
|
if self.state != self.State.DISCONNECTING:
|
||||||
logger.warning(color('invalid state', 'red'))
|
logger.warning(color('invalid state', 'red'))
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1286,7 +1253,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
logger.warning('unexpected source or destination CID')
|
logger.warning('unexpected source or destination CID')
|
||||||
return
|
return
|
||||||
|
|
||||||
self.change_state(self.DISCONNECTED)
|
self._change_state(self.State.DISCONNECTED)
|
||||||
if self.disconnection_result:
|
if self.disconnection_result:
|
||||||
self.disconnection_result.set_result(None)
|
self.disconnection_result.set_result(None)
|
||||||
self.disconnection_result = None
|
self.disconnection_result = None
|
||||||
@@ -1339,7 +1306,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def write(self, data: bytes) -> None:
|
def write(self, data: bytes) -> None:
|
||||||
if self.state != self.CONNECTED:
|
if self.state != self.State.CONNECTED:
|
||||||
logger.warning('not connected, dropping data')
|
logger.warning('not connected, dropping data')
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -1367,7 +1334,7 @@ class LeConnectionOrientedChannel(EventEmitter):
|
|||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f'CoC({self.source_cid}->{self.destination_cid}, '
|
f'CoC({self.source_cid}->{self.destination_cid}, '
|
||||||
f'State={self.state_name(self.state)}, '
|
f'State={self.state.name}, '
|
||||||
f'PSM={self.le_psm}, '
|
f'PSM={self.le_psm}, '
|
||||||
f'MTU={self.mtu}/{self.peer_mtu}, '
|
f'MTU={self.mtu}/{self.peer_mtu}, '
|
||||||
f'MPS={self.mps}/{self.peer_mps}, '
|
f'MPS={self.mps}/{self.peer_mps}, '
|
||||||
@@ -1571,7 +1538,7 @@ class ChannelManager:
|
|||||||
if connection_handle in self.identifiers:
|
if connection_handle in self.identifiers:
|
||||||
del self.identifiers[connection_handle]
|
del self.identifiers[connection_handle]
|
||||||
|
|
||||||
def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None:
|
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||||
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
|
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'{color(">>> Sending L2CAP PDU", "blue")} '
|
f'{color(">>> Sending L2CAP PDU", "blue")} '
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import grpc
|
import grpc
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -27,8 +28,8 @@ from bumble.core import (
|
|||||||
)
|
)
|
||||||
from bumble.device import Connection as BumbleConnection, Device
|
from bumble.device import Connection as BumbleConnection, Device
|
||||||
from bumble.hci import HCI_Error
|
from bumble.hci import HCI_Error
|
||||||
|
from bumble.utils import EventWatcher
|
||||||
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
|
||||||
from contextlib import suppress
|
|
||||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||||
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
|
||||||
@@ -294,23 +295,35 @@ class SecurityService(SecurityServicer):
|
|||||||
try:
|
try:
|
||||||
self.log.debug('Pair...')
|
self.log.debug('Pair...')
|
||||||
|
|
||||||
|
security_result = asyncio.get_running_loop().create_future()
|
||||||
|
|
||||||
|
with contextlib.closing(EventWatcher()) as watcher:
|
||||||
|
|
||||||
|
@watcher.on(connection, 'pairing')
|
||||||
|
def on_pairing(*_: Any) -> None:
|
||||||
|
security_result.set_result('success')
|
||||||
|
|
||||||
|
@watcher.on(connection, 'pairing_failure')
|
||||||
|
def on_pairing_failure(*_: Any) -> None:
|
||||||
|
security_result.set_result('pairing_failure')
|
||||||
|
|
||||||
|
@watcher.on(connection, 'disconnection')
|
||||||
|
def on_disconnection(*_: Any) -> None:
|
||||||
|
security_result.set_result('connection_died')
|
||||||
|
|
||||||
if (
|
if (
|
||||||
connection.transport == BT_LE_TRANSPORT
|
connection.transport == BT_LE_TRANSPORT
|
||||||
and connection.role == BT_PERIPHERAL_ROLE
|
and connection.role == BT_PERIPHERAL_ROLE
|
||||||
):
|
):
|
||||||
wait_for_security: asyncio.Future[
|
|
||||||
bool
|
|
||||||
] = asyncio.get_running_loop().create_future()
|
|
||||||
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
|
|
||||||
connection.on("pairing_failure", wait_for_security.set_exception)
|
|
||||||
|
|
||||||
connection.request_pairing()
|
connection.request_pairing()
|
||||||
|
|
||||||
await wait_for_security
|
|
||||||
else:
|
else:
|
||||||
await connection.pair()
|
await connection.pair()
|
||||||
|
|
||||||
self.log.debug('Paired')
|
result = await security_result
|
||||||
|
|
||||||
|
self.log.debug(f'Pairing session complete, status={result}')
|
||||||
|
if result != 'success':
|
||||||
|
return SecureResponse(**{result: empty_pb2.Empty()})
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
self.log.warning("Connection died during encryption")
|
self.log.warning("Connection died during encryption")
|
||||||
return SecureResponse(connection_died=empty_pb2.Empty())
|
return SecureResponse(connection_died=empty_pb2.Empty())
|
||||||
@@ -369,6 +382,7 @@ class SecurityService(SecurityServicer):
|
|||||||
str
|
str
|
||||||
] = asyncio.get_running_loop().create_future()
|
] = asyncio.get_running_loop().create_future()
|
||||||
authenticate_task: Optional[asyncio.Future[None]] = None
|
authenticate_task: Optional[asyncio.Future[None]] = None
|
||||||
|
pair_task: Optional[asyncio.Future[None]] = None
|
||||||
|
|
||||||
async def authenticate() -> None:
|
async def authenticate() -> None:
|
||||||
assert connection
|
assert connection
|
||||||
@@ -415,6 +429,10 @@ class SecurityService(SecurityServicer):
|
|||||||
if authenticate_task is None:
|
if authenticate_task is None:
|
||||||
authenticate_task = asyncio.create_task(authenticate())
|
authenticate_task = asyncio.create_task(authenticate())
|
||||||
|
|
||||||
|
def pair(*_: Any) -> None:
|
||||||
|
if self.need_pairing(connection, level):
|
||||||
|
pair_task = asyncio.create_task(connection.pair())
|
||||||
|
|
||||||
listeners: Dict[str, Callable[..., None]] = {
|
listeners: Dict[str, Callable[..., None]] = {
|
||||||
'disconnection': set_failure('connection_died'),
|
'disconnection': set_failure('connection_died'),
|
||||||
'pairing_failure': set_failure('pairing_failure'),
|
'pairing_failure': set_failure('pairing_failure'),
|
||||||
@@ -425,6 +443,7 @@ class SecurityService(SecurityServicer):
|
|||||||
'connection_encryption_change': on_encryption_change,
|
'connection_encryption_change': on_encryption_change,
|
||||||
'classic_pairing': try_set_success,
|
'classic_pairing': try_set_success,
|
||||||
'classic_pairing_failure': set_failure('pairing_failure'),
|
'classic_pairing_failure': set_failure('pairing_failure'),
|
||||||
|
'security_request': pair,
|
||||||
}
|
}
|
||||||
|
|
||||||
# register event handlers
|
# register event handlers
|
||||||
@@ -452,6 +471,15 @@ class SecurityService(SecurityServicer):
|
|||||||
pass
|
pass
|
||||||
self.log.debug('Authenticated')
|
self.log.debug('Authenticated')
|
||||||
|
|
||||||
|
# wait for `pair` to finish if any
|
||||||
|
if pair_task is not None:
|
||||||
|
self.log.debug('Wait for authentication...')
|
||||||
|
try:
|
||||||
|
await pair_task # type: ignore
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
self.log.debug('paired')
|
||||||
|
|
||||||
return WaitSecurityResponse(**kwargs)
|
return WaitSecurityResponse(**kwargs)
|
||||||
|
|
||||||
def reached_security_level(
|
def reached_security_level(
|
||||||
@@ -523,7 +551,7 @@ class SecurityStorageService(SecurityStorageServicer):
|
|||||||
self.log.debug(f"DeleteBond: {address}")
|
self.log.debug(f"DeleteBond: {address}")
|
||||||
|
|
||||||
if self.device.keystore is not None:
|
if self.device.keystore is not None:
|
||||||
with suppress(KeyError):
|
with contextlib.suppress(KeyError):
|
||||||
await self.device.keystore.delete(str(address))
|
await self.device.keystore.delete(str(address))
|
||||||
|
|
||||||
return empty_pb2.Empty()
|
return empty_pb2.Empty()
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
@@ -1771,7 +1772,26 @@ class Manager(EventEmitter):
|
|||||||
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
||||||
connection.send_l2cap_pdu(cid, command.to_bytes())
|
connection.send_l2cap_pdu(cid, command.to_bytes())
|
||||||
|
|
||||||
|
def on_smp_security_request_command(
|
||||||
|
self, connection: Connection, request: SMP_Security_Request_Command
|
||||||
|
) -> None:
|
||||||
|
connection.emit('security_request', request.auth_req)
|
||||||
|
|
||||||
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
|
||||||
|
# Parse the L2CAP payload into an SMP Command object
|
||||||
|
command = SMP_Command.from_bytes(pdu)
|
||||||
|
logger.debug(
|
||||||
|
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
||||||
|
f'{connection.peer_address}: {command}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Security request is more than just pairing, so let applications handle them
|
||||||
|
if command.code == SMP_SECURITY_REQUEST_COMMAND:
|
||||||
|
self.on_smp_security_request_command(
|
||||||
|
connection, cast(SMP_Security_Request_Command, command)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Look for a session with this connection, and create one if none exists
|
# Look for a session with this connection, and create one if none exists
|
||||||
if not (session := self.sessions.get(connection.handle)):
|
if not (session := self.sessions.get(connection.handle)):
|
||||||
if connection.role == BT_CENTRAL_ROLE:
|
if connection.role == BT_CENTRAL_ROLE:
|
||||||
@@ -1782,13 +1802,6 @@ class Manager(EventEmitter):
|
|||||||
)
|
)
|
||||||
self.sessions[connection.handle] = session
|
self.sessions[connection.handle] = session
|
||||||
|
|
||||||
# Parse the L2CAP payload into an SMP Command object
|
|
||||||
command = SMP_Command.from_bytes(pdu)
|
|
||||||
logger.debug(
|
|
||||||
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
|
|
||||||
f'{connection.peer_address}: {command}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Delegate the handling of the command to the session
|
# Delegate the handling of the command to the session
|
||||||
session.on_smp_command(command)
|
session.on_smp_command(command)
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import grpc.aio
|
import grpc.aio
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
|
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
|
||||||
|
|
||||||
# pylint: disable=no-name-in-module
|
# pylint: disable=no-name-in-module
|
||||||
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_android_emulator_transport(spec: str | None) -> Transport:
|
async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||||
'''
|
'''
|
||||||
Open a transport connection to an Android emulator via its gRPC interface.
|
Open a transport connection to an Android emulator via its gRPC interface.
|
||||||
The parameter string has this syntax:
|
The parameter string has this syntax:
|
||||||
@@ -82,7 +84,7 @@ async def open_android_emulator_transport(spec: str | None) -> Transport:
|
|||||||
logger.debug(f'connecting to gRPC server at {server_address}')
|
logger.debug(f'connecting to gRPC server at {server_address}')
|
||||||
channel = grpc.aio.insecure_channel(server_address)
|
channel = grpc.aio.insecure_channel(server_address)
|
||||||
|
|
||||||
service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
|
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
|
||||||
if mode == 'host':
|
if mode == 'host':
|
||||||
# Connect as a host
|
# Connect as a host
|
||||||
service = EmulatedBluetoothServiceStub(channel)
|
service = EmulatedBluetoothServiceStub(channel)
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ def publish_grpc_port(grpc_port) -> bool:
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_android_netsim_controller_transport(
|
async def open_android_netsim_controller_transport(
|
||||||
server_host: str | None, server_port: int
|
server_host: Optional[str], server_port: int
|
||||||
) -> Transport:
|
) -> Transport:
|
||||||
if not server_port:
|
if not server_port:
|
||||||
raise ValueError('invalid port')
|
raise ValueError('invalid port')
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import socket
|
|||||||
import ctypes
|
import ctypes
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .common import Transport, ParserSource
|
from .common import Transport, ParserSource
|
||||||
|
|
||||||
|
|
||||||
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_hci_socket_transport(spec: str | None) -> Transport:
|
async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
|
||||||
'''
|
'''
|
||||||
Open an HCI Socket (only available on some platforms).
|
Open an HCI Socket (only available on some platforms).
|
||||||
The parameter string is either empty (to use the first/default Bluetooth adapter)
|
The parameter string is either empty (to use the first/default Bluetooth adapter)
|
||||||
@@ -45,9 +47,9 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
|
|||||||
# Create a raw HCI socket
|
# Create a raw HCI socket
|
||||||
try:
|
try:
|
||||||
hci_socket = socket.socket(
|
hci_socket = socket.socket(
|
||||||
socket.AF_BLUETOOTH,
|
socket.AF_BLUETOOTH, # type: ignore[attr-defined]
|
||||||
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
|
socket.SOCK_RAW | socket.SOCK_NONBLOCK, # type: ignore[attr-defined]
|
||||||
socket.BTPROTO_HCI, # type: ignore
|
socket.BTPROTO_HCI, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
except AttributeError as error:
|
except AttributeError as error:
|
||||||
# Not supported on this platform
|
# Not supported on this platform
|
||||||
@@ -78,7 +80,7 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
|
|||||||
bind_address = struct.pack(
|
bind_address = struct.pack(
|
||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
'<HHH',
|
'<HHH',
|
||||||
socket.AF_BLUETOOTH,
|
socket.AF_BLUETOOTH, # type: ignore[attr-defined]
|
||||||
adapter_index,
|
adapter_index,
|
||||||
HCI_CHANNEL_USER,
|
HCI_CHANNEL_USER,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import atexit
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .common import Transport, StreamPacketSource, StreamPacketSink
|
from .common import Transport, StreamPacketSource, StreamPacketSink
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_pty_transport(spec: str | None) -> Transport:
|
async def open_pty_transport(spec: Optional[str]) -> Transport:
|
||||||
'''
|
'''
|
||||||
Open a PTY transport.
|
Open a PTY transport.
|
||||||
The parameter string may be empty, or a path name where a symbolic link
|
The parameter string may be empty, or a path name where a symbolic link
|
||||||
|
|||||||
@@ -17,6 +17,8 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .common import Transport
|
from .common import Transport
|
||||||
from .file import open_file_transport
|
from .file import open_file_transport
|
||||||
|
|
||||||
@@ -27,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
async def open_vhci_transport(spec: str | None) -> Transport:
|
async def open_vhci_transport(spec: Optional[str]) -> Transport:
|
||||||
'''
|
'''
|
||||||
Open a VHCI transport (only available on some platforms).
|
Open a VHCI transport (only available on some platforms).
|
||||||
The parameter string is either empty (to use the default VHCI device
|
The parameter string is either empty (to use the default VHCI device
|
||||||
|
|||||||
110
bumble/utils.py
110
bumble/utils.py
@@ -15,12 +15,24 @@
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Imports
|
# Imports
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
import collections
|
import collections
|
||||||
import sys
|
import sys
|
||||||
from typing import Awaitable, Set, TypeVar
|
from typing import (
|
||||||
|
Awaitable,
|
||||||
|
Set,
|
||||||
|
TypeVar,
|
||||||
|
List,
|
||||||
|
Tuple,
|
||||||
|
Callable,
|
||||||
|
Any,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
|
|
||||||
@@ -64,6 +76,102 @@ def composite_listener(cls):
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
_Handler = TypeVar('_Handler', bound=Callable)
|
||||||
|
|
||||||
|
|
||||||
|
class EventWatcher:
|
||||||
|
'''A wrapper class to control the lifecycle of event handlers better.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```
|
||||||
|
watcher = EventWatcher()
|
||||||
|
|
||||||
|
def on_foo():
|
||||||
|
...
|
||||||
|
watcher.on(emitter, 'foo', on_foo)
|
||||||
|
|
||||||
|
@watcher.on(emitter, 'bar')
|
||||||
|
def on_bar():
|
||||||
|
...
|
||||||
|
|
||||||
|
# Close all event handlers watching through this watcher
|
||||||
|
watcher.close()
|
||||||
|
```
|
||||||
|
|
||||||
|
As context:
|
||||||
|
```
|
||||||
|
with contextlib.closing(EventWatcher()) as context:
|
||||||
|
@context.on(emitter, 'foo')
|
||||||
|
def on_foo():
|
||||||
|
...
|
||||||
|
# on_foo() has been removed here!
|
||||||
|
```
|
||||||
|
'''
|
||||||
|
|
||||||
|
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.handlers = []
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||||
|
...
|
||||||
|
|
||||||
|
def on(
|
||||||
|
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||||
|
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
|
||||||
|
'''Watch an event until the context is closed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emitter: EventEmitter to watch
|
||||||
|
event: Event name
|
||||||
|
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def wrapper(f: _Handler) -> _Handler:
|
||||||
|
self.handlers.append((emitter, event, f))
|
||||||
|
emitter.on(event, f)
|
||||||
|
return f
|
||||||
|
|
||||||
|
return wrapper if handler is None else wrapper(handler)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
|
||||||
|
...
|
||||||
|
|
||||||
|
def once(
|
||||||
|
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
|
||||||
|
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
|
||||||
|
'''Watch an event for once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emitter: EventEmitter to watch
|
||||||
|
event: Event name
|
||||||
|
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def wrapper(f: _Handler) -> _Handler:
|
||||||
|
self.handlers.append((emitter, event, f))
|
||||||
|
emitter.once(event, f)
|
||||||
|
return f
|
||||||
|
|
||||||
|
return wrapper if handler is None else wrapper(handler)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
for emitter, event, handler in self.handlers:
|
||||||
|
if handler in emitter.listeners(event):
|
||||||
|
emitter.remove_listener(event, handler)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
_T = TypeVar('_T')
|
_T = TypeVar('_T')
|
||||||
|
|
||||||
|
|||||||
159
rust/Cargo.lock
generated
159
rust/Cargo.lock
generated
@@ -130,6 +130,16 @@ version = "2.4.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635"
|
checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bstr"
|
||||||
|
version = "1.6.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4c2f7349907b712260e64b0afe2f84692af14a454be26187d9df565c7f69266a"
|
||||||
|
dependencies = [
|
||||||
|
"memchr",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bumble"
|
name = "bumble"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -138,7 +148,9 @@ dependencies = [
|
|||||||
"clap 4.4.1",
|
"clap 4.4.1",
|
||||||
"directories",
|
"directories",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
|
"file-header",
|
||||||
"futures",
|
"futures",
|
||||||
|
"globset",
|
||||||
"hex",
|
"hex",
|
||||||
"itertools",
|
"itertools",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
@@ -272,6 +284,73 @@ version = "0.8.4"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
|
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam"
|
||||||
|
version = "0.8.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"crossbeam-channel",
|
||||||
|
"crossbeam-deque",
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-queue",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-channel"
|
||||||
|
version = "0.5.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-deque"
|
||||||
|
version = "0.8.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"crossbeam-epoch",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-epoch"
|
||||||
|
version = "0.9.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"cfg-if",
|
||||||
|
"crossbeam-utils",
|
||||||
|
"memoffset 0.9.0",
|
||||||
|
"scopeguard",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-queue"
|
||||||
|
version = "0.3.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-utils"
|
||||||
|
version = "0.8.16"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "directories"
|
name = "directories"
|
||||||
version = "5.0.1"
|
version = "5.0.1"
|
||||||
@@ -348,6 +427,19 @@ version = "2.0.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764"
|
checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "file-header"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b5568149106e77ae33bc3a2c3ef3839cbe63ffa4a8dd4a81612a6f9dfdbc2e9f"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam",
|
||||||
|
"lazy_static",
|
||||||
|
"license",
|
||||||
|
"thiserror",
|
||||||
|
"walkdir",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fnv"
|
name = "fnv"
|
||||||
version = "1.0.7"
|
version = "1.0.7"
|
||||||
@@ -484,6 +576,19 @@ version = "0.28.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
|
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "globset"
|
||||||
|
version = "0.4.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "759c97c1e17c55525b57192c06a267cda0ac5210b222d6b82189a2338fa1c13d"
|
||||||
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
|
"bstr",
|
||||||
|
"fnv",
|
||||||
|
"log",
|
||||||
|
"regex",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h2"
|
name = "h2"
|
||||||
version = "0.3.21"
|
version = "0.3.21"
|
||||||
@@ -710,6 +815,17 @@ dependencies = [
|
|||||||
"vcpkg",
|
"vcpkg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "license"
|
||||||
|
version = "3.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b66615d42e949152327c402e03cd29dab8bff91ce470381ac2ca6d380d8d9946"
|
||||||
|
dependencies = [
|
||||||
|
"reword",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linux-raw-sys"
|
name = "linux-raw-sys"
|
||||||
version = "0.4.5"
|
version = "0.4.5"
|
||||||
@@ -756,6 +872,15 @@ dependencies = [
|
|||||||
"autocfg",
|
"autocfg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "memoffset"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mime"
|
name = "mime"
|
||||||
version = "0.3.17"
|
version = "0.3.17"
|
||||||
@@ -1200,6 +1325,15 @@ dependencies = [
|
|||||||
"winreg",
|
"winreg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "reword"
|
||||||
|
version = "7.0.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fe272098dce9ed76b479995953f748d1851261390b08f8a0ff619c885a1f0765"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-segmentation",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rusb"
|
name = "rusb"
|
||||||
version = "0.9.3"
|
version = "0.9.3"
|
||||||
@@ -1241,6 +1375,15 @@ version = "1.0.15"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741"
|
checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "same-file"
|
||||||
|
version = "1.0.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||||
|
dependencies = [
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "schannel"
|
name = "schannel"
|
||||||
version = "0.1.22"
|
version = "0.1.22"
|
||||||
@@ -1589,6 +1732,12 @@ dependencies = [
|
|||||||
"tinyvec",
|
"tinyvec",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-segmentation"
|
||||||
|
version = "1.10.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unindent"
|
name = "unindent"
|
||||||
version = "0.1.11"
|
version = "0.1.11"
|
||||||
@@ -1618,6 +1767,16 @@ version = "0.2.15"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "walkdir"
|
||||||
|
version = "2.4.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee"
|
||||||
|
dependencies = [
|
||||||
|
"same-file",
|
||||||
|
"winapi-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "want"
|
name = "want"
|
||||||
version = "0.3.1"
|
version = "0.3.1"
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ itertools = "0.11.0"
|
|||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
thiserror = "1.0.41"
|
thiserror = "1.0.41"
|
||||||
|
|
||||||
|
# Dev tools
|
||||||
|
file-header = { version = "0.1.2", optional = true }
|
||||||
|
globset = { version = "0.4.13", optional = true }
|
||||||
|
|
||||||
# CLI
|
# CLI
|
||||||
anyhow = { version = "1.0.71", optional = true }
|
anyhow = { version = "1.0.71", optional = true }
|
||||||
clap = { version = "4.3.3", features = ["derive"], optional = true }
|
clap = { version = "4.3.3", features = ["derive"], optional = true }
|
||||||
@@ -52,10 +56,15 @@ env_logger = "0.10.0"
|
|||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
rustdoc-args = ["--generate-link-to-definition"]
|
rustdoc-args = ["--generate-link-to-definition"]
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "file-header"
|
||||||
|
path = "tools/file_header.rs"
|
||||||
|
required-features = ["dev-tools"]
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "gen-assigned-numbers"
|
name = "gen-assigned-numbers"
|
||||||
path = "tools/gen_assigned_numbers.rs"
|
path = "tools/gen_assigned_numbers.rs"
|
||||||
required-features = ["bumble-codegen"]
|
required-features = ["dev-tools"]
|
||||||
|
|
||||||
[[bin]]
|
[[bin]]
|
||||||
name = "bumble"
|
name = "bumble"
|
||||||
@@ -71,7 +80,7 @@ harness = false
|
|||||||
[features]
|
[features]
|
||||||
anyhow = ["pyo3/anyhow"]
|
anyhow = ["pyo3/anyhow"]
|
||||||
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
|
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
|
||||||
bumble-codegen = ["dep:anyhow"]
|
dev-tools = ["dep:anyhow", "dep:clap", "dep:file-header", "dep:globset"]
|
||||||
# separate feature for CLI so that dependencies don't spend time building these
|
# separate feature for CLI so that dependencies don't spend time building these
|
||||||
bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"]
|
bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"]
|
||||||
default = []
|
default = []
|
||||||
|
|||||||
@@ -62,5 +62,5 @@ in tests at `pytests/assigned_numbers.rs`.
|
|||||||
To regenerate the assigned number tables based on the Python codebase:
|
To regenerate the assigned number tables based on the Python codebase:
|
||||||
|
|
||||||
```
|
```
|
||||||
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features bumble-codegen
|
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features dev-tools
|
||||||
```
|
```
|
||||||
@@ -1,3 +1,17 @@
|
|||||||
|
// Copyright 2023 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
use bumble::wrapper::{self, core::Uuid16};
|
use bumble::wrapper::{self, core::Uuid16};
|
||||||
use pyo3::{intern, prelude::*, types::PyDict};
|
use pyo3::{intern, prelude::*, types::PyDict};
|
||||||
use std::collections;
|
use std::collections;
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
// Copyright 2023 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
//! BLE advertisements.
|
//! BLE advertisements.
|
||||||
|
|
||||||
use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS};
|
use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS};
|
||||||
|
|||||||
@@ -1,3 +1,17 @@
|
|||||||
|
// Copyright 2023 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
//! Bumble & Python logging
|
//! Bumble & Python logging
|
||||||
|
|
||||||
use pyo3::types::PyDict;
|
use pyo3::types::PyDict;
|
||||||
|
|||||||
78
rust/tools/file_header.rs
Normal file
78
rust/tools/file_header.rs
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
// Copyright 2023 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use clap::Parser as _;
|
||||||
|
use file_header::{
|
||||||
|
add_headers_recursively, check_headers_recursively,
|
||||||
|
license::spdx::{YearCopyrightOwnerValue, APACHE_2_0},
|
||||||
|
};
|
||||||
|
use globset::{Glob, GlobSet, GlobSetBuilder};
|
||||||
|
use std::{env, path::PathBuf};
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
let rust_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
|
||||||
|
let ignore_globset = ignore_globset()?;
|
||||||
|
// Note: when adding headers, there is a bug where the line spacing is off for Apache 2.0 (see https://github.com/spdx/license-list-XML/issues/2127)
|
||||||
|
let header = APACHE_2_0.build_header(YearCopyrightOwnerValue::new(2023, "Google LLC".into()));
|
||||||
|
|
||||||
|
let cli = Cli::parse();
|
||||||
|
|
||||||
|
match cli.subcommand {
|
||||||
|
Subcommand::CheckAll => {
|
||||||
|
let result =
|
||||||
|
check_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header, 4)?;
|
||||||
|
if result.has_failure() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"The following files do not have headers: {result:?}"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Subcommand::AddAll => {
|
||||||
|
let files_with_new_header =
|
||||||
|
add_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header)?;
|
||||||
|
files_with_new_header
|
||||||
|
.iter()
|
||||||
|
.for_each(|path| println!("Added header to: {path:?}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ignore_globset() -> anyhow::Result<GlobSet> {
|
||||||
|
Ok(GlobSetBuilder::new()
|
||||||
|
.add(Glob::new("**/.idea/**")?)
|
||||||
|
.add(Glob::new("**/target/**")?)
|
||||||
|
.add(Glob::new("**/.gitignore")?)
|
||||||
|
.add(Glob::new("**/CHANGELOG.md")?)
|
||||||
|
.add(Glob::new("**/Cargo.lock")?)
|
||||||
|
.add(Glob::new("**/Cargo.toml")?)
|
||||||
|
.add(Glob::new("**/README.md")?)
|
||||||
|
.add(Glob::new("*.bin")?)
|
||||||
|
.build()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(clap::Parser)]
|
||||||
|
struct Cli {
|
||||||
|
#[clap(subcommand)]
|
||||||
|
subcommand: Subcommand,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(clap::Subcommand, Debug, Clone)]
|
||||||
|
enum Subcommand {
|
||||||
|
/// Checks if a license is present in files that are not in the ignore list.
|
||||||
|
CheckAll,
|
||||||
|
/// Adds a license as needed to files that are not in the ignore list.
|
||||||
|
AddAll,
|
||||||
|
}
|
||||||
@@ -84,7 +84,7 @@ development =
|
|||||||
black == 22.10
|
black == 22.10
|
||||||
grpcio-tools >= 1.57.0
|
grpcio-tools >= 1.57.0
|
||||||
invoke >= 1.7.3
|
invoke >= 1.7.3
|
||||||
mypy == 1.2.0
|
mypy == 1.5.0
|
||||||
nox >= 2022
|
nox >= 2022
|
||||||
pylint == 2.15.8
|
pylint == 2.15.8
|
||||||
types-appdirs >= 1.4.3
|
types-appdirs >= 1.4.3
|
||||||
|
|||||||
@@ -891,10 +891,10 @@ async def async_main():
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
def test_attribute_string_to_permissions():
|
def test_permissions_from_string():
|
||||||
assert Attribute.string_to_permissions('READABLE') == 1
|
assert Attribute.Permissions.from_string('READABLE') == 1
|
||||||
assert Attribute.string_to_permissions('WRITEABLE') == 2
|
assert Attribute.Permissions.from_string('WRITEABLE') == 2
|
||||||
assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3
|
assert Attribute.Permissions.from_string('READABLE,WRITEABLE') == 3
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|||||||
77
tests/utils_test.py
Normal file
77
tests/utils_test.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2021-2023 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from bumble import utils
|
||||||
|
from pyee import EventEmitter
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
def test_on() -> None:
|
||||||
|
emitter = EventEmitter()
|
||||||
|
with contextlib.closing(utils.EventWatcher()) as context:
|
||||||
|
mock = MagicMock()
|
||||||
|
context.on(emitter, 'event', mock)
|
||||||
|
|
||||||
|
emitter.emit('event')
|
||||||
|
|
||||||
|
assert not emitter.listeners('event')
|
||||||
|
assert mock.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_decorator() -> None:
|
||||||
|
emitter = EventEmitter()
|
||||||
|
with contextlib.closing(utils.EventWatcher()) as context:
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
@context.on(emitter, 'event')
|
||||||
|
def on_event(*_) -> None:
|
||||||
|
mock()
|
||||||
|
|
||||||
|
emitter.emit('event')
|
||||||
|
|
||||||
|
assert not emitter.listeners('event')
|
||||||
|
assert mock.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_handlers() -> None:
|
||||||
|
emitter = EventEmitter()
|
||||||
|
with contextlib.closing(utils.EventWatcher()) as context:
|
||||||
|
mock = MagicMock()
|
||||||
|
|
||||||
|
context.once(emitter, 'a', mock)
|
||||||
|
context.once(emitter, 'b', mock)
|
||||||
|
|
||||||
|
emitter.emit('b', 'b')
|
||||||
|
|
||||||
|
assert not emitter.listeners('a')
|
||||||
|
assert not emitter.listeners('b')
|
||||||
|
|
||||||
|
mock.assert_called_once_with('b')
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
def run_tests():
|
||||||
|
test_on()
|
||||||
|
test_on_decorator()
|
||||||
|
test_multiple_handlers()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
if __name__ == '__main__':
|
||||||
|
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||||
|
run_tests()
|
||||||
Reference in New Issue
Block a user