From f8a2d4f0e03020eb2046c970734596cb81448dca Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Tue, 4 Jun 2024 16:11:26 +0800 Subject: [PATCH] Reorganize exceptions * Add BaseBumbleException as a "real" root error * Add several core error classes and properly replace builtin errors with them * Add several error classes for specific modules (transport, device) --- bumble/at.py | 18 ++++--- bumble/avc.py | 13 ++++-- bumble/avctp.py | 5 +- bumble/avdtp.py | 3 +- bumble/avrcp.py | 3 +- bumble/codecs.py | 30 ++++++------ bumble/colors.py | 8 +++- bumble/core.py | 40 +++++++++++++--- bumble/device.py | 67 ++++++++++++++++----------- bumble/drivers/rtk.py | 21 +++++---- bumble/gatt_client.py | 4 +- bumble/hci.py | 24 ++++++---- bumble/l2cap.py | 41 +++++++++------- bumble/link.py | 11 +++-- bumble/profiles/csip.py | 8 ++-- bumble/profiles/heart_rate_service.py | 7 +-- bumble/rfcomm.py | 6 ++- bumble/sdp.py | 24 +++++----- bumble/smp.py | 5 +- bumble/snoop.py | 9 ++-- bumble/transport/__init__.py | 4 +- bumble/transport/android_emulator.py | 12 +++-- bumble/transport/android_netsim.py | 16 ++++--- bumble/transport/common.py | 29 ++++++++---- bumble/transport/pyusb.py | 4 +- bumble/transport/usb.py | 7 ++- 26 files changed, 260 insertions(+), 159 deletions(-) diff --git a/bumble/at.py b/bumble/at.py index 78a4b086..ed9aeed9 100644 --- a/bumble/at.py +++ b/bumble/at.py @@ -14,13 +14,19 @@ from typing import List, Union +from bumble import core + + +class AtParsingError(core.InvalidPacketError): + """Error raised when parsing AT commands fails.""" + def tokenize_parameters(buffer: bytes) -> List[bytes]: """Split input parameters into tokens. Removes space characters outside of double quote blocks: T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0) are ignored [..], unless they are embedded in numeric or string constants" - Raises ValueError in case of invalid input string.""" + Raises AtParsingError in case of invalid input string.""" tokens = [] in_quotes = False @@ -43,11 +49,11 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]: token = bytearray() elif char == b'(': if len(token) > 0: - raise ValueError("open_paren following regular character") + raise AtParsingError("open_paren following regular character") tokens.append(char) elif char == b'"': if len(token) > 0: - raise ValueError("quote following regular character") + raise AtParsingError("quote following regular character") in_quotes = True token.extend(char) else: @@ -59,7 +65,7 @@ def tokenize_parameters(buffer: bytes) -> List[bytes]: def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]: """Parse the parameters using the comma and parenthesis separators. - Raises ValueError in case of invalid input string.""" + Raises AtParsingError in case of invalid input string.""" tokens = tokenize_parameters(buffer) accumulator: List[list] = [[]] @@ -73,7 +79,7 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]: accumulator.append([]) elif token == b')': if len(accumulator) < 2: - raise ValueError("close_paren without matching open_paren") + raise AtParsingError("close_paren without matching open_paren") accumulator[-1].append(current) current = accumulator.pop() else: @@ -81,5 +87,5 @@ def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]: accumulator[-1].append(current) if len(accumulator) > 1: - raise ValueError("missing close_paren") + raise AtParsingError("missing close_paren") return accumulator[0] diff --git a/bumble/avc.py b/bumble/avc.py index 1d0a7dc9..8e6b968e 100644 --- a/bumble/avc.py +++ b/bumble/avc.py @@ -20,6 +20,7 @@ import enum import struct from typing import Dict, Type, Union, Tuple +from bumble import core from bumble.utils import OpenIntEnum @@ -88,7 +89,9 @@ class Frame: short_name = subclass.__name__.replace("ResponseFrame", "") category_class = ResponseFrame else: - raise ValueError(f"invalid subclass name {subclass.__name__}") + raise core.InvalidArgumentError( + f"invalid subclass name {subclass.__name__}" + ) uppercase_indexes = [ i for i in range(len(short_name)) if short_name[i].isupper() @@ -106,7 +109,7 @@ class Frame: @staticmethod def from_bytes(data: bytes) -> Frame: if data[0] >> 4 != 0: - raise ValueError("first 4 bits must be 0s") + raise core.InvalidPacketError("first 4 bits must be 0s") ctype_or_response = data[0] & 0xF subunit_type = Frame.SubunitType(data[1] >> 3) @@ -122,7 +125,7 @@ class Frame: # Extended to the next byte extension = data[2] if extension == 0: - raise ValueError("extended subunit ID value reserved") + raise core.InvalidPacketError("extended subunit ID value reserved") if extension == 0xFF: subunit_id = 5 + 254 + data[3] opcode_offset = 4 @@ -131,7 +134,7 @@ class Frame: opcode_offset = 3 elif subunit_id == 6: - raise ValueError("reserved subunit ID") + raise core.InvalidPacketError("reserved subunit ID") opcode = Frame.OperationCode(data[opcode_offset]) operands = data[opcode_offset + 1 :] @@ -448,7 +451,7 @@ class PassThroughFrame: operation_data: bytes, ) -> None: if len(operation_data) > 255: - raise ValueError("operation data must be <= 255 bytes") + raise core.InvalidArgumentError("operation data must be <= 255 bytes") self.state_flag = state_flag self.operation_id = operation_id self.operation_data = operation_data diff --git a/bumble/avctp.py b/bumble/avctp.py index 22713249..6d702561 100644 --- a/bumble/avctp.py +++ b/bumble/avctp.py @@ -23,6 +23,7 @@ from typing import Callable, cast, Dict, Optional from bumble.colors import color from bumble import avc +from bumble import core from bumble import l2cap # ----------------------------------------------------------------------------- @@ -275,7 +276,7 @@ class Protocol: self, pid: int, handler: Protocol.CommandHandler ) -> None: if pid not in self.command_handlers or self.command_handlers[pid] != handler: - raise ValueError("command handler not registered") + raise core.InvalidArgumentError("command handler not registered") del self.command_handlers[pid] def register_response_handler( @@ -287,5 +288,5 @@ class Protocol: self, pid: int, handler: Protocol.ResponseHandler ) -> None: if pid not in self.response_handlers or self.response_handlers[pid] != handler: - raise ValueError("response handler not registered") + raise core.InvalidArgumentError("response handler not registered") del self.response_handlers[pid] diff --git a/bumble/avdtp.py b/bumble/avdtp.py index 713f7b7f..c9418819 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -43,6 +43,7 @@ from .core import ( BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, InvalidStateError, ProtocolError, + InvalidArgumentError, name_or_number, ) from .a2dp import ( @@ -700,7 +701,7 @@ class Message: # pylint:disable=attribute-defined-outside-init signal_identifier_str = name[:-7] message_type = Message.MessageType.RESPONSE_REJECT else: - raise ValueError('invalid class name') + raise InvalidArgumentError('invalid class name') subclass.message_type = message_type diff --git a/bumble/avrcp.py b/bumble/avrcp.py index 11f4effa..e06a5a67 100644 --- a/bumble/avrcp.py +++ b/bumble/avrcp.py @@ -55,6 +55,7 @@ from bumble.sdp import ( ) from bumble.utils import AsyncRunner, OpenIntEnum from bumble.core import ( + InvalidArgumentError, ProtocolError, BT_L2CAP_PROTOCOL_ID, BT_AVCTP_PROTOCOL_ID, @@ -1411,7 +1412,7 @@ class Protocol(pyee.EventEmitter): def notify_track_changed(self, identifier: bytes) -> None: """Notify the connected peer of a Track change.""" if len(identifier) != 8: - raise ValueError("identifier must be 8 bytes") + raise InvalidArgumentError("identifier must be 8 bytes") self.notify_event(TrackChangedEvent(identifier)) def notify_playback_position_changed(self, position: int) -> None: diff --git a/bumble/codecs.py b/bumble/codecs.py index 1d7ae82c..cfb3cad1 100644 --- a/bumble/codecs.py +++ b/bumble/codecs.py @@ -18,6 +18,8 @@ from __future__ import annotations from dataclasses import dataclass +from bumble import core + # ----------------------------------------------------------------------------- class BitReader: @@ -40,7 +42,7 @@ class BitReader: """ "Read up to 32 bits.""" if bits > 32: - raise ValueError('maximum read size is 32') + raise core.InvalidArgumentError('maximum read size is 32') if self.bits_cached >= bits: # We have enough bits. @@ -53,7 +55,7 @@ class BitReader: feed_size = len(feed_bytes) feed_int = int.from_bytes(feed_bytes, byteorder='big') if 8 * feed_size + self.bits_cached < bits: - raise ValueError('trying to read past the data') + raise core.InvalidArgumentError('trying to read past the data') self.byte_position += feed_size # Combine the new cache and the old cache @@ -68,7 +70,7 @@ class BitReader: def read_bytes(self, count: int): if self.bit_position + 8 * count > 8 * len(self.data): - raise ValueError('not enough data') + raise core.InvalidArgumentError('not enough data') if self.bit_position % 8: # Not byte aligned @@ -113,7 +115,7 @@ class AacAudioRtpPacket: @staticmethod def program_config_element(reader: BitReader): - raise ValueError('program_config_element not supported') + raise core.InvalidPacketError('program_config_element not supported') @dataclass class GASpecificConfig: @@ -140,7 +142,7 @@ class AacAudioRtpPacket: aac_spectral_data_resilience_flags = reader.read(1) extension_flag_3 = reader.read(1) if extension_flag_3 == 1: - raise ValueError('extensionFlag3 == 1 not supported') + raise core.InvalidPacketError('extensionFlag3 == 1 not supported') @staticmethod def audio_object_type(reader: BitReader): @@ -216,7 +218,7 @@ class AacAudioRtpPacket: reader, self.channel_configuration, self.audio_object_type ) else: - raise ValueError( + raise core.InvalidPacketError( f'audioObjectType {self.audio_object_type} not supported' ) @@ -260,7 +262,7 @@ class AacAudioRtpPacket: else: audio_mux_version_a = 0 if audio_mux_version_a != 0: - raise ValueError('audioMuxVersionA != 0 not supported') + raise core.InvalidPacketError('audioMuxVersionA != 0 not supported') if audio_mux_version == 1: tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader) stream_cnt = 0 @@ -268,10 +270,10 @@ class AacAudioRtpPacket: num_sub_frames = reader.read(6) num_program = reader.read(4) if num_program != 0: - raise ValueError('num_program != 0 not supported') + raise core.InvalidPacketError('num_program != 0 not supported') num_layer = reader.read(3) if num_layer != 0: - raise ValueError('num_layer != 0 not supported') + raise core.InvalidPacketError('num_layer != 0 not supported') if audio_mux_version == 0: self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig( reader @@ -284,7 +286,7 @@ class AacAudioRtpPacket: ) audio_specific_config_len = reader.bit_position - marker if asc_len < audio_specific_config_len: - raise ValueError('audio_specific_config_len > asc_len') + raise core.InvalidPacketError('audio_specific_config_len > asc_len') asc_len -= audio_specific_config_len reader.skip(asc_len) frame_length_type = reader.read(3) @@ -293,7 +295,9 @@ class AacAudioRtpPacket: elif frame_length_type == 1: frame_length = reader.read(9) else: - raise ValueError(f'frame_length_type {frame_length_type} not supported') + raise core.InvalidPacketError( + f'frame_length_type {frame_length_type} not supported' + ) self.other_data_present = reader.read(1) if self.other_data_present: @@ -318,12 +322,12 @@ class AacAudioRtpPacket: def __init__(self, reader: BitReader, mux_config_present: int): if mux_config_present == 0: - raise ValueError('muxConfigPresent == 0 not supported') + raise core.InvalidPacketError('muxConfigPresent == 0 not supported') # AudioMuxElement - ISO/EIC 14496-3 Table 1.41 use_same_stream_mux = reader.read(1) if use_same_stream_mux: - raise ValueError('useSameStreamMux == 1 not supported') + raise core.InvalidPacketError('useSameStreamMux == 1 not supported') self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader) # We only support: diff --git a/bumble/colors.py b/bumble/colors.py index 2813cfe5..37ce03a8 100644 --- a/bumble/colors.py +++ b/bumble/colors.py @@ -16,6 +16,10 @@ from functools import partial from typing import List, Optional, Union +class ColorError(ValueError): + """Error raised when a color spec is invalid.""" + + # ANSI color names. There is also a "default" COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white') @@ -52,7 +56,7 @@ def _color_code(spec: ColorSpec, base: int) -> str: elif isinstance(spec, int) and 0 <= spec <= 255: return _join(base + 8, 5, spec) else: - raise ValueError('Invalid color spec "%s"' % spec) + raise ColorError('Invalid color spec "%s"' % spec) def color( @@ -72,7 +76,7 @@ def color( if style_part in STYLES: codes.append(STYLES.index(style_part)) else: - raise ValueError('Invalid style "%s"' % style_part) + raise ColorError('Invalid style "%s"' % style_part) if codes: return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s) diff --git a/bumble/core.py b/bumble/core.py index dce721a4..28b5611c 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -76,7 +76,13 @@ def get_dict_key_by_value(dictionary, value): # ----------------------------------------------------------------------------- # Exceptions # ----------------------------------------------------------------------------- -class BaseError(Exception): + + +class BaseBumbleError(Exception): + """Base Error raised by Bumble.""" + + +class BaseError(BaseBumbleError): """Base class for errors with an error code, error name and namespace""" def __init__( @@ -115,18 +121,38 @@ class ProtocolError(BaseError): """Protocol Error""" -class TimeoutError(Exception): # pylint: disable=redefined-builtin +class TimeoutError(BaseBumbleError): # pylint: disable=redefined-builtin """Timeout Error""" -class CommandTimeoutError(Exception): +class CommandTimeoutError(BaseBumbleError): """Command Timeout Error""" -class InvalidStateError(Exception): +class InvalidStateError(BaseBumbleError): """Invalid State Error""" +class InvalidArgumentError(BaseBumbleError, ValueError): + """Invalid Argument Error""" + + +class InvalidPacketError(BaseBumbleError, ValueError): + """Invalid Packet Error""" + + +class InvalidOperationError(BaseBumbleError, RuntimeError): + """Invalid Operation Error""" + + +class OutOfResourcesError(BaseBumbleError, RuntimeError): + """Out of Resources Error""" + + +class UnreachableError(BaseBumbleError): + """The code path raising this error should be unreachable.""" + + class ConnectionError(BaseError): # pylint: disable=redefined-builtin """Connection Error""" @@ -185,12 +211,12 @@ class UUID: or uuid_str_or_int[18] != '-' or uuid_str_or_int[23] != '-' ): - raise ValueError('invalid UUID format') + raise InvalidArgumentError('invalid UUID format') uuid_str = uuid_str_or_int.replace('-', '') else: uuid_str = uuid_str_or_int if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4: - raise ValueError(f"invalid UUID format: {uuid_str}") + raise InvalidArgumentError(f"invalid UUID format: {uuid_str}") self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str))) self.name = name @@ -215,7 +241,7 @@ class UUID: return self.register() - raise ValueError('only 2, 4 and 16 bytes are allowed') + raise InvalidArgumentError('only 2, 4 and 16 bytes are allowed') @classmethod def from_16_bits(cls, uuid_16: int, name: Optional[str] = None) -> UUID: diff --git a/bumble/device.py b/bumble/device.py index 5f686cdd..9387d0ce 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -173,10 +173,15 @@ from .core import ( BT_LE_TRANSPORT, BT_PERIPHERAL_ROLE, AdvertisingData, + BaseBumbleError, ConnectionParameterUpdateError, CommandTimeoutError, ConnectionPHY, + InvalidArgumentError, + InvalidOperationError, InvalidStateError, + OutOfResourcesError, + UnreachableError, ) from .utils import ( AsyncRunner, @@ -259,6 +264,8 @@ DEVICE_MAX_HIGH_DUTY_CYCLE_CONNECTABLE_DIRECTED_ADVERTISING_DURATION = 1.28 # ----------------------------------------------------------------------------- # Classes # ----------------------------------------------------------------------------- +class ObjectLookupError(BaseBumbleError): + """Error raised when failed to lookup an object.""" # ----------------------------------------------------------------------------- @@ -1374,7 +1381,9 @@ def with_connection_from_handle(function): @functools.wraps(function) def wrapper(self, connection_handle, *args, **kwargs): if (connection := self.lookup_connection(connection_handle)) is None: - raise ValueError(f'no connection for handle: 0x{connection_handle:04x}') + raise ObjectLookupError( + f'no connection for handle: 0x{connection_handle:04x}' + ) return function(self, connection, *args, **kwargs) return wrapper @@ -1389,7 +1398,7 @@ def with_connection_from_address(function): for connection in self.connections.values(): if connection.peer_address == address: return function(self, connection, *args, **kwargs) - raise ValueError('no connection for address') + raise ObjectLookupError('no connection for address') return wrapper @@ -1798,7 +1807,7 @@ class Device(CompositeEventEmitter): spec=spec, ) else: - raise ValueError(f'Unexpected mode {spec}') + raise InvalidArgumentError(f'Unexpected mode {spec}') def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: self.host.send_l2cap_pdu(connection_handle, cid, pdu) @@ -1959,7 +1968,7 @@ class Device(CompositeEventEmitter): def supports_le_features(self, feature: LeFeatureMask) -> bool: return self.host.supports_le_features(feature) - def supports_le_phy(self, phy): + def supports_le_phy(self, phy: int) -> bool: if phy == HCI_LE_1M_PHY: return True @@ -1968,7 +1977,7 @@ class Device(CompositeEventEmitter): HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY, } if phy not in feature_map: - raise ValueError('invalid PHY') + raise InvalidArgumentError('invalid PHY') return self.supports_le_features(feature_map[phy]) @@ -2028,7 +2037,7 @@ class Device(CompositeEventEmitter): # Decide what peer address to use if advertising_type.is_directed: if target is None: - raise ValueError('directed advertising requires a target') + raise InvalidArgumentError('directed advertising requires a target') peer_address = target else: peer_address = Address.ANY @@ -2135,7 +2144,7 @@ class Device(CompositeEventEmitter): and advertising_data and scan_response_data ): - raise ValueError( + raise InvalidArgumentError( "Extended advertisements can't have both data and scan \ response data" ) @@ -2151,7 +2160,9 @@ class Device(CompositeEventEmitter): if handle not in self.extended_advertising_sets ) except StopIteration as exc: - raise RuntimeError("all valid advertising handles already in use") from exc + raise OutOfResourcesError( + "all valid advertising handles already in use" + ) from exc # Use the device's random address if a random address is needed but none was # provided. @@ -2250,14 +2261,14 @@ class Device(CompositeEventEmitter): ) -> None: # Check that the arguments are legal if scan_interval < scan_window: - raise ValueError('scan_interval must be >= scan_window') + raise InvalidArgumentError('scan_interval must be >= scan_window') if ( scan_interval < DEVICE_MIN_SCAN_INTERVAL or scan_interval > DEVICE_MAX_SCAN_INTERVAL ): - raise ValueError('scan_interval out of range') + raise InvalidArgumentError('scan_interval out of range') if scan_window < DEVICE_MIN_SCAN_WINDOW or scan_window > DEVICE_MAX_SCAN_WINDOW: - raise ValueError('scan_interval out of range') + raise InvalidArgumentError('scan_interval out of range') # Reset the accumulators self.advertisement_accumulators = {} @@ -2285,7 +2296,7 @@ class Device(CompositeEventEmitter): scanning_phy_count += 1 if scanning_phy_count == 0: - raise ValueError('at least one scanning PHY must be enabled') + raise InvalidArgumentError('at least one scanning PHY must be enabled') await self.send_command( HCI_LE_Set_Extended_Scan_Parameters_Command( @@ -2479,7 +2490,7 @@ class Device(CompositeEventEmitter): # Check parameters if transport not in (BT_LE_TRANSPORT, BT_BR_EDR_TRANSPORT): - raise ValueError('invalid transport') + raise InvalidArgumentError('invalid transport') # Adjust the transport automatically if we need to if transport == BT_LE_TRANSPORT and not self.le_enabled: @@ -2496,7 +2507,7 @@ class Device(CompositeEventEmitter): peer_address = Address.from_string_for_transport( peer_address, transport ) - except ValueError: + except InvalidArgumentError: # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') peer_address = await self.find_peer_by_name( @@ -2508,7 +2519,7 @@ class Device(CompositeEventEmitter): transport == BT_BR_EDR_TRANSPORT and peer_address.address_type != Address.PUBLIC_DEVICE_ADDRESS ): - raise ValueError('BR/EDR addresses must be PUBLIC') + raise InvalidArgumentError('BR/EDR addresses must be PUBLIC') assert isinstance(peer_address, Address) @@ -2559,7 +2570,7 @@ class Device(CompositeEventEmitter): ) ) if not phys: - raise ValueError('at least one supported PHY needed') + raise InvalidArgumentError('at least one supported PHY needed') phy_count = len(phys) initiating_phys = phy_list_to_bits(phys) @@ -2631,7 +2642,7 @@ class Device(CompositeEventEmitter): ) else: if HCI_LE_1M_PHY not in connection_parameters_preferences: - raise ValueError('1M PHY preferences required') + raise InvalidArgumentError('1M PHY preferences required') prefs = connection_parameters_preferences[HCI_LE_1M_PHY] result = await self.send_command( @@ -2731,7 +2742,7 @@ class Device(CompositeEventEmitter): if isinstance(peer_address, str): try: peer_address = Address(peer_address) - except ValueError: + except InvalidArgumentError: # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') peer_address = await self.find_peer_by_name( @@ -2741,7 +2752,7 @@ class Device(CompositeEventEmitter): assert isinstance(peer_address, Address) if peer_address == Address.NIL: - raise ValueError('accept on nil address') + raise InvalidArgumentError('accept on nil address') # Create a future so that we can wait for the request pending_request_fut = asyncio.get_running_loop().create_future() @@ -2854,7 +2865,7 @@ class Device(CompositeEventEmitter): if isinstance(peer_address, str): try: peer_address = Address(peer_address) - except ValueError: + except InvalidArgumentError: # If the address is not parsable, assume it is a name instead logger.debug('looking for peer by name') peer_address = await self.find_peer_by_name( @@ -2897,10 +2908,10 @@ class Device(CompositeEventEmitter): async def set_data_length(self, connection, tx_octets, tx_time) -> None: if tx_octets < 0x001B or tx_octets > 0x00FB: - raise ValueError('tx_octets must be between 0x001B and 0x00FB') + raise InvalidArgumentError('tx_octets must be between 0x001B and 0x00FB') if tx_time < 0x0148 or tx_time > 0x4290: - raise ValueError('tx_time must be between 0x0148 and 0x4290') + raise InvalidArgumentError('tx_time must be between 0x0148 and 0x4290') return await self.send_command( HCI_LE_Set_Data_Length_Command( @@ -3175,7 +3186,7 @@ class Device(CompositeEventEmitter): async def encrypt(self, connection, enable=True): if not enable and connection.transport == BT_LE_TRANSPORT: - raise ValueError('`enable` parameter is classic only.') + raise InvalidArgumentError('`enable` parameter is classic only.') # Set up event handlers pending_encryption = asyncio.get_running_loop().create_future() @@ -3194,11 +3205,11 @@ class Device(CompositeEventEmitter): if connection.transport == BT_LE_TRANSPORT: # Look for a key in the key store if self.keystore is None: - raise RuntimeError('no key store') + raise InvalidOperationError('no key store') keys = await self.keystore.get(str(connection.peer_address)) if keys is None: - raise RuntimeError('keys not found in key store') + raise InvalidOperationError('keys not found in key store') if keys.ltk is not None: ltk = keys.ltk.value @@ -3209,7 +3220,7 @@ class Device(CompositeEventEmitter): rand = keys.ltk_central.rand ediv = keys.ltk_central.ediv else: - raise RuntimeError('no LTK found for peer') + raise InvalidOperationError('no LTK found for peer') if connection.role != HCI_CENTRAL_ROLE: raise InvalidStateError('only centrals can start encryption') @@ -3484,7 +3495,7 @@ class Device(CompositeEventEmitter): return cis_link # Mypy believes this is reachable when context is an ExitStack. - raise InvalidStateError('Unreachable') + raise UnreachableError() # [LE only] @experimental('Only for testing.') @@ -3950,7 +3961,7 @@ class Device(CompositeEventEmitter): return await pairing_config.delegate.confirm(auto=True) async def na() -> bool: - assert False, "N/A: unreachable" + raise UnreachableError() # See Bluetooth spec @ Vol 3, Part C 5.2.2.6 methods = { diff --git a/bumble/drivers/rtk.py b/bumble/drivers/rtk.py index 4a9034db..1336d2c2 100644 --- a/bumble/drivers/rtk.py +++ b/bumble/drivers/rtk.py @@ -33,6 +33,7 @@ from typing import Tuple import weakref +from bumble import core from bumble.hci import ( hci_vendor_command_op_code, STATUS_SPEC, @@ -49,6 +50,10 @@ from bumble.drivers import common logger = logging.getLogger(__name__) +class RtkFirmwareError(core.BaseBumbleError): + """Error raised when RTK firmware initialization fails.""" + + # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- @@ -208,15 +213,15 @@ class Firmware: extension_sig = bytes([0x51, 0x04, 0xFD, 0x77]) if not firmware.startswith(RTK_EPATCH_SIGNATURE): - raise ValueError("Firmware does not start with epatch signature") + raise RtkFirmwareError("Firmware does not start with epatch signature") if not firmware.endswith(extension_sig): - raise ValueError("Firmware does not end with extension sig") + raise RtkFirmwareError("Firmware does not end with extension sig") # The firmware should start with a 14 byte header. epatch_header_size = 14 if len(firmware) < epatch_header_size: - raise ValueError("Firmware too short") + raise RtkFirmwareError("Firmware too short") # Look for the "project ID", starting from the end. offset = len(firmware) - len(extension_sig) @@ -230,7 +235,7 @@ class Firmware: break if length == 0: - raise ValueError("Invalid 0-length instruction") + raise RtkFirmwareError("Invalid 0-length instruction") if opcode == 0 and length == 1: project_id = firmware[offset - 1] @@ -239,7 +244,7 @@ class Firmware: offset -= length if project_id < 0: - raise ValueError("Project ID not found") + raise RtkFirmwareError("Project ID not found") self.project_id = project_id @@ -252,7 +257,7 @@ class Firmware: # ... (16 bits each) # ... (32 bits each) if epatch_header_size + 8 * num_patches > len(firmware): - raise ValueError("Firmware too short") + raise RtkFirmwareError("Firmware too short") chip_id_table_offset = epatch_header_size patch_length_table_offset = chip_id_table_offset + 2 * num_patches patch_offset_table_offset = chip_id_table_offset + 4 * num_patches @@ -266,7 +271,7 @@ class Firmware: " len(firmware): - raise ValueError("Firmware too short") + raise RtkFirmwareError("Firmware too short") # Get the SVN version for the patch (svn_version,) = struct.unpack_from( @@ -645,7 +650,7 @@ class Driver(common.Driver): ): return await self.download_for_rtl8723b() - raise ValueError("ROM not supported") + raise RtkFirmwareError("ROM not supported") async def init_controller(self): await self.download_firmware() diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index c71aabd7..6d4dcf6a 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -331,9 +331,9 @@ class Client: async def request_mtu(self, mtu: int) -> int: # Check the range if mtu < ATT_DEFAULT_MTU: - raise ValueError(f'MTU must be >= {ATT_DEFAULT_MTU}') + raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}') if mtu > 0xFFFF: - raise ValueError('MTU must be <= 0xFFFF') + raise core.InvalidArgumentError('MTU must be <= 0xFFFF') # We can only send one request per connection if self.mtu_exchange_done: diff --git a/bumble/hci.py b/bumble/hci.py index 9ef40bf2..66cc5ba8 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -31,6 +31,8 @@ from .core import ( BT_BR_EDR_TRANSPORT, AdvertisingData, DeviceClass, + InvalidArgumentError, + InvalidPacketError, ProtocolError, bit_flags_to_strings, name_or_number, @@ -91,14 +93,14 @@ def map_class_of_device(class_of_device): ) -def phy_list_to_bits(phys): +def phy_list_to_bits(phys: Optional[Iterable[int]]) -> int: if phys is None: return 0 phy_bits = 0 for phy in phys: if phy not in HCI_LE_PHY_TYPE_TO_BIT: - raise ValueError('invalid PHY') + raise InvalidArgumentError('invalid PHY') phy_bits |= 1 << HCI_LE_PHY_TYPE_TO_BIT[phy] return phy_bits @@ -1552,7 +1554,7 @@ class HCI_Object: new_offset, field_value = field_type(data, offset) return (field_value, new_offset - offset) - raise ValueError(f'unknown field type {field_type}') + raise InvalidArgumentError(f'unknown field type {field_type}') @staticmethod def dict_from_bytes(data, offset, fields): @@ -1621,7 +1623,7 @@ class HCI_Object: if 0 <= field_value <= 255: field_bytes = bytes([field_value]) else: - raise ValueError('value too large for *-typed field') + raise InvalidArgumentError('value too large for *-typed field') else: field_bytes = bytes(field_value) elif field_type == 'v': @@ -1640,7 +1642,9 @@ class HCI_Object: elif len(field_bytes) > field_type: field_bytes = field_bytes[:field_type] else: - raise ValueError(f"don't know how to serialize type {type(field_value)}") + raise InvalidArgumentError( + f"don't know how to serialize type {type(field_value)}" + ) return field_bytes @@ -1904,7 +1908,7 @@ class Address: self.address_bytes = bytes(reversed(bytes.fromhex(address))) if len(self.address_bytes) != 6: - raise ValueError('invalid address length') + raise InvalidArgumentError('invalid address length') self.address_type = address_type @@ -2104,7 +2108,7 @@ class HCI_Command(HCI_Packet): op_code, length = struct.unpack_from('> 14) & 3 data = packet[5:] if len(data) != data_total_length: - raise ValueError('invalid packet length') + raise InvalidPacketError('invalid packet length') return HCI_AclDataPacket( connection_handle, pb_flag, bc_flag, data_total_length, data ) @@ -6152,7 +6156,7 @@ class HCI_SynchronousDataPacket(HCI_Packet): packet_status = (h >> 12) & 0b11 data = packet[4:] if len(data) != data_total_length: - raise ValueError( + raise InvalidPacketError( f'invalid packet length {len(data)} != {data_total_length}' ) return HCI_SynchronousDataPacket( diff --git a/bumble/l2cap.py b/bumble/l2cap.py index cec14b85..b16a27a0 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -41,7 +41,14 @@ from typing import ( from .utils import deprecated from .colors import color -from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError +from .core import ( + BT_CENTRAL_ROLE, + InvalidStateError, + InvalidArgumentError, + InvalidPacketError, + OutOfResourcesError, + ProtocolError, +) from .hci import ( HCI_LE_Connection_Update_Command, HCI_Object, @@ -188,17 +195,17 @@ class LeCreditBasedChannelSpec: self.max_credits < 1 or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS ): - raise ValueError('max credits out of range') + raise InvalidArgumentError('max credits out of range') if ( self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU ): - raise ValueError('MTU out of range') + raise InvalidArgumentError('MTU out of range') if ( self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS ): - raise ValueError('MPS out of range') + raise InvalidArgumentError('MPS out of range') class L2CAP_PDU: @@ -210,7 +217,7 @@ class L2CAP_PDU: def from_bytes(data: bytes) -> L2CAP_PDU: # Check parameters if len(data) < 4: - raise ValueError('not enough data for L2CAP header') + raise InvalidPacketError('not enough data for L2CAP header') _, l2cap_pdu_cid = struct.unpack_from(' int: @@ -1526,7 +1533,7 @@ class ChannelManager: if cid not in channels: return cid - raise RuntimeError('no free CID') + raise OutOfResourcesError('no free CID') def next_identifier(self, connection: Connection) -> int: identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 @@ -1573,15 +1580,15 @@ class ChannelManager: else: # Check that the PSM isn't already in use if spec.psm in self.servers: - raise ValueError('PSM already in use') + raise InvalidArgumentError('PSM already in use') # Check that the PSM is valid if spec.psm % 2 == 0: - raise ValueError('invalid PSM (not odd)') + raise InvalidArgumentError('invalid PSM (not odd)') check = spec.psm >> 8 while check: if check % 2 != 0: - raise ValueError('invalid PSM') + raise InvalidArgumentError('invalid PSM') check >>= 8 self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu) @@ -1623,7 +1630,7 @@ class ChannelManager: else: # Check that the PSM isn't already in use if spec.psm in self.le_coc_servers: - raise ValueError('PSM already in use') + raise InvalidArgumentError('PSM already in use') self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer( self, @@ -2151,10 +2158,10 @@ class ChannelManager: connection_channels = self.channels.setdefault(connection.handle, {}) source_cid = self.find_free_le_cid(connection_channels) if source_cid is None: # Should never happen! - raise RuntimeError('all CIDs already in use') + raise OutOfResourcesError('all CIDs already in use') if spec.psm is None: - raise ValueError('PSM cannot be None') + raise InvalidArgumentError('PSM cannot be None') # Create the channel logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}') @@ -2203,10 +2210,10 @@ class ChannelManager: connection_channels = self.channels.setdefault(connection.handle, {}) source_cid = self.find_free_br_edr_cid(connection_channels) if source_cid is None: # Should never happen! - raise RuntimeError('all CIDs already in use') + raise OutOfResourcesError('all CIDs already in use') if spec.psm is None: - raise ValueError('PSM cannot be None') + raise InvalidArgumentError('PSM cannot be None') # Create the channel logger.debug( diff --git a/bumble/link.py b/bumble/link.py index 5ef56b7a..8971e21b 100644 --- a/bumble/link.py +++ b/bumble/link.py @@ -19,7 +19,12 @@ import logging import asyncio from functools import partial -from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT +from bumble.core import ( + BT_PERIPHERAL_ROLE, + BT_BR_EDR_TRANSPORT, + BT_LE_TRANSPORT, + InvalidStateError, +) from bumble.colors import color from bumble.hci import ( Address, @@ -405,12 +410,12 @@ class RemoteLink: def add_controller(self, controller): if self.controller: - raise ValueError('controller already set') + raise InvalidStateError('controller already set') self.controller = controller def remove_controller(self, controller): if self.controller != controller: - raise ValueError('controller mismatch') + raise InvalidStateError('controller mismatch') self.controller = None def get_pending_connection(self): diff --git a/bumble/profiles/csip.py b/bumble/profiles/csip.py index 03fba9c5..9ba3bafd 100644 --- a/bumble/profiles/csip.py +++ b/bumble/profiles/csip.py @@ -113,7 +113,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService): set_member_rank: Optional[int] = None, ) -> None: if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH: - raise ValueError( + raise core.InvalidArgumentError( f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}' ) @@ -178,7 +178,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService): key = await connection.device.get_link_key(connection.peer_address) if not key: - raise RuntimeError('LTK or LinkKey is not present') + raise core.InvalidOperationError('LTK or LinkKey is not present') sirk_bytes = sef(key, self.set_identity_resolving_key) @@ -234,7 +234,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy): '''Reads SIRK and decrypts if encrypted.''' response = await self.set_identity_resolving_key.read_value() if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1: - raise RuntimeError('Invalid SIRK value') + raise core.InvalidPacketError('Invalid SIRK value') sirk_type = SirkType(response[0]) if sirk_type == SirkType.PLAINTEXT: @@ -250,7 +250,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy): key = await device.get_link_key(connection.peer_address) if not key: - raise RuntimeError('LTK or LinkKey is not present') + raise core.InvalidOperationError('LTK or LinkKey is not present') sirk = sef(key, response[1:]) diff --git a/bumble/profiles/heart_rate_service.py b/bumble/profiles/heart_rate_service.py index fe46cb26..0c9a12f0 100644 --- a/bumble/profiles/heart_rate_service.py +++ b/bumble/profiles/heart_rate_service.py @@ -19,6 +19,7 @@ from enum import IntEnum import struct +from bumble import core from ..gatt_client import ProfileServiceProxy from ..att import ATT_Error from ..gatt import ( @@ -59,17 +60,17 @@ class HeartRateService(TemplateService): rr_intervals=None, ): if heart_rate < 0 or heart_rate > 0xFFFF: - raise ValueError('heart_rate out of range') + raise core.InvalidArgumentError('heart_rate out of range') if energy_expended is not None and ( energy_expended < 0 or energy_expended > 0xFFFF ): - raise ValueError('energy_expended out of range') + raise core.InvalidArgumentError('energy_expended out of range') if rr_intervals: for rr_interval in rr_intervals: if rr_interval < 0 or rr_interval * 1024 > 0xFFFF: - raise ValueError('rr_intervals out of range') + raise core.InvalidArgumentError('rr_intervals out of range') self.heart_rate = heart_rate self.sensor_contact_detected = sensor_contact_detected diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 1020a1ea..a541a730 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -36,7 +36,9 @@ from .core import ( BT_RFCOMM_PROTOCOL_ID, BT_BR_EDR_TRANSPORT, BT_L2CAP_PROTOCOL_ID, + InvalidArgumentError, InvalidStateError, + InvalidPacketError, ProtocolError, ) @@ -333,7 +335,7 @@ class RFCOMM_Frame: frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information) if frame.fcs != fcs: logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}') - raise ValueError('fcs mismatch') + raise InvalidPacketError('fcs mismatch') return frame @@ -680,7 +682,7 @@ class DLC(EventEmitter): # Automatically convert strings to bytes using UTF-8 data = data.encode('utf-8') else: - raise ValueError('write only accept bytes or strings') + raise InvalidArgumentError('write only accept bytes or strings') self.tx_buffer += data self.drained.clear() diff --git a/bumble/sdp.py b/bumble/sdp.py index 35b4a3a1..a98a48cd 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -23,7 +23,7 @@ from typing_extensions import Self from . import core, l2cap from .colors import color -from .core import InvalidStateError +from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError from .hci import HCI_Object, name_or_number, key_with_value if TYPE_CHECKING: @@ -189,7 +189,9 @@ class DataElement: self.bytes = None if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER): if value_size is None: - raise ValueError('integer types must have a value size specified') + raise InvalidArgumentError( + 'integer types must have a value size specified' + ) @staticmethod def nil() -> DataElement: @@ -265,7 +267,7 @@ class DataElement: if len(data) == 8: return struct.unpack('>Q', data)[0] - raise ValueError(f'invalid integer length {len(data)}') + raise InvalidPacketError(f'invalid integer length {len(data)}') @staticmethod def signed_integer_from_bytes(data): @@ -281,7 +283,7 @@ class DataElement: if len(data) == 8: return struct.unpack('>q', data)[0] - raise ValueError(f'invalid integer length {len(data)}') + raise InvalidPacketError(f'invalid integer length {len(data)}') @staticmethod def list_from_bytes(data): @@ -354,7 +356,7 @@ class DataElement: data = b'' elif self.type == DataElement.UNSIGNED_INTEGER: if self.value < 0: - raise ValueError('UNSIGNED_INTEGER cannot be negative') + raise InvalidArgumentError('UNSIGNED_INTEGER cannot be negative') if self.value_size == 1: data = struct.pack('B', self.value) @@ -365,7 +367,7 @@ class DataElement: elif self.value_size == 8: data = struct.pack('>Q', self.value) else: - raise ValueError('invalid value_size') + raise InvalidArgumentError('invalid value_size') elif self.type == DataElement.SIGNED_INTEGER: if self.value_size == 1: data = struct.pack('b', self.value) @@ -376,7 +378,7 @@ class DataElement: elif self.value_size == 8: data = struct.pack('>q', self.value) else: - raise ValueError('invalid value_size') + raise InvalidArgumentError('invalid value_size') elif self.type == DataElement.UUID: data = bytes(reversed(bytes(self.value))) elif self.type == DataElement.URL: @@ -392,7 +394,7 @@ class DataElement: size_bytes = b'' if self.type == DataElement.NIL: if size != 0: - raise ValueError('NIL must be empty') + raise InvalidArgumentError('NIL must be empty') size_index = 0 elif self.type in ( DataElement.UNSIGNED_INTEGER, @@ -410,7 +412,7 @@ class DataElement: elif size == 16: size_index = 4 else: - raise ValueError('invalid data size') + raise InvalidArgumentError('invalid data size') elif self.type in ( DataElement.TEXT_STRING, DataElement.SEQUENCE, @@ -427,10 +429,10 @@ class DataElement: size_index = 7 size_bytes = struct.pack('>I', size) else: - raise ValueError('invalid data size') + raise InvalidArgumentError('invalid data size') elif self.type == DataElement.BOOLEAN: if size != 1: - raise ValueError('boolean must be 1 byte') + raise InvalidArgumentError('boolean must be 1 byte') size_index = 0 self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data diff --git a/bumble/smp.py b/bumble/smp.py index 3a88a31f..cf523e7b 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -55,6 +55,7 @@ from .core import ( BT_CENTRAL_ROLE, BT_LE_TRANSPORT, AdvertisingData, + InvalidArgumentError, ProtocolError, name_or_number, ) @@ -784,7 +785,7 @@ class Session: self.peer_oob_data = pairing_config.oob.peer_data if pairing_config.sc: if pairing_config.oob.our_context is None: - raise ValueError( + raise InvalidArgumentError( "oob pairing config requires a context when sc is True" ) self.r = pairing_config.oob.our_context.r @@ -793,7 +794,7 @@ class Session: self.tk = pairing_config.oob.legacy_context.tk else: if pairing_config.oob.legacy_context is None: - raise ValueError( + raise InvalidArgumentError( "oob pairing config requires a legacy context when sc is False" ) self.r = bytes(16) diff --git a/bumble/snoop.py b/bumble/snoop.py index 4b331d29..326603f5 100644 --- a/bumble/snoop.py +++ b/bumble/snoop.py @@ -23,6 +23,7 @@ import datetime from typing import BinaryIO, Generator import os +from bumble import core from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET @@ -138,13 +139,13 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]: """ if ':' not in spec: - raise ValueError('snooper type prefix missing') + raise core.InvalidArgumentError('snooper type prefix missing') snooper_type, snooper_args = spec.split(':', maxsplit=1) if snooper_type == 'btsnoop': if ':' not in snooper_args: - raise ValueError('I/O type for btsnoop snooper type missing') + raise core.InvalidArgumentError('I/O type for btsnoop snooper type missing') io_type, io_name = snooper_args.split(':', maxsplit=1) if io_type == 'file': @@ -165,6 +166,6 @@ def create_snooper(spec: str) -> Generator[Snooper, None, None]: _SNOOPER_INSTANCE_COUNT -= 1 return - raise ValueError(f'I/O type {io_type} not supported') + raise core.InvalidArgumentError(f'I/O type {io_type} not supported') - raise ValueError(f'snooper type {snooper_type} not found') + raise core.InvalidArgumentError(f'snooper type {snooper_type} not found') diff --git a/bumble/transport/__init__.py b/bumble/transport/__init__.py index 6a9a6b53..73414814 100644 --- a/bumble/transport/__init__.py +++ b/bumble/transport/__init__.py @@ -20,7 +20,7 @@ import logging import os from typing import Optional -from .common import Transport, AsyncPipeSink, SnoopingTransport +from .common import Transport, AsyncPipeSink, SnoopingTransport, TransportSpecError from ..snoop import create_snooper # ----------------------------------------------------------------------------- @@ -180,7 +180,7 @@ async def _open_transport(scheme: str, spec: Optional[str]) -> Transport: return await open_android_netsim_transport(spec) - raise ValueError('unknown transport scheme') + raise TransportSpecError('unknown transport scheme') # ----------------------------------------------------------------------------- diff --git a/bumble/transport/android_emulator.py b/bumble/transport/android_emulator.py index 9cd7ec21..d2bc8ef8 100644 --- a/bumble/transport/android_emulator.py +++ b/bumble/transport/android_emulator.py @@ -20,7 +20,13 @@ import grpc.aio from typing import Optional, Union -from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport +from .common import ( + PumpedTransport, + PumpedPacketSource, + PumpedPacketSink, + Transport, + TransportSpecError, +) # pylint: disable=no-name-in-module from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub @@ -77,7 +83,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport: elif ':' in param: server_host, server_port = param.split(':') else: - raise ValueError('invalid parameter') + raise TransportSpecError('invalid parameter') # Connect to the gRPC server server_address = f'{server_host}:{server_port}' @@ -94,7 +100,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport: service = VhciForwardingServiceStub(channel) hci_device = HciDevice(service.attachVhci()) else: - raise ValueError('invalid mode') + raise TransportSpecError('invalid mode') # Create the transport object class EmulatorTransport(PumpedTransport): diff --git a/bumble/transport/android_netsim.py b/bumble/transport/android_netsim.py index e9d36cd5..264266df 100644 --- a/bumble/transport/android_netsim.py +++ b/bumble/transport/android_netsim.py @@ -31,6 +31,8 @@ from .common import ( PumpedPacketSource, PumpedPacketSink, Transport, + TransportSpecError, + TransportInitError, ) # pylint: disable=no-name-in-module @@ -135,7 +137,7 @@ async def open_android_netsim_controller_transport( server_host: Optional[str], server_port: int, options: Dict[str, str] ) -> Transport: if not server_port: - raise ValueError('invalid port') + raise TransportSpecError('invalid port') if server_host == '_' or not server_host: server_host = 'localhost' @@ -288,7 +290,7 @@ async def open_android_netsim_host_transport_with_address( instance_number = 0 if options is None else int(options.get('instance', '0')) server_port = find_grpc_port(instance_number) if not server_port: - raise RuntimeError('gRPC server port not found') + raise TransportInitError('gRPC server port not found') # Connect to the gRPC server server_address = f'{server_host}:{server_port}' @@ -326,7 +328,7 @@ async def open_android_netsim_host_transport_with_channel( if response_type == 'error': logger.warning(f'received error: {response.error}') - raise RuntimeError(response.error) + raise TransportInitError(response.error) if response_type == 'hci_packet': return ( @@ -334,7 +336,7 @@ async def open_android_netsim_host_transport_with_channel( + response.hci_packet.packet ) - raise ValueError('unsupported response type') + raise TransportSpecError('unsupported response type') async def write(self, packet): await self.hci_device.write( @@ -429,7 +431,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport: options: Dict[str, str] = {} for param in params[params_offset:]: if '=' not in param: - raise ValueError('invalid parameter, expected =') + raise TransportSpecError('invalid parameter, expected =') option_name, option_value = param.split('=') options[option_name] = option_value @@ -440,7 +442,7 @@ async def open_android_netsim_transport(spec: Optional[str]) -> Transport: ) if mode == 'controller': if host is None: - raise ValueError(': missing') + raise TransportSpecError(': missing') return await open_android_netsim_controller_transport(host, port, options) - raise ValueError('invalid mode option') + raise TransportSpecError('invalid mode option') diff --git a/bumble/transport/common.py b/bumble/transport/common.py index ffbf7b07..60286b2d 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -23,6 +23,7 @@ import logging import io from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict +from bumble import core from bumble import hci from bumble.colors import color from bumble.snoop import Snooper @@ -49,10 +50,16 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = { # ----------------------------------------------------------------------------- # Errors # ----------------------------------------------------------------------------- -class TransportLostError(Exception): - """ - The Transport has been lost/disconnected. - """ +class TransportLostError(core.BaseBumbleError, RuntimeError): + """The Transport has been lost/disconnected.""" + + +class TransportInitError(core.BaseBumbleError, RuntimeError): + """Error raised when the transport cannot be initialized.""" + + +class TransportSpecError(core.BaseBumbleError, ValueError): + """Error raised when the transport spec is invalid.""" # ----------------------------------------------------------------------------- @@ -132,7 +139,9 @@ class PacketParser: packet_type ) or self.extended_packet_info.get(packet_type) if self.packet_info is None: - raise ValueError(f'invalid packet type {packet_type}') + raise core.InvalidPacketError( + f'invalid packet type {packet_type}' + ) self.state = PacketParser.NEED_LENGTH self.bytes_needed = self.packet_info[0] + self.packet_info[1] elif self.state == PacketParser.NEED_LENGTH: @@ -178,19 +187,19 @@ class PacketReader: # Get the packet info based on its type packet_info = HCI_PACKET_INFO.get(packet_type[0]) if packet_info is None: - raise ValueError(f'invalid packet type {packet_type[0]} found') + raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found') # Read the header (that includes the length) header_size = packet_info[0] + packet_info[1] header = self.source.read(header_size) if len(header) != header_size: - raise ValueError('packet too short') + raise core.InvalidPacketError('packet too short') # Read the body body_length = struct.unpack_from(packet_info[2], header, packet_info[1])[0] body = self.source.read(body_length) if len(body) != body_length: - raise ValueError('packet too short') + raise core.InvalidPacketError('packet too short') return packet_type + header + body @@ -211,7 +220,7 @@ class AsyncPacketReader: # Get the packet info based on its type packet_info = HCI_PACKET_INFO.get(packet_type[0]) if packet_info is None: - raise ValueError(f'invalid packet type {packet_type[0]} found') + raise core.InvalidPacketError(f'invalid packet type {packet_type[0]} found') # Read the header (that includes the length) header_size = packet_info[0] + packet_info[1] @@ -420,7 +429,7 @@ class SnoopingTransport(Transport): return SnoopingTransport( transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close ) - raise RuntimeError('unexpected code path') # Satisfy the type checker + raise core.UnreachableError() # Satisfy the type checker class Source: sink: TransportSink diff --git a/bumble/transport/pyusb.py b/bumble/transport/pyusb.py index 68a1dfd9..8f327072 100644 --- a/bumble/transport/pyusb.py +++ b/bumble/transport/pyusb.py @@ -29,7 +29,7 @@ from usb.core import USBError from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB -from .common import Transport, ParserSource +from .common import Transport, ParserSource, TransportInitError from .. import hci from ..colors import color @@ -259,7 +259,7 @@ async def open_pyusb_transport(spec: str) -> Transport: device = None if device is None: - raise ValueError('device not found') + raise TransportInitError('device not found') logger.debug(f'USB Device: {device}') # Power Cycle the device diff --git a/bumble/transport/usb.py b/bumble/transport/usb.py index 69e9649c..e3de98c7 100644 --- a/bumble/transport/usb.py +++ b/bumble/transport/usb.py @@ -24,10 +24,9 @@ import platform import usb1 -from bumble.transport.common import Transport, ParserSource +from bumble.transport.common import Transport, ParserSource, TransportInitError from bumble import hci from bumble.colors import color -from bumble.utils import AsyncRunner # ----------------------------------------------------------------------------- @@ -442,7 +441,7 @@ async def open_usb_transport(spec: str) -> Transport: if found is None: context.close() - raise ValueError('device not found') + raise TransportInitError('device not found') logger.debug(f'USB Device: {found}') @@ -507,7 +506,7 @@ async def open_usb_transport(spec: str) -> Transport: endpoints = find_endpoints(found) if endpoints is None: - raise ValueError('no compatible interface found for device') + raise TransportInitError('no compatible interface found for device') (configuration, interface, setting, acl_in, acl_out, events_in) = endpoints logger.debug( f'selected endpoints: configuration={configuration}, '