diff --git a/bumble/hci.py b/bumble/hci.py index a340bd5b..4a6ae58d 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -2291,7 +2291,7 @@ class HCI_Command(HCI_Packet): hci_packet_type = HCI_COMMAND_PACKET command_names: dict[int, str] = {} command_classes: dict[int, type[HCI_Command]] = {} - op_code: int = -1 + op_code: int fields: Fields = () return_parameters_fields: Fields = () _parameters: bytes = b'' @@ -2304,10 +2304,14 @@ class HCI_Command(HCI_Packet): Decorator used to declare and register subclasses ''' - subclass.name = subclass.__name__.upper() - subclass.op_code = key_with_value(subclass.command_names, subclass.name) - if subclass.op_code is None: - raise KeyError(f'command {subclass.name} not found in command_names') + # Subclasses may set parameters as ClassVar, or inferred from class name. + if not hasattr(subclass, 'name'): + subclass.name = subclass.__name__.upper() + if not hasattr(subclass, 'op_code'): + op_code = key_with_value(subclass.command_names, subclass.name) + if op_code is None: + raise KeyError(f'command {subclass.name} not found in command_names') + subclass.op_code = op_code if dataclasses.is_dataclass(subclass): subclass.fields = HCI_Object.fields_from_dataclass(subclass) @@ -2350,11 +2354,12 @@ class HCI_Command(HCI_Packet): command.parameters = parameters return command - @staticmethod - def command_name(op_code): - name = HCI_Command.command_names.get(op_code) - if name is not None: + @classmethod + def command_name(cls, op_code: int) -> str: + if name := cls.command_names.get(op_code): return name + if (subclass := cls.command_classes.get(op_code)) and subclass.name: + return subclass.name return f'[OGF=0x{op_code >> 10:02x}, OCF=0x{op_code & 0x3FF:04x}]' @classmethod @@ -5598,7 +5603,7 @@ class HCI_Event(HCI_Packet): event_names: dict[int, str] = {} event_classes: dict[int, type[HCI_Event]] = {} vendor_factories: list[Callable[[bytes], Optional[HCI_Event]]] = [] - event_code: int = -1 + event_code: int fields: Fields = () _parameters: bytes = b'' @@ -5609,12 +5614,17 @@ class HCI_Event(HCI_Packet): ''' Decorator used to declare and register subclasses ''' + # Subclasses may set parameters as ClassVar, or inferred from class name. + if not hasattr(subclass, 'name'): + subclass.name = subclass.__name__.upper() + if not hasattr(subclass, 'event_code'): + event_code = key_with_value(subclass.event_names, subclass.name) + if event_code is None: + raise KeyError(f'event {subclass.name} not found in event_names') + subclass.event_code = event_code - subclass.name = subclass.__name__.upper() - subclass.event_code = key_with_value(subclass.event_names, subclass.name) - subclass.fields = HCI_Object.fields_from_dataclass(subclass) - if subclass.event_code is None: - raise KeyError(f'event {subclass.name} not found in event_names') + if dataclasses.is_dataclass(subclass): + subclass.fields = HCI_Object.fields_from_dataclass(subclass) # Register a factory for this class cls.event_classes[subclass.event_code] = subclass @@ -5630,9 +5640,11 @@ class HCI_Event(HCI_Packet): and event_name.endswith('_EVENT') } - @staticmethod - def event_name(event_code): - return name_or_number(HCI_Event.event_names, event_code) + @classmethod + def event_name(cls, event_code: int) -> str: + if (subclass := cls.event_classes.get(event_code)) and subclass.name: + return subclass.name + return name_or_number(cls.event_names, event_code) @staticmethod def register_events(symbols: dict[str, Any]) -> None: @@ -5758,7 +5770,7 @@ class HCI_Extended_Event(HCI_Event): subevent_names: dict[int, str] = {} subevent_classes: dict[int, type[HCI_Extended_Event]] = {} - subevent_code: int = -1 + subevent_code: int _parameters: bytes = b'' _ExtendedEvent = TypeVar("_ExtendedEvent", bound="HCI_Extended_Event") @@ -5769,14 +5781,20 @@ class HCI_Extended_Event(HCI_Event): ''' Decorator used to declare and register subclasses ''' - subclass.name = subclass.__name__.upper() - subclass.subevent_code = key_with_value(subclass.subevent_names, subclass.name) - if subclass.subevent_code is None: - raise KeyError(f'subevent {subclass.name} not found in subevent_names') + # Subclasses may set parameters as ClassVar, or inferred from class name. + if not hasattr(subclass, 'name'): + subclass.name = subclass.__name__.upper() + if not hasattr(subclass, 'subevent_code'): + subevent_code = key_with_value(subclass.subevent_names, subclass.name) + if subevent_code is None: + raise KeyError(f'subevent {subclass.name} not found in subevent_names') + subclass.subevent_code = subevent_code + + if dataclasses.is_dataclass(subclass): + subclass.fields = HCI_Object.fields_from_dataclass(subclass) # Register a factory for this class cls.subevent_classes[subclass.subevent_code] = subclass - subclass.fields = HCI_Object.fields_from_dataclass(subclass) return subclass @@ -5793,10 +5811,11 @@ class HCI_Extended_Event(HCI_Event): self._parameters = parameters @classmethod - def subevent_name(cls, subevent_code): - subevent_name = cls.subevent_names.get(subevent_code) - if subevent_name is not None: + def subevent_name(cls, subevent_code: int) -> str: + if subevent_name := cls.subevent_names.get(subevent_code): return subevent_name + if (subclass := cls.subevent_classes.get(subevent_code)) and subclass.name: + return subclass.name return f'{cls.__name__.upper()}[0x{subevent_code:02X}]' diff --git a/tests/hci_test.py b/tests/hci_test.py index b1ca1908..45315883 100644 --- a/tests/hci_test.py +++ b/tests/hci_test.py @@ -239,6 +239,51 @@ def test_HCI_Command(): basic_check(command) +# ----------------------------------------------------------------------------- +def test_custom_command(): + @hci.HCI_Command.command + class CustomCommand(hci.HCI_Command): + op_code = 0x7788 + name = 'Custom Command' + + command = CustomCommand() + basic_check(command) + parsed = hci.HCI_Packet.from_bytes(bytes(command)) + assert isinstance(parsed, CustomCommand) + assert parsed.op_code == 0x7788 + assert parsed.name == 'Custom Command' + + +# ----------------------------------------------------------------------------- +def test_custom_event(): + @hci.HCI_Event.event + class CustomEvent(hci.HCI_Event): + event_code = 0x99 + name = 'Custom Event' + + event = CustomEvent() + basic_check(event) + parsed = hci.HCI_Packet.from_bytes(bytes(event)) + assert isinstance(parsed, CustomEvent) + assert parsed.event_code == 0x99 + assert parsed.name == 'Custom Event' + + +# ----------------------------------------------------------------------------- +def test_custom_le_meta_event(): + @hci.HCI_LE_Meta_Event.event + class CustomEvent(hci.HCI_LE_Meta_Event): + subevent_code = 0xFF + name = 'Custom Extended Event' + + event = CustomEvent() + basic_check(event) + parsed = hci.HCI_Packet.from_bytes(bytes(event)) + assert isinstance(parsed, CustomEvent) + assert parsed.subevent_code == 0xFF + assert parsed.name == 'Custom Extended Event' + + # ----------------------------------------------------------------------------- @pytest.mark.parametrize( "clazz,",