Compare commits

...

15 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
cc96d4245f address PR comments 2023-09-27 21:25:13 -07:00
Gilles Boccon-Gibod
f1777a5bd2 use .to_string instead of a manual suffix replacement 2023-09-21 19:03:54 -07:00
Gilles Boccon-Gibod
78a06ae8cf make implementation match the doc 2023-09-21 19:01:40 -07:00
zxzxwu
d290df4aa9 Merge pull request #278 from zxzxwu/gatt2
Typing GATT
2023-09-21 16:09:36 +08:00
Josh Wu
e559744f32 Typing att 2023-09-21 15:52:07 +08:00
zxzxwu
67418e649a Merge pull request #288 from zxzxwu/l2cap_states
L2CAP: Refactor states to enums
2023-09-21 15:42:21 +08:00
Gilles Boccon-Gibod
5adf9fab53 Merge pull request #275 from whitevegagabriel/file-header
Add license header check for rust files
2023-09-20 16:21:38 -07:00
Josh Wu
2491b686fa Handle SMP_Security_Request 2023-09-20 23:13:08 +02:00
Josh Wu
efd02b2f3e Adopt reviews 2023-09-20 23:03:23 +02:00
Josh Wu
3b14078646 Overload signatures 2023-09-20 23:03:23 +02:00
Josh Wu
eb9d5632bc Add utils_test type hint 2023-09-20 23:03:23 +02:00
Josh Wu
45f60edbb6 Pyee watcher context 2023-09-20 23:03:23 +02:00
David Duarte
393ea6a7bb pandora_server: Load server config
Pandora server has it's own config that we load from the 'server'
property of the current bumble config file
2023-09-18 14:28:42 -07:00
Gabriel White-Vega
6ec6f1efe5 Add license header check for rust files
Added binary that can check for and add Apache 2.0 licenses.
Run this binary during the build-rust workflow.
2023-09-14 14:29:47 -04:00
Josh Wu
5d9598ea51 L2CAP: Refactor states to enums 2023-09-14 20:52:33 +08:00
25 changed files with 893 additions and 302 deletions

View File

@@ -65,6 +65,8 @@ jobs:
with:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build
run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
# Lints after build so what clippy needs is already built

View File

@@ -1172,7 +1172,7 @@ class ScanResult:
name = ''
# Remove any '/P' qualifier suffix from the address string
address_str = str(self.address).replace('/P', '')
address_str = self.address.to_string(with_type_qualifier=False)
# RSSI bar
bar_string = rssi_bar(self.rssi)

View File

@@ -63,7 +63,8 @@ async def get_classic_info(host):
if command_succeeded(response):
print()
print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
color('Classic Address:', 'yellow'),
response.return_parameters.bd_addr.to_string(False),
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):

View File

@@ -3,7 +3,7 @@ import click
import logging
import json
from bumble.pandora import PandoraDevice, serve
from bumble.pandora import PandoraDevice, Config, serve
from typing import Dict, Any
BUMBLE_SERVER_GRPC_PORT = 7999
@@ -29,12 +29,14 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
bumble_config = retrieve_config(config)
if 'transport' not in bumble_config.keys():
bumble_config.update({'transport': transport})
bumble_config.setdefault('transport', transport)
device = PandoraDevice(bumble_config)
server_config = Config()
server_config.load_from_dict(bumble_config.get('server', {}))
logging.basicConfig(level=logging.DEBUG)
asyncio.run(serve(device, port=grpc_port))
asyncio.run(serve(device, config=server_config, port=grpc_port))
def retrieve_config(config: str) -> Dict[str, Any]:

View File

@@ -195,7 +195,7 @@ class WebSocketOutput(QueuedOutput):
except HCI_StatusError:
pass
peer_name = '' if connection.peer_name is None else connection.peer_name
peer_address = str(connection.peer_address).replace('/P', '')
peer_address = connection.peer_address.to_string(False)
await self.send_message(
'connection',
peer_address=peer_address,
@@ -376,7 +376,7 @@ class UiServer:
if connection := self.speaker().connection:
await self.send_message(
'connection',
peer_address=str(connection.peer_address).replace('/P', ''),
peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name,
)

View File

@@ -23,13 +23,14 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import functools
import struct
from pyee import EventEmitter
from typing import Dict, Type, TYPE_CHECKING
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, get_dict_key_by_value, ProtocolError
from bumble.hci import HCI_Object, key_with_value, HCI_Constant
from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value
from bumble.colors import color
if TYPE_CHECKING:
@@ -182,6 +183,7 @@ UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
@@ -209,7 +211,7 @@ class ATT_PDU:
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
op_code = 0
name = None
name: str
@staticmethod
def from_bytes(pdu):
@@ -719,48 +721,68 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
'''
# -----------------------------------------------------------------------------
class ConnectionValue(Protocol):
def read(self, connection) -> bytes:
...
def write(self, connection, value: bytes) -> None:
...
# -----------------------------------------------------------------------------
class Attribute(EventEmitter):
# Permission flags
READABLE = 0x01
WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10
WRITE_REQUIRES_AUTHENTICATION = 0x20
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
class Permissions(enum.IntFlag):
READABLE = 0x01
WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10
WRITE_REQUIRES_AUTHENTICATION = 0x20
READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80
PERMISSION_NAMES = {
READABLE: 'READABLE',
WRITEABLE: 'WRITEABLE',
READ_REQUIRES_ENCRYPTION: 'READ_REQUIRES_ENCRYPTION',
WRITE_REQUIRES_ENCRYPTION: 'WRITE_REQUIRES_ENCRYPTION',
READ_REQUIRES_AUTHENTICATION: 'READ_REQUIRES_AUTHENTICATION',
WRITE_REQUIRES_AUTHENTICATION: 'WRITE_REQUIRES_AUTHENTICATION',
READ_REQUIRES_AUTHORIZATION: 'READ_REQUIRES_AUTHORIZATION',
WRITE_REQUIRES_AUTHORIZATION: 'WRITE_REQUIRES_AUTHORIZATION',
}
@classmethod
def from_string(cls, permissions_str: str) -> Attribute.Permissions:
try:
return functools.reduce(
lambda x, y: x | Attribute.Permissions[y],
permissions_str.replace('|', ',').split(","),
Attribute.Permissions(0),
)
except TypeError as exc:
# The check for `p.name is not None` here is needed because for InFlag
# enums, the .name property can be None, when the enum value is 0,
# so the type hint for .name is Optional[str].
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
) from exc
@staticmethod
def string_to_permissions(permissions_str: str):
try:
return functools.reduce(
lambda x, y: x | get_dict_key_by_value(Attribute.PERMISSION_NAMES, y),
permissions_str.split(","),
0,
)
except TypeError as exc:
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {','.join(Attribute.PERMISSION_NAMES.values())}\nGot: {permissions_str}"
) from exc
# Permission flags(legacy-use only)
READABLE = Permissions.READABLE
WRITEABLE = Permissions.WRITEABLE
READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
def __init__(self, attribute_type, permissions, value=b''):
value: Union[str, bytes, ConnectionValue]
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, ConnectionValue] = b'',
) -> None:
EventEmitter.__init__(self)
self.handle = 0
self.end_group_handle = 0
if isinstance(permissions, str):
self.permissions = self.string_to_permissions(permissions)
self.permissions = Attribute.Permissions.from_string(permissions)
else:
self.permissions = permissions
@@ -778,22 +800,26 @@ class Attribute(EventEmitter):
else:
self.value = value
def encode_value(self, value):
def encode_value(self, value: Any) -> bytes:
return value
def decode_value(self, value_bytes):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
def read_value(self, connection: Connection):
def read_value(self, connection: Optional[Connection]) -> bytes:
if (
self.permissions & self.READ_REQUIRES_ENCRYPTION
) and not connection.encryption:
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
and not connection.encryption
):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
self.permissions & self.READ_REQUIRES_AUTHENTICATION
) and not connection.authenticated:
(self.permissions & self.READ_REQUIRES_AUTHENTICATION)
and connection is not None
and not connection.authenticated
):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
@@ -803,9 +829,9 @@ class Attribute(EventEmitter):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
if read := getattr(self.value, 'read', None):
if hasattr(self.value, 'read'):
try:
value = read(connection) # pylint: disable=not-callable
value = self.value.read(connection)
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
@@ -815,7 +841,7 @@ class Attribute(EventEmitter):
return self.encode_value(value)
def write_value(self, connection: Connection, value_bytes):
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
@@ -836,9 +862,9 @@ class Attribute(EventEmitter):
value = self.decode_value(value_bytes)
if write := getattr(self.value, 'write', None):
if hasattr(self.value, 'write'):
try:
write(connection, value) # pylint: disable=not-callable
self.value.write(connection, value) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle

View File

@@ -28,7 +28,7 @@ import enum
import functools
import logging
import struct
from typing import Optional, Sequence, List
from typing import Optional, Sequence, Iterable, List, Union
from .colors import color
from .core import UUID, get_dict_key_by_value
@@ -187,7 +187,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
# -----------------------------------------------------------------------------
def show_services(services):
def show_services(services: Iterable[Service]) -> None:
for service in services:
print(color(str(service), 'cyan'))
@@ -210,11 +210,11 @@ class Service(Attribute):
def __init__(
self,
uuid,
uuid: Union[str, UUID],
characteristics: List[Characteristic],
primary=True,
included_services: List[Service] = [],
):
) -> None:
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
uuid = UUID(uuid)
@@ -239,7 +239,7 @@ class Service(Attribute):
"""
return None
def __str__(self):
def __str__(self) -> str:
return (
f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
@@ -255,9 +255,11 @@ class TemplateService(Service):
to expose their UUID as a class property
'''
UUID: Optional[UUID] = None
UUID: UUID
def __init__(self, characteristics, primary=True):
def __init__(
self, characteristics: List[Characteristic], primary: bool = True
) -> None:
super().__init__(self.UUID, characteristics, primary)
@@ -269,7 +271,7 @@ class IncludedServiceDeclaration(Attribute):
service: Service
def __init__(self, service):
def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
)
@@ -278,7 +280,7 @@ class IncludedServiceDeclaration(Attribute):
)
self.service = service
def __str__(self):
def __str__(self) -> str:
return (
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
f'group_starting_handle=0x{self.service.handle:04X}, '
@@ -326,7 +328,7 @@ class Characteristic(Attribute):
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
)
def __str__(self):
def __str__(self) -> str:
# NOTE: we override this method to offer a consistent result between python
# versions: the value returned by IntFlag.__str__() changed in version 11.
return '|'.join(
@@ -348,10 +350,10 @@ class Characteristic(Attribute):
def __init__(
self,
uuid,
uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties,
permissions,
value=b'',
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, CharacteristicValue] = b'',
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value)
@@ -369,7 +371,7 @@ class Characteristic(Attribute):
def has_properties(self, properties: Characteristic.Properties) -> bool:
return self.properties & properties == properties
def __str__(self):
def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
@@ -386,7 +388,7 @@ class CharacteristicDeclaration(Attribute):
characteristic: Characteristic
def __init__(self, characteristic, value_handle):
def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
@@ -397,7 +399,7 @@ class CharacteristicDeclaration(Attribute):
self.value_handle = value_handle
self.characteristic = characteristic
def __str__(self):
def __str__(self) -> str:
return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
@@ -520,7 +522,7 @@ class CharacteristicAdapter:
return self.wrapped_characteristic.unsubscribe(subscriber)
def __str__(self):
def __str__(self) -> str:
wrapped = str(self.wrapped_characteristic)
return f'{self.__class__.__name__}({wrapped})'
@@ -600,10 +602,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
def encode_value(self, value):
def encode_value(self, value: str) -> bytes:
return value.encode('utf-8')
def decode_value(self, value):
def decode_value(self, value: bytes) -> str:
return value.decode('utf-8')
@@ -613,7 +615,7 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
'''
def __str__(self):
def __str__(self) -> str:
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '

View File

@@ -28,7 +28,18 @@ import asyncio
import logging
import struct
from datetime import datetime
from typing import List, Optional, Dict, Tuple, Callable, Union, Any
from typing import (
List,
Optional,
Dict,
Tuple,
Callable,
Union,
Any,
Iterable,
Type,
TYPE_CHECKING,
)
from pyee import EventEmitter
@@ -66,8 +77,12 @@ from .gatt import (
GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
TemplateService,
)
if TYPE_CHECKING:
from bumble.device import Connection
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -78,16 +93,16 @@ logger = logging.getLogger(__name__)
# Proxies
# -----------------------------------------------------------------------------
class AttributeProxy(EventEmitter):
client: Client
def __init__(self, client, handle, end_group_handle, attribute_type):
def __init__(
self, client: Client, handle: int, end_group_handle: int, attribute_type: UUID
) -> None:
EventEmitter.__init__(self)
self.client = client
self.handle = handle
self.end_group_handle = end_group_handle
self.type = attribute_type
async def read_value(self, no_long_read=False):
async def read_value(self, no_long_read: bool = False) -> bytes:
return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
@@ -97,13 +112,13 @@ class AttributeProxy(EventEmitter):
self.handle, self.encode_value(value), with_response
)
def encode_value(self, value):
def encode_value(self, value: Any) -> bytes:
return value
def decode_value(self, value_bytes):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
def __str__(self):
def __str__(self) -> str:
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
@@ -136,14 +151,14 @@ class ServiceProxy(AttributeProxy):
def get_characteristics_by_uuid(self, uuid):
return self.client.get_characteristics_by_uuid(uuid, self)
def __str__(self):
def __str__(self) -> str:
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
class CharacteristicProxy(AttributeProxy):
properties: Characteristic.Properties
descriptors: List[DescriptorProxy]
subscribers: Dict[Any, Callable]
subscribers: Dict[Any, Callable[[bytes], Any]]
def __init__(
self,
@@ -171,7 +186,9 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.discover_descriptors(self)
async def subscribe(
self, subscriber: Optional[Callable] = None, prefer_notify=True
self,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
):
if subscriber is not None:
if subscriber in self.subscribers:
@@ -195,7 +212,7 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.unsubscribe(self, subscriber)
def __str__(self):
def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, '
@@ -207,7 +224,7 @@ class DescriptorProxy(AttributeProxy):
def __init__(self, client, handle, descriptor_type):
super().__init__(client, handle, 0, descriptor_type)
def __str__(self):
def __str__(self) -> str:
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type})'
@@ -216,8 +233,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies
'''
SERVICE_CLASS: Type[TemplateService]
@classmethod
def from_client(cls, client):
def from_client(cls, client: Client) -> ProfileServiceProxy:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -227,8 +246,12 @@ class ProfileServiceProxy:
class Client:
services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[int, Callable[[bytes], Any]]
indication_subscribers: Dict[int, Callable[[bytes], Any]]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU]
def __init__(self, connection):
def __init__(self, connection: Connection) -> None:
self.connection = connection
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
@@ -241,16 +264,16 @@ class Client:
self.services = []
self.cached_values = {}
def send_gatt_pdu(self, pdu):
def send_gatt_pdu(self, pdu: bytes) -> None:
self.connection.send_l2cap_pdu(ATT_CID, pdu)
async def send_command(self, command):
async def send_command(self, command: ATT_PDU) -> None:
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes())
async def send_request(self, request):
async def send_request(self, request: ATT_PDU):
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
@@ -279,14 +302,14 @@ class Client:
return response
def send_confirmation(self, confirmation):
def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes())
async def request_mtu(self, mtu):
async def request_mtu(self, mtu: int) -> int:
# Check the range
if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}')
@@ -313,10 +336,12 @@ class Client:
return self.connection.att_mtu
def get_services_by_uuid(self, uuid):
def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid(self, uuid, service=None):
def get_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy] = None
) -> List[CharacteristicProxy]:
services = [service] if service else self.services
return [
c
@@ -363,7 +388,7 @@ class Client:
if not already_known:
self.services.append(service)
async def discover_services(self, uuids=None) -> List[ServiceProxy]:
async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.1 Discover All Primary Services
'''
@@ -435,7 +460,7 @@ class Client:
return services
async def discover_service(self, uuid):
async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
'''
@@ -468,7 +493,7 @@ class Client:
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
return []
break
for attribute_handle, end_group_handle in response.handles_information:
@@ -480,7 +505,7 @@ class Client:
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return
return []
# Create a service proxy for this service
service = ServiceProxy(
@@ -721,7 +746,7 @@ class Client:
return descriptors
async def discover_attributes(self):
async def discover_attributes(self) -> List[AttributeProxy]:
'''
Discover all attributes, regardless of type
'''
@@ -844,7 +869,9 @@ class Client:
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(self, attribute, no_long_read=False):
async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> Any:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -905,7 +932,9 @@ class Client:
# Return the value as bytes
return attribute_value
async def read_characteristics_by_uuid(self, uuid, service):
async def read_characteristics_by_uuid(
self, uuid: UUID, service: Optional[ServiceProxy]
) -> List[bytes]:
'''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
'''
@@ -960,7 +989,12 @@ class Client:
return characteristics_values
async def write_value(self, attribute, value, with_response=False):
async def write_value(
self,
attribute: Union[int, AttributeProxy],
value: bytes,
with_response: bool = False,
) -> None:
'''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value
@@ -990,7 +1024,7 @@ class Client:
)
)
def on_gatt_pdu(self, att_pdu):
def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
@@ -1013,6 +1047,7 @@ class Client:
return
# Return the response to the coroutine that is waiting for it
assert self.pending_response is not None
self.pending_response.set_result(att_pdu)
else:
handler_name = f'on_{att_pdu.name.lower()}'
@@ -1060,7 +1095,7 @@ class Client:
# Confirm that we received the indication
self.send_confirmation(ATT_Handle_Value_Confirmation())
def cache_value(self, attribute_handle: int, value: bytes):
def cache_value(self, attribute_handle: int, value: bytes) -> None:
self.cached_values[attribute_handle] = (
datetime.now(),
value,

View File

@@ -23,11 +23,12 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
from collections import defaultdict
import struct
from typing import List, Tuple, Optional, TypeVar, Type
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter
from .colors import color
@@ -42,6 +43,7 @@ from .att import (
ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS,
ATT_PDU,
ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error,
@@ -73,6 +75,8 @@ from .gatt import (
Service,
)
if TYPE_CHECKING:
from bumble.device import Device, Connection
# -----------------------------------------------------------------------------
# Logging
@@ -91,8 +95,13 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
class Server(EventEmitter):
attributes: List[Attribute]
services: List[Service]
attributes_by_handle: Dict[int, Attribute]
subscribers: Dict[int, Dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
def __init__(self, device):
def __init__(self, device: Device) -> None:
super().__init__()
self.device = device
self.services = []
@@ -107,16 +116,16 @@ class Server(EventEmitter):
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
def __str__(self):
def __str__(self) -> str:
return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle, pdu):
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
def next_handle(self):
def next_handle(self) -> int:
return 1 + len(self.attributes)
def get_advertising_service_data(self):
def get_advertising_service_data(self) -> Dict[Attribute, bytes]:
return {
attribute: data
for attribute in self.attributes
@@ -124,7 +133,7 @@ class Server(EventEmitter):
and (data := attribute.get_advertising_data())
}
def get_attribute(self, handle):
def get_attribute(self, handle: int) -> Optional[Attribute]:
attribute = self.attributes_by_handle.get(handle)
if attribute:
return attribute
@@ -173,12 +182,17 @@ class Server(EventEmitter):
return next(
(
(attribute, self.get_attribute(attribute.characteristic.handle))
(
attribute,
self.get_attribute(attribute.characteristic.handle),
) # type: ignore
for attribute in map(
self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1),
)
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
if attribute is not None
and attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
and isinstance(attribute, CharacteristicDeclaration)
and attribute.characteristic.uuid == characteristic_uuid
),
None,
@@ -197,7 +211,7 @@ class Server(EventEmitter):
return next(
(
attribute
attribute # type: ignore
for attribute in map(
self.get_attribute,
range(
@@ -205,12 +219,12 @@ class Server(EventEmitter):
characteristic_value.end_group_handle + 1,
),
)
if attribute.type == descriptor_uuid
if attribute is not None and attribute.type == descriptor_uuid
),
None,
)
def add_attribute(self, attribute):
def add_attribute(self, attribute: Attribute) -> None:
# Assign a handle to this attribute
attribute.handle = self.next_handle()
attribute.end_group_handle = (
@@ -220,7 +234,7 @@ class Server(EventEmitter):
# Add this attribute to the list
self.attributes.append(attribute)
def add_service(self, service: Service):
def add_service(self, service: Service) -> None:
# Add the service attribute to the DB
self.add_attribute(service)
@@ -285,11 +299,13 @@ class Server(EventEmitter):
service.end_group_handle = self.attributes[-1].handle
self.services.append(service)
def add_services(self, services):
def add_services(self, services: Iterable[Service]) -> None:
for service in services:
self.add_service(service)
def read_cccd(self, connection, characteristic):
def read_cccd(
self, connection: Optional[Connection], characteristic: Characteristic
) -> bytes:
if connection is None:
return bytes([0, 0])
@@ -300,7 +316,12 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value):
def write_cccd(
self,
connection: Connection,
characteristic: Characteristic,
value: bytes,
) -> None:
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
@@ -327,13 +348,19 @@ class Server(EventEmitter):
indicate_enabled,
)
def send_response(self, connection, response):
def send_response(self, connection: Connection, response: ATT_PDU) -> None:
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, value=None, force=False):
async def notify_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -370,7 +397,13 @@ class Server(EventEmitter):
)
self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
async def indicate_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)
@@ -411,15 +444,13 @@ class Server(EventEmitter):
assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response
self.pending_confirmations[
pending_confirmation = self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
try:
self.send_gatt_pdu(connection.handle, indication.to_bytes())
await asyncio.wait_for(
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error
@@ -427,8 +458,12 @@ class Server(EventEmitter):
self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers(
self, indicate, attribute, value=None, force=False
):
self,
indicate: bool,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Get all the connections for which there's at least one subscription
connections = [
connection
@@ -450,13 +485,23 @@ class Server(EventEmitter):
]
)
async def notify_subscribers(self, attribute, value=None, force=False):
async def notify_subscribers(
self,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
async def indicate_subscribers(self, attribute, value=None, force=False):
async def indicate_subscribers(
self,
attribute: Attribute,
value: Optional[bytes] = None,
force: bool = False,
):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, connection):
def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
@@ -464,7 +509,7 @@ class Server(EventEmitter):
if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle]
def on_gatt_pdu(self, connection, att_pdu):
def on_gatt_pdu(self, connection: Connection, att_pdu: ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None)
@@ -506,7 +551,7 @@ class Server(EventEmitter):
#######################################################
# ATT handlers
#######################################################
def on_att_request(self, connection, pdu):
def on_att_request(self, connection: Connection, pdu: ATT_PDU) -> None:
'''
Handler for requests without a more specific handler
'''
@@ -679,7 +724,6 @@ class Server(EventEmitter):
and attribute.handle <= request.ending_handle
and pdu_space_available
):
try:
attribute_value = attribute.read_value(connection)
except ATT_Error as error:

View File

@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import logging
import struct
@@ -676,56 +677,35 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
class Channel(EventEmitter):
# States
CLOSED = 0x00
WAIT_CONNECT = 0x01
WAIT_CONNECT_RSP = 0x02
OPEN = 0x03
WAIT_DISCONNECT = 0x04
WAIT_CREATE = 0x05
WAIT_CREATE_RSP = 0x06
WAIT_MOVE = 0x07
WAIT_MOVE_RSP = 0x08
WAIT_MOVE_CONFIRM = 0x09
WAIT_CONFIRM_RSP = 0x0A
class State(enum.IntEnum):
# States
CLOSED = 0x00
WAIT_CONNECT = 0x01
WAIT_CONNECT_RSP = 0x02
OPEN = 0x03
WAIT_DISCONNECT = 0x04
WAIT_CREATE = 0x05
WAIT_CREATE_RSP = 0x06
WAIT_MOVE = 0x07
WAIT_MOVE_RSP = 0x08
WAIT_MOVE_CONFIRM = 0x09
WAIT_CONFIRM_RSP = 0x0A
# CONFIG substates
WAIT_CONFIG = 0x10
WAIT_SEND_CONFIG = 0x11
WAIT_CONFIG_REQ_RSP = 0x12
WAIT_CONFIG_RSP = 0x13
WAIT_CONFIG_REQ = 0x14
WAIT_IND_FINAL_RSP = 0x15
WAIT_FINAL_RSP = 0x16
WAIT_CONTROL_IND = 0x17
STATE_NAMES = {
CLOSED: 'CLOSED',
WAIT_CONNECT: 'WAIT_CONNECT',
WAIT_CONNECT_RSP: 'WAIT_CONNECT_RSP',
OPEN: 'OPEN',
WAIT_DISCONNECT: 'WAIT_DISCONNECT',
WAIT_CREATE: 'WAIT_CREATE',
WAIT_CREATE_RSP: 'WAIT_CREATE_RSP',
WAIT_MOVE: 'WAIT_MOVE',
WAIT_MOVE_RSP: 'WAIT_MOVE_RSP',
WAIT_MOVE_CONFIRM: 'WAIT_MOVE_CONFIRM',
WAIT_CONFIRM_RSP: 'WAIT_CONFIRM_RSP',
WAIT_CONFIG: 'WAIT_CONFIG',
WAIT_SEND_CONFIG: 'WAIT_SEND_CONFIG',
WAIT_CONFIG_REQ_RSP: 'WAIT_CONFIG_REQ_RSP',
WAIT_CONFIG_RSP: 'WAIT_CONFIG_RSP',
WAIT_CONFIG_REQ: 'WAIT_CONFIG_REQ',
WAIT_IND_FINAL_RSP: 'WAIT_IND_FINAL_RSP',
WAIT_FINAL_RSP: 'WAIT_FINAL_RSP',
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
}
# CONFIG substates
WAIT_CONFIG = 0x10
WAIT_SEND_CONFIG = 0x11
WAIT_CONFIG_REQ_RSP = 0x12
WAIT_CONFIG_RSP = 0x13
WAIT_CONFIG_REQ = 0x14
WAIT_IND_FINAL_RSP = 0x15
WAIT_FINAL_RSP = 0x16
WAIT_CONTROL_IND = 0x17
connection_result: Optional[asyncio.Future[None]]
disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]]
state: int
state: State
connection: Connection
def __init__(
@@ -741,7 +721,7 @@ class Channel(EventEmitter):
self.manager = manager
self.connection = connection
self.signaling_cid = signaling_cid
self.state = Channel.CLOSED
self.state = self.State.CLOSED
self.mtu = mtu
self.psm = psm
self.source_cid = source_cid
@@ -751,10 +731,8 @@ class Channel(EventEmitter):
self.disconnection_result = None
self.sink = None
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
)
def _change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
@@ -767,7 +745,7 @@ class Channel(EventEmitter):
# Check that there isn't already a request pending
if self.response:
raise InvalidStateError('request already pending')
if self.state != Channel.OPEN:
if self.state != self.State.OPEN:
raise InvalidStateError('channel not open')
self.response = asyncio.get_running_loop().create_future()
@@ -787,14 +765,14 @@ class Channel(EventEmitter):
)
async def connect(self) -> None:
if self.state != Channel.CLOSED:
if self.state != self.State.CLOSED:
raise InvalidStateError('invalid state')
# Check that we can start a new connection
if self.connection_result:
raise RuntimeError('connection already pending')
self.change_state(Channel.WAIT_CONNECT_RSP)
self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame(
L2CAP_Connection_Request(
identifier=self.manager.next_identifier(self.connection),
@@ -814,10 +792,10 @@ class Channel(EventEmitter):
self.connection_result = None
async def disconnect(self) -> None:
if self.state != Channel.OPEN:
if self.state != self.State.OPEN:
raise InvalidStateError('invalid state')
self.change_state(Channel.WAIT_DISCONNECT)
self._change_state(self.State.WAIT_DISCONNECT)
self.send_control_frame(
L2CAP_Disconnection_Request(
identifier=self.manager.next_identifier(self.connection),
@@ -832,8 +810,8 @@ class Channel(EventEmitter):
return await self.disconnection_result
def abort(self) -> None:
if self.state == self.OPEN:
self.change_state(self.CLOSED)
if self.state == self.State.OPEN:
self._change_state(self.State.CLOSED)
self.emit('close')
def send_configure_request(self) -> None:
@@ -856,7 +834,7 @@ class Channel(EventEmitter):
def on_connection_request(self, request) -> None:
self.destination_cid = request.source_cid
self.change_state(Channel.WAIT_CONNECT)
self._change_state(self.State.WAIT_CONNECT)
self.send_control_frame(
L2CAP_Connection_Response(
identifier=request.identifier,
@@ -866,24 +844,24 @@ class Channel(EventEmitter):
status=0x0000,
)
)
self.change_state(Channel.WAIT_CONFIG)
self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
def on_connection_response(self, response):
if self.state != Channel.WAIT_CONNECT_RSP:
if self.state != self.State.WAIT_CONNECT_RSP:
logger.warning(color('invalid state', 'red'))
return
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
self.destination_cid = response.destination_cid
self.change_state(Channel.WAIT_CONFIG)
self._change_state(self.State.WAIT_CONFIG)
self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_REQ_RSP)
self._change_state(self.State.WAIT_CONFIG_REQ_RSP)
elif response.result == L2CAP_Connection_Response.CONNECTION_PENDING:
pass
else:
self.change_state(Channel.CLOSED)
self._change_state(self.State.CLOSED)
self.connection_result.set_exception(
ProtocolError(
response.result,
@@ -895,9 +873,9 @@ class Channel(EventEmitter):
def on_configure_request(self, request) -> None:
if self.state not in (
Channel.WAIT_CONFIG,
Channel.WAIT_CONFIG_REQ,
Channel.WAIT_CONFIG_REQ_RSP,
self.State.WAIT_CONFIG,
self.State.WAIT_CONFIG_REQ,
self.State.WAIT_CONFIG_REQ_RSP,
):
logger.warning(color('invalid state', 'red'))
return
@@ -918,25 +896,28 @@ class Channel(EventEmitter):
options=request.options, # TODO: don't accept everything blindly
)
)
if self.state == Channel.WAIT_CONFIG:
self.change_state(Channel.WAIT_SEND_CONFIG)
if self.state == self.State.WAIT_CONFIG:
self._change_state(self.State.WAIT_SEND_CONFIG)
self.send_configure_request()
self.change_state(Channel.WAIT_CONFIG_RSP)
elif self.state == Channel.WAIT_CONFIG_REQ:
self.change_state(Channel.OPEN)
self._change_state(self.State.WAIT_CONFIG_RSP)
elif self.state == self.State.WAIT_CONFIG_REQ:
self._change_state(self.State.OPEN)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
self.emit('open')
elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_RSP)
elif self.state == self.State.WAIT_CONFIG_REQ_RSP:
self._change_state(self.State.WAIT_CONFIG_RSP)
def on_configure_response(self, response) -> None:
if response.result == L2CAP_Configure_Response.SUCCESS:
if self.state == Channel.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_REQ)
elif self.state in (Channel.WAIT_CONFIG_RSP, Channel.WAIT_CONTROL_IND):
self.change_state(Channel.OPEN)
if self.state == self.State.WAIT_CONFIG_REQ_RSP:
self._change_state(self.State.WAIT_CONFIG_REQ)
elif self.state in (
self.State.WAIT_CONFIG_RSP,
self.State.WAIT_CONTROL_IND,
):
self._change_state(self.State.OPEN)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
@@ -966,7 +947,7 @@ class Channel(EventEmitter):
# TODO: decide how to fail gracefully
def on_disconnection_request(self, request) -> None:
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
if self.state in (self.State.OPEN, self.State.WAIT_DISCONNECT):
self.send_control_frame(
L2CAP_Disconnection_Response(
identifier=request.identifier,
@@ -974,14 +955,14 @@ class Channel(EventEmitter):
source_cid=request.source_cid,
)
)
self.change_state(Channel.CLOSED)
self._change_state(self.State.CLOSED)
self.emit('close')
self.manager.on_channel_closed(self)
else:
logger.warning(color('invalid state', 'red'))
def on_disconnection_response(self, response) -> None:
if self.state != Channel.WAIT_DISCONNECT:
if self.state != self.State.WAIT_DISCONNECT:
logger.warning(color('invalid state', 'red'))
return
@@ -992,7 +973,7 @@ class Channel(EventEmitter):
logger.warning('unexpected source or destination CID')
return
self.change_state(Channel.CLOSED)
self._change_state(self.State.CLOSED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -1004,7 +985,7 @@ class Channel(EventEmitter):
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
f'MTU={self.mtu}, '
f'state={Channel.STATE_NAMES[self.state]})'
f'state={self.state.name})'
)
@@ -1014,33 +995,21 @@ class LeConnectionOrientedChannel(EventEmitter):
LE Credit-based Connection Oriented Channel
"""
INIT = 0
CONNECTED = 1
CONNECTING = 2
DISCONNECTING = 3
DISCONNECTED = 4
CONNECTION_ERROR = 5
STATE_NAMES = {
INIT: 'INIT',
CONNECTED: 'CONNECTED',
CONNECTING: 'CONNECTING',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
CONNECTION_ERROR: 'CONNECTION_ERROR',
}
class State(enum.IntEnum):
INIT = 0
CONNECTED = 1
CONNECTING = 2
DISCONNECTING = 3
DISCONNECTED = 4
CONNECTION_ERROR = 5
out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes]
state: int
state: State
connection: Connection
@staticmethod
def state_name(state: int) -> str:
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
def __init__(
self,
manager: ChannelManager,
@@ -1083,19 +1052,17 @@ class LeConnectionOrientedChannel(EventEmitter):
self.drained.set()
if connected:
self.state = LeConnectionOrientedChannel.CONNECTED
self.state = self.State.CONNECTED
else:
self.state = LeConnectionOrientedChannel.INIT
self.state = self.State.INIT
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
def _change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state
if new_state == self.CONNECTED:
if new_state == self.State.CONNECTED:
self.emit('open')
elif new_state == self.DISCONNECTED:
elif new_state == self.State.DISCONNECTED:
self.emit('close')
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
@@ -1106,7 +1073,7 @@ class LeConnectionOrientedChannel(EventEmitter):
async def connect(self) -> LeConnectionOrientedChannel:
# Check that we're in the right state
if self.state != self.INIT:
if self.state != self.State.INIT:
raise InvalidStateError('not in a connectable state')
# Check that we can start a new connection
@@ -1114,7 +1081,7 @@ class LeConnectionOrientedChannel(EventEmitter):
if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests')
self.change_state(self.CONNECTING)
self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request(
identifier=identifier,
le_psm=self.le_psm,
@@ -1134,10 +1101,10 @@ class LeConnectionOrientedChannel(EventEmitter):
async def disconnect(self) -> None:
# Check that we're connected
if self.state != self.CONNECTED:
if self.state != self.State.CONNECTED:
raise InvalidStateError('not connected')
self.change_state(self.DISCONNECTING)
self._change_state(self.State.DISCONNECTING)
self.flush_output()
self.send_control_frame(
L2CAP_Disconnection_Request(
@@ -1153,15 +1120,15 @@ class LeConnectionOrientedChannel(EventEmitter):
return await self.disconnection_result
def abort(self) -> None:
if self.state == self.CONNECTED:
self.change_state(self.DISCONNECTED)
if self.state == self.State.CONNECTED:
self._change_state(self.State.DISCONNECTED)
def on_pdu(self, pdu: bytes) -> None:
if self.sink is None:
logger.warning('received pdu without a sink')
return
if self.state != self.CONNECTED:
if self.state != self.State.CONNECTED:
logger.warning('received PDU while not connected, dropping')
# Manage the peer credits
@@ -1240,7 +1207,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.credits = response.initial_credits
self.connected = True
self.connection_result.set_result(self)
self.change_state(self.CONNECTED)
self._change_state(self.State.CONNECTED)
else:
self.connection_result.set_exception(
ProtocolError(
@@ -1251,7 +1218,7 @@ class LeConnectionOrientedChannel(EventEmitter):
),
)
)
self.change_state(self.CONNECTION_ERROR)
self._change_state(self.State.CONNECTION_ERROR)
# Cleanup
self.connection_result = None
@@ -1271,11 +1238,11 @@ class LeConnectionOrientedChannel(EventEmitter):
source_cid=request.source_cid,
)
)
self.change_state(self.DISCONNECTED)
self._change_state(self.State.DISCONNECTED)
self.flush_output()
def on_disconnection_response(self, response) -> None:
if self.state != self.DISCONNECTING:
if self.state != self.State.DISCONNECTING:
logger.warning(color('invalid state', 'red'))
return
@@ -1286,7 +1253,7 @@ class LeConnectionOrientedChannel(EventEmitter):
logger.warning('unexpected source or destination CID')
return
self.change_state(self.DISCONNECTED)
self._change_state(self.State.DISCONNECTED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
@@ -1339,7 +1306,7 @@ class LeConnectionOrientedChannel(EventEmitter):
return
def write(self, data: bytes) -> None:
if self.state != self.CONNECTED:
if self.state != self.State.CONNECTED:
logger.warning('not connected, dropping data')
return
@@ -1367,7 +1334,7 @@ class LeConnectionOrientedChannel(EventEmitter):
def __str__(self) -> str:
return (
f'CoC({self.source_cid}->{self.destination_cid}, '
f'State={self.state_name(self.state)}, '
f'State={self.state.name}, '
f'PSM={self.le_psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, '
f'MPS={self.mps}/{self.peer_mps}, '

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
import contextlib
import grpc
import logging
@@ -27,8 +28,8 @@ from bumble.core import (
)
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from contextlib import suppress
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
@@ -294,23 +295,35 @@ class SecurityService(SecurityServicer):
try:
self.log.debug('Pair...')
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
wait_for_security: asyncio.Future[
bool
] = asyncio.get_running_loop().create_future()
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
connection.on("pairing_failure", wait_for_security.set_exception)
security_result = asyncio.get_running_loop().create_future()
connection.request_pairing()
with contextlib.closing(EventWatcher()) as watcher:
await wait_for_security
else:
await connection.pair()
@watcher.on(connection, 'pairing')
def on_pairing(*_: Any) -> None:
security_result.set_result('success')
self.log.debug('Paired')
@watcher.on(connection, 'pairing_failure')
def on_pairing_failure(*_: Any) -> None:
security_result.set_result('pairing_failure')
@watcher.on(connection, 'disconnection')
def on_disconnection(*_: Any) -> None:
security_result.set_result('connection_died')
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
connection.request_pairing()
else:
await connection.pair()
result = await security_result
self.log.debug(f'Pairing session complete, status={result}')
if result != 'success':
return SecureResponse(**{result: empty_pb2.Empty()})
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -369,6 +382,7 @@ class SecurityService(SecurityServicer):
str
] = asyncio.get_running_loop().create_future()
authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None:
assert connection
@@ -415,6 +429,10 @@ class SecurityService(SecurityServicer):
if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate())
def pair(*_: Any) -> None:
if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair())
listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
@@ -425,6 +443,7 @@ class SecurityService(SecurityServicer):
'connection_encryption_change': on_encryption_change,
'classic_pairing': try_set_success,
'classic_pairing_failure': set_failure('pairing_failure'),
'security_request': pair,
}
# register event handlers
@@ -452,6 +471,15 @@ class SecurityService(SecurityServicer):
pass
self.log.debug('Authenticated')
# wait for `pair` to finish if any
if pair_task is not None:
self.log.debug('Wait for authentication...')
try:
await pair_task # type: ignore
except:
pass
self.log.debug('paired')
return WaitSecurityResponse(**kwargs)
def reached_security_level(
@@ -523,7 +551,7 @@ class SecurityStorageService(SecurityStorageServicer):
self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None:
with suppress(KeyError):
with contextlib.suppress(KeyError):
await self.device.keystore.delete(str(address))
return empty_pb2.Empty()

View File

@@ -37,6 +37,7 @@ from typing import (
Optional,
Tuple,
Type,
cast,
)
from pyee import EventEmitter
@@ -1771,7 +1772,26 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes())
def on_smp_security_request_command(
self, connection: Connection, request: SMP_Security_Request_Command
) -> None:
connection.emit('security_request', request.auth_req)
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Security request is more than just pairing, so let applications handle them
if command.code == SMP_SECURITY_REQUEST_COMMAND:
self.on_smp_security_request_command(
connection, cast(SMP_Security_Request_Command, command)
)
return
# Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE:
@@ -1782,13 +1802,6 @@ class Manager(EventEmitter):
)
self.sessions[connection.handle] = session
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Delegate the handling of the command to the session
session.on_smp_command(command)

View File

@@ -31,14 +31,12 @@ async def open_ws_client_transport(spec: str) -> Transport:
'''
Open a WebSocket client transport.
The parameter string has this syntax:
<remote-host>:<remote-port>
<websocket-url>
Example: 127.0.0.1:9001
Example: ws://localhost:7681/v1/websocket/bt
'''
remote_host, remote_port = spec.split(':')
uri = f'ws://{remote_host}:{remote_port}'
websocket = await websockets.client.connect(uri)
websocket = await websockets.client.connect(spec)
transport = PumpedTransport(
PumpedPacketSource(websocket.recv),

View File

@@ -15,12 +15,24 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import traceback
import collections
import sys
from typing import Awaitable, Set, TypeVar
from typing import (
Awaitable,
Set,
TypeVar,
List,
Tuple,
Callable,
Any,
Optional,
Union,
overload,
)
from functools import wraps
from pyee import EventEmitter
@@ -64,6 +76,102 @@ def composite_listener(cls):
return cls
# -----------------------------------------------------------------------------
_Handler = TypeVar('_Handler', bound=Callable)
class EventWatcher:
'''A wrapper class to control the lifecycle of event handlers better.
Usage:
```
watcher = EventWatcher()
def on_foo():
...
watcher.on(emitter, 'foo', on_foo)
@watcher.on(emitter, 'bar')
def on_bar():
...
# Close all event handlers watching through this watcher
watcher.close()
```
As context:
```
with contextlib.closing(EventWatcher()) as context:
@context.on(emitter, 'foo')
def on_foo():
...
# on_foo() has been removed here!
```
'''
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
def __init__(self) -> None:
self.handlers = []
@overload
def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
...
@overload
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
...
def on(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event until the context is closed.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
'''
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.on(event, f)
return f
return wrapper if handler is None else wrapper(handler)
@overload
def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
...
@overload
def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
...
def once(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event for once.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
'''
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.once(event, f)
return f
return wrapper if handler is None else wrapper(handler)
def close(self) -> None:
for emitter, event, handler in self.handlers:
if handler in emitter.listeners(event):
emitter.remove_listener(event, handler)
# -----------------------------------------------------------------------------
_T = TypeVar('_T')

159
rust/Cargo.lock generated
View File

@@ -130,6 +130,16 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635"
[[package]]
name = "bstr"
version = "1.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c2f7349907b712260e64b0afe2f84692af14a454be26187d9df565c7f69266a"
dependencies = [
"memchr",
"serde",
]
[[package]]
name = "bumble"
version = "0.1.0"
@@ -138,7 +148,9 @@ dependencies = [
"clap 4.4.1",
"directories",
"env_logger",
"file-header",
"futures",
"globset",
"hex",
"itertools",
"lazy_static",
@@ -272,6 +284,73 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]]
name = "crossbeam"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
dependencies = [
"cfg-if",
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
dependencies = [
"cfg-if",
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7"
dependencies = [
"autocfg",
"cfg-if",
"crossbeam-utils",
"memoffset 0.9.0",
"scopeguard",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
dependencies = [
"cfg-if",
]
[[package]]
name = "directories"
version = "5.0.1"
@@ -348,6 +427,19 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764"
[[package]]
name = "file-header"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5568149106e77ae33bc3a2c3ef3839cbe63ffa4a8dd4a81612a6f9dfdbc2e9f"
dependencies = [
"crossbeam",
"lazy_static",
"license",
"thiserror",
"walkdir",
]
[[package]]
name = "fnv"
version = "1.0.7"
@@ -484,6 +576,19 @@ version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
[[package]]
name = "globset"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "759c97c1e17c55525b57192c06a267cda0ac5210b222d6b82189a2338fa1c13d"
dependencies = [
"aho-corasick",
"bstr",
"fnv",
"log",
"regex",
]
[[package]]
name = "h2"
version = "0.3.21"
@@ -710,6 +815,17 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "license"
version = "3.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b66615d42e949152327c402e03cd29dab8bff91ce470381ac2ca6d380d8d9946"
dependencies = [
"reword",
"serde",
"serde_json",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.5"
@@ -756,6 +872,15 @@ dependencies = [
"autocfg",
]
[[package]]
name = "memoffset"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
dependencies = [
"autocfg",
]
[[package]]
name = "mime"
version = "0.3.17"
@@ -1200,6 +1325,15 @@ dependencies = [
"winreg",
]
[[package]]
name = "reword"
version = "7.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe272098dce9ed76b479995953f748d1851261390b08f8a0ff619c885a1f0765"
dependencies = [
"unicode-segmentation",
]
[[package]]
name = "rusb"
version = "0.9.3"
@@ -1241,6 +1375,15 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.22"
@@ -1589,6 +1732,12 @@ dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-segmentation"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36"
[[package]]
name = "unindent"
version = "0.1.11"
@@ -1618,6 +1767,16 @@ version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "walkdir"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"

View File

@@ -24,6 +24,10 @@ itertools = "0.11.0"
lazy_static = "1.4.0"
thiserror = "1.0.41"
# Dev tools
file-header = { version = "0.1.2", optional = true }
globset = { version = "0.4.13", optional = true }
# CLI
anyhow = { version = "1.0.71", optional = true }
clap = { version = "4.3.3", features = ["derive"], optional = true }
@@ -52,10 +56,15 @@ env_logger = "0.10.0"
[package.metadata.docs.rs]
rustdoc-args = ["--generate-link-to-definition"]
[[bin]]
name = "file-header"
path = "tools/file_header.rs"
required-features = ["dev-tools"]
[[bin]]
name = "gen-assigned-numbers"
path = "tools/gen_assigned_numbers.rs"
required-features = ["bumble-codegen"]
required-features = ["dev-tools"]
[[bin]]
name = "bumble"
@@ -71,7 +80,7 @@ harness = false
[features]
anyhow = ["pyo3/anyhow"]
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
bumble-codegen = ["dep:anyhow"]
dev-tools = ["dep:anyhow", "dep:clap", "dep:file-header", "dep:globset"]
# separate feature for CLI so that dependencies don't spend time building these
bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"]
default = []

View File

@@ -62,5 +62,5 @@ in tests at `pytests/assigned_numbers.rs`.
To regenerate the assigned number tables based on the Python codebase:
```
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features bumble-codegen
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features dev-tools
```

View File

@@ -1,3 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bumble::wrapper::{self, core::Uuid16};
use pyo3::{intern, prelude::*, types::PyDict};
use std::collections;

View File

@@ -1,3 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! BLE advertisements.
use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS};

View File

@@ -1,3 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Bumble & Python logging
use pyo3::types::PyDict;

78
rust/tools/file_header.rs Normal file
View File

@@ -0,0 +1,78 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::anyhow;
use clap::Parser as _;
use file_header::{
add_headers_recursively, check_headers_recursively,
license::spdx::{YearCopyrightOwnerValue, APACHE_2_0},
};
use globset::{Glob, GlobSet, GlobSetBuilder};
use std::{env, path::PathBuf};
fn main() -> anyhow::Result<()> {
let rust_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?);
let ignore_globset = ignore_globset()?;
// Note: when adding headers, there is a bug where the line spacing is off for Apache 2.0 (see https://github.com/spdx/license-list-XML/issues/2127)
let header = APACHE_2_0.build_header(YearCopyrightOwnerValue::new(2023, "Google LLC".into()));
let cli = Cli::parse();
match cli.subcommand {
Subcommand::CheckAll => {
let result =
check_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header, 4)?;
if result.has_failure() {
return Err(anyhow!(
"The following files do not have headers: {result:?}"
));
}
}
Subcommand::AddAll => {
let files_with_new_header =
add_headers_recursively(&rust_dir, |p| !ignore_globset.is_match(p), header)?;
files_with_new_header
.iter()
.for_each(|path| println!("Added header to: {path:?}"));
}
}
Ok(())
}
fn ignore_globset() -> anyhow::Result<GlobSet> {
Ok(GlobSetBuilder::new()
.add(Glob::new("**/.idea/**")?)
.add(Glob::new("**/target/**")?)
.add(Glob::new("**/.gitignore")?)
.add(Glob::new("**/CHANGELOG.md")?)
.add(Glob::new("**/Cargo.lock")?)
.add(Glob::new("**/Cargo.toml")?)
.add(Glob::new("**/README.md")?)
.add(Glob::new("*.bin")?)
.build()?)
}
#[derive(clap::Parser)]
struct Cli {
#[clap(subcommand)]
subcommand: Subcommand,
}
#[derive(clap::Subcommand, Debug, Clone)]
enum Subcommand {
/// Checks if a license is present in files that are not in the ignore list.
CheckAll,
/// Adds a license as needed to files that are not in the ignore list.
AddAll,
}

View File

@@ -891,10 +891,10 @@ async def async_main():
# -----------------------------------------------------------------------------
def test_attribute_string_to_permissions():
assert Attribute.string_to_permissions('READABLE') == 1
assert Attribute.string_to_permissions('WRITEABLE') == 2
assert Attribute.string_to_permissions('READABLE,WRITEABLE') == 3
def test_permissions_from_string():
assert Attribute.Permissions.from_string('READABLE') == 1
assert Attribute.Permissions.from_string('WRITEABLE') == 2
assert Attribute.Permissions.from_string('READABLE,WRITEABLE') == 3
# -----------------------------------------------------------------------------

77
tests/utils_test.py Normal file
View File

@@ -0,0 +1,77 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import os
from bumble import utils
from pyee import EventEmitter
from unittest.mock import MagicMock
def test_on() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
context.on(emitter, 'event', mock)
emitter.emit('event')
assert not emitter.listeners('event')
assert mock.call_count == 1
def test_on_decorator() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
@context.on(emitter, 'event')
def on_event(*_) -> None:
mock()
emitter.emit('event')
assert not emitter.listeners('event')
assert mock.call_count == 1
def test_multiple_handlers() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
context.once(emitter, 'a', mock)
context.once(emitter, 'b', mock)
emitter.emit('b', 'b')
assert not emitter.listeners('a')
assert not emitter.listeners('b')
mock.assert_called_once_with('b')
# -----------------------------------------------------------------------------
def run_tests():
test_on()
test_on_decorator()
test_multiple_handlers()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
run_tests()

View File

@@ -23,7 +23,7 @@ from bumble.device import Device
# -----------------------------------------------------------------------------
class ScanEntry:
def __init__(self, advertisement):
self.address = str(advertisement.address).replace("/P", "")
self.address = advertisement.address.to_string(False)
self.address_type = ('Public', 'Random', 'Public Identity', 'Random Identity')[
advertisement.address.address_type
]

View File

@@ -171,7 +171,7 @@ class Speaker:
self.connection = connection
connection.on('disconnection', self.on_bluetooth_disconnection)
peer_name = '' if connection.peer_name is None else connection.peer_name
peer_address = str(connection.peer_address).replace('/P', '')
peer_address = connection.peer_address.to_string(False)
self.emit_event(
'connection', {'peer_name': peer_name, 'peer_address': peer_address}
)