Merge pull request #885 from zxzxwu/match-case

Replace long if-else with match-case
This commit is contained in:
Josh Wu
2026-02-11 13:12:38 +08:00
committed by GitHub
11 changed files with 451 additions and 432 deletions

View File

@@ -267,26 +267,27 @@ class MediaCodecInformation:
def create( def create(
cls, media_codec_type: int, data: bytes cls, media_codec_type: int, data: bytes
) -> MediaCodecInformation | bytes: ) -> MediaCodecInformation | bytes:
if media_codec_type == CodecType.SBC: match media_codec_type:
return SbcMediaCodecInformation.from_bytes(data) case CodecType.SBC:
elif media_codec_type == CodecType.MPEG_2_4_AAC: return SbcMediaCodecInformation.from_bytes(data)
return AacMediaCodecInformation.from_bytes(data) case CodecType.MPEG_2_4_AAC:
elif media_codec_type == CodecType.NON_A2DP: return AacMediaCodecInformation.from_bytes(data)
vendor_media_codec_information = ( case CodecType.NON_A2DP:
VendorSpecificMediaCodecInformation.from_bytes(data) vendor_media_codec_information = (
) VendorSpecificMediaCodecInformation.from_bytes(data)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
) )
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
return vendor_media_codec_information return vendor_media_codec_information
@classmethod @classmethod

View File

@@ -27,7 +27,7 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
are ignored [..], unless they are embedded in numeric or string constants" are ignored [..], unless they are embedded in numeric or string constants"
Raises AtParsingError in case of invalid input string.""" Raises AtParsingError in case of invalid input string."""
tokens = [] tokens: list[bytearray] = []
in_quotes = False in_quotes = False
token = bytearray() token = bytearray()
for b in buffer: for b in buffer:
@@ -40,23 +40,24 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
tokens.append(token[1:-1]) tokens.append(token[1:-1])
token = bytearray() token = bytearray()
else: else:
if char == b' ': match char:
pass case b' ':
elif char == b',' or char == b')': pass
tokens.append(token) case b',' | b')':
tokens.append(char) tokens.append(token)
token = bytearray() tokens.append(char)
elif char == b'(': token = bytearray()
if len(token) > 0: case b'(':
raise AtParsingError("open_paren following regular character") if len(token) > 0:
tokens.append(char) raise AtParsingError("open_paren following regular character")
elif char == b'"': tokens.append(char)
if len(token) > 0: case b'"':
raise AtParsingError("quote following regular character") if len(token) > 0:
in_quotes = True raise AtParsingError("quote following regular character")
token.extend(char) in_quotes = True
else: token.extend(char)
token.extend(char) case _:
token.extend(char)
tokens.append(token) tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0] return [bytes(token) for token in tokens if len(token) > 0]
@@ -71,18 +72,19 @@ def parse_parameters(buffer: bytes) -> list[bytes | list]:
current: bytes | list = b'' current: bytes | list = b''
for token in tokens: for token in tokens:
if token == b',': match token:
accumulator[-1].append(current) case b',':
current = b'' accumulator[-1].append(current)
elif token == b'(': current = b''
accumulator.append([]) case b'(':
elif token == b')': accumulator.append([])
if len(accumulator) < 2: case b')':
raise AtParsingError("close_paren without matching open_paren") if len(accumulator) < 2:
accumulator[-1].append(current) raise AtParsingError("close_paren without matching open_paren")
current = accumulator.pop() accumulator[-1].append(current)
else: current = accumulator.pop()
current = token case _:
current = token
accumulator[-1].append(current) accumulator[-1].append(current)
if len(accumulator) > 1: if len(accumulator) > 1:

View File

@@ -954,12 +954,13 @@ class Attribute(utils.EventEmitter, Generic[_T]):
self.permissions = permissions self.permissions = permissions
# Convert the type to a UUID object if it isn't already # Convert the type to a UUID object if it isn't already
if isinstance(attribute_type, str): match attribute_type:
self.type = UUID(attribute_type) case str():
elif isinstance(attribute_type, bytes): self.type = UUID(attribute_type)
self.type = UUID.from_bytes(attribute_type) case bytes():
else: self.type = UUID.from_bytes(attribute_type)
self.type = attribute_type case _:
self.type = attribute_type
self.value = value self.value = value
@@ -994,30 +995,31 @@ class Attribute(utils.EventEmitter, Generic[_T]):
) )
value: _T | None value: _T | None
if isinstance(self.value, AttributeValue): match self.value:
try: case AttributeValue():
read_value = self.value.read(connection) try:
if inspect.isawaitable(read_value): read_value = self.value.read(connection)
value = await read_value if inspect.isawaitable(read_value):
else: value = await read_value
value = read_value else:
except ATT_Error as error: value = read_value
raise ATT_Error( except ATT_Error as error:
error_code=error.error_code, att_handle=self.handle raise ATT_Error(
) from error error_code=error.error_code, att_handle=self.handle
elif isinstance(self.value, AttributeValueV2): ) from error
try: case AttributeValueV2():
read_value = self.value.read(bearer) try:
if inspect.isawaitable(read_value): read_value = self.value.read(bearer)
value = await read_value if inspect.isawaitable(read_value):
else: value = await read_value
value = read_value else:
except ATT_Error as error: value = read_value
raise ATT_Error( except ATT_Error as error:
error_code=error.error_code, att_handle=self.handle raise ATT_Error(
) from error error_code=error.error_code, att_handle=self.handle
else: ) from error
value = self.value case _:
value = self.value
self.emit(self.EVENT_READ, connection, b'' if value is None else value) self.emit(self.EVENT_READ, connection, b'' if value is None else value)
@@ -1049,26 +1051,27 @@ class Attribute(utils.EventEmitter, Generic[_T]):
decoded_value = self.decode_value(value) decoded_value = self.decode_value(value)
if isinstance(self.value, AttributeValue): match self.value:
try: case AttributeValue():
result = self.value.write(connection, decoded_value) try:
if inspect.isawaitable(result): result = self.value.write(connection, decoded_value)
await result if inspect.isawaitable(result):
except ATT_Error as error: await result
raise ATT_Error( except ATT_Error as error:
error_code=error.error_code, att_handle=self.handle raise ATT_Error(
) from error error_code=error.error_code, att_handle=self.handle
elif isinstance(self.value, AttributeValueV2): ) from error
try: case AttributeValueV2():
result = self.value.write(bearer, decoded_value) try:
if inspect.isawaitable(result): result = self.value.write(bearer, decoded_value)
await result if inspect.isawaitable(result):
except ATT_Error as error: await result
raise ATT_Error( except ATT_Error as error:
error_code=error.error_code, att_handle=self.handle raise ATT_Error(
) from error error_code=error.error_code, att_handle=self.handle
else: ) from error
self.value = decoded_value case _:
self.value = decoded_value
self.emit(self.EVENT_WRITE, connection, decoded_value) self.emit(self.EVENT_WRITE, connection, decoded_value)

View File

@@ -403,14 +403,15 @@ class Controller:
) )
# If the packet is a command, invoke the handler for this packet # If the packet is a command, invoke the handler for this packet
if isinstance(packet, hci.HCI_Command): match packet:
self.on_hci_command_packet(packet) case hci.HCI_Command():
elif isinstance(packet, hci.HCI_AclDataPacket): self.on_hci_command_packet(packet)
self.on_hci_acl_data_packet(packet) case hci.HCI_AclDataPacket():
elif isinstance(packet, hci.HCI_Event): self.on_hci_acl_data_packet(packet)
self.on_hci_event_packet(packet) case hci.HCI_Event():
else: self.on_hci_event_packet(packet)
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') case _:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
def on_hci_command_packet(self, command: hci.HCI_Command) -> None: def on_hci_command_packet(self, command: hci.HCI_Command) -> None:
handler_name = f'on_{command.name.lower()}' handler_name = f'on_{command.name.lower()}'
@@ -517,26 +518,28 @@ class Controller:
logger.error("Cannot find a connection for %s", sender_address) logger.error("Cannot find a connection for %s", sender_address)
return return
if isinstance(packet, ll.TerminateInd): match packet:
self.on_le_disconnected(connection, packet.error_code) case ll.TerminateInd():
elif isinstance(packet, ll.CisReq): self.on_le_disconnected(connection, packet.error_code)
self.on_le_cis_request(connection, packet.cig_id, packet.cis_id) case ll.CisReq():
elif isinstance(packet, ll.CisRsp): self.on_le_cis_request(connection, packet.cig_id, packet.cis_id)
self.on_le_cis_established(packet.cig_id, packet.cis_id) case ll.CisRsp():
connection.send_ll_control_pdu(ll.CisInd(packet.cig_id, packet.cis_id)) self.on_le_cis_established(packet.cig_id, packet.cis_id)
elif isinstance(packet, ll.CisInd): connection.send_ll_control_pdu(ll.CisInd(packet.cig_id, packet.cis_id))
self.on_le_cis_established(packet.cig_id, packet.cis_id) case ll.CisInd():
elif isinstance(packet, ll.CisTerminateInd): self.on_le_cis_established(packet.cig_id, packet.cis_id)
self.on_le_cis_disconnected(packet.cig_id, packet.cis_id) case ll.CisTerminateInd():
elif isinstance(packet, ll.EncReq): self.on_le_cis_disconnected(packet.cig_id, packet.cis_id)
self.on_le_encrypted(connection) case ll.EncReq():
self.on_le_encrypted(connection)
def on_ll_advertising_pdu(self, packet: ll.AdvertisingPdu) -> None: def on_ll_advertising_pdu(self, packet: ll.AdvertisingPdu) -> None:
logger.debug("[%s] <<< Advertising PDU: %s", self.name, packet) logger.debug("[%s] <<< Advertising PDU: %s", self.name, packet)
if isinstance(packet, ll.ConnectInd): match packet:
self.on_le_connect_ind(packet) case ll.ConnectInd():
elif isinstance(packet, (ll.AdvInd, ll.AdvExtInd)): self.on_le_connect_ind(packet)
self.on_advertising_pdu(packet) case ll.AdvInd() | ll.AdvExtInd():
self.on_advertising_pdu(packet)
def on_le_connect_ind(self, packet: ll.ConnectInd) -> None: def on_le_connect_ind(self, packet: ll.ConnectInd) -> None:
''' '''
@@ -894,51 +897,52 @@ class Controller:
return future return future
def on_lmp_packet(self, sender_address: hci.Address, packet: lmp.Packet): def on_lmp_packet(self, sender_address: hci.Address, packet: lmp.Packet):
if isinstance(packet, (lmp.LmpAccepted, lmp.LmpAcceptedExt)): match packet:
if future := self.classic_pending_commands.setdefault( case lmp.LmpAccepted() | lmp.LmpAcceptedExt():
sender_address, {} if future := self.classic_pending_commands.setdefault(
).get(packet.response_opcode): sender_address, {}
future.set_result(hci.HCI_SUCCESS) ).get(packet.response_opcode):
else: future.set_result(hci.HCI_SUCCESS)
else:
logger.error("!!! Unhandled packet: %s", packet)
case lmp.LmpNotAccepted() | lmp.LmpNotAcceptedExt():
if future := self.classic_pending_commands.setdefault(
sender_address, {}
).get(packet.response_opcode):
future.set_result(packet.error_code)
else:
logger.error("!!! Unhandled packet: %s", packet)
case lmp.LmpHostConnectionReq():
self.on_classic_connection_request(
sender_address, hci.HCI_Connection_Complete_Event.LinkType.ACL
)
case lmp.LmpScoLinkReq():
self.on_classic_connection_request(
sender_address, hci.HCI_Connection_Complete_Event.LinkType.SCO
)
case lmp.LmpEscoLinkReq():
self.on_classic_connection_request(
sender_address, hci.HCI_Connection_Complete_Event.LinkType.ESCO
)
case lmp.LmpDetach():
self.on_classic_disconnected(
sender_address, hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
)
case lmp.LmpSwitchReq():
self.on_classic_role_change_request(sender_address)
case lmp.LmpRemoveScoLinkReq() | lmp.LmpRemoveEscoLinkReq():
self.on_classic_sco_disconnected(sender_address, packet.error_code)
case lmp.LmpNameReq():
self.on_classic_remote_name_request(sender_address, packet.name_offset)
case lmp.LmpNameRes():
self.on_classic_remote_name_response(
sender_address,
packet.name_offset,
packet.name_length,
packet.name_fregment,
)
case _:
logger.error("!!! Unhandled packet: %s", packet) logger.error("!!! Unhandled packet: %s", packet)
elif isinstance(packet, (lmp.LmpNotAccepted, lmp.LmpNotAcceptedExt)):
if future := self.classic_pending_commands.setdefault(
sender_address, {}
).get(packet.response_opcode):
future.set_result(packet.error_code)
else:
logger.error("!!! Unhandled packet: %s", packet)
elif isinstance(packet, (lmp.LmpHostConnectionReq)):
self.on_classic_connection_request(
sender_address, hci.HCI_Connection_Complete_Event.LinkType.ACL
)
elif isinstance(packet, (lmp.LmpScoLinkReq)):
self.on_classic_connection_request(
sender_address, hci.HCI_Connection_Complete_Event.LinkType.SCO
)
elif isinstance(packet, (lmp.LmpEscoLinkReq)):
self.on_classic_connection_request(
sender_address, hci.HCI_Connection_Complete_Event.LinkType.ESCO
)
elif isinstance(packet, (lmp.LmpDetach)):
self.on_classic_disconnected(
sender_address, hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
)
elif isinstance(packet, (lmp.LmpSwitchReq)):
self.on_classic_role_change_request(sender_address)
elif isinstance(packet, (lmp.LmpRemoveScoLinkReq, lmp.LmpRemoveEscoLinkReq)):
self.on_classic_sco_disconnected(sender_address, packet.error_code)
elif isinstance(packet, lmp.LmpNameReq):
self.on_classic_remote_name_request(sender_address, packet.name_offset)
elif isinstance(packet, lmp.LmpNameRes):
self.on_classic_remote_name_response(
sender_address,
packet.name_offset,
packet.name_length,
packet.name_fregment,
)
else:
logger.error("!!! Unhandled packet: %s", packet)
def on_classic_connection_request( def on_classic_connection_request(
self, peer_address: hci.Address, link_type: int self, peer_address: hci.Address, link_type: int

View File

@@ -280,14 +280,15 @@ class UUID:
if not force_128: if not force_128:
return self.uuid_bytes return self.uuid_bytes
if len(self.uuid_bytes) == 2: match len(self.uuid_bytes):
return self.BASE_UUID + self.uuid_bytes + bytes([0, 0]) case 2:
elif len(self.uuid_bytes) == 4: return self.BASE_UUID + self.uuid_bytes + bytes([0, 0])
return self.BASE_UUID + self.uuid_bytes case 4:
elif len(self.uuid_bytes) == 16: return self.BASE_UUID + self.uuid_bytes
return self.uuid_bytes case 16:
else: return self.uuid_bytes
assert False, "unreachable" case _:
assert False, "unreachable"
def to_pdu_bytes(self) -> bytes: def to_pdu_bytes(self) -> bytes:
''' '''
@@ -1769,66 +1770,71 @@ class AdvertisingData:
@classmethod @classmethod
def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str: def ad_data_to_string(cls, ad_type: int, ad_data: bytes) -> str:
if ad_type == AdvertisingData.FLAGS: match ad_type:
ad_type_str = 'Flags' case AdvertisingData.FLAGS:
ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True) ad_type_str = 'Flags'
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: ad_data_str = AdvertisingData.flags_to_string(ad_data[0], short=True)
ad_type_str = 'Complete List of 16-bit Service Class UUIDs' case AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2) ad_type_str = 'Complete List of 16-bit Service Class UUIDs'
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs' case AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS:
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2) ad_type_str = 'Incomplete List of 16-bit Service Class UUIDs'
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 2)
ad_type_str = 'Complete List of 32-bit Service Class UUIDs' case AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4) ad_type_str = 'Complete List of 32-bit Service Class UUIDs'
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs' case AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS:
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4) ad_type_str = 'Incomplete List of 32-bit Service Class UUIDs'
elif ad_type == AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 4)
ad_type_str = 'Complete List of 128-bit Service Class UUIDs' case AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16) ad_type_str = 'Complete List of 128-bit Service Class UUIDs'
elif ad_type == AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs' case AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS:
ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16) ad_type_str = 'Incomplete List of 128-bit Service Class UUIDs'
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID: ad_data_str = AdvertisingData.uuid_list_to_string(ad_data, 16)
ad_type_str = 'Service Data' case AdvertisingData.SERVICE_DATA_16_BIT_UUID:
uuid = UUID.from_bytes(ad_data[:2]) ad_type_str = 'Service Data'
ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}' uuid = UUID.from_bytes(ad_data[:2])
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID: ad_data_str = f'service={uuid}, data={ad_data[2:].hex()}'
ad_type_str = 'Service Data' case AdvertisingData.SERVICE_DATA_32_BIT_UUID:
uuid = UUID.from_bytes(ad_data[:4]) ad_type_str = 'Service Data'
ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}' uuid = UUID.from_bytes(ad_data[:4])
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID: ad_data_str = f'service={uuid}, data={ad_data[4:].hex()}'
ad_type_str = 'Service Data' case AdvertisingData.SERVICE_DATA_128_BIT_UUID:
uuid = UUID.from_bytes(ad_data[:16]) ad_type_str = 'Service Data'
ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}' uuid = UUID.from_bytes(ad_data[:16])
elif ad_type == AdvertisingData.SHORTENED_LOCAL_NAME: ad_data_str = f'service={uuid}, data={ad_data[16:].hex()}'
ad_type_str = 'Shortened Local Name' case AdvertisingData.SHORTENED_LOCAL_NAME:
ad_data_str = f'"{ad_data.decode("utf-8")}"' ad_type_str = 'Shortened Local Name'
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name'
try:
ad_data_str = f'"{ad_data.decode("utf-8")}"' ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError: case AdvertisingData.COMPLETE_LOCAL_NAME:
ad_type_str = 'Complete Local Name'
try:
ad_data_str = f'"{ad_data.decode("utf-8")}"'
except UnicodeDecodeError:
ad_data_str = ad_data.hex()
case AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0])
case AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
ad_type_str = 'Manufacturer Specific Data'
company_id = struct.unpack_from('<H', ad_data, 0)[0]
company_name = COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}'
)
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
case AdvertisingData.APPEARANCE:
ad_type_str = 'Appearance'
appearance = Appearance.from_int(
struct.unpack_from('<H', ad_data, 0)[0]
)
ad_data_str = str(appearance)
case AdvertisingData.BROADCAST_NAME:
ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8')
case _:
ad_type_str = AdvertisingData.Type(ad_type).name
ad_data_str = ad_data.hex() ad_data_str = ad_data.hex()
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
ad_type_str = 'TX Power Level'
ad_data_str = str(ad_data[0])
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
ad_type_str = 'Manufacturer Specific Data'
company_id = struct.unpack_from('<H', ad_data, 0)[0]
company_name = COMPANY_IDENTIFIERS.get(company_id, f'0x{company_id:04X}')
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
elif ad_type == AdvertisingData.APPEARANCE:
ad_type_str = 'Appearance'
appearance = Appearance.from_int(struct.unpack_from('<H', ad_data, 0)[0])
ad_data_str = str(appearance)
elif ad_type == AdvertisingData.BROADCAST_NAME:
ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8')
else:
ad_type_str = AdvertisingData.Type(ad_type).name
ad_data_str = ad_data.hex()
return f'[{ad_type_str}]: {ad_data_str}' return f'[{ad_type_str}]: {ad_data_str}'

View File

@@ -201,50 +201,51 @@ def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
value = data[2 : 2 + value_length] value = data[2 : 2 + value_length]
typed_value: Any typed_value: Any
if value_type == ValueType.END: match value_type:
break case ValueType.END:
break
if value_type in (ValueType.CNVI, ValueType.CNVR): case ValueType.CNVI | ValueType.CNVR:
(v,) = struct.unpack("<I", value) (v,) = struct.unpack("<I", value)
typed_value = ( typed_value = (
(((v >> 0) & 0xF) << 12) (((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0) | (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4) | (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8) | (((v >> 24) & 0xF) << 8)
) )
elif value_type == ValueType.HARDWARE_INFO: case ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value) (v,) = struct.unpack("<I", value)
typed_value = HardwareInfo( typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F) HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
) )
elif value_type in ( case (
ValueType.USB_VENDOR_ID, ValueType.USB_VENDOR_ID
ValueType.USB_PRODUCT_ID, | ValueType.USB_PRODUCT_ID
ValueType.DEVICE_REVISION, | ValueType.DEVICE_REVISION
): ):
(typed_value,) = struct.unpack("<H", value) (typed_value,) = struct.unpack("<H", value)
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION: case ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0]) typed_value = ModeOfOperation(value[0])
elif value_type in ( case (
ValueType.BUILD_TYPE, ValueType.BUILD_TYPE
ValueType.BUILD_NUMBER, | ValueType.BUILD_NUMBER
ValueType.SECURE_BOOT, | ValueType.SECURE_BOOT
ValueType.OTP_LOCK, | ValueType.OTP_LOCK
ValueType.API_LOCK, | ValueType.API_LOCK
ValueType.DEBUG_LOCK, | ValueType.DEBUG_LOCK
ValueType.SECURE_BOOT_ENGINE_TYPE, | ValueType.SECURE_BOOT_ENGINE_TYPE
): ):
typed_value = value[0] typed_value = value[0]
elif value_type == ValueType.TIMESTAMP: case ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1]) typed_value = Timestamp(value[0], value[1])
elif value_type == ValueType.FIRMWARE_BUILD: case ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2])) typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
elif value_type == ValueType.BLUETOOTH_ADDRESS: case ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address( typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
) )
else: case _:
typed_value = value typed_value = value
result.append((value_type, typed_value)) result.append((value_type, typed_value))
data = data[2 + value_length :] data = data[2 + value_length :]

View File

@@ -31,6 +31,7 @@ from typing import (
ClassVar, ClassVar,
Generic, Generic,
Literal, Literal,
SupportsBytes,
TypeVar, TypeVar,
cast, cast,
) )
@@ -1860,44 +1861,46 @@ class HCI_Object:
field_type = field_type['parser'] field_type = field_type['parser']
# Parse the field # Parse the field
if field_type == '*': match field_type:
# The rest of the bytes case '*':
field_value = data[offset:] # The rest of the bytes
return (field_value, len(field_value)) field_value = data[offset:]
if field_type == 'v': return (field_value, len(field_value))
# Variable-length bytes field, with 1-byte length at the beginning case 'v':
field_length = data[offset] # Variable-length bytes field, with 1-byte length at the beginning
offset += 1 field_length = data[offset]
field_value = data[offset : offset + field_length] offset += 1
return (field_value, field_length + 1) field_value = data[offset : offset + field_length]
if field_type == 1: return (field_value, field_length + 1)
# 8-bit unsigned case 1:
return (data[offset], 1) # 8-bit unsigned
if field_type == -1: return (data[offset], 1)
# 8-bit signed case -1:
return (struct.unpack_from('b', data, offset)[0], 1) # 8-bit signed
if field_type == 2: return (struct.unpack_from('b', data, offset)[0], 1)
# 16-bit unsigned case 2:
return (struct.unpack_from('<H', data, offset)[0], 2) # 16-bit unsigned
if field_type == '>2': return (struct.unpack_from('<H', data, offset)[0], 2)
# 16-bit unsigned big-endian case '>2':
return (struct.unpack_from('>H', data, offset)[0], 2) # 16-bit unsigned big-endian
if field_type == -2: return (struct.unpack_from('>H', data, offset)[0], 2)
# 16-bit signed case -2:
return (struct.unpack_from('<h', data, offset)[0], 2) # 16-bit signed
if field_type == 3: return (struct.unpack_from('<h', data, offset)[0], 2)
# 24-bit unsigned case 3:
padded = data[offset : offset + 3] + bytes([0]) # 24-bit unsigned
return (struct.unpack('<I', padded)[0], 3) padded = data[offset : offset + 3] + bytes([0])
if field_type == 4: return (struct.unpack('<I', padded)[0], 3)
# 32-bit unsigned case 4:
return (struct.unpack_from('<I', data, offset)[0], 4) # 32-bit unsigned
if field_type == '>4': return (struct.unpack_from('<I', data, offset)[0], 4)
# 32-bit unsigned big-endian case '>4':
return (struct.unpack_from('>I', data, offset)[0], 4) # 32-bit unsigned big-endian
if isinstance(field_type, int) and 4 < field_type <= 256: return (struct.unpack_from('>I', data, offset)[0], 4)
# Byte array (from 5 up to 256 bytes) case int() if 4 < field_type <= 256:
return (data[offset : offset + field_type], field_type) # Byte array (from 5 up to 256 bytes)
return (data[offset : offset + field_type], field_type)
if callable(field_type): if callable(field_type):
new_offset, field_value = field_type(data, offset) new_offset, field_value = field_type(data, offset)
return (field_value, new_offset - offset) return (field_value, new_offset - offset)
@@ -1954,60 +1957,58 @@ class HCI_Object:
# Serialize the field # Serialize the field
if serializer: if serializer:
field_bytes = serializer(field_value) return serializer(field_value)
elif field_type == 1: match field_type:
# 8-bit unsigned case 1:
field_bytes = bytes([field_value]) # 8-bit unsigned
elif field_type == -1: return bytes([field_value])
# 8-bit signed case -1:
field_bytes = struct.pack('b', field_value) # 8-bit signed
elif field_type == 2: return struct.pack('b', field_value)
# 16-bit unsigned case 2:
field_bytes = struct.pack('<H', field_value) # 16-bit unsigned
elif field_type == '>2': return struct.pack('<H', field_value)
# 16-bit unsigned big-endian case '>2':
field_bytes = struct.pack('>H', field_value) # 16-bit unsigned big-endian
elif field_type == -2: return struct.pack('>H', field_value)
# 16-bit signed case -2:
field_bytes = struct.pack('<h', field_value) # 16-bit signed
elif field_type == 3: return struct.pack('<h', field_value)
# 24-bit unsigned case 3:
field_bytes = struct.pack('<I', field_value)[0:3] # 24-bit unsigned
elif field_type == 4: return struct.pack('<I', field_value)[0:3]
# 32-bit unsigned case 4:
field_bytes = struct.pack('<I', field_value) # 32-bit unsigned
elif field_type == '>4': return struct.pack('<I', field_value)
# 32-bit unsigned big-endian case '>4':
field_bytes = struct.pack('>I', field_value) # 32-bit unsigned big-endian
elif field_type == '*': return struct.pack('>I', field_value)
if isinstance(field_value, int): case '*':
if 0 <= field_value <= 255: if isinstance(field_value, int):
field_bytes = bytes([field_value]) if 0 <= field_value <= 255:
return bytes([field_value])
else:
raise InvalidArgumentError('value too large for *-typed field')
else: else:
raise InvalidArgumentError('value too large for *-typed field') return bytes(field_value)
else: case 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_bytes = bytes(field_value) field_bytes = bytes(field_value)
elif field_type == 'v': field_length = len(field_bytes)
# Variable-length bytes field, with 1-byte length at the beginning return bytes([field_length]) + field_bytes
field_bytes = bytes(field_value) if isinstance(field_value, (bytes, bytearray, SupportsBytes)):
field_length = len(field_bytes)
field_bytes = bytes([field_length]) + field_bytes
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, '__bytes__'
):
field_bytes = bytes(field_value) field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256: if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or pad with zeros if the field is too long or too short # Truncate or pad with zeros if the field is too long or too short
if len(field_bytes) < field_type: if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes)) return field_bytes + bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type: elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type] return field_bytes[:field_type]
else: return field_bytes
raise InvalidArgumentError(
f"don't know how to serialize type {type(field_value)}"
)
return field_bytes raise InvalidArgumentError(
f"don't know how to serialize type {type(field_value)}"
)
@staticmethod @staticmethod
def dict_to_bytes(hci_object, object_fields): def dict_to_bytes(hci_object, object_fields):

View File

@@ -22,7 +22,7 @@ import collections
import dataclasses import dataclasses
import logging import logging
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload from typing import TYPE_CHECKING, Any, TypeVar, overload
from bumble import drivers, hci, utils from bumble import drivers, hci, utils
from bumble.colors import color from bumble.colors import color
@@ -1002,18 +1002,19 @@ class Host(utils.EventEmitter):
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST) self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet # If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == hci.HCI_COMMAND_PACKET: match packet:
self.on_hci_command_packet(cast(hci.HCI_Command, packet)) case hci.HCI_Command():
elif packet.hci_packet_type == hci.HCI_EVENT_PACKET: self.on_hci_command_packet(packet)
self.on_hci_event_packet(cast(hci.HCI_Event, packet)) case hci.HCI_Event():
elif packet.hci_packet_type == hci.HCI_ACL_DATA_PACKET: self.on_hci_event_packet(packet)
self.on_hci_acl_data_packet(cast(hci.HCI_AclDataPacket, packet)) case hci.HCI_AclDataPacket():
elif packet.hci_packet_type == hci.HCI_SYNCHRONOUS_DATA_PACKET: self.on_hci_acl_data_packet(packet)
self.on_hci_sco_data_packet(cast(hci.HCI_SynchronousDataPacket, packet)) case hci.HCI_SynchronousDataPacket():
elif packet.hci_packet_type == hci.HCI_ISO_DATA_PACKET: self.on_hci_sco_data_packet(packet)
self.on_hci_iso_data_packet(cast(hci.HCI_IsoDataPacket, packet)) case hci.HCI_IsoDataPacket():
else: self.on_hci_iso_data_packet(packet)
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') case _:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
def on_hci_command_packet(self, command: hci.HCI_Command) -> None: def on_hci_command_packet(self, command: hci.HCI_Command) -> None:
logger.warning(f'!!! unexpected command packet: {command}') logger.warning(f'!!! unexpected command packet: {command}')

View File

@@ -664,46 +664,44 @@ class AudioStreamControlService(gatt.TemplateService):
responses = [] responses = []
logger.debug(f'*** ASCS Write {operation} ***') logger.debug(f'*** ASCS Write {operation} ***')
if isinstance(operation, ASE_Config_Codec): match operation:
for ase_id, *args in zip( case ASE_Config_Codec():
operation.ase_id, for ase_id, *args in zip(
operation.target_latency, operation.ase_id,
operation.target_phy, operation.target_latency,
operation.codec_id, operation.target_phy,
operation.codec_specific_configuration, operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Config_QOS():
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Enable() | ASE_Update_Metadata():
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case (
ASE_Receiver_Start_Ready()
| ASE_Disable()
| ASE_Receiver_Stop_Ready()
| ASE_Release()
): ):
responses.append(self.on_operation(operation.op_code, ase_id, args)) for ase_id in operation.ase_id:
elif isinstance(operation, ASE_Config_QOS): responses.append(self.on_operation(operation.op_code, ase_id, []))
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(operation, (ASE_Enable, ASE_Update_Metadata)):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif isinstance(
operation,
(
ASE_Receiver_Start_Ready,
ASE_Disable,
ASE_Receiver_Stop_Ready,
ASE_Release,
),
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes( control_point_notification = bytes(
[operation.op_code, len(responses)] [operation.op_code, len(responses)]

View File

@@ -333,17 +333,18 @@ class CodecSpecificCapabilities:
value = int.from_bytes(data[offset : offset + length - 1], 'little') value = int.from_bytes(data[offset : offset + length - 1], 'little')
offset += length - 1 offset += length - 1
if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY: match type:
supported_sampling_frequencies = SupportedSamplingFrequency(value) case CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
elif type == CodecSpecificCapabilities.Type.FRAME_DURATION: supported_sampling_frequencies = SupportedSamplingFrequency(value)
supported_frame_durations = SupportedFrameDuration(value) case CodecSpecificCapabilities.Type.FRAME_DURATION:
elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT: supported_frame_durations = SupportedFrameDuration(value)
supported_audio_channel_count = bits_to_channel_counts(value) case CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME: supported_audio_channel_count = bits_to_channel_counts(value)
min_octets_per_sample = value & 0xFFFF case CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
max_octets_per_sample = value >> 16 min_octets_per_sample = value & 0xFFFF
elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU: max_octets_per_sample = value >> 16
supported_max_codec_frames_per_sdu = value case CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
supported_max_codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised. # It is expected here that if some fields are missing, an error should be raised.
# pylint: disable=possibly-used-before-assignment,used-before-assignment # pylint: disable=possibly-used-before-assignment,used-before-assignment

View File

@@ -55,14 +55,15 @@ class GenericAccessService(TemplateService):
def __init__( def __init__(
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0 self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0
): ):
if isinstance(appearance, int): match appearance:
appearance_int = appearance case int():
elif isinstance(appearance, tuple): appearance_int = appearance
appearance_int = (appearance[0] << 6) | appearance[1] case tuple():
elif isinstance(appearance, Appearance): appearance_int = (appearance[0] << 6) | appearance[1]
appearance_int = int(appearance) case Appearance():
else: appearance_int = int(appearance)
raise TypeError() case _:
raise TypeError()
self.device_name_characteristic = Characteristic( self.device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,