add support for field arrays in hci packet definitions

This commit is contained in:
Gilles Boccon-Gibod
2023-07-30 22:19:10 -07:00
parent 8eeb58e467
commit bdad225033
2 changed files with 251 additions and 221 deletions

View File

@@ -1445,7 +1445,13 @@ class HCI_Object:
@staticmethod @staticmethod
def init_from_fields(hci_object, fields, values): def init_from_fields(hci_object, fields, values):
if isinstance(values, dict): if isinstance(values, dict):
for field_name, _ in fields: for field in fields:
if isinstance(field, list):
# The field is an array, up-level the array field names
for sub_field_name, _ in field:
setattr(hci_object, sub_field_name, values[sub_field_name])
else:
field_name = field[0]
setattr(hci_object, field_name, values[field_name]) setattr(hci_object, field_name, values[field_name])
else: else:
for field_name, field_value in zip(fields, values): for field_name, field_value in zip(fields, values):
@@ -1457,9 +1463,7 @@ class HCI_Object:
HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values()) HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values())
@staticmethod @staticmethod
def dict_from_bytes(data, offset, fields): def parse_field(data, offset, field_type):
result = collections.OrderedDict()
for (field_name, field_type) in fields:
# The field_type may be a dictionary with a mapper, parser, and/or size # The field_type may be a dictionary with a mapper, parser, and/or size
if isinstance(field_type, dict): if isinstance(field_type, dict):
if 'size' in field_type: if 'size' in field_type:
@@ -1471,57 +1475,67 @@ class HCI_Object:
if field_type == '*': if field_type == '*':
# The rest of the bytes # The rest of the bytes
field_value = data[offset:] field_value = data[offset:]
offset += len(field_value) return (field_value, len(field_value))
elif field_type == 1: if field_type == 1:
# 8-bit unsigned # 8-bit unsigned
field_value = data[offset] return (data[offset], 1)
offset += 1 if field_type == -1:
elif field_type == -1:
# 8-bit signed # 8-bit signed
field_value = struct.unpack_from('b', data, offset)[0] return (struct.unpack_from('b', data, offset)[0], 1)
offset += 1 if field_type == 2:
elif field_type == 2:
# 16-bit unsigned # 16-bit unsigned
field_value = struct.unpack_from('<H', data, offset)[0] return (struct.unpack_from('<H', data, offset)[0], 2)
offset += 2 if field_type == '>2':
elif field_type == '>2':
# 16-bit unsigned big-endian # 16-bit unsigned big-endian
field_value = struct.unpack_from('>H', data, offset)[0] return (struct.unpack_from('>H', data, offset)[0], 2)
offset += 2 if field_type == -2:
elif field_type == -2:
# 16-bit signed # 16-bit signed
field_value = struct.unpack_from('<h', data, offset)[0] return (struct.unpack_from('<h', data, offset)[0], 2)
offset += 2 if field_type == 3:
elif field_type == 3:
# 24-bit unsigned # 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0]) padded = data[offset : offset + 3] + bytes([0])
field_value = struct.unpack('<I', padded)[0] return (struct.unpack('<I', padded)[0], 3)
offset += 3 if field_type == 4:
elif field_type == 4:
# 32-bit unsigned # 32-bit unsigned
field_value = struct.unpack_from('<I', data, offset)[0] return (struct.unpack_from('<I', data, offset)[0], 4)
offset += 4 if field_type == '>4':
elif field_type == '>4':
# 32-bit unsigned big-endian # 32-bit unsigned big-endian
field_value = struct.unpack_from('>I', data, offset)[0] return (struct.unpack_from('>I', data, offset)[0], 4)
offset += 4 if isinstance(field_type, int) and 4 < field_type <= 256:
elif isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes) # Byte array (from 5 up to 256 bytes)
field_value = data[offset : offset + field_type] return (data[offset : offset + field_type], field_type)
offset += field_type if callable(field_type):
elif callable(field_type): new_offset, field_value = field_type(data, offset)
offset, field_value = field_type(data, offset) return (field_value, new_offset - offset)
else:
raise ValueError(f'unknown field type {field_type}') raise ValueError(f'unknown field type {field_type}')
@staticmethod
def dict_from_bytes(data, offset, fields):
result = collections.OrderedDict()
for field in fields:
if isinstance(field, list):
# This is an array field, starting with a 1-byte item count.
item_count = data[offset]
offset += 1
for _ in range(item_count):
for sub_field_name, sub_field_type in field:
value, size = HCI_Object.parse_field(
data, offset, sub_field_type
)
result.setdefault(sub_field_name, []).append(value)
offset += size
continue
field_name, field_type = field
field_value, field_size = HCI_Object.parse_field(data, offset, field_type)
result[field_name] = field_value result[field_name] = field_value
offset += field_size
return result return result
@staticmethod @staticmethod
def dict_to_bytes(hci_object, fields): def serialize_field(field_value, field_type):
result = bytearray()
for (field_name, field_type) in fields:
# The field_type may be a dictionary with a mapper, parser, serializer, # The field_type may be a dictionary with a mapper, parser, serializer,
# and/or size # and/or size
serializer = None serializer = None
@@ -1532,7 +1546,6 @@ class HCI_Object:
field_type = field_type['size'] field_type = field_type['size']
# Serialize the field # Serialize the field
field_value = hci_object[field_name]
if serializer: if serializer:
field_bytes = serializer(field_value) field_bytes = serializer(field_value)
elif field_type == 1: elif field_type == 1:
@@ -1572,17 +1585,38 @@ class HCI_Object:
): ):
field_bytes = bytes(field_value) field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256: if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or Pad with zeros if the field is too long or too short # Truncate or pad with zeros if the field is too long or too short
if len(field_bytes) < field_type: if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes)) field_bytes += bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type: elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type] field_bytes = field_bytes[:field_type]
else: else:
raise ValueError( raise ValueError(f"don't know how to serialize type {type(field_value)}")
f"don't know how to serialize type {type(field_value)}"
)
result += field_bytes return field_bytes
@staticmethod
def dict_to_bytes(hci_object, fields):
result = bytearray()
for field in fields:
if isinstance(field, list):
# The field is an array. The serialized form starts with a 1-byte
# item count. We use the length of the first array field as the
# array count, since all array fields have the same number of items.
item_count = len(hci_object[field[0][0]])
result += bytes([item_count]) + b''.join(
b''.join(
HCI_Object.serialize_field(
hci_object[sub_field_name][i], sub_field_type
)
for sub_field_name, sub_field_type in field
)
for i in range(item_count)
)
continue
(field_name, field_type) = field
result += HCI_Object.serialize_field(hci_object[field_name], field_type)
return bytes(result) return bytes(result)
@@ -1617,48 +1651,73 @@ class HCI_Object:
return str(value) return str(value)
@staticmethod @staticmethod
def format_fields(hci_object, keys, indentation='', value_mappers=None): def stringify_field(
if not keys: field_name, field_type, field_value, indentation, value_mappers
return '' ):
# Measure the widest field name
max_field_name_length = max(
(len(key[0] if isinstance(key, tuple) else key) for key in keys)
)
# Build array of formatted key:value pairs
fields = []
for key in keys:
value_mapper = None value_mapper = None
if isinstance(key, tuple): if isinstance(field_type, dict):
# The key has an associated specifier
key, specifier = key
# Get the value mapper from the specifier # Get the value mapper from the specifier
if isinstance(specifier, dict): value_mapper = field_type.get('mapper')
value_mapper = specifier.get('mapper')
# Get the value for the field
value = hci_object[key]
# Check if there's a matching mapper passed # Check if there's a matching mapper passed
if value_mappers: if value_mappers:
value_mapper = value_mappers.get(key, value_mapper) value_mapper = value_mappers.get(field_name, value_mapper)
# Map the value if we have a mapper # Map the value if we have a mapper
if value_mapper is not None: if value_mapper is not None:
value = value_mapper(value) field_value = value_mapper(field_value)
# Get the string representation of the value # Get the string representation of the value
value_str = HCI_Object.format_field_value( return HCI_Object.format_field_value(
value, indentation=indentation + ' ' field_value, indentation=indentation + ' '
) )
# Add the field to the formatted result @staticmethod
key_str = color(f'{key + ":":{1 + max_field_name_length}}', 'cyan') def format_fields(hci_object, fields, indentation='', value_mappers=None):
fields.append(f'{indentation}{key_str} {value_str}') if not fields:
return ''
return '\n'.join(fields) # Build array of formatted key:value pairs
field_strings = []
for field in fields:
if isinstance(field, list):
for sub_field in field:
sub_field_name, sub_field_type = sub_field
item_count = len(hci_object[sub_field_name])
for i in range(item_count):
field_strings.append(
(
f'{sub_field_name}[{i}]',
HCI_Object.stringify_field(
sub_field_name,
sub_field_type,
hci_object[sub_field_name][i],
indentation,
value_mappers,
),
),
)
continue
field_name, field_type = field
field_value = hci_object[field_name]
field_strings.append(
(
field_name,
HCI_Object.stringify_field(
field_name, field_type, field_value, indentation, value_mappers
),
),
)
# Measure the widest field name
max_field_name_length = max(len(s[0]) for s in field_strings)
sep = ':'
return '\n'.join(
f'{indentation}'
f'{color(f"{field_name + sep:{1 + max_field_name_length}}", "cyan")} {field_value}'
for field_name, field_value in field_strings
)
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.to_bytes()
@@ -3769,9 +3828,7 @@ class HCI_LE_Set_Extended_Advertising_Parameters_Command(HCI_Command):
'advertising_data', 'advertising_data',
{ {
'parser': HCI_Object.parse_length_prefixed_bytes, 'parser': HCI_Object.parse_length_prefixed_bytes,
'serializer': functools.partial( 'serializer': HCI_Object.serialize_length_prefixed_bytes,
HCI_Object.serialize_length_prefixed_bytes
),
}, },
), ),
] ]
@@ -3819,9 +3876,7 @@ class HCI_LE_Set_Extended_Advertising_Data_Command(HCI_Command):
'scan_response_data', 'scan_response_data',
{ {
'parser': HCI_Object.parse_length_prefixed_bytes, 'parser': HCI_Object.parse_length_prefixed_bytes,
'serializer': functools.partial( 'serializer': HCI_Object.serialize_length_prefixed_bytes,
HCI_Object.serialize_length_prefixed_bytes
),
}, },
), ),
] ]
@@ -3849,73 +3904,21 @@ class HCI_LE_Set_Extended_Scan_Response_Data_Command(HCI_Command):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command(fields=None) @HCI_Command.command(
[
('enable', 1),
[
('advertising_handles', 1),
('durations', 2),
('max_extended_advertising_events', 1),
],
]
)
class HCI_LE_Set_Extended_Advertising_Enable_Command(HCI_Command): class HCI_LE_Set_Extended_Advertising_Enable_Command(HCI_Command):
''' '''
See Bluetooth spec @ 7.8.56 LE Set Extended Advertising Enable Command See Bluetooth spec @ 7.8.56 LE Set Extended Advertising Enable Command
''' '''
@classmethod
def from_parameters(cls, parameters):
enable = parameters[0]
num_sets = parameters[1]
advertising_handles = []
durations = []
max_extended_advertising_events = []
offset = 2
for _ in range(num_sets):
advertising_handles.append(parameters[offset])
durations.append(struct.unpack_from('<H', parameters, offset + 1)[0])
max_extended_advertising_events.append(parameters[offset + 3])
offset += 4
return cls(
enable, advertising_handles, durations, max_extended_advertising_events
)
def __init__(
self, enable, advertising_handles, durations, max_extended_advertising_events
):
super().__init__(HCI_LE_SET_EXTENDED_ADVERTISING_ENABLE_COMMAND)
self.enable = enable
self.advertising_handles = advertising_handles
self.durations = durations
self.max_extended_advertising_events = max_extended_advertising_events
self.parameters = bytes([enable, len(advertising_handles)]) + b''.join(
[
struct.pack(
'<BHB',
advertising_handles[i],
durations[i],
max_extended_advertising_events[i],
)
for i in range(len(advertising_handles))
]
)
def __str__(self):
fields = [('enable:', self.enable)]
for i, advertising_handle in enumerate(self.advertising_handles):
fields.append(
(f'advertising_handle[{i}]: ', advertising_handle)
)
fields.append((f'duration[{i}]: ', self.durations[i]))
fields.append(
(
f'max_extended_advertising_events[{i}]:',
self.max_extended_advertising_events[i],
)
)
return (
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command( @HCI_Command.command(
@@ -4066,7 +4069,10 @@ class HCI_LE_Set_Extended_Scan_Parameters_Command(HCI_Command):
color(self.name, 'green') color(self.name, 'green')
+ ':\n' + ':\n'
+ '\n'.join( + '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields] [
color(' ' + field[0], 'cyan') + ' ' + str(field[1])
for field in fields
]
) )
) )
@@ -4242,7 +4248,10 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
color(self.name, 'green') color(self.name, 'green')
+ ':\n' + ':\n'
+ '\n'.join( + '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields] [
color(' ' + field[0], 'cyan') + ' ' + str(field[1])
for field in fields
]
) )
) )

View File

@@ -46,6 +46,7 @@ from bumble.hci import (
HCI_LE_Set_Advertising_Parameters_Command, HCI_LE_Set_Advertising_Parameters_Command,
HCI_LE_Set_Default_PHY_Command, HCI_LE_Set_Default_PHY_Command,
HCI_LE_Set_Event_Mask_Command, HCI_LE_Set_Event_Mask_Command,
HCI_LE_Set_Extended_Advertising_Enable_Command,
HCI_LE_Set_Extended_Scan_Parameters_Command, HCI_LE_Set_Extended_Scan_Parameters_Command,
HCI_LE_Set_Random_Address_Command, HCI_LE_Set_Random_Address_Command,
HCI_LE_Set_Scan_Enable_Command, HCI_LE_Set_Scan_Enable_Command,
@@ -422,6 +423,25 @@ def test_HCI_LE_Set_Extended_Scan_Parameters_Command():
basic_check(command) basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Extended_Advertising_Enable_Command():
command = HCI_Packet.from_bytes(
bytes.fromhex('0139200e010301050008020600090307000a')
)
assert command.enable == 1
assert command.advertising_handles == [1, 2, 3]
assert command.durations == [5, 6, 7]
assert command.max_extended_advertising_events == [8, 9, 10]
command = HCI_LE_Set_Extended_Advertising_Enable_Command(
enable=1,
advertising_handles=[1, 2, 3],
durations=[5, 6, 7],
max_extended_advertising_events=[8, 9, 10],
)
basic_check(command)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_address(): def test_address():
a = Address('C4:F2:17:1A:1D:BB') a = Address('C4:F2:17:1A:1D:BB')
@@ -478,6 +498,7 @@ def run_test_commands():
test_HCI_LE_Read_Remote_Features_Command() test_HCI_LE_Read_Remote_Features_Command()
test_HCI_LE_Set_Default_PHY_Command() test_HCI_LE_Set_Default_PHY_Command()
test_HCI_LE_Set_Extended_Scan_Parameters_Command() test_HCI_LE_Set_Extended_Scan_Parameters_Command()
test_HCI_LE_Set_Extended_Advertising_Enable_Command()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------