Compare commits

...

15 Commits

Author SHA1 Message Date
Josh Wu
2491b686fa Handle SMP_Security_Request 2023-09-20 23:13:08 +02:00
Josh Wu
efd02b2f3e Adopt reviews 2023-09-20 23:03:23 +02:00
Josh Wu
3b14078646 Overload signatures 2023-09-20 23:03:23 +02:00
Josh Wu
eb9d5632bc Add utils_test type hint 2023-09-20 23:03:23 +02:00
Josh Wu
45f60edbb6 Pyee watcher context 2023-09-20 23:03:23 +02:00
David Duarte
393ea6a7bb pandora_server: Load server config
Pandora server has it's own config that we load from the 'server'
property of the current bumble config file
2023-09-18 14:28:42 -07:00
Gilles Boccon-Gibod
0d36d99a73 Merge pull request #287 from google/revert-286-gbg/package-depencencies-for-wasm
Revert "make cryptography a valid dependency for emscripten targets"
2023-09-13 23:37:42 -07:00
Gilles Boccon-Gibod
d8a9f5a724 Revert "make cryptography a valid dependency for emscripten targets" 2023-09-13 23:36:33 -07:00
Gilles Boccon-Gibod
2c66e1a042 Merge pull request #285 from google/gbg/fix-mypy-errors
mypy: ignore false positive errors
2023-09-13 23:30:50 -07:00
Gilles Boccon-Gibod
d5eccdb00f Merge pull request #286 from google/gbg/package-depencencies-for-wasm
make cryptography a valid dependency for emscripten targets
2023-09-13 23:30:28 -07:00
Gilles Boccon-Gibod
32626573a6 ignore false positive errors 2023-09-13 23:17:00 -07:00
Gilles Boccon-Gibod
caa82b8f7e make cryptography a valid dependency for emscripten targets 2023-09-13 22:38:28 -07:00
Gilles Boccon-Gibod
5af347b499 Merge pull request #282 from google/gbg/multi-python-pre-commit-check
run pre-commit tests with all supported Python versions
2023-09-13 07:47:32 -07:00
zxzxwu
4ed5bb5a9e Merge pull request #281 from zxzxwu/cleanup-transport
Replace | typing usage with Optional and Union
2023-09-13 13:31:41 +08:00
Josh Wu
f39f5f531c Replace | typing usage with Optional and Union 2023-09-12 15:50:51 +08:00
14 changed files with 280 additions and 44 deletions

View File

@@ -3,7 +3,7 @@ import click
import logging import logging
import json import json
from bumble.pandora import PandoraDevice, serve from bumble.pandora import PandoraDevice, Config, serve
from typing import Dict, Any from typing import Dict, Any
BUMBLE_SERVER_GRPC_PORT = 7999 BUMBLE_SERVER_GRPC_PORT = 7999
@@ -29,12 +29,14 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No
transport = transport.replace('<rootcanal-port>', str(rootcanal_port)) transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
bumble_config = retrieve_config(config) bumble_config = retrieve_config(config)
if 'transport' not in bumble_config.keys(): bumble_config.setdefault('transport', transport)
bumble_config.update({'transport': transport})
device = PandoraDevice(bumble_config) device = PandoraDevice(bumble_config)
server_config = Config()
server_config.load_from_dict(bumble_config.get('server', {}))
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
asyncio.run(serve(device, port=grpc_port)) asyncio.run(serve(device, config=server_config, port=grpc_port))
def retrieve_config(config: str) -> Dict[str, Any]: def retrieve_config(config: str) -> Dict[str, Any]:

View File

@@ -80,7 +80,7 @@ class BaseError(Exception):
def __init__( def __init__(
self, self,
error_code: int | None, error_code: Optional[int],
error_namespace: str = '', error_namespace: str = '',
error_name: str = '', error_name: str = '',
details: str = '', details: str = '',

View File

@@ -4397,7 +4397,7 @@ class HCI_Event(HCI_Packet):
if len(parameters) != length: if len(parameters) != length:
raise ValueError('invalid packet length') raise ValueError('invalid packet length')
cls: Type[HCI_Event | HCI_LE_Meta_Event] | None cls: Any
if event_code == HCI_LE_META_EVENT: if event_code == HCI_LE_META_EVENT:
# We do this dispatch here and not in the subclass in order to avoid call # We do this dispatch here and not in the subclass in order to avoid call
# loops # loops

View File

@@ -757,7 +757,7 @@ class Channel(EventEmitter):
) )
self.state = new_state self.state = new_state
def send_pdu(self, pdu: SupportsBytes | bytes) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
@@ -1098,7 +1098,7 @@ class LeConnectionOrientedChannel(EventEmitter):
elif new_state == self.DISCONNECTED: elif new_state == self.DISCONNECTED:
self.emit('close') self.emit('close')
def send_pdu(self, pdu: SupportsBytes | bytes) -> None: def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
@@ -1571,7 +1571,7 @@ class ChannelManager:
if connection_handle in self.identifiers: if connection_handle in self.identifiers:
del self.identifiers[connection_handle] del self.identifiers[connection_handle]
def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None: def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
logger.debug( logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} ' f'{color(">>> Sending L2CAP PDU", "blue")} '

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import contextlib
import grpc import grpc
import logging import logging
@@ -27,8 +28,8 @@ from bumble.core import (
) )
from bumble.device import Connection as BumbleConnection, Device from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error from bumble.hci import HCI_Error
from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from contextlib import suppress
from google.protobuf import any_pb2 # pytype: disable=pyi-error from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
@@ -294,23 +295,35 @@ class SecurityService(SecurityServicer):
try: try:
self.log.debug('Pair...') self.log.debug('Pair...')
if ( security_result = asyncio.get_running_loop().create_future()
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
wait_for_security: asyncio.Future[
bool
] = asyncio.get_running_loop().create_future()
connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
connection.on("pairing_failure", wait_for_security.set_exception)
connection.request_pairing() with contextlib.closing(EventWatcher()) as watcher:
await wait_for_security @watcher.on(connection, 'pairing')
else: def on_pairing(*_: Any) -> None:
await connection.pair() security_result.set_result('success')
self.log.debug('Paired') @watcher.on(connection, 'pairing_failure')
def on_pairing_failure(*_: Any) -> None:
security_result.set_result('pairing_failure')
@watcher.on(connection, 'disconnection')
def on_disconnection(*_: Any) -> None:
security_result.set_result('connection_died')
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
connection.request_pairing()
else:
await connection.pair()
result = await security_result
self.log.debug(f'Pairing session complete, status={result}')
if result != 'success':
return SecureResponse(**{result: empty_pb2.Empty()})
except asyncio.CancelledError: except asyncio.CancelledError:
self.log.warning("Connection died during encryption") self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty()) return SecureResponse(connection_died=empty_pb2.Empty())
@@ -369,6 +382,7 @@ class SecurityService(SecurityServicer):
str str
] = asyncio.get_running_loop().create_future() ] = asyncio.get_running_loop().create_future()
authenticate_task: Optional[asyncio.Future[None]] = None authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None: async def authenticate() -> None:
assert connection assert connection
@@ -415,6 +429,10 @@ class SecurityService(SecurityServicer):
if authenticate_task is None: if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate()) authenticate_task = asyncio.create_task(authenticate())
def pair(*_: Any) -> None:
if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair())
listeners: Dict[str, Callable[..., None]] = { listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'), 'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'), 'pairing_failure': set_failure('pairing_failure'),
@@ -425,6 +443,7 @@ class SecurityService(SecurityServicer):
'connection_encryption_change': on_encryption_change, 'connection_encryption_change': on_encryption_change,
'classic_pairing': try_set_success, 'classic_pairing': try_set_success,
'classic_pairing_failure': set_failure('pairing_failure'), 'classic_pairing_failure': set_failure('pairing_failure'),
'security_request': pair,
} }
# register event handlers # register event handlers
@@ -452,6 +471,15 @@ class SecurityService(SecurityServicer):
pass pass
self.log.debug('Authenticated') self.log.debug('Authenticated')
# wait for `pair` to finish if any
if pair_task is not None:
self.log.debug('Wait for authentication...')
try:
await pair_task # type: ignore
except:
pass
self.log.debug('paired')
return WaitSecurityResponse(**kwargs) return WaitSecurityResponse(**kwargs)
def reached_security_level( def reached_security_level(
@@ -523,7 +551,7 @@ class SecurityStorageService(SecurityStorageServicer):
self.log.debug(f"DeleteBond: {address}") self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None: if self.device.keystore is not None:
with suppress(KeyError): with contextlib.suppress(KeyError):
await self.device.keystore.delete(str(address)) await self.device.keystore.delete(str(address))
return empty_pb2.Empty() return empty_pb2.Empty()

View File

@@ -37,6 +37,7 @@ from typing import (
Optional, Optional,
Tuple, Tuple,
Type, Type,
cast,
) )
from pyee import EventEmitter from pyee import EventEmitter
@@ -1771,7 +1772,26 @@ class Manager(EventEmitter):
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
connection.send_l2cap_pdu(cid, command.to_bytes()) connection.send_l2cap_pdu(cid, command.to_bytes())
def on_smp_security_request_command(
self, connection: Connection, request: SMP_Security_Request_Command
) -> None:
connection.emit('security_request', request.auth_req)
def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None: def on_smp_pdu(self, connection: Connection, pdu: bytes) -> None:
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Security request is more than just pairing, so let applications handle them
if command.code == SMP_SECURITY_REQUEST_COMMAND:
self.on_smp_security_request_command(
connection, cast(SMP_Security_Request_Command, command)
)
return
# Look for a session with this connection, and create one if none exists # Look for a session with this connection, and create one if none exists
if not (session := self.sessions.get(connection.handle)): if not (session := self.sessions.get(connection.handle)):
if connection.role == BT_CENTRAL_ROLE: if connection.role == BT_CENTRAL_ROLE:
@@ -1782,13 +1802,6 @@ class Manager(EventEmitter):
) )
self.sessions[connection.handle] = session self.sessions[connection.handle] = session
# Parse the L2CAP payload into an SMP Command object
command = SMP_Command.from_bytes(pdu)
logger.debug(
f'<<< Received SMP Command on connection [0x{connection.handle:04X}] '
f'{connection.peer_address}: {command}'
)
# Delegate the handling of the command to the session # Delegate the handling of the command to the session
session.on_smp_command(command) session.on_smp_command(command)

View File

@@ -18,6 +18,8 @@
import logging import logging
import grpc.aio import grpc.aio
from typing import Optional, Union
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec: str | None) -> Transport: async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
''' '''
Open a transport connection to an Android emulator via its gRPC interface. Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax: The parameter string has this syntax:
@@ -82,7 +84,7 @@ async def open_android_emulator_transport(spec: str | None) -> Transport:
logger.debug(f'connecting to gRPC server at {server_address}') logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address) channel = grpc.aio.insecure_channel(server_address)
service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
if mode == 'host': if mode == 'host':
# Connect as a host # Connect as a host
service = EmulatedBluetoothServiceStub(channel) service = EmulatedBluetoothServiceStub(channel)

View File

@@ -122,7 +122,7 @@ def publish_grpc_port(grpc_port) -> bool:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport( async def open_android_netsim_controller_transport(
server_host: str | None, server_port: int server_host: Optional[str], server_port: int
) -> Transport: ) -> Transport:
if not server_port: if not server_port:
raise ValueError('invalid port') raise ValueError('invalid port')

View File

@@ -23,6 +23,8 @@ import socket
import ctypes import ctypes
import collections import collections
from typing import Optional
from .common import Transport, ParserSource from .common import Transport, ParserSource
@@ -33,7 +35,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_hci_socket_transport(spec: str | None) -> Transport: async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
''' '''
Open an HCI Socket (only available on some platforms). Open an HCI Socket (only available on some platforms).
The parameter string is either empty (to use the first/default Bluetooth adapter) The parameter string is either empty (to use the first/default Bluetooth adapter)
@@ -45,9 +47,9 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
# Create a raw HCI socket # Create a raw HCI socket
try: try:
hci_socket = socket.socket( hci_socket = socket.socket(
socket.AF_BLUETOOTH, socket.AF_BLUETOOTH, # type: ignore[attr-defined]
socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.SOCK_RAW | socket.SOCK_NONBLOCK, # type: ignore[attr-defined]
socket.BTPROTO_HCI, # type: ignore socket.BTPROTO_HCI, # type: ignore[attr-defined]
) )
except AttributeError as error: except AttributeError as error:
# Not supported on this platform # Not supported on this platform
@@ -78,7 +80,7 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
bind_address = struct.pack( bind_address = struct.pack(
# pylint: disable=no-member # pylint: disable=no-member
'<HHH', '<HHH',
socket.AF_BLUETOOTH, socket.AF_BLUETOOTH, # type: ignore[attr-defined]
adapter_index, adapter_index,
HCI_CHANNEL_USER, HCI_CHANNEL_USER,
) )

View File

@@ -23,6 +23,8 @@ import atexit
import os import os
import logging import logging
from typing import Optional
from .common import Transport, StreamPacketSource, StreamPacketSink from .common import Transport, StreamPacketSource, StreamPacketSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -32,7 +34,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_pty_transport(spec: str | None) -> Transport: async def open_pty_transport(spec: Optional[str]) -> Transport:
''' '''
Open a PTY transport. Open a PTY transport.
The parameter string may be empty, or a path name where a symbolic link The parameter string may be empty, or a path name where a symbolic link

View File

@@ -17,6 +17,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
from typing import Optional
from .common import Transport from .common import Transport
from .file import open_file_transport from .file import open_file_transport
@@ -27,7 +29,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_vhci_transport(spec: str | None) -> Transport: async def open_vhci_transport(spec: Optional[str]) -> Transport:
''' '''
Open a VHCI transport (only available on some platforms). Open a VHCI transport (only available on some platforms).
The parameter string is either empty (to use the default VHCI device The parameter string is either empty (to use the default VHCI device

View File

@@ -15,12 +15,24 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import traceback import traceback
import collections import collections
import sys import sys
from typing import Awaitable, Set, TypeVar from typing import (
Awaitable,
Set,
TypeVar,
List,
Tuple,
Callable,
Any,
Optional,
Union,
overload,
)
from functools import wraps from functools import wraps
from pyee import EventEmitter from pyee import EventEmitter
@@ -64,6 +76,102 @@ def composite_listener(cls):
return cls return cls
# -----------------------------------------------------------------------------
_Handler = TypeVar('_Handler', bound=Callable)
class EventWatcher:
'''A wrapper class to control the lifecycle of event handlers better.
Usage:
```
watcher = EventWatcher()
def on_foo():
...
watcher.on(emitter, 'foo', on_foo)
@watcher.on(emitter, 'bar')
def on_bar():
...
# Close all event handlers watching through this watcher
watcher.close()
```
As context:
```
with contextlib.closing(EventWatcher()) as context:
@context.on(emitter, 'foo')
def on_foo():
...
# on_foo() has been removed here!
```
'''
handlers: List[Tuple[EventEmitter, str, Callable[..., Any]]]
def __init__(self) -> None:
self.handlers = []
@overload
def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
...
@overload
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
...
def on(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event until the context is closed.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
'''
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.on(event, f)
return f
return wrapper if handler is None else wrapper(handler)
@overload
def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
...
@overload
def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
...
def once(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event for once.
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
'''
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.once(event, f)
return f
return wrapper if handler is None else wrapper(handler)
def close(self) -> None:
for emitter, event, handler in self.handlers:
if handler in emitter.listeners(event):
emitter.remove_listener(event, handler)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_T = TypeVar('_T') _T = TypeVar('_T')

View File

@@ -84,7 +84,7 @@ development =
black == 22.10 black == 22.10
grpcio-tools >= 1.57.0 grpcio-tools >= 1.57.0
invoke >= 1.7.3 invoke >= 1.7.3
mypy == 1.2.0 mypy == 1.5.0
nox >= 2022 nox >= 2022
pylint == 2.15.8 pylint == 2.15.8
types-appdirs >= 1.4.3 types-appdirs >= 1.4.3

77
tests/utils_test.py Normal file
View File

@@ -0,0 +1,77 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import os
from bumble import utils
from pyee import EventEmitter
from unittest.mock import MagicMock
def test_on() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
context.on(emitter, 'event', mock)
emitter.emit('event')
assert not emitter.listeners('event')
assert mock.call_count == 1
def test_on_decorator() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
@context.on(emitter, 'event')
def on_event(*_) -> None:
mock()
emitter.emit('event')
assert not emitter.listeners('event')
assert mock.call_count == 1
def test_multiple_handlers() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
mock = MagicMock()
context.once(emitter, 'a', mock)
context.once(emitter, 'b', mock)
emitter.emit('b', 'b')
assert not emitter.listeners('a')
assert not emitter.listeners('b')
mock.assert_called_once_with('b')
# -----------------------------------------------------------------------------
def run_tests():
test_on()
test_on_decorator()
test_multiple_handlers()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
run_tests()