Merge pull request #851 from zxzxwu/fix

Fix some typos and annotations
This commit is contained in:
zxzxwu
2026-01-06 14:02:40 +08:00
committed by GitHub
6 changed files with 65 additions and 55 deletions

View File

@@ -15,6 +15,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
@@ -63,7 +65,7 @@ POST_PAIRING_DELAY = 1
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Waiter: class Waiter:
instance = None instance: Waiter | None = None
def __init__(self, linger=False): def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future() self.done = asyncio.get_running_loop().create_future()
@@ -327,25 +329,25 @@ async def on_pairing_failure(connection, reason):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def pair( async def pair(
mode, mode: str,
sc, sc: bool,
mitm, mitm: bool,
bond, bond: bool,
ctkd, ctkd: bool,
advertising_address, advertising_address: str,
identity_address, identity_address: str,
linger, linger: bool,
io, io: str,
oob, oob: str,
prompt, prompt: bool,
request, request: bool,
print_keys, print_keys: bool,
keystore_file, keystore_file: str,
advertise_service_uuids, advertise_service_uuids: str,
advertise_appearance, advertise_appearance: str,
device_config, device_config: str,
hci_transport, hci_transport: str,
address_or_name, address_or_name: str,
): ):
Waiter.instance = Waiter(linger=linger) Waiter.instance = Waiter(linger=linger)
@@ -403,6 +405,7 @@ async def pair(
# Create an OOB context if needed # Create an OOB context if needed
if oob: if oob:
our_oob_context = OobContext() our_oob_context = OobContext()
legacy_context: OobLegacyContext | None
if oob == '-': if oob == '-':
shared_data = None shared_data = None
legacy_context = OobLegacyContext() legacy_context = OobLegacyContext()
@@ -661,25 +664,25 @@ class LogHandler(logging.Handler):
@click.argument('hci_transport') @click.argument('hci_transport')
@click.argument('address-or-name', required=False) @click.argument('address-or-name', required=False)
def main( def main(
mode, mode: str,
sc, sc: bool,
mitm, mitm: bool,
bond, bond: bool,
ctkd, ctkd: bool,
advertising_address, advertising_address: str,
identity_address, identity_address: str,
linger, linger: bool,
io, io: str,
oob, oob: str,
prompt, prompt: bool,
request, request: bool,
print_keys, print_keys: bool,
keystore_file, keystore_file: str,
advertise_service_uuid, advertise_service_uuid: str,
advertise_appearance, advertise_appearance: str,
device_config, device_config: str,
hci_transport, hci_transport: str,
address_or_name, address_or_name: str,
): ):
# Setup logging # Setup logging
log_handler = LogHandler() log_handler = LogHandler()

View File

@@ -171,6 +171,7 @@ class Advertisement:
) )
sid: int = 0 sid: int = 0
data_bytes: bytes = b'' data_bytes: bytes = b''
data: AdvertisingData = field(init=False)
# Constants # Constants
TX_POWER_NOT_AVAILABLE: ClassVar[int] = ( TX_POWER_NOT_AVAILABLE: ClassVar[int] = (
@@ -480,6 +481,7 @@ class PeriodicAdvertisement:
rssi: int = hci.HCI_LE_Periodic_Advertising_Report_Event.RSSI_NOT_AVAILABLE rssi: int = hci.HCI_LE_Periodic_Advertising_Report_Event.RSSI_NOT_AVAILABLE
is_truncated: bool = False is_truncated: bool = False
data_bytes: bytes = b'' data_bytes: bytes = b''
data: AdvertisingData | None = field(init=False)
# Constants # Constants
TX_POWER_NOT_AVAILABLE: ClassVar[int] = ( TX_POWER_NOT_AVAILABLE: ClassVar[int] = (
@@ -1099,6 +1101,7 @@ class Big(utils.EventEmitter):
max_pdu: int = 0 max_pdu: int = 0
iso_interval: float = 0.0 # ISO interval, in milliseconds iso_interval: float = 0.0 # ISO interval, in milliseconds
bis_links: Sequence[BisLink] = () bis_links: Sequence[BisLink] = ()
device: Device = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__init__() super().__init__()
@@ -1160,6 +1163,7 @@ class BigSync(utils.EventEmitter):
max_pdu: int = 0 max_pdu: int = 0
iso_interval: float = 0.0 iso_interval: float = 0.0
bis_links: Sequence[BisLink] = () bis_links: Sequence[BisLink] = ()
device: Device = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__init__() super().__init__()
@@ -1655,6 +1659,7 @@ class BisLink(_IsoLink):
handle: int handle: int
big: Big | BigSync big: Big | BigSync
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
device: Device = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__init__() super().__init__()
@@ -2088,9 +2093,10 @@ class DeviceConfiguration:
l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE, l2cap.L2CAP_Information_Request.ExtendedFeatures.ENHANCED_RETRANSMISSION_MODE,
) )
eatt_enabled: bool = False eatt_enabled: bool = False
gatt_services: list[dict[str, Any]] = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.gatt_services: list[dict[str, Any]] = [] self.gatt_services = []
def load_from_dict(self, config: dict[str, Any]) -> None: def load_from_dict(self, config: dict[str, Any]) -> None:
config = copy.deepcopy(config) config = copy.deepcopy(config)
@@ -2270,6 +2276,7 @@ class Device(utils.CompositeEventEmitter):
big_syncs: dict[int, BigSync] big_syncs: dict[int, BigSync]
_pending_cis: dict[int, tuple[int, int]] _pending_cis: dict[int, tuple[int, int]]
gatt_service: gatt_service.GenericAttributeProfileService | None = None gatt_service: gatt_service.GenericAttributeProfileService | None = None
keystore: KeyStore | None = None
EVENT_ADVERTISEMENT = "advertisement" EVENT_ADVERTISEMENT = "advertisement"
EVENT_PERIODIC_ADVERTISING_SYNC_TRANSFER = "periodic_advertising_sync_transfer" EVENT_PERIODIC_ADVERTISING_SYNC_TRANSFER = "periodic_advertising_sync_transfer"
@@ -4527,8 +4534,8 @@ class Device(utils.CompositeEventEmitter):
ediv = 0 ediv = 0
elif keys.ltk_central is not None: elif keys.ltk_central is not None:
ltk = keys.ltk_central.value ltk = keys.ltk_central.value
rand = keys.ltk_central.rand rand = keys.ltk_central.rand or b''
ediv = keys.ltk_central.ediv ediv = keys.ltk_central.ediv or 0
else: else:
raise InvalidOperationError('no LTK found for peer') raise InvalidOperationError('no LTK found for peer')

View File

@@ -111,7 +111,7 @@ class Server(utils.EventEmitter):
) -> l2cap.LeCreditBasedChannelServer: ) -> l2cap.LeCreditBasedChannelServer:
def on_channel(channel: l2cap.LeCreditBasedChannel): def on_channel(channel: l2cap.LeCreditBasedChannel):
logger.debug( logger.debug(
"New EATT Bearer Conenction=0x%04X CID=0x%04X", "New EATT Bearer Connection=0x%04X CID=0x%04X",
channel.connection.handle, channel.connection.handle,
channel.source_cid, channel.source_cid,
) )

View File

@@ -292,9 +292,9 @@ async def test_legacy_advertising_disconnection(auto_restart):
await devices[0].start_advertising( await devices[0].start_advertising(
auto_restart=auto_restart, advertising_interval_min=1.0 auto_restart=auto_restart, advertising_interval_min=1.0
) )
connecion = await devices[1].connect(devices[0].random_address) connection = await devices[1].connect(devices[0].random_address)
await connecion.disconnect() await connection.disconnect()
await async_barrier() await async_barrier()
await async_barrier() await async_barrier()

View File

@@ -57,15 +57,13 @@ async def test_self_disconnection():
await two_devices.setup_connection() await two_devices.setup_connection()
await two_devices.connections[0].disconnect() await two_devices.connections[0].disconnect()
await async_barrier() await async_barrier()
assert two_devices.connections[0] is None assert not two_devices.connections
assert two_devices.connections[1] is None
two_devices = TwoDevices() two_devices = TwoDevices()
await two_devices.setup_connection() await two_devices.setup_connection()
await two_devices.connections[1].disconnect() await two_devices.connections[1].disconnect()
await async_barrier() await async_barrier()
assert two_devices.connections[0] is None assert not two_devices.connections
assert two_devices.connections[1] is None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -31,10 +31,10 @@ from bumble.transport.common import AsyncPipeSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Devices: class Devices:
connections: list[Connection | None] connections: dict[int, Connection]
def __init__(self, num_devices: int) -> None: def __init__(self, num_devices: int) -> None:
self.connections = [None for _ in range(num_devices)] self.connections = {}
self.link = LocalLink() self.link = LocalLink()
addresses = [":".join([f"F{i}"] * 6) for i in range(num_devices)] addresses = [":".join([f"F{i}"] * 6) for i in range(num_devices)]
@@ -60,12 +60,14 @@ class Devices:
asyncio.get_event_loop().create_future() for _ in range(num_devices) asyncio.get_event_loop().create_future() for _ in range(num_devices)
] ]
def on_connection(self, which, connection): def on_connection(self, which: int, connection: Connection) -> None:
self.connections[which] = connection self.connections[which] = connection
connection.on('disconnection', lambda code: self.on_disconnection(which)) connection.on(
connection.EVENT_DISCONNECTION, lambda *_: self.on_disconnection(which)
)
def on_disconnection(self, which): def on_disconnection(self, which: int) -> None:
self.connections[which] = None self.connections.pop(which, None)
def on_paired(self, which: int, keys: PairingKeys) -> None: def on_paired(self, which: int, keys: PairingKeys) -> None:
self.paired[which].set_result(keys) self.paired[which].set_result(keys)