Reorganize exceptions

* Add BaseBumbleException as a "real" root error
* Add several core error classes and properly replace builtin errors
  with them
* Add several error classes for specific modules (transport, device)
This commit is contained in:
Josh Wu
2024-06-04 16:11:26 +08:00
parent 090309302f
commit f8a2d4f0e0
26 changed files with 260 additions and 159 deletions

View File

@@ -14,13 +14,19 @@
from typing import List, Union from typing import List, Union
from bumble import core
class AtParsingError(core.InvalidPacketError):
"""Error raised when parsing AT commands fails."""
def tokenize_parameters(buffer: bytes) -> List[bytes]: def tokenize_parameters(buffer: bytes) -> List[bytes]:
"""Split input parameters into tokens. """Split input parameters into tokens.
Removes space characters outside of double quote blocks: Removes space characters outside of double quote blocks:
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0) T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
are ignored [..], unless they are embedded in numeric or string constants" are ignored [..], unless they are embedded in numeric or string constants"
Raises ValueError in case of invalid input string.""" Raises AtParsingError in case of invalid input string."""
tokens = [] tokens = []
in_quotes = False in_quotes = False
@@ -43,11 +49,11 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
token = bytearray() token = bytearray()
elif char == b'(': elif char == b'(':
if len(token) > 0: if len(token) > 0:
raise ValueError("open_paren following regular character") raise AtParsingError("open_paren following regular character")
tokens.append(char) tokens.append(char)
elif char == b'"': elif char == b'"':
if len(token) > 0: if len(token) > 0:
raise ValueError("quote following regular character") raise AtParsingError("quote following regular character")
in_quotes = True in_quotes = True
token.extend(char) token.extend(char)
else: else:
@@ -59,7 +65,7 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]:
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]: def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators. """Parse the parameters using the comma and parenthesis separators.
Raises ValueError in case of invalid input string.""" Raises AtParsingError in case of invalid input string."""
tokens = tokenize_parameters(buffer) tokens = tokenize_parameters(buffer)
accumulator: List[list] = [[]] accumulator: List[list] = [[]]
@@ -73,7 +79,7 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
accumulator.append([]) accumulator.append([])
elif token == b')': elif token == b')':
if len(accumulator) < 2: if len(accumulator) < 2:
raise ValueError("close_paren without matching open_paren") raise AtParsingError("close_paren without matching open_paren")
accumulator[-1].append(current) accumulator[-1].append(current)
current = accumulator.pop() current = accumulator.pop()
else: else:
@@ -81,5 +87,5 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
accumulator[-1].append(current) accumulator[-1].append(current)
if len(accumulator) > 1: if len(accumulator) > 1:
raise ValueError("missing close_paren") raise AtParsingError("missing close_paren")
return accumulator[0] return accumulator[0]

View File

@@ -20,6 +20,7 @@ import enum
import struct import struct
from typing import Dict, Type, Union, Tuple from typing import Dict, Type, Union, Tuple
from bumble import core
from bumble.utils import OpenIntEnum from bumble.utils import OpenIntEnum
@@ -88,7 +89,9 @@ class Frame:
short_name = subclass.__name__.replace("ResponseFrame", "") short_name = subclass.__name__.replace("ResponseFrame", "")
category_class = ResponseFrame category_class = ResponseFrame
else: else:
raise ValueError(f"invalid subclass name {subclass.__name__}") raise core.InvalidArgumentError(
f"invalid subclass name {subclass.__name__}"
)
uppercase_indexes = [ uppercase_indexes = [
i for i in range(len(short_name)) if short_name[i].isupper() i for i in range(len(short_name)) if short_name[i].isupper()
@@ -106,7 +109,7 @@ class Frame:
@staticmethod @staticmethod
def from_bytes(data: bytes) -> Frame: def from_bytes(data: bytes) -> Frame:
if data[0] >> 4 != 0: if data[0] >> 4 != 0:
raise ValueError("first 4 bits must be 0s") raise core.InvalidPacketError("first 4 bits must be 0s")
ctype_or_response = data[0] & 0xF ctype_or_response = data[0] & 0xF
subunit_type = Frame.SubunitType(data[1] >> 3) subunit_type = Frame.SubunitType(data[1] >> 3)
@@ -122,7 +125,7 @@ class Frame:
# Extended to the next byte # Extended to the next byte
extension = data[2] extension = data[2]
if extension == 0: if extension == 0:
raise ValueError("extended subunit ID value reserved") raise core.InvalidPacketError("extended subunit ID value reserved")
if extension == 0xFF: if extension == 0xFF:
subunit_id = 5 + 254 + data[3] subunit_id = 5 + 254 + data[3]
opcode_offset = 4 opcode_offset = 4
@@ -131,7 +134,7 @@ class Frame:
opcode_offset = 3 opcode_offset = 3
elif subunit_id == 6: elif subunit_id == 6:
raise ValueError("reserved subunit ID") raise core.InvalidPacketError("reserved subunit ID")
opcode = Frame.OperationCode(data[opcode_offset]) opcode = Frame.OperationCode(data[opcode_offset])
operands = data[opcode_offset + 1 :] operands = data[opcode_offset + 1 :]
@@ -448,7 +451,7 @@ class PassThroughFrame:
operation_data: bytes, operation_data: bytes,
) -> None: ) -> None:
if len(operation_data) > 255: if len(operation_data) > 255:
raise ValueError("operation data must be <= 255 bytes") raise core.InvalidArgumentError("operation data must be <= 255 bytes")
self.state_flag = state_flag self.state_flag = state_flag
self.operation_id = operation_id self.operation_id = operation_id
self.operation_data = operation_data self.operation_data = operation_data

View File

@@ -23,6 +23,7 @@ from typing import Callable, cast, Dict, Optional
from bumble.colors import color from bumble.colors import color
from bumble import avc from bumble import avc
from bumble import core
from bumble import l2cap from bumble import l2cap
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -275,7 +276,7 @@ class Protocol:
self, pid: int, handler: Protocol.CommandHandler self, pid: int, handler: Protocol.CommandHandler
) -> None: ) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler: if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise ValueError("command handler not registered") raise core.InvalidArgumentError("command handler not registered")
del self.command_handlers[pid] del self.command_handlers[pid]
def register_response_handler( def register_response_handler(
@@ -287,5 +288,5 @@ class Protocol:
self, pid: int, handler: Protocol.ResponseHandler self, pid: int, handler: Protocol.ResponseHandler
) -> None: ) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler: if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise ValueError("response handler not registered") raise core.InvalidArgumentError("response handler not registered")
del self.response_handlers[pid] del self.response_handlers[pid]

View File

@@ -43,6 +43,7 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError, InvalidStateError,
ProtocolError, ProtocolError,
InvalidArgumentError,
name_or_number, name_or_number,
) )
from .a2dp import ( from .a2dp import (
@@ -700,7 +701,7 @@ class Message: # pylint:disable=attribute-defined-outside-init
signal_identifier_str = name[:-7] signal_identifier_str = name[:-7]
message_type = Message.MessageType.RESPONSE_REJECT message_type = Message.MessageType.RESPONSE_REJECT
else: else:
raise ValueError('invalid class name') raise InvalidArgumentError('invalid class name')
subclass.message_type = message_type subclass.message_type = message_type

View File

@@ -55,6 +55,7 @@ from bumble.sdp import (
) )
from bumble.utils import AsyncRunner, OpenIntEnum from bumble.utils import AsyncRunner, OpenIntEnum
from bumble.core import ( from bumble.core import (
InvalidArgumentError,
ProtocolError, ProtocolError,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_AVCTP_PROTOCOL_ID, BT_AVCTP_PROTOCOL_ID,
@@ -1411,7 +1412,7 @@ class Protocol(pyee.EventEmitter):
def notify_track_changed(self, identifier: bytes) -> None: def notify_track_changed(self, identifier: bytes) -> None:
"""Notify the connected peer of a Track change.""" """Notify the connected peer of a Track change."""
if len(identifier) != 8: if len(identifier) != 8:
raise ValueError("identifier must be 8 bytes") raise InvalidArgumentError("identifier must be 8 bytes")
self.notify_event(TrackChangedEvent(identifier)) self.notify_event(TrackChangedEvent(identifier))
def notify_playback_position_changed(self, position: int) -> None: def notify_playback_position_changed(self, position: int) -> None:

View File

@@ -18,6 +18,8 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from bumble import core
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BitReader: class BitReader:
@@ -40,7 +42,7 @@ class BitReader:
""" "Read up to 32 bits.""" """ "Read up to 32 bits."""
if bits > 32: if bits > 32:
raise ValueError('maximum read size is 32') raise core.InvalidArgumentError('maximum read size is 32')
if self.bits_cached >= bits: if self.bits_cached >= bits:
# We have enough bits. # We have enough bits.
@@ -53,7 +55,7 @@ class BitReader:
feed_size = len(feed_bytes) feed_size = len(feed_bytes)
feed_int = int.from_bytes(feed_bytes, byteorder='big') feed_int = int.from_bytes(feed_bytes, byteorder='big')
if 8 * feed_size + self.bits_cached < bits: if 8 * feed_size + self.bits_cached < bits:
raise ValueError('trying to read past the data') raise core.InvalidArgumentError('trying to read past the data')
self.byte_position += feed_size self.byte_position += feed_size
# Combine the new cache and the old cache # Combine the new cache and the old cache
@@ -68,7 +70,7 @@ class BitReader:
def read_bytes(self, count: int): def read_bytes(self, count: int):
if self.bit_position + 8 * count > 8 * len(self.data): if self.bit_position + 8 * count > 8 * len(self.data):
raise ValueError('not enough data') raise core.InvalidArgumentError('not enough data')
if self.bit_position % 8: if self.bit_position % 8:
# Not byte aligned # Not byte aligned
@@ -113,7 +115,7 @@ class AacAudioRtpPacket:
@staticmethod @staticmethod
def program_config_element(reader: BitReader): def program_config_element(reader: BitReader):
raise ValueError('program_config_element not supported') raise core.InvalidPacketError('program_config_element not supported')
@dataclass @dataclass
class GASpecificConfig: class GASpecificConfig:
@@ -140,7 +142,7 @@ class AacAudioRtpPacket:
aac_spectral_data_resilience_flags = reader.read(1) aac_spectral_data_resilience_flags = reader.read(1)
extension_flag_3 = reader.read(1) extension_flag_3 = reader.read(1)
if extension_flag_3 == 1: if extension_flag_3 == 1:
raise ValueError('extensionFlag3 == 1 not supported') raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
@staticmethod @staticmethod
def audio_object_type(reader: BitReader): def audio_object_type(reader: BitReader):
@@ -216,7 +218,7 @@ class AacAudioRtpPacket:
reader, self.channel_configuration, self.audio_object_type reader, self.channel_configuration, self.audio_object_type
) )
else: else:
raise ValueError( raise core.InvalidPacketError(
f'audioObjectType {self.audio_object_type} not supported' f'audioObjectType {self.audio_object_type} not supported'
) )
@@ -260,7 +262,7 @@ class AacAudioRtpPacket:
else: else:
audio_mux_version_a = 0 audio_mux_version_a = 0
if audio_mux_version_a != 0: if audio_mux_version_a != 0:
raise ValueError('audioMuxVersionA != 0 not supported') raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1: if audio_mux_version == 1:
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader) tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
stream_cnt = 0 stream_cnt = 0
@@ -268,10 +270,10 @@ class AacAudioRtpPacket:
num_sub_frames = reader.read(6) num_sub_frames = reader.read(6)
num_program = reader.read(4) num_program = reader.read(4)
if num_program != 0: if num_program != 0:
raise ValueError('num_program != 0 not supported') raise core.InvalidPacketError('num_program != 0 not supported')
num_layer = reader.read(3) num_layer = reader.read(3)
if num_layer != 0: if num_layer != 0:
raise ValueError('num_layer != 0 not supported') raise core.InvalidPacketError('num_layer != 0 not supported')
if audio_mux_version == 0: if audio_mux_version == 0:
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig( self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
reader reader
@@ -284,7 +286,7 @@ class AacAudioRtpPacket:
) )
audio_specific_config_len = reader.bit_position - marker audio_specific_config_len = reader.bit_position - marker
if asc_len < audio_specific_config_len: if asc_len < audio_specific_config_len:
raise ValueError('audio_specific_config_len > asc_len') raise core.InvalidPacketError('audio_specific_config_len > asc_len')
asc_len -= audio_specific_config_len asc_len -= audio_specific_config_len
reader.skip(asc_len) reader.skip(asc_len)
frame_length_type = reader.read(3) frame_length_type = reader.read(3)
@@ -293,7 +295,9 @@ class AacAudioRtpPacket:
elif frame_length_type == 1: elif frame_length_type == 1:
frame_length = reader.read(9) frame_length = reader.read(9)
else: else:
raise ValueError(f'frame_length_type {frame_length_type} not supported') raise core.InvalidPacketError(
f'frame_length_type {frame_length_type} not supported'
)
self.other_data_present = reader.read(1) self.other_data_present = reader.read(1)
if self.other_data_present: if self.other_data_present:
@@ -318,12 +322,12 @@ class AacAudioRtpPacket:
def __init__(self, reader: BitReader, mux_config_present: int): def __init__(self, reader: BitReader, mux_config_present: int):
if mux_config_present == 0: if mux_config_present == 0:
raise ValueError('muxConfigPresent == 0 not supported') raise core.InvalidPacketError('muxConfigPresent == 0 not supported')
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41 # AudioMuxElement - ISO/EIC 14496-3 Table 1.41
use_same_stream_mux = reader.read(1) use_same_stream_mux = reader.read(1)
if use_same_stream_mux: if use_same_stream_mux:
raise ValueError('useSameStreamMux == 1 not supported') raise core.InvalidPacketError('useSameStreamMux == 1 not supported')
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader) self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
# We only support: # We only support:

View File

@@ -16,6 +16,10 @@ from functools import partial
from typing import List, Optional, Union from typing import List, Optional, Union
class ColorError(ValueError):
"""Error raised when a color spec is invalid."""
# ANSI color names. There is also a "default" # ANSI color names. There is also a "default"
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white') COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
@@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str:
elif isinstance(spec, int) and 0 <= spec <= 255: elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec) return _join(base + 8, 5, spec)
else: else:
raise ValueError('Invalid color spec "%s"' % spec) raise ColorError('Invalid color spec "%s"' % spec)
def color( def color(
@@ -72,7 +76,7 @@ def color(
if style_part in STYLES: if style_part in STYLES:
codes.append(STYLES.index(style_part)) codes.append(STYLES.index(style_part))
else: else:
raise ValueError('Invalid style "%s"' % style_part) raise ColorError('Invalid style "%s"' % style_part)
if codes: if codes:
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s) return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)

View File

@@ -76,7 +76,13 @@ def get_dict_key_by_value(dictionary, value):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Exceptions # Exceptions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BaseError(Exception):
class BaseBumbleError(Exception):
"""Base Error raised by Bumble."""
class BaseError(BaseBumbleError):
"""Base class for errors with an error code, error name and namespace""" """Base class for errors with an error code, error name and namespace"""
def __init__( def __init__(
@@ -115,18 +121,38 @@ class ProtocolError(BaseError):
"""Protocol Error""" """Protocol Error"""
class TimeoutError(Exception): # pylint: disable=redefined-builtin class TimeoutError(BaseBumbleError): # pylint: disable=redefined-builtin
"""Timeout Error""" """Timeout Error"""
class CommandTimeoutError(Exception): class CommandTimeoutError(BaseBumbleError):
"""Command Timeout Error""" """Command Timeout Error"""
class InvalidStateError(Exception): class InvalidStateError(BaseBumbleError):
"""Invalid State Error""" """Invalid State Error"""
class InvalidArgumentError(BaseBumbleError, ValueError):
"""Invalid Argument Error"""
class InvalidPacketError(BaseBumbleError, ValueError):
"""Invalid Packet Error"""
class InvalidOperationError(BaseBumbleError, RuntimeError):
"""Invalid Operation Error"""
class OutOfResourcesError(BaseBumbleError, RuntimeError):
"""Out of Resources Error"""
class UnreachableError(BaseBumbleError):
"""The code path raising this error should be unreachable."""
class ConnectionError(BaseError): # pylint: disable=redefined-builtin class ConnectionError(BaseError): # pylint: disable=redefined-builtin
"""Connection Error""" """Connection Error"""
@@ -185,12 +211,12 @@ class UUID:
or uuid_str_or_int[18] != '-' or uuid_str_or_int[18] != '-'
or uuid_str_or_int[23] != '-' or uuid_str_or_int[23] != '-'
): ):
raise ValueError('invalid UUID format') raise InvalidArgumentError('invalid UUID format')
uuid_str = uuid_str_or_int.replace('-', '') uuid_str = uuid_str_or_int.replace('-', '')
else: else:
uuid_str = uuid_str_or_int uuid_str = uuid_str_or_int
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4: if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
raise ValueError(f"invalid UUID format: {uuid_str}") raise InvalidArgumentError(f"invalid UUID format: {uuid_str}")
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str))) self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
self.name = name self.name = name
@@ -215,7 +241,7 @@ class UUID:
return self.register() return self.register()
raise ValueError('only 2, 4 and 16 bytes are allowed') raise InvalidArgumentError('only 2, 4 and 16 bytes are allowed')
@classmethod @classmethod
def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID: def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID:

View File

@@ -173,10 +173,15 @@ from .core import (
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
AdvertisingData, AdvertisingData,
BaseBumbleError,
ConnectionParameterUpdateError, ConnectionParameterUpdateError,
CommandTimeoutError, CommandTimeoutError,
ConnectionPHY, ConnectionPHY,
InvalidArgumentError,
InvalidOperationError,
InvalidStateError, InvalidStateError,
OutOfResourcesError,
UnreachableError,
) )
from .utils import ( from .utils import (
AsyncRunner, AsyncRunner,
@@ -259,6 +264,8 @@ DEVICE_MAX_HIGH_DUTY_CYCLE_CONNECTABLE_DIRECTED_ADVERTISING_DURATION = 1.28
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ObjectLookupError(BaseBumbleError):
"""Error raised when failed to lookup an object."""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1374,7 +1381,9 @@ def with_connection_from_handle(function):
@functools.wraps(function) @functools.wraps(function)
def wrapper(self, connection_handle, *args, **kwargs): def wrapper(self, connection_handle, *args, **kwargs):
if (connection := self.lookup_connection(connection_handle)) is None: if (connection := self.lookup_connection(connection_handle)) is None:
raise ValueError(f'no connection for handle: 0x{connection_handle:04x}') raise ObjectLookupError(
f'no connection for handle: 0x{connection_handle:04x}'
)
return function(self, connection, *args, **kwargs) return function(self, connection, *args, **kwargs)
return wrapper return wrapper
@@ -1389,7 +1398,7 @@ def with_connection_from_address(function):
for connection in self.connections.values(): for connection in self.connections.values():
if connection.peer_address == address: if connection.peer_address == address:
return function(self, connection, *args, **kwargs) return function(self, connection, *args, **kwargs)
raise ValueError('no connection for address') raise ObjectLookupError('no connection for address')
return wrapper return wrapper
@@ -1798,7 +1807,7 @@ class Device(CompositeEventEmitter):
spec=spec, spec=spec,
) )
else: else:
raise ValueError(f'Unexpected mode {spec}') raise InvalidArgumentError(f'Unexpected mode {spec}')
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu) self.host.send_l2cap_pdu(connection_handle, cid, pdu)
@@ -1959,7 +1968,7 @@ class Device(CompositeEventEmitter):
def supports_le_features(self, feature: LeFeatureMask) -> bool: def supports_le_features(self, feature: LeFeatureMask) -> bool:
return self.host.supports_le_features(feature) return self.host.supports_le_features(feature)
def supports_le_phy(self, phy): def supports_le_phy(self, phy: int) -> bool:
if phy == HCI_LE_1M_PHY: if phy == HCI_LE_1M_PHY:
return True return True
@@ -1968,7 +1977,7 @@ class Device(CompositeEventEmitter):
HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY, HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY,
} }
if phy not in feature_map: if phy not in feature_map:
raise ValueError('invalid PHY') raise InvalidArgumentError('invalid PHY')
return self.supports_le_features(feature_map[phy]) return self.supports_le_features(feature_map[phy])
@@ -2028,7 +2037,7 @@ class Device(CompositeEventEmitter):
# Decide what peer address to use # Decide what peer address to use
if advertising_type.is_directed: if advertising_type.is_directed:
if target is None: if target is None:
raise ValueError('directed advertising requires a target') raise InvalidArgumentError('directed advertising requires a target')
peer_address = target peer_address = target
else: else:
peer_address = Address.ANY peer_address = Address.ANY
@@ -2135,7 +2144,7 @@ class Device(CompositeEventEmitter):
and advertising_data and advertising_data
and scan_response_data and scan_response_data
): ):
raise ValueError( raise InvalidArgumentError(
"Extended advertisements can't have both data and scan \ "Extended advertisements can't have both data and scan \
response data" response data"
) )
@@ -2151,7 +2160,9 @@ class Device(CompositeEventEmitter):
if handle not in self.extended_advertising_sets if handle not in self.extended_advertising_sets
) )
except StopIteration as exc: except StopIteration as exc:
raise RuntimeError("all valid advertising handles already in use") from exc raise OutOfResourcesError(
"all valid advertising handles already in use"
) from exc
# Use the device's random address if a random address is needed but none was # Use the device's random address if a random address is needed but none was
# provided. # provided.
@@ -2250,14 +2261,14 @@ class Device(CompositeEventEmitter):
) -> None: ) -> None:
# Check that the arguments are legal # Check that the arguments are legal
if scan_interval < scan_window: if scan_interval < scan_window:
raise ValueError('scan_interval must be >= scan_window') raise InvalidArgumentError('scan_interval must be >= scan_window')
if ( if (
scan_interval < DEVICE_MIN_SCAN_INTERVAL scan_interval < DEVICE_MIN_SCAN_INTERVAL
or scan_interval > DEVICE_MAX_SCAN_INTERVAL or scan_interval > DEVICE_MAX_SCAN_INTERVAL
): ):
raise ValueError('scan_interval out of range') raise InvalidArgumentError('scan_interval out of range')
if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW: if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW:
raise ValueError('scan_interval out of range') raise InvalidArgumentError('scan_interval out of range')
# Reset the accumulators # Reset the accumulators
self.advertisement_accumulators = {} self.advertisement_accumulators = {}
@@ -2285,7 +2296,7 @@ class Device(CompositeEventEmitter):
scanning_phy_count += 1 scanning_phy_count += 1
if scanning_phy_count == 0: if scanning_phy_count == 0:
raise ValueError('at least one scanning PHY must be enabled') raise InvalidArgumentError('at least one scanning PHY must be enabled')
await self.send_command( await self.send_command(
HCI_LE_Set_Extended_Scan_Parameters_Command( HCI_LE_Set_Extended_Scan_Parameters_Command(
@@ -2479,7 +2490,7 @@ class Device(CompositeEventEmitter):
# Check parameters # Check parameters
if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT): if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT):
raise ValueError('invalid transport') raise InvalidArgumentError('invalid transport')
# Adjust the transport automatically if we need to # Adjust the transport automatically if we need to
if transport == BT_LE_TRANSPORT and not self.le_enabled: if transport == BT_LE_TRANSPORT and not self.le_enabled:
@@ -2496,7 +2507,7 @@ class Device(CompositeEventEmitter):
peer_address = Address.from_string_for_transport( peer_address = Address.from_string_for_transport(
peer_address, transport peer_address, transport
) )
except ValueError: except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead # If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name') logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name( peer_address = await self.find_peer_by_name(
@@ -2508,7 +2519,7 @@ class Device(CompositeEventEmitter):
transport == BT_BR_EDR_TRANSPORT transport == BT_BR_EDR_TRANSPORT
and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS
): ):
raise ValueError('BR/EDR addresses must be PUBLIC') raise InvalidArgumentError('BR/EDR addresses must be PUBLIC')
assert isinstance(peer_address, Address) assert isinstance(peer_address, Address)
@@ -2559,7 +2570,7 @@ class Device(CompositeEventEmitter):
) )
) )
if not phys: if not phys:
raise ValueError('at least one supported PHY needed') raise InvalidArgumentError('at least one supported PHY needed')
phy_count = len(phys) phy_count = len(phys)
initiating_phys = phy_list_to_bits(phys) initiating_phys = phy_list_to_bits(phys)
@@ -2631,7 +2642,7 @@ class Device(CompositeEventEmitter):
) )
else: else:
if HCI_LE_1M_PHY not in connection_parameters_preferences: if HCI_LE_1M_PHY not in connection_parameters_preferences:
raise ValueError('1M PHY preferences required') raise InvalidArgumentError('1M PHY preferences required')
prefs = connection_parameters_preferences[HCI_LE_1M_PHY] prefs = connection_parameters_preferences[HCI_LE_1M_PHY]
result = await self.send_command( result = await self.send_command(
@@ -2731,7 +2742,7 @@ class Device(CompositeEventEmitter):
if isinstance(peer_address, str): if isinstance(peer_address, str):
try: try:
peer_address = Address(peer_address) peer_address = Address(peer_address)
except ValueError: except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead # If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name') logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name( peer_address = await self.find_peer_by_name(
@@ -2741,7 +2752,7 @@ class Device(CompositeEventEmitter):
assert isinstance(peer_address, Address) assert isinstance(peer_address, Address)
if peer_address == Address.NIL: if peer_address == Address.NIL:
raise ValueError('accept on nil address') raise InvalidArgumentError('accept on nil address')
# Create a future so that we can wait for the request # Create a future so that we can wait for the request
pending_request_fut = asyncio.get_running_loop().create_future() pending_request_fut = asyncio.get_running_loop().create_future()
@@ -2854,7 +2865,7 @@ class Device(CompositeEventEmitter):
if isinstance(peer_address, str): if isinstance(peer_address, str):
try: try:
peer_address = Address(peer_address) peer_address = Address(peer_address)
except ValueError: except InvalidArgumentError:
# If the address is not parsable, assume it is a name instead # If the address is not parsable, assume it is a name instead
logger.debug('looking for peer by name') logger.debug('looking for peer by name')
peer_address = await self.find_peer_by_name( peer_address = await self.find_peer_by_name(
@@ -2897,10 +2908,10 @@ class Device(CompositeEventEmitter):
async def set_data_length(self, connection, tx_octets, tx_time) -> None: async def set_data_length(self, connection, tx_octets, tx_time) -> None:
if tx_octets < 0x001B or tx_octets > 0x00FB: if tx_octets < 0x001B or tx_octets > 0x00FB:
raise ValueError('tx_octets must be between 0x001B and 0x00FB') raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB')
if tx_time < 0x0148 or tx_time > 0x4290: if tx_time < 0x0148 or tx_time > 0x4290:
raise ValueError('tx_time must be between 0x0148 and 0x4290') raise InvalidArgumentError('tx_time must be between 0x0148 and 0x4290')
return await self.send_command( return await self.send_command(
HCI_LE_Set_Data_Length_Command( HCI_LE_Set_Data_Length_Command(
@@ -3175,7 +3186,7 @@ class Device(CompositeEventEmitter):
async def encrypt(self, connection, enable=True): async def encrypt(self, connection, enable=True):
if not enable and connection.transport == BT_LE_TRANSPORT: if not enable and connection.transport == BT_LE_TRANSPORT:
raise ValueError('`enable` parameter is classic only.') raise InvalidArgumentError('`enable` parameter is classic only.')
# Set up event handlers # Set up event handlers
pending_encryption = asyncio.get_running_loop().create_future() pending_encryption = asyncio.get_running_loop().create_future()
@@ -3194,11 +3205,11 @@ class Device(CompositeEventEmitter):
if connection.transport == BT_LE_TRANSPORT: if connection.transport == BT_LE_TRANSPORT:
# Look for a key in the key store # Look for a key in the key store
if self.keystore is None: if self.keystore is None:
raise RuntimeError('no key store') raise InvalidOperationError('no key store')
keys = await self.keystore.get(str(connection.peer_address)) keys = await self.keystore.get(str(connection.peer_address))
if keys is None: if keys is None:
raise RuntimeError('keys not found in key store') raise InvalidOperationError('keys not found in key store')
if keys.ltk is not None: if keys.ltk is not None:
ltk = keys.ltk.value ltk = keys.ltk.value
@@ -3209,7 +3220,7 @@ class Device(CompositeEventEmitter):
rand = keys.ltk_central.rand rand = keys.ltk_central.rand
ediv = keys.ltk_central.ediv ediv = keys.ltk_central.ediv
else: else:
raise RuntimeError('no LTK found for peer') raise InvalidOperationError('no LTK found for peer')
if connection.role != HCI_CENTRAL_ROLE: if connection.role != HCI_CENTRAL_ROLE:
raise InvalidStateError('only centrals can start encryption') raise InvalidStateError('only centrals can start encryption')
@@ -3484,7 +3495,7 @@ class Device(CompositeEventEmitter):
return cis_link return cis_link
# Mypy believes this is reachable when context is an ExitStack. # Mypy believes this is reachable when context is an ExitStack.
raise InvalidStateError('Unreachable') raise UnreachableError()
# [LE only] # [LE only]
@experimental('Only for testing.') @experimental('Only for testing.')
@@ -3950,7 +3961,7 @@ class Device(CompositeEventEmitter):
return await pairing_config.delegate.confirm(auto=True) return await pairing_config.delegate.confirm(auto=True)
async def na() -> bool: async def na() -> bool:
assert False, "N/A: unreachable" raise UnreachableError()
# See Bluetooth spec @ Vol 3, Part C 5.2.2.6 # See Bluetooth spec @ Vol 3, Part C 5.2.2.6
methods = { methods = {

View File

@@ -33,6 +33,7 @@ from typing import Tuple
import weakref import weakref
from bumble import core
from bumble.hci import ( from bumble.hci import (
hci_vendor_command_op_code, hci_vendor_command_op_code,
STATUS_SPEC, STATUS_SPEC,
@@ -49,6 +50,10 @@ from bumble.drivers import common
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RtkFirmwareError(core.BaseBumbleError):
"""Error raised when RTK firmware initialization fails."""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -208,15 +213,15 @@ class Firmware:
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77]) extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
if not firmware.startswith(RTK_EPATCH_SIGNATURE): if not firmware.startswith(RTK_EPATCH_SIGNATURE):
raise ValueError("Firmware does not start with epatch signature") raise RtkFirmwareError("Firmware does not start with epatch signature")
if not firmware.endswith(extension_sig): if not firmware.endswith(extension_sig):
raise ValueError("Firmware does not end with extension sig") raise RtkFirmwareError("Firmware does not end with extension sig")
# The firmware should start with a 14 byte header. # The firmware should start with a 14 byte header.
epatch_header_size = 14 epatch_header_size = 14
if len(firmware) < epatch_header_size: if len(firmware) < epatch_header_size:
raise ValueError("Firmware too short") raise RtkFirmwareError("Firmware too short")
# Look for the "project ID", starting from the end. # Look for the "project ID", starting from the end.
offset = len(firmware) - len(extension_sig) offset = len(firmware) - len(extension_sig)
@@ -230,7 +235,7 @@ class Firmware:
break break
if length == 0: if length == 0:
raise ValueError("Invalid 0-length instruction") raise RtkFirmwareError("Invalid 0-length instruction")
if opcode == 0 and length == 1: if opcode == 0 and length == 1:
project_id = firmware[offset - 1] project_id = firmware[offset - 1]
@@ -239,7 +244,7 @@ class Firmware:
offset -= length offset -= length
if project_id < 0: if project_id < 0:
raise ValueError("Project ID not found") raise RtkFirmwareError("Project ID not found")
self.project_id = project_id self.project_id = project_id
@@ -252,7 +257,7 @@ class Firmware:
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each) # <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each) # <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
if epatch_header_size + 8 * num_patches > len(firmware): if epatch_header_size + 8 * num_patches > len(firmware):
raise ValueError("Firmware too short") raise RtkFirmwareError("Firmware too short")
chip_id_table_offset = epatch_header_size chip_id_table_offset = epatch_header_size
patch_length_table_offset = chip_id_table_offset + 2 * num_patches patch_length_table_offset = chip_id_table_offset + 2 * num_patches
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
@@ -266,7 +271,7 @@ class Firmware:
"<I", firmware, patch_offset_table_offset + 4 * patch_index "<I", firmware, patch_offset_table_offset + 4 * patch_index
) )
if patch_offset + patch_length > len(firmware): if patch_offset + patch_length > len(firmware):
raise ValueError("Firmware too short") raise RtkFirmwareError("Firmware too short")
# Get the SVN version for the patch # Get the SVN version for the patch
(svn_version,) = struct.unpack_from( (svn_version,) = struct.unpack_from(
@@ -645,7 +650,7 @@ class Driver(common.Driver):
): ):
return await self.download_for_rtl8723b() return await self.download_for_rtl8723b()
raise ValueError("ROM not supported") raise RtkFirmwareError("ROM not supported")
async def init_controller(self): async def init_controller(self):
await self.download_firmware() await self.download_firmware()

View File

@@ -331,9 +331,9 @@ class Client:
async def request_mtu(self, mtu: int) -> int: async def request_mtu(self, mtu: int) -> int:
# Check the range # Check the range
if mtu < ATT_DEFAULT_MTU: if mtu < ATT_DEFAULT_MTU:
raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}') raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
if mtu > 0xFFFF: if mtu > 0xFFFF:
raise ValueError('MTU must be <= 0xFFFF') raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
# We can only send one request per connection # We can only send one request per connection
if self.mtu_exchange_done: if self.mtu_exchange_done:

View File

@@ -31,6 +31,8 @@ from .core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
AdvertisingData, AdvertisingData,
DeviceClass, DeviceClass,
InvalidArgumentError,
InvalidPacketError,
ProtocolError, ProtocolError,
bit_flags_to_strings, bit_flags_to_strings,
name_or_number, name_or_number,
@@ -91,14 +93,14 @@ def map_class_of_device(class_of_device):
) )
def phy_list_to_bits(phys): def phy_list_to_bits(phys: Optional[Iterable[int]]) -> int:
if phys is None: if phys is None:
return 0 return 0
phy_bits = 0 phy_bits = 0
for phy in phys: for phy in phys:
if phy not in HCI_LE_PHY_TYPE_TO_BIT: if phy not in HCI_LE_PHY_TYPE_TO_BIT:
raise ValueError('invalid PHY') raise InvalidArgumentError('invalid PHY')
phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy] phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy]
return phy_bits return phy_bits
@@ -1552,7 +1554,7 @@ class HCI_Object:
new_offset, field_value = field_type(data, offset) new_offset, field_value = field_type(data, offset)
return (field_value, new_offset - offset) return (field_value, new_offset - offset)
raise ValueError(f'unknown field type {field_type}') raise InvalidArgumentError(f'unknown field type {field_type}')
@staticmethod @staticmethod
def dict_from_bytes(data, offset, fields): def dict_from_bytes(data, offset, fields):
@@ -1621,7 +1623,7 @@ class HCI_Object:
if 0 <= field_value <= 255: if 0 <= field_value <= 255:
field_bytes = bytes([field_value]) field_bytes = bytes([field_value])
else: else:
raise ValueError('value too large for *-typed field') raise InvalidArgumentError('value too large for *-typed field')
else: else:
field_bytes = bytes(field_value) field_bytes = bytes(field_value)
elif field_type == 'v': elif field_type == 'v':
@@ -1640,7 +1642,9 @@ class HCI_Object:
elif len(field_bytes) > field_type: elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type] field_bytes = field_bytes[:field_type]
else: else:
raise ValueError(f"don't know how to serialize type {type(field_value)}") raise InvalidArgumentError(
f"don't know how to serialize type {type(field_value)}"
)
return field_bytes return field_bytes
@@ -1904,7 +1908,7 @@ class Address:
self.address_bytes = bytes(reversed(bytes.fromhex(address))) self.address_bytes = bytes(reversed(bytes.fromhex(address)))
if len(self.address_bytes) != 6: if len(self.address_bytes) != 6:
raise ValueError('invalid address length') raise InvalidArgumentError('invalid address length')
self.address_type = address_type self.address_type = address_type
@@ -2104,7 +2108,7 @@ class HCI_Command(HCI_Packet):
op_code, length = struct.unpack_from('<HB', packet, 1) op_code, length = struct.unpack_from('<HB', packet, 1)
parameters = packet[4:] parameters = packet[4:]
if len(parameters) != length: if len(parameters) != length:
raise ValueError('invalid packet length') raise InvalidPacketError('invalid packet length')
# Look for a registered class # Look for a registered class
cls = HCI_Command.command_classes.get(op_code) cls = HCI_Command.command_classes.get(op_code)
@@ -4729,7 +4733,7 @@ class HCI_Event(HCI_Packet):
length = packet[2] length = packet[2]
parameters = packet[3:] parameters = packet[3:]
if len(parameters) != length: if len(parameters) != length:
raise ValueError('invalid packet length') raise InvalidPacketError('invalid packet length')
cls: Any cls: Any
if event_code == HCI_LE_META_EVENT: if event_code == HCI_LE_META_EVENT:
@@ -6104,7 +6108,7 @@ class HCI_AclDataPacket(HCI_Packet):
bc_flag = (h >> 14) & 3 bc_flag = (h >> 14) & 3
data = packet[5:] data = packet[5:]
if len(data) != data_total_length: if len(data) != data_total_length:
raise ValueError('invalid packet length') raise InvalidPacketError('invalid packet length')
return HCI_AclDataPacket( return HCI_AclDataPacket(
connection_handle, pb_flag, bc_flag, data_total_length, data connection_handle, pb_flag, bc_flag, data_total_length, data
) )
@@ -6152,7 +6156,7 @@ class HCI_SynchronousDataPacket(HCI_Packet):
packet_status = (h >> 12) & 0b11 packet_status = (h >> 12) & 0b11
data = packet[4:] data = packet[4:]
if len(data) != data_total_length: if len(data) != data_total_length:
raise ValueError( raise InvalidPacketError(
f'invalid packet length {len(data)} != {data_total_length}' f'invalid packet length {len(data)} != {data_total_length}'
) )
return HCI_SynchronousDataPacket( return HCI_SynchronousDataPacket(

View File

@@ -41,7 +41,14 @@ from typing import (
from .utils import deprecated from .utils import deprecated
from .colors import color from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError from .core import (
BT_CENTRAL_ROLE,
InvalidStateError,
InvalidArgumentError,
InvalidPacketError,
OutOfResourcesError,
ProtocolError,
)
from .hci import ( from .hci import (
HCI_LE_Connection_Update_Command, HCI_LE_Connection_Update_Command,
HCI_Object, HCI_Object,
@@ -188,17 +195,17 @@ class LeCreditBasedChannelSpec:
self.max_credits < 1 self.max_credits < 1
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
): ):
raise ValueError('max credits out of range') raise InvalidArgumentError('max credits out of range')
if ( if (
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
): ):
raise ValueError('MTU out of range') raise InvalidArgumentError('MTU out of range')
if ( if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
): ):
raise ValueError('MPS out of range') raise InvalidArgumentError('MPS out of range')
class L2CAP_PDU: class L2CAP_PDU:
@@ -210,7 +217,7 @@ class L2CAP_PDU:
def from_bytes(data: bytes) -> L2CAP_PDU: def from_bytes(data: bytes) -> L2CAP_PDU:
# Check parameters # Check parameters
if len(data) < 4: if len(data) < 4:
raise ValueError('not enough data for L2CAP header') raise InvalidPacketError('not enough data for L2CAP header')
_, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0) _, l2cap_pdu_cid = struct.unpack_from('<HH', data, 0)
l2cap_pdu_payload = data[4:] l2cap_pdu_payload = data[4:]
@@ -815,7 +822,7 @@ class ClassicChannel(EventEmitter):
# Check that we can start a new connection # Check that we can start a new connection
if self.connection_result: if self.connection_result:
raise RuntimeError('connection already pending') raise InvalidStateError('connection already pending')
self._change_state(self.State.WAIT_CONNECT_RSP) self._change_state(self.State.WAIT_CONNECT_RSP)
self.send_control_frame( self.send_control_frame(
@@ -1126,7 +1133,7 @@ class LeCreditBasedChannel(EventEmitter):
# Check that we can start a new connection # Check that we can start a new connection
identifier = self.manager.next_identifier(self.connection) identifier = self.manager.next_identifier(self.connection)
if identifier in self.manager.le_coc_requests: if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests') raise InvalidStateError('too many concurrent connection requests')
self._change_state(self.State.CONNECTING) self._change_state(self.State.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request( request = L2CAP_LE_Credit_Based_Connection_Request(
@@ -1513,7 +1520,7 @@ class ChannelManager:
if cid not in channels: if cid not in channels:
return cid return cid
raise RuntimeError('no free CID available') raise OutOfResourcesError('no free CID available')
@staticmethod @staticmethod
def find_free_le_cid(channels: Iterable[int]) -> int: def find_free_le_cid(channels: Iterable[int]) -> int:
@@ -1526,7 +1533,7 @@ class ChannelManager:
if cid not in channels: if cid not in channels:
return cid return cid
raise RuntimeError('no free CID') raise OutOfResourcesError('no free CID')
def next_identifier(self, connection: Connection) -> int: def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
@@ -1573,15 +1580,15 @@ class ChannelManager:
else: else:
# Check that the PSM isn't already in use # Check that the PSM isn't already in use
if spec.psm in self.servers: if spec.psm in self.servers:
raise ValueError('PSM already in use') raise InvalidArgumentError('PSM already in use')
# Check that the PSM is valid # Check that the PSM is valid
if spec.psm % 2 == 0: if spec.psm % 2 == 0:
raise ValueError('invalid PSM (not odd)') raise InvalidArgumentError('invalid PSM (not odd)')
check = spec.psm >> 8 check = spec.psm >> 8
while check: while check:
if check % 2 != 0: if check % 2 != 0:
raise ValueError('invalid PSM') raise InvalidArgumentError('invalid PSM')
check >>= 8 check >>= 8
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu) self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
@@ -1623,7 +1630,7 @@ class ChannelManager:
else: else:
# Check that the PSM isn't already in use # Check that the PSM isn't already in use
if spec.psm in self.le_coc_servers: if spec.psm in self.le_coc_servers:
raise ValueError('PSM already in use') raise InvalidArgumentError('PSM already in use')
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer( self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
self, self,
@@ -2151,10 +2158,10 @@ class ChannelManager:
connection_channels = self.channels.setdefault(connection.handle, {}) connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels) source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen! if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use') raise OutOfResourcesError('all CIDs already in use')
if spec.psm is None: if spec.psm is None:
raise ValueError('PSM cannot be None') raise InvalidArgumentError('PSM cannot be None')
# Create the channel # Create the channel
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}') logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
@@ -2203,10 +2210,10 @@ class ChannelManager:
connection_channels = self.channels.setdefault(connection.handle, {}) connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels) source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen! if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use') raise OutOfResourcesError('all CIDs already in use')
if spec.psm is None: if spec.psm is None:
raise ValueError('PSM cannot be None') raise InvalidArgumentError('PSM cannot be None')
# Create the channel # Create the channel
logger.debug( logger.debug(

View File

@@ -19,7 +19,12 @@ import logging
import asyncio import asyncio
from functools import partial from functools import partial
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT from bumble.core import (
BT_PERIPHERAL_ROLE,
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
InvalidStateError,
)
from bumble.colors import color from bumble.colors import color
from bumble.hci import ( from bumble.hci import (
Address, Address,
@@ -405,12 +410,12 @@ class RemoteLink:
def add_controller(self, controller): def add_controller(self, controller):
if self.controller: if self.controller:
raise ValueError('controller already set') raise InvalidStateError('controller already set')
self.controller = controller self.controller = controller
def remove_controller(self, controller): def remove_controller(self, controller):
if self.controller != controller: if self.controller != controller:
raise ValueError('controller mismatch') raise InvalidStateError('controller mismatch')
self.controller = None self.controller = None
def get_pending_connection(self): def get_pending_connection(self):

View File

@@ -113,7 +113,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
set_member_rank: Optional[int] = None, set_member_rank: Optional[int] = None,
) -> None: ) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH: if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise ValueError( raise core.InvalidArgumentError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}' f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
) )
@@ -178,7 +178,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
key = await connection.device.get_link_key(connection.peer_address) key = await connection.device.get_link_key(connection.peer_address)
if not key: if not key:
raise RuntimeError('LTK or LinkKey is not present') raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key) sirk_bytes = sef(key, self.set_identity_resolving_key)
@@ -234,7 +234,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
'''Reads SIRK and decrypts if encrypted.''' '''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value() response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1: if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise RuntimeError('Invalid SIRK value') raise core.InvalidPacketError('Invalid SIRK value')
sirk_type = SirkType(response[0]) sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT: if sirk_type == SirkType.PLAINTEXT:
@@ -250,7 +250,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
key = await device.get_link_key(connection.peer_address) key = await device.get_link_key(connection.peer_address)
if not key: if not key:
raise RuntimeError('LTK or LinkKey is not present') raise core.InvalidOperationError('LTK or LinkKey is not present')
sirk = sef(key, response[1:]) sirk = sef(key, response[1:])

View File

@@ -19,6 +19,7 @@
from enum import IntEnum from enum import IntEnum
import struct import struct
from bumble import core
from ..gatt_client import ProfileServiceProxy from ..gatt_client import ProfileServiceProxy
from ..att import ATT_Error from ..att import ATT_Error
from ..gatt import ( from ..gatt import (
@@ -59,17 +60,17 @@ class HeartRateService(TemplateService):
rr_intervals=None, rr_intervals=None,
): ):
if heart_rate < 0 or heart_rate > 0xFFFF: if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range') raise core.InvalidArgumentError('heart_rate out of range')
if energy_expended is not None and ( if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF energy_expended < 0 or energy_expended > 0xFFFF
): ):
raise ValueError('energy_expended out of range') raise core.InvalidArgumentError('energy_expended out of range')
if rr_intervals: if rr_intervals:
for rr_interval in rr_intervals: for rr_interval in rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF: if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise ValueError('rr_intervals out of range') raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected self.sensor_contact_detected = sensor_contact_detected

View File

@@ -36,7 +36,9 @@ from .core import (
BT_RFCOMM_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
InvalidArgumentError,
InvalidStateError, InvalidStateError,
InvalidPacketError,
ProtocolError, ProtocolError,
) )
@@ -333,7 +335,7 @@ class RFCOMM_Frame:
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information) frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
if frame.fcs != fcs: if frame.fcs != fcs:
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}') logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
raise ValueError('fcs mismatch') raise InvalidPacketError('fcs mismatch')
return frame return frame
@@ -680,7 +682,7 @@ class DLC(EventEmitter):
# Automatically convert strings to bytes using UTF-8 # Automatically convert strings to bytes using UTF-8
data = data.encode('utf-8') data = data.encode('utf-8')
else: else:
raise ValueError('write only accept bytes or strings') raise InvalidArgumentError('write only accept bytes or strings')
self.tx_buffer += data self.tx_buffer += data
self.drained.clear() self.drained.clear()

View File

@@ -23,7 +23,7 @@ from typing_extensions import Self
from . import core, l2cap from . import core, l2cap
from .colors import color from .colors import color
from .core import InvalidStateError from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
from .hci import HCI_Object, name_or_number, key_with_value from .hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -189,7 +189,9 @@ class DataElement:
self.bytes = None self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER): if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None: if value_size is None:
raise ValueError('integer types must have a value size specified') raise InvalidArgumentError(
'integer types must have a value size specified'
)
@staticmethod @staticmethod
def nil() -> DataElement: def nil() -> DataElement:
@@ -265,7 +267,7 @@ class DataElement:
if len(data) == 8: if len(data) == 8:
return struct.unpack('>Q', data)[0] return struct.unpack('>Q', data)[0]
raise ValueError(f'invalid integer length {len(data)}') raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod @staticmethod
def signed_integer_from_bytes(data): def signed_integer_from_bytes(data):
@@ -281,7 +283,7 @@ class DataElement:
if len(data) == 8: if len(data) == 8:
return struct.unpack('>q', data)[0] return struct.unpack('>q', data)[0]
raise ValueError(f'invalid integer length {len(data)}') raise InvalidPacketError(f'invalid integer length {len(data)}')
@staticmethod @staticmethod
def list_from_bytes(data): def list_from_bytes(data):
@@ -354,7 +356,7 @@ class DataElement:
data = b'' data = b''
elif self.type == DataElement.UNSIGNED_INTEGER: elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0: if self.value < 0:
raise ValueError('UNSIGNED_INTEGER cannot be negative') raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative')
if self.value_size == 1: if self.value_size == 1:
data = struct.pack('B', self.value) data = struct.pack('B', self.value)
@@ -365,7 +367,7 @@ class DataElement:
elif self.value_size == 8: elif self.value_size == 8:
data = struct.pack('>Q', self.value) data = struct.pack('>Q', self.value)
else: else:
raise ValueError('invalid value_size') raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.SIGNED_INTEGER: elif self.type == DataElement.SIGNED_INTEGER:
if self.value_size == 1: if self.value_size == 1:
data = struct.pack('b', self.value) data = struct.pack('b', self.value)
@@ -376,7 +378,7 @@ class DataElement:
elif self.value_size == 8: elif self.value_size == 8:
data = struct.pack('>q', self.value) data = struct.pack('>q', self.value)
else: else:
raise ValueError('invalid value_size') raise InvalidArgumentError('invalid value_size')
elif self.type == DataElement.UUID: elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value))) data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL: elif self.type == DataElement.URL:
@@ -392,7 +394,7 @@ class DataElement:
size_bytes = b'' size_bytes = b''
if self.type == DataElement.NIL: if self.type == DataElement.NIL:
if size != 0: if size != 0:
raise ValueError('NIL must be empty') raise InvalidArgumentError('NIL must be empty')
size_index = 0 size_index = 0
elif self.type in ( elif self.type in (
DataElement.UNSIGNED_INTEGER, DataElement.UNSIGNED_INTEGER,
@@ -410,7 +412,7 @@ class DataElement:
elif size == 16: elif size == 16:
size_index = 4 size_index = 4
else: else:
raise ValueError('invalid data size') raise InvalidArgumentError('invalid data size')
elif self.type in ( elif self.type in (
DataElement.TEXT_STRING, DataElement.TEXT_STRING,
DataElement.SEQUENCE, DataElement.SEQUENCE,
@@ -427,10 +429,10 @@ class DataElement:
size_index = 7 size_index = 7
size_bytes = struct.pack('>I', size) size_bytes = struct.pack('>I', size)
else: else:
raise ValueError('invalid data size') raise InvalidArgumentError('invalid data size')
elif self.type == DataElement.BOOLEAN: elif self.type == DataElement.BOOLEAN:
if size != 1: if size != 1:
raise ValueError('boolean must be 1 byte') raise InvalidArgumentError('boolean must be 1 byte')
size_index = 0 size_index = 0
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data

View File

@@ -55,6 +55,7 @@ from .core import (
BT_CENTRAL_ROLE, BT_CENTRAL_ROLE,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
AdvertisingData, AdvertisingData,
InvalidArgumentError,
ProtocolError, ProtocolError,
name_or_number, name_or_number,
) )
@@ -784,7 +785,7 @@ class Session:
self.peer_oob_data = pairing_config.oob.peer_data self.peer_oob_data = pairing_config.oob.peer_data
if pairing_config.sc: if pairing_config.sc:
if pairing_config.oob.our_context is None: if pairing_config.oob.our_context is None:
raise ValueError( raise InvalidArgumentError(
"oob pairing config requires a context when sc is True" "oob pairing config requires a context when sc is True"
) )
self.r = pairing_config.oob.our_context.r self.r = pairing_config.oob.our_context.r
@@ -793,7 +794,7 @@ class Session:
self.tk = pairing_config.oob.legacy_context.tk self.tk = pairing_config.oob.legacy_context.tk
else: else:
if pairing_config.oob.legacy_context is None: if pairing_config.oob.legacy_context is None:
raise ValueError( raise InvalidArgumentError(
"oob pairing config requires a legacy context when sc is False" "oob pairing config requires a legacy context when sc is False"
) )
self.r = bytes(16) self.r = bytes(16)

View File

@@ -23,6 +23,7 @@ import datetime
from typing import BinaryIO, Generator from typing import BinaryIO, Generator
import os import os
from bumble import core
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
@@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
""" """
if ':' not in spec: if ':' not in spec:
raise ValueError('snooper type prefix missing') raise core.InvalidArgumentError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1) snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop': if snooper_type == 'btsnoop':
if ':' not in snooper_args: if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop snooper type missing') raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1) io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file': if io_type == 'file':
@@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]:
_SNOOPER_INSTANCE_COUNT -= 1 _SNOOPER_INSTANCE_COUNT -= 1
return return
raise ValueError(f'I/O type {io_type} not supported') raise core.InvalidArgumentError(f'I/O type {io_type} not supported')
raise ValueError(f'snooper type {snooper_type} not found') raise core.InvalidArgumentError(f'snooper type {snooper_type} not found')

View File

@@ -20,7 +20,7 @@ import logging
import os import os
from typing import Optional from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError
from ..snoop import create_snooper from ..snoop import create_snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -180,7 +180,7 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
return await open_android_netsim_transport(spec) return await open_android_netsim_transport(spec)
raise ValueError('unknown transport scheme') raise TransportSpecError('unknown transport scheme')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -20,7 +20,13 @@ import grpc.aio
from typing import Optional, Union from typing import Optional, Union
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport from .common import (
PumpedTransport,
PumpedPacketSource,
PumpedPacketSink,
Transport,
TransportSpecError,
)
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
@@ -77,7 +83,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
elif ':' in param: elif ':' in param:
server_host, server_port = param.split(':') server_host, server_port = param.split(':')
else: else:
raise ValueError('invalid parameter') raise TransportSpecError('invalid parameter')
# Connect to the gRPC server # Connect to the gRPC server
server_address = f'{server_host}:{server_port}' server_address = f'{server_host}:{server_port}'
@@ -94,7 +100,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
service = VhciForwardingServiceStub(channel) service = VhciForwardingServiceStub(channel)
hci_device = HciDevice(service.attachVhci()) hci_device = HciDevice(service.attachVhci())
else: else:
raise ValueError('invalid mode') raise TransportSpecError('invalid mode')
# Create the transport object # Create the transport object
class EmulatorTransport(PumpedTransport): class EmulatorTransport(PumpedTransport):

View File

@@ -31,6 +31,8 @@ from .common import (
PumpedPacketSource, PumpedPacketSource,
PumpedPacketSink, PumpedPacketSink,
Transport, Transport,
TransportSpecError,
TransportInitError,
) )
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
@@ -135,7 +137,7 @@ async def open_android_netsim_controller_transport(
server_host: Optional[str], server_port: int, options: Dict[str, str] server_host: Optional[str], server_port: int, options: Dict[str, str]
) -> Transport: ) -> Transport:
if not server_port: if not server_port:
raise ValueError('invalid port') raise TransportSpecError('invalid port')
if server_host == '_' or not server_host: if server_host == '_' or not server_host:
server_host = 'localhost' server_host = 'localhost'
@@ -288,7 +290,7 @@ async def open_android_netsim_host_transport_with_address(
instance_number = 0 if options is None else int(options.get('instance', '0')) instance_number = 0 if options is None else int(options.get('instance', '0'))
server_port = find_grpc_port(instance_number) server_port = find_grpc_port(instance_number)
if not server_port: if not server_port:
raise RuntimeError('gRPC server port not found') raise TransportInitError('gRPC server port not found')
# Connect to the gRPC server # Connect to the gRPC server
server_address = f'{server_host}:{server_port}' server_address = f'{server_host}:{server_port}'
@@ -326,7 +328,7 @@ async def open_android_netsim_host_transport_with_channel(
if response_type == 'error': if response_type == 'error':
logger.warning(f'received error: {response.error}') logger.warning(f'received error: {response.error}')
raise RuntimeError(response.error) raise TransportInitError(response.error)
if response_type == 'hci_packet': if response_type == 'hci_packet':
return ( return (
@@ -334,7 +336,7 @@ async def open_android_netsim_host_transport_with_channel(
+ response.hci_packet.packet + response.hci_packet.packet
) )
raise ValueError('unsupported response type') raise TransportSpecError('unsupported response type')
async def write(self, packet): async def write(self, packet):
await self.hci_device.write( await self.hci_device.write(
@@ -429,7 +431,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
options: Dict[str, str] = {} options: Dict[str, str] = {}
for param in params[params_offset:]: for param in params[params_offset:]:
if '=' not in param: if '=' not in param:
raise ValueError('invalid parameter, expected <name>=<value>') raise TransportSpecError('invalid parameter, expected <name>=<value>')
option_name, option_value = param.split('=') option_name, option_value = param.split('=')
options[option_name] = option_value options[option_name] = option_value
@@ -440,7 +442,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
) )
if mode == 'controller': if mode == 'controller':
if host is None: if host is None:
raise ValueError('<host>:<port> missing') raise TransportSpecError('<host>:<port> missing')
return await open_android_netsim_controller_transport(host, port, options) return await open_android_netsim_controller_transport(host, port, options)
raise ValueError('invalid mode option') raise TransportSpecError('invalid mode option')

View File

@@ -23,6 +23,7 @@ import logging
import io import io
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
from bumble import core
from bumble import hci from bumble import hci
from bumble.colors import color from bumble.colors import color
from bumble.snoop import Snooper from bumble.snoop import Snooper
@@ -49,10 +50,16 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Errors # Errors
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TransportLostError(Exception): class TransportLostError(core.BaseBumbleError, RuntimeError):
""" """The Transport has been lost/disconnected."""
The Transport has been lost/disconnected.
"""
class TransportInitError(core.BaseBumbleError, RuntimeError):
"""Error raised when the transport cannot be initialized."""
class TransportSpecError(core.BaseBumbleError, ValueError):
"""Error raised when the transport spec is invalid."""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -132,7 +139,9 @@ class PacketParser:
packet_type packet_type
) or self.extended_packet_info.get(packet_type) ) or self.extended_packet_info.get(packet_type)
if self.packet_info is None: if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}') raise core.InvalidPacketError(
f'invalid packet type {packet_type}'
)
self.state = PacketParser.NEED_LENGTH self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1] self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH: elif self.state == PacketParser.NEED_LENGTH:
@@ -178,19 +187,19 @@ class PacketReader:
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found') raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
header = self.source.read(header_size) header = self.source.read(header_size)
if len(header) != header_size: if len(header) != header_size:
raise ValueError('packet too short') raise core.InvalidPacketError('packet too short')
# Read the body # Read the body
body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0]
body = self.source.read(body_length) body = self.source.read(body_length)
if len(body) != body_length: if len(body) != body_length:
raise ValueError('packet too short') raise core.InvalidPacketError('packet too short')
return packet_type + header + body return packet_type + header + body
@@ -211,7 +220,7 @@ class AsyncPacketReader:
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found') raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -420,7 +429,7 @@ class SnoopingTransport(Transport):
return SnoopingTransport( return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
) )
raise RuntimeError('unexpected code path') # Satisfy the type checker raise core.UnreachableError() # Satisfy the type checker
class Source: class Source:
sink: TransportSink sink: TransportSink

View File

@@ -29,7 +29,7 @@ from usb.core import USBError
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
from .common import Transport, ParserSource from .common import Transport, ParserSource, TransportInitError
from .. import hci from .. import hci
from ..colors import color from ..colors import color
@@ -259,7 +259,7 @@ async def open_pyusb_transport(spec: str) -> Transport:
device = None device = None
if device is None: if device is None:
raise ValueError('device not found') raise TransportInitError('device not found')
logger.debug(f'USB Device: {device}') logger.debug(f'USB Device: {device}')
# Power Cycle the device # Power Cycle the device

View File

@@ -24,10 +24,9 @@ import platform
import usb1 import usb1
from bumble.transport.common import Transport, ParserSource from bumble.transport.common import Transport, ParserSource, TransportInitError
from bumble import hci from bumble import hci
from bumble.colors import color from bumble.colors import color
from bumble.utils import AsyncRunner
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -442,7 +441,7 @@ async def open_usb_transport(spec: str) -> Transport:
if found is None: if found is None:
context.close() context.close()
raise ValueError('device not found') raise TransportInitError('device not found')
logger.debug(f'USB Device: {found}') logger.debug(f'USB Device: {found}')
@@ -507,7 +506,7 @@ async def open_usb_transport(spec: str) -> Transport:
endpoints = find_endpoints(found) endpoints = find_endpoints(found)
if endpoints is None: if endpoints is None:
raise ValueError('no compatible interface found for device') raise TransportInitError('no compatible interface found for device')
(configuration, interface, setting, acl_in, acl_out, events_in) = endpoints (configuration, interface, setting, acl_in, acl_out, events_in) = endpoints
logger.debug( logger.debug(
f'selected endpoints: configuration={configuration}, ' f'selected endpoints: configuration={configuration}, '