diff --git a/bumble/controller.py b/bumble/controller.py index 7856113..39b330c 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -108,7 +108,9 @@ class Connection: def on_hci_acl_data_packet(self, packet): self.assembler.feed_packet(packet) self.controller.send_hci_packet( - HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)]) + HCI_Number_Of_Completed_Packets_Event( + connection_handles=[self.handle], num_completed_packets=[1] + ) ) def on_acl_pdu(self, data): diff --git a/bumble/hci.py b/bumble/hci.py index e4065c9..5720b5d 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -23,7 +23,8 @@ import functools import logging import secrets import struct -from typing import Any, Callable, Iterable, Optional, Union, ClassVar +from collections.abc import Sequence +from typing import Any, Callable, Iterable, Optional, Union, TypeVar, ClassVar from typing_extensions import Self from bumble import crypto @@ -120,6 +121,7 @@ def phy_list_to_bits(phys: Optional[Iterable[Phy]]) -> int: # - an integer [1, 4] for 1-byte, 2-byte or 4-byte unsigned little-endian integers # - an integer [-2, -1] for 1-byte, 2-byte signed little-endian integers FieldSpec = Union[dict[str, Any], Callable[[bytes, int], tuple[int, Any]], str, int] +Fields = Sequence[Union[tuple[str, FieldSpec], 'Fields']] @dataclasses.dataclass @@ -2271,29 +2273,30 @@ class HCI_Command(HCI_Packet): hci_packet_type = HCI_COMMAND_PACKET command_names: dict[int, str] = {} command_classes: dict[int, type[HCI_Command]] = {} - op_code: int + op_code: int = -1 + fields: Fields = () + return_parameters_fields: Fields = () @staticmethod - def command(fields=(), return_parameters_fields=()): + def command( + fields: Optional[Fields] = None, + return_parameters_fields: Optional[Fields] = None, + ): ''' Decorator used to declare and register subclasses ''' - def inner(cls): + Command = TypeVar("Command", bound=HCI_Command) + + def inner(cls: type[Command]) -> type[Command]: cls.name = cls.__name__.upper() cls.op_code = key_with_value(cls.command_names, cls.name) if cls.op_code is None: raise KeyError(f'command {cls.name} not found in command_names') - cls.fields = fields - cls.return_parameters_fields = return_parameters_fields - - # Patch the __init__ method to fix the op_code if fields is not None: - - def init(self, parameters=None, **kwargs): - return HCI_Command.__init__(self, cls.op_code, parameters, **kwargs) - - cls.__init__ = init + cls.fields = fields + if return_parameters_fields is not None: + cls.return_parameters_fields = return_parameters_fields # Register a factory for this class HCI_Command.command_classes[cls.op_code] = cls @@ -2325,16 +2328,15 @@ class HCI_Command(HCI_Packet): cls = HCI_Command.command_classes.get(op_code) if cls is None: # No class registered, just use a generic instance - return HCI_Command(op_code, parameters) + return HCI_Command(parameters, op_code=op_code) - # Create a new instance - if (fields := getattr(cls, 'fields', None)) is not None: - self = cls.__new__(cls) - HCI_Command.__init__(self, op_code, parameters) - HCI_Object.init_from_bytes(self, parameters, 0, fields) - return self + return cls.from_parameters(parameters) - return cls.from_parameters(parameters) # type: ignore + @classmethod + def from_parameters(cls, parameters: bytes) -> HCI_Command: + command = cls(parameters) + HCI_Object.init_from_bytes(command, parameters, 0, cls.fields) + return command @staticmethod def command_name(op_code): @@ -2357,18 +2359,22 @@ class HCI_Command(HCI_Packet): return_parameters.fields = cls.return_parameters_fields return return_parameters - def __init__(self, op_code=-1, parameters=None, **kwargs): - # Since the legacy implementation relies on an __init__ injector, typing always - # complains that positional argument op_code is not passed, so here sets a - # default value to allow building derived HCI_Command without op_code. - assert op_code != -1 - super().__init__(HCI_Command.command_name(op_code)) - if (fields := getattr(self, 'fields', None)) and kwargs: - HCI_Object.init_from_fields(self, fields, kwargs) + def __init__( + self, + parameters: Optional[bytes] = None, + *, + op_code: Optional[int] = None, + **kwargs, + ) -> None: + # op_code should be set in cls. + if op_code is not None: + self.op_code = op_code + super().__init__(HCI_Command.command_name(self.op_code)) + if self.fields and kwargs: + HCI_Object.init_from_fields(self, self.fields, kwargs) if parameters is None: - parameters = HCI_Object.dict_to_bytes(kwargs, fields) - self.op_code = op_code - self.parameters = parameters + parameters = HCI_Object.dict_to_bytes(kwargs, self.fields) + self.parameters = parameters or b'' def __bytes__(self): parameters = b'' if self.parameters is None else self.parameters @@ -2379,8 +2385,8 @@ class HCI_Command(HCI_Packet): def __str__(self): result = color(self.name, 'green') - if fields := getattr(self, 'fields', None): - result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ') + if self.fields: + result += ':\n' + HCI_Object.format_fields(self.__dict__, self.fields, ' ') else: if self.parameters: result += f': {self.parameters.hex()}' @@ -2924,6 +2930,7 @@ class HCI_Set_Connectionless_Peripheral_Broadcast_Receive_Command(HCI_Command): # ----------------------------------------------------------------------------- +@HCI_Command.command() class HCI_Start_Synchronization_Train_Command(HCI_Command): ''' See Bluetooth spec @ 7.1.51 Start Synchronization Train Command @@ -4360,13 +4367,7 @@ class HCI_LE_Set_Extended_Advertising_Parameters_Command(HCI_Command): }, ), ('fragment_preference', 1), - ( - 'advertising_data', - { - 'parser': HCI_Object.parse_length_prefixed_bytes, - 'serializer': HCI_Object.serialize_length_prefixed_bytes, - }, - ), + ('advertising_data', 'v'), ] ) class HCI_LE_Set_Extended_Advertising_Data_Command(HCI_Command): @@ -4397,13 +4398,7 @@ class HCI_LE_Set_Extended_Advertising_Data_Command(HCI_Command): }, ), ('fragment_preference', 1), - ( - 'scan_response_data', - { - 'parser': HCI_Object.parse_length_prefixed_bytes, - 'serializer': HCI_Object.serialize_length_prefixed_bytes, - }, - ), + ('scan_response_data', 'v'), ] ) class HCI_LE_Set_Extended_Scan_Response_Data_Command(HCI_Command): @@ -4507,13 +4502,7 @@ class HCI_LE_Set_Periodic_Advertising_Parameters_Command(HCI_Command): ).name, }, ), - ( - 'advertising_data', - { - 'parser': HCI_Object.parse_length_prefixed_bytes, - 'serializer': HCI_Object.serialize_length_prefixed_bytes, - }, - ), + ('advertising_data', 'v'), ] ) class HCI_LE_Set_Periodic_Advertising_Data_Command(HCI_Command): @@ -4549,8 +4538,10 @@ class HCI_LE_Set_Extended_Scan_Parameters_Command(HCI_Command): EXTENDED_UNFILTERED_POLICY = 0x02 EXTENDED_FILTERED_POLICY = 0x03 + op_code = HCI_LE_SET_EXTENDED_SCAN_PARAMETERS_COMMAND + @classmethod - def from_parameters(cls, parameters): + def from_parameters(cls, parameters: bytes) -> Self: own_address_type = parameters[0] scanning_filter_policy = parameters[1] scanning_phys = parameters[2] @@ -4586,7 +4577,7 @@ class HCI_LE_Set_Extended_Scan_Parameters_Command(HCI_Command): scan_intervals, scan_windows, ): - super().__init__(HCI_LE_SET_EXTENDED_SCAN_PARAMETERS_COMMAND) + super().__init__() self.own_address_type = own_address_type self.scanning_filter_policy = scanning_filter_policy self.scanning_phys = scanning_phys @@ -4660,8 +4651,10 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command): See Bluetooth spec @ 7.8.66 LE Extended Create Connection Command ''' + op_code = HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND + @classmethod - def from_parameters(cls, parameters): + def from_parameters(cls, parameters: bytes) -> Self: initiator_filter_policy = parameters[0] own_address_type = parameters[1] peer_address_type = parameters[2] @@ -4708,7 +4701,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command): min_ce_lengths, max_ce_lengths, ): - super().__init__(HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND) + super().__init__() self.initiator_filter_policy = initiator_filter_policy self.own_address_type = own_address_type self.peer_address_type = peer_address_type @@ -5560,25 +5553,24 @@ class HCI_Event(HCI_Packet): event_names: dict[int, str] = {} event_classes: dict[int, type[HCI_Event]] = {} vendor_factories: list[Callable[[bytes], Optional[HCI_Event]]] = [] + event_code: int = -1 + fields: Fields = () @staticmethod - def event(fields=()): + def event(fields: Optional[Fields] = ()): ''' Decorator used to declare and register subclasses ''' - def inner(cls): + Event = TypeVar("Event", bound=HCI_Event) + + def inner(cls: type[Event]) -> type[Event]: cls.name = cls.__name__.upper() cls.event_code = key_with_value(cls.event_names, cls.name) if cls.event_code is None: raise KeyError(f'event {cls.name} not found in event_names') - cls.fields = fields - - # Patch the __init__ method to fix the event_code - def init(self, parameters=None, **kwargs): - return HCI_Event.__init__(self, cls.event_code, parameters, **kwargs) - - cls.__init__ = init + if fields is not None: + cls.fields = fields # Register a factory for this class HCI_Event.event_classes[cls.event_code] = cls @@ -5638,7 +5630,7 @@ class HCI_Event(HCI_Packet): if len(parameters) != length: raise InvalidPacketError('invalid packet length') - subclass: Any + subclass: Optional[type[HCI_Event]] if event_code == HCI_LE_META_EVENT: # We do this dispatch here and not in the subclass in order to avoid call # loops @@ -5646,7 +5638,9 @@ class HCI_Event(HCI_Packet): subclass = HCI_LE_Meta_Event.subevent_classes.get(subevent_code) if subclass is None: # No class registered, just use a generic class instance - return HCI_LE_Meta_Event(subevent_code, parameters) + return HCI_LE_Meta_Event( + subevent_code=subevent_code, parameters=parameters + ) elif event_code == HCI_VENDOR_EVENT: # Invoke all the registered factories to see if any of them can handle # the event @@ -5661,31 +5655,33 @@ class HCI_Event(HCI_Packet): subclass = HCI_Event.event_classes.get(event_code) if subclass is None: # No class registered, just use a generic class instance - return HCI_Event(event_code, parameters) + return HCI_Event(event_code=event_code, parameters=parameters) # Invoke the factory to create a new instance - return subclass.from_parameters(parameters) # type: ignore + return subclass.from_parameters(parameters) @classmethod - def from_parameters(cls, parameters): - self = cls.__new__(cls) - HCI_Event.__init__(self, self.event_code, parameters) - if fields := getattr(self, 'fields', None): - HCI_Object.init_from_bytes(self, parameters, 0, fields) + def from_parameters(cls, parameters: bytes) -> Self: + self = cls(parameters) + if self.fields: + HCI_Object.init_from_bytes(self, parameters, 0, self.fields) return self - def __init__(self, event_code=-1, parameters=None, **kwargs): - # Since the legacy implementation relies on an __init__ injector, typing always - # complains that positional argument event_code is not passed, so here sets a - # default value to allow building derived HCI_Event without event_code. - assert event_code != -1 - super().__init__(HCI_Event.event_name(event_code)) - if (fields := getattr(self, 'fields', None)) and kwargs: - HCI_Object.init_from_fields(self, fields, kwargs) + def __init__( + self, + parameters: Optional[bytes] = None, + *, + event_code: Optional[int] = None, + **kwargs, + ): + if event_code is not None: + self.event_code = event_code + super().__init__(HCI_Event.event_name(self.event_code)) + if self.fields and kwargs: + HCI_Object.init_from_fields(self, self.fields, kwargs) if parameters is None: - parameters = HCI_Object.dict_to_bytes(kwargs, fields) - self.event_code = event_code - self.parameters = parameters + parameters = HCI_Object.dict_to_bytes(kwargs, self.fields) + self.parameters = parameters or b'' def __bytes__(self): parameters = b'' if self.parameters is None else self.parameters @@ -5693,8 +5689,8 @@ class HCI_Event(HCI_Packet): def __str__(self): result = color(self.name, 'magenta') - if fields := getattr(self, 'fields', None): - result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ') + if self.fields: + result += ':\n' + HCI_Object.format_fields(self.__dict__, self.fields, ' ') else: if self.parameters: result += f': {self.parameters.hex()}' @@ -5712,27 +5708,23 @@ class HCI_Extended_Event(HCI_Event): subevent_names: dict[int, str] = {} subevent_classes: dict[int, type[HCI_Extended_Event]] = {} + subevent_code: int = -1 @classmethod - def event(cls, fields=()): + def event(cls, fields: Optional[Fields] = None): ''' Decorator used to declare and register subclasses ''' - def inner(cls): + ExtendedEvent = TypeVar("ExtendedEvent", bound=HCI_Extended_Event) + + def inner(cls: type[ExtendedEvent]) -> type[ExtendedEvent]: cls.name = cls.__name__.upper() cls.subevent_code = key_with_value(cls.subevent_names, cls.name) if cls.subevent_code is None: raise KeyError(f'subevent {cls.name} not found in subevent_names') - cls.fields = fields - - # Patch the __init__ method to fix the subevent_code - original_init = cls.__init__ - - def init(self, parameters=None, **kwargs): - return original_init(self, cls.subevent_code, parameters, **kwargs) - - cls.__init__ = init + if fields is not None: + cls.fields = fields # Register a factory for this class cls.subevent_classes[cls.subevent_code] = cls @@ -5778,23 +5770,28 @@ class HCI_Extended_Event(HCI_Event): @classmethod def from_parameters(cls, parameters: bytes) -> HCI_Extended_Event: """Factory method for subclasses (the subevent code has already been parsed)""" - self = cls.__new__(cls) - HCI_Extended_Event.__init__(self, self.subevent_code, parameters) - if fields := getattr(self, 'fields', None): - HCI_Object.init_from_bytes(self, parameters, 1, fields) - return self + event = cls(parameters) + if event.fields: + HCI_Object.init_from_bytes(event, parameters, 1, event.fields) + return event - def __init__(self, subevent_code=None, parameters=None, **kwargs): - assert subevent_code is not None - self.subevent_code = subevent_code - if parameters is None and (fields := getattr(self, 'fields', None)) and kwargs: - parameters = bytes([subevent_code]) + HCI_Object.dict_to_bytes( - kwargs, fields + def __init__( + self, + parameters: Optional[bytes] = None, + *, + subevent_code: Optional[int] = None, + **kwargs, + ) -> None: + if subevent_code is not None: + self.subevent_code = subevent_code + if parameters is None and self.fields and kwargs: + parameters = bytes([self.subevent_code]) + HCI_Object.dict_to_bytes( + kwargs, self.fields ) - super().__init__(self.event_code, parameters, **kwargs) + super().__init__(parameters, **kwargs) # Override the name in order to adopt the subevent name instead - self.name = self.subevent_name(subevent_code) + self.name = self.subevent_name(self.subevent_code) # ----------------------------------------------------------------------------- @@ -5848,6 +5845,7 @@ class HCI_LE_Advertising_Report_Event(HCI_LE_Meta_Event): ''' subevent_code = HCI_LE_ADVERTISING_REPORT_EVENT + name = 'HCI_LE_ADVERTISING_REPORT_EVENT' # Event Types ADV_IND = 0x00 @@ -5869,13 +5867,7 @@ class HCI_LE_Advertising_Report_Event(HCI_LE_Meta_Event): ('event_type', 1), ('address_type', Address.ADDRESS_TYPE_SPEC), ('address', Address.parse_address_preceded_by_type), - ( - 'data', - { - 'parser': HCI_Object.parse_length_prefixed_bytes, - 'serializer': HCI_Object.serialize_length_prefixed_bytes, - }, - ), + ('data', 'v'), ('rssi', -1), ] @@ -5907,7 +5899,7 @@ class HCI_LE_Advertising_Report_Event(HCI_LE_Meta_Event): return name_or_number(cls.EVENT_TYPE_NAMES, event_type) @classmethod - def from_parameters(cls, parameters): + def from_parameters(cls, parameters: bytes) -> Self: num_reports = parameters[1] reports = [] offset = 2 @@ -5926,7 +5918,7 @@ class HCI_LE_Advertising_Report_Event(HCI_LE_Meta_Event): [bytes(report) for report in reports] ) - super().__init__(self.subevent_code, parameters) + super().__init__(parameters) def __str__(self): reports = '\n'.join( @@ -6082,6 +6074,7 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event): ''' subevent_code = HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT + name = 'HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT' # Event types flags CONNECTABLE_ADVERTISING = 0 @@ -6130,13 +6123,7 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event): ('periodic_advertising_interval', 2), ('direct_address_type', Address.ADDRESS_TYPE_SPEC), ('direct_address', Address.parse_address_preceded_by_type), - ( - 'data', - { - 'parser': HCI_Object.parse_length_prefixed_bytes, - 'serializer': HCI_Object.serialize_length_prefixed_bytes, - }, - ), + ('data', 'v'), ] @classmethod @@ -6194,9 +6181,9 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event): return f'0x{event_type:04X} [{",".join(event_type_flags)}]{legacy_info_string}' @classmethod - def from_parameters(cls, parameters): + def from_parameters(cls, parameters: bytes) -> Self: num_reports = parameters[1] - reports = [] + reports: list[HCI_LE_Extended_Advertising_Report_Event.Report] = [] offset = 2 for _ in range(num_reports): report = cls.Report.from_parameters(parameters, offset) @@ -6205,7 +6192,9 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event): return cls(reports) - def __init__(self, reports): + def __init__( + self, reports: Sequence[HCI_LE_Extended_Advertising_Report_Event.Report] + ): self.reports = reports[:] # Serialize the fields @@ -6213,7 +6202,7 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event): [HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT, len(reports)] ) + b''.join([bytes(report) for report in reports]) - super().__init__(self.subevent_code, parameters) + super().__init__(parameters) def __str__(self): reports = '\n'.join( @@ -6856,50 +6845,27 @@ class HCI_Inquiry_Complete_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.registered +@HCI_Event.event( + [ + [ + ('bd_addr', Address.parse_address), + ('page_scan_repetition_mode', 1), + ('reserved', 1), + ('reserved', 1), + ('class_of_device', {'size': 3, 'mapper': map_class_of_device}), + ('clock_offset', 2), + ] + ] +) class HCI_Inquiry_Result_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.2 Inquiry Result Event ''' - RESPONSE_FIELDS = [ - ('bd_addr', Address.parse_address), - ('page_scan_repetition_mode', 1), - ('reserved', 1), - ('reserved', 1), - ('class_of_device', {'size': 3, 'mapper': map_class_of_device}), - ('clock_offset', 2), - ] - - @staticmethod - def from_parameters(parameters): - num_responses = parameters[0] - responses = [] - offset = 1 - for _ in range(num_responses): - response = HCI_Object.from_bytes( - parameters, offset, HCI_Inquiry_Result_Event.RESPONSE_FIELDS - ) - offset += 14 - responses.append(response) - - return HCI_Inquiry_Result_Event(responses) - - def __init__(self, responses): - self.responses = responses[:] - - # Serialize the fields - parameters = bytes([HCI_INQUIRY_RESULT_EVENT, len(responses)]) + b''.join( - [bytes(response) for response in responses] - ) - - super().__init__(HCI_INQUIRY_RESULT_EVENT, parameters) - - def __str__(self): - responses = '\n'.join( - [response.to_string(indentation=' ') for response in self.responses] - ) - return f'{color("HCI_INQUIRY_RESULT_EVENT", "magenta")}:\n{responses}' + bd_addr: list[Address] + page_scan_repetition_mode: list[int] + class_of_device: list[int] + clock_offset: list[int] # ----------------------------------------------------------------------------- @@ -7147,29 +7113,28 @@ class HCI_Command_Complete_Event(HCI_Event): @staticmethod def from_parameters(parameters): - self = HCI_Command_Complete_Event.__new__(HCI_Command_Complete_Event) - HCI_Event.__init__(self, self.event_code, parameters) + event = HCI_Command_Complete_Event(parameters) HCI_Object.init_from_bytes( - self, parameters, 0, HCI_Command_Complete_Event.fields + event, parameters, 0, HCI_Command_Complete_Event.fields ) # Parse the return parameters if ( - isinstance(self.return_parameters, bytes) - and len(self.return_parameters) == 1 + isinstance(event.return_parameters, bytes) + and len(event.return_parameters) == 1 ): # All commands with 1-byte return parameters return a 'status' field, # convert it to an integer - self.return_parameters = self.return_parameters[0] + event.return_parameters = event.return_parameters[0] else: - cls = HCI_Command.command_classes.get(self.command_opcode) + cls = HCI_Command.command_classes.get(event.command_opcode) if cls: # Try to parse the return parameters bytes into an object. - return_parameters = cls.parse_return_parameters(self.return_parameters) + return_parameters = cls.parse_return_parameters(event.return_parameters) if return_parameters is not None: - self.return_parameters = return_parameters + event.return_parameters = return_parameters - return self + return event def __str__(self): return f'{color(self.name, "magenta")}:\n' + HCI_Object.format_fields( @@ -7222,56 +7187,21 @@ class HCI_Role_Change_Event(HCI_Event): # ----------------------------------------------------------------------------- -@HCI_Event.registered +@HCI_Event.event( + [ + [ + ('connection_handles', 2), + ('num_completed_packets', 2), + ] + ] +) class HCI_Number_Of_Completed_Packets_Event(HCI_Event): ''' See Bluetooth spec @ 7.7.19 Number Of Completed Packets Event ''' - @classmethod - def from_parameters(cls, parameters): - self = cls.__new__(cls) - self.parameters = parameters - num_handles = parameters[0] - self.connection_handles = [] - self.num_completed_packets = [] - for i in range(num_handles): - self.connection_handles.append( - struct.unpack_from(' 0 + assert isinstance(clazz.name, str) + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + "clazz,", + [ + clazz[1] + for clazz in inspect.getmembers(hci) + if isinstance(clazz[1], type) + and clazz[1] is not hci.HCI_Event + and issubclass(clazz[1], hci.HCI_Event) + and not issubclass(clazz[1], hci.HCI_Extended_Event) + ], +) +def test_hci_event_subclasses_event_code(clazz: type[hci.HCI_Event]): + assert clazz.event_code > 0 + assert isinstance(clazz.name, str) + + +# ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + "clazz,", + [ + clazz[1] + for clazz in inspect.getmembers(hci) + if isinstance(clazz[1], type) + and issubclass(clazz[1], hci.HCI_Extended_Event) + and clazz[1] not in (hci.HCI_Extended_Event, hci.HCI_LE_Meta_Event) + ], +) +def test_hci_extended_event_subclasses_event_code(clazz: type[hci.HCI_Extended_Event]): + assert clazz.event_code > 0 + assert clazz.subevent_code > 0 + assert isinstance(clazz.name, str) + + # ----------------------------------------------------------------------------- def test_HCI_PIN_Code_Request_Reply_Command(): pin_code = b'1234'