forked from auracaster/bumble_mirror
Compare commits
23 Commits
gbg/fix-co
...
gbg/smp-im
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8beb6b1ff | ||
|
|
2d44de611f | ||
|
|
e6fc63b2d8 | ||
|
|
1321c7da81 | ||
|
|
5a1b03fd91 | ||
|
|
de47721753 | ||
|
|
83a76a75d3 | ||
|
|
d5b5ef8313 | ||
|
|
856a8d53cd | ||
|
|
177c273a57 | ||
|
|
24a863983d | ||
|
|
b7ef09d4a3 | ||
|
|
b5b6cd13b8 | ||
|
|
ef781bc374 | ||
|
|
00978c1d63 | ||
|
|
b731f6f556 | ||
|
|
ed261886e1 | ||
|
|
5e18094c31 | ||
|
|
9a9b4e5bf1 | ||
|
|
895f1618d8 | ||
|
|
52746e0c68 | ||
|
|
f9b7072423 | ||
|
|
fa4be1958f |
7
.vscode/settings.json
vendored
7
.vscode/settings.json
vendored
@@ -71,5 +71,10 @@
|
||||
"editor.rulers": [88]
|
||||
},
|
||||
"python.formatting.provider": "black",
|
||||
"pylint.importStrategy": "useBundled"
|
||||
"pylint.importStrategy": "useBundled",
|
||||
"python.testing.pytestArgs": [
|
||||
"."
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
||||
|
||||
119
apps/pair.py
119
apps/pair.py
@@ -19,8 +19,8 @@ import asyncio
|
||||
import os
|
||||
import logging
|
||||
import click
|
||||
import aioconsole
|
||||
from colors import color
|
||||
from prompt_toolkit.shortcuts import PromptSession
|
||||
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.transport import open_transport_or_link
|
||||
@@ -42,9 +42,23 @@ from bumble.att import (
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Waiter:
|
||||
instance = None
|
||||
|
||||
def __init__(self):
|
||||
self.done = asyncio.get_running_loop().create_future()
|
||||
|
||||
def terminate(self):
|
||||
self.done.set_result(None)
|
||||
|
||||
async def wait_until_terminated(self):
|
||||
return await self.done
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Delegate(PairingDelegate):
|
||||
def __init__(self, mode, connection, capability_string, prompt):
|
||||
def __init__(self, mode, connection, capability_string, do_prompt):
|
||||
super().__init__(
|
||||
{
|
||||
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
|
||||
@@ -58,7 +72,18 @@ class Delegate(PairingDelegate):
|
||||
self.mode = mode
|
||||
self.peer = Peer(connection)
|
||||
self.peer_name = None
|
||||
self.prompt = prompt
|
||||
self.do_prompt = do_prompt
|
||||
|
||||
def print(self, message):
|
||||
print(color(message, 'yellow'))
|
||||
|
||||
async def prompt(self, message):
|
||||
# Wait a bit to allow some of the log lines to print before we prompt
|
||||
await asyncio.sleep(1)
|
||||
|
||||
session = PromptSession(message)
|
||||
response = await session.prompt_async()
|
||||
return response.lower().strip()
|
||||
|
||||
async def update_peer_name(self):
|
||||
if self.peer_name is not None:
|
||||
@@ -73,19 +98,15 @@ class Delegate(PairingDelegate):
|
||||
self.peer_name = '[?]'
|
||||
|
||||
async def accept(self):
|
||||
if self.prompt:
|
||||
if self.do_prompt:
|
||||
await self.update_peer_name()
|
||||
|
||||
# Wait a bit to allow some of the log lines to print before we prompt
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Prompt for acceptance
|
||||
print(color('###-----------------------------------', 'yellow'))
|
||||
print(color(f'### Pairing request from {self.peer_name}', 'yellow'))
|
||||
print(color('###-----------------------------------', 'yellow'))
|
||||
self.print('###-----------------------------------')
|
||||
self.print(f'### Pairing request from {self.peer_name}')
|
||||
self.print('###-----------------------------------')
|
||||
while True:
|
||||
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow'))
|
||||
response = response.lower().strip()
|
||||
response = await self.prompt('>>> Accept? ')
|
||||
|
||||
if response == 'yes':
|
||||
return True
|
||||
@@ -96,23 +117,17 @@ class Delegate(PairingDelegate):
|
||||
# Accept silently
|
||||
return True
|
||||
|
||||
async def compare_numbers(self, number, digits=6):
|
||||
async def compare_numbers(self, number, digits):
|
||||
await self.update_peer_name()
|
||||
|
||||
# Wait a bit to allow some of the log lines to print before we prompt
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Prompt for a numeric comparison
|
||||
print(color('###-----------------------------------', 'yellow'))
|
||||
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
|
||||
print(color('###-----------------------------------', 'yellow'))
|
||||
self.print('###-----------------------------------')
|
||||
self.print(f'### Pairing with {self.peer_name}')
|
||||
self.print('###-----------------------------------')
|
||||
while True:
|
||||
response = await aioconsole.ainput(
|
||||
color(
|
||||
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
|
||||
)
|
||||
response = await self.prompt(
|
||||
f'>>> Does the other device display {number:0{digits}}? '
|
||||
)
|
||||
response = response.lower().strip()
|
||||
|
||||
if response == 'yes':
|
||||
return True
|
||||
@@ -123,30 +138,24 @@ class Delegate(PairingDelegate):
|
||||
async def get_number(self):
|
||||
await self.update_peer_name()
|
||||
|
||||
# Wait a bit to allow some of the log lines to print before we prompt
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Prompt for a PIN
|
||||
while True:
|
||||
try:
|
||||
print(color('###-----------------------------------', 'yellow'))
|
||||
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
|
||||
print(color('###-----------------------------------', 'yellow'))
|
||||
return int(await aioconsole.ainput(color('>>> Enter PIN: ', 'yellow')))
|
||||
self.print('###-----------------------------------')
|
||||
self.print(f'### Pairing with {self.peer_name}')
|
||||
self.print('###-----------------------------------')
|
||||
return int(await self.prompt('>>> Enter PIN: '))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def display_number(self, number, digits=6):
|
||||
async def display_number(self, number, digits):
|
||||
await self.update_peer_name()
|
||||
|
||||
# Wait a bit to allow some of the log lines to print before we prompt
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Display a PIN code
|
||||
print(color('###-----------------------------------', 'yellow'))
|
||||
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
|
||||
print(color(f'### PIN: {number:0{digits}}', 'yellow'))
|
||||
print(color('###-----------------------------------', 'yellow'))
|
||||
self.print('###-----------------------------------')
|
||||
self.print(f'### Pairing with {self.peer_name}')
|
||||
self.print(f'### PIN: {number:0{digits}}')
|
||||
self.print('###-----------------------------------')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -238,6 +247,7 @@ def on_pairing(keys):
|
||||
print(color('*** Paired!', 'cyan'))
|
||||
keys.print(prefix=color('*** ', 'cyan'))
|
||||
print(color('***-----------------------------------', 'cyan'))
|
||||
Waiter.instance.terminate()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -245,6 +255,7 @@ def on_pairing_failure(reason):
|
||||
print(color('***-----------------------------------', 'red'))
|
||||
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
|
||||
print(color('***-----------------------------------', 'red'))
|
||||
Waiter.instance.terminate()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -262,6 +273,8 @@ async def pair(
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
):
|
||||
Waiter.instance = Waiter()
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
@@ -332,7 +345,19 @@ async def pair(
|
||||
# Advertise so that peers can find us and connect
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
await hci_source.wait_for_termination()
|
||||
# Run until the user asks to exit
|
||||
await Waiter.instance.wait_until_terminated()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LogHandler(logging.Handler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))
|
||||
|
||||
def emit(self, record):
|
||||
message = self.format(record)
|
||||
print(message)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -366,7 +391,11 @@ async def pair(
|
||||
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
|
||||
)
|
||||
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
|
||||
@click.option('--keystore-file', help='File in which to store the pairing keys')
|
||||
@click.option(
|
||||
'--keystore-file',
|
||||
metavar='<filename>',
|
||||
help='File in which to store the pairing keys',
|
||||
)
|
||||
@click.argument('device-config')
|
||||
@click.argument('hci_transport')
|
||||
@click.argument('address-or-name', required=False)
|
||||
@@ -384,7 +413,13 @@ def main(
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
# Setup logging
|
||||
log_handler = LogHandler()
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.addHandler(log_handler)
|
||||
root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
|
||||
# Pair
|
||||
asyncio.run(
|
||||
pair(
|
||||
mode,
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import struct
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
import bitstruct
|
||||
import construct
|
||||
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
from .sdp import (
|
||||
@@ -258,7 +258,17 @@ class SbcMediaCodecInformation(
|
||||
A2DP spec - 4.3.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
BIT_FIELDS = 'u4u4u4u2u2u8u8'
|
||||
BIT_FIELDS = construct.Bitwise(
|
||||
construct.Sequence(
|
||||
construct.BitsInteger(4),
|
||||
construct.BitsInteger(4),
|
||||
construct.BitsInteger(4),
|
||||
construct.BitsInteger(2),
|
||||
construct.BitsInteger(2),
|
||||
construct.BitsInteger(8),
|
||||
construct.BitsInteger(8),
|
||||
)
|
||||
)
|
||||
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
|
||||
CHANNEL_MODE_BITS = {
|
||||
SBC_MONO_CHANNEL_MODE: 1 << 3,
|
||||
@@ -276,7 +286,7 @@ class SbcMediaCodecInformation(
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
return SbcMediaCodecInformation(
|
||||
*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)
|
||||
*SbcMediaCodecInformation.BIT_FIELDS.parse(data)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -326,7 +336,7 @@ class SbcMediaCodecInformation(
|
||||
)
|
||||
|
||||
def __bytes__(self):
|
||||
return bitstruct.pack(self.BIT_FIELDS, *self)
|
||||
return self.BIT_FIELDS.build(self)
|
||||
|
||||
def __str__(self):
|
||||
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
|
||||
@@ -350,14 +360,23 @@ class SbcMediaCodecInformation(
|
||||
class AacMediaCodecInformation(
|
||||
namedtuple(
|
||||
'AacMediaCodecInformation',
|
||||
['object_type', 'sampling_frequency', 'channels', 'vbr', 'bitrate'],
|
||||
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
|
||||
)
|
||||
):
|
||||
'''
|
||||
A2DP spec - 4.5.2 Codec Specific Information Elements
|
||||
'''
|
||||
|
||||
BIT_FIELDS = 'u8u12u2p2u1u23'
|
||||
BIT_FIELDS = construct.Bitwise(
|
||||
construct.Sequence(
|
||||
construct.BitsInteger(8),
|
||||
construct.BitsInteger(12),
|
||||
construct.BitsInteger(2),
|
||||
construct.BitsInteger(2),
|
||||
construct.BitsInteger(1),
|
||||
construct.BitsInteger(23),
|
||||
)
|
||||
)
|
||||
OBJECT_TYPE_BITS = {
|
||||
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
|
||||
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
|
||||
@@ -383,7 +402,7 @@ class AacMediaCodecInformation(
|
||||
@staticmethod
|
||||
def from_bytes(data):
|
||||
return AacMediaCodecInformation(
|
||||
*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data)
|
||||
*AacMediaCodecInformation.BIT_FIELDS.parse(data)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -394,6 +413,7 @@ class AacMediaCodecInformation(
|
||||
object_type=cls.OBJECT_TYPE_BITS[object_type],
|
||||
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
|
||||
channels=cls.CHANNELS_BITS[channels],
|
||||
rfa=0,
|
||||
vbr=vbr,
|
||||
bitrate=bitrate,
|
||||
)
|
||||
@@ -411,7 +431,7 @@ class AacMediaCodecInformation(
|
||||
)
|
||||
|
||||
def __bytes__(self):
|
||||
return bitstruct.pack(self.BIT_FIELDS, *self)
|
||||
return self.BIT_FIELDS.build(self)
|
||||
|
||||
def __str__(self):
|
||||
object_types = [
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import struct
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
from .company_ids import COMPANY_IDENTIFIERS
|
||||
|
||||
@@ -146,7 +147,7 @@ class UUID:
|
||||
'''
|
||||
|
||||
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
|
||||
UUIDS: list[UUID] = [] # Registry of all instances created
|
||||
UUIDS: List[UUID] = [] # Registry of all instances created
|
||||
|
||||
def __init__(self, uuid_str_or_int, name=None):
|
||||
if isinstance(uuid_str_or_int, int):
|
||||
@@ -181,7 +182,7 @@ class UUID:
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, uuid_bytes, name=None):
|
||||
def from_bytes(cls, uuid_bytes: bytes, name: Optional[str] = None) -> UUID:
|
||||
if len(uuid_bytes) in (2, 4, 16):
|
||||
self = cls.__new__(cls)
|
||||
self.uuid_bytes = uuid_bytes
|
||||
@@ -225,7 +226,7 @@ class UUID:
|
||||
'''
|
||||
return self.to_bytes(force_128=(len(self.uuid_bytes) == 4))
|
||||
|
||||
def to_hex_str(self):
|
||||
def to_hex_str(self) -> str:
|
||||
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
|
||||
return bytes(reversed(self.uuid_bytes)).hex().upper()
|
||||
|
||||
@@ -607,6 +608,11 @@ class DeviceClass:
|
||||
# -----------------------------------------------------------------------------
|
||||
# Advertising Data
|
||||
# -----------------------------------------------------------------------------
|
||||
AdvertisingObject = Union[
|
||||
List[UUID], Tuple[UUID, bytes], bytes, str, int, Tuple[int, int], Tuple[int, bytes]
|
||||
]
|
||||
|
||||
|
||||
class AdvertisingData:
|
||||
# fmt: off
|
||||
# pylint: disable=line-too-long
|
||||
@@ -722,10 +728,12 @@ class AdvertisingData:
|
||||
BR_EDR_CONTROLLER_FLAG = 0x08
|
||||
BR_EDR_HOST_FLAG = 0x10
|
||||
|
||||
ad_structures: List[Tuple[int, bytes]]
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
def __init__(self, ad_structures=None):
|
||||
def __init__(self, ad_structures: Optional[List[Tuple[int, bytes]]] = None) -> None:
|
||||
if ad_structures is None:
|
||||
ad_structures = []
|
||||
self.ad_structures = ad_structures[:]
|
||||
@@ -752,7 +760,7 @@ class AdvertisingData:
|
||||
return ','.join(bit_flags_to_strings(flags, flag_names))
|
||||
|
||||
@staticmethod
|
||||
def uuid_list_to_objects(ad_data, uuid_size):
|
||||
def uuid_list_to_objects(ad_data: bytes, uuid_size: int) -> List[UUID]:
|
||||
uuids = []
|
||||
offset = 0
|
||||
while (uuid_size * (offset + 1)) <= len(ad_data):
|
||||
@@ -829,7 +837,7 @@ class AdvertisingData:
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
@staticmethod
|
||||
def ad_data_to_object(ad_type, ad_data):
|
||||
def ad_data_to_object(ad_type: int, ad_data: bytes) -> AdvertisingObject:
|
||||
if ad_type in (
|
||||
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
@@ -868,22 +876,22 @@ class AdvertisingData:
|
||||
return ad_data.decode("utf-8")
|
||||
|
||||
if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS):
|
||||
return ad_data[0]
|
||||
return cast(int, struct.unpack('B', ad_data)[0])
|
||||
|
||||
if ad_type in (
|
||||
AdvertisingData.APPEARANCE,
|
||||
AdvertisingData.ADVERTISING_INTERVAL,
|
||||
):
|
||||
return struct.unpack('<H', ad_data)[0]
|
||||
return cast(int, struct.unpack('<H', ad_data)[0])
|
||||
|
||||
if ad_type == AdvertisingData.CLASS_OF_DEVICE:
|
||||
return struct.unpack('<I', bytes([*ad_data, 0]))[0]
|
||||
return cast(int, struct.unpack('<I', bytes([*ad_data, 0]))[0])
|
||||
|
||||
if ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
|
||||
return struct.unpack('<HH', ad_data)
|
||||
return cast(Tuple[int, int], struct.unpack('<HH', ad_data))
|
||||
|
||||
if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
|
||||
return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:])
|
||||
return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:])
|
||||
|
||||
return ad_data
|
||||
|
||||
@@ -898,26 +906,27 @@ class AdvertisingData:
|
||||
self.ad_structures.append((ad_type, ad_data))
|
||||
offset += length
|
||||
|
||||
def get(self, type_id, return_all=False, raw=False):
|
||||
def get_all(self, type_id: int, raw: bool = False) -> List[AdvertisingObject]:
|
||||
'''
|
||||
Get Advertising Data Structure(s) with a given type
|
||||
|
||||
If return_all is True, returns a (possibly empty) list of matches,
|
||||
else returns the first entry, or None if no structure matches.
|
||||
Returns a (possibly empty) list of matches.
|
||||
'''
|
||||
|
||||
def process_ad_data(ad_data):
|
||||
def process_ad_data(ad_data: bytes) -> AdvertisingObject:
|
||||
return ad_data if raw else self.ad_data_to_object(type_id, ad_data)
|
||||
|
||||
if return_all:
|
||||
return [
|
||||
process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id
|
||||
]
|
||||
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
|
||||
|
||||
return next(
|
||||
(process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id),
|
||||
None,
|
||||
)
|
||||
def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingObject]:
|
||||
'''
|
||||
Get Advertising Data Structure(s) with a given type
|
||||
|
||||
Returns the first entry, or None if no structure matches.
|
||||
'''
|
||||
|
||||
all = self.get_all(type_id, raw=raw)
|
||||
return all[0] if all else None
|
||||
|
||||
def __bytes__(self):
|
||||
return b''.join(
|
||||
|
||||
238
bumble/device.py
238
bumble/device.py
@@ -23,7 +23,7 @@ import asyncio
|
||||
import logging
|
||||
from contextlib import asynccontextmanager, AsyncExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from colors import color
|
||||
|
||||
@@ -197,6 +197,8 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Advertisement:
|
||||
address: Address
|
||||
|
||||
TX_POWER_NOT_AVAILABLE = (
|
||||
HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE
|
||||
)
|
||||
@@ -511,6 +513,17 @@ ConnectionParametersPreferences.default = ConnectionParametersPreferences()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Connection(CompositeEventEmitter):
|
||||
device: Device
|
||||
handle: int
|
||||
transport: int
|
||||
self_address: Address
|
||||
peer_address: Address
|
||||
role: int
|
||||
encryption: int
|
||||
authenticated: bool
|
||||
sc: bool
|
||||
link_key_type: int
|
||||
|
||||
@composite_listener
|
||||
class Listener:
|
||||
def on_disconnection(self, reason):
|
||||
@@ -611,6 +624,10 @@ class Connection(CompositeEventEmitter):
|
||||
def is_encrypted(self):
|
||||
return self.encryption != 0
|
||||
|
||||
@property
|
||||
def is_incomplete(self) -> bool:
|
||||
return self.handle == None
|
||||
|
||||
def send_l2cap_pdu(self, cid, pdu):
|
||||
self.device.send_l2cap_pdu(self.handle, cid, pdu)
|
||||
|
||||
@@ -626,20 +643,22 @@ class Connection(CompositeEventEmitter):
|
||||
):
|
||||
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
|
||||
|
||||
async def disconnect(self, reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR):
|
||||
return await self.device.disconnect(self, reason)
|
||||
async def disconnect(
|
||||
self, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
|
||||
) -> None:
|
||||
await self.device.disconnect(self, reason)
|
||||
|
||||
async def pair(self):
|
||||
async def pair(self) -> None:
|
||||
return await self.device.pair(self)
|
||||
|
||||
def request_pairing(self):
|
||||
def request_pairing(self) -> None:
|
||||
return self.device.request_pairing(self)
|
||||
|
||||
# [Classic only]
|
||||
async def authenticate(self):
|
||||
async def authenticate(self) -> None:
|
||||
return await self.device.authenticate(self)
|
||||
|
||||
async def encrypt(self, enable=True):
|
||||
async def encrypt(self, enable: bool = True) -> None:
|
||||
return await self.device.encrypt(self, enable)
|
||||
|
||||
async def sustain(self, timeout=None):
|
||||
@@ -707,10 +726,10 @@ class Connection(CompositeEventEmitter):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class DeviceConfiguration:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# Setup defaults
|
||||
self.name = DEVICE_DEFAULT_NAME
|
||||
self.address = DEVICE_DEFAULT_ADDRESS
|
||||
self.address = Address(DEVICE_DEFAULT_ADDRESS)
|
||||
self.class_of_device = DEVICE_DEFAULT_CLASS_OF_DEVICE
|
||||
self.scan_response_data = DEVICE_DEFAULT_SCAN_RESPONSE_DATA
|
||||
self.advertising_interval_min = DEVICE_DEFAULT_ADVERTISING_INTERVAL
|
||||
@@ -730,12 +749,13 @@ class DeviceConfiguration:
|
||||
)
|
||||
self.irk = bytes(16) # This really must be changed for any level of security
|
||||
self.keystore = None
|
||||
self.gatt_services = []
|
||||
self.gatt_services: List[Dict[str, Any]] = []
|
||||
|
||||
def load_from_dict(self, config):
|
||||
def load_from_dict(self, config: Dict[str, Any]) -> None:
|
||||
# Load simple properties
|
||||
self.name = config.get('name', self.name)
|
||||
self.address = Address(config.get('address', self.address))
|
||||
if address := config.get('address', None):
|
||||
self.address = Address(address)
|
||||
self.class_of_device = config.get('class_of_device', self.class_of_device)
|
||||
self.advertising_interval_min = config.get(
|
||||
'advertising_interval', self.advertising_interval_min
|
||||
@@ -842,6 +862,22 @@ device_host_event_handlers: list[str] = []
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Device(CompositeEventEmitter):
|
||||
# incomplete list of fields.
|
||||
random_address: Address
|
||||
public_address: Address
|
||||
classic_enabled: bool
|
||||
name: str
|
||||
class_of_device: int
|
||||
gatt_server: gatt_server.Server
|
||||
advertising_data: bytes
|
||||
scan_response_data: bytes
|
||||
connections: Dict[int, Connection]
|
||||
pending_connections: Dict[Address, Connection]
|
||||
classic_pending_accepts: Dict[
|
||||
Address, List[asyncio.Future[Union[Connection, Tuple[Address, int, int]]]]
|
||||
]
|
||||
advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator]
|
||||
|
||||
@composite_listener
|
||||
class Listener:
|
||||
def on_advertisement(self, advertisement):
|
||||
@@ -888,12 +924,12 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name=None,
|
||||
address=None,
|
||||
config=None,
|
||||
host=None,
|
||||
generic_access_service=True,
|
||||
):
|
||||
name: Optional[str] = None,
|
||||
address: Optional[Address] = None,
|
||||
config: Optional[DeviceConfiguration] = None,
|
||||
host: Optional[Host] = None,
|
||||
generic_access_service: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._host = None
|
||||
@@ -995,10 +1031,12 @@ class Device(CompositeEventEmitter):
|
||||
setup_event_forwarding(self.gatt_server, self, 'characteristic_subscription')
|
||||
|
||||
# Set the initial host
|
||||
self.host = host
|
||||
if host:
|
||||
self.host = host
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
def host(self) -> Host:
|
||||
assert self._host
|
||||
return self._host
|
||||
|
||||
@host.setter
|
||||
@@ -1032,15 +1070,18 @@ class Device(CompositeEventEmitter):
|
||||
def sdp_service_records(self, service_records):
|
||||
self.sdp_server.service_records = service_records
|
||||
|
||||
def lookup_connection(self, connection_handle):
|
||||
def lookup_connection(self, connection_handle: int) -> Optional[Connection]:
|
||||
if connection := self.connections.get(connection_handle):
|
||||
return connection
|
||||
|
||||
return None
|
||||
|
||||
def find_connection_by_bd_addr(
|
||||
self, bd_addr, transport=None, check_address_type=False
|
||||
):
|
||||
self,
|
||||
bd_addr: Address,
|
||||
transport: Optional[int] = None,
|
||||
check_address_type: bool = False,
|
||||
) -> Optional[Connection]:
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address.to_bytes() == bd_addr.to_bytes():
|
||||
if (
|
||||
@@ -1098,11 +1139,11 @@ class Device(CompositeEventEmitter):
|
||||
logger.warning('!!! Command timed out')
|
||||
raise CommandTimeoutError() from error
|
||||
|
||||
async def power_on(self):
|
||||
async def power_on(self) -> None:
|
||||
# Reset the controller
|
||||
await self.host.reset()
|
||||
|
||||
response = await self.send_command(HCI_Read_BD_ADDR_Command())
|
||||
response = await self.send_command(HCI_Read_BD_ADDR_Command()) # type: ignore[call-arg]
|
||||
if response.return_parameters.status == HCI_SUCCESS:
|
||||
logger.debug(
|
||||
color(f'BD_ADDR: {response.return_parameters.bd_addr}', 'yellow')
|
||||
@@ -1114,7 +1155,7 @@ class Device(CompositeEventEmitter):
|
||||
HCI_Write_LE_Host_Support_Command(
|
||||
le_supported_host=int(self.le_enabled),
|
||||
simultaneous_le_host=int(self.le_simultaneous_enabled),
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
if self.le_enabled:
|
||||
@@ -1124,7 +1165,7 @@ class Device(CompositeEventEmitter):
|
||||
if self.host.supports_command(HCI_LE_RAND_COMMAND):
|
||||
# Get 8 random bytes
|
||||
response = await self.send_command(
|
||||
HCI_LE_Rand_Command(), check_result=True
|
||||
HCI_LE_Rand_Command(), check_result=True # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Ensure the address bytes can be a static random address
|
||||
@@ -1145,7 +1186,7 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Random_Address_Command(
|
||||
random_address=self.random_address
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1153,7 +1194,7 @@ class Device(CompositeEventEmitter):
|
||||
if self.keystore and self.host.supports_command(
|
||||
HCI_LE_CLEAR_RESOLVING_LIST_COMMAND
|
||||
):
|
||||
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
|
||||
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
|
||||
|
||||
resolving_keys = await self.keystore.get_resolving_keys()
|
||||
for (irk, address) in resolving_keys:
|
||||
@@ -1163,7 +1204,7 @@ class Device(CompositeEventEmitter):
|
||||
peer_identity_address=address,
|
||||
peer_irk=irk,
|
||||
local_irk=self.irk,
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Enable address resolution
|
||||
@@ -1178,28 +1219,24 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
if self.classic_enabled:
|
||||
await self.send_command(
|
||||
HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8'))
|
||||
HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) # type: ignore[call-arg]
|
||||
)
|
||||
await self.send_command(
|
||||
HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device)
|
||||
HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device) # type: ignore[call-arg]
|
||||
)
|
||||
await self.send_command(
|
||||
HCI_Write_Simple_Pairing_Mode_Command(
|
||||
simple_pairing_mode=int(self.classic_ssp_enabled)
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
await self.send_command(
|
||||
HCI_Write_Secure_Connections_Host_Support_Command(
|
||||
secure_connections_host_support=int(self.classic_sc_enabled)
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
await self.set_connectable(self.connectable)
|
||||
await self.set_discoverable(self.discoverable)
|
||||
|
||||
# Let the SMP manager know about the address
|
||||
# TODO: allow using a public address
|
||||
self.smp_manager.address = self.random_address
|
||||
|
||||
# Done
|
||||
self.powered_on = True
|
||||
|
||||
@@ -1221,11 +1258,11 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def start_advertising(
|
||||
self,
|
||||
advertising_type=AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
||||
target=None,
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
auto_restart=False,
|
||||
):
|
||||
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
||||
target: Optional[Address] = None,
|
||||
own_address_type: int = OwnAddressType.RANDOM,
|
||||
auto_restart: bool = False,
|
||||
) -> None:
|
||||
# If we're advertising, stop first
|
||||
if self.advertising:
|
||||
await self.stop_advertising()
|
||||
@@ -1235,7 +1272,7 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Advertising_Data_Command(
|
||||
advertising_data=self.advertising_data
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1244,7 +1281,7 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Scan_Response_Data_Command(
|
||||
scan_response_data=self.scan_response_data
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1270,13 +1307,13 @@ class Device(CompositeEventEmitter):
|
||||
peer_address=peer_address,
|
||||
advertising_channel_map=7,
|
||||
advertising_filter_policy=0,
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
# Enable advertising
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1),
|
||||
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1285,11 +1322,11 @@ class Device(CompositeEventEmitter):
|
||||
self.advertising_type = advertising_type
|
||||
self.advertising = True
|
||||
|
||||
async def stop_advertising(self):
|
||||
async def stop_advertising(self) -> None:
|
||||
# Disable advertising
|
||||
if self.advertising:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0),
|
||||
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1304,14 +1341,14 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def start_scanning(
|
||||
self,
|
||||
legacy=False,
|
||||
active=True,
|
||||
scan_interval=DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms
|
||||
scan_window=DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
filter_duplicates=False,
|
||||
scanning_phys=(HCI_LE_1M_PHY, HCI_LE_CODED_PHY),
|
||||
):
|
||||
legacy: bool = False,
|
||||
active: bool = True,
|
||||
scan_interval: int = DEVICE_DEFAULT_SCAN_INTERVAL, # Scan interval in ms
|
||||
scan_window: int = DEVICE_DEFAULT_SCAN_WINDOW, # Scan window in ms
|
||||
own_address_type: int = OwnAddressType.RANDOM,
|
||||
filter_duplicates: bool = False,
|
||||
scanning_phys: Tuple[int, int] = (HCI_LE_1M_PHY, HCI_LE_CODED_PHY),
|
||||
) -> None:
|
||||
# Check that the arguments are legal
|
||||
if scan_interval < scan_window:
|
||||
raise ValueError('scan_interval must be >= scan_window')
|
||||
@@ -1361,7 +1398,7 @@ class Device(CompositeEventEmitter):
|
||||
scan_types=[scan_type] * scanning_phy_count,
|
||||
scan_intervals=[int(scan_window / 0.625)] * scanning_phy_count,
|
||||
scan_windows=[int(scan_window / 0.625)] * scanning_phy_count,
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1372,7 +1409,7 @@ class Device(CompositeEventEmitter):
|
||||
filter_duplicates=1 if filter_duplicates else 0,
|
||||
duration=0, # TODO allow other values
|
||||
period=0, # TODO allow other values
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
else:
|
||||
@@ -1390,7 +1427,7 @@ class Device(CompositeEventEmitter):
|
||||
le_scan_window=int(scan_window / 0.625),
|
||||
own_address_type=own_address_type,
|
||||
scanning_filter_policy=HCI_LE_Set_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY,
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1398,25 +1435,25 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Scan_Enable_Command(
|
||||
le_scan_enable=1, filter_duplicates=1 if filter_duplicates else 0
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
self.scanning_is_passive = not active
|
||||
self.scanning = True
|
||||
|
||||
async def stop_scanning(self):
|
||||
async def stop_scanning(self) -> None:
|
||||
# Disable scanning
|
||||
if self.supports_le_feature(HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE):
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Extended_Scan_Enable_Command(
|
||||
enable=0, filter_duplicates=0, duration=0, period=0
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
else:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0),
|
||||
HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1434,9 +1471,9 @@ class Device(CompositeEventEmitter):
|
||||
if advertisement := accumulator.update(report):
|
||||
self.emit('advertisement', advertisement)
|
||||
|
||||
async def start_discovery(self, auto_restart=True):
|
||||
async def start_discovery(self, auto_restart: bool = True) -> None:
|
||||
await self.send_command(
|
||||
HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE),
|
||||
HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1445,7 +1482,7 @@ class Device(CompositeEventEmitter):
|
||||
lap=HCI_GENERAL_INQUIRY_LAP,
|
||||
inquiry_length=DEVICE_DEFAULT_INQUIRY_LENGTH,
|
||||
num_responses=0, # Unlimited number of responses.
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
if response.status != HCI_Command_Status_Event.PENDING:
|
||||
self.discovering = False
|
||||
@@ -1454,9 +1491,9 @@ class Device(CompositeEventEmitter):
|
||||
self.auto_restart_inquiry = auto_restart
|
||||
self.discovering = True
|
||||
|
||||
async def stop_discovery(self):
|
||||
async def stop_discovery(self) -> None:
|
||||
if self.discovering:
|
||||
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True)
|
||||
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True) # type: ignore[call-arg]
|
||||
self.auto_restart_inquiry = True
|
||||
self.discovering = False
|
||||
|
||||
@@ -1484,7 +1521,7 @@ class Device(CompositeEventEmitter):
|
||||
HCI_Write_Scan_Enable_Command(scan_enable=scan_enable)
|
||||
)
|
||||
|
||||
async def set_discoverable(self, discoverable=True):
|
||||
async def set_discoverable(self, discoverable: bool = True) -> None:
|
||||
self.discoverable = discoverable
|
||||
if self.classic_enabled:
|
||||
# Synthesize an inquiry response if none is set already
|
||||
@@ -1504,7 +1541,7 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_Write_Extended_Inquiry_Response_Command(
|
||||
fec_required=0, extended_inquiry_response=self.inquiry_response
|
||||
),
|
||||
), # type: ignore[call-arg]
|
||||
check_result=True,
|
||||
)
|
||||
await self.set_scan_enable(
|
||||
@@ -1512,7 +1549,7 @@ class Device(CompositeEventEmitter):
|
||||
page_scan_enabled=self.connectable,
|
||||
)
|
||||
|
||||
async def set_connectable(self, connectable=True):
|
||||
async def set_connectable(self, connectable: bool = True) -> None:
|
||||
self.connectable = connectable
|
||||
if self.classic_enabled:
|
||||
await self.set_scan_enable(
|
||||
@@ -1522,12 +1559,14 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
peer_address,
|
||||
transport=BT_LE_TRANSPORT,
|
||||
connection_parameters_preferences=None,
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT,
|
||||
):
|
||||
peer_address: Union[Address, str],
|
||||
transport: int = BT_LE_TRANSPORT,
|
||||
connection_parameters_preferences: Optional[
|
||||
Dict[int, ConnectionParametersPreferences]
|
||||
] = None,
|
||||
own_address_type: int = OwnAddressType.RANDOM,
|
||||
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
|
||||
) -> Connection:
|
||||
'''
|
||||
Request a connection to a peer.
|
||||
When transport is BLE, this method cannot be called if there is already a
|
||||
@@ -1574,6 +1613,8 @@ class Device(CompositeEventEmitter):
|
||||
):
|
||||
raise ValueError('BR/EDR addresses must be PUBLIC')
|
||||
|
||||
assert isinstance(peer_address, Address)
|
||||
|
||||
def on_connection(connection):
|
||||
if transport == BT_LE_TRANSPORT or (
|
||||
# match BR/EDR connection event against peer address
|
||||
@@ -1691,7 +1732,7 @@ class Device(CompositeEventEmitter):
|
||||
supervision_timeouts=supervision_timeouts,
|
||||
min_ce_lengths=min_ce_lengths,
|
||||
max_ce_lengths=max_ce_lengths,
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
else:
|
||||
if HCI_LE_1M_PHY not in connection_parameters_preferences:
|
||||
@@ -1720,7 +1761,7 @@ class Device(CompositeEventEmitter):
|
||||
supervision_timeout=int(prefs.supervision_timeout / 10),
|
||||
min_ce_length=int(prefs.min_ce_length / 0.625),
|
||||
max_ce_length=int(prefs.max_ce_length / 0.625),
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
else:
|
||||
# Save pending connection
|
||||
@@ -1737,7 +1778,7 @@ class Device(CompositeEventEmitter):
|
||||
clock_offset=0x0000,
|
||||
allow_role_switch=0x01,
|
||||
reserved=0,
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
if result.status != HCI_Command_Status_Event.PENDING:
|
||||
@@ -1756,10 +1797,10 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
await self.send_command(HCI_LE_Create_Connection_Cancel_Command())
|
||||
await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) # type: ignore[call-arg]
|
||||
else:
|
||||
await self.send_command(
|
||||
HCI_Create_Connection_Cancel_Command(bd_addr=peer_address)
|
||||
HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1777,10 +1818,10 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def accept(
|
||||
self,
|
||||
peer_address=Address.ANY,
|
||||
role=BT_PERIPHERAL_ROLE,
|
||||
timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT,
|
||||
):
|
||||
peer_address: Union[Address, str] = Address.ANY,
|
||||
role: int = BT_PERIPHERAL_ROLE,
|
||||
timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT,
|
||||
) -> Connection:
|
||||
'''
|
||||
Wait and accept any incoming connection or a connection from `peer_address` when
|
||||
set.
|
||||
@@ -1802,22 +1843,24 @@ class Device(CompositeEventEmitter):
|
||||
peer_address, BT_BR_EDR_TRANSPORT
|
||||
) # TODO: timeout
|
||||
|
||||
assert isinstance(peer_address, Address)
|
||||
|
||||
if peer_address == Address.NIL:
|
||||
raise ValueError('accept on nil address')
|
||||
|
||||
# Create a future so that we can wait for the request
|
||||
pending_request = asyncio.get_running_loop().create_future()
|
||||
pending_request_fut = asyncio.get_running_loop().create_future()
|
||||
|
||||
if peer_address == Address.ANY:
|
||||
self.classic_pending_accepts[Address.ANY].append(pending_request)
|
||||
self.classic_pending_accepts[Address.ANY].append(pending_request_fut)
|
||||
elif peer_address in self.classic_pending_accepts:
|
||||
raise InvalidStateError('accept connection already pending')
|
||||
else:
|
||||
self.classic_pending_accepts[peer_address] = pending_request
|
||||
self.classic_pending_accepts[peer_address] = [pending_request_fut]
|
||||
|
||||
try:
|
||||
# Wait for a request or a completed connection
|
||||
pending_request = self.abort_on('flush', pending_request)
|
||||
pending_request = self.abort_on('flush', pending_request_fut)
|
||||
result = await (
|
||||
asyncio.wait_for(pending_request, timeout)
|
||||
if timeout
|
||||
@@ -1826,7 +1869,7 @@ class Device(CompositeEventEmitter):
|
||||
except Exception:
|
||||
# Remove future from device context
|
||||
if peer_address == Address.ANY:
|
||||
self.classic_pending_accepts[Address.ANY].remove(pending_request)
|
||||
self.classic_pending_accepts[Address.ANY].remove(pending_request_fut)
|
||||
else:
|
||||
self.classic_pending_accepts.pop(peer_address)
|
||||
raise
|
||||
@@ -1838,6 +1881,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# Otherwise, result came from `on_connection_request`
|
||||
peer_address, _class_of_device, _link_type = result
|
||||
assert isinstance(peer_address, Address)
|
||||
|
||||
# Create a future so that we can wait for the connection's result
|
||||
pending_connection = asyncio.get_running_loop().create_future()
|
||||
@@ -1867,7 +1911,7 @@ class Device(CompositeEventEmitter):
|
||||
try:
|
||||
# Accept connection request
|
||||
await self.send_command(
|
||||
HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role)
|
||||
HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role) # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Wait for connection complete
|
||||
@@ -2243,7 +2287,7 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
|
||||
# [Classic only]
|
||||
async def request_remote_name(self, remote): # remote: Connection | Address
|
||||
async def request_remote_name(self, remote: Union[Address, Connection]) -> str:
|
||||
# Set up event handlers
|
||||
pending_name = asyncio.get_running_loop().create_future()
|
||||
|
||||
@@ -2271,7 +2315,7 @@ class Device(CompositeEventEmitter):
|
||||
page_scan_repetition_mode=HCI_Remote_Name_Request_Command.R2,
|
||||
reserved=0,
|
||||
clock_offset=0, # TODO investigate non-0 values
|
||||
)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
if result.status != HCI_COMMAND_STATUS_PENDING:
|
||||
@@ -2372,7 +2416,7 @@ class Device(CompositeEventEmitter):
|
||||
# In this case, set the completed `connection` to the `accept` future
|
||||
# result.
|
||||
if peer_address in self.classic_pending_accepts:
|
||||
future = self.classic_pending_accepts.pop(peer_address)
|
||||
future, *_ = self.classic_pending_accepts.pop(peer_address)
|
||||
future.set_result(connection)
|
||||
|
||||
# Emit an event to notify listeners of the new connection
|
||||
@@ -2473,7 +2517,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# match a pending future using `bd_addr`
|
||||
if bd_addr in self.classic_pending_accepts:
|
||||
future = self.classic_pending_accepts.pop(bd_addr)
|
||||
future, *_ = self.classic_pending_accepts.pop(bd_addr)
|
||||
future.set_result((bd_addr, class_of_device, link_type))
|
||||
|
||||
# match first pending future for ANY address
|
||||
|
||||
@@ -28,7 +28,7 @@ import enum
|
||||
import functools
|
||||
import logging
|
||||
import struct
|
||||
from typing import Sequence
|
||||
from typing import Optional, Sequence
|
||||
from colors import color
|
||||
|
||||
from .core import UUID, get_dict_key_by_value
|
||||
@@ -204,6 +204,8 @@ class Service(Attribute):
|
||||
See Vol 3, Part G - 3.1 SERVICE DEFINITION
|
||||
'''
|
||||
|
||||
uuid: UUID
|
||||
|
||||
def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
|
||||
# Convert the uuid to a UUID object if it isn't already
|
||||
if isinstance(uuid, str):
|
||||
@@ -221,7 +223,7 @@ class Service(Attribute):
|
||||
self.characteristics = characteristics[:]
|
||||
self.primary = primary
|
||||
|
||||
def get_advertising_data(self):
|
||||
def get_advertising_data(self) -> Optional[bytes]:
|
||||
"""
|
||||
Get Service specific advertising data
|
||||
Defined by each Service, default value is empty
|
||||
|
||||
@@ -27,7 +27,7 @@ import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import struct
|
||||
from typing import Tuple, Optional
|
||||
from typing import List, Tuple, Optional
|
||||
from pyee import EventEmitter
|
||||
from colors import color
|
||||
|
||||
@@ -90,6 +90,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
||||
# GATT Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server(EventEmitter):
|
||||
attributes: List[Attribute]
|
||||
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
@@ -140,6 +142,7 @@ class Server(EventEmitter):
|
||||
attribute
|
||||
for attribute in self.attributes
|
||||
if attribute.type == GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
|
||||
and isinstance(attribute, Service)
|
||||
and attribute.uuid == service_uuid
|
||||
),
|
||||
None,
|
||||
|
||||
@@ -21,7 +21,7 @@ import collections
|
||||
import logging
|
||||
import functools
|
||||
from colors import color
|
||||
from typing import Dict, Type
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
from .core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
@@ -1729,7 +1729,9 @@ class Address:
|
||||
address_type = data[offset - 1]
|
||||
return Address.parse_address_with_type(data, offset, address_type)
|
||||
|
||||
def __init__(self, address, address_type=RANDOM_DEVICE_ADDRESS):
|
||||
def __init__(
|
||||
self, address: Union[bytes, str], address_type: int = RANDOM_DEVICE_ADDRESS
|
||||
):
|
||||
'''
|
||||
Initialize an instance. `address` may be a byte array in little-endian
|
||||
format, or a hex string in big-endian format (with optional ':'
|
||||
|
||||
@@ -141,7 +141,7 @@ class Host(AbortableEventEmitter):
|
||||
if controller_sink:
|
||||
self.set_packet_sink(controller_sink)
|
||||
|
||||
async def flush(self):
|
||||
async def flush(self) -> None:
|
||||
# Make sure no command is pending
|
||||
await self.command_semaphore.acquire()
|
||||
|
||||
@@ -660,7 +660,7 @@ class Host(AbortableEventEmitter):
|
||||
connection_handle=event.connection_handle,
|
||||
interval_min=event.interval_min,
|
||||
interval_max=event.interval_max,
|
||||
latency=event.latency,
|
||||
max_latency=event.max_latency,
|
||||
timeout=event.timeout,
|
||||
min_ce_length=0,
|
||||
max_ce_length=0,
|
||||
|
||||
@@ -24,6 +24,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import Optional
|
||||
from colors import color
|
||||
|
||||
from .hci import Address
|
||||
@@ -129,7 +130,7 @@ class PairingKeys:
|
||||
for (key_property, key_value) in value.items():
|
||||
print(f'{prefix} {color(key_property, "green")}: {key_value}')
|
||||
else:
|
||||
print(f'{prefix}{color(property, "cyan")}: {value}')
|
||||
print(f'{prefix}{color(container_property, "cyan")}: {value}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -216,7 +217,7 @@ class JsonKeyStore(KeyStore):
|
||||
params = device_config.keystore.split(':', 1)[1:]
|
||||
namespace = str(device_config.address)
|
||||
if params:
|
||||
filename = params[1]
|
||||
filename = params[0]
|
||||
else:
|
||||
filename = None
|
||||
|
||||
@@ -242,7 +243,7 @@ class JsonKeyStore(KeyStore):
|
||||
# Atomically replace the previous file
|
||||
os.rename(temp_filename, self.filename)
|
||||
|
||||
async def delete(self, name):
|
||||
async def delete(self, name: str) -> None:
|
||||
db = await self.load()
|
||||
|
||||
namespace = db.get(self.namespace)
|
||||
@@ -278,7 +279,7 @@ class JsonKeyStore(KeyStore):
|
||||
|
||||
await self.save(db)
|
||||
|
||||
async def get(self, name):
|
||||
async def get(self, name: str) -> Optional[PairingKeys]:
|
||||
db = await self.load()
|
||||
|
||||
namespace = db.get(self.namespace)
|
||||
|
||||
@@ -20,7 +20,7 @@ import logging
|
||||
import struct
|
||||
from colors import color
|
||||
import colors
|
||||
from typing import Dict, Type
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from . import core
|
||||
from .core import InvalidStateError
|
||||
@@ -183,63 +183,63 @@ class DataElement:
|
||||
raise ValueError('integer types must have a value size specified')
|
||||
|
||||
@staticmethod
|
||||
def nil():
|
||||
def nil() -> DataElement:
|
||||
return DataElement(DataElement.NIL, None)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer(value, value_size):
|
||||
def unsigned_integer(value: int, value_size: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_8(value):
|
||||
def unsigned_integer_8(value: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=1)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_16(value):
|
||||
def unsigned_integer_16(value: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=2)
|
||||
|
||||
@staticmethod
|
||||
def unsigned_integer_32(value):
|
||||
def unsigned_integer_32(value: int) -> DataElement:
|
||||
return DataElement(DataElement.UNSIGNED_INTEGER, value, value_size=4)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer(value, value_size):
|
||||
def signed_integer(value: int, value_size: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_8(value):
|
||||
def signed_integer_8(value: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=1)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_16(value):
|
||||
def signed_integer_16(value: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=2)
|
||||
|
||||
@staticmethod
|
||||
def signed_integer_32(value):
|
||||
def signed_integer_32(value: int) -> DataElement:
|
||||
return DataElement(DataElement.SIGNED_INTEGER, value, value_size=4)
|
||||
|
||||
@staticmethod
|
||||
def uuid(value):
|
||||
def uuid(value: core.UUID) -> DataElement:
|
||||
return DataElement(DataElement.UUID, value)
|
||||
|
||||
@staticmethod
|
||||
def text_string(value):
|
||||
def text_string(value: str) -> DataElement:
|
||||
return DataElement(DataElement.TEXT_STRING, value)
|
||||
|
||||
@staticmethod
|
||||
def boolean(value):
|
||||
def boolean(value: bool) -> DataElement:
|
||||
return DataElement(DataElement.BOOLEAN, value)
|
||||
|
||||
@staticmethod
|
||||
def sequence(value):
|
||||
def sequence(value: List[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.SEQUENCE, value)
|
||||
|
||||
@staticmethod
|
||||
def alternative(value):
|
||||
def alternative(value: List[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.ALTERNATIVE, value)
|
||||
|
||||
@staticmethod
|
||||
def url(value):
|
||||
def url(value: str) -> DataElement:
|
||||
return DataElement(DataElement.URL, value)
|
||||
|
||||
@staticmethod
|
||||
@@ -458,7 +458,7 @@ class DataElement:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ServiceAttribute:
|
||||
def __init__(self, attribute_id, value):
|
||||
def __init__(self, attribute_id: int, value: DataElement) -> None:
|
||||
self.id = attribute_id
|
||||
self.value = value
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import asyncio
|
||||
import secrets
|
||||
from typing import Dict, Type
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from pyee import EventEmitter
|
||||
from colors import color
|
||||
@@ -504,27 +504,29 @@ class PairingDelegate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
io_capability=NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution=DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution=DEFAULT_KEY_DISTRIBUTION,
|
||||
):
|
||||
io_capability: int = NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution: int = DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution: int = DEFAULT_KEY_DISTRIBUTION,
|
||||
) -> None:
|
||||
self.io_capability = io_capability
|
||||
self.local_initiator_key_distribution = local_initiator_key_distribution
|
||||
self.local_responder_key_distribution = local_responder_key_distribution
|
||||
|
||||
async def accept(self):
|
||||
async def accept(self) -> bool:
|
||||
return True
|
||||
|
||||
async def confirm(self):
|
||||
async def confirm(self) -> bool:
|
||||
return True
|
||||
|
||||
async def compare_numbers(self, _number, _digits=6):
|
||||
# pylint: disable-next=unused-argument
|
||||
async def compare_numbers(self, number: int, digits: int) -> bool:
|
||||
return True
|
||||
|
||||
async def get_number(self):
|
||||
async def get_number(self) -> int:
|
||||
return 0
|
||||
|
||||
async def display_number(self, _number, _digits=6):
|
||||
# pylint: disable-next=unused-argument
|
||||
async def display_number(self, number: int, digits: int) -> None:
|
||||
pass
|
||||
|
||||
async def key_distribution_response(
|
||||
@@ -538,7 +540,13 @@ class PairingDelegate:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class PairingConfig:
|
||||
def __init__(self, sc=True, mitm=True, bonding=True, delegate=None):
|
||||
def __init__(
|
||||
self,
|
||||
sc: bool = True,
|
||||
mitm: bool = True,
|
||||
bonding: bool = True,
|
||||
delegate: Optional[PairingDelegate] = None,
|
||||
) -> None:
|
||||
self.sc = sc
|
||||
self.mitm = mitm
|
||||
self.bonding = bonding
|
||||
@@ -655,7 +663,8 @@ class Session:
|
||||
self.peer_expected_distributions = []
|
||||
self.dh_key = None
|
||||
self.confirm_value = None
|
||||
self.passkey = 0
|
||||
self.passkey = None
|
||||
self.passkey_ready = asyncio.Event()
|
||||
self.passkey_step = 0
|
||||
self.passkey_display = False
|
||||
self.pairing_method = 0
|
||||
@@ -833,6 +842,7 @@ class Session:
|
||||
# Generate random Passkey/PIN code
|
||||
self.passkey = secrets.randbelow(1000000)
|
||||
logger.debug(f'Pairing PIN CODE: {self.passkey:06}')
|
||||
self.passkey_ready.set()
|
||||
|
||||
# The value of TK is computed from the PIN code
|
||||
if not self.sc:
|
||||
@@ -853,6 +863,8 @@ class Session:
|
||||
self.tk = passkey.to_bytes(16, byteorder='little')
|
||||
logger.debug(f'TK from passkey = {self.tk.hex()}')
|
||||
|
||||
self.passkey_ready.set()
|
||||
|
||||
if next_steps is not None:
|
||||
next_steps()
|
||||
|
||||
@@ -904,17 +916,29 @@ class Session:
|
||||
logger.debug(f'generated random: {self.r.hex()}')
|
||||
|
||||
if self.sc:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
z = 0
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
|
||||
else:
|
||||
return
|
||||
|
||||
if self.is_initiator:
|
||||
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
|
||||
else:
|
||||
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
|
||||
async def next_steps():
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
z = 0
|
||||
elif self.pairing_method == self.PASSKEY:
|
||||
# We need a passkey
|
||||
await self.passkey_ready.wait()
|
||||
|
||||
z = 0x80 + ((self.passkey >> self.passkey_step) & 1)
|
||||
else:
|
||||
return
|
||||
|
||||
if self.is_initiator:
|
||||
confirm_value = crypto.f4(self.pka, self.pkb, self.r, bytes([z]))
|
||||
else:
|
||||
confirm_value = crypto.f4(self.pkb, self.pka, self.r, bytes([z]))
|
||||
|
||||
self.send_command(
|
||||
SMP_Pairing_Confirm_Command(confirm_value=confirm_value)
|
||||
)
|
||||
|
||||
# Perform the next steps asynchronously in case we need to wait for input
|
||||
self.connection.abort_on('disconnection', next_steps())
|
||||
else:
|
||||
confirm_value = crypto.c1(
|
||||
self.tk,
|
||||
@@ -927,7 +951,7 @@ class Session:
|
||||
self.ra,
|
||||
)
|
||||
|
||||
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
|
||||
self.send_command(SMP_Pairing_Confirm_Command(confirm_value=confirm_value))
|
||||
|
||||
def send_pairing_random_command(self):
|
||||
self.send_command(SMP_Pairing_Random_Command(random_value=self.r))
|
||||
@@ -1358,8 +1382,8 @@ class Session:
|
||||
|
||||
# Start phase 2
|
||||
if self.sc:
|
||||
if self.pairing_method == self.PASSKEY and self.passkey_display:
|
||||
self.display_passkey()
|
||||
if self.pairing_method == self.PASSKEY:
|
||||
self.display_or_input_passkey()
|
||||
|
||||
self.send_public_key_command()
|
||||
else:
|
||||
@@ -1420,18 +1444,22 @@ class Session:
|
||||
else:
|
||||
srand = self.r
|
||||
mrand = command.random_value
|
||||
stk = crypto.s1(self.tk, srand, mrand)
|
||||
logger.debug(f'STK = {stk.hex()}')
|
||||
self.stk = crypto.s1(self.tk, srand, mrand)
|
||||
logger.debug(f'STK = {self.stk.hex()}')
|
||||
|
||||
# Generate LTK
|
||||
self.ltk = crypto.r()
|
||||
|
||||
if self.is_initiator:
|
||||
self.start_encryption(stk)
|
||||
self.start_encryption(self.stk)
|
||||
else:
|
||||
self.send_pairing_random_command()
|
||||
|
||||
def on_smp_pairing_random_command_secure_connections(self, command):
|
||||
if self.pairing_method == self.PASSKEY and self.passkey is None:
|
||||
logger.warning('no passkey entered, ignoring command')
|
||||
return
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
if self.is_initiator:
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
@@ -1559,17 +1587,13 @@ class Session:
|
||||
logger.debug(f'DH key: {self.dh_key.hex()}')
|
||||
|
||||
if self.is_initiator:
|
||||
if self.pairing_method == self.PASSKEY:
|
||||
if self.passkey_display:
|
||||
self.send_pairing_confirm_command()
|
||||
else:
|
||||
self.input_passkey(self.send_pairing_confirm_command)
|
||||
self.send_pairing_confirm_command()
|
||||
else:
|
||||
# Send our public key back to the initiator
|
||||
if self.pairing_method == self.PASSKEY:
|
||||
self.display_or_input_passkey(self.send_public_key_command)
|
||||
else:
|
||||
self.send_public_key_command()
|
||||
self.display_or_input_passkey()
|
||||
|
||||
# Send our public key back to the initiator
|
||||
self.send_public_key_command()
|
||||
|
||||
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
|
||||
# We can now send the confirmation value
|
||||
|
||||
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def open_transport(name):
|
||||
async def open_transport(name: str) -> Transport:
|
||||
'''
|
||||
Open a transport by name.
|
||||
The name must be <type>:<parameters>
|
||||
|
||||
@@ -259,7 +259,7 @@ class Transport:
|
||||
def __iter__(self):
|
||||
return iter((self.source, self.sink))
|
||||
|
||||
async def close(self):
|
||||
async def close(self) -> None:
|
||||
self.source.close()
|
||||
self.sink.close()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import logging
|
||||
import traceback
|
||||
import collections
|
||||
import sys
|
||||
from typing import Awaitable
|
||||
from typing import Awaitable, TypeVar
|
||||
from functools import wraps
|
||||
from colors import color
|
||||
from pyee import EventEmitter
|
||||
@@ -65,8 +65,11 @@ def composite_listener(cls):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
class AbortableEventEmitter(EventEmitter):
|
||||
def abort_on(self, event: str, awaitable: Awaitable):
|
||||
def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
|
||||
"""
|
||||
Set a coroutine or future to abort when an event occur.
|
||||
"""
|
||||
@@ -75,6 +78,8 @@ class AbortableEventEmitter(EventEmitter):
|
||||
return future
|
||||
|
||||
def on_event(*_):
|
||||
if future.done():
|
||||
return
|
||||
msg = f'abort: {event} event occurred.'
|
||||
if isinstance(future, asyncio.Task):
|
||||
# python < 3.9 does not support passing a message on `Task.cancel`
|
||||
|
||||
@@ -64,11 +64,11 @@ module = "aioconsole.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "bitstruct.*"
|
||||
module = "colors.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "colors.*"
|
||||
module = "construct.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
|
||||
11
setup.cfg
11
setup.cfg
@@ -30,15 +30,14 @@ package_dir =
|
||||
bumble.apps = apps
|
||||
include-package-data = True
|
||||
install_requires =
|
||||
aioconsole >= 0.4.1
|
||||
ansicolors >= 1.1
|
||||
appdirs >= 1.4
|
||||
bitstruct >= 8.12
|
||||
click >= 7.1.2; platform_system!='Emscripten'
|
||||
construct >= 2.10
|
||||
cryptography == 35; platform_system!='Emscripten'
|
||||
grpcio >= 1.46; platform_system!='Emscripten'
|
||||
libusb1 >= 2.0.1; platform_system!='Emscripten'
|
||||
libusb-package == 1.0.26.0; platform_system!='Emscripten'
|
||||
libusb-package == 1.0.26.1; platform_system!='Emscripten'
|
||||
prompt_toolkit >= 3.0.16; platform_system!='Emscripten'
|
||||
protobuf >= 3.12.4
|
||||
pyee >= 8.2.2
|
||||
@@ -73,11 +72,11 @@ test =
|
||||
pytest-html >= 3.2.0
|
||||
coverage >= 6.4
|
||||
development =
|
||||
black >= 22.10
|
||||
black == 22.10
|
||||
invoke >= 1.7.3
|
||||
mypy >= 0.991
|
||||
mypy == 0.991
|
||||
nox >= 2022
|
||||
pylint >= 2.15.8
|
||||
pylint == 2.15.8
|
||||
types-appdirs >= 1.4.3
|
||||
types-invoke >= 1.7.3
|
||||
types-protobuf >= 4.21.0
|
||||
|
||||
@@ -25,10 +25,8 @@ def test_ad_data():
|
||||
assert data == ad_bytes
|
||||
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None
|
||||
assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])
|
||||
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []
|
||||
assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [
|
||||
bytes([123])
|
||||
]
|
||||
assert ad.get_all(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) == []
|
||||
assert ad.get_all(AdvertisingData.TX_POWER_LEVEL, raw=True) == [bytes([123])]
|
||||
|
||||
data2 = bytes([2, AdvertisingData.TX_POWER_LEVEL, 234])
|
||||
ad.append(data2)
|
||||
@@ -36,8 +34,8 @@ def test_ad_data():
|
||||
assert ad_bytes == data + data2
|
||||
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None
|
||||
assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])
|
||||
assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []
|
||||
assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [
|
||||
assert ad.get_all(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) == []
|
||||
assert ad.get_all(AdvertisingData.TX_POWER_LEVEL, raw=True) == [
|
||||
bytes([123]),
|
||||
bytes([234]),
|
||||
]
|
||||
|
||||
@@ -32,7 +32,6 @@ from bumble.smp import (
|
||||
PairingDelegate,
|
||||
SMP_PAIRING_NOT_SUPPORTED_ERROR,
|
||||
SMP_CONFIRM_VALUE_FAILED_ERROR,
|
||||
SMP_ID_KEY_DISTRIBUTION_FLAG,
|
||||
)
|
||||
from bumble.core import ProtocolError
|
||||
|
||||
@@ -273,9 +272,15 @@ KEY_DIST = range(16)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'io_cap, sc, mitm, key_dist', itertools.product(IO_CAP, SC, MITM, KEY_DIST)
|
||||
'io_caps, sc, mitm, key_dist',
|
||||
itertools.chain(
|
||||
itertools.product([IO_CAP], SC, MITM, [15]),
|
||||
itertools.product(
|
||||
[[PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT]], SC, MITM, KEY_DIST
|
||||
),
|
||||
),
|
||||
)
|
||||
async def test_self_smp(io_cap, sc, mitm, key_dist):
|
||||
async def test_self_smp(io_caps, sc, mitm, key_dist):
|
||||
class Delegate(PairingDelegate):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -296,6 +301,7 @@ async def test_self_smp(io_cap, sc, mitm, key_dist):
|
||||
self.peer_delegate = None
|
||||
self.number = asyncio.get_running_loop().create_future()
|
||||
|
||||
# pylint: disable-next=unused-argument
|
||||
async def compare_numbers(self, number, digits):
|
||||
if self.peer_delegate is None:
|
||||
logger.warning(f'[{self.name}] no peer delegate')
|
||||
@@ -331,8 +337,9 @@ async def test_self_smp(io_cap, sc, mitm, key_dist):
|
||||
|
||||
pairing_config_sets = [('Initiator', [None]), ('Responder', [None])]
|
||||
for pairing_config_set in pairing_config_sets:
|
||||
delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist)
|
||||
pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate))
|
||||
for io_cap in io_caps:
|
||||
delegate = Delegate(pairing_config_set[0], io_cap, key_dist, key_dist)
|
||||
pairing_config_set[1].append(PairingConfig(sc, mitm, True, delegate))
|
||||
|
||||
for pairing_config1 in pairing_config_sets[0][1]:
|
||||
for pairing_config2 in pairing_config_sets[1][1]:
|
||||
|
||||
Reference in New Issue
Block a user