Compare commits

...

23 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
a8beb6b1ff remove stale comment 2023-02-14 16:05:46 -08:00
Gilles Boccon-Gibod
2d44de611f make pylint happy 2023-02-14 16:04:20 -08:00
Gilles Boccon-Gibod
e6fc63b2d8 improve smp compatibility with other OS flows 2023-02-13 10:53:00 -08:00
Gilles Boccon-Gibod
1321c7da81 Merge pull request #125 from google/gbg/gh-124
fix getting the filename from the keystore option.
2023-02-10 20:17:38 -08:00
Gilles Boccon-Gibod
5a1b03fd91 format 2023-02-08 10:54:27 -08:00
Gilles Boccon-Gibod
de47721753 fix typo caused by an earlier refactor. 2023-02-08 09:56:11 -08:00
Gilles Boccon-Gibod
83a76a75d3 fix getting the filename from the keystore option. 2023-02-08 09:40:19 -08:00
Lucas Abel
d5b5ef8313 Merge pull request #122 from google/uael/abort-on-fix-invalid-state
utils: fix possible invalide state error while canceling future for `abort_on`
2023-02-06 17:13:34 -08:00
uael
856a8d53cd utils: fix possible invalide state error while canceling future for abort_on 2023-02-06 16:58:23 +00:00
Gilles Boccon-Gibod
177c273a57 Merge pull request #121 from google/gbg/replace-bitstruct
replace bitstruct with construct
2023-02-05 11:33:36 -08:00
Gilles Boccon-Gibod
24a863983d Merge branch 'gbg/replace-bitstruct' of https://github.com/google/bumble into gbg/replace-bitstruct
# Conflicts:
#	bumble/a2dp.py
#	pyproject.toml
2023-02-04 09:31:18 -08:00
Gilles Boccon-Gibod
b7ef09d4a3 fix format 2023-02-04 09:26:31 -08:00
Gilles Boccon-Gibod
b5b6cd13b8 replace bitstruct with construct 2023-02-04 09:23:13 -08:00
Gilles Boccon-Gibod
ef781bc374 replace bitstruct with construct 2023-02-03 19:41:07 -08:00
Lucas Abel
00978c1d63 Merge pull request #118 from google/uael/type-hints
overall: add types hints to the small subset used by avatar
2023-02-02 12:48:40 -08:00
uael
b731f6f556 overall: add types hints to the small subset used by avatar 2023-02-02 19:37:55 +00:00
Lucas Abel
ed261886e1 Merge pull request #119 from google/uael/fix-ci-packages-version
build: fix version of packages running checks in CI
2023-02-02 11:03:34 -08:00
uael
5e18094c31 build: fix version of packages running checks in CI 2023-02-02 17:23:15 +00:00
Lucas Abel
9a9b4e5bf1 Merge pull request #117 from google/uael/host-fixes
host: fixed `.latency` attribute error
2023-01-27 17:38:11 -08:00
Abel Lucas
895f1618d8 host: fixed .latency attribute error 2023-01-27 23:05:43 +00:00
Gilles Boccon-Gibod
52746e0c68 Merge pull request #116 from google/barbibulle-patch-1
fix libusb-package dependency
2023-01-25 15:59:42 -08:00
Gilles Boccon-Gibod
f9b7072423 Update setup.cfg 2023-01-25 15:37:33 -08:00
Gilles Boccon-Gibod
fa4be1958f Merge pull request #114 from google/gbg/fix-constant-typo
fix typo in constant name
2023-01-23 08:50:07 -08:00
19 changed files with 413 additions and 259 deletions

View File

@@ -71,5 +71,10 @@
"editor.rulers": [88]
},
"python.formatting.provider": "black",
"pylint.importStrategy": "useBundled"
"pylint.importStrategy": "useBundled",
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@@ -19,8 +19,8 @@ import asyncio
import os
import logging
import click
import aioconsole
from colors import color
from prompt_toolkit.shortcuts import PromptSession
from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link
@@ -42,9 +42,23 @@ from bumble.att import (
)
# -----------------------------------------------------------------------------
class Waiter:
instance = None
def __init__(self):
self.done = asyncio.get_running_loop().create_future()
def terminate(self):
self.done.set_result(None)
async def wait_until_terminated(self):
return await self.done
# -----------------------------------------------------------------------------
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, prompt):
def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__(
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
@@ -58,7 +72,18 @@ class Delegate(PairingDelegate):
self.mode = mode
self.peer = Peer(connection)
self.peer_name = None
self.prompt = prompt
self.do_prompt = do_prompt
def print(self, message):
print(color(message, 'yellow'))
async def prompt(self, message):
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
session = PromptSession(message)
response = await session.prompt_async()
return response.lower().strip()
async def update_peer_name(self):
if self.peer_name is not None:
@@ -73,19 +98,15 @@ class Delegate(PairingDelegate):
self.peer_name = '[?]'
async def accept(self):
if self.prompt:
if self.do_prompt:
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for acceptance
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing request from {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing request from {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow'))
response = response.lower().strip()
response = await self.prompt('>>> Accept? ')
if response == 'yes':
return True
@@ -96,23 +117,17 @@ class Delegate(PairingDelegate):
# Accept silently
return True
async def compare_numbers(self, number, digits=6):
async def compare_numbers(self, number, digits):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a numeric comparison
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await aioconsole.ainput(
color(
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
)
response = await self.prompt(
f'>>> Does the other device display {number:0{digits}}? '
)
response = response.lower().strip()
if response == 'yes':
return True
@@ -123,30 +138,24 @@ class Delegate(PairingDelegate):
async def get_number(self):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a PIN
while True:
try:
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
return int(await aioconsole.ainput(color('>>> Enter PIN: ', 'yellow')))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
return int(await self.prompt('>>> Enter PIN: '))
except ValueError:
pass
async def display_number(self, number, digits=6):
async def display_number(self, number, digits):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Display a PIN code
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color(f'### PIN: {number:0{digits}}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print(f'### PIN: {number:0{digits}}')
self.print('###-----------------------------------')
# -----------------------------------------------------------------------------
@@ -238,6 +247,7 @@ def on_pairing(keys):
print(color('*** Paired!', 'cyan'))
keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@@ -245,6 +255,7 @@ def on_pairing_failure(reason):
print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red'))
Waiter.instance.terminate()
# -----------------------------------------------------------------------------
@@ -262,6 +273,8 @@ async def pair(
hci_transport,
address_or_name,
):
Waiter.instance = Waiter()
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -332,7 +345,19 @@ async def pair(
# Advertise so that peers can find us and connect
await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination()
# Run until the user asks to exit
await Waiter.instance.wait_until_terminated()
# -----------------------------------------------------------------------------
class LogHandler(logging.Handler):
def __init__(self):
super().__init__()
self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))
def emit(self, record):
message = self.format(record)
print(message)
# -----------------------------------------------------------------------------
@@ -366,7 +391,11 @@ async def pair(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
)
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
@click.option('--keystore-file', help='File in which to store the pairing keys')
@click.option(
'--keystore-file',
metavar='<filename>',
help='File in which to store the pairing keys',
)
@click.argument('device-config')
@click.argument('hci_transport')
@click.argument('address-or-name', required=False)
@@ -384,7 +413,13 @@ def main(
hci_transport,
address_or_name,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Setup logging
log_handler = LogHandler()
root_logger = logging.getLogger()
root_logger.addHandler(log_handler)
root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Pair
asyncio.run(
pair(
mode,

View File

@@ -18,7 +18,7 @@
import struct
import logging
from collections import namedtuple
import bitstruct
import construct
from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
@@ -258,7 +258,17 @@ class SbcMediaCodecInformation(
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
BIT_FIELDS = 'u4u4u4u2u2u8u8'
BIT_FIELDS = construct.Bitwise(
construct.Sequence(
construct.BitsInteger(4),
construct.BitsInteger(4),
construct.BitsInteger(4),
construct.BitsInteger(2),
construct.BitsInteger(2),
construct.BitsInteger(8),
construct.BitsInteger(8),
)
)
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
@@ -276,7 +286,7 @@ class SbcMediaCodecInformation(
@staticmethod
def from_bytes(data):
return SbcMediaCodecInformation(
*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)
*SbcMediaCodecInformation.BIT_FIELDS.parse(data)
)
@classmethod
@@ -326,7 +336,7 @@ class SbcMediaCodecInformation(
)
def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self)
return self.BIT_FIELDS.build(self)
def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
@@ -350,14 +360,23 @@ class SbcMediaCodecInformation(
class AacMediaCodecInformation(
namedtuple(
'AacMediaCodecInformation',
['object_type', 'sampling_frequency', 'channels', 'vbr', 'bitrate'],
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
)
):
'''
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
BIT_FIELDS = 'u8u12u2p2u1u23'
BIT_FIELDS = construct.Bitwise(
construct.Sequence(
construct.BitsInteger(8),
construct.BitsInteger(12),
construct.BitsInteger(2),
construct.BitsInteger(2),
construct.BitsInteger(1),
construct.BitsInteger(23),
)
)
OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
@@ -383,7 +402,7 @@ class AacMediaCodecInformation(
@staticmethod
def from_bytes(data):
return AacMediaCodecInformation(
*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data)
*AacMediaCodecInformation.BIT_FIELDS.parse(data)
)
@classmethod
@@ -394,6 +413,7 @@ class AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels],
rfa=0,
vbr=vbr,
bitrate=bitrate,
)
@@ -411,7 +431,7 @@ class AacMediaCodecInformation(
)
def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self)
return self.BIT_FIELDS.build(self)
def __str__(self):
object_types = [

View File

@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from typing import List, Optional, Tuple, Union, cast
from .company_ids import COMPANY_IDENTIFIERS
@@ -146,7 +147,7 @@ class UUID:
'''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
UUIDS: list[UUID] = [] # Registry of all instances created
UUIDS: List[UUID] = [] # Registry of all instances created
def __init__(self, uuid_str_or_int, name=None):
if isinstance(uuid_str_or_int, int):
@@ -181,7 +182,7 @@ class UUID:
return self
@classmethod
def from_bytes(cls, uuid_bytes, name=None):
def from_bytes(cls, uuid_bytes: bytes, name: Optional[str] = None) -> UUID:
if len(uuid_bytes) in (2, 4, 16):
self = cls.__new__(cls)
self.uuid_bytes = uuid_bytes
@@ -225,7 +226,7 @@ class UUID:
'''
return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
def to_hex_str(self):
def to_hex_str(self) -> str:
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
return bytes(reversed(self.uuid_bytes)).hex().upper()
@@ -607,6 +608,11 @@ class DeviceClass:
# -----------------------------------------------------------------------------
# Advertising Data
# -----------------------------------------------------------------------------
AdvertisingObject = Union[
List[UUID], Tuple[UUID, bytes], bytes, str, int, Tuple[int, int], Tuple[int, bytes]
]
class AdvertisingData:
# fmt: off
# pylint: disable=line-too-long
@@ -722,10 +728,12 @@ class AdvertisingData:
BR_EDR_CONTROLLER_FLAG = 0x08
BR_EDR_HOST_FLAG = 0x10
ad_structures: List[Tuple[int, bytes]]
# fmt: on
# pylint: enable=line-too-long
def __init__(self, ad_structures=None):
def __init__(self, ad_structures: Optional[List[Tuple[int, bytes]]] = None) -> None:
if ad_structures is None:
ad_structures = []
self.ad_structures = ad_structures[:]
@@ -752,7 +760,7 @@ class AdvertisingData:
return ','.join(bit_flags_to_strings(flags, flag_names))
@staticmethod
def uuid_list_to_objects(ad_data, uuid_size):
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]:
uuids = []
offset = 0
while (uuid_size * (offset + 1)) <= len(ad_data):
@@ -829,7 +837,7 @@ class AdvertisingData:
# pylint: disable=too-many-return-statements
@staticmethod
def ad_data_to_object(ad_type, ad_data):
def ad_data_to_object(ad_type: int, ad_data: bytes) -> AdvertisingObject:
if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
@@ -868,22 +876,22 @@ class AdvertisingData:
return ad_data.decode("utf-8")
if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS):
return ad_data[0]
return cast(int, struct.unpack('B', ad_data)[0])
if ad_type in (
AdvertisingData.APPEARANCE,
AdvertisingData.ADVERTISING_INTERVAL,
):
return struct.unpack('<H', ad_data)[0]
return cast(int, struct.unpack('<H', ad_data)[0])
if ad_type == AdvertisingData.CLASS_OF_DEVICE:
return struct.unpack('<I', bytes([*ad_data, 0]))[0]
return cast(int, struct.unpack('<I', bytes([*ad_data, 0]))[0])
if ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
return struct.unpack('<HH', ad_data)
return cast(Tuple[int, int], struct.unpack('<HH', ad_data))
if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:])
return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:])
return ad_data
@@ -898,26 +906,27 @@ class AdvertisingData:
self.ad_structures.append((ad_type, ad_data))
offset += length
def get(self, type_id, return_all=False, raw=False):
def get_all(self, type_id: int, raw: bool = False) -> List[AdvertisingObject]:
'''
Get Advertising Data Structure(s) with a given type
If return_all is True, returns a (possibly empty) list of matches,
else returns the first entry, or None if no structure matches.
Returns a (possibly empty) list of matches.
'''
def process_ad_data(ad_data):
def process_ad_data(ad_data: bytes) -> AdvertisingObject:
return ad_data if raw else self.ad_data_to_object(type_id, ad_data)
if return_all:
return [
process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id
]
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
return next(
(process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id),
None,
)
def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingObject]:
'''
Get Advertising Data Structure(s) with a given type
Returns the first entry, or None if no structure matches.
'''
all = self.get_all(type_id, raw=raw)
return all[0] if all else None
def __bytes__(self):
return b''.join(

View File

@@ -23,7 +23,7 @@ import asyncio
import logging
from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass
from typing import ClassVar
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
from colors import color
@@ -197,6 +197,8 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN
# -----------------------------------------------------------------------------
class Advertisement:
address: Address
TX_POWER_NOT_AVAILABLE = (
HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE
)
@@ -511,6 +513,17 @@ ConnectionParametersPreferences.default = ConnectionParametersPreferences()
# -----------------------------------------------------------------------------
class Connection(CompositeEventEmitter):
device: Device
handle: int
transport: int
self_address: Address
peer_address: Address
role: int
encryption: int
authenticated: bool
sc: bool
link_key_type: int
@composite_listener
class Listener:
def on_disconnection(self, reason):
@@ -611,6 +624,10 @@ class Connection(CompositeEventEmitter):
def is_encrypted(self):
return self.encryption != 0
@property
def is_incomplete(self) -> bool:
return self.handle == None
def send_l2cap_pdu(self, cid, pdu):
self.device.send_l2cap_pdu(self.handle, cid, pdu)
@@ -626,20 +643,22 @@ class Connection(CompositeEventEmitter):
):
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
async def disconnect(self, reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR):
return await self.device.disconnect(self, reason)
async def disconnect(
self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
) -> None:
await self.device.disconnect(self, reason)
async def pair(self):
async def pair(self) -> None:
return await self.device.pair(self)
def request_pairing(self):
def request_pairing(self) -> None:
return self.device.request_pairing(self)
# [Classic only]
async def authenticate(self):
async def authenticate(self) -> None:
return await self.device.authenticate(self)
async def encrypt(self, enable=True):
async def encrypt(self, enable: bool = True) -> None:
return await self.device.encrypt(self, enable)
async def sustain(self, timeout=None):
@@ -707,10 +726,10 @@ class Connection(CompositeEventEmitter):
# -----------------------------------------------------------------------------
class DeviceConfiguration:
def __init__(self):
def __init__(self) -> None:
# Setup defaults
self.name = DEVICE_DEFAULT_NAME
self.address = DEVICE_DEFAULT_ADDRESS
self.address = Address(DEVICE_DEFAULT_ADDRESS)
self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE
self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL
@@ -730,12 +749,13 @@ class DeviceConfiguration:
)
self.irk = bytes(16) # This really must be changed for any level of security
self.keystore = None
self.gatt_services = []
self.gatt_services: List[Dict[str, Any]] = []
def load_from_dict(self, config):
def load_from_dict(self, config: Dict[str, Any]) -> None:
# Load simple properties
self.name = config.get('name', self.name)
self.address = Address(config.get('address', self.address))
if address := config.get('address', None):
self.address = Address(address)
self.class_of_device = config.get('class_of_device', self.class_of_device)
self.advertising_interval_min = config.get(
'advertising_interval', self.advertising_interval_min
@@ -842,6 +862,22 @@ device_host_event_handlers: list[str] = []
# -----------------------------------------------------------------------------
class Device(CompositeEventEmitter):
# incomplete list of fields.
random_address: Address
public_address: Address
classic_enabled: bool
name: str
class_of_device: int
gatt_server: gatt_server.Server
advertising_data: bytes
scan_response_data: bytes
connections: Dict[int, Connection]
pending_connections: Dict[Address, Connection]
classic_pending_accepts: Dict[
Address, List[asyncio.Future[Union[Connection, Tuple[Address, int, int]]]]
]
advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator]
@composite_listener
class Listener:
def on_advertisement(self, advertisement):
@@ -888,12 +924,12 @@ class Device(CompositeEventEmitter):
def __init__(
self,
name=None,
address=None,
config=None,
host=None,
generic_access_service=True,
):
name: Optional[str] = None,
address: Optional[Address] = None,
config: Optional[DeviceConfiguration] = None,
host: Optional[Host] = None,
generic_access_service: bool = True,
) -> None:
super().__init__()
self._host = None
@@ -995,10 +1031,12 @@ class Device(CompositeEventEmitter):
setup_event_forwarding(self.gatt_server, self, 'characteristic_subscription')
# Set the initial host
self.host = host
if host:
self.host = host
@property
def host(self):
def host(self) -> Host:
assert self._host
return self._host
@host.setter
@@ -1032,15 +1070,18 @@ class Device(CompositeEventEmitter):
def sdp_service_records(self, service_records):
self.sdp_server.service_records = service_records
def lookup_connection(self, connection_handle):
def lookup_connection(self, connection_handle: int) -> Optional[Connection]:
if connection := self.connections.get(connection_handle):
return connection
return None
def find_connection_by_bd_addr(
self, bd_addr, transport=None, check_address_type=False
):
self,
bd_addr: Address,
transport: Optional[int] = None,
check_address_type: bool = False,
) -> Optional[Connection]:
for connection in self.connections.values():
if connection.peer_address.to_bytes() == bd_addr.to_bytes():
if (
@@ -1098,11 +1139,11 @@ class Device(CompositeEventEmitter):
logger.warning('!!! Command timed out')
raise CommandTimeoutError() from error
async def power_on(self):
async def power_on(self) -> None:
# Reset the controller
await self.host.reset()
response = await self.send_command(HCI_Read_BD_ADDR_Command())
response = await self.send_command(HCI_Read_BD_ADDR_Command()) # type: ignore[call-arg]
if response.return_parameters.status == HCI_SUCCESS:
logger.debug(
color(f'BD_ADDR: {response.return_parameters.bd_addr}', 'yellow')
@@ -1114,7 +1155,7 @@ class Device(CompositeEventEmitter):
HCI_Write_LE_Host_Support_Command(
le_supported_host=int(self.le_enabled),
simultaneous_le_host=int(self.le_simultaneous_enabled),
)
) # type: ignore[call-arg]
)
if self.le_enabled:
@@ -1124,7 +1165,7 @@ class Device(CompositeEventEmitter):
if self.host.supports_command(HCI_LE_RAND_COMMAND):
# Get 8 random bytes
response = await self.send_command(
HCI_LE_Rand_Command(), check_result=True
HCI_LE_Rand_Command(), check_result=True # type: ignore[call-arg]
)
# Ensure the address bytes can be a static random address
@@ -1145,7 +1186,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Random_Address_Command(
random_address=self.random_address
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1153,7 +1194,7 @@ class Device(CompositeEventEmitter):
if self.keystore and self.host.supports_command(
HCI_LE_CLEAR_RESOLVING_LIST_COMMAND
):
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
resolving_keys = await self.keystore.get_resolving_keys()
for (irk, address) in resolving_keys:
@@ -1163,7 +1204,7 @@ class Device(CompositeEventEmitter):
peer_identity_address=address,
peer_irk=irk,
local_irk=self.irk,
)
) # type: ignore[call-arg]
)
# Enable address resolution
@@ -1178,28 +1219,24 @@ class Device(CompositeEventEmitter):
if self.classic_enabled:
await self.send_command(
HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8'))
HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) # type: ignore[call-arg]
)
await self.send_command(
HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device)
HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device) # type: ignore[call-arg]
)
await self.send_command(
HCI_Write_Simple_Pairing_Mode_Command(
simple_pairing_mode=int(self.classic_ssp_enabled)
)
) # type: ignore[call-arg]
)
await self.send_command(
HCI_Write_Secure_Connections_Host_Support_Command(
secure_connections_host_support=int(self.classic_sc_enabled)
)
) # type: ignore[call-arg]
)
await self.set_connectable(self.connectable)
await self.set_discoverable(self.discoverable)
# Let the SMP manager know about the address
# TODO: allow using a public address
self.smp_manager.address = self.random_address
# Done
self.powered_on = True
@@ -1221,11 +1258,11 @@ class Device(CompositeEventEmitter):
async def start_advertising(
self,
advertising_type=AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
target=None,
own_address_type=OwnAddressType.RANDOM,
auto_restart=False,
):
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
target: Optional[Address] = None,
own_address_type: int = OwnAddressType.RANDOM,
auto_restart: bool = False,
) -> None:
# If we're advertising, stop first
if self.advertising:
await self.stop_advertising()
@@ -1235,7 +1272,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Advertising_Data_Command(
advertising_data=self.advertising_data
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1244,7 +1281,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Scan_Response_Data_Command(
scan_response_data=self.scan_response_data
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1270,13 +1307,13 @@ class Device(CompositeEventEmitter):
peer_address=peer_address,
advertising_channel_map=7,
advertising_filter_policy=0,
),
), # type: ignore[call-arg]
check_result=True,
)
# Enable advertising
await self.send_command(
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1),
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1), # type: ignore[call-arg]
check_result=True,
)
@@ -1285,11 +1322,11 @@ class Device(CompositeEventEmitter):
self.advertising_type = advertising_type
self.advertising = True
async def stop_advertising(self):
async def stop_advertising(self) -> None:
# Disable advertising
if self.advertising:
await self.send_command(
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0),
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), # type: ignore[call-arg]
check_result=True,
)
@@ -1304,14 +1341,14 @@ class Device(CompositeEventEmitter):
async def start_scanning(
self,
legacy=False,
active=True,
scan_interval=DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms
scan_window=DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
own_address_type=OwnAddressType.RANDOM,
filter_duplicates=False,
scanning_phys=(HCI_LE_1M_PHY, HCI_LE_CODED_PHY),
):
legacy: bool = False,
active: bool = True,
scan_interval: int = DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms
scan_window: int = DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
own_address_type: int = OwnAddressType.RANDOM,
filter_duplicates: bool = False,
scanning_phys: Tuple[int, int] = (HCI_LE_1M_PHY, HCI_LE_CODED_PHY),
) -> None:
# Check that the arguments are legal
if scan_interval < scan_window:
raise ValueError('scan_interval must be >= scan_window')
@@ -1361,7 +1398,7 @@ class Device(CompositeEventEmitter):
scan_types=[scan_type] * scanning_phy_count,
scan_intervals=[int(scan_window / 0.625)] * scanning_phy_count,
scan_windows=[int(scan_window / 0.625)] * scanning_phy_count,
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1372,7 +1409,7 @@ class Device(CompositeEventEmitter):
filter_duplicates=1 if filter_duplicates else 0,
duration=0, # TODO allow other values
period=0, # TODO allow other values
),
), # type: ignore[call-arg]
check_result=True,
)
else:
@@ -1390,7 +1427,7 @@ class Device(CompositeEventEmitter):
le_scan_window=int(scan_window / 0.625),
own_address_type=own_address_type,
scanning_filter_policy=HCI_LE_Set_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY,
),
), # type: ignore[call-arg]
check_result=True,
)
@@ -1398,25 +1435,25 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_LE_Set_Scan_Enable_Command(
le_scan_enable=1, filter_duplicates=1 if filter_duplicates else 0
),
), # type: ignore[call-arg]
check_result=True,
)
self.scanning_is_passive = not active
self.scanning = True
async def stop_scanning(self):
async def stop_scanning(self) -> None:
# Disable scanning
if self.supports_le_feature(HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE):
await self.send_command(
HCI_LE_Set_Extended_Scan_Enable_Command(
enable=0, filter_duplicates=0, duration=0, period=0
),
), # type: ignore[call-arg]
check_result=True,
)
else:
await self.send_command(
HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0),
HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0), # type: ignore[call-arg]
check_result=True,
)
@@ -1434,9 +1471,9 @@ class Device(CompositeEventEmitter):
if advertisement := accumulator.update(report):
self.emit('advertisement', advertisement)
async def start_discovery(self, auto_restart=True):
async def start_discovery(self, auto_restart: bool = True) -> None:
await self.send_command(
HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE),
HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE), # type: ignore[call-arg]
check_result=True,
)
@@ -1445,7 +1482,7 @@ class Device(CompositeEventEmitter):
lap=HCI_GENERAL_INQUIRY_LAP,
inquiry_length=DEVICE_DEFAULT_INQUIRY_LENGTH,
num_responses=0, # Unlimited number of responses.
)
) # type: ignore[call-arg]
)
if response.status != HCI_Command_Status_Event.PENDING:
self.discovering = False
@@ -1454,9 +1491,9 @@ class Device(CompositeEventEmitter):
self.auto_restart_inquiry = auto_restart
self.discovering = True
async def stop_discovery(self):
async def stop_discovery(self) -> None:
if self.discovering:
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True)
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True) # type: ignore[call-arg]
self.auto_restart_inquiry = True
self.discovering = False
@@ -1484,7 +1521,7 @@ class Device(CompositeEventEmitter):
HCI_Write_Scan_Enable_Command(scan_enable=scan_enable)
)
async def set_discoverable(self, discoverable=True):
async def set_discoverable(self, discoverable: bool = True) -> None:
self.discoverable = discoverable
if self.classic_enabled:
# Synthesize an inquiry response if none is set already
@@ -1504,7 +1541,7 @@ class Device(CompositeEventEmitter):
await self.send_command(
HCI_Write_Extended_Inquiry_Response_Command(
fec_required=0, extended_inquiry_response=self.inquiry_response
),
), # type: ignore[call-arg]
check_result=True,
)
await self.set_scan_enable(
@@ -1512,7 +1549,7 @@ class Device(CompositeEventEmitter):
page_scan_enabled=self.connectable,
)
async def set_connectable(self, connectable=True):
async def set_connectable(self, connectable: bool = True) -> None:
self.connectable = connectable
if self.classic_enabled:
await self.set_scan_enable(
@@ -1522,12 +1559,14 @@ class Device(CompositeEventEmitter):
async def connect(
self,
peer_address,
transport=BT_LE_TRANSPORT,
connection_parameters_preferences=None,
own_address_type=OwnAddressType.RANDOM,
timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT,
):
peer_address: Union[Address, str],
transport: int = BT_LE_TRANSPORT,
connection_parameters_preferences: Optional[
Dict[int, ConnectionParametersPreferences]
] = None,
own_address_type: int = OwnAddressType.RANDOM,
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
) -> Connection:
'''
Request a connection to a peer.
When transport is BLE, this method cannot be called if there is already a
@@ -1574,6 +1613,8 @@ class Device(CompositeEventEmitter):
):
raise ValueError('BR/EDR addresses must be PUBLIC')
assert isinstance(peer_address, Address)
def on_connection(connection):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection event against peer address
@@ -1691,7 +1732,7 @@ class Device(CompositeEventEmitter):
supervision_timeouts=supervision_timeouts,
min_ce_lengths=min_ce_lengths,
max_ce_lengths=max_ce_lengths,
)
) # type: ignore[call-arg]
)
else:
if HCI_LE_1M_PHY not in connection_parameters_preferences:
@@ -1720,7 +1761,7 @@ class Device(CompositeEventEmitter):
supervision_timeout=int(prefs.supervision_timeout / 10),
min_ce_length=int(prefs.min_ce_length / 0.625),
max_ce_length=int(prefs.max_ce_length / 0.625),
)
) # type: ignore[call-arg]
)
else:
# Save pending connection
@@ -1737,7 +1778,7 @@ class Device(CompositeEventEmitter):
clock_offset=0x0000,
allow_role_switch=0x01,
reserved=0,
)
) # type: ignore[call-arg]
)
if result.status != HCI_Command_Status_Event.PENDING:
@@ -1756,10 +1797,10 @@ class Device(CompositeEventEmitter):
)
except asyncio.TimeoutError:
if transport == BT_LE_TRANSPORT:
await self.send_command(HCI_LE_Create_Connection_Cancel_Command())
await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) # type: ignore[call-arg]
else:
await self.send_command(
HCI_Create_Connection_Cancel_Command(bd_addr=peer_address)
HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) # type: ignore[call-arg]
)
try:
@@ -1777,10 +1818,10 @@ class Device(CompositeEventEmitter):
async def accept(
self,
peer_address=Address.ANY,
role=BT_PERIPHERAL_ROLE,
timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT,
):
peer_address: Union[Address, str] = Address.ANY,
role: int = BT_PERIPHERAL_ROLE,
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
) -> Connection:
'''
Wait and accept any incoming connection or a connection from `peer_address` when
set.
@@ -1802,22 +1843,24 @@ class Device(CompositeEventEmitter):
peer_address, BT_BR_EDR_TRANSPORT
) # TODO: timeout
assert isinstance(peer_address, Address)
if peer_address == Address.NIL:
raise ValueError('accept on nil address')
# Create a future so that we can wait for the request
pending_request = asyncio.get_running_loop().create_future()
pending_request_fut = asyncio.get_running_loop().create_future()
if peer_address == Address.ANY:
self.classic_pending_accepts[Address.ANY].append(pending_request)
self.classic_pending_accepts[Address.ANY].append(pending_request_fut)
elif peer_address in self.classic_pending_accepts:
raise InvalidStateError('accept connection already pending')
else:
self.classic_pending_accepts[peer_address] = pending_request
self.classic_pending_accepts[peer_address] = [pending_request_fut]
try:
# Wait for a request or a completed connection
pending_request = self.abort_on('flush', pending_request)
pending_request = self.abort_on('flush', pending_request_fut)
result = await (
asyncio.wait_for(pending_request, timeout)
if timeout
@@ -1826,7 +1869,7 @@ class Device(CompositeEventEmitter):
except Exception:
# Remove future from device context
if peer_address == Address.ANY:
self.classic_pending_accepts[Address.ANY].remove(pending_request)
self.classic_pending_accepts[Address.ANY].remove(pending_request_fut)
else:
self.classic_pending_accepts.pop(peer_address)
raise
@@ -1838,6 +1881,7 @@ class Device(CompositeEventEmitter):
# Otherwise, result came from `on_connection_request`
peer_address, _class_of_device, _link_type = result
assert isinstance(peer_address, Address)
# Create a future so that we can wait for the connection's result
pending_connection = asyncio.get_running_loop().create_future()
@@ -1867,7 +1911,7 @@ class Device(CompositeEventEmitter):
try:
# Accept connection request
await self.send_command(
HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role)
HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role) # type: ignore[call-arg]
)
# Wait for connection complete
@@ -2243,7 +2287,7 @@ class Device(CompositeEventEmitter):
)
# [Classic only]
async def request_remote_name(self, remote): # remote: Connection | Address
async def request_remote_name(self, remote: Union[Address, Connection]) -> str:
# Set up event handlers
pending_name = asyncio.get_running_loop().create_future()
@@ -2271,7 +2315,7 @@ class Device(CompositeEventEmitter):
page_scan_repetition_mode=HCI_Remote_Name_Request_Command.R2,
reserved=0,
clock_offset=0, # TODO investigate non-0 values
)
) # type: ignore[call-arg]
)
if result.status != HCI_COMMAND_STATUS_PENDING:
@@ -2372,7 +2416,7 @@ class Device(CompositeEventEmitter):
# In this case, set the completed `connection` to the `accept` future
# result.
if peer_address in self.classic_pending_accepts:
future = self.classic_pending_accepts.pop(peer_address)
future, *_ = self.classic_pending_accepts.pop(peer_address)
future.set_result(connection)
# Emit an event to notify listeners of the new connection
@@ -2473,7 +2517,7 @@ class Device(CompositeEventEmitter):
# match a pending future using `bd_addr`
if bd_addr in self.classic_pending_accepts:
future = self.classic_pending_accepts.pop(bd_addr)
future, *_ = self.classic_pending_accepts.pop(bd_addr)
future.set_result((bd_addr, class_of_device, link_type))
# match first pending future for ANY address

View File

@@ -28,7 +28,7 @@ import enum
import functools
import logging
import struct
from typing import Sequence
from typing import Optional, Sequence
from colors import color
from .core import UUID, get_dict_key_by_value
@@ -204,6 +204,8 @@ class Service(Attribute):
See Vol 3, Part G - 3.1 SERVICE DEFINITION
'''
uuid: UUID
def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
# Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str):
@@ -221,7 +223,7 @@ class Service(Attribute):
self.characteristics = characteristics[:]
self.primary = primary
def get_advertising_data(self):
def get_advertising_data(self) -> Optional[bytes]:
"""
Get Service specific advertising data
Defined by each Service, default value is empty

View File

@@ -27,7 +27,7 @@ import asyncio
import logging
from collections import defaultdict
import struct
from typing import Tuple, Optional
from typing import List, Tuple, Optional
from pyee import EventEmitter
from colors import color
@@ -90,6 +90,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# GATT Server
# -----------------------------------------------------------------------------
class Server(EventEmitter):
attributes: List[Attribute]
def __init__(self, device):
super().__init__()
self.device = device
@@ -140,6 +142,7 @@ class Server(EventEmitter):
attribute
for attribute in self.attributes
if attribute.type == GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
and isinstance(attribute, Service)
and attribute.uuid == service_uuid
),
None,

View File

@@ -21,7 +21,7 @@ import collections
import logging
import functools
from colors import color
from typing import Dict, Type
from typing import Dict, Type, Union
from .core import (
BT_BR_EDR_TRANSPORT,
@@ -1729,7 +1729,9 @@ class Address:
address_type = data[offset - 1]
return Address.parse_address_with_type(data, offset, address_type)
def __init__(self, address, address_type=RANDOM_DEVICE_ADDRESS):
def __init__(
self, address: Union[bytes, str], address_type: int = RANDOM_DEVICE_ADDRESS
):
'''
Initialize an instance. `address` may be a byte array in little-endian
format, or a hex string in big-endian format (with optional ':'

View File

@@ -141,7 +141,7 @@ class Host(AbortableEventEmitter):
if controller_sink:
self.set_packet_sink(controller_sink)
async def flush(self):
async def flush(self) -> None:
# Make sure no command is pending
await self.command_semaphore.acquire()
@@ -660,7 +660,7 @@ class Host(AbortableEventEmitter):
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
latency=event.latency,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length=0,

View File

@@ -24,6 +24,7 @@ import asyncio
import logging
import os
import json
from typing import Optional
from colors import color
from .hci import Address
@@ -129,7 +130,7 @@ class PairingKeys:
for (key_property, key_value) in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}')
else:
print(f'{prefix}{color(property, "cyan")}: {value}')
print(f'{prefix}{color(container_property, "cyan")}: {value}')
# -----------------------------------------------------------------------------
@@ -216,7 +217,7 @@ class JsonKeyStore(KeyStore):
params = device_config.keystore.split(':', 1)[1:]
namespace = str(device_config.address)
if params:
filename = params[1]
filename = params[0]
else:
filename = None
@@ -242,7 +243,7 @@ class JsonKeyStore(KeyStore):
# Atomically replace the previous file
os.rename(temp_filename, self.filename)
async def delete(self, name):
async def delete(self, name: str) -> None:
db = await self.load()
namespace = db.get(self.namespace)
@@ -278,7 +279,7 @@ class JsonKeyStore(KeyStore):
await self.save(db)
async def get(self, name):
async def get(self, name: str) -> Optional[PairingKeys]:
db = await self.load()
namespace = db.get(self.namespace)

View File

@@ -20,7 +20,7 @@ import logging
import struct
from colors import color
import colors
from typing import Dict, Type
from typing import Dict, List, Type
from . import core
from .core import InvalidStateError
@@ -183,63 +183,63 @@ class DataElement:
raise ValueError('integer types must have a value size specified')
@staticmethod
def nil():
def nil() -> DataElement:
return DataElement(DataElement.NIL, None)
@staticmethod
def unsigned_integer(value, value_size):
def unsigned_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
@staticmethod
def unsigned_integer_8(value):
def unsigned_integer_8(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
@staticmethod
def unsigned_integer_16(value):
def unsigned_integer_16(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
@staticmethod
def unsigned_integer_32(value):
def unsigned_integer_32(value: int) -> DataElement:
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
@staticmethod
def signed_integer(value, value_size):
def signed_integer(value: int, value_size: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
@staticmethod
def signed_integer_8(value):
def signed_integer_8(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
@staticmethod
def signed_integer_16(value):
def signed_integer_16(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
@staticmethod
def signed_integer_32(value):
def signed_integer_32(value: int) -> DataElement:
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
@staticmethod
def uuid(value):
def uuid(value: core.UUID) -> DataElement:
return DataElement(DataElement.UUID, value)
@staticmethod
def text_string(value):
def text_string(value: str) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value)
@staticmethod
def boolean(value):
def boolean(value: bool) -> DataElement:
return DataElement(DataElement.BOOLEAN, value)
@staticmethod
def sequence(value):
def sequence(value: List[DataElement]) -> DataElement:
return DataElement(DataElement.SEQUENCE, value)
@staticmethod
def alternative(value):
def alternative(value: List[DataElement]) -> DataElement:
return DataElement(DataElement.ALTERNATIVE, value)
@staticmethod
def url(value):
def url(value: str) -> DataElement:
return DataElement(DataElement.URL, value)
@staticmethod
@@ -458,7 +458,7 @@ class DataElement:
# -----------------------------------------------------------------------------
class ServiceAttribute:
def __init__(self, attribute_id, value):
def __init__(self, attribute_id: int, value: DataElement) -> None:
self.id = attribute_id
self.value = value

View File

@@ -26,7 +26,7 @@ from __future__ import annotations
import logging
import asyncio
import secrets
from typing import Dict, Type
from typing import Dict, Optional, Type
from pyee import EventEmitter
from colors import color
@@ -504,27 +504,29 @@ class PairingDelegate:
def __init__(
self,
io_capability=NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION,
):
io_capability: int = NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: int = DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: int = DEFAULT_KEY_DISTRIBUTION,
) -> None:
self.io_capability = io_capability
self.local_initiator_key_distribution = local_initiator_key_distribution
self.local_responder_key_distribution = local_responder_key_distribution
async def accept(self):
async def accept(self) -> bool:
return True
async def confirm(self):
async def confirm(self) -> bool:
return True
async def compare_numbers(self, _number, _digits=6):
# pylint: disable-next=unused-argument
async def compare_numbers(self, number: int, digits: int) -> bool:
return True
async def get_number(self):
async def get_number(self) -> int:
return 0
async def display_number(self, _number, _digits=6):
# pylint: disable-next=unused-argument
async def display_number(self, number: int, digits: int) -> None:
pass
async def key_distribution_response(
@@ -538,7 +540,13 @@ class PairingDelegate:
# -----------------------------------------------------------------------------
class PairingConfig:
def __init__(self, sc=True, mitm=True, bonding=True, delegate=None):
def __init__(
self,
sc: bool = True,
mitm: bool = True,
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
) -> None:
self.sc = sc
self.mitm = mitm
self.bonding = bonding
@@ -655,7 +663,8 @@ class Session:
self.peer_expected_distributions = []
self.dh_key = None
self.confirm_value = None
self.passkey = 0
self.passkey = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method = 0
@@ -833,6 +842,7 @@ class Session:
# Generate random Passkey/PIN code
self.passkey = secrets.randbelow(1000000)
logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
self.passkey_ready.set()
# The value of TK is computed from the PIN code
if not self.sc:
@@ -853,6 +863,8 @@ class Session:
self.tk = passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}')
self.passkey_ready.set()
if next_steps is not None:
next_steps()
@@ -904,17 +916,29 @@ class Session:
logger.debug(f'generated random: {self.r.hex()}')
if self.sc:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return
if self.is_initiator:
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
else:
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
async def next_steps():
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
z = 0
elif self.pairing_method == self.PASSKEY:
# We need a passkey
await self.passkey_ready.wait()
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
else:
return
if self.is_initiator:
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
else:
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
self.send_command(
SMP_Pairing_Confirm_Command(confirm_value=confirm_value)
)
# Perform the next steps asynchronously in case we need to wait for input
self.connection.abort_on('disconnection', next_steps())
else:
confirm_value = crypto.c1(
self.tk,
@@ -927,7 +951,7 @@ class Session:
self.ra,
)
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
def send_pairing_random_command(self):
self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
@@ -1358,8 +1382,8 @@ class Session:
# Start phase 2
if self.sc:
if self.pairing_method == self.PASSKEY and self.passkey_display:
self.display_passkey()
if self.pairing_method == self.PASSKEY:
self.display_or_input_passkey()
self.send_public_key_command()
else:
@@ -1420,18 +1444,22 @@ class Session:
else:
srand = self.r
mrand = command.random_value
stk = crypto.s1(self.tk, srand, mrand)
logger.debug(f'STK = {stk.hex()}')
self.stk = crypto.s1(self.tk, srand, mrand)
logger.debug(f'STK = {self.stk.hex()}')
# Generate LTK
self.ltk = crypto.r()
if self.is_initiator:
self.start_encryption(stk)
self.start_encryption(self.stk)
else:
self.send_pairing_random_command()
def on_smp_pairing_random_command_secure_connections(self, command):
if self.pairing_method == self.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command')
return
# pylint: disable=too-many-return-statements
if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
@@ -1559,17 +1587,13 @@ class Session:
logger.debug(f'DH key: {self.dh_key.hex()}')
if self.is_initiator:
if self.pairing_method == self.PASSKEY:
if self.passkey_display:
self.send_pairing_confirm_command()
else:
self.input_passkey(self.send_pairing_confirm_command)
self.send_pairing_confirm_command()
else:
# Send our public key back to the initiator
if self.pairing_method == self.PASSKEY:
self.display_or_input_passkey(self.send_public_key_command)
else:
self.send_public_key_command()
self.display_or_input_passkey()
# Send our public key back to the initiator
self.send_public_key_command()
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
# We can now send the confirmation value

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_transport(name):
async def open_transport(name: str) -> Transport:
'''
Open a transport by name.
The name must be <type>:<parameters>

View File

@@ -259,7 +259,7 @@ class Transport:
def __iter__(self):
return iter((self.source, self.sink))
async def close(self):
async def close(self) -> None:
self.source.close()
self.sink.close()

View File

@@ -20,7 +20,7 @@ import logging
import traceback
import collections
import sys
from typing import Awaitable
from typing import Awaitable, TypeVar
from functools import wraps
from colors import color
from pyee import EventEmitter
@@ -65,8 +65,11 @@ def composite_listener(cls):
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
class AbortableEventEmitter(EventEmitter):
def abort_on(self, event: str, awaitable: Awaitable):
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
"""
Set a coroutine or future to abort when an event occur.
"""
@@ -75,6 +78,8 @@ class AbortableEventEmitter(EventEmitter):
return future
def on_event(*_):
if future.done():
return
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`

View File

@@ -64,11 +64,11 @@ module = "aioconsole.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "bitstruct.*"
module = "colors.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "colors.*"
module = "construct.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]

View File

@@ -30,15 +30,14 @@ package_dir =
bumble.apps = apps
include-package-data = True
install_requires =
aioconsole >= 0.4.1
ansicolors >= 1.1
appdirs >= 1.4
bitstruct >= 8.12
click >= 7.1.2; platform_system!='Emscripten'
construct >= 2.10
cryptography == 35; platform_system!='Emscripten'
grpcio >= 1.46; platform_system!='Emscripten'
libusb1 >= 2.0.1; platform_system!='Emscripten'
libusb-package == 1.0.26.0; platform_system!='Emscripten'
libusb-package == 1.0.26.1; platform_system!='Emscripten'
prompt_toolkit >= 3.0.16; platform_system!='Emscripten'
protobuf >= 3.12.4
pyee >= 8.2.2
@@ -73,11 +72,11 @@ test =
pytest-html >= 3.2.0
coverage >= 6.4
development =
black >= 22.10
black == 22.10
invoke >= 1.7.3
mypy >= 0.991
mypy == 0.991
nox >= 2022
pylint >= 2.15.8
pylint == 2.15.8
types-appdirs >= 1.4.3
types-invoke >= 1.7.3
types-protobuf >= 4.21.0

View File

@@ -25,10 +25,8 @@ def test_ad_data():
assert data == ad_bytes
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None
assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []
assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [
bytes([123])
]
assert ad.get_all(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) == []
assert ad.get_all(AdvertisingData.TX_POWER_LEVEL, raw=True) == [bytes([123])]
data2 = bytes([2, AdvertisingData.TX_POWER_LEVEL, 234])
ad.append(data2)
@@ -36,8 +34,8 @@ def test_ad_data():
assert ad_bytes == data + data2
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None
assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []
assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [
assert ad.get_all(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) == []
assert ad.get_all(AdvertisingData.TX_POWER_LEVEL, raw=True) == [
bytes([123]),
bytes([234]),
]

View File

@@ -32,7 +32,6 @@ from bumble.smp import (
PairingDelegate,
SMP_PAIRING_NOT_SUPPORTED_ERROR,
SMP_CONFIRM_VALUE_FAILED_ERROR,
SMP_ID_KEY_DISTRIBUTION_FLAG,
)
from bumble.core import ProtocolError
@@ -273,9 +272,15 @@ KEY_DIST = range(16)
@pytest.mark.asyncio
@pytest.mark.parametrize(
'io_cap, sc, mitm, key_dist', itertools.product(IO_CAP, SC, MITM, KEY_DIST)
'io_caps, sc, mitm, key_dist',
itertools.chain(
itertools.product([IO_CAP], SC, MITM, [15]),
itertools.product(
[[PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT]], SC, MITM, KEY_DIST
),
),
)
async def test_self_smp(io_cap, sc, mitm, key_dist):
async def test_self_smp(io_caps, sc, mitm, key_dist):
class Delegate(PairingDelegate):
def __init__(
self,
@@ -296,6 +301,7 @@ async def test_self_smp(io_cap, sc, mitm, key_dist):
self.peer_delegate = None
self.number = asyncio.get_running_loop().create_future()
# pylint: disable-next=unused-argument
async def compare_numbers(self, number, digits):
if self.peer_delegate is None:
logger.warning(f'[{self.name}] no peer delegate')
@@ -331,8 +337,9 @@ async def test_self_smp(io_cap, sc, mitm, key_dist):
pairing_config_sets = [('Initiator', [None]), ('Responder', [None])]
for pairing_config_set in pairing_config_sets:
delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist)
pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate))
for io_cap in io_caps:
delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist)
pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate))
for pairing_config1 in pairing_config_sets[0][1]:
for pairing_config2 in pairing_config_sets[1][1]: