Compare commits

..

1 Commits

Author SHA1 Message Date
7e4948d9ef add small asrc example 2025-10-06 11:04:06 +02:00
160 changed files with 4676 additions and 7188 deletions

View File

@@ -18,18 +18,18 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13.0", "3.14"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.0"]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies

View File

@@ -40,7 +40,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL

View File

@@ -22,10 +22,10 @@ jobs:
steps:
- name: Check out from Git
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Set up JDK
uses: actions/setup-java@v5
uses: actions/setup-java@v4
with:
distribution: 'zulu'
java-version: 17

View File

@@ -26,9 +26,9 @@ jobs:
21/24, 22/24, 23/24, 24/24,
]
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set Up Python 3.11
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install
@@ -46,7 +46,7 @@ jobs:
run: cat rootcanal.log
- name: Upload Mobly logs
if: always()
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: mobly-logs-${{ strategy.job-index }}
path: /tmp/logs/mobly/bumble.bumbles/

View File

@@ -18,18 +18,18 @@ jobs:
strategy:
matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
@@ -48,15 +48,14 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
# Rust runtime doesn't support 3.14 yet.
python-version: ["3.10", "3.11", "3.12", "3.13"]
rust-version: [ "1.80.0", "1.91.0" ]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
rust-version: [ "1.80.0", "stable" ]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
@@ -69,7 +68,7 @@ jobs:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features --version 1.11.0 # allows building/testing combinations of features
run: cargo install cargo-all-features # allows building/testing combinations of features
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build

View File

@@ -14,13 +14,13 @@ jobs:
steps:
- name: Check out from Git
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Install dependencies
@@ -31,7 +31,7 @@ jobs:
run: python -m build
- name: Publish package to PyPI
if: github.event_name == 'release' && startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1.13
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -50,7 +50,7 @@ Bumble is easiest to use with a dedicated USB dongle.
This is because internal Bluetooth interfaces tend to be locked down by the operating system.
You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if you are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if your are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
## License

File diff suppressed because it is too large Load Diff

View File

@@ -22,6 +22,7 @@ import logging
import statistics
import struct
import time
from typing import Optional
import click
@@ -256,8 +257,8 @@ async def pre_power_on(device: Device, classic: bool) -> None:
async def post_power_on(
device: Device,
le_scan: tuple[int, int] | None,
le_advertise: int | None,
le_scan: Optional[tuple[int, int]],
le_advertise: Optional[int],
classic_page_scan: bool,
classic_inquiry_scan: bool,
) -> None:
@@ -1299,7 +1300,7 @@ class IsoClient(StreamedPacketIO):
super().__init__()
self.device = device
self.ready = asyncio.Event()
self.cis_link: CisLink | None = None
self.cis_link: Optional[CisLink] = None
async def on_connection(
self, connection: Connection, cis_link: CisLink, sender: bool
@@ -1340,7 +1341,7 @@ class IsoServer(StreamedPacketIO):
):
super().__init__()
self.device = device
self.cis_link: CisLink | None = None
self.cis_link: Optional[CisLink] = None
self.ready = asyncio.Event()
logging.info(

View File

@@ -24,6 +24,7 @@ import logging
import os
import re
from collections import OrderedDict
from typing import Optional, Union
import click
import humanize
@@ -125,8 +126,8 @@ def parse_phys(phys):
# Console App
# -----------------------------------------------------------------------------
class ConsoleApp:
connected_peer: Peer | None
connection_phy: ConnectionPHY | None
connected_peer: Optional[Peer]
connection_phy: Optional[ConnectionPHY]
def __init__(self):
self.known_addresses = set()
@@ -519,7 +520,7 @@ class ConsoleApp:
self.show_attributes(attributes)
def find_remote_characteristic(self, param) -> CharacteristicProxy | None:
def find_remote_characteristic(self, param) -> Optional[CharacteristicProxy]:
if not self.connected_peer:
return None
parts = param.split('.')
@@ -541,7 +542,9 @@ class ConsoleApp:
return None
def find_local_attribute(self, param) -> Characteristic | Descriptor | None:
def find_local_attribute(
self, param
) -> Optional[Union[Characteristic, Descriptor]]:
parts = param.split('.')
if len(parts) == 3:
service_uuid = UUID(parts[0])
@@ -1093,7 +1096,9 @@ class DeviceListener(Device.Listener, Connection.Listener):
if self.app.connected_peer.connection.is_encrypted
else 'not encrypted'
)
self.app.append_to_output(f'connection encryption change: {encryption_state}')
self.app.append_to_output(
'connection encryption change: ' f'{encryption_state}'
)
def on_connection_data_length_change(self):
self.app.append_to_output(

View File

@@ -35,6 +35,8 @@ from bumble.hci import (
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_SUCCESS,
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
CodecID,
HCI_Command,
HCI_Command_Complete_Event,
@@ -52,7 +54,6 @@ from bumble.hci import (
HCI_Read_Local_Supported_Codecs_V2_Command,
HCI_Read_Local_Version_Information_Command,
LeFeature,
SpecificationVersion,
map_null_terminated_utf8_string,
)
from bumble.host import Host
@@ -274,7 +275,7 @@ async def async_main(
(
f'min={min(latencies):.2f}, '
f'max={max(latencies):.2f}, '
f'average={sum(latencies) / len(latencies):.2f},'
f'average={sum(latencies)/len(latencies):.2f},'
),
[f'{latency:.4}' for latency in latencies],
'\n',
@@ -288,20 +289,14 @@ async def async_main(
)
print(
color(' HCI Version: ', 'green'),
SpecificationVersion(host.local_version.hci_version).name,
)
print(
color(' HCI Subversion:', 'green'),
f'0x{host.local_version.hci_subversion:04x}',
name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version),
)
print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion)
print(
color(' LMP Version: ', 'green'),
SpecificationVersion(host.local_version.lmp_version).name,
)
print(
color(' LMP Subversion:', 'green'),
f'0x{host.local_version.lmp_subversion:04x}',
name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version),
)
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the Classic info
await get_classic_info(host)

View File

@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
import asyncio
import time
from typing import Optional
import click
@@ -40,7 +41,7 @@ class Loopback:
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: int | None = None
self.connection_handle: Optional[int] = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0

View File

@@ -16,7 +16,7 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
from collections.abc import Callable, Iterable
from typing import Callable, Iterable, Optional
import click
@@ -174,7 +174,7 @@ async def show_vcs(vcs: VolumeControlServiceProxy) -> None:
# -----------------------------------------------------------------------------
async def show_device_info(peer, done: asyncio.Future | None) -> None:
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
try:
# Discover all services
print(color('### Discovering Services and Characteristics', 'magenta'))
@@ -215,6 +215,7 @@ async def show_device_info(peer, done: asyncio.Future | None) -> None:
# -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport(transport) as (hci_source, hci_sink):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(

View File

@@ -61,6 +61,7 @@ async def dump_gatt_db(peer, done):
# -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport(transport) as (hci_source, hci_sink):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(

View File

@@ -268,6 +268,7 @@ class UiServer:
# -----------------------------------------------------------------------------
class Speaker:
def __init__(
self,
device_config_path: str | None,
@@ -298,7 +299,6 @@ class Speaker:
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
eatt_enabled=True,
)
device_config.le_enabled = True

View File

@@ -15,11 +15,10 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
import os
import struct
import click
from prompt_toolkit.shortcuts import PromptSession
@@ -65,7 +64,7 @@ POST_PAIRING_DELAY = 1
# -----------------------------------------------------------------------------
class Waiter:
instance: Waiter | None = None
instance = None
def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future()
@@ -329,25 +328,25 @@ async def on_pairing_failure(connection, reason):
# -----------------------------------------------------------------------------
async def pair(
mode: str,
sc: bool,
mitm: bool,
bond: bool,
ctkd: bool,
advertising_address: str,
identity_address: str,
linger: bool,
io: str,
oob: str,
prompt: bool,
request: bool,
print_keys: bool,
keystore_file: str,
advertise_service_uuids: str,
advertise_appearance: str,
device_config: str,
hci_transport: str,
address_or_name: str,
mode,
sc,
mitm,
bond,
ctkd,
advertising_address,
identity_address,
linger,
io,
oob,
prompt,
request,
print_keys,
keystore_file,
advertise_service_uuids,
advertise_appearance,
device_config,
hci_transport,
address_or_name,
):
Waiter.instance = Waiter(linger=linger)
@@ -405,7 +404,6 @@ async def pair(
# Create an OOB context if needed
if oob:
our_oob_context = OobContext()
legacy_context: OobLegacyContext | None
if oob == '-':
shared_data = None
legacy_context = OobLegacyContext()
@@ -530,9 +528,7 @@ async def pair(
if advertise_appearance:
advertise_appearance = advertise_appearance.upper()
try:
appearance = data_types.Appearance.from_int(
int(advertise_appearance)
)
advertise_appearance_int = int(advertise_appearance)
except ValueError:
category, subcategory = advertise_appearance.split('/')
try:
@@ -550,11 +546,12 @@ async def pair(
except ValueError:
print(color(f'Invalid subcategory {subcategory}', 'red'))
return
appearance = data_types.Appearance(
category_enum, subcategory_enum
advertise_appearance_int = int(
Appearance(category_enum, subcategory_enum)
)
advertising_data_types.append(appearance)
advertising_data_types.append(
data_types.Appearance(category_enum, subcategory_enum)
)
device.advertising_data = bytes(AdvertisingData(advertising_data_types))
await device.start_advertising(
auto_restart=True,
@@ -664,25 +661,25 @@ class LogHandler(logging.Handler):
@click.argument('hci_transport')
@click.argument('address-or-name', required=False)
def main(
mode: str,
sc: bool,
mitm: bool,
bond: bool,
ctkd: bool,
advertising_address: str,
identity_address: str,
linger: bool,
io: str,
oob: str,
prompt: bool,
request: bool,
print_keys: bool,
keystore_file: str,
advertise_service_uuid: str,
advertise_appearance: str,
device_config: str,
hci_transport: str,
address_or_name: str,
mode,
sc,
mitm,
bond,
ctkd,
advertising_address,
identity_address,
linger,
io,
oob,
prompt,
request,
print_keys,
keystore_file,
advertise_service_uuid,
advertise_appearance,
device_config,
hci_transport,
address_or_name,
):
# Setup logging
log_handler = LogHandler()

View File

@@ -19,7 +19,7 @@ ROOTCANAL_PORT_CUTTLEFISH = 7300
@click.option(
'--transport',
help='HCI transport',
default='tcp-client:127.0.0.1:<rootcanal-port>',
default=f'tcp-client:127.0.0.1:<rootcanal-port>',
)
@click.option(
'--config',
@@ -44,7 +44,7 @@ def retrieve_config(config: str) -> dict[str, Any]:
if not config:
return {}
with open(config) as f:
with open(config, 'r') as f:
return json.load(f)

View File

@@ -19,6 +19,7 @@ from __future__ import annotations
import asyncio
import logging
from typing import Optional, Union
import click
@@ -46,13 +47,14 @@ from bumble.avdtp import (
AVDTP_DELAY_REPORTING_SERVICE_CATEGORY,
MediaCodecCapabilities,
MediaPacketPump,
find_avdtp_service_with_connection,
)
from bumble.avdtp import Protocol as AvdtpProtocol
from bumble.avdtp import find_avdtp_service_with_connection
from bumble.avrcp import Protocol as AvrcpProtocol
from bumble.colors import color
from bumble.core import AdvertisingData, DeviceClass, PhysicalTransport
from bumble.core import AdvertisingData
from bumble.core import ConnectionError as BumbleConnectionError
from bumble.core import DeviceClass, PhysicalTransport
from bumble.device import Connection, Device, DeviceConfiguration
from bumble.hci import HCI_CONNECTION_ALREADY_EXISTS_ERROR, Address, HCI_Constant
from bumble.pairing import PairingConfig
@@ -189,7 +191,7 @@ class Player:
def __init__(
self,
transport: str,
device_config: str | None,
device_config: Optional[str],
authenticate: bool,
encrypt: bool,
) -> None:
@@ -197,8 +199,8 @@ class Player:
self.device_config = device_config
self.authenticate = authenticate
self.encrypt = encrypt
self.avrcp_protocol: AvrcpProtocol | None = None
self.done: asyncio.Event | None
self.avrcp_protocol: Optional[AvrcpProtocol] = None
self.done: Optional[asyncio.Event]
async def run(self, workload) -> None:
self.done = asyncio.Event()
@@ -313,7 +315,7 @@ class Player:
codec_type: int,
vendor_id: int,
codec_id: int,
packet_source: SbcPacketSource | AacPacketSource | OpusPacketSource,
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource],
codec_capabilities: MediaCodecCapabilities,
):
# Discover all endpoints on the remote device
@@ -379,11 +381,11 @@ class Player:
print(f">>> {color(address.to_string(False), 'yellow')}:")
print(f" Device Class (raw): {class_of_device:06X}")
major_class_name = DeviceClass.major_device_class_name(major_device_class)
print(f" Device Major Class: {major_class_name}")
print(" Device Major Class: " f"{major_class_name}")
minor_class_name = DeviceClass.minor_device_class_name(
major_device_class, minor_device_class
)
print(f" Device Minor Class: {minor_class_name}")
print(" Device Minor Class: " f"{minor_class_name}")
print(
" Device Services: "
f"{', '.join(DeviceClass.service_class_labels(service_classes))}"
@@ -418,7 +420,7 @@ class Player:
async def play(
self,
device: Device,
address: str | None,
address: Optional[str],
audio_format: str,
audio_file: str,
) -> None:
@@ -447,7 +449,7 @@ class Player:
return input_file.read(byte_count)
# Obtain the codec capabilities from the stream
packet_source: SbcPacketSource | AacPacketSource | OpusPacketSource
packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource]
vendor_id = 0
codec_id = 0
if audio_format == "sbc":

View File

@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
import asyncio
import time
from typing import Optional
import click
@@ -81,14 +82,14 @@ class ServerBridge:
def __init__(
self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
) -> None:
self.device: Device | None = None
self.device: Optional[Device] = None
self.channel = channel
self.uuid = uuid
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.rfcomm_channel: rfcomm.DLC | None = None
self.tcp_tracer: Tracer | None
self.rfcomm_tracer: Tracer | None
self.rfcomm_channel: Optional[rfcomm.DLC] = None
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
@@ -241,14 +242,14 @@ class ClientBridge:
self.tcp_port = tcp_port
self.authenticate = authenticate
self.encrypt = encrypt
self.device: Device | None = None
self.connection: Connection | None = None
self.rfcomm_client: rfcomm.Client | None
self.rfcomm_mux: rfcomm.Multiplexer | None
self.device: Optional[Device] = None
self.connection: Optional[Connection] = None
self.rfcomm_client: Optional[rfcomm.Client]
self.rfcomm_mux: Optional[rfcomm.Multiplexer]
self.tcp_connected: bool = False
self.tcp_tracer: Tracer | None
self.rfcomm_tracer: Tracer | None
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))

View File

@@ -217,7 +217,9 @@ async def scan(
@click.option(
'--irk',
metavar='<IRK_HEX>:<ADDRESS>',
help=('Use this IRK for resolving private addresses (may be used more than once)'),
help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
multiple=True,
)
@click.option(

View File

@@ -26,6 +26,7 @@ import pathlib
import subprocess
import weakref
from importlib import resources
from typing import Optional
import aiohttp
import click
@@ -155,7 +156,7 @@ class QueuedOutput(Output):
packets: asyncio.Queue
extractor: AudioExtractor
packet_pump_task: asyncio.Task | None
packet_pump_task: Optional[asyncio.Task]
started: bool
def __init__(self, extractor):
@@ -229,8 +230,8 @@ class WebSocketOutput(QueuedOutput):
class FfplayOutput(QueuedOutput):
MAX_QUEUE_SIZE = 32768
subprocess: asyncio.subprocess.Process | None
ffplay_task: asyncio.Task | None
subprocess: Optional[asyncio.subprocess.Process]
ffplay_task: Optional[asyncio.Task]
def __init__(self, codec: str) -> None:
super().__init__(AudioExtractor.create(codec))

View File

@@ -21,12 +21,11 @@ import dataclasses
import enum
import logging
import struct
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import ClassVar
from collections.abc import AsyncGenerator
from typing import Awaitable, Callable
from typing_extensions import Self
from typing_extensions import ClassVar, Self
from bumble import utils
from bumble.codecs import AacAudioRtpPacket
from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.core import (
@@ -60,18 +59,19 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# fmt: off
class CodecType(utils.OpenIntEnum):
SBC = 0x00
MPEG_1_2_AUDIO = 0x01
MPEG_2_4_AAC = 0x02
ATRAC_FAMILY = 0x03
NON_A2DP = 0xFF
A2DP_SBC_CODEC_TYPE = 0x00
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01
A2DP_MPEG_2_4_AAC_CODEC_TYPE = 0x02
A2DP_ATRAC_FAMILY_CODEC_TYPE = 0x03
A2DP_NON_A2DP_CODEC_TYPE = 0xFF
A2DP_SBC_CODEC_TYPE = CodecType.SBC
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = CodecType.MPEG_1_2_AUDIO
A2DP_MPEG_2_4_AAC_CODEC_TYPE = CodecType.MPEG_2_4_AAC
A2DP_ATRAC_FAMILY_CODEC_TYPE = CodecType.ATRAC_FAMILY
A2DP_NON_A2DP_CODEC_TYPE = CodecType.NON_A2DP
A2DP_CODEC_TYPE_NAMES = {
A2DP_SBC_CODEC_TYPE: 'A2DP_SBC_CODEC_TYPE',
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE: 'A2DP_MPEG_1_2_AUDIO_CODEC_TYPE',
A2DP_MPEG_2_4_AAC_CODEC_TYPE: 'A2DP_MPEG_2_4_AAC_CODEC_TYPE',
A2DP_ATRAC_FAMILY_CODEC_TYPE: 'A2DP_ATRAC_FAMILY_CODEC_TYPE',
A2DP_NON_A2DP_CODEC_TYPE: 'A2DP_NON_A2DP_CODEC_TYPE'
}
SBC_SYNC_WORD = 0x9C
@@ -259,48 +259,9 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
]
# -----------------------------------------------------------------------------
class MediaCodecInformation:
'''Base Media Codec Information.'''
@classmethod
def create(
cls, media_codec_type: int, data: bytes
) -> MediaCodecInformation | bytes:
if media_codec_type == CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
elif media_codec_type == CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
return vendor_media_codec_information
@classmethod
def from_bytes(cls, data: bytes) -> Self:
del data # Unused.
raise NotImplementedError
def __bytes__(self) -> bytes:
raise NotImplementedError
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcMediaCodecInformation(MediaCodecInformation):
class SbcMediaCodecInformation:
'''
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
@@ -384,7 +345,7 @@ class SbcMediaCodecInformation(MediaCodecInformation):
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class AacMediaCodecInformation(MediaCodecInformation):
class AacMediaCodecInformation:
'''
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
@@ -466,7 +427,7 @@ class AacMediaCodecInformation(MediaCodecInformation):
@dataclasses.dataclass
# -----------------------------------------------------------------------------
class VendorSpecificMediaCodecInformation(MediaCodecInformation):
class VendorSpecificMediaCodecInformation:
'''
A2DP spec - 4.7.2 Codec Specific Information Elements
'''
@@ -490,7 +451,7 @@ class VendorSpecificMediaCodecInformation(MediaCodecInformation):
'VendorSpecificMediaCodecInformation(',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
f' codec_id: {self.codec_id:04X}',
f' value: {self.value.hex()})',
f' value: {self.value.hex()}' ')',
]
)
@@ -686,7 +647,7 @@ class SbcPacketSource:
# Prepare for next packets
sequence_number += 1
sequence_number &= 0xFFFF
sample_count += sum(frame.sample_count for frame in frames)
sample_count += sum((frame.sample_count for frame in frames))
frames = [frame]
frames_size = len(frame.payload)
else:

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from bumble import core
@@ -35,7 +36,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
if in_quotes:
token.extend(char)
if char == b'"':
if char == b'\"':
in_quotes = False
tokens.append(token[1:-1])
token = bytearray()
@@ -62,18 +63,18 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
return [bytes(token) for token in tokens if len(token) > 0]
def parse_parameters(buffer: bytes) -> list[bytes | list]:
def parse_parameters(buffer: bytes) -> list[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators.
Raises AtParsingError in case of invalid input string."""
tokens = tokenize_parameters(buffer)
accumulator: list[list] = [[]]
current: bytes | list = b''
current: Union[bytes, list] = bytes()
for token in tokens:
if token == b',':
accumulator[-1].append(current)
current = b''
current = bytes()
elif token == b'(':
accumulator.append([])
elif token == b')':

View File

@@ -29,18 +29,18 @@ import enum
import functools
import inspect
import struct
from collections.abc import Awaitable, Callable
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
ClassVar,
Generic,
TypeAlias,
Optional,
TypeVar,
Union,
)
from typing_extensions import TypeIs
from bumble import hci, l2cap, utils
from bumble import hci, utils
from bumble.colors import color
from bumble.core import UUID, InvalidOperationError, ProtocolError
from bumble.hci import HCI_Object
@@ -53,14 +53,6 @@ if TYPE_CHECKING:
_T = TypeVar('_T')
Bearer: TypeAlias = "Connection | l2cap.LeCreditBasedChannel"
EnhancedBearer: TypeAlias = l2cap.LeCreditBasedChannel
def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]:
return isinstance(bearer, EnhancedBearer)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -69,7 +61,6 @@ def is_enhanced_bearer(bearer: Bearer) -> TypeIs[EnhancedBearer]:
ATT_CID = 0x04
ATT_PSM = 0x001F
EATT_PSM = 0x0027
class Opcode(hci.SpecableEnum):
ATT_ERROR_RESPONSE = 0x01
@@ -229,7 +220,7 @@ class ATT_PDU:
fields: ClassVar[hci.Fields] = ()
op_code: int = dataclasses.field(init=False)
name: str = dataclasses.field(init=False)
_payload: bytes | None = dataclasses.field(default=None, init=False)
_payload: Optional[bytes] = dataclasses.field(default=None, init=False)
@classmethod
def from_bytes(cls, pdu: bytes) -> ATT_PDU:
@@ -769,66 +760,31 @@ class AttributeValue(Generic[_T]):
def __init__(
self,
read: (
Callable[[Connection], _T] | Callable[[Connection], Awaitable[_T]] | None
) = None,
write: (
Callable[[Connection, _T], None]
| Callable[[Connection, _T], Awaitable[None]]
| None
) = None,
read: Union[
Callable[[Connection], _T],
Callable[[Connection], Awaitable[_T]],
None,
] = None,
write: Union[
Callable[[Connection, _T], None],
Callable[[Connection, _T], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Connection) -> _T | Awaitable[_T]:
def read(self, connection: Connection) -> Union[_T, Awaitable[_T]]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(connection)
def write(self, connection: Connection, value: _T) -> Awaitable[None] | None:
def write(self, connection: Connection, value: _T) -> Union[Awaitable[None], None]:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(connection, value)
# -----------------------------------------------------------------------------
class AttributeValueV2(Generic[_T]):
'''
Attribute value compatible with enhanced bearers.
The only difference between AttributeValue and AttributeValueV2 is that the actual
bearer (ACL connection for un-enhanced bearer, L2CAP channel for enhanced bearer)
will be passed into read and write callbacks in V2, while in V1 it is always
the base ACL connection.
This is only required when attributes must distinguish bearers, otherwise normal
`AttributeValue` objects are also applicable in enhanced bearers.
'''
def __init__(
self,
read: Callable[[Bearer], Awaitable[_T]] | Callable[[Bearer], _T] | None = None,
write: (
Callable[[Bearer, _T], Awaitable[None]]
| Callable[[Bearer, _T], None]
| None
) = None,
):
self._read = read
self._write = write
def read(self, bearer: Bearer) -> _T | Awaitable[_T]:
if self._read is None:
raise InvalidOperationError('AttributeValue has no read function')
return self._read(bearer)
def write(self, bearer: Bearer, value: _T) -> Awaitable[None] | None:
if self._write is None:
raise InvalidOperationError('AttributeValue has no write function')
return self._write(bearer, value)
# -----------------------------------------------------------------------------
class Attribute(utils.EventEmitter, Generic[_T]):
class Permissions(enum.IntFlag):
@@ -872,13 +828,13 @@ class Attribute(utils.EventEmitter, Generic[_T]):
EVENT_READ = "read"
EVENT_WRITE = "write"
value: AttributeValue[_T] | _T | None
value: Union[AttributeValue[_T], _T, None]
def __init__(
self,
attribute_type: str | bytes | UUID,
permissions: str | Attribute.Permissions,
value: AttributeValue[_T] | _T | None = None,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[AttributeValue[_T], _T, None] = None,
) -> None:
utils.EventEmitter.__init__(self)
self.handle = 0
@@ -904,8 +860,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
def decode_value(self, value: bytes) -> _T:
return value # type: ignore
async def read_value(self, bearer: Bearer) -> bytes:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
async def read_value(self, connection: Connection) -> bytes:
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -928,7 +883,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
value: _T | None
value: Union[_T, None]
if isinstance(self.value, AttributeValue):
try:
read_value = self.value.read(connection)
@@ -940,17 +895,6 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
read_value = self.value.read(bearer)
if inspect.isawaitable(read_value):
value = await read_value
else:
value = read_value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
@@ -958,8 +902,7 @@ class Attribute(utils.EventEmitter, Generic[_T]):
return b'' if value is None else self.encode_value(value)
async def write_value(self, bearer: Bearer, value: bytes) -> None:
connection = bearer.connection if is_enhanced_bearer(bearer) else bearer
async def write_value(self, connection: Connection, value: bytes) -> None:
if (
(self.permissions & self.WRITE_REQUIRES_ENCRYPTION)
and connection is not None
@@ -993,15 +936,6 @@ class Attribute(utils.EventEmitter, Generic[_T]):
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
elif isinstance(self.value, AttributeValueV2):
try:
result = self.value.write(bearer, decoded_value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = decoded_value

View File

@@ -19,15 +19,14 @@ from __future__ import annotations
import abc
import asyncio
import concurrent.futures
import dataclasses
import enum
import logging
import pathlib
import sys
import wave
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, BinaryIO
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO
from bumble.colors import color
@@ -177,7 +176,7 @@ class ThreadedAudioOutput(AudioOutput):
"""
def __init__(self) -> None:
self._thread_pool = concurrent.futures.ThreadPoolExecutor(1)
self._thread_pool = ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
self._write_task = asyncio.create_task(self._write_loop())
@@ -406,7 +405,7 @@ class ThreadedAudioInput(AudioInput):
"""Base class for AudioInput implementation where reading samples may block."""
def __init__(self) -> None:
self._thread_pool = concurrent.futures.ThreadPoolExecutor(1)
self._thread_pool = ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
@abc.abstractmethod
@@ -546,6 +545,5 @@ class SoundDeviceAudioInput(ThreadedAudioInput):
return bytes(pcm_buffer)
def _close(self):
if self._stream:
self._stream.stop()
self._stream = None
self._stream.stop()
self._stream = None

339
bumble/audio/io_asrc.py Normal file
View File

@@ -0,0 +1,339 @@
# Copyright 2025
#
# Drop-in replacement for `SoundDeviceAudioInput` that adds a tiny ASRC stage.
#
# Constraints per request:
# - Only import io_bumble.py at module level.
# - Reuse the ASRC functionality from asrc.py conceptually (PI control + FIFO +
# linear/sinc resampling behavior). We implement a minimal, dependency-free
# variant (linear interpolation with a small PI loop) so this module does not
# import anything else at top-level.
#
# Notes:
# - Input stream is captured via sounddevice (imported lazily inside methods).
# - Input is mono float32 for simplicity; output matches the original class
# signature: INT16, stereo, at the same nominal sample rate as requested.
from .io import PcmFormat, ThreadedAudioInput, logger # only top-level import
class SoundDeviceAudioInputAsrc(ThreadedAudioInput):
"""Sound device audio input with a simple ASRC stage.
Interface-compatible with `io_bumble.SoundDeviceAudioInput`:
- __init__(device_name: str, pcm_format: PcmFormat)
- _open() -> PcmFormat
- _read(frame_size: int) -> bytes
- _close() -> None
Behavior:
- Captures mono float32 frames from the device.
- Buffers into an internal ring buffer.
- Produces stereo INT16 frames using a linear-interp resampler whose
ratio is adjusted by a tiny PI loop to hold FIFO depth near a target.
"""
def __init__(self, device_name: str, pcm_format: str) -> None:
super().__init__()
# Device & format
self._device = int(device_name) if device_name else None
pcm_format: PcmFormat | None
if pcm_format == 'auto':
pcm_format = None
else:
pcm_format = PcmFormat.from_str(pcm_format)
self._pcm_format_in = pcm_format
# We always output stereo INT16 at the same nominal sample rate.
self._pcm_format_out = PcmFormat(
PcmFormat.Endianness.LITTLE,
PcmFormat.SampleType.INT16,
pcm_format.sample_rate,
2,
)
# sounddevice stream (created in _open)
self._stream = None # type: ignore[assignment]
# --- ASRC state (inspired by asrc.py) ---
# Nominal input/output rate ratio
self._r = 1.0
self._integral = 0.0
self._phi = 0.0 # fractional read position within current chunk
# PI gains (tiny to avoid warble)
self._Kp = 2e-6
self._Ki = 5e-8
self._R0 = 1.0
# Target FIFO level and deadband (≈10 ms target, 0.5 ms deadband)
fs = float(self._pcm_format_in.sample_rate)
self._target_samples = max(1, int(0.010 * fs))
self._deadband = max(1, int(0.0005 * fs))
# Ring buffer for mono float32 samples
# Capacity ~2 seconds for headroom
self._rb_cap = max(self._target_samples * 32, int(2 * fs))
self._rb = None # created in _init_rb()
self._ridx = 0
self._size = 0
self._lock = None # created in _init_rb()
self._init_rb()
# Light logging timer
self._last_log = 0.0
# Streaming resampler and internal output buffer (lazy init)
self._rs = None # samplerate.Resampler
self._out_buf = None # numpy.ndarray float32
# ---------------- Internal helpers -----------------
def _init_rb(self) -> None:
# Lazy import standard libs to keep only io_bumble imported at top level
import threading
from array import array
self._rb = array('f', [0.0] * self._rb_cap) # float32 ring buffer
self._lock = threading.Lock()
self._ridx = 0
self._size = 0
def _fifo_len(self) -> int:
with self._lock:
return self._size
def _fifo_write(self, x_f32) -> None:
# x_f32: 1-D float32-like iterable
k = len(x_f32)
if k <= 0:
return
rb = self._rb
if rb is None:
return
with self._lock:
# Trim if larger than capacity: keep last N
if k >= self._rb_cap:
x_f32 = x_f32[-self._rb_cap:]
k = self._rb_cap
# Make room on overflow (drop oldest)
excess = max(0, self._size + k - self._rb_cap)
if excess:
self._ridx = (self._ridx + excess) % self._rb_cap
self._size -= excess
# Write at tail position
wpos = (self._ridx + self._size) % self._rb_cap
first = min(k, self._rb_cap - wpos)
# Write first chunk
from array import array as _array # lazy import
rb[wpos:wpos + first] = _array('f', x_f32[:first])
# Wrap if needed
second = k - first
if second:
rb[0:second] = _array('f', x_f32[first:])
self._size += k
def _fifo_peek_array(self, n: int):
# Returns a Python list[float] copy of up to n samples
rb = self._rb
if rb is None:
return []
m = max(0, min(n, self._fifo_len()))
if m <= 0:
return []
pos = self._ridx
first = min(m, self._rb_cap - pos)
# Copy out
out = [0.0] * m
# First chunk
out[:first] = rb[pos:pos + first]
# Second chunk if wrap
second = m - first
if second > 0:
out[first:] = rb[0:second]
return out
def _fifo_discard(self, n: int) -> None:
with self._lock:
d = max(0, min(n, self._size))
self._ridx = (self._ridx + d) % self._rb_cap
self._size -= d
def _update_ratio(self) -> None:
# PI loop to hold buffer near target
e = self._target_samples - self._fifo_len()
if -self._deadband <= e <= self._deadband:
e = 0.0
cand_integral = self._integral + e
r_unclamped = self._R0 * (1.0 + self._Kp * e + self._Ki * cand_integral)
# Limit to ±1000 ppm vs nominal
ppm_unclamped = 1e6 * (r_unclamped / self._R0 - 1.0)
saturated_high = ppm_unclamped > 1000.0
saturated_low = ppm_unclamped < -1000.0
if saturated_high:
self._r = self._R0 * (1 + 1000e-6)
if e <= 0:
self._integral = cand_integral
self._integral *= 0.99
elif saturated_low:
self._r = self._R0 * (1 - 1000e-6)
if e >= 0:
self._integral = cand_integral
self._integral *= 0.99
else:
self._integral = cand_integral
self._r = r_unclamped
# Occasional log
try:
import time as _time
now = _time.time()
if now - self._last_log > 1.0:
buf_ms = 1000.0 * self._fifo_len() / float(self._pcm_format_in.sample_rate)
print(
f"\nASRC buf={buf_ms:5.1f} ms r={self._r:.9f} corr={1e6 * (self._r / self._R0 - 1.0):+7.1f} ppm"
)
self._last_log = now
except Exception:
# Logging must never break audio
pass
def _process(self, n_out: int) -> list[float]:
# Accumulate at least n_out samples using samplerate.Resampler
if n_out <= 0:
return []
# Lazy imports
import numpy as np # type: ignore
# Lazy init output buffer
if self._out_buf is None:
self._out_buf = np.zeros(0, dtype=np.float32)
# Choose chunk so we don't take too much from FIFO each time
max_chunk = max(256, int(np.ceil(n_out / max(1e-9, self._r))))
safety_iters = 0
while self._out_buf.size < n_out and safety_iters < 16:
safety_iters += 1
available = self._fifo_len()
if available <= 0:
break
take = min(available, max_chunk)
x = self._fifo_peek_array(take)
self._fifo_discard(take)
if not x:
break
x_arr = np.asarray(x, dtype=np.float32)
if self._rs is not None:
try:
y = self._rs.process(x_arr, ratio=float(self._r), end_of_input=False)
except Exception:
logger.exception("ASRC resampler error")
y = None
else:
y = None
if y is not None and getattr(y, 'size', 0):
y = y.astype(np.float32, copy=False)
if self._out_buf.size == 0:
self._out_buf = y
else:
self._out_buf = np.concatenate((self._out_buf, y))
if self._out_buf.size >= n_out:
out = self._out_buf[:n_out]
self._out_buf = self._out_buf[n_out:]
return out.tolist()
else:
# Not enough data produced; pad with zeros
out = np.zeros(n_out, dtype=np.float32)
if self._out_buf.size:
out[: self._out_buf.size] = self._out_buf
self._out_buf = np.zeros(0, dtype=np.float32)
return out.tolist()
def _mono_to_stereo_int16_bytes(self, mono_f32: list[float]) -> bytes:
# Convert [-1,1] float list to stereo int16 little-endian bytes
import struct
ba = bytearray()
for v in mono_f32:
# clip
if v > 1.0:
v = 1.0
elif v < -1.0:
v = -1.0
i16 = int(v * 32767.0)
ba += struct.pack('<hh', i16, i16)
return bytes(ba)
# ---------------- ThreadedAudioInput hooks -----------------
def _open(self) -> PcmFormat:
# Set up sounddevice RawInputStream (int16) and start callback producer
import sounddevice # pylint: disable=import-error
import math
import samplerate as sr # type: ignore
# We capture mono regardless of requested channels, then output stereo.
channels = 1
samplerate = int(self._pcm_format_in.sample_rate)
def _callback(indata, frames, time_info, status): # noqa: ARG001 (signature is fixed)
# indata: raw int16 bytes-like buffer of shape (frames, channels)
try:
if status:
logger.warning("Input status: %s", status)
if frames <= 0:
return
# Interpret raw bytes as little-endian int16 mono
mv = memoryview(indata).cast('h') # len == frames * channels
# Convert to float in [-1, 1]
# Avoid division errors; protect NaN/Inf
mono = []
for i in range(frames):
v = mv[i]
f = float(v) / 32768.0
if not (f == f) or math.isinf(f):
f = 0.0
mono.append(f)
self._fifo_write(mono)
except Exception: # never let callback raise
logger.exception("Audio input callback error")
# Create streaming resampler (mono)
try:
self._rs = sr.Resampler(converter_type="sinc_fastest", channels=1)
except Exception:
logger.exception("Failed to create samplerate.Resampler; audio may be silent")
self._rs = None
self._stream = sounddevice.RawInputStream(
samplerate=samplerate,
device=self._device,
channels=channels,
dtype='int16',
callback=_callback,
)
self._stream.start()
return self._pcm_format_out
def _read(self, frame_size: int) -> bytes:
# Produce 'frame_size' output frames (stereo INT16)
if frame_size <= 0:
return b''
# Update resampling ratio based on FIFO level
try:
self._update_ratio()
except Exception:
# keep going even if update failed
pass
# Process mono float32
mono = self._process(frame_size)
# Convert to stereo int16 LE bytes
return self._mono_to_stereo_int16_bytes(mono)
def _close(self) -> None:
try:
if self._stream is not None:
self._stream.stop()
self._stream.close()
except Exception:
logger.exception('Error closing input stream')
finally:
self._stream = None

View File

@@ -19,6 +19,7 @@ from __future__ import annotations
import enum
import struct
from typing import Union
from bumble import core, utils
@@ -165,7 +166,7 @@ class Frame:
def to_bytes(
self,
ctype_or_response: CommandFrame.CommandType | ResponseFrame.ResponseCode,
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
) -> bytes:
# TODO: support extended subunit types and ids.
return (

View File

@@ -19,10 +19,10 @@ from __future__ import annotations
import logging
import struct
from collections.abc import Callable
from enum import IntEnum
from typing import Callable, Optional, cast
from bumble import core, l2cap
from bumble import avc, core, l2cap
from bumble.colors import color
# -----------------------------------------------------------------------------
@@ -144,9 +144,9 @@ class MessageAssembler:
# -----------------------------------------------------------------------------
class Protocol:
CommandHandler = Callable[[int, bytes], None]
CommandHandler = Callable[[int, avc.CommandFrame], None]
command_handlers: dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, bytes | None], None]
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
response_handlers: dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int
message_assembler: MessageAssembler
@@ -204,15 +204,20 @@ class Protocol:
self.send_ipid(transaction_label, pid)
return
self.command_handlers[pid](transaction_label, payload)
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
else:
if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}")
return
# By convention, for an ipid, send a None payload to the response handler.
response_payload = None if ipid else payload
self.response_handlers[pid](transaction_label, response_payload)
if ipid:
response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
def send_message(
self,
@@ -257,7 +262,7 @@ class Protocol:
def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug(
f">>> AVCTP ipid: transaction_label={transaction_label}, pid={pid}"
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
)
self.send_message(transaction_label, False, True, pid, b'')

File diff suppressed because it is too large Load Diff

View File

@@ -22,9 +22,21 @@ import enum
import functools
import logging
import struct
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
from dataclasses import dataclass, field
from typing import ClassVar, SupportsBytes, TypeVar
from typing import (
AsyncIterator,
Awaitable,
Callable,
ClassVar,
Iterable,
List,
Optional,
Sequence,
SupportsBytes,
TypeVar,
Union,
cast,
)
from bumble import avc, avctp, core, hci, l2cap, utils
from bumble.colors import color
@@ -196,7 +208,7 @@ def make_controller_service_sdp_records(
service_record_handle: int,
avctp_version: tuple[int, int] = (1, 4),
avrcp_version: tuple[int, int] = (1, 6),
supported_features: int | ControllerFeatures = 1,
supported_features: Union[int, ControllerFeatures] = 1,
) -> list[ServiceAttribute]:
avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1]
@@ -288,7 +300,7 @@ def make_target_service_sdp_records(
service_record_handle: int,
avctp_version: tuple[int, int] = (1, 4),
avrcp_version: tuple[int, int] = (1, 6),
supported_features: int | TargetFeatures = 0x23,
supported_features: Union[int, TargetFeatures] = 0x23,
) -> list[ServiceAttribute]:
# TODO: support a way to compute the supported features from a feature list
avctp_version_int = avctp_version[0] << 8 | avctp_version[1]
@@ -478,7 +490,7 @@ class BrowseableItem:
MEDIA_ELEMENT = 0x03
item_type: ClassVar[Type]
_payload: bytes | None = None
_payload: Optional[bytes] = None
subclasses: ClassVar[dict[Type, type[BrowseableItem]]] = {}
fields: ClassVar[hci.Fields] = ()
@@ -672,7 +684,7 @@ class PduAssembler:
6.3.1 AVRCP specific AV//C commands
"""
pdu_id: PduId | None
pdu_id: Optional[PduId]
payload: bytes
def __init__(self, callback: Callable[[PduId, bytes], None]) -> None:
@@ -725,7 +737,7 @@ class PduAssembler:
# -----------------------------------------------------------------------------
class Command:
pdu_id: ClassVar[PduId]
_payload: bytes | None = None
_payload: Optional[bytes] = None
_Command = TypeVar('_Command', bound='Command')
subclasses: ClassVar[dict[int, type[Command]]] = {}
@@ -1017,7 +1029,7 @@ class AddToNowPlayingCommand(Command):
# -----------------------------------------------------------------------------
class Response:
pdu_id: PduId
_payload: bytes | None = None
_payload: Optional[bytes] = None
fields: ClassVar[hci.Fields] = ()
subclasses: ClassVar[dict[PduId, type[Response]]] = {}
@@ -1079,7 +1091,7 @@ class NotImplementedResponse(Response):
class GetCapabilitiesResponse(Response):
pdu_id = PduId.GET_CAPABILITIES
capability_id: GetCapabilitiesCommand.CapabilityId
capabilities: Sequence[SupportsBytes | bytes]
capabilities: Sequence[Union[SupportsBytes, bytes]]
@classmethod
def from_parameters(cls, parameters: bytes) -> Response:
@@ -1092,7 +1104,7 @@ class GetCapabilitiesResponse(Response):
capability_id = GetCapabilitiesCommand.CapabilityId(parameters[0])
capability_count = parameters[1]
capabilities: list[SupportsBytes | bytes]
capabilities: list[Union[SupportsBytes, bytes]]
if capability_id == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED:
capabilities = [EventId(parameters[2 + x]) for x in range(capability_count)]
else:
@@ -1363,7 +1375,7 @@ class AddToNowPlayingResponse(Response):
# -----------------------------------------------------------------------------
class Event:
event_id: EventId
_pdu: bytes | None = None
_pdu: Optional[bytes] = None
_Event = TypeVar('_Event', bound='Event')
subclasses: ClassVar[dict[int, type[Event]]] = {}
@@ -1436,13 +1448,13 @@ class PlayerApplicationSettingChangedEvent(Event):
attribute_id: ApplicationSetting.AttributeId = field(
metadata=ApplicationSetting.AttributeId.type_metadata(1)
)
value_id: (
ApplicationSetting.EqualizerOnOffStatus
| ApplicationSetting.RepeatModeStatus
| ApplicationSetting.ShuffleOnOffStatus
| ApplicationSetting.ScanOnOffStatus
| ApplicationSetting.GenericValue
) = field(metadata=hci.metadata(1))
value_id: Union[
ApplicationSetting.EqualizerOnOffStatus,
ApplicationSetting.RepeatModeStatus,
ApplicationSetting.ShuffleOnOffStatus,
ApplicationSetting.ScanOnOffStatus,
ApplicationSetting.GenericValue,
] = field(metadata=hci.metadata(1))
def __post_init__(self) -> None:
super().__post_init__()
@@ -1628,17 +1640,17 @@ class Protocol(utils.EventEmitter):
delegate: Delegate
send_transaction_label: int
command_pdu_assembler: PduAssembler
receive_command_state: ReceiveCommandState | None
receive_command_state: Optional[ReceiveCommandState]
response_pdu_assembler: PduAssembler
receive_response_state: ReceiveResponseState | None
avctp_protocol: avctp.Protocol | None
receive_response_state: Optional[ReceiveResponseState]
avctp_protocol: Optional[avctp.Protocol]
free_commands: asyncio.Queue
pending_commands: dict[int, PendingCommand] # Pending commands, by label
notification_listeners: dict[EventId, NotificationListener]
@staticmethod
def _check_vendor_dependent_frame(
frame: avc.VendorDependentCommandFrame | avc.VendorDependentResponseFrame,
frame: Union[avc.VendorDependentCommandFrame, avc.VendorDependentResponseFrame],
) -> bool:
if frame.company_id != AVRCP_BLUETOOTH_SIG_COMPANY_ID:
logger.debug("unsupported company id, ignoring")
@@ -1650,7 +1662,7 @@ class Protocol(utils.EventEmitter):
return True
def __init__(self, delegate: Delegate | None = None) -> None:
def __init__(self, delegate: Optional[Delegate] = None) -> None:
super().__init__()
self.delegate = delegate if delegate else Delegate()
self.command_pdu_assembler = PduAssembler(self._on_command_pdu)
@@ -1750,11 +1762,7 @@ class Protocol(utils.EventEmitter):
),
)
response = self._check_response(response_context, GetCapabilitiesResponse)
return list(
capability
for capability in response.capabilities
if isinstance(capability, EventId)
)
return cast(List[EventId], response.capabilities)
async def get_play_status(self) -> SongAndPlayStatus:
"""Get the play status of the connected peer."""
@@ -2004,14 +2012,11 @@ class Protocol(utils.EventEmitter):
self.emit(self.EVENT_STOP)
def _on_avctp_command(self, transaction_label: int, payload: bytes) -> None:
command = avc.CommandFrame.from_bytes(payload)
if not isinstance(command, avc.CommandFrame):
raise core.InvalidPacketError(
f"{command} is not a valid AV/C Command Frame"
)
def _on_avctp_command(
self, transaction_label: int, command: avc.CommandFrame
) -> None:
logger.debug(
f"<<< AVCTP Command, transaction_label={transaction_label}: {command}"
f"<<< AVCTP Command, transaction_label={transaction_label}: " f"{command}"
)
# Only addressing the unit, or the PANEL subunit with subunit ID 0 is supported
@@ -2067,12 +2072,9 @@ class Protocol(utils.EventEmitter):
# TODO handle other types
self.send_not_implemented_response(transaction_label, command)
def _on_avctp_response(self, transaction_label: int, payload: bytes | None) -> None:
response = avc.ResponseFrame.from_bytes(payload) if payload else None
if not isinstance(response, avc.ResponseFrame):
raise core.InvalidPacketError(
f"{response} is not a valid AV/C Response Frame"
)
def _on_avctp_response(
self, transaction_label: int, response: Optional[avc.ResponseFrame]
) -> None:
logger.debug(
f"<<< AVCTP Response, transaction_label={transaction_label}: {response}"
)
@@ -2174,7 +2176,7 @@ class Protocol(utils.EventEmitter):
# NOTE: with a small number of supported responses, a manual switch like this
# is Ok, but if/when more responses are supported, a lookup mechanism would be
# more appropriate.
response: Response | None = None
response: Optional[Response] = None
if response_code == avc.ResponseFrame.ResponseCode.REJECTED:
response = RejectedResponse(pdu_id=pdu_id, status_code=StatusCode(pdu[0]))
elif response_code == avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED:
@@ -2389,7 +2391,7 @@ class Protocol(utils.EventEmitter):
effective_volume = await self.delegate.get_absolute_volume()
self.send_avrcp_response(
transaction_label,
avc.ResponseFrame.ResponseCode.ACCEPTED,
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
SetAbsoluteVolumeResponse(effective_volume),
)

View File

@@ -163,23 +163,23 @@ class AacAudioRtpPacket:
cls, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> Self:
# GASpecificConfig - ISO/EIC 14496-3 Table 4.1
reader.read(1) # frame_length_flag
frame_length_flag = reader.read(1)
depends_on_core_coder = reader.read(1)
if depends_on_core_coder:
reader.read(14) # core_coder_delay
core_coder_delay = reader.read(14)
extension_flag = reader.read(1)
if not channel_configuration:
raise core.InvalidPacketError('program_config_element not supported')
if audio_object_type in (6, 20):
reader.read(3) # layer_nr
layer_nr = reader.read(3)
if extension_flag:
if audio_object_type == 22:
reader.read(5) # num_of_sub_frame
reader.read(11) # layer_length
num_of_sub_frame = reader.read(5)
layer_length = reader.read(11)
if audio_object_type in (17, 19, 20, 23):
reader.read(1) # aac_section_data_resilience_flags
reader.read(1) # aac_scale_factor_data_resilience_flags
reader.read(1) # aac_spectral_data_resilience_flags
aac_section_data_resilience_flags = reader.read(1)
aac_scale_factor_data_resilience_flags = reader.read(1)
aac_spectral_data_resilience_flags = reader.read(1)
extension_flag_3 = reader.read(1)
if extension_flag_3 == 1:
raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
@@ -364,10 +364,10 @@ class AacAudioRtpPacket:
if audio_mux_version_a != 0:
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1:
AacAudioRtpPacket.read_latm_value(reader) # tara_buffer_fullness
# stream_cnt = 0
reader.read(1) # all_streams_same_time_framing
reader.read(6) # num_sub_frames
tara_buffer_fullness = AacAudioRtpPacket.read_latm_value(reader)
stream_cnt = 0
all_streams_same_time_framing = reader.read(1)
num_sub_frames = reader.read(6)
num_program = reader.read(4)
if num_program != 0:
raise core.InvalidPacketError('num_program != 0 not supported')
@@ -391,9 +391,9 @@ class AacAudioRtpPacket:
reader.skip(asc_len)
frame_length_type = reader.read(3)
if frame_length_type == 0:
reader.read(8) # latm_buffer_fullness
latm_buffer_fullness = reader.read(8)
elif frame_length_type == 1:
reader.read(9) # frame_length
frame_length = reader.read(9)
else:
raise core.InvalidPacketError(
f'frame_length_type {frame_length_type} not supported'
@@ -413,7 +413,7 @@ class AacAudioRtpPacket:
break
crc_check_present = reader.read(1)
if crc_check_present:
reader.read(8) # crc_checksum
crc_checksum = reader.read(8)
return cls(other_data_present, other_data_len_bits, audio_specific_config)

View File

@@ -13,6 +13,7 @@
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from functools import partial
from typing import Optional, Union
class ColorError(ValueError):
@@ -37,7 +38,7 @@ STYLES = (
)
ColorSpec = str | int
ColorSpec = Union[str, int]
def _join(*values: ColorSpec) -> str:
@@ -55,14 +56,14 @@ def _color_code(spec: ColorSpec, base: int) -> str:
elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
else:
raise ColorError(f'Invalid color spec "{spec}"')
raise ColorError('Invalid color spec "%s"' % spec)
def color(
s: str,
fg: ColorSpec | None = None,
bg: ColorSpec | None = None,
style: str | None = None,
fg: Optional[ColorSpec] = None,
bg: Optional[ColorSpec] = None,
style: Optional[str] = None,
) -> str:
codes: list[ColorSpec] = []
@@ -75,10 +76,10 @@ def color(
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
raise ColorError(f'Invalid style "{style_part}"')
raise ColorError('Invalid style "%s"' % style_part)
if codes:
return f'\x1b[{_join(*codes)}m{s}\x1b[0m'
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
else:
return s

File diff suppressed because it is too large Load Diff

View File

@@ -20,11 +20,15 @@ from __future__ import annotations
import dataclasses
import enum
import struct
from collections.abc import Iterable
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Iterable,
Literal,
Optional,
Type,
Union,
cast,
overload,
)
@@ -99,7 +103,7 @@ class BaseError(BaseBumbleError):
def __init__(
self,
error_code: int | None,
error_code: Optional[int],
error_namespace: str = '',
error_name: str = '',
details: str = '',
@@ -212,9 +216,11 @@ class UUID:
UUIDS: list[UUID] = [] # Registry of all instances created
uuid_bytes: bytes
name: str | None
name: Optional[str]
def __init__(self, uuid_str_or_int: str | int, name: str | None = None) -> None:
def __init__(
self, uuid_str_or_int: Union[str, int], name: Optional[str] = None
) -> None:
if isinstance(uuid_str_or_int, int):
self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else:
@@ -247,7 +253,7 @@ class UUID:
return self
@classmethod
def from_bytes(cls, uuid_bytes: bytes, name: str | None = None) -> UUID:
def from_bytes(cls, uuid_bytes: bytes, name: Optional[str] = None) -> UUID:
if len(uuid_bytes) in (2, 4, 16):
self = cls.__new__(cls)
self.uuid_bytes = uuid_bytes
@@ -258,11 +264,11 @@ class UUID:
raise InvalidArgumentError('only 2, 4 and 16 bytes are allowed')
@classmethod
def from_16_bits(cls, uuid_16: int, name: str | None = None) -> UUID:
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:
return cls.from_bytes(struct.pack('<H', uuid_16), name)
@classmethod
def from_32_bits(cls, uuid_32: int, name: str | None = None) -> UUID:
def from_32_bits(cls, uuid_32: int, name: Optional[str] = None) -> UUID:
return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod
@@ -728,7 +734,7 @@ class ClassOfDevice:
MajorDeviceClass.HEALTH: HEALTH_MINOR_DEVICE_CLASS_LABELS,
}
_MINOR_DEVICE_CLASSES: ClassVar[dict[MajorDeviceClass, type]] = {
_MINOR_DEVICE_CLASSES: ClassVar[dict[MajorDeviceClass, Type]] = {
MajorDeviceClass.COMPUTER: ComputerMinorDeviceClass,
MajorDeviceClass.PHONE: PhoneMinorDeviceClass,
MajorDeviceClass.LAN_NETWORK_ACCESS_POINT: LanNetworkMinorDeviceClass,
@@ -743,17 +749,17 @@ class ClassOfDevice:
major_service_classes: MajorServiceClasses
major_device_class: MajorDeviceClass
minor_device_class: (
ComputerMinorDeviceClass
| PhoneMinorDeviceClass
| LanNetworkMinorDeviceClass
| AudioVideoMinorDeviceClass
| PeripheralMinorDeviceClass
| WearableMinorDeviceClass
| ToyMinorDeviceClass
| HealthMinorDeviceClass
| int
)
minor_device_class: Union[
ComputerMinorDeviceClass,
PhoneMinorDeviceClass,
LanNetworkMinorDeviceClass,
AudioVideoMinorDeviceClass,
PeripheralMinorDeviceClass,
WearableMinorDeviceClass,
ToyMinorDeviceClass,
HealthMinorDeviceClass,
int,
]
@classmethod
def from_int(cls, class_of_device: int) -> Self:
@@ -1542,7 +1548,7 @@ class DataType:
return f"{self.__class__.__name__}({self.value_string()})"
@classmethod
def from_advertising_data(cls, advertising_data: AdvertisingData) -> Self | None:
def from_advertising_data(cls, advertising_data: AdvertisingData) -> Optional[Self]:
if (data := advertising_data.get(cls.ad_type, raw=True)) is None:
return None
@@ -1570,16 +1576,16 @@ class DataType:
# -----------------------------------------------------------------------------
# Advertising Data
# -----------------------------------------------------------------------------
AdvertisingDataObject = (
list[UUID]
| tuple[UUID, bytes]
| bytes
| str
| int
| tuple[int, int]
| tuple[int, bytes]
| Appearance
)
AdvertisingDataObject = Union[
list[UUID],
tuple[UUID, bytes],
bytes,
str,
int,
tuple[int, int],
tuple[int, bytes],
Appearance,
]
class AdvertisingData:
@@ -1716,7 +1722,7 @@ class AdvertisingData:
def __init__(
self,
ad_structures: Iterable[tuple[int, bytes] | DataType] | None = None,
ad_structures: Optional[Iterable[Union[tuple[int, bytes], DataType]]] = None,
) -> None:
if ad_structures is None:
ad_structures = []
@@ -2014,7 +2020,7 @@ class AdvertisingData:
AdvertisingData.Type.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
],
raw: Literal[False] = False,
) -> list[UUID] | None: ...
) -> Optional[list[UUID]]: ...
@overload
def get(
@@ -2025,7 +2031,7 @@ class AdvertisingData:
AdvertisingData.Type.SERVICE_DATA_128_BIT_UUID,
],
raw: Literal[False] = False,
) -> tuple[UUID, bytes] | None: ...
) -> Optional[tuple[UUID, bytes]]: ...
@overload
def get(
@@ -2037,7 +2043,7 @@ class AdvertisingData:
AdvertisingData.Type.BROADCAST_NAME,
],
raw: Literal[False] = False,
) -> str | None: ...
) -> Optional[Optional[str]]: ...
@overload
def get(
@@ -2049,36 +2055,38 @@ class AdvertisingData:
AdvertisingData.Type.CLASS_OF_DEVICE,
],
raw: Literal[False] = False,
) -> int | None: ...
) -> Optional[int]: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.PERIPHERAL_CONNECTION_INTERVAL_RANGE,],
raw: Literal[False] = False,
) -> tuple[int, int] | None: ...
) -> Optional[tuple[int, int]]: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.MANUFACTURER_SPECIFIC_DATA,],
raw: Literal[False] = False,
) -> tuple[int, bytes] | None: ...
) -> Optional[tuple[int, bytes]]: ...
@overload
def get(
self,
type_id: Literal[AdvertisingData.Type.APPEARANCE,],
raw: Literal[False] = False,
) -> Appearance | None: ...
) -> Optional[Appearance]: ...
@overload
def get(self, type_id: int, raw: Literal[True]) -> bytes | None: ...
def get(self, type_id: int, raw: Literal[True]) -> Optional[bytes]: ...
@overload
def get(self, type_id: int, raw: bool = False) -> AdvertisingDataObject | None: ...
def get(
self, type_id: int, raw: bool = False
) -> Optional[AdvertisingDataObject]: ...
def get(self, type_id: int, raw: bool = False) -> AdvertisingDataObject | None:
def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingDataObject]:
'''
Get advertising data as a simple AdvertisingDataObject object.

View File

@@ -25,11 +25,10 @@ try:
from bumble.crypto.cryptography import EccKey, aes_cmac, e
except ImportError:
logging.getLogger(__name__).debug(
"Unable to import cryptography, using built-in primitives."
"Unable to import cryptography, use built-in primitives."
)
from bumble.crypto.builtin import EccKey, aes_cmac, e # type: ignore[assignment]
_EccKey = EccKey # For the linter only
# -----------------------------------------------------------------------------
# Logging

View File

@@ -29,6 +29,7 @@ import dataclasses
import functools
import secrets
import struct
from typing import Optional
from bumble import core
@@ -84,6 +85,7 @@ class _AES:
# fmt: on
def __init__(self, key: bytes) -> None:
if len(key) not in (16, 24, 32):
raise core.InvalidArgumentError(f'Invalid key size {len(key)}')
@@ -110,6 +112,7 @@ class _AES:
r_con_pointer = 0
t = kc
while t < round_key_count:
tt = tk[kc - 1]
tk[0] ^= (
(self._S[(tt >> 16) & 0xFF] << 24)
@@ -266,6 +269,7 @@ class _ECB:
class _CBC:
def __init__(self, key: bytes, iv: bytes = bytes(16)) -> None:
if len(iv) != 16:
raise core.InvalidArgumentError(
@@ -298,6 +302,7 @@ class _CBC:
class _CMAC:
def __init__(
self,
key: bytes,
@@ -308,7 +313,7 @@ class _CMAC:
self.digest_size = mac_len
self._key = key
self._block_size = bs = 16
self._mac_tag: bytes | None = None
self._mac_tag: Optional[bytes] = None
self._update_after_digest = update_after_digest
# Section 5.3 of NIST SP 800 38B and Appendix B
@@ -347,7 +352,7 @@ class _CMAC:
self._last_ct = zero_block
# Last block that was encrypted with AES
self._last_pt: bytes | None = None
self._last_pt: Optional[bytes] = None
# Counter for total message size
self._data_size = 0
@@ -409,6 +414,7 @@ class _CMAC:
self._last_pt = _xor(second_last, data_block[-bs:])
def digest(self) -> bytes:
bs = self._block_size
if self._mac_tag is not None and not self._update_after_digest:

View File

@@ -25,8 +25,7 @@ from __future__ import annotations
import dataclasses
import math
import struct
from collections.abc import Sequence
from typing import Any, ClassVar
from typing import Any, ClassVar, Sequence
from typing_extensions import Self

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
# -----------------------------------------------------------------------------
# Constants
@@ -166,12 +167,12 @@ class G722Decoder:
# The initial value in BLOCK 3H
self._band[1].det = 8
def decode_frame(self, encoded_data: bytes | bytearray) -> bytearray:
def decode_frame(self, encoded_data: Union[bytes, bytearray]) -> bytearray:
result_array = bytearray(len(encoded_data) * 4)
self.g722_decode(result_array, encoded_data)
return result_array
def g722_decode(self, result_array, encoded_data: bytes | bytearray) -> int:
def g722_decode(self, result_array, encoded_data: Union[bytes, bytearray]) -> int:
"""Decode the data frame using g722 decoder."""
result_length = 0

File diff suppressed because it is too large Load Diff

View File

@@ -24,8 +24,7 @@ from __future__ import annotations
import logging
import pathlib
import platform
from collections.abc import Iterable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable, Optional
from bumble.drivers import intel, rtk
from bumble.drivers.common import Driver
@@ -42,7 +41,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
async def get_driver_for_host(host: Host) -> Driver | None:
async def get_driver_for_host(host: Host) -> Optional[Driver]:
"""Probe diver classes until one returns a valid instance for a host, or none is
found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
@@ -50,10 +49,6 @@ async def get_driver_for_host(host: Host) -> Driver | None:
driver_classes: dict[str, type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
probe_list: Iterable[str]
if driver_name := host.hci_metadata.get("driver"):
# The "driver" metadata may include runtime options after a '/' (for example
# "intel/ddc=..."). Keep only the base driver name (the portion before the
# first slash) so it matches a key in driver_classes (e.g. "intel").
driver_name = driver_name.split("/")[0]
# Only probe a single driver
probe_list = [driver_name]
else:

View File

@@ -29,7 +29,7 @@ import os
import pathlib
import platform
import struct
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from bumble import core, hci, utils
from bumble.drivers import common
@@ -353,8 +353,8 @@ class Driver(common.Driver):
self.reset_complete = asyncio.Event()
# Parse configuration options from the driver name.
self.ddc_addon: bytes | None = None
self.ddc_override: bytes | None = None
self.ddc_addon: Optional[bytes] = None
self.ddc_override: Optional[bytes] = None
driver = host.hci_metadata.get("driver")
if driver is not None and driver.startswith("intel/"):
for key, value in [
@@ -380,7 +380,7 @@ class Driver(common.Driver):
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) not in known list"
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
)
return False
@@ -459,10 +459,6 @@ class Driver(common.Driver):
== ModeOfOperation.OPERATIONAL
):
logger.debug("firmware already loaded")
# If the firmeare is already loaded, still attempt to load any
# device configuration (DDC). DDC can be applied independently of a
# firmware reload and may contain runtime overrides or patches.
await self.load_ddc_if_any()
return
# We only support some platforms and variants.
@@ -483,7 +479,9 @@ class Driver(common.Driver):
raise DriverError("insufficient device info, missing CNVI or CNVR")
firmware_base_name = (
f"ibt-{device_info[ValueType.CNVI]:04X}-{device_info[ValueType.CNVR]:04X}"
"ibt-"
f"{device_info[ValueType.CNVI]:04X}-"
f"{device_info[ValueType.CNVR]:04X}"
)
logger.debug(f"FW base name: {firmware_base_name}")
@@ -600,39 +598,17 @@ class Driver(common.Driver):
await self.reset_complete.wait()
logger.debug("reset complete")
await self.load_ddc_if_any(firmware_base_name)
async def load_ddc_if_any(self, firmware_base_name: str | None = None) -> None:
"""
Check for and load any Device Data Configuration (DDC) blobs.
Args:
firmware_base_name: Base name of the selected firmware (e.g. "ibt-XXXX-YYYY").
If None, don't attempt to look up a .ddc file that
corresponds to the firmware image.
Priority:
1. If a ddc_override was provided via driver metadata, use it (highest priority).
2. Otherwise, if firmware_base_name is provided, attempt to find a .ddc file
that corresponds to the selected firmware image.
3. Finally, if a ddc_addon was provided, append/load it after the primary DDC.
"""
# If an explicit DDC override was supplied, use it and skip file lookup.
# Load the device config if there is one.
if self.ddc_override:
logger.debug("loading overridden DDC")
await self.load_device_config(self.ddc_override)
else:
# Only attempt .ddc file lookup if a firmware_base_name was provided.
if firmware_base_name is None:
logger.debug(
"no firmware_base_name provided; skipping .ddc file lookup"
)
else:
ddc_name = f"{firmware_base_name}.ddc"
ddc_path = _find_binary_path(ddc_name)
if ddc_path:
logger.debug(f"loading DDC from {ddc_path}")
ddc_data = ddc_path.read_bytes()
await self.load_device_config(ddc_data)
ddc_name = f"{firmware_base_name}.ddc"
ddc_path = _find_binary_path(ddc_name)
if ddc_path:
logger.debug(f"loading DDC from {ddc_path}")
ddc_data = ddc_path.read_bytes()
await self.load_device_config(ddc_data)
if self.ddc_addon:
logger.debug("loading DDC addon")
await self.load_device_config(self.ddc_addon)

View File

@@ -115,14 +115,12 @@ RTK_USB_PRODUCTS = {
# Realtek 8761BUV
(0x0B05, 0x190E),
(0x0BDA, 0x8771),
(0x0BDA, 0x877B),
(0x0BDA, 0xA728),
(0x0BDA, 0xA729),
(0x2230, 0x0016),
(0x2357, 0x0604),
(0x2550, 0x8761),
(0x2B89, 0x8761),
(0x7392, 0xC611),
(0x0BDA, 0x877B),
# Realtek 8821AE
(0x0B05, 0x17DC),
(0x13D3, 0x3414),
@@ -484,7 +482,7 @@ class Driver(common.Driver):
if (vendor_id, product_id) not in RTK_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) not in known list"
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
)
return False

View File

@@ -28,10 +28,9 @@ import enum
import functools
import logging
import struct
from collections.abc import Iterable, Sequence
from typing import TypeVar
from typing import Iterable, Optional, Sequence, TypeVar, Union
from bumble.att import Attribute, AttributeValue, AttributeValueV2
from bumble.att import Attribute, AttributeValue
from bumble.colors import color
from bumble.core import UUID, BaseBumbleError
@@ -228,6 +227,7 @@ GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB3, 'Bearer Provider Name')
@@ -356,7 +356,7 @@ class Service(Attribute):
def __init__(
self,
uuid: str | UUID,
uuid: Union[str, UUID],
characteristics: Iterable[Characteristic],
primary=True,
included_services: Iterable[Service] = (),
@@ -379,7 +379,7 @@ class Service(Attribute):
self.characteristics = list(characteristics)
self.primary = primary
def get_advertising_data(self) -> bytes | None:
def get_advertising_data(self) -> Optional[bytes]:
"""
Get Service specific advertising data
Defined by each Service, default value is empty
@@ -503,10 +503,10 @@ class Characteristic(Attribute[_T]):
def __init__(
self,
uuid: str | bytes | UUID,
uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties,
permissions: str | Attribute.Permissions,
value: AttributeValue[_T] | _T | None = None,
permissions: Union[str, Attribute.Permissions],
value: Union[AttributeValue[_T], _T, None] = None,
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value)
@@ -579,7 +579,7 @@ class Descriptor(Attribute):
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, (AttributeValue, AttributeValueV2)):
elif isinstance(self.value, CharacteristicValue):
value_str = '<dynamic>'
else:
value_str = '<...>'

View File

@@ -22,8 +22,7 @@
from __future__ import annotations
import struct
from collections.abc import Callable, Iterable
from typing import Any, Generic, Literal, TypeVar
from typing import Any, Callable, Generic, Iterable, Literal, Optional, TypeVar
from bumble import utils
from bumble.core import InvalidOperationError
@@ -75,8 +74,8 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter[_T]):
def __init__(
self,
characteristic: Characteristic,
encode: Callable[[_T], bytes] | None = None,
decode: Callable[[bytes], _T] | None = None,
encode: Optional[Callable[[_T], bytes]] = None,
decode: Optional[Callable[[bytes], _T]] = None,
):
super().__init__(characteristic)
self.encode = encode
@@ -102,8 +101,8 @@ class DelegatedCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T]):
def __init__(
self,
characteristic_proxy: CharacteristicProxy,
encode: Callable[[_T], bytes] | None = None,
decode: Callable[[bytes], _T] | None = None,
encode: Optional[Callable[[_T], bytes]] = None,
decode: Optional[Callable[[bytes], _T]] = None,
):
super().__init__(characteristic_proxy)
self.encode = encode
@@ -362,4 +361,5 @@ class EnumCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T3]):
def decode_value(self, value: bytes) -> _T3:
int_value = int.from_bytes(value, self.byteorder)
a = self.cls(int_value)
return self.cls(int_value)

View File

@@ -26,20 +26,21 @@
from __future__ import annotations
import asyncio
import functools
import logging
import struct
from collections.abc import Callable, Iterable
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterable,
Optional,
TypeVar,
overload,
Union,
)
from bumble import att, core, l2cap, utils
from bumble import att, core, utils
from bumble.colors import color
from bumble.core import UUID, InvalidStateError
from bumble.gatt import (
@@ -56,12 +57,12 @@ from bumble.gatt import (
)
from bumble.hci import HCI_Constant
if TYPE_CHECKING:
from bumble import device as device_module
# -----------------------------------------------------------------------------
# Typing
# -----------------------------------------------------------------------------
if TYPE_CHECKING:
from bumble.device import Connection
_T = TypeVar('_T')
# -----------------------------------------------------------------------------
@@ -191,7 +192,7 @@ class CharacteristicProxy(AttributeProxy[_T]):
self.descriptors_discovered = False
self.subscribers = {} # Map from subscriber to proxy subscriber
def get_descriptor(self, descriptor_type: UUID) -> DescriptorProxy | None:
def get_descriptor(self, descriptor_type: UUID) -> Optional[DescriptorProxy]:
for descriptor in self.descriptors:
if descriptor.type == descriptor_type:
return descriptor
@@ -203,7 +204,7 @@ class CharacteristicProxy(AttributeProxy[_T]):
async def subscribe(
self,
subscriber: Callable[[_T], Any] | None = None,
subscriber: Optional[Callable[[_T], Any]] = None,
prefer_notify: bool = True,
) -> None:
if subscriber is not None:
@@ -252,7 +253,7 @@ class ProfileServiceProxy:
SERVICE_CLASS: type[TemplateService]
@classmethod
def from_client(cls, client: Client) -> ProfileServiceProxy | None:
def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -263,14 +264,16 @@ class Client:
services: list[ServiceProxy]
cached_values: dict[int, tuple[datetime, bytes]]
notification_subscribers: dict[
int, set[CharacteristicProxy | Callable[[bytes], Any]]
int, set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
indication_subscribers: dict[int, set[CharacteristicProxy | Callable[[bytes], Any]]]
pending_response: asyncio.futures.Future[att.ATT_PDU] | None
pending_request: att.ATT_PDU | None
indication_subscribers: dict[
int, set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
pending_response: Optional[asyncio.futures.Future[att.ATT_PDU]]
pending_request: Optional[att.ATT_PDU]
def __init__(self, bearer: att.Bearer) -> None:
self.bearer = bearer
def __init__(self, connection: Connection) -> None:
self.connection = connection
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
@@ -280,78 +283,21 @@ class Client:
self.services = []
self.cached_values = {}
if att.is_enhanced_bearer(bearer):
bearer.on(bearer.EVENT_CLOSE, self.on_disconnection)
self._bearer_id = (
f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
)
# Fill the mtu.
bearer.on_att_mtu_update(att.ATT_DEFAULT_MTU)
self.connection = bearer.connection
else:
bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection)
self._bearer_id = f'[0x{bearer.handle:04X}]'
self.connection = bearer
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
) -> Client: ...
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client]: ...
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client] | Client:
channels = await connection.device.l2cap_channel_manager.create_enhanced_credit_based_channels(
connection,
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM),
count,
)
def on_pdu(client: Client, pdu: bytes):
client.on_gatt_pdu(att.ATT_PDU.from_bytes(pdu))
clients = [cls(channel) for channel in channels]
for channel, client in zip(channels, clients):
channel.sink = functools.partial(on_pdu, client)
channel.att_mtu = att.ATT_DEFAULT_MTU
return clients[0] if count == 1 else clients
@property
def mtu(self) -> int:
return self.bearer.att_mtu
@mtu.setter
def mtu(self, value: int) -> None:
self.bearer.on_att_mtu_update(value)
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
def send_gatt_pdu(self, pdu: bytes) -> None:
if att.is_enhanced_bearer(self.bearer):
self.bearer.write(pdu)
else:
self.bearer.send_l2cap_pdu(att.ATT_CID, pdu)
self.connection.send_l2cap_pdu(att.ATT_CID, pdu)
async def send_command(self, command: att.ATT_PDU) -> None:
logger.debug(f'GATT Command from client: {self._bearer_id} {command}')
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(bytes(command))
async def send_request(self, request: att.ATT_PDU):
logger.debug(f'GATT Request from client: {self._bearer_id} {request}')
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
# Wait until we can send (only one pending command at a time for the connection)
response = None
@@ -380,7 +326,10 @@ class Client:
def send_confirmation(
self, confirmation: att.ATT_Handle_Value_Confirmation
) -> None:
logger.debug(f'GATT Confirmation from client: {self._bearer_id} {confirmation}')
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int:
@@ -392,7 +341,7 @@ class Client:
# We can only send one request per connection
if self.mtu_exchange_done:
return self.mtu
return self.connection.att_mtu
# Send the request
self.mtu_exchange_done = True
@@ -403,15 +352,15 @@ class Client:
raise att.ATT_Error(error_code=response.error_code, message=response)
# Compute the final MTU
self.mtu = min(mtu, response.server_rx_mtu)
self.connection.att_mtu = min(mtu, response.server_rx_mtu)
return self.mtu
return self.connection.att_mtu
def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid(
self, uuid: UUID, service: ServiceProxy | None = None
self, uuid: UUID, service: Optional[ServiceProxy] = None
) -> list[CharacteristicProxy[bytes]]:
services = [service] if service else self.services
return [
@@ -420,14 +369,13 @@ class Client:
if c.uuid == uuid
]
def get_attribute_grouping(
self, attribute_handle: int
) -> (
ServiceProxy
| tuple[ServiceProxy, CharacteristicProxy]
| tuple[ServiceProxy, CharacteristicProxy, DescriptorProxy]
| None
):
def get_attribute_grouping(self, attribute_handle: int) -> Optional[
Union[
ServiceProxy,
tuple[ServiceProxy, CharacteristicProxy],
tuple[ServiceProxy, CharacteristicProxy, DescriptorProxy],
]
]:
"""
Get the attribute(s) associated with an attribute handle
"""
@@ -530,7 +478,7 @@ class Client:
return services
async def discover_service(self, uuid: str | UUID) -> list[ServiceProxy]:
async def discover_service(self, uuid: Union[str, UUID]) -> list[ServiceProxy]:
'''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
'''
@@ -664,7 +612,7 @@ class Client:
return included_services
async def discover_characteristics(
self, uuids, service: ServiceProxy | None
self, uuids, service: Optional[ServiceProxy]
) -> list[CharacteristicProxy[bytes]]:
'''
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
@@ -751,9 +699,9 @@ class Client:
async def discover_descriptors(
self,
characteristic: CharacteristicProxy | None = None,
start_handle: int | None = None,
end_handle: int | None = None,
characteristic: Optional[CharacteristicProxy] = None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
) -> list[DescriptorProxy]:
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
@@ -862,7 +810,7 @@ class Client:
async def subscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Callable[[Any], Any] | None = None,
subscriber: Optional[Callable[[Any], Any]] = None,
prefer_notify: bool = True,
) -> None:
# If we haven't already discovered the descriptors for this characteristic,
@@ -912,7 +860,7 @@ class Client:
async def unsubscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Callable[[Any], Any] | None = None,
subscriber: Optional[Callable[[Any], Any]] = None,
force: bool = False,
) -> None:
'''
@@ -977,7 +925,7 @@ class Client:
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(
self, attribute: int | AttributeProxy, no_long_read: bool = False
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> bytes:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -998,7 +946,7 @@ class Client:
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.mtu - 1:
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
@@ -1022,7 +970,7 @@ class Client:
part = response.part_attribute_value
attribute_value += part
if len(part) < self.mtu - 1:
if len(part) < self.connection.att_mtu - 1:
break
offset += len(part)
@@ -1032,7 +980,7 @@ class Client:
return attribute_value
async def read_characteristics_by_uuid(
self, uuid: UUID, service: ServiceProxy | None
self, uuid: UUID, service: Optional[ServiceProxy]
) -> list[bytes]:
'''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
@@ -1090,7 +1038,7 @@ class Client:
async def write_value(
self,
attribute: int | AttributeProxy,
attribute: Union[int, AttributeProxy],
value: bytes,
with_response: bool = False,
) -> None:
@@ -1118,13 +1066,14 @@ class Client:
)
)
def on_disconnection(self, *args) -> None:
del args # unused.
def on_disconnection(self, _) -> None:
if self.pending_response and not self.pending_response.done():
self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Response to client: {self._bearer_id} {att_pdu}')
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
if att_pdu.op_code in att.ATT_RESPONSES:
if self.pending_request is None:
# Not expected!
@@ -1154,7 +1103,8 @@ class Client:
else:
logger.warning(
color(
'--- Ignoring GATT Response from ' f'{self._bearer_id}: ',
'--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'red',
)
+ str(att_pdu)

View File

@@ -29,11 +29,11 @@ import asyncio
import logging
import struct
from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Iterable, Optional, TypeVar
from bumble import att, core, l2cap, utils
from bumble import att, utils
from bumble.colors import color
from bumble.core import UUID
from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
@@ -43,13 +43,14 @@ from bumble.gatt import (
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
CharacteristicDeclaration,
CharacteristicValue,
Descriptor,
IncludedServiceDeclaration,
Service,
)
if TYPE_CHECKING:
from bumble.device import Device
from bumble.device import Connection, Device
# -----------------------------------------------------------------------------
# Logging
@@ -63,18 +64,6 @@ logger = logging.getLogger(__name__)
GATT_SERVER_DEFAULT_MAX_MTU = 517
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
return f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
else:
return f'[0x{bearer.handle:04X}]'
# -----------------------------------------------------------------------------
# GATT Server
# -----------------------------------------------------------------------------
@@ -82,9 +71,9 @@ class Server(utils.EventEmitter):
attributes: list[att.Attribute]
services: list[Service]
attributes_by_handle: dict[int, att.Attribute]
subscribers: dict[att.Bearer, dict[int, bytes]]
indication_semaphores: defaultdict[att.Bearer, asyncio.Semaphore]
pending_confirmations: defaultdict[att.Bearer, asyncio.futures.Future | None]
subscribers: dict[int, dict[int, bytes]]
indication_semaphores: defaultdict[int, asyncio.Semaphore]
pending_confirmations: defaultdict[int, Optional[asyncio.futures.Future]]
EVENT_CHARACTERISTIC_SUBSCRIPTION = "characteristic_subscription"
@@ -106,29 +95,8 @@ class Server(utils.EventEmitter):
def __str__(self) -> str:
return "\n".join(map(str, self.attributes))
def register_eatt(
self, spec: l2cap.LeCreditBasedChannelSpec | None = None
) -> l2cap.LeCreditBasedChannelServer:
def on_channel(channel: l2cap.LeCreditBasedChannel):
logger.debug(
"New EATT Bearer Connection=0x%04X CID=0x%04X",
channel.connection.handle,
channel.source_cid,
)
channel.att_mtu = att.ATT_DEFAULT_MTU
channel.sink = lambda pdu: self.on_gatt_pdu(
channel, att.ATT_PDU.from_bytes(pdu)
)
return self.device.create_l2cap_server(
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM), handler=on_channel
)
def send_gatt_pdu(self, bearer: att.Bearer, pdu: bytes) -> None:
if att.is_enhanced_bearer(bearer):
bearer.write(pdu)
else:
self.device.send_l2cap_pdu(bearer.handle, att.ATT_CID, pdu)
def send_gatt_pdu(self, connection_handle: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(connection_handle, att.ATT_CID, pdu)
def next_handle(self) -> int:
return 1 + len(self.attributes)
@@ -141,7 +109,7 @@ class Server(utils.EventEmitter):
and (data := attribute.get_advertising_data())
}
def get_attribute(self, handle: int) -> att.Attribute | None:
def get_attribute(self, handle: int) -> Optional[att.Attribute]:
attribute = self.attributes_by_handle.get(handle)
if attribute:
return attribute
@@ -158,7 +126,7 @@ class Server(utils.EventEmitter):
def get_attribute_group(
self, handle: int, group_type: type[AttributeGroupType]
) -> AttributeGroupType | None:
) -> Optional[AttributeGroupType]:
return next(
(
attribute
@@ -169,7 +137,7 @@ class Server(utils.EventEmitter):
None,
)
def get_service_attribute(self, service_uuid: core.UUID) -> Service | None:
def get_service_attribute(self, service_uuid: UUID) -> Optional[Service]:
return next(
(
attribute
@@ -182,8 +150,8 @@ class Server(utils.EventEmitter):
)
def get_characteristic_attributes(
self, service_uuid: core.UUID, characteristic_uuid: core.UUID
) -> tuple[CharacteristicDeclaration, Characteristic] | None:
self, service_uuid: UUID, characteristic_uuid: UUID
) -> Optional[tuple[CharacteristicDeclaration, Characteristic]]:
service_handle = self.get_service_attribute(service_uuid)
if not service_handle:
return None
@@ -207,11 +175,8 @@ class Server(utils.EventEmitter):
)
def get_descriptor_attribute(
self,
service_uuid: core.UUID,
characteristic_uuid: core.UUID,
descriptor_uuid: core.UUID,
) -> Descriptor | None:
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
) -> Optional[Descriptor]:
characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid
)
@@ -291,7 +256,14 @@ class Server(utils.EventEmitter):
Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
att.Attribute.READABLE | att.Attribute.WRITEABLE,
self.make_descriptor_value(characteristic),
CharacteristicValue(
read=lambda connection, characteristic=characteristic: self.read_cccd(
connection, characteristic
),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
)
)
@@ -307,21 +279,10 @@ class Server(utils.EventEmitter):
for service in services:
self.add_service(service)
def make_descriptor_value(
self, characteristic: Characteristic
) -> att.AttributeValueV2:
# It is necessary to use Attribute Value V2 here to identify the bearer of CCCD.
return att.AttributeValueV2(
lambda bearer, characteristic=characteristic: self.read_cccd(
bearer, characteristic
),
write=lambda bearer, value, characteristic=characteristic: self.write_cccd(
bearer, characteristic, value
),
)
def read_cccd(self, bearer: att.Bearer, characteristic: Characteristic) -> bytes:
subscribers = self.subscribers.get(bearer)
def read_cccd(
self, connection: Connection, characteristic: Characteristic
) -> bytes:
subscribers = self.subscribers.get(connection.handle)
cccd = None
if subscribers:
cccd = subscribers.get(characteristic.handle)
@@ -330,12 +291,12 @@ class Server(utils.EventEmitter):
def write_cccd(
self,
bearer: att.Bearer,
connection: Connection,
characteristic: Characteristic,
value: bytes,
) -> None:
logger.debug(
f'Subscription update for connection={_bearer_id(bearer)}, '
f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
@@ -344,60 +305,41 @@ class Server(utils.EventEmitter):
logger.warning('CCCD value not 2 bytes long')
return
cccds = self.subscribers.setdefault(bearer, {})
cccds = self.subscribers.setdefault(connection.handle, {})
cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}')
notify_enabled = value[0] & 0x01 != 0
indicate_enabled = value[0] & 0x02 != 0
characteristic.emit(
characteristic.EVENT_SUBSCRIPTION,
bearer,
connection,
notify_enabled,
indicate_enabled,
)
self.emit(
self.EVENT_CHARACTERISTIC_SUBSCRIPTION,
bearer,
connection,
characteristic,
notify_enabled,
indicate_enabled,
)
def send_response(self, bearer: att.Bearer, response: att.ATT_PDU) -> None:
logger.debug(f'GATT Response from server: {_bearer_id(bearer)} {response}')
self.send_gatt_pdu(bearer, bytes(response))
def send_response(self, connection: Connection, response: att.ATT_PDU) -> None:
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, bytes(response))
async def notify_subscriber(
self,
bearer: att.Bearer,
connection: Connection,
attribute: att.Attribute,
value: bytes | None = None,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to notify all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._notify_single_subscriber(bearer, attribute, value, force)
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(bearer)
subscribers = self.subscribers.get(connection.handle)
if not subscribers:
logger.debug('not notifying, no subscribers')
return
@@ -413,53 +355,34 @@ class Server(utils.EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(bearer)
await attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
# Notify
notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification))
logger.debug(
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
)
self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(
self,
bearer: att.Bearer,
connection: Connection,
attribute: att.Attribute,
value: bytes | None = None,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
return await self._notify_single_subscriber(bearer, attribute, value, force)
else:
# If API is called to a Connection and not forced, try to indicate all subscribed bearers on it.
bearers = [
channel
for channel in self.device.l2cap_channel_manager.le_coc_channels.get(
bearer.handle, {}
).values()
if channel.psm == att.EATT_PSM
] + [bearer]
for bearer in bearers:
await self._indicate_single_bearer(bearer, attribute, value, force)
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
force: bool,
) -> None:
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(bearer)
subscribers = self.subscribers.get(connection.handle)
if not subscribers:
logger.debug('not indicating, no subscribers')
return
@@ -475,71 +398,73 @@ class Server(utils.EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(bearer)
await attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
if len(value) > connection.att_mtu - 3:
value = value[: connection.att_mtu - 3]
# Indicate
indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
logger.debug(
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
)
# Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[bearer]:
assert self.pending_confirmations[bearer] is None
async with self.indication_semaphores[connection.handle]:
assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response
pending_confirmation = self.pending_confirmations[bearer] = (
pending_confirmation = self.pending_confirmations[connection.handle] = (
asyncio.get_running_loop().create_future()
)
try:
self.send_gatt_pdu(bearer, bytes(indication))
self.send_gatt_pdu(connection.handle, bytes(indication))
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') from error
finally:
self.pending_confirmations[bearer] = None
self.pending_confirmations[connection.handle] = None
async def _notify_or_indicate_subscribers(
self,
indicate: bool,
attribute: att.Attribute,
value: bytes | None = None,
value: Optional[bytes] = None,
force: bool = False,
) -> None:
# Get all the bearers for which there's at least one subscription
bearers: list[att.Bearer] = [
bearer
for bearer, subscribers in self.subscribers.items()
if force or subscribers.get(attribute.handle)
# Get all the connections for which there's at least one subscription
connections = [
connection
for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
]
if connection is not None
]
# Indicate or notify for each connection
if bearers:
coroutine = (
self._indicate_single_bearer
if indicate
else self._notify_single_subscriber
)
if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
await asyncio.wait(
[
asyncio.create_task(coroutine(bearer, attribute, value, force))
for bearer in bearers
asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
]
)
async def notify_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
value: Optional[bytes] = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(
@@ -549,23 +474,26 @@ class Server(utils.EventEmitter):
async def indicate_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
value: Optional[bytes] = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
def on_disconnection(self, bearer: att.Bearer) -> None:
self.subscribers.pop(bearer, None)
self.indication_semaphores.pop(bearer, None)
self.pending_confirmations.pop(bearer, None)
def on_disconnection(self, connection: Connection) -> None:
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
del self.indication_semaphores[connection.handle]
if connection.handle in self.pending_confirmations:
del self.pending_confirmations[connection.handle]
def on_gatt_pdu(self, bearer: att.Bearer, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Request to server: {_bearer_id(bearer)} {att_pdu}')
def on_gatt_pdu(self, connection: Connection, att_pdu: att.ATT_PDU) -> None:
logger.debug(f'GATT Request to server: [0x{connection.handle:04X}] {att_pdu}')
handler_name = f'on_{att_pdu.name.lower()}'
handler = getattr(self, handler_name, None)
if handler is not None:
try:
handler(bearer, att_pdu)
handler(connection, att_pdu)
except att.ATT_Error as error:
logger.debug(f'normal exception returned by handler: {error}')
response = att.ATT_Error_Response(
@@ -573,7 +501,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=error.att_handle,
error_code=error.error_code,
)
self.send_response(bearer, response)
self.send_response(connection, response)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
response = att.ATT_Error_Response(
@@ -581,18 +509,18 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000,
error_code=att.ATT_UNLIKELY_ERROR_ERROR,
)
self.send_response(bearer, response)
self.send_response(connection, response)
raise
else:
# No specific handler registered
if att_pdu.op_code in att.ATT_REQUESTS:
# Invoke the generic handler
self.on_att_request(bearer, att_pdu)
self.on_att_request(connection, att_pdu)
else:
# Just ignore
logger.warning(
color(
f'--- Ignoring GATT Request from {_bearer_id(bearer)}: ',
f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ',
'red',
)
+ str(att_pdu)
@@ -601,14 +529,13 @@ class Server(utils.EventEmitter):
#######################################################
# ATT handlers
#######################################################
def on_att_request(self, bearer: att.Bearer, pdu: att.ATT_PDU) -> None:
def on_att_request(self, connection: Connection, pdu: att.ATT_PDU) -> None:
'''
Handler for requests without a more specific handler
'''
logger.warning(
color(
f'--- Unsupported ATT Request from {_bearer_id(bearer)}: ',
'red',
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red'
)
+ str(pdu)
)
@@ -617,28 +544,29 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=0x0000,
error_code=att.ATT_REQUEST_NOT_SUPPORTED_ERROR,
)
self.send_response(bearer, response)
self.send_response(connection, response)
def on_att_exchange_mtu_request(
self, bearer: att.Bearer, request: att.ATT_Exchange_MTU_Request
self, connection: Connection, request: att.ATT_Exchange_MTU_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
self.send_response(
bearer, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
connection, att.ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
)
# Compute the final MTU
if request.client_rx_mtu >= att.ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)
bearer.on_att_mtu_update(mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')
def on_att_find_information_request(
self, bearer: att.Bearer, request: att.ATT_Find_Information_Request
self, connection: Connection, request: att.ATT_Find_Information_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -651,7 +579,7 @@ class Server(utils.EventEmitter):
or request.starting_handle > request.ending_handle
):
self.send_response(
bearer,
connection,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
@@ -661,7 +589,7 @@ class Server(utils.EventEmitter):
return
# Build list of returned attributes
pdu_space_available = bearer.att_mtu - 2
pdu_space_available = connection.att_mtu - 2
attributes: list[att.Attribute] = []
uuid_size = 0
for attribute in (
@@ -703,18 +631,18 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(bearer, response)
self.send_response(connection, response)
@utils.AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(
self, bearer: att.Bearer, request: att.ATT_Find_By_Type_Value_Request
self, connection: Connection, request: att.ATT_Find_By_Type_Value_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
# Build list of returned attributes
pdu_space_available = bearer.att_mtu - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
response: att.ATT_PDU
async for attribute in (
@@ -723,7 +651,7 @@ class Server(utils.EventEmitter):
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and (await attribute.read_value(bearer)) == request.attribute_value
and (await attribute.read_value(connection)) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -759,17 +687,17 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(bearer, response)
self.send_response(connection, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_by_type_request(
self, bearer: att.Bearer, request: att.ATT_Read_By_Type_Request
self, connection: Connection, request: att.ATT_Read_By_Type_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
pdu_space_available = bearer.att_mtu - 2
pdu_space_available = connection.att_mtu - 2
response: att.ATT_PDU = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -787,7 +715,7 @@ class Server(utils.EventEmitter):
and pdu_space_available
):
try:
attribute_value = await attribute.read_value(bearer)
attribute_value = await attribute.read_value(connection)
except att.ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -800,7 +728,7 @@ class Server(utils.EventEmitter):
break
# Check the attribute value size
max_attribute_size = min(bearer.att_mtu - 4, 253)
max_attribute_size = min(connection.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -827,11 +755,11 @@ class Server(utils.EventEmitter):
else:
logging.debug(f"not found {request}")
self.send_response(bearer, response)
self.send_response(connection, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_request(
self, bearer: att.Bearer, request: att.ATT_Read_Request
self, connection: Connection, request: att.ATT_Read_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
@@ -840,7 +768,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(bearer)
value = await attribute.read_value(connection)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -848,7 +776,7 @@ class Server(utils.EventEmitter):
error_code=error.error_code,
)
else:
value_size = min(bearer.att_mtu - 1, len(value))
value_size = min(connection.att_mtu - 1, len(value))
response = att.ATT_Read_Response(attribute_value=value[:value_size])
else:
response = att.ATT_Error_Response(
@@ -856,11 +784,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(bearer, response)
self.send_response(connection, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_blob_request(
self, bearer: att.Bearer, request: att.ATT_Read_Blob_Request
self, connection: Connection, request: att.ATT_Read_Blob_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -869,7 +797,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(bearer)
value = await attribute.read_value(connection)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -883,7 +811,7 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= bearer.att_mtu - 1:
elif len(value) <= connection.att_mtu - 1:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -891,7 +819,7 @@ class Server(utils.EventEmitter):
)
else:
part_size = min(
bearer.att_mtu - 1, len(value) - request.value_offset
connection.att_mtu - 1, len(value) - request.value_offset
)
response = att.ATT_Read_Blob_Response(
part_attribute_value=value[
@@ -904,11 +832,11 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.attribute_handle,
error_code=att.ATT_INVALID_HANDLE_ERROR,
)
self.send_response(bearer, response)
self.send_response(connection, response)
@utils.AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(
self, bearer: att.Bearer, request: att.ATT_Read_By_Group_Type_Request
self, connection: Connection, request: att.ATT_Read_By_Group_Type_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -923,10 +851,10 @@ class Server(utils.EventEmitter):
attribute_handle_in_error=request.starting_handle,
error_code=att.ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
)
self.send_response(bearer, response)
self.send_response(connection, response)
return
pdu_space_available = bearer.att_mtu - 2
pdu_space_available = connection.att_mtu - 2
attributes: list[tuple[int, int, bytes]] = []
for attribute in (
attribute
@@ -938,9 +866,9 @@ class Server(utils.EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(bearer)
attribute_value = await attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(bearer.att_mtu - 6, 251)
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
@@ -975,11 +903,11 @@ class Server(utils.EventEmitter):
error_code=att.ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(bearer, response)
self.send_response(connection, response)
@utils.AsyncRunner.run_in_task()
async def on_att_write_request(
self, bearer: att.Bearer, request: att.ATT_Write_Request
self, connection: Connection, request: att.ATT_Write_Request
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
@@ -989,7 +917,7 @@ class Server(utils.EventEmitter):
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(
bearer,
connection,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -1003,7 +931,7 @@ class Server(utils.EventEmitter):
# Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response(
bearer,
connection,
att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
@@ -1015,7 +943,7 @@ class Server(utils.EventEmitter):
response: att.ATT_PDU
try:
# Accept the value
await attribute.write_value(bearer, request.attribute_value)
await attribute.write_value(connection, request.attribute_value)
except att.ATT_Error as error:
response = att.ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -1025,11 +953,11 @@ class Server(utils.EventEmitter):
else:
# Done
response = att.ATT_Write_Response()
self.send_response(bearer, response)
self.send_response(connection, response)
@utils.AsyncRunner.run_in_task()
async def on_att_write_command(
self, bearer: att.Bearer, request: att.ATT_Write_Command
self, connection: Connection, request: att.ATT_Write_Command
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
@@ -1048,20 +976,22 @@ class Server(utils.EventEmitter):
# Accept the value
try:
await attribute.write_value(bearer, request.attribute_value)
await attribute.write_value(connection, request.attribute_value)
except Exception:
logger.exception('!!! ignoring exception')
def on_att_handle_value_confirmation(
self,
bearer: att.Bearer,
connection: Connection,
confirmation: att.ATT_Handle_Value_Confirmation,
):
'''
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
'''
del confirmation # Unused.
if (pending_confirmation := self.pending_confirmations[bearer]) is None:
if (
pending_confirmation := self.pending_confirmations[connection.handle]
) is None:
# Not expected!
logger.warning(
'!!! unexpected confirmation, there is no pending indication'

View File

@@ -24,13 +24,17 @@ import functools
import logging
import secrets
import struct
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Sequence
from dataclasses import field
from typing import (
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
TypeVar,
Union,
cast,
)
@@ -102,7 +106,7 @@ def map_class_of_device(class_of_device):
)
def phy_list_to_bits(phys: Iterable[Phy] | None) -> int:
def phy_list_to_bits(phys: Optional[Iterable[Phy]]) -> int:
if phys is None:
return 0
@@ -115,6 +119,7 @@ def phy_list_to_bits(phys: Iterable[Phy] | None) -> int:
class SpecableEnum(utils.OpenIntEnum):
@classmethod
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
return {
@@ -142,6 +147,7 @@ class SpecableEnum(utils.OpenIntEnum):
class SpecableFlag(enum.IntFlag):
@classmethod
def type_spec(cls, size: int, byteorder: Literal['little', 'big'] = 'little'):
return {
@@ -180,8 +186,8 @@ class SpecableFlag(enum.IntFlag):
# - "v" for variable length bytes with a leading length byte
# - an integer [1, 4] for 1-byte, 2-byte or 4-byte unsigned little-endian integers
# - an integer [-2, -1] for 1-byte, 2-byte signed little-endian integers
FieldSpec = dict[str, Any] | Callable[[bytes, int], tuple[int, Any]] | str | int
Fields = Sequence['tuple[str, FieldSpec] | Fields']
FieldSpec = Union[dict[str, Any], Callable[[bytes, int], tuple[int, Any]], str, int]
Fields = Sequence[Union[tuple[str, FieldSpec], 'Fields']]
@dataclasses.dataclass
@@ -207,44 +213,22 @@ def metadata(
HCI_VENDOR_OGF = 0x3F
# Specification Version
class SpecificationVersion(utils.OpenIntEnum):
BLUETOOTH_CORE_1_0B = 0
BLUETOOTH_CORE_1_1 = 1
BLUETOOTH_CORE_1_2 = 2
BLUETOOTH_CORE_2_0_EDR = 3
BLUETOOTH_CORE_2_1_EDR = 4
BLUETOOTH_CORE_3_0_HS = 5
BLUETOOTH_CORE_4_0 = 6
BLUETOOTH_CORE_4_1 = 7
BLUETOOTH_CORE_4_2 = 8
BLUETOOTH_CORE_5_0 = 9
BLUETOOTH_CORE_5_1 = 10
BLUETOOTH_CORE_5_2 = 11
BLUETOOTH_CORE_5_3 = 12
BLUETOOTH_CORE_5_4 = 13
BLUETOOTH_CORE_6_0 = 14
BLUETOOTH_CORE_6_1 = 15
BLUETOOTH_CORE_6_2 = 16
# For backwards compatibility only
HCI_VERSION_BLUETOOTH_CORE_1_0B = SpecificationVersion.BLUETOOTH_CORE_1_0B
HCI_VERSION_BLUETOOTH_CORE_1_1 = SpecificationVersion.BLUETOOTH_CORE_1_1
HCI_VERSION_BLUETOOTH_CORE_1_2 = SpecificationVersion.BLUETOOTH_CORE_1_2
HCI_VERSION_BLUETOOTH_CORE_2_0_EDR = SpecificationVersion.BLUETOOTH_CORE_2_0_EDR
HCI_VERSION_BLUETOOTH_CORE_2_1_EDR = SpecificationVersion.BLUETOOTH_CORE_2_1_EDR
HCI_VERSION_BLUETOOTH_CORE_3_0_HS = SpecificationVersion.BLUETOOTH_CORE_3_0_HS
HCI_VERSION_BLUETOOTH_CORE_4_0 = SpecificationVersion.BLUETOOTH_CORE_4_0
HCI_VERSION_BLUETOOTH_CORE_4_1 = SpecificationVersion.BLUETOOTH_CORE_4_1
HCI_VERSION_BLUETOOTH_CORE_4_2 = SpecificationVersion.BLUETOOTH_CORE_4_2
HCI_VERSION_BLUETOOTH_CORE_5_0 = SpecificationVersion.BLUETOOTH_CORE_5_0
HCI_VERSION_BLUETOOTH_CORE_5_1 = SpecificationVersion.BLUETOOTH_CORE_5_1
HCI_VERSION_BLUETOOTH_CORE_5_2 = SpecificationVersion.BLUETOOTH_CORE_5_2
HCI_VERSION_BLUETOOTH_CORE_5_3 = SpecificationVersion.BLUETOOTH_CORE_5_3
HCI_VERSION_BLUETOOTH_CORE_5_4 = SpecificationVersion.BLUETOOTH_CORE_5_4
HCI_VERSION_BLUETOOTH_CORE_6_0 = SpecificationVersion.BLUETOOTH_CORE_6_0
HCI_VERSION_BLUETOOTH_CORE_6_1 = SpecificationVersion.BLUETOOTH_CORE_6_1
HCI_VERSION_BLUETOOTH_CORE_6_2 = SpecificationVersion.BLUETOOTH_CORE_6_2
# HCI Version
HCI_VERSION_BLUETOOTH_CORE_1_0B = 0
HCI_VERSION_BLUETOOTH_CORE_1_1 = 1
HCI_VERSION_BLUETOOTH_CORE_1_2 = 2
HCI_VERSION_BLUETOOTH_CORE_2_0_EDR = 3
HCI_VERSION_BLUETOOTH_CORE_2_1_EDR = 4
HCI_VERSION_BLUETOOTH_CORE_3_0_HS = 5
HCI_VERSION_BLUETOOTH_CORE_4_0 = 6
HCI_VERSION_BLUETOOTH_CORE_4_1 = 7
HCI_VERSION_BLUETOOTH_CORE_4_2 = 8
HCI_VERSION_BLUETOOTH_CORE_5_0 = 9
HCI_VERSION_BLUETOOTH_CORE_5_1 = 10
HCI_VERSION_BLUETOOTH_CORE_5_2 = 11
HCI_VERSION_BLUETOOTH_CORE_5_3 = 12
HCI_VERSION_BLUETOOTH_CORE_5_4 = 13
HCI_VERSION_BLUETOOTH_CORE_6_0 = 14
HCI_VERSION_NAMES = {
HCI_VERSION_BLUETOOTH_CORE_1_0B: 'HCI_VERSION_BLUETOOTH_CORE_1_0B',
@@ -262,10 +246,9 @@ HCI_VERSION_NAMES = {
HCI_VERSION_BLUETOOTH_CORE_5_3: 'HCI_VERSION_BLUETOOTH_CORE_5_3',
HCI_VERSION_BLUETOOTH_CORE_5_4: 'HCI_VERSION_BLUETOOTH_CORE_5_4',
HCI_VERSION_BLUETOOTH_CORE_6_0: 'HCI_VERSION_BLUETOOTH_CORE_6_0',
HCI_VERSION_BLUETOOTH_CORE_6_1: 'HCI_VERSION_BLUETOOTH_CORE_6_1',
HCI_VERSION_BLUETOOTH_CORE_6_2: 'HCI_VERSION_BLUETOOTH_CORE_6_2',
}
# LMP Version
LMP_VERSION_NAMES = HCI_VERSION_NAMES
# HCI Packet types
@@ -1803,20 +1786,20 @@ class HCI_Object:
@classmethod
def dict_and_offset_from_bytes(
cls, data: bytes, offset: int, object_fields: Fields
cls, data: bytes, offset: int, fields: Fields
) -> tuple[int, collections.OrderedDict[str, Any]]:
result = collections.OrderedDict[str, Any]()
for object_field in object_fields:
if isinstance(object_field, list):
for field in fields:
if isinstance(field, list):
# This is an array field, starting with a 1-byte item count.
item_count = data[offset]
offset += 1
# Set fields first, because item_count might be 0.
for sub_field_name, _ in object_field:
for sub_field_name, _ in field:
result[sub_field_name] = []
for _ in range(item_count):
for sub_field_name, sub_field_type in object_field:
for sub_field_name, sub_field_type in field:
value, size = HCI_Object.parse_field(
data, offset, sub_field_type
)
@@ -1824,7 +1807,7 @@ class HCI_Object:
offset += size
continue
field_name, field_type = object_field
field_name, field_type = field
assert isinstance(field_name, str)
field_value, field_size = HCI_Object.parse_field(
data, offset, cast(FieldSpec, field_type)
@@ -1907,26 +1890,26 @@ class HCI_Object:
return field_bytes
@staticmethod
def dict_to_bytes(hci_object, object_fields):
def dict_to_bytes(hci_object, fields):
result = bytearray()
for object_field in object_fields:
if isinstance(object_field, list):
for field in fields:
if isinstance(field, list):
# The field is an array. The serialized form starts with a 1-byte
# item count. We use the length of the first array field as the
# array count, since all array fields have the same number of items.
item_count = len(hci_object[object_field[0][0]])
item_count = len(hci_object[field[0][0]])
result += bytes([item_count]) + b''.join(
b''.join(
HCI_Object.serialize_field(
hci_object[sub_field_name][i], sub_field_type
)
for sub_field_name, sub_field_type in object_field
for sub_field_name, sub_field_type in field
)
for i in range(item_count)
)
continue
(field_name, field_type) = object_field
(field_name, field_type) = field
result += HCI_Object.serialize_field(hci_object[field_name], field_type)
return bytes(result)
@@ -1984,15 +1967,15 @@ class HCI_Object:
)
@staticmethod
def format_fields(hci_object, object_fields, indentation='', value_mappers=None):
if not object_fields:
def format_fields(hci_object, fields, indentation='', value_mappers=None):
if not fields:
return ''
# Build array of formatted key:value pairs
field_strings = []
for object_field in object_fields:
if isinstance(object_field, list):
for sub_field in object_field:
for field in fields:
if isinstance(field, list):
for sub_field in field:
sub_field_name, sub_field_type = sub_field
item_count = len(hci_object[sub_field_name])
for i in range(item_count):
@@ -2010,7 +1993,7 @@ class HCI_Object:
)
continue
field_name, field_type = object_field
field_name, field_type = field
field_value = hci_object[field_name]
field_strings.append(
(
@@ -2033,16 +2016,16 @@ class HCI_Object:
@classmethod
def fields_from_dataclass(cls, obj: Any) -> list[Any]:
stack: list[list[Any]] = [[]]
for object_field in dataclasses.fields(obj):
for field in dataclasses.fields(obj):
# Fields without metadata should be ignored.
if not isinstance(
(metadata := object_field.metadata.get("bumble.hci")), FieldMetadata
(metadata := field.metadata.get("bumble.hci")), FieldMetadata
):
continue
if metadata.list_begin:
stack.append([])
if metadata.spec:
stack[-1].append((object_field.name, metadata.spec))
stack[-1].append((field.name, metadata.spec))
if metadata.list_end:
top = stack.pop()
stack[-1].append(top)
@@ -2175,7 +2158,7 @@ class Address:
def __init__(
self,
address: bytes | str,
address: Union[bytes, str],
address_type: AddressType = RANDOM_DEVICE_ADDRESS,
) -> None:
'''
@@ -2440,9 +2423,9 @@ class HCI_Command(HCI_Packet):
def __init__(
self,
parameters: bytes | None = None,
parameters: Optional[bytes] = None,
*,
op_code: int | None = None,
op_code: Optional[int] = None,
**kwargs,
) -> None:
# op_code should be set in cls.
@@ -3458,17 +3441,6 @@ class HCI_Write_Synchronous_Flow_Control_Enable_Command(HCI_Command):
synchronous_flow_control_enable: int = field(metadata=metadata(1))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_Set_Controller_To_Host_Flow_Control_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.38 Set Controller To Host Flow Control command
'''
flow_control_enable: int = field(metadata=metadata(1))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -4366,15 +4338,6 @@ class HCI_LE_Write_Suggested_Default_Data_Length_Command(HCI_Command):
suggested_max_tx_time: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Read_Local_P_256_Public_Key_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.36 LE LE Read Local P-256 Public Key command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -4402,15 +4365,6 @@ class HCI_LE_Clear_Resolving_List_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Read_Resolving_List_Size_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.41 LE Read Resolving List Size command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -5074,15 +5028,6 @@ class HCI_LE_Periodic_Advertising_Terminate_Sync_Command(HCI_Command):
sync_handle: int = field(metadata=metadata(2))
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
class HCI_LE_Read_Transmit_Power_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.74 LE Read Transmit Power command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command
@dataclasses.dataclass
@@ -5733,7 +5678,7 @@ class HCI_Event(HCI_Packet):
hci_packet_type = HCI_EVENT_PACKET
event_names: dict[int, str] = {}
event_classes: dict[int, type[HCI_Event]] = {}
vendor_factories: list[Callable[[bytes], HCI_Event | None]] = []
vendor_factories: list[Callable[[bytes], Optional[HCI_Event]]] = []
event_code: int
fields: Fields = ()
_parameters: bytes = b''
@@ -5794,12 +5739,14 @@ class HCI_Event(HCI_Packet):
return event_class
@classmethod
def add_vendor_factory(cls, factory: Callable[[bytes], HCI_Event | None]) -> None:
def add_vendor_factory(
cls, factory: Callable[[bytes], Optional[HCI_Event]]
) -> None:
cls.vendor_factories.append(factory)
@classmethod
def remove_vendor_factory(
cls, factory: Callable[[bytes], HCI_Event | None]
cls, factory: Callable[[bytes], Optional[HCI_Event]]
) -> None:
if factory in cls.vendor_factories:
cls.vendor_factories.remove(factory)
@@ -5812,7 +5759,7 @@ class HCI_Event(HCI_Packet):
if len(parameters) != length:
raise InvalidPacketError('invalid packet length')
subclass: type[HCI_Event] | None
subclass: Optional[type[HCI_Event]]
if event_code == HCI_LE_META_EVENT:
# We do this dispatch here and not in the subclass in order to avoid call
# loops
@@ -5850,9 +5797,9 @@ class HCI_Event(HCI_Packet):
def __init__(
self,
parameters: bytes | None = None,
parameters: Optional[bytes] = None,
*,
event_code: int | None = None,
event_code: Optional[int] = None,
**kwargs,
):
if event_code is not None:
@@ -5961,7 +5908,9 @@ class HCI_Extended_Event(HCI_Event):
cls.subevent_names.update(cls.subevent_map(symbols))
@classmethod
def subclass_from_parameters(cls, parameters: bytes) -> HCI_Extended_Event | None:
def subclass_from_parameters(
cls, parameters: bytes
) -> Optional[HCI_Extended_Event]:
"""
Factory method that parses the subevent code, finds a registered subclass,
and creates an instance if found.
@@ -5981,9 +5930,9 @@ class HCI_Extended_Event(HCI_Event):
def __init__(
self,
parameters: bytes | None = None,
parameters: Optional[bytes] = None,
*,
subevent_code: int | None = None,
subevent_code: Optional[int] = None,
**kwargs,
) -> None:
if subevent_code is not None:
@@ -6979,7 +6928,7 @@ class HCI_Command_Complete_Event(HCI_Event):
command_opcode: int = field(
metadata=metadata({'size': 2, 'mapper': HCI_Command.command_name})
)
return_parameters: bytes | HCI_Object | int = field(metadata=metadata("*"))
return_parameters: Union[bytes, HCI_Object, int] = field(metadata=metadata("*"))
def map_return_parameters(self, return_parameters):
'''Map simple 'status' return parameters to their named constant form'''
@@ -7563,20 +7512,20 @@ class HCI_IsoDataPacket(HCI_Packet):
iso_sdu_fragment: bytes
pb_flag: int
ts_flag: int = 0
time_stamp: int | None = None
packet_sequence_number: int | None = None
iso_sdu_length: int | None = None
packet_status_flag: int | None = None
time_stamp: Optional[int] = None
packet_sequence_number: Optional[int] = None
iso_sdu_length: Optional[int] = None
packet_status_flag: Optional[int] = None
def __post_init__(self) -> None:
self.ts_flag = self.time_stamp is not None
@staticmethod
def from_bytes(packet: bytes) -> HCI_IsoDataPacket:
time_stamp: int | None = None
packet_sequence_number: int | None = None
iso_sdu_length: int | None = None
packet_status_flag: int | None = None
time_stamp: Optional[int] = None
packet_sequence_number: Optional[int] = None
iso_sdu_length: Optional[int] = None
packet_status_flag: Optional[int] = None
pos = 1
pdu_info, data_total_length = struct.unpack_from('<HH', packet, pos)
@@ -7659,7 +7608,7 @@ class HCI_IsoDataPacket(HCI_Packet):
# -----------------------------------------------------------------------------
class HCI_AclDataPacketAssembler:
current_data: bytes | None
current_data: Optional[bytes]
def __init__(self, callback: Callable[[bytes], Any]) -> None:
self.callback = callback

View File

@@ -20,7 +20,7 @@ from __future__ import annotations
import datetime
import logging
from collections.abc import Callable, MutableMapping
from typing import Any, cast
from typing import Any, Optional, cast
from bumble import avc, avctp, avdtp, avrcp, crypto, rfcomm, sdp
from bumble.att import ATT_CID, ATT_PDU
@@ -70,7 +70,7 @@ AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
class PacketTracer:
class AclStream:
psms: MutableMapping[int, int]
peer: PacketTracer.AclStream | None
peer: Optional[PacketTracer.AclStream]
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
@@ -201,7 +201,7 @@ class PacketTracer:
self.label = label
self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle
self.packet_timestamp: datetime.datetime | None = None
self.packet_timestamp: Optional[datetime.datetime] = None
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
logger.info(
@@ -230,7 +230,7 @@ class PacketTracer:
self.peer.end_acl_stream(connection_handle)
def on_packet(
self, timestamp: datetime.datetime | None, packet: HCI_Packet
self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
) -> None:
self.packet_timestamp = timestamp
self.emit(packet)
@@ -262,7 +262,7 @@ class PacketTracer:
self,
packet: HCI_Packet,
direction: int = 0,
timestamp: datetime.datetime | None = None,
timestamp: Optional[datetime.datetime] = None,
) -> None:
if direction == 0:
self.host_to_controller_analyzer.on_packet(timestamp, packet)

View File

@@ -25,8 +25,7 @@ import enum
import logging
import re
import traceback
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Union
from typing_extensions import Self
@@ -81,7 +80,7 @@ class HfpProtocol:
dlc.sink = self.feed
def feed(self, data: bytes | str) -> None:
def feed(self, data: Union[bytes, str]) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
@@ -325,8 +324,8 @@ class CallInfo:
status: CallInfoStatus
mode: CallInfoMode
multi_party: CallInfoMultiParty
number: str | None = None
type: int | None = None
number: Optional[str] = None
type: Optional[int] = None
@dataclasses.dataclass
@@ -354,10 +353,10 @@ class CallLineIdentification:
number: str
type: int
subaddr: str | None = None
satype: int | None = None
alpha: str | None = None
cli_validity: int | None = None
subaddr: Optional[str] = None
satype: Optional[int] = None
alpha: Optional[str] = None
cli_validity: Optional[int] = None
@classmethod
def parse_from(cls, parameters: list[bytes]) -> Self:
@@ -490,9 +489,9 @@ STATUS_CODES = {
@dataclasses.dataclass
class HfConfiguration:
supported_hf_features: collections.abc.Sequence[HfFeature]
supported_hf_indicators: collections.abc.Sequence[HfIndicator]
supported_audio_codecs: collections.abc.Sequence[AudioCodec]
supported_hf_features: list[HfFeature]
supported_hf_indicators: list[HfIndicator]
supported_audio_codecs: list[AudioCodec]
@dataclasses.dataclass
@@ -585,7 +584,7 @@ class AgIndicatorState:
indicator: AgIndicator
supported_values: set[int]
current_status: int
index: int | None = None
index: Optional[int] = None
enabled: bool = True
@property
@@ -598,7 +597,7 @@ class AgIndicatorState:
supported_values_text = (
f'({",".join(str(v) for v in self.supported_values)})'
)
return f'("{self.indicator.value}",{supported_values_text})'
return f'(\"{self.indicator.value}\",{supported_values_text})'
@classmethod
def call(cls: type[Self]) -> Self:
@@ -729,7 +728,7 @@ class HfProtocol(utils.EventEmitter):
command_lock: asyncio.Lock
if TYPE_CHECKING:
response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse | None]
unsolicited_queue: asyncio.Queue[Optional[AtResponse]]
else:
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
@@ -754,7 +753,7 @@ class HfProtocol(utils.EventEmitter):
# Build local features.
self.supported_hf_features = sum(configuration.supported_hf_features)
self.supported_audio_codecs = list(configuration.supported_audio_codecs)
self.supported_audio_codecs = configuration.supported_audio_codecs
self.hf_indicators = {
indicator: HfIndicatorState(indicator=indicator)
@@ -821,7 +820,7 @@ class HfProtocol(utils.EventEmitter):
cmd: str,
timeout: float = 1.0,
response_type: AtResponseType = AtResponseType.NONE,
) -> None | AtResponse | list[AtResponse]:
) -> Union[None, AtResponse, list[AtResponse]]:
"""
Sends an AT command and wait for the peer response.
Wait for the AT responses sent by the peer, to the status code.
@@ -1352,7 +1351,7 @@ class AgProtocol(utils.EventEmitter):
logger.warning(f'AG indicator {indicator} is disabled')
indicator_state.current_status = value
self.send_response(f'+CIEV: {index + 1},{value}')
self.send_response(f'+CIEV: {index+1},{value}')
async def negotiate_codec(self, codec: AudioCodec) -> None:
"""Starts codec negotiation."""
@@ -1412,13 +1411,13 @@ class AgProtocol(utils.EventEmitter):
self.emit(self.EVENT_VOICE_RECOGNITION, VoiceRecognitionState(int(vrec)))
def _on_chld(self, operation_code: bytes) -> None:
call_index: int | None = None
call_index: Optional[int] = None
if len(operation_code) > 1:
call_index = int(operation_code[1:])
operation_code = operation_code[:1] + b'x'
try:
operation = CallHoldOperation(operation_code.decode())
except Exception:
except:
logger.error(f'Invalid operation: {operation_code.decode()}')
self.send_cme_error(CmeError.OPERATION_NOT_SUPPORTED)
return
@@ -1482,8 +1481,8 @@ class AgProtocol(utils.EventEmitter):
def _on_cmer(
self,
mode: bytes,
keypad: bytes | None = None,
display: bytes | None = None,
keypad: Optional[bytes] = None,
display: Optional[bytes] = None,
indicator: bytes = b'',
) -> None:
if (
@@ -1590,7 +1589,7 @@ class AgProtocol(utils.EventEmitter):
def _on_clcc(self) -> None:
for call in self.calls:
number_text = f',"{call.number}"' if call.number is not None else ''
number_text = f',\"{call.number}\"' if call.number is not None else ''
type_text = f',{call.type}' if call.type is not None else ''
response = (
f'+CLCC: {call.index}'
@@ -1845,7 +1844,7 @@ def make_ag_sdp_records(
async def find_hf_sdp_record(
connection: device.Connection,
) -> tuple[int, ProfileVersion, HfSdpFeature] | None:
) -> Optional[tuple[int, ProfileVersion, HfSdpFeature]]:
"""Searches a Hands-Free SDP record from remote device.
Args:
@@ -1865,9 +1864,9 @@ async def find_hf_sdp_record(
],
)
for attribute_lists in search_result:
channel: int | None = None
version: ProfileVersion | None = None
features: HfSdpFeature | None = None
channel: Optional[int] = None
version: Optional[ProfileVersion] = None
features: Optional[HfSdpFeature] = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
@@ -1897,7 +1896,7 @@ async def find_hf_sdp_record(
async def find_ag_sdp_record(
connection: device.Connection,
) -> tuple[int, ProfileVersion, AgSdpFeature] | None:
) -> Optional[tuple[int, ProfileVersion, AgSdpFeature]]:
"""Searches an Audio-Gateway SDP record from remote device.
Args:
@@ -1916,9 +1915,9 @@ async def find_ag_sdp_record(
],
)
for attribute_lists in search_result:
channel: int | None = None
version: ProfileVersion | None = None
features: AgSdpFeature | None = None
channel: Optional[int] = None
version: Optional[ProfileVersion] = None
features: Optional[AgSdpFeature] = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:

View File

@@ -21,8 +21,8 @@ import enum
import logging
import struct
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Callable, Optional
from typing_extensions import override
@@ -195,9 +195,9 @@ class SendHandshakeMessage(Message):
# -----------------------------------------------------------------------------
class HID(ABC, utils.EventEmitter):
l2cap_ctrl_channel: l2cap.ClassicChannel | None = None
l2cap_intr_channel: l2cap.ClassicChannel | None = None
connection: device.Connection | None = None
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
connection: Optional[device.Connection] = None
EVENT_INTERRUPT_DATA = "interrupt_data"
EVENT_CONTROL_DATA = "control_data"
@@ -212,7 +212,7 @@ class HID(ABC, utils.EventEmitter):
def __init__(self, device: device.Device, role: Role) -> None:
super().__init__()
self.remote_device_bd_address: Address | None = None
self.remote_device_bd_address: Optional[Address] = None
self.device = device
self.role = role
@@ -246,7 +246,7 @@ class HID(ABC, utils.EventEmitter):
# Create a new L2CAP connection - interrupt channel
try:
channel = await self.connection.create_l2cap_channel(
l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM)
l2cap.ClassicChannelSpec(HID_CONTROL_PSM)
)
channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel
@@ -353,10 +353,10 @@ class Device(HID):
data: bytes = b''
status: int = 0
get_report_cb: Callable[[int, int, int], GetSetStatus] | None = None
set_report_cb: Callable[[int, int, int, bytes], GetSetStatus] | None = None
get_protocol_cb: Callable[[], GetSetStatus] | None = None
set_protocol_cb: Callable[[int], GetSetStatus] | None = None
get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)

View File

@@ -22,8 +22,7 @@ import collections
import dataclasses
import logging
import struct
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Union, cast
from bumble import drivers, hci, utils
from bumble.colors import color
@@ -109,7 +108,8 @@ class DataPacketQueue(utils.EventEmitter):
if self._packets:
logger.debug(
f'{self._in_flight} packets in flight, {len(self._packets)} in queue'
f'{self._in_flight} packets in flight, '
f'{len(self._packets)} in queue'
)
def flush(self, connection_handle: int) -> None:
@@ -199,7 +199,7 @@ class Connection:
self.peer_address = peer_address
self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
acl_packet_queue: DataPacketQueue | None = (
acl_packet_queue: Optional[DataPacketQueue] = (
host.le_acl_packet_queue
if transport == PhysicalTransport.LE
else host.acl_packet_queue
@@ -242,18 +242,20 @@ class Host(utils.EventEmitter):
bis_links: dict[int, IsoLink]
sco_links: dict[int, ScoLink]
bigs: dict[int, set[int]]
acl_packet_queue: DataPacketQueue | None = None
le_acl_packet_queue: DataPacketQueue | None = None
iso_packet_queue: DataPacketQueue | None = None
hci_sink: TransportSink | None = None
acl_packet_queue: Optional[DataPacketQueue] = None
le_acl_packet_queue: Optional[DataPacketQueue] = None
iso_packet_queue: Optional[DataPacketQueue] = None
hci_sink: Optional[TransportSink] = None
hci_metadata: dict[str, Any]
long_term_key_provider: Callable[[int, bytes, int], Awaitable[bytes | None]] | None
link_key_provider: Callable[[hci.Address], Awaitable[bytes | None]] | None
long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
]
link_key_provider: Optional[Callable[[hci.Address], Awaitable[Optional[bytes]]]]
def __init__(
self,
controller_source: TransportSource | None = None,
controller_sink: TransportSink | None = None,
controller_source: Optional[TransportSource] = None,
controller_sink: Optional[TransportSink] = None,
) -> None:
super().__init__()
@@ -265,7 +267,7 @@ class Host(utils.EventEmitter):
self.sco_links = {} # SCO links, by connection handle
self.bigs = {} # BIG Handle to BIS Handles
self.pending_command = None
self.pending_response: asyncio.Future[Any] | None = None
self.pending_response: Optional[asyncio.Future[Any]] = None
self.number_of_supported_advertising_sets = 0
self.maximum_advertising_data_length = 31
self.local_version = None
@@ -278,7 +280,7 @@ class Host(utils.EventEmitter):
self.long_term_key_provider = None
self.link_key_provider = None
self.pairing_io_capability_provider = None # Classic only
self.snooper: Snooper | None = None
self.snooper: Optional[Snooper] = None
# Connect to the source and sink if specified
if controller_source:
@@ -289,9 +291,9 @@ class Host(utils.EventEmitter):
def find_connection_by_bd_addr(
self,
bd_addr: hci.Address,
transport: int | None = None,
transport: Optional[int] = None,
check_address_type: bool = False,
) -> Connection | None:
) -> Optional[Connection]:
for connection in self.connections.values():
if bytes(connection.peer_address) == bytes(bd_addr):
if (
@@ -631,7 +633,7 @@ class Host(utils.EventEmitter):
)
@property
def controller(self) -> TransportSink | None:
def controller(self) -> Optional[TransportSink]:
return self.hci_sink
@controller.setter
@@ -640,7 +642,7 @@ class Host(utils.EventEmitter):
if controller:
self.set_packet_source(controller)
def set_packet_sink(self, sink: TransportSink | None) -> None:
def set_packet_sink(self, sink: Optional[TransportSink]) -> None:
self.hci_sink = sink
def set_packet_source(self, source: TransportSource) -> None:
@@ -655,7 +657,7 @@ class Host(utils.EventEmitter):
self.hci_sink.on_packet(bytes(packet))
async def send_command(
self, command, check_result=False, response_timeout: int | None = None
self, command, check_result=False, response_timeout: Optional[int] = None
):
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
@@ -705,7 +707,7 @@ class Host(utils.EventEmitter):
asyncio.create_task(send_command(command))
def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None:
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)):
logger.warning(f'connection 0x{connection_handle:04X} not found')
return
@@ -716,24 +718,27 @@ class Host(utils.EventEmitter):
)
return
# Create a PDU
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
# Send the data to the controller via ACL packets
max_packet_size = packet_queue.max_packet_size
for offset in range(0, len(sdu), max_packet_size):
pdu = sdu[offset : offset + max_packet_size]
bytes_remaining = len(l2cap_pdu)
offset = 0
pb_flag = 0
while bytes_remaining:
data_total_length = min(bytes_remaining, packet_queue.max_packet_size)
acl_packet = hci.HCI_AclDataPacket(
connection_handle=connection_handle,
pb_flag=1 if offset > 0 else 0,
pb_flag=pb_flag,
bc_flag=0,
data_total_length=len(pdu),
data=pdu,
)
logger.debug(
'>>> ACL packet enqueue: (Handle=0x%04X) %s', connection_handle, pdu
data_total_length=data_total_length,
data=l2cap_pdu[offset : offset + data_total_length],
)
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
packet_queue.enqueue(acl_packet, connection_handle)
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.send_acl_sdu(connection_handle, bytes(L2CAP_PDU(cid, pdu)))
pb_flag = 1
offset += data_total_length
bytes_remaining -= data_total_length
def get_data_packet_queue(self, connection_handle: int) -> DataPacketQueue | None:
if connection := self.connections.get(connection_handle):
@@ -898,7 +903,7 @@ class Host(utils.EventEmitter):
self.emit('l2cap_pdu', connection.handle, cid, pdu)
def on_command_processed(
self, event: hci.HCI_Command_Complete_Event | hci.HCI_Command_Status_Event
self, event: Union[hci.HCI_Command_Complete_Event, hci.HCI_Command_Status_Event]
):
if self.pending_response:
# Check that it is what we were expecting
@@ -961,11 +966,11 @@ class Host(utils.EventEmitter):
def on_hci_le_connection_complete_event(
self,
event: (
hci.HCI_LE_Connection_Complete_Event
| hci.HCI_LE_Enhanced_Connection_Complete_Event
| hci.HCI_LE_Enhanced_Connection_Complete_V2_Event
),
event: Union[
hci.HCI_LE_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
],
):
# Check if this is a cancellation
if event.status == hci.HCI_SUCCESS:
@@ -1010,10 +1015,10 @@ class Host(utils.EventEmitter):
def on_hci_le_enhanced_connection_complete_event(
self,
event: (
hci.HCI_LE_Enhanced_Connection_Complete_Event
| hci.HCI_LE_Enhanced_Connection_Complete_V2_Event
),
event: Union[
hci.HCI_LE_Enhanced_Connection_Complete_Event,
hci.HCI_LE_Enhanced_Connection_Complete_V2_Event,
],
):
# Just use the same implementation as for the non-enhanced event for now
self.on_hci_le_connection_complete_event(event)
@@ -1392,7 +1397,8 @@ class Host(utils.EventEmitter):
if event.status == hci.HCI_SUCCESS:
# Create/update the connection
logger.debug(
f'### SCO CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}'
f'### SCO CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr}'
)
self.sco_links[event.connection_handle] = ScoLink(
@@ -1444,7 +1450,7 @@ class Host(utils.EventEmitter):
def on_hci_le_data_length_change_event(
self, event: hci.HCI_LE_Data_Length_Change_Event
):
if event.connection_handle not in self.connections:
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! DATA LENGTH CHANGE: unknown handle')
return

View File

@@ -27,7 +27,7 @@ import dataclasses
import json
import logging
import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from typing_extensions import Self
@@ -51,8 +51,8 @@ class PairingKeys:
class Key:
value: bytes
authenticated: bool = False
ediv: int | None = None
rand: bytes | None = None
ediv: Optional[int] = None
rand: Optional[bytes] = None
@classmethod
def from_dict(cls, key_dict: dict[str, Any]) -> PairingKeys.Key:
@@ -74,17 +74,17 @@ class PairingKeys:
return key_dict
address_type: hci.AddressType | None = None
ltk: Key | None = None
ltk_central: Key | None = None
ltk_peripheral: Key | None = None
irk: Key | None = None
csrk: Key | None = None
link_key: Key | None = None # Classic
link_key_type: int | None = None # Classic
address_type: Optional[hci.AddressType] = None
ltk: Optional[Key] = None
ltk_central: Optional[Key] = None
ltk_peripheral: Optional[Key] = None
irk: Optional[Key] = None
csrk: Optional[Key] = None
link_key: Optional[Key] = None # Classic
link_key_type: Optional[int] = None # Classic
@classmethod
def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Key | None:
def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Optional[Key]:
key_dict = keys_dict.get(key_name)
if key_dict is None:
return None
@@ -156,7 +156,7 @@ class KeyStore:
async def update(self, name: str, keys: PairingKeys) -> None:
pass
async def get(self, _name: str) -> PairingKeys | None:
async def get(self, _name: str) -> Optional[PairingKeys]:
return None
async def get_all(self) -> list[tuple[str, PairingKeys]]:
@@ -274,7 +274,7 @@ class JsonKeyStore(KeyStore):
@classmethod
def from_device(
cls: type[Self], device: Device, filename: str | None = None
cls: type[Self], device: Device, filename: Optional[str] = None
) -> Self:
if not filename:
# Extract the filename from the config if there is one
@@ -297,7 +297,7 @@ class JsonKeyStore(KeyStore):
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try:
with open(self.filename, encoding='utf-8') as json_file:
with open(self.filename, 'r', encoding='utf-8') as json_file:
db = json.load(json_file)
except FileNotFoundError:
db = {}
@@ -348,7 +348,7 @@ class JsonKeyStore(KeyStore):
key_map.clear()
await self.save(db)
async def get(self, name: str) -> PairingKeys | None:
async def get(self, name: str) -> Optional[PairingKeys]:
_, key_map = await self.load()
if name not in key_map:
return None
@@ -370,7 +370,7 @@ class MemoryKeyStore(KeyStore):
async def update(self, name: str, keys: PairingKeys) -> None:
self.all_keys[name] = keys
async def get(self, name: str) -> PairingKeys | None:
async def get(self, name: str) -> Optional[PairingKeys]:
return self.all_keys.get(name)
async def get_all(self) -> list[tuple[str, PairingKeys]]:

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
@@ -19,12 +18,18 @@ import asyncio
# Imports
# -----------------------------------------------------------------------------
import logging
from typing import TYPE_CHECKING
from typing import Optional
from bumble import core, hci, ll, lmp
if TYPE_CHECKING:
from bumble import controller
from bumble import controller, core
from bumble.hci import (
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_SUCCESS,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
Address,
HCI_Connection_Complete_Event,
Role,
)
# -----------------------------------------------------------------------------
# Logging
@@ -32,6 +37,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def parse_parameters(params_str):
result = {}
for param_str in params_str.split(','):
if '=' in param_str:
key, value = param_str.split('=')
result[key] = value
return result
# -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
@@ -45,34 +62,37 @@ class LocalLink:
def __init__(self):
self.controllers = set()
self.pending_connection = None
self.pending_classic_connection = None
############################################################
# Common utils
############################################################
def add_controller(self, controller: controller.Controller):
def add_controller(self, controller):
logger.debug(f'new controller: {controller}')
self.controllers.add(controller)
def remove_controller(self, controller: controller.Controller):
def remove_controller(self, controller):
self.controllers.remove(controller)
def find_le_controller(self, address: hci.Address) -> controller.Controller | None:
def find_controller(self, address):
for controller in self.controllers:
for connection in controller.le_connections.values():
if connection.self_address == address:
return controller
if controller.random_address == address:
return controller
return None
def find_classic_controller(
self, address: hci.Address
) -> controller.Controller | None:
self, address: Address
) -> Optional[controller.Controller]:
for controller in self.controllers:
if controller.public_address == address:
return controller
return None
def get_pending_connection(self):
return self.pending_connection
############################################################
# LE handlers
############################################################
@@ -80,16 +100,16 @@ class LocalLink:
def on_address_changed(self, controller):
pass
def send_acl_data(
self,
sender_controller: controller.Controller,
destination_address: hci.Address,
transport: core.PhysicalTransport,
data: bytes,
):
def send_advertising_data(self, sender_address, data):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_advertising_data(sender_address, data)
def send_acl_data(self, sender_controller, destination_address, transport, data):
# Send the data to the first controller with a matching address
if transport == core.PhysicalTransport.LE:
destination_controller = self.find_le_controller(destination_address)
destination_controller = self.find_controller(destination_address)
source_address = sender_controller.random_address
elif transport == core.PhysicalTransport.BR_EDR:
destination_controller = self.find_classic_controller(destination_address)
@@ -98,52 +118,262 @@ class LocalLink:
raise ValueError("unsupported transport type")
if destination_controller is not None:
asyncio.get_running_loop().call_soon(
lambda: destination_controller.on_link_acl_data(
source_address, transport, data
)
)
destination_controller.on_link_acl_data(source_address, transport, data)
def send_advertising_pdu(
self,
sender_controller: controller.Controller,
packet: ll.AdvertisingPdu,
):
loop = asyncio.get_running_loop()
for c in self.controllers:
if c != sender_controller:
loop.call_soon(c.on_ll_advertising_pdu, packet)
def on_connection_complete(self):
# Check that we expect this call
if not self.pending_connection:
logger.warning('on_connection_complete with no pending connection')
return
def send_ll_control_pdu(
self,
sender_address: hci.Address,
receiver_address: hci.Address,
packet: ll.ControlPdu,
):
if not (receiver_controller := self.find_le_controller(receiver_address)):
raise core.InvalidArgumentError(
f"Unable to find controller for address {receiver_address}"
central_address, le_create_connection_command = self.pending_connection
self.pending_connection = None
# Find the controller that initiated the connection
if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found')
return
# Connect to the first controller with a matching address
if peripheral_controller := self.find_controller(
le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_SUCCESS
)
asyncio.get_running_loop().call_soon(
lambda: receiver_controller.on_ll_control_pdu(sender_address, packet)
peripheral_controller.on_link_central_connected(central_address)
return
# No peripheral found
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
)
def connect(self, central_address, le_create_connection_command):
logger.debug(
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(
self, initiating_address, target_address, disconnect_command
):
# Find the controller that initiated the disconnection
if not (initiating_controller := self.find_controller(initiating_address)):
logger.warning('!!! Initiating controller not found')
return
# Disconnect from the first controller with a matching address
if target_controller := self.find_controller(target_address):
target_controller.on_link_disconnected(
initiating_address, disconnect_command.reason
)
initiating_controller.on_link_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, initiating_address, target_address, disconnect_command):
logger.debug(
f'$$$ DISCONNECTION {initiating_address} -> '
f'{target_address}: reason = {disconnect_command.reason}'
)
args = [initiating_address, target_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
# pylint: disable=too-many-arguments
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address):
central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk)
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
asyncio.get_running_loop().call_soon(
central_controller.on_link_cis_established, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
asyncio.get_running_loop().call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peer_controller.on_link_cis_disconnected, cig_id, cis_id
)
############################################################
# Classic handlers
############################################################
def send_lmp_packet(
self,
sender_controller: controller.Controller,
receiver_address: hci.Address,
packet: lmp.Packet,
):
if not (receiver_controller := self.find_classic_controller(receiver_address)):
raise core.InvalidArgumentError(
f"Unable to find controller for address {receiver_address}"
)
asyncio.get_running_loop().call_soon(
lambda: receiver_controller.on_lmp_packet(
sender_controller.public_address, packet
)
def classic_connect(self, initiator_controller, responder_address):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
if responder_controller is None:
initiator_controller.on_classic_connection_complete(
responder_address, HCI_PAGE_TIMEOUT_ERROR
)
return
self.pending_classic_connection = (initiator_controller, responder_controller)
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
HCI_Connection_Complete_Event.LinkType.ACL,
)
def classic_accept_connection(
self, responder_controller, initiator_address, responder_role
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR
)
return
async def task():
if responder_role != Role.PERIPHERAL:
initiator_controller.on_classic_role_change(
responder_controller.public_address, int(not (responder_role))
)
initiator_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_SUCCESS
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, responder_role
)
responder_controller.on_classic_connection_complete(
initiator_controller.public_address, HCI_SUCCESS
)
self.pending_classic_connection = None
def classic_disconnect(self, initiator_controller, responder_address, reason):
logger.debug(
f'[Classic] {initiator_controller.public_address} disconnects {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
async def task():
initiator_controller.on_classic_disconnected(responder_address, reason)
asyncio.create_task(task())
responder_controller.on_classic_disconnected(
initiator_controller.public_address, reason
)
def classic_switch_role(
self, initiator_controller, responder_address, initiator_new_role
):
responder_controller = self.find_classic_controller(responder_address)
if responder_controller is None:
return
async def task():
initiator_controller.on_classic_role_change(
responder_address, initiator_new_role
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, int(not (initiator_new_role))
)
def classic_sco_connect(
self,
initiator_controller: controller.Controller,
responder_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
# Initiator controller should handle it.
assert responder_controller
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
link_type,
)
def classic_accept_sco_connection(
self,
responder_controller: controller.Controller,
initiator_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_sco_connection_complete(
responder_controller.public_address,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
link_type,
)
return
async def task():
initiator_controller.on_classic_sco_connection_complete(
responder_controller.public_address, HCI_SUCCESS, link_type
)
asyncio.create_task(task())
responder_controller.on_classic_sco_connection_complete(
initiator_controller.public_address, HCI_SUCCESS, link_type
)

View File

@@ -1,200 +0,0 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
from typing import ClassVar
from bumble import hci
# -----------------------------------------------------------------------------
# Advertising PDU
# -----------------------------------------------------------------------------
class AdvertisingPdu:
"""Base Advertising Physical Channel PDU class.
See Core Spec 6.0, Volume 6, Part B, 2.3. Advertising physical channel PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
@dataclasses.dataclass
class ConnectInd(AdvertisingPdu):
initiator_address: hci.Address
advertiser_address: hci.Address
interval: int
latency: int
timeout: int
@dataclasses.dataclass
class AdvInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvDirectInd(AdvertisingPdu):
advertiser_address: hci.Address
target_address: hci.Address
@dataclasses.dataclass
class AdvNonConnInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvExtInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
target_address: hci.Address | None = None
adi: int | None = None
tx_power: int | None = None
# -----------------------------------------------------------------------------
# LL Control PDU
# -----------------------------------------------------------------------------
class ControlPdu:
"""Base LL Control PDU Class.
See Core Spec 6.0, Volume 6, Part B, 2.4.2. LL Control PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
class Opcode(hci.SpecableEnum):
LL_CONNECTION_UPDATE_IND = 0x00
LL_CHANNEL_MAP_IND = 0x01
LL_TERMINATE_IND = 0x02
LL_ENC_REQ = 0x03
LL_ENC_RSP = 0x04
LL_START_ENC_REQ = 0x05
LL_START_ENC_RSP = 0x06
LL_UNKNOWN_RSP = 0x07
LL_FEATURE_REQ = 0x08
LL_FEATURE_RSP = 0x09
LL_PAUSE_ENC_REQ = 0x0A
LL_PAUSE_ENC_RSP = 0x0B
LL_VERSION_IND = 0x0C
LL_REJECT_IND = 0x0D
LL_PERIPHERAL_FEATURE_REQ = 0x0E
LL_CONNECTION_PARAM_REQ = 0x0F
LL_CONNECTION_PARAM_RSP = 0x10
LL_REJECT_EXT_IND = 0x11
LL_PING_REQ = 0x12
LL_PING_RSP = 0x13
LL_LENGTH_REQ = 0x14
LL_LENGTH_RSP = 0x15
LL_PHY_REQ = 0x16
LL_PHY_RSP = 0x17
LL_PHY_UPDATE_IND = 0x18
LL_MIN_USED_CHANNELS_IND = 0x19
LL_CTE_REQ = 0x1A
LL_CTE_RSP = 0x1B
LL_PERIODIC_SYNC_IND = 0x1C
LL_CLOCK_ACCURACY_REQ = 0x1D
LL_CLOCK_ACCURACY_RSP = 0x1E
LL_CIS_REQ = 0x1F
LL_CIS_RSP = 0x20
LL_CIS_IND = 0x21
LL_CIS_TERMINATE_IND = 0x22
LL_POWER_CONTROL_REQ = 0x23
LL_POWER_CONTROL_RSP = 0x24
LL_POWER_CHANGE_IND = 0x25
LL_SUBRATE_REQ = 0x26
LL_SUBRATE_IND = 0x27
LL_CHANNEL_REPORTING_IND = 0x28
LL_CHANNEL_STATUS_IND = 0x29
LL_PERIODIC_SYNC_WR_IND = 0x2A
LL_FEATURE_EXT_REQ = 0x2B
LL_FEATURE_EXT_RSP = 0x2C
LL_CS_SEC_RSP = 0x2D
LL_CS_CAPABILITIES_REQ = 0x2E
LL_CS_CAPABILITIES_RSP = 0x2F
LL_CS_CONFIG_REQ = 0x30
LL_CS_CONFIG_RSP = 0x31
LL_CS_REQ = 0x32
LL_CS_RSP = 0x33
LL_CS_IND = 0x34
LL_CS_TERMINATE_REQ = 0x35
LL_CS_FAE_REQ = 0x36
LL_CS_FAE_RSP = 0x37
LL_CS_CHANNEL_MAP_IND = 0x38
LL_CS_SEC_REQ = 0x39
LL_CS_TERMINATE_RSP = 0x3A
LL_FRAME_SPACE_REQ = 0x3B
LL_FRAME_SPACE_RSP = 0x3C
opcode: ClassVar[Opcode]
@dataclasses.dataclass
class TerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_TERMINATE_IND
error_code: int
@dataclasses.dataclass
class EncReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_ENC_REQ
rand: bytes
ediv: int
ltk: bytes
@dataclasses.dataclass
class CisReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisRsp(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisTerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_TERMINATE_IND
cig_id: int
cis_id: int
error_code: int

View File

@@ -1,324 +0,0 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from dataclasses import dataclass, field
from typing import TypeVar
from bumble import hci, utils
class Opcode(utils.OpenIntEnum):
'''
See Bluetooth spec @ Vol 2, Part C - 5.1 PDU summary.
Follow the alphabetical order defined there.
'''
# fmt: off
LMP_ACCEPTED = 3
LMP_ACCEPTED_EXT = 127 << 8 + 1
LMP_AU_RAND = 11
LMP_AUTO_RATE = 35
LMP_CHANNEL_CLASSIFICATION = 127 << 8 + 17
LMP_CHANNEL_CLASSIFICATION_REQ = 127 << 8 + 16
LMP_CLK_ADJ = 127 << 8 + 5
LMP_CLK_ADJ_ACK = 127 << 8 + 6
LMP_CLK_ADJ_REQ = 127 << 8 + 7
LMP_CLKOFFSET_REQ = 5
LMP_CLKOFFSET_RES = 6
LMP_COMB_KEY = 9
LMP_DECR_POWER_REQ = 32
LMP_DETACH = 7
LMP_DHKEY_CHECK = 65
LMP_ENCAPSULATED_HEADER = 61
LMP_ENCAPSULATED_PAYLOAD = 62
LMP_ENCRYPTION_KEY_SIZE_MASK_REQ= 58
LMP_ENCRYPTION_KEY_SIZE_MASK_RES= 59
LMP_ENCRYPTION_KEY_SIZE_REQ = 16
LMP_ENCRYPTION_MODE_REQ = 15
LMP_ESCO_LINK_REQ = 127 << 8 + 12
LMP_FEATURES_REQ = 39
LMP_FEATURES_REQ_EXT = 127 << 8 + 3
LMP_FEATURES_RES = 40
LMP_FEATURES_RES_EXT = 127 << 8 + 4
LMP_HOLD = 20
LMP_HOLD_REQ = 21
LMP_HOST_CONNECTION_REQ = 51
LMP_IN_RAND = 8
LMP_INCR_POWER_REQ = 31
LMP_IO_CAPABILITY_REQ = 127 << 8 + 25
LMP_IO_CAPABILITY_RES = 127 << 8 + 26
LMP_KEYPRESS_NOTIFICATION = 127 << 8 + 30
LMP_MAX_POWER = 33
LMP_MAX_SLOT = 45
LMP_MAX_SLOT_REQ = 46
LMP_MIN_POWER = 34
LMP_NAME_REQ = 1
LMP_NAME_RES = 2
LMP_NOT_ACCEPTED = 4
LMP_NOT_ACCEPTED_EXT = 127 << 8 + 2
LMP_NUMERIC_COMPARISON_FAILED = 127 << 8 + 27
LMP_OOB_FAILED = 127 << 8 + 29
LMP_PACKET_TYPE_TABLE_REQ = 127 << 8 + 11
LMP_PAGE_MODE_REQ = 53
LMP_PAGE_SCAN_MODE_REQ = 54
LMP_PASSKEY_FAILED = 127 << 8 + 28
LMP_PAUSE_ENCRYPTION_AES_REQ = 66
LMP_PAUSE_ENCRYPTION_REQ = 127 << 8 + 23
LMP_PING_REQ = 127 << 8 + 33
LMP_PING_RES = 127 << 8 + 34
LMP_POWER_CONTROL_REQ = 127 << 8 + 31
LMP_POWER_CONTROL_RES = 127 << 8 + 32
LMP_PREFERRED_RATE = 36
LMP_QUALITY_OF_SERVICE = 41
LMP_QUALITY_OF_SERVICE_REQ = 42
LMP_REMOVE_ESCO_LINK_REQ = 127 << 8 + 13
LMP_REMOVE_SCO_LINK_REQ = 44
LMP_RESUME_ENCRYPTION_REQ = 127 << 8 + 24
LMP_SAM_DEFINE_MAP = 127 << 8 + 36
LMP_SAM_SET_TYPE0 = 127 << 8 + 35
LMP_SAM_SWITCH = 127 << 8 + 37
LMP_SCO_LINK_REQ = 43
LMP_SET_AFH = 60
LMP_SETUP_COMPLETE = 49
LMP_SIMPLE_PAIRING_CONFIRM = 63
LMP_SIMPLE_PAIRING_NUMBER = 64
LMP_SLOT_OFFSET = 52
LMP_SNIFF_REQ = 23
LMP_SNIFF_SUBRATING_REQ = 127 << 8 + 21
LMP_SNIFF_SUBRATING_RES = 127 << 8 + 22
LMP_SRES = 12
LMP_START_ENCRYPTION_REQ = 17
LMP_STOP_ENCRYPTION_REQ = 18
LMP_SUPERVISION_TIMEOUT = 55
LMP_SWITCH_REQ = 19
LMP_TEMP_KEY = 14
LMP_TEMP_RAND = 13
LMP_TEST_ACTIVATE = 56
LMP_TEST_CONTROL = 57
LMP_TIMING_ACCURACY_REQ = 47
LMP_TIMING_ACCURACY_RES = 48
LMP_UNIT_KEY = 10
LMP_UNSNIFF_REQ = 24
LMP_USE_SEMI_PERMANENT_KEY = 50
LMP_VERSION_REQ = 37
LMP_VERSION_RES = 38
# fmt: on
@classmethod
def parse_from(cls, data: bytes, offset: int = 0) -> tuple[int, Opcode]:
opcode = data[offset]
if opcode in (124, 127):
opcode = struct.unpack('>H', data)[0]
return offset + 2, Opcode(opcode)
return offset + 1, Opcode(opcode)
def __bytes__(self) -> bytes:
if self.value >> 8:
return struct.pack('>H', self.value)
return bytes([self.value])
@classmethod
def type_metadata(cls):
return hci.metadata(
{
'serializer': bytes,
'parser': lambda data, offset: (Opcode.parse_from(data, offset)),
}
)
class Packet:
'''
See Bluetooth spec @ Vol 2, Part C - 5.1 PDU summary
'''
subclasses: dict[int, type[Packet]] = {}
opcode: Opcode
fields: hci.Fields = ()
_payload: bytes = b''
_Packet = TypeVar("_Packet", bound="Packet")
@classmethod
def subclass(cls, subclass: type[_Packet]) -> type[_Packet]:
# Register a factory for this class
cls.subclasses[subclass.opcode] = subclass
subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass)
return subclass
@classmethod
def from_bytes(cls, data: bytes) -> Packet:
offset, opcode = Opcode.parse_from(data)
if not (subclass := cls.subclasses.get(opcode)):
instance = Packet()
instance.opcode = opcode
else:
instance = subclass(
**hci.HCI_Object.dict_from_bytes(data, offset, subclass.fields)
)
instance.payload = data[offset:]
return instance
@property
def payload(self) -> bytes:
if self._payload is None:
self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._payload
@payload.setter
def payload(self, value: bytes) -> None:
self._payload = value
def __bytes__(self) -> bytes:
return bytes(self.opcode) + self.payload
@Packet.subclass
@dataclass
class LmpAccepted(Packet):
opcode = Opcode.LMP_ACCEPTED
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
@Packet.subclass
@dataclass
class LmpNotAccepted(Packet):
opcode = Opcode.LMP_NOT_ACCEPTED
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpAcceptedExt(Packet):
opcode = Opcode.LMP_ACCEPTED_EXT
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
@Packet.subclass
@dataclass
class LmpNotAcceptedExt(Packet):
opcode = Opcode.LMP_NOT_ACCEPTED_EXT
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpAuRand(Packet):
opcode = Opcode.LMP_AU_RAND
random_number: bytes = field(metadata=hci.metadata(16))
@Packet.subclass
@dataclass
class LmpDetach(Packet):
opcode = Opcode.LMP_DETACH
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpEscoLinkReq(Packet):
opcode = Opcode.LMP_ESCO_LINK_REQ
esco_handle: int = field(metadata=hci.metadata(1))
esco_lt_addr: int = field(metadata=hci.metadata(1))
timing_control_flags: int = field(metadata=hci.metadata(1))
d_esco: int = field(metadata=hci.metadata(1))
t_esco: int = field(metadata=hci.metadata(1))
w_esco: int = field(metadata=hci.metadata(1))
esco_packet_type_c_to_p: int = field(metadata=hci.metadata(1))
esco_packet_type_p_to_c: int = field(metadata=hci.metadata(1))
packet_length_c_to_p: int = field(metadata=hci.metadata(2))
packet_length_p_to_c: int = field(metadata=hci.metadata(2))
air_mode: int = field(metadata=hci.metadata(1))
negotiation_state: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpHostConnectionReq(Packet):
opcode = Opcode.LMP_HOST_CONNECTION_REQ
@Packet.subclass
@dataclass
class LmpRemoveEscoLinkReq(Packet):
opcode = Opcode.LMP_REMOVE_ESCO_LINK_REQ
esco_handle: int = field(metadata=hci.metadata(1))
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpRemoveScoLinkReq(Packet):
opcode = Opcode.LMP_REMOVE_SCO_LINK_REQ
sco_handle: int = field(metadata=hci.metadata(1))
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpScoLinkReq(Packet):
opcode = Opcode.LMP_SCO_LINK_REQ
sco_handle: int = field(metadata=hci.metadata(1))
timing_control_flags: int = field(metadata=hci.metadata(1))
d_sco: int = field(metadata=hci.metadata(1))
t_sco: int = field(metadata=hci.metadata(1))
sco_packet: int = field(metadata=hci.metadata(1))
air_mode: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpSwitchReq(Packet):
opcode = Opcode.LMP_SWITCH_REQ
switch_instant: int = field(metadata=hci.metadata(4), default=0)
@Packet.subclass
@dataclass
class LmpNameReq(Packet):
opcode = Opcode.LMP_NAME_REQ
name_offset: int = field(metadata=hci.metadata(2))
@Packet.subclass
@dataclass
class LmpNameRes(Packet):
opcode = Opcode.LMP_NAME_RES
name_offset: int = field(metadata=hci.metadata(2))
name_length: int = field(metadata=hci.metadata(3))
name_fregment: bytes = field(metadata=hci.metadata('*'))

View File

@@ -20,6 +20,7 @@ from __future__ import annotations
import enum
import secrets
from dataclasses import dataclass
from typing import Optional
from bumble import hci
from bumble.core import AdvertisingData, LeRole
@@ -44,16 +45,16 @@ from bumble.smp import (
class OobData:
"""OOB data that can be sent from one device to another."""
address: hci.Address | None = None
role: LeRole | None = None
shared_data: OobSharedData | None = None
legacy_context: OobLegacyContext | None = None
address: Optional[hci.Address] = None
role: Optional[LeRole] = None
shared_data: Optional[OobSharedData] = None
legacy_context: Optional[OobLegacyContext] = None
@classmethod
def from_ad(cls, ad: AdvertisingData) -> OobData:
instance = cls()
shared_data_c: bytes | None = None
shared_data_r: bytes | None = None
shared_data_c: Optional[bytes] = None
shared_data_r: Optional[bytes] = None
for ad_type, ad_data in ad.ad_structures:
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
instance.address = hci.Address(ad_data)
@@ -180,14 +181,14 @@ class PairingDelegate:
"""Compare two numbers."""
return True
async def get_number(self) -> int | None:
async def get_number(self) -> Optional[int]:
"""
Return an optional number as an answer to a passkey request.
Returning `None` will result in a negative reply.
"""
return 0
async def get_string(self, max_length: int) -> str | None:
async def get_string(self, max_length: int) -> Optional[str]:
"""
Return a string whose utf-8 encoding is up to max_length bytes.
"""
@@ -238,18 +239,18 @@ class PairingConfig:
class OobConfig:
"""Config for OOB pairing."""
our_context: OobContext | None
peer_data: OobSharedData | None
legacy_context: OobLegacyContext | None
our_context: Optional[OobContext]
peer_data: Optional[OobSharedData]
legacy_context: Optional[OobLegacyContext]
def __init__(
self,
sc: bool = True,
mitm: bool = True,
bonding: bool = True,
delegate: PairingDelegate | None = None,
identity_address_type: AddressType | None = None,
oob: OobConfig | None = None,
delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None,
oob: Optional[OobConfig] = None,
) -> None:
self.sc = sc
self.mitm = mitm

View File

@@ -19,7 +19,7 @@ This module implement the Pandora Bluetooth test APIs for the Bumble stack.
__version__ = "0.0.1"
from collections.abc import Callable
from typing import Callable, List, Optional
import grpc
import grpc.aio
@@ -58,7 +58,7 @@ def register_servicer_hook(
async def serve(
bumble: PandoraDevice,
config: Config = Config(),
grpc_server: grpc.aio.Server | None = None,
grpc_server: Optional[grpc.aio.Server] = None,
port: int = 0,
) -> None:
# initialize a gRPC server if not provided.

View File

@@ -16,7 +16,7 @@
from __future__ import annotations
from typing import Any
from typing import Any, Optional
from bumble import transport
from bumble.core import (
@@ -54,7 +54,7 @@ class PandoraDevice:
# HCI transport name & instance.
_hci_name: str
_hci: transport.Transport | None # type: ignore[name-defined]
_hci: Optional[transport.Transport] # type: ignore[name-defined]
def __init__(self, config: dict[str, Any]) -> None:
self.config = config
@@ -74,9 +74,7 @@ class PandoraDevice:
# open HCI transport & set device host.
self._hci = await transport.open_transport(self._hci_name)
self.device.host = Host(
controller_source=self._hci.source, controller_sink=self._hci.sink
) # type: ignore[no-untyped-call]
self.device.host = Host(controller_source=self._hci.source, controller_sink=self._hci.sink) # type: ignore[no-untyped-call]
# power-on.
await self.device.power_on()
@@ -98,7 +96,7 @@ class PandoraDevice:
await self.close()
await self.open()
def info(self) -> dict[str, str] | None:
def info(self) -> Optional[dict[str, str]]:
return {
'public_bd_address': str(self.device.public_address),
'random_address': str(self.device.random_address),

View File

@@ -17,15 +17,12 @@ from __future__ import annotations
import asyncio
import logging
import struct
from collections.abc import AsyncGenerator
from typing import cast
from typing import AsyncGenerator, Optional, cast
import grpc
import grpc.aio
from google.protobuf import (
any_pb2, # pytype: disable=pyi-error
empty_pb2, # pytype: disable=pyi-error
)
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora import host_pb2
from pandora.host_grpc_aio import HostServicer
from pandora.host_pb2 import (
@@ -305,9 +302,7 @@ class HostService(HostServicer):
await disconnection_future
self.log.debug("Disconnected")
finally:
connection.remove_listener(
connection.EVENT_DISCONNECTION, on_disconnection
) # type: ignore
connection.remove_listener(connection.EVENT_DISCONNECTION, on_disconnection) # type: ignore
return empty_pb2.Empty()
@@ -544,7 +539,7 @@ class HostService(HostServicer):
await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_advertising()
)
except Exception:
except:
pass
@utils.rpc
@@ -614,7 +609,7 @@ class HostService(HostServicer):
await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_scanning()
)
except Exception:
except:
pass
@utils.rpc
@@ -624,7 +619,7 @@ class HostService(HostServicer):
self.log.debug('Inquiry')
inquiry_queue: asyncio.Queue[
tuple[Address, int, AdvertisingData, int] | None
Optional[tuple[Address, int, AdvertisingData, int]]
] = asyncio.Queue()
complete_handler = self.device.on(
self.device.EVENT_INQUIRY_COMPLETE, lambda: inquiry_queue.put_nowait(None)
@@ -649,18 +644,14 @@ class HostService(HostServicer):
)
finally:
self.device.remove_listener(
self.device.EVENT_INQUIRY_COMPLETE, complete_handler
) # type: ignore
self.device.remove_listener(
self.device.EVENT_INQUIRY_RESULT, result_handler
) # type: ignore
self.device.remove_listener(self.device.EVENT_INQUIRY_COMPLETE, complete_handler) # type: ignore
self.device.remove_listener(self.device.EVENT_INQUIRY_RESULT, result_handler) # type: ignore
try:
self.log.debug('Stop inquiry')
await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_discovery()
)
except Exception:
except:
pass
@utils.rpc

View File

@@ -18,15 +18,15 @@ import json
import logging
from asyncio import Future
from asyncio import Queue as AsyncQueue
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import AsyncGenerator, Optional, Union
import grpc
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
from pandora.l2cap_pb2 import COMMAND_NOT_UNDERSTOOD, INVALID_CID_IN_REQUEST
from pandora.l2cap_pb2 import Channel as PandoraChannel # pytype: disable=pyi-error
from pandora.l2cap_pb2 import (
COMMAND_NOT_UNDERSTOOD,
INVALID_CID_IN_REQUEST,
ConnectRequest,
ConnectResponse,
CreditBasedChannelRequest,
@@ -41,7 +41,6 @@ from pandora.l2cap_pb2 import (
WaitDisconnectionRequest,
WaitDisconnectionResponse,
)
from pandora.l2cap_pb2 import Channel as PandoraChannel # pytype: disable=pyi-error
from bumble.core import InvalidArgumentError, OutOfResourcesError
from bumble.device import Device
@@ -56,7 +55,7 @@ from bumble.l2cap import (
from bumble.pandora import utils
from bumble.pandora.config import Config
L2capChannel = ClassicChannel | LeCreditBasedChannel
L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
@dataclass
@@ -107,8 +106,10 @@ class L2CAPService(L2CAPServicer):
oneof = request.WhichOneof('type')
self.log.debug(f'WaitConnection channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: ClassicChannelSpec | LeCreditBasedChannelSpec | None = None
l2cap_server: ClassicChannelServer | LeCreditBasedChannelServer | None = None
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
l2cap_server: Optional[
Union[ClassicChannelServer, LeCreditBasedChannelServer]
] = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,
@@ -215,7 +216,7 @@ class L2CAPService(L2CAPServicer):
oneof = request.WhichOneof('type')
self.log.debug(f'Channel request type: {oneof}.')
channel_type = getattr(request, oneof)
spec: ClassicChannelSpec | LeCreditBasedChannelSpec | None = None
spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm,

View File

@@ -17,15 +17,13 @@ from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
from typing import Any
from collections.abc import Awaitable
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Optional, Union
import grpc
from google.protobuf import (
any_pb2, # pytype: disable=pyi-error
empty_pb2, # pytype: disable=pyi-error
wrappers_pb2, # pytype: disable=pyi-error
)
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
from pandora.host_pb2 import Connection
from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
from pandora.security_pb2 import (
@@ -66,7 +64,7 @@ class PairingDelegate(BasePairingDelegate):
def __init__(
self,
connection: BumbleConnection,
service: SecurityService,
service: "SecurityService",
io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
@@ -132,7 +130,7 @@ class PairingDelegate(BasePairingDelegate):
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
async def get_number(self) -> int | None:
async def get_number(self) -> Optional[int]:
self.log.debug(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
)
@@ -149,7 +147,7 @@ class PairingDelegate(BasePairingDelegate):
assert answer.answer_variant() == 'passkey'
return answer.passkey
async def get_string(self, max_length: int) -> str | None:
async def get_string(self, max_length: int) -> Optional[str]:
self.log.debug(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
)
@@ -197,8 +195,8 @@ class SecurityService(SecurityServicer):
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'Security', 'device': device}
)
self.event_queue: asyncio.Queue[PairingEvent] | None = None
self.event_answer: AsyncIterator[PairingEventAnswer] | None = None
self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
self.device = device
self.config = config
@@ -233,7 +231,7 @@ class SecurityService(SecurityServicer):
if level == LEVEL2:
return connection.encryption != 0 and connection.authenticated
link_key_type: int | None = None
link_key_type: Optional[int] = None
if (keystore := connection.device.keystore) and (
keys := await keystore.get(str(connection.peer_address))
):
@@ -412,8 +410,8 @@ class SecurityService(SecurityServicer):
wait_for_security: asyncio.Future[str] = (
asyncio.get_running_loop().create_future()
)
authenticate_task: asyncio.Future[None] | None = None
pair_task: asyncio.Future[None] | None = None
authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None:
if (encryption := connection.encryption) != 0:
@@ -457,9 +455,9 @@ class SecurityService(SecurityServicer):
def pair(*_: Any) -> None:
if self.need_pairing(connection, level):
bumble.utils.AsyncRunner.spawn(connection.pair())
pair_task = asyncio.create_task(connection.pair())
listeners: dict[str, Callable[..., None | Awaitable[None]]] = {
listeners: dict[str, Callable[..., Union[None, Awaitable[None]]]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
'connection_authentication_failure': set_failure('authentication_failure'),
@@ -502,7 +500,7 @@ class SecurityService(SecurityServicer):
return WaitSecurityResponse(**kwargs)
async def reached_security_level(
self, connection: BumbleConnection, level: SecurityLevel | LESecurityLevel
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
) -> bool:
self.log.debug(
str(

View File

@@ -18,8 +18,7 @@ import contextlib
import functools
import inspect
import logging
from collections.abc import Generator, MutableMapping
from typing import Any
from typing import Any, Generator, MutableMapping, Optional
import grpc
from google.protobuf.message import Message # pytype: disable=pyi-error
@@ -35,7 +34,7 @@ ADDRESS_TYPES: dict[str, AddressType] = {
}
def address_from_request(request: Message, field: str | None) -> Address:
def address_from_request(request: Message, field: Optional[str]) -> Address:
if field is None:
return Address.ANY
return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
@@ -96,7 +95,8 @@ def rpc(func: Any) -> Any:
@functools.wraps(func)
def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
with exception_to_rpc_error(context):
yield from func(self, request, context)
for v in func(self, request, context):
yield v
@functools.wraps(func)
def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:

View File

@@ -22,6 +22,7 @@ from __future__ import annotations
import logging
import struct
from dataclasses import dataclass
from typing import Optional
from bumble import utils
from bumble.att import ATT_Error
@@ -128,7 +129,7 @@ class AudioInputState:
mute: Mute = Mute.NOT_MUTED
gain_mode: GainMode = GainMode.MANUAL
change_counter: int = 0
attribute: Attribute | None = None
attribute: Optional[Attribute] = None
def __bytes__(self) -> bytes:
return bytes(
@@ -198,6 +199,7 @@ class AudioInputControlPoint:
gain_settings_properties: GainSettingsProperties
async def on_write(self, connection: Connection, value: bytes) -> None:
opcode = AudioInputControlPointOpCode(value[0])
if opcode == AudioInputControlPointOpCode.SET_GAIN_SETTING:
@@ -315,7 +317,7 @@ class AudioInputDescription:
'''
audio_input_description: str = "Bluetooth"
attribute: Attribute | None = None
attribute: Optional[Attribute] = None
def on_read(self, _connection: Connection) -> str:
return self.audio_input_description
@@ -338,11 +340,11 @@ class AICSService(TemplateService):
def __init__(
self,
audio_input_state: AudioInputState | None = None,
gain_settings_properties: GainSettingsProperties | None = None,
audio_input_state: Optional[AudioInputState] = None,
gain_settings_properties: Optional[GainSettingsProperties] = None,
audio_input_type: str = "local",
audio_input_status: AudioInputStatus | None = None,
audio_input_description: AudioInputDescription | None = None,
audio_input_status: Optional[AudioInputStatus] = None,
audio_input_description: Optional[AudioInputDescription] = None,
):
self.audio_input_state = (
AudioInputState() if audio_input_state is None else audio_input_state

View File

@@ -25,7 +25,7 @@ import asyncio
import dataclasses
import enum
import logging
from collections.abc import Iterable
from typing import Iterable, Optional, Union
from bumble import utils
from bumble.device import Peer
@@ -230,7 +230,7 @@ class AmsClient(utils.EventEmitter):
self.supported_commands = set()
@classmethod
async def for_peer(cls, peer: Peer) -> AmsClient | None:
async def for_peer(cls, peer: Peer) -> Optional[AmsClient]:
ams_proxy = await peer.discover_service_and_create_proxy(AmsProxy)
if ams_proxy is None:
return None
@@ -263,7 +263,9 @@ class AmsClient(utils.EventEmitter):
async def observe(
self,
entity: EntityId,
attributes: Iterable[PlayerAttributeId | QueueAttributeId | TrackAttributeId],
attributes: Iterable[
Union[PlayerAttributeId, QueueAttributeId, TrackAttributeId]
],
) -> None:
await self._ams_proxy.entity_update.write_value(
bytes([entity] + list(attributes)), with_response=True

View File

@@ -27,7 +27,7 @@ import datetime
import enum
import logging
import struct
from collections.abc import Sequence
from typing import Optional, Sequence, Union
from bumble import utils
from bumble.att import ATT_Error
@@ -116,7 +116,7 @@ class NotificationAttributeId(utils.OpenIntEnum):
@dataclasses.dataclass
class NotificationAttribute:
attribute_id: NotificationAttributeId
value: str | int | datetime.datetime
value: Union[str, int, datetime.datetime]
@dataclasses.dataclass
@@ -242,10 +242,10 @@ class AncsProxy(ProfileServiceProxy):
class AncsClient(utils.EventEmitter):
_expected_response_command_id: CommandId | None
_expected_response_notification_uid: int | None
_expected_response_app_identifier: str | None
_expected_app_identifier: str | None
_expected_response_command_id: Optional[CommandId]
_expected_response_notification_uid: Optional[int]
_expected_response_app_identifier: Optional[str]
_expected_app_identifier: Optional[str]
_expected_response_tuples: int
_response_accumulator: bytes
@@ -255,12 +255,12 @@ class AncsClient(utils.EventEmitter):
super().__init__()
self._ancs_proxy = ancs_proxy
self._command_semaphore = asyncio.Semaphore()
self._response: asyncio.Future | None = None
self._response: Optional[asyncio.Future] = None
self._reset_response()
self._started = False
@classmethod
async def for_peer(cls, peer: Peer) -> AncsClient | None:
async def for_peer(cls, peer: Peer) -> Optional[AncsClient]:
ancs_proxy = await peer.discover_service_and_create_proxy(AncsProxy)
if ancs_proxy is None:
return None
@@ -316,7 +316,7 @@ class AncsClient(utils.EventEmitter):
# Not enough data yet.
return
attributes: list[NotificationAttribute | AppAttribute] = []
attributes: list[Union[NotificationAttribute, AppAttribute]] = []
if command_id == CommandId.GET_NOTIFICATION_ATTRIBUTES:
(notification_uid,) = struct.unpack_from(
@@ -342,7 +342,7 @@ class AncsClient(utils.EventEmitter):
str_value = attribute_data[3 : 3 + attribute_data_length].decode(
"utf-8"
)
value: str | int | datetime.datetime
value: Union[str, int, datetime.datetime]
if attribute_id == NotificationAttributeId.MESSAGE_SIZE:
value = int(str_value)
elif attribute_id == NotificationAttributeId.DATE:
@@ -415,7 +415,7 @@ class AncsClient(utils.EventEmitter):
self,
notification_uid: int,
attributes: Sequence[
NotificationAttributeId | tuple[NotificationAttributeId, int]
Union[NotificationAttributeId, tuple[NotificationAttributeId, int]]
],
) -> list[NotificationAttribute]:
if not self._started:

View File

@@ -24,7 +24,7 @@ import logging
import struct
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, TypeVar
from typing import Any, Optional, TypeVar, Union
from bumble import colors, device, gatt, gatt_client, hci, utils
from bumble.profiles import le_audio
@@ -49,7 +49,7 @@ class ASE_Operation:
classes: dict[int, type[ASE_Operation]] = {}
op_code: Opcode
name: str
fields: Sequence[Any] | None = None
fields: Optional[Sequence[Any]] = None
ase_id: Sequence[int]
class Opcode(enum.IntEnum):
@@ -278,7 +278,7 @@ class AseStateMachine(gatt.Characteristic):
EVENT_STATE_CHANGE = "state_change"
cis_link: device.CisLink | None = None
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
@@ -290,7 +290,7 @@ class AseStateMachine(gatt.Characteristic):
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: CodecSpecificConfiguration | bytes = b''
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
@@ -610,7 +610,7 @@ class AudioStreamControlService(gatt.TemplateService):
ase_state_machines: dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic[bytes]
_active_client: device.Connection | None = None
_active_client: Optional[device.Connection] = None
def __init__(
self,

View File

@@ -19,8 +19,7 @@
import enum
import logging
import struct
from collections.abc import Callable
from typing import Any
from typing import Any, Callable, Optional, Union
from bumble import data_types, gatt, gatt_client, l2cap, utils
from bumble.core import AdvertisingData
@@ -91,20 +90,20 @@ class AshaService(gatt.TemplateService):
EVENT_DISCONNECTED = "disconnected"
EVENT_VOLUME_CHANGED = "volume_changed"
audio_sink: Callable[[bytes], Any] | None
active_codec: Codec | None = None
audio_type: AudioType | None = None
volume: int | None = None
other_state: int | None = None
connection: Connection | None = None
audio_sink: Optional[Callable[[bytes], Any]]
active_codec: Optional[Codec] = None
audio_type: Optional[AudioType] = None
volume: Optional[int] = None
other_state: Optional[int] = None
connection: Optional[Connection] = None
def __init__(
self,
capability: int,
hisyncid: list[int] | bytes,
hisyncid: Union[list[int], bytes],
device: Device,
psm: int = 0,
audio_sink: Callable[[bytes], Any] | None = None,
audio_sink: Optional[Callable[[bytes], Any]] = None,
feature_map: int = FeatureMap.LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED,
protocol_version: int = 0x01,
render_delay_milliseconds: int = 0,

View File

@@ -21,8 +21,7 @@ from __future__ import annotations
import dataclasses
import logging
import struct
from collections.abc import Sequence
from typing import ClassVar
from typing import ClassVar, Optional, Sequence
from bumble import core, device, gatt, gatt_adapters, gatt_client, hci, utils
@@ -338,12 +337,7 @@ class BroadcastAudioScanService(gatt.TemplateService):
b"12", # TEST
)
super().__init__(
[
self.broadcast_audio_scan_control_point_characteristic,
self.broadcast_receive_state_characteristic,
]
)
super().__init__([self.battery_level_characteristic])
def on_broadcast_audio_scan_control_point_write(
self, connection: device.Connection, value: bytes
@@ -357,7 +351,7 @@ class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy[bytes]
broadcast_receive_states: list[
gatt_client.CharacteristicProxy[BroadcastReceiveState | None]
gatt_client.CharacteristicProxy[Optional[BroadcastReceiveState]]
]
def __init__(self, service_proxy: gatt_client.ServiceProxy):

View File

@@ -16,6 +16,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from typing import Optional
from bumble.gatt import (
GATT_BATTERY_LEVEL_CHARACTERISTIC,
@@ -55,7 +56,7 @@ class BatteryService(TemplateService):
class BatteryServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = BatteryService
battery_level: CharacteristicProxy[int] | None
battery_level: Optional[CharacteristicProxy[int]]
def __init__(self, service_proxy):
self.service_proxy = service_proxy

View File

@@ -20,6 +20,7 @@ from __future__ import annotations
import enum
import struct
from typing import Optional
from bumble import core, crypto, device, gatt, gatt_client
@@ -95,17 +96,17 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic[bytes]
coordinated_set_size_characteristic: gatt.Characteristic[bytes] | None = None
set_member_lock_characteristic: gatt.Characteristic[bytes] | None = None
set_member_rank_characteristic: gatt.Characteristic[bytes] | None = None
coordinated_set_size_characteristic: Optional[gatt.Characteristic[bytes]] = None
set_member_lock_characteristic: Optional[gatt.Characteristic[bytes]] = None
set_member_rank_characteristic: Optional[gatt.Characteristic[bytes]] = None
def __init__(
self,
set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType,
coordinated_set_size: int | None = None,
set_member_lock: MemberLock | None = None,
set_member_rank: int | None = None,
coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise core.InvalidArgumentError(
@@ -197,9 +198,9 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy[bytes]
coordinated_set_size: gatt_client.CharacteristicProxy[bytes] | None = None
set_member_lock: gatt_client.CharacteristicProxy[bytes] | None = None
set_member_rank: gatt_client.CharacteristicProxy[bytes] | None = None
coordinated_set_size: Optional[gatt_client.CharacteristicProxy[bytes]] = None
set_member_lock: Optional[gatt_client.CharacteristicProxy[bytes]] = None
set_member_rank: Optional[gatt_client.CharacteristicProxy[bytes]] = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy

View File

@@ -17,6 +17,7 @@
# Imports
# -----------------------------------------------------------------------------
import struct
from typing import Optional
from bumble.gatt import (
GATT_DEVICE_INFORMATION_SERVICE,
@@ -53,14 +54,14 @@ class DeviceInformationService(TemplateService):
def __init__(
self,
manufacturer_name: str | None = None,
model_number: str | None = None,
serial_number: str | None = None,
hardware_revision: str | None = None,
firmware_revision: str | None = None,
software_revision: str | None = None,
system_id: tuple[int, int] | None = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: bytes | None = None,
manufacturer_name: Optional[str] = None,
model_number: Optional[str] = None,
serial_number: Optional[str] = None,
hardware_revision: Optional[str] = None,
firmware_revision: Optional[str] = None,
software_revision: Optional[str] = None,
system_id: Optional[tuple[int, int]] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: Optional[bytes] = None,
# TODO: pnp_id
):
characteristics: list[Characteristic[bytes]] = [
@@ -108,14 +109,14 @@ class DeviceInformationService(TemplateService):
class DeviceInformationServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = DeviceInformationService
manufacturer_name: CharacteristicProxy[str] | None
model_number: CharacteristicProxy[str] | None
serial_number: CharacteristicProxy[str] | None
hardware_revision: CharacteristicProxy[str] | None
firmware_revision: CharacteristicProxy[str] | None
software_revision: CharacteristicProxy[str] | None
system_id: CharacteristicProxy[tuple[int, int]] | None
ieee_regulatory_certification_data_list: CharacteristicProxy[bytes] | None
manufacturer_name: Optional[CharacteristicProxy[str]]
model_number: Optional[CharacteristicProxy[str]]
serial_number: Optional[CharacteristicProxy[str]]
hardware_revision: Optional[CharacteristicProxy[str]]
firmware_revision: Optional[CharacteristicProxy[str]]
software_revision: Optional[CharacteristicProxy[str]]
system_id: Optional[CharacteristicProxy[tuple[int, int]]]
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy[bytes]]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy

View File

@@ -19,6 +19,7 @@
# -----------------------------------------------------------------------------
import logging
import struct
from typing import Optional, Union
from bumble.core import Appearance
from bumble.gatt import (
@@ -53,7 +54,7 @@ class GenericAccessService(TemplateService):
appearance_characteristic: Characteristic[bytes]
def __init__(
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0
self, device_name: str, appearance: Union[Appearance, tuple[int, int], int] = 0
):
if isinstance(appearance, int):
appearance_int = appearance
@@ -87,8 +88,8 @@ class GenericAccessService(TemplateService):
class GenericAccessServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GenericAccessService
device_name: CharacteristicProxy[str] | None
appearance: CharacteristicProxy[Appearance] | None
device_name: Optional[CharacteristicProxy[str]]
appearance: Optional[CharacteristicProxy[Appearance]]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy

View File

@@ -40,6 +40,7 @@ class GenericAttributeProfileService(gatt.TemplateService):
database_hash_enabled: bool = True,
service_change_enabled: bool = True,
) -> None:
if server_supported_features is not None:
self.server_supported_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC,

View File

@@ -19,6 +19,7 @@
# -----------------------------------------------------------------------------
import struct
from enum import IntFlag
from typing import Optional
from bumble.gatt import (
GATT_BGR_FEATURES_CHARACTERISTIC,
@@ -76,18 +77,18 @@ class GamingAudioService(TemplateService):
UUID = GATT_GAMING_AUDIO_SERVICE
gmap_role: Characteristic
ugg_features: Characteristic | None = None
ugt_features: Characteristic | None = None
bgs_features: Characteristic | None = None
bgr_features: Characteristic | None = None
ugg_features: Optional[Characteristic] = None
ugt_features: Optional[Characteristic] = None
bgs_features: Optional[Characteristic] = None
bgr_features: Optional[Characteristic] = None
def __init__(
self,
gmap_role: GmapRole,
ugg_features: UggFeatures | None = None,
ugt_features: UgtFeatures | None = None,
bgs_features: BgsFeatures | None = None,
bgr_features: BgrFeatures | None = None,
ugg_features: Optional[UggFeatures] = None,
ugt_features: Optional[UgtFeatures] = None,
bgs_features: Optional[BgsFeatures] = None,
bgr_features: Optional[BgrFeatures] = None,
) -> None:
characteristics = []
@@ -149,10 +150,10 @@ class GamingAudioService(TemplateService):
class GamingAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GamingAudioService
ugg_features: CharacteristicProxy[UggFeatures] | None = None
ugt_features: CharacteristicProxy[UgtFeatures] | None = None
bgs_features: CharacteristicProxy[BgsFeatures] | None = None
bgr_features: CharacteristicProxy[BgrFeatures] | None = None
ugg_features: Optional[CharacteristicProxy[UggFeatures]] = None
ugt_features: Optional[CharacteristicProxy[UgtFeatures]] = None
bgs_features: Optional[CharacteristicProxy[BgsFeatures]] = None
bgr_features: Optional[CharacteristicProxy[BgrFeatures]] = None
def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy

View File

@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Optional, Union
from bumble import att, gatt, gatt_adapters, gatt_client, utils
from bumble.core import InvalidArgumentError, InvalidStateError
@@ -145,7 +145,7 @@ class PresetChangedOperation:
return bytes([self.prev_index]) + bytes(self.preset_record)
change_id: ChangeId
additional_parameters: Generic | int
additional_parameters: Union[Generic, int]
def to_bytes(self, is_last: bool) -> bytes:
if isinstance(self.additional_parameters, PresetChangedOperation.Generic):
@@ -235,7 +235,7 @@ class HearingAccessService(gatt.TemplateService):
preset_records: dict[int, PresetRecord] # key is the preset index
read_presets_request_in_progress: bool
other_server_in_binaural_set: HearingAccessService | None = None
other_server_in_binaural_set: Optional[HearingAccessService] = None
preset_changed_operations_history_per_device: dict[
Address, list[PresetChangedOperation]

View File

@@ -20,6 +20,7 @@ from __future__ import annotations
import struct
from enum import IntEnum
from typing import Optional
from bumble import core
from bumble.att import ATT_Error
@@ -206,13 +207,13 @@ class HeartRateService(TemplateService):
class HeartRateServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = HeartRateService
heart_rate_measurement: (
CharacteristicProxy[HeartRateService.HeartRateMeasurement] | None
)
body_sensor_location: (
CharacteristicProxy[HeartRateService.BodySensorLocation] | None
)
heart_rate_control_point: CharacteristicProxy[int] | None
heart_rate_measurement: Optional[
CharacteristicProxy[HeartRateService.HeartRateMeasurement]
]
body_sensor_location: Optional[
CharacteristicProxy[HeartRateService.BodySensorLocation]
]
heart_rate_control_point: Optional[CharacteristicProxy[int]]
def __init__(self, service_proxy):
self.service_proxy = service_proxy

View File

@@ -137,7 +137,7 @@ class Metadata:
values.append(str(decoded))
return '\n'.join(
f'{indent}{key}: {" " * (max_key_length - len(key))}{value}'
f'{indent}{key}: {" " * (max_key_length-len(key))}{value}'
for key, value in zip(keys, values)
)

View File

@@ -22,7 +22,7 @@ import asyncio
import dataclasses
import enum
import struct
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, ClassVar, Optional
from typing_extensions import Self
@@ -196,7 +196,7 @@ class MediaControlService(gatt.TemplateService):
UUID = gatt.GATT_MEDIA_CONTROL_SERVICE
def __init__(self, media_player_name: str | None = None) -> None:
def __init__(self, media_player_name: Optional[str] = None) -> None:
self.track_position = 0
self.media_player_name_characteristic = gatt.Characteristic(
@@ -337,32 +337,32 @@ class MediaControlServiceProxy(
EVENT_TRACK_DURATION = "track_duration"
EVENT_TRACK_POSITION = "track_position"
media_player_name: gatt_client.CharacteristicProxy[bytes] | None = None
media_player_icon_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
media_player_icon_url: gatt_client.CharacteristicProxy[bytes] | None = None
track_changed: gatt_client.CharacteristicProxy[bytes] | None = None
track_title: gatt_client.CharacteristicProxy[bytes] | None = None
track_duration: gatt_client.CharacteristicProxy[bytes] | None = None
track_position: gatt_client.CharacteristicProxy[bytes] | None = None
playback_speed: gatt_client.CharacteristicProxy[bytes] | None = None
seeking_speed: gatt_client.CharacteristicProxy[bytes] | None = None
current_track_segments_object_id: gatt_client.CharacteristicProxy[bytes] | None = (
None
)
current_track_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
next_track_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
parent_group_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
current_group_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
playing_order: gatt_client.CharacteristicProxy[bytes] | None = None
playing_orders_supported: gatt_client.CharacteristicProxy[bytes] | None = None
media_state: gatt_client.CharacteristicProxy[bytes] | None = None
media_control_point: gatt_client.CharacteristicProxy[bytes] | None = None
media_control_point_opcodes_supported: (
gatt_client.CharacteristicProxy[bytes] | None
) = None
search_control_point: gatt_client.CharacteristicProxy[bytes] | None = None
search_results_object_id: gatt_client.CharacteristicProxy[bytes] | None = None
content_control_id: gatt_client.CharacteristicProxy[bytes] | None = None
media_player_name: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_player_icon_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_player_icon_url: Optional[gatt_client.CharacteristicProxy[bytes]] = None
track_changed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
track_title: Optional[gatt_client.CharacteristicProxy[bytes]] = None
track_duration: Optional[gatt_client.CharacteristicProxy[bytes]] = None
track_position: Optional[gatt_client.CharacteristicProxy[bytes]] = None
playback_speed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
seeking_speed: Optional[gatt_client.CharacteristicProxy[bytes]] = None
current_track_segments_object_id: Optional[
gatt_client.CharacteristicProxy[bytes]
] = None
current_track_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
next_track_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
parent_group_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
current_group_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
playing_order: Optional[gatt_client.CharacteristicProxy[bytes]] = None
playing_orders_supported: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_state: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_control_point: Optional[gatt_client.CharacteristicProxy[bytes]] = None
media_control_point_opcodes_supported: Optional[
gatt_client.CharacteristicProxy[bytes]
] = None
search_control_point: Optional[gatt_client.CharacteristicProxy[bytes]] = None
search_results_object_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
content_control_id: Optional[gatt_client.CharacteristicProxy[bytes]] = None
if TYPE_CHECKING:
media_control_point_notifications: asyncio.Queue[bytes]

View File

@@ -21,7 +21,7 @@ from __future__ import annotations
import dataclasses
import logging
import struct
from collections.abc import Sequence
from typing import Optional, Sequence, Union
from bumble import gatt, gatt_adapters, gatt_client, hci
from bumble.profiles import le_audio
@@ -39,7 +39,7 @@ class PacRecord:
'''Published Audio Capabilities Service, Table 3.2/3.4.'''
coding_format: hci.CodingFormat
codec_specific_capabilities: CodecSpecificCapabilities | bytes
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata)
@classmethod
@@ -56,7 +56,7 @@ class PacRecord:
offset += 1
metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size])
codec_specific_capabilities: CodecSpecificCapabilities | bytes
codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes]
if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC:
codec_specific_capabilities = codec_specific_capabilities_bytes
else:
@@ -101,10 +101,10 @@ class PacRecord:
class PublishedAudioCapabilitiesService(gatt.TemplateService):
UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE
sink_pac: gatt.Characteristic[bytes] | None
sink_audio_locations: gatt.Characteristic[bytes] | None
source_pac: gatt.Characteristic[bytes] | None
source_audio_locations: gatt.Characteristic[bytes] | None
sink_pac: Optional[gatt.Characteristic[bytes]]
sink_audio_locations: Optional[gatt.Characteristic[bytes]]
source_pac: Optional[gatt.Characteristic[bytes]]
source_audio_locations: Optional[gatt.Characteristic[bytes]]
available_audio_contexts: gatt.Characteristic[bytes]
supported_audio_contexts: gatt.Characteristic[bytes]
@@ -115,9 +115,9 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
available_source_context: ContextType,
available_sink_context: ContextType,
sink_pac: Sequence[PacRecord] = (),
sink_audio_locations: AudioLocation | None = None,
sink_audio_locations: Optional[AudioLocation] = None,
source_pac: Sequence[PacRecord] = (),
source_audio_locations: AudioLocation | None = None,
source_audio_locations: Optional[AudioLocation] = None,
) -> None:
characteristics = []
@@ -183,10 +183,14 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = PublishedAudioCapabilitiesService
sink_pac: gatt_client.CharacteristicProxy[list[PacRecord]] | None = None
sink_audio_locations: gatt_client.CharacteristicProxy[AudioLocation] | None = None
source_pac: gatt_client.CharacteristicProxy[list[PacRecord]] | None = None
source_audio_locations: gatt_client.CharacteristicProxy[AudioLocation] | None = None
sink_pac: Optional[gatt_client.CharacteristicProxy[list[PacRecord]]] = None
sink_audio_locations: Optional[gatt_client.CharacteristicProxy[AudioLocation]] = (
None
)
source_pac: Optional[gatt_client.CharacteristicProxy[list[PacRecord]]] = None
source_audio_locations: Optional[gatt_client.CharacteristicProxy[AudioLocation]] = (
None
)
available_audio_contexts: gatt_client.CharacteristicProxy[tuple[ContextType, ...]]
supported_audio_contexts: gatt_client.CharacteristicProxy[tuple[ContextType, ...]]

View File

@@ -22,7 +22,6 @@ import enum
from typing_extensions import Self
from bumble import core, data_types, gatt
from bumble.profiles import le_audio
@@ -47,18 +46,3 @@ class PublicBroadcastAnnouncement:
return cls(
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
)
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
data_types.ServiceData16BitUUID(
gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE, bytes(self)
)
]
)
)
def __bytes__(self) -> bytes:
metadata_bytes = bytes(self.metadata)
return bytes([self.features, len(metadata_bytes)]) + metadata_bytes

View File

@@ -20,9 +20,9 @@ from __future__ import annotations
import dataclasses
import enum
from collections.abc import Sequence
from typing import Sequence
from bumble import att, device, gatt, gatt_adapters, gatt_client
from bumble import att, device, gatt, gatt_adapters, gatt_client, utils
# -----------------------------------------------------------------------------
# Constants

View File

@@ -18,6 +18,7 @@
# -----------------------------------------------------------------------------
import struct
from dataclasses import dataclass
from typing import Optional
from bumble import utils
from bumble.att import ATT_Error
@@ -68,7 +69,7 @@ class ErrorCode(utils.OpenIntEnum):
class VolumeOffsetState:
volume_offset: int = 0
change_counter: int = 0
attribute: Characteristic | None = None
attribute: Optional[Characteristic] = None
def __bytes__(self) -> bytes:
return struct.pack('<hB', self.volume_offset, self.change_counter)
@@ -92,7 +93,7 @@ class VolumeOffsetState:
@dataclass
class VocsAudioLocation:
audio_location: AudioLocation = AudioLocation.NOT_ALLOWED
attribute: Characteristic | None = None
attribute: Optional[Characteristic] = None
def __bytes__(self) -> bytes:
return struct.pack('<I', self.audio_location)
@@ -117,6 +118,7 @@ class VolumeOffsetControlPoint:
volume_offset_state: VolumeOffsetState
async def on_write(self, connection: Connection, value: bytes) -> None:
opcode = value[0]
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
@@ -146,7 +148,7 @@ class VolumeOffsetControlPoint:
@dataclass
class AudioOutputDescription:
audio_output_description: str = ''
attribute: Characteristic | None = None
attribute: Optional[Characteristic] = None
@classmethod
def from_bytes(cls, data: bytes):
@@ -171,10 +173,11 @@ class VolumeOffsetControlService(TemplateService):
def __init__(
self,
volume_offset_state: VolumeOffsetState | None = None,
audio_location: VocsAudioLocation | None = None,
audio_output_description: AudioOutputDescription | None = None,
volume_offset_state: Optional[VolumeOffsetState] = None,
audio_location: Optional[VocsAudioLocation] = None,
audio_output_description: Optional[AudioOutputDescription] = None,
) -> None:
self.volume_offset_state = (
VolumeOffsetState() if volume_offset_state is None else volume_offset_state
)

View File

@@ -22,8 +22,7 @@ import collections
import dataclasses
import enum
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing_extensions import Self
@@ -120,7 +119,7 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: UUID | None = None
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> list[sdp.ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
@@ -187,7 +186,7 @@ async def find_rfcomm_channels(connection: Connection) -> dict[int, list[UUID]]:
)
for attribute_lists in search_result:
service_classes: list[UUID] = []
channel: int | None = None
channel: Optional[int] = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
@@ -208,7 +207,7 @@ async def find_rfcomm_channels(connection: Connection) -> dict[int, list[UUID]]:
# -----------------------------------------------------------------------------
async def find_rfcomm_channel_with_uuid(
connection: Connection, uuid: str | UUID
) -> int | None:
) -> Optional[int]:
"""Searches an RFCOMM channel associated with given UUID from service records.
Args:
@@ -474,15 +473,15 @@ class DLC(utils.EventEmitter):
self.state = DLC.State.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.connection_result: asyncio.Future | None = None
self.disconnection_result: asyncio.Future | None = None
self.connection_result: Optional[asyncio.Future] = None
self.disconnection_result: Optional[asyncio.Future] = None
self.drained = asyncio.Event()
self.drained.set()
# Queued packets when sink is not set.
self._enqueued_rx_packets: collections.deque[bytes] = collections.deque(
maxlen=DEFAULT_RX_QUEUE_SIZE
)
self._sink: Callable[[bytes], None] | None = None
self._sink: Optional[Callable[[bytes], None]] = None
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
@@ -491,11 +490,11 @@ class DLC(utils.EventEmitter):
)
@property
def sink(self) -> Callable[[bytes], None] | None:
def sink(self) -> Optional[Callable[[bytes], None]]:
return self._sink
@sink.setter
def sink(self, sink: Callable[[bytes], None] | None) -> None:
def sink(self, sink: Optional[Callable[[bytes], None]]) -> None:
self._sink = sink
# Dump queued packets to sink
if sink:
@@ -675,14 +674,10 @@ class DLC(utils.EventEmitter):
while (self.tx_buffer and self.tx_credits > 0) or rx_credits_needed > 0:
# Get the next chunk, up to MTU size
if rx_credits_needed > 0:
chunk = bytes([rx_credits_needed])
chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
self.rx_credits += rx_credits_needed
if self.tx_buffer and self.tx_credits > 0:
chunk += self.tx_buffer[: self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
tx_credit_spent = True
else:
tx_credit_spent = False
tx_credit_spent = len(chunk) > 1
else:
chunk = self.tx_buffer[: self.mtu]
self.tx_buffer = self.tx_buffer[len(chunk) :]
@@ -713,7 +708,7 @@ class DLC(utils.EventEmitter):
self.drained.set()
# Stream protocol
def write(self, data: bytes | str) -> None:
def write(self, data: Union[bytes, str]) -> None:
# We can only send bytes
if not isinstance(data, bytes):
if isinstance(data, str):
@@ -770,10 +765,10 @@ class Multiplexer(utils.EventEmitter):
EVENT_DLC = "dlc"
connection_result: asyncio.Future | None
disconnection_result: asyncio.Future | None
open_result: asyncio.Future | None
acceptor: Callable[[int], tuple[int, int] | None] | None
connection_result: Optional[asyncio.Future]
disconnection_result: Optional[asyncio.Future]
open_result: Optional[asyncio.Future]
acceptor: Optional[Callable[[int], Optional[tuple[int, int]]]]
dlcs: dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
@@ -785,7 +780,7 @@ class Multiplexer(utils.EventEmitter):
self.connection_result = None
self.disconnection_result = None
self.open_result = None
self.open_pn: RFCOMM_MCC_PN | None = None
self.open_pn: Optional[RFCOMM_MCC_PN] = None
self.open_rx_max_credits = 0
self.acceptor = None
@@ -1032,8 +1027,8 @@ class Multiplexer(utils.EventEmitter):
# -----------------------------------------------------------------------------
class Client:
multiplexer: Multiplexer | None
l2cap_channel: l2cap.ClassicChannel | None
multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.ClassicChannel]
def __init__(
self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
@@ -1146,7 +1141,7 @@ class Server(utils.EventEmitter):
# Notify
self.emit(self.EVENT_START, multiplexer)
def accept_dlc(self, channel_number: int) -> tuple[int, int] | None:
def accept_dlc(self, channel_number: int) -> Optional[tuple[int, int]]:
return self.dlc_configs.get(channel_number)
def on_dlc(self, dlc: DLC) -> None:

View File

@@ -20,8 +20,7 @@ from __future__ import annotations
import asyncio
import logging
import struct
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, NewType
from typing import TYPE_CHECKING, Iterable, NewType, Optional, Sequence, Union
from typing_extensions import Self
@@ -498,7 +497,7 @@ class ServiceAttribute:
@staticmethod
def find_attribute_in_list(
attribute_list: Iterable[ServiceAttribute], attribute_id: int
) -> DataElement | None:
) -> Optional[DataElement]:
return next(
(
attribute.value
@@ -529,7 +528,7 @@ class ServiceAttribute:
def to_string(self, with_colors=False):
if with_colors:
return (
f'Attribute(id={color(self.id_name(self.id), "magenta")},'
f'Attribute(id={color(self.id_name(self.id),"magenta")},'
f'value={self.value})'
)
@@ -779,11 +778,11 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
class Client:
def __init__(self, connection: Connection, mtu: int = 0) -> None:
self.connection = connection
self.channel: l2cap.ClassicChannel | None = None
self.channel: Optional[l2cap.ClassicChannel] = None
self.mtu = mtu
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request: SDP_PDU | None = None
self.pending_response: asyncio.futures.Future[SDP_PDU] | None = None
self.pending_request: Optional[SDP_PDU] = None
self.pending_response: Optional[asyncio.futures.Future[SDP_PDU]] = None
self.next_transaction_id = 0
async def connect(self) -> None:
@@ -899,7 +898,7 @@ class Client:
async def search_attributes(
self,
uuids: Iterable[core.UUID],
attribute_ids: Iterable[int | tuple[int, int]],
attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> list[list[ServiceAttribute]]:
"""
Search for attributes by UUID and attribute IDs.
@@ -971,7 +970,7 @@ class Client:
async def get_attributes(
self,
service_record_handle: int,
attribute_ids: Iterable[int | tuple[int, int]],
attribute_ids: Iterable[Union[int, tuple[int, int]]],
) -> list[ServiceAttribute]:
"""
Get attributes for a service.
@@ -1043,10 +1042,10 @@ class Client:
# -----------------------------------------------------------------------------
class Server:
CONTINUATION_STATE = bytes([0x01, 0x00])
channel: l2cap.ClassicChannel | None
channel: Optional[l2cap.ClassicChannel]
Service = NewType('Service', list[ServiceAttribute])
service_records: dict[int, Service]
current_response: None | bytes | tuple[int, list[int]]
current_response: Union[None, bytes, tuple[int, list[int]]]
def __init__(self, device: Device) -> None:
self.device = device
@@ -1124,7 +1123,7 @@ class Server:
self,
continuation_state: bytes,
transaction_id: int,
) -> bool | None:
) -> Optional[bool]:
# Check if this is a valid continuation
if len(continuation_state) > 1:
if (

View File

@@ -27,9 +27,17 @@ from __future__ import annotations
import asyncio
import enum
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
ClassVar,
Optional,
TypeVar,
cast,
)
from bumble import crypto, utils
from bumble.colors import color
@@ -205,10 +213,10 @@ class SMP_Command:
fields: ClassVar[Fields]
code: int = field(default=0, init=False)
name: str = field(default='', init=False)
_payload: bytes | None = field(default=None, init=False)
_payload: Optional[bytes] = field(default=None, init=False)
@classmethod
def from_bytes(cls, pdu: bytes) -> SMP_Command:
def from_bytes(cls, pdu: bytes) -> "SMP_Command":
code = pdu[0]
subclass = SMP_Command.smp_classes.get(code)
@@ -546,7 +554,7 @@ class OobContext:
r: bytes
def __init__(
self, ecc_key: crypto.EccKey | None = None, r: bytes | None = None
self, ecc_key: Optional[crypto.EccKey] = None, r: Optional[bytes] = None
) -> None:
self.ecc_key = crypto.EccKey.generate() if ecc_key is None else ecc_key
self.r = crypto.r() if r is None else r
@@ -562,7 +570,7 @@ class OobLegacyContext:
tk: bytes
def __init__(self, tk: bytes | None = None) -> None:
def __init__(self, tk: Optional[bytes] = None) -> None:
self.tk = crypto.r() if tk is None else tk
@@ -669,31 +677,31 @@ class Session:
self.stk = None
self.ltk_ediv = 0
self.ltk_rand = bytes(8)
self.link_key: bytes | None = None
self.link_key: Optional[bytes] = None
self.maximum_encryption_key_size: int = 0
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.peer_random_value: bytes | None = None
self.peer_random_value: Optional[bytes] = None
self.peer_public_key_x: bytes = bytes(32)
self.peer_public_key_y = bytes(32)
self.peer_ltk = None
self.peer_ediv = None
self.peer_rand: bytes | None = None
self.peer_rand: Optional[bytes] = None
self.peer_identity_resolving_key = None
self.peer_bd_addr: Address | None = None
self.peer_bd_addr: Optional[Address] = None
self.peer_signature_key = None
self.peer_expected_distributions: list[type[SMP_Command]] = []
self.dh_key = b''
self.confirm_value = None
self.passkey: int | None = None
self.passkey: Optional[int] = None
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method: PairingMethod = PairingMethod.JUST_WORKS
self.pairing_config = pairing_config
self.wait_before_continuing: asyncio.Future[None] | None = None
self.wait_before_continuing: Optional[asyncio.Future[None]] = None
self.completed = False
self.ctkd_task: Awaitable[None] | None = None
self.ctkd_task: Optional[Awaitable[None]] = None
# Decide if we're the initiator or the responder
self.is_initiator = is_initiator
@@ -712,7 +720,7 @@ class Session:
# Create a future that can be used to wait for the session to complete
if self.is_initiator:
self.pairing_result: asyncio.Future[None] | None = (
self.pairing_result: Optional[asyncio.Future[None]] = (
asyncio.get_running_loop().create_future()
)
else:
@@ -820,7 +828,7 @@ class Session:
def auth_req(self) -> int:
return smp_auth_req(self.bonding, self.mitm, self.sc, self.keypress, self.ct2)
def get_long_term_key(self, rand: bytes, ediv: int) -> bytes | None:
def get_long_term_key(self, rand: bytes, ediv: int) -> Optional[bytes]:
if not self.sc and not self.completed:
if rand == self.ltk_rand and ediv == self.ltk_ediv:
return self.stk
@@ -931,7 +939,7 @@ class Session:
self.pairing_config.delegate.display_number(self.passkey, digits=6)
)
def input_passkey(self, next_steps: Callable[[], None] | None = None) -> None:
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
# Prompt the user for the passkey displayed on the peer
def after_input(passkey: int) -> None:
self.passkey = passkey
@@ -948,7 +956,7 @@ class Session:
self.prompt_user_for_number(after_input)
def display_or_input_passkey(
self, next_steps: Callable[[], None] | None = None
self, next_steps: Optional[Callable[[], None]] = None
) -> None:
if self.passkey_display:
@@ -998,6 +1006,7 @@ class Session:
self.send_command(response)
def send_pairing_confirm_command(self) -> None:
if self.pairing_method != PairingMethod.OOB:
self.r = crypto.r()
logger.debug(f'generated random: {self.r.hex()}')
@@ -1833,6 +1842,7 @@ class Session:
self.send_public_key_command()
def next_steps() -> None:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
@@ -1919,7 +1929,7 @@ class Manager(utils.EventEmitter):
sessions: dict[int, Session]
pairing_config_factory: Callable[[Connection], PairingConfig]
session_proxy: type[Session]
_ecc_key: crypto.EccKey | None
_ecc_key: Optional[crypto.EccKey]
def __init__(
self,
@@ -2012,7 +2022,7 @@ class Manager(utils.EventEmitter):
self.device.on_pairing_start(session.connection)
async def on_pairing(
self, session: Session, identity_address: Address | None, keys: PairingKeys
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
@@ -2031,7 +2041,7 @@ class Manager(utils.EventEmitter):
def get_long_term_key(
self, connection: Connection, rand: bytes, ediv: int
) -> bytes | None:
) -> Optional[bytes]:
if session := self.sessions.get(connection.handle):
return session.get_long_term_key(rand, ediv)

View File

@@ -16,14 +16,13 @@ import datetime
import logging
import os
import struct
from collections.abc import Generator
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from contextlib import contextmanager
from enum import IntEnum
from typing import BinaryIO
from typing import BinaryIO, Generator
from bumble import core
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
@@ -66,7 +65,7 @@ class BtSnooper(Snooper):
"""
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc)
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MS = datetime.timedelta(microseconds=1)
@@ -86,13 +85,7 @@ class BtSnooper(Snooper):
# Compute the current timestamp
timestamp = (
int(
(
datetime.datetime.now(tz=datetime.timezone.utc)
- self.TIMESTAMP_ANCHOR
)
/ self.ONE_MS
)
int((datetime.datetime.utcnow() - self.TIMESTAMP_ANCHOR) / self.ONE_MS)
+ self.TIMESTAMP_DELTA
)
@@ -136,7 +129,7 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
records will be written to that file if it can be opened/created.
The keyword args that may be referenced by the string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.now(tz=datetime.timezone.utc)`
utcnow: the value of `datetime.utcnow()`
pid: the current process ID.
instance: the instance ID in the current process.
@@ -160,7 +153,7 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
global _SNOOPER_INSTANCE_COUNT
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.now(tz=datetime.timezone.utc),
utcnow=datetime.datetime.utcnow(),
pid=os.getpid(),
instance=_SNOOPER_INSTANCE_COUNT,
)

View File

@@ -18,6 +18,7 @@
import logging
import os
import re
from typing import Optional
from bumble import utils
from bumble.snoop import create_snooper
@@ -83,12 +84,7 @@ async def open_transport(name: str) -> Transport:
scheme, *tail = name.split(':', 1)
spec = tail[0] if tail else None
metadata = None
# If a spec is provided, check for a metadata section in square brackets.
# The regex captures a comma-separated list of key=value pairs (allowing an
# optional trailing comma). The key is matched by \w+ and the value by [^,\]]+,
# meaning the value may contain any character except a comma or a closing
# bracket (']').
if spec and (m := re.search(r'\[(\w+=[^,\]]+(?:,\w+=[^,\]]+)*,?)\]', spec)):
if spec and (m := re.search(r'\[(\w+=\w+(?:,\w+=\w+)*,?)\]', spec)):
metadata_str = m.group(1)
if m.start() == 0:
# <metadata><spec>
@@ -110,7 +106,7 @@ async def open_transport(name: str) -> Transport:
# -----------------------------------------------------------------------------
async def _open_transport(scheme: str, spec: str | None) -> Transport:
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements

View File

@@ -16,6 +16,7 @@
# Imports
# -----------------------------------------------------------------------------
import logging
from typing import Optional, Union
import grpc.aio
@@ -43,7 +44,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec: str | None) -> Transport:
async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
'''
Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax:
@@ -88,7 +89,7 @@ async def open_android_emulator_transport(spec: str | None) -> Transport:
logger.debug('connecting to gRPC server at %s', server_address)
channel = grpc.aio.insecure_channel(server_address)
service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
if mode == 'host':
# Connect as a host
service = EmulatedBluetoothServiceStub(channel)

View File

@@ -22,6 +22,7 @@ import os
import pathlib
import platform
import sys
from typing import Optional
import grpc.aio
@@ -65,7 +66,7 @@ DEFAULT_VARIANT = ''
# -----------------------------------------------------------------------------
def get_ini_dir() -> pathlib.Path | None:
def get_ini_dir() -> Optional[pathlib.Path]:
if sys.platform == 'darwin':
if tmpdir := os.getenv('TMPDIR', None):
return pathlib.Path(tmpdir)
@@ -99,7 +100,7 @@ def find_grpc_port(instance_number: int) -> int:
ini_file = ini_dir / ini_file_name(instance_number)
logger.debug(f'Looking for .ini file at {ini_file}')
if ini_file.is_file():
with open(ini_file) as ini_file_data:
with open(ini_file, 'r') as ini_file_data:
for line in ini_file_data.readlines():
if '=' in line:
key, value = line.split('=')
@@ -145,7 +146,7 @@ def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
# -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(
server_host: str | None, server_port: int, options: dict[str, str]
server_host: Optional[str], server_port: int, options: dict[str, str]
) -> Transport:
if server_host == '_' or not server_host:
server_host = 'localhost'
@@ -155,26 +156,21 @@ async def open_android_netsim_controller_transport(
logger.warning("unable to publish gRPC port")
class HciDevice:
def __init__(self, context, server):
def __init__(self, context, on_data_received):
self.context = context
self.server = server
self.on_data_received = on_data_received
self.name = None
self.sink = None
self.loop = asyncio.get_running_loop()
self.done = self.loop.create_future()
self.task = self.loop.create_task(self.pump())
async def pump(self):
try:
await self.pump_loop()
except asyncio.CancelledError:
logger.debug('Pump task canceled')
finally:
if self.sink:
logger.debug('Releasing sink')
self.server.release_sink()
self.sink = None
logger.debug('Pump task terminated')
if not self.done.done():
self.done.set_result(None)
async def pump_loop(self):
while True:
@@ -190,26 +186,15 @@ async def open_android_netsim_controller_transport(
if request.WhichOneof('request_type') == 'initial_info':
logger.debug(f'Received initial info: {request}')
self.name = request.initial_info.name
# We only accept BLUETOOTH
if request.initial_info.chip.kind != ChipKind.BLUETOOTH:
logger.warning('Unsupported chip type')
error = PacketResponse(error='Unsupported chip type')
await self.context.write(error)
# return
continue
return
# Lease the sink so that no other device can send
self.sink = self.server.lease_sink(self)
if self.sink is None:
logger.warning('Another device is already connected')
error = PacketResponse(error='Device busy')
await self.context.write(error)
# return
continue
continue
self.name = request.initial_info.name
continue
# Expect a data packet
request_type = request.WhichOneof('request_type')
@@ -220,10 +205,10 @@ async def open_android_netsim_controller_transport(
continue
# Process the packet
assert self.sink is not None
self.sink(
data = (
bytes([request.hci_packet.packet_type]) + request.hci_packet.packet
)
self.on_data_received(data)
async def send_packet(self, data):
return await self.context.write(
@@ -232,6 +217,12 @@ async def open_android_netsim_controller_transport(
)
)
def terminate(self):
self.task.cancel()
async def wait_for_termination(self):
await self.done
server_address = f'{server_host}:{server_port}'
class Server(PacketStreamerServicer, ParserSource):
@@ -267,27 +258,27 @@ async def open_android_netsim_controller_transport(
return await self.device.send_packet(packet)
def lease_sink(self, device):
if self.device:
return None
self.device = device
return self.parser.feed_data
def release_sink(self):
self.device = None
async def StreamPackets(self, request_iterator, context):
async def StreamPackets(self, _request_iterator, context):
logger.debug('StreamPackets request')
# Instantiate a new device
device = HciDevice(context, self)
# Check that we don't already have a device
if self.device:
logger.debug('Busy, already serving a device')
return PacketResponse(error='Busy')
# Pump packets to/from the device
logger.debug('Pumping device packets')
# Instantiate a new device
self.device = HciDevice(context, self.parser.feed_data)
# Wait for the device to terminate
logger.debug('Waiting for device to terminate')
try:
await device.pump()
finally:
logger.debug('Pump terminated')
await self.device.wait_for_termination()
except asyncio.CancelledError:
logger.debug('Request canceled')
self.device.terminate()
logger.debug('Device terminated')
self.device = None
server = Server()
await server.start()
@@ -300,9 +291,9 @@ async def open_android_netsim_controller_transport(
# -----------------------------------------------------------------------------
async def open_android_netsim_host_transport_with_address(
server_host: str | None,
server_host: Optional[str],
server_port: int,
options: dict[str, str] | None = None,
options: Optional[dict[str, str]] = None,
):
if server_host == '_' or not server_host:
server_host = 'localhost'
@@ -327,7 +318,7 @@ async def open_android_netsim_host_transport_with_address(
# -----------------------------------------------------------------------------
async def open_android_netsim_host_transport_with_channel(
channel, options: dict[str, str] | None = None
channel, options: Optional[dict[str, str]] = None
):
# Wrapper for I/O operations
class HciDevice:
@@ -407,7 +398,7 @@ async def open_android_netsim_host_transport_with_channel(
# -----------------------------------------------------------------------------
async def open_android_netsim_transport(spec: str | None) -> Transport:
async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
'''
Open a transport connection as a client or server, implementing Android's `netsim`
simulator protocol over gRPC.

View File

@@ -22,8 +22,7 @@ import contextlib
import io
import logging
import struct
from collections.abc import Awaitable, Callable
from typing import Any, Protocol
from typing import Any, ContextManager, Optional, Protocol
from bumble import core, hci
from bumble.colors import color
@@ -107,11 +106,11 @@ class PacketParser:
NEED_LENGTH = 1
NEED_BODY = 2
sink: TransportSink | None
sink: Optional[TransportSink]
extended_packet_info: dict[int, tuple[int, int, str]]
packet_info: tuple[int, int, str] | None = None
packet_info: Optional[tuple[int, int, str]] = None
def __init__(self, sink: TransportSink | None = None) -> None:
def __init__(self, sink: Optional[TransportSink] = None) -> None:
self.sink = sink
self.extended_packet_info = {}
self.reset()
@@ -176,7 +175,7 @@ class PacketReader:
self.source = source
self.at_end = False
def next_packet(self) -> bytes | None:
def next_packet(self) -> Optional[bytes]:
# Get the packet type
packet_type = self.source.read(1)
if len(packet_type) != 1:
@@ -253,7 +252,7 @@ class BaseSource:
"""
terminated: asyncio.Future[None]
sink: TransportSink | None
sink: Optional[TransportSink]
def __init__(self) -> None:
self.terminated = asyncio.get_running_loop().create_future()
@@ -357,7 +356,7 @@ class Transport:
# -----------------------------------------------------------------------------
class PumpedPacketSource(ParserSource):
pump_task: asyncio.Task[None] | None
pump_task: Optional[asyncio.Task[None]]
def __init__(self, receive) -> None:
super().__init__()
@@ -390,17 +389,15 @@ class PumpedPacketSource(ParserSource):
# -----------------------------------------------------------------------------
class PumpedPacketSink:
pump_task: asyncio.Task[None] | None
def __init__(self, send: Callable[[bytes], Awaitable[Any]]):
def __init__(self, send):
self.send_function = send
self.packet_queue = asyncio.Queue[bytes]()
self.packet_queue = asyncio.Queue()
self.pump_task = None
def on_packet(self, packet: bytes) -> None:
self.packet_queue.put_nowait(packet)
def start(self) -> None:
def start(self):
async def pump_packets():
while True:
try:
@@ -443,7 +440,7 @@ class SnoopingTransport(Transport):
@staticmethod
def create_with(
transport: Transport, snooper: contextlib.AbstractContextManager[Snooper]
transport: Transport, snooper: ContextManager[Snooper]
) -> SnoopingTransport:
"""
Create an instance given a snooper that works as as context manager.

View File

@@ -16,6 +16,7 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import io
import logging
from bumble.transport.common import StreamPacketSink, StreamPacketSource, Transport
@@ -35,7 +36,7 @@ async def open_file_transport(spec: str) -> Transport:
'''
# Open the file
file = open(spec, 'r+b', buffering=0)
file = io.open(spec, 'r+b', buffering=0)
# Setup reading
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(

View File

@@ -22,6 +22,7 @@ import logging
import os
import socket
import struct
from typing import Optional
from bumble.transport.common import ParserSource, Transport
@@ -32,7 +33,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_hci_socket_transport(spec: str | None) -> Transport:
async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
'''
Open an HCI Socket (only available on some platforms).
The parameter string is either empty (to use the first/default Bluetooth adapter)
@@ -86,7 +87,7 @@ async def open_hci_socket_transport(spec: str | None) -> Transport:
)
!= 0
):
raise OSError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
class HciSocketSource(ParserSource):
def __init__(self, hci_socket):

View File

@@ -17,10 +17,12 @@
# -----------------------------------------------------------------------------
import asyncio
import atexit
import io
import logging
import os
import pty
import tty
from typing import Optional
from bumble.transport.common import StreamPacketSink, StreamPacketSource, Transport
@@ -31,7 +33,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_pty_transport(spec: str | None) -> Transport:
async def open_pty_transport(spec: Optional[str]) -> Transport:
'''
Open a PTY transport.
The parameter string may be empty, or a path name where a symbolic link
@@ -46,11 +48,11 @@ async def open_pty_transport(spec: str | None) -> Transport:
tty.setraw(replica)
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
StreamPacketSource, open(primary, 'rb', closefd=False)
StreamPacketSource, io.open(primary, 'rb', closefd=False)
)
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
asyncio.BaseProtocol, open(primary, 'wb', closefd=False)
asyncio.BaseProtocol, io.open(primary, 'wb', closefd=False)
)
packet_sink = StreamPacketSink(write_transport)

View File

@@ -19,6 +19,7 @@ import asyncio
import logging
import threading
import time
from typing import Optional
import usb.core
import usb.util
@@ -283,9 +284,7 @@ async def open_pyusb_transport(spec: str) -> Transport:
device = await _power_cycle(device) # type: ignore
except Exception as e:
logging.debug(e, stack_info=True)
logging.info(
f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}"
) # type: ignore
logging.info(f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}") # type: ignore
# Collect the metadata
device_metadata = {'vendor_id': device.idVendor, 'product_id': device.idProduct}
@@ -371,9 +370,7 @@ async def _power_cycle(device: UsbDevice) -> UsbDevice:
# Device needs to be find again otherwise it will appear as disconnected
return usb.core.find(idVendor=device.idVendor, idProduct=device.idProduct) # type: ignore
except USBError:
logger.exception(
f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition."
) # type: ignore
logger.exception(f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition.") # type: ignore
return device
@@ -388,7 +385,7 @@ def _set_port_status(device: UsbDevice, port: int, on: bool):
)
def _find_device_by_path(sys_path: str) -> UsbDevice | None:
def _find_device_by_path(sys_path: str) -> Optional[UsbDevice]:
"""Finds a USB device based on its system path."""
bus_num, *port_parts = sys_path.split('-')
ports = [int(port) for port in port_parts[0].split('.')]
@@ -401,7 +398,7 @@ def _find_device_by_path(sys_path: str) -> UsbDevice | None:
return None
def _find_hub_by_device_path(sys_path: str) -> UsbDevice | None:
def _find_hub_by_device_path(sys_path: str) -> Optional[UsbDevice]:
"""Finds the USB hub associated with a specific device path."""
hub_sys_path = sys_path.rsplit('.', 1)[0]
hub_device = _find_device_by_path(hub_sys_path)

View File

@@ -17,6 +17,7 @@
# -----------------------------------------------------------------------------
import asyncio
import logging
from typing import Optional
import serial_asyncio
@@ -51,7 +52,7 @@ class SerialPacketSource(StreamPacketSource):
logger.debug('connection made')
self._ready.set()
def connection_lost(self, exc: Exception | None) -> None:
def connection_lost(self, exc: Optional[Exception]) -> None:
logger.debug('connection lost')
self.on_transport_lost()

View File

@@ -16,6 +16,7 @@
# Imports
# -----------------------------------------------------------------------------
import logging
from typing import Optional
from bumble.transport.common import Transport
from bumble.transport.file import open_file_transport
@@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_vhci_transport(spec: str | None) -> Transport:
async def open_vhci_transport(spec: Optional[str]) -> Transport:
'''
Open a VHCI transport (only available on some platforms).
The parameter string is either empty (to use the default VHCI device

View File

@@ -17,7 +17,7 @@
# -----------------------------------------------------------------------------
import logging
import websockets.asyncio.client
import websockets.client
from bumble.transport.common import (
PumpedPacketSink,
@@ -42,7 +42,7 @@ async def open_ws_client_transport(spec: str) -> Transport:
Example: ws://localhost:7681/v1/websocket/bt
'''
websocket = await websockets.asyncio.client.connect(spec)
websocket = await websockets.client.connect(spec)
class WsTransport(PumpedTransport):
async def close(self):

View File

@@ -17,7 +17,7 @@
# -----------------------------------------------------------------------------
import logging
import websockets.asyncio.server
import websockets
from bumble.transport.common import ParserSource, PumpedPacketSink, Transport
@@ -40,12 +40,7 @@ async def open_ws_server_transport(spec: str) -> Transport:
'''
class WsServerTransport(Transport):
sink: PumpedPacketSink
source: ParserSource
connection: websockets.asyncio.server.ServerConnection | None
server: websockets.asyncio.server.Server | None
def __init__(self) -> None:
def __init__(self):
source = ParserSource()
sink = PumpedPacketSink(self.send_packet)
self.connection = None
@@ -53,19 +48,17 @@ async def open_ws_server_transport(spec: str) -> Transport:
super().__init__(source, sink)
async def serve(self, local_host: str, local_port: str) -> None:
async def serve(self, local_host, local_port):
self.sink.start()
# pylint: disable-next=no-member
self.server = await websockets.asyncio.server.serve(
handler=self.on_connection,
self.server = await websockets.serve(
ws_handler=self.on_connection,
host=local_host if local_host != '_' else None,
port=int(local_port),
)
logger.debug(f'websocket server ready on port {local_port}')
async def on_connection(
self, connection: websockets.asyncio.server.ServerConnection
) -> None:
async def on_connection(self, connection):
logger.debug(
f'new connection on {connection.local_address} '
f'from {connection.remote_address}'
@@ -84,11 +77,11 @@ async def open_ws_server_transport(spec: str) -> Transport:
# We're now disconnected
self.connection = None
async def send_packet(self, packet: bytes) -> None:
async def send_packet(self, packet):
if self.connection is None:
logger.debug('no connection, dropping packet')
return
await self.connection.send(packet)
return await self.connection.send(packet)
local_host, local_port = spec.rsplit(':', maxsplit=1)
transport = WsServerTransport()

View File

@@ -22,12 +22,16 @@ import collections
import enum
import functools
import logging
import sys
import warnings
from collections.abc import Awaitable, Callable
from typing import (
Any,
Awaitable,
Callable,
Optional,
Protocol,
TypeVar,
Union,
overload,
)
@@ -166,8 +170,8 @@ class EventWatcher:
) -> _Handler: ...
def on(
self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None
) -> _Handler | Callable[[_Handler], _Handler]:
self, emitter: pyee.EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event until the context is closed.
Args:
@@ -195,8 +199,8 @@ class EventWatcher:
) -> _Handler: ...
def once(
self, emitter: pyee.EventEmitter, event: str, handler: _Handler | None = None
) -> _Handler | Callable[[_Handler], _Handler]:
self, emitter: pyee.EventEmitter, event: str, handler: Optional[_Handler] = None
) -> Union[_Handler, Callable[[_Handler], _Handler]]:
'''Watch an event for once.
Args:
@@ -237,7 +241,11 @@ def cancel_on_event(
return
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
future.cancel(msg)
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
@@ -529,20 +537,3 @@ class IntConvertible(Protocol):
def __init__(self, value: int) -> None: ...
def __int__(self) -> int: ...
# -----------------------------------------------------------------------------
def crc_16(data: bytes) -> int:
"""Calculate CRC-16-IBM of given data.
Polynomial = x^16 + x^15 + x^2 + 1 = 0x8005 or 0xA001(Reversed)
"""
crc = 0x0000
for byte in data:
crc ^= byte
for _ in range(8):
if (crc & 0x0001) > 0:
crc = (crc >> 1) ^ 0xA001
else:
crc = crc >> 1
return crc

Some files were not shown because too many files have changed in this diff Show More