Submitting review comment fix: header function and spacing

This commit is contained in:
skarnataki
2023-10-10 11:43:57 +00:00
committed by Lucas Abel
parent fc1bf36ace
commit 493f4f8b95
2 changed files with 45 additions and 60 deletions

View File

@@ -30,6 +30,8 @@ from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError # type:
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device, Connection from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -45,7 +47,7 @@ HID_INTERRUPT_PSM = 0x0013
class Message(): class Message():
message_type: MessageType
# Report types # Report types
class ReportType(enum.IntEnum): class ReportType(enum.IntEnum):
OTHER_REPORT = 0x00 OTHER_REPORT = 0x00
@@ -83,6 +85,11 @@ class Message():
EXIT_SUSPEND = 0x04 EXIT_SUSPEND = 0x04
VIRTUAL_CABLE_UNPLUG = 0x05 VIRTUAL_CABLE_UNPLUG = 0x05
# Class Method to derive header
@classmethod
def header( cls , lower_bits : int = 0x00 ) -> bytes :
return bytes([(cls.message_type << 4) | lower_bits])
# HIDP messages # HIDP messages
@dataclass @dataclass
@@ -90,58 +97,54 @@ class GetReportMessage(Message):
report_type : int report_type : int
report_id : int report_id : int
buffer_size : int buffer_size : int
message_type = Message.MessageType.GET_REPORT
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
if(self.report_type == Message.ReportType.OTHER_REPORT):
param = self.report_type
else:
param = 0x08 | self.report_type
header = ((Message.MessageType.GET_REPORT << 4) | param)
packet_bytes = bytearray() packet_bytes = bytearray()
packet_bytes.append(header)
packet_bytes.append(self.report_id) packet_bytes.append(self.report_id)
packet_bytes.extend([(self.buffer_size & 0xff), ((self.buffer_size >> 8) & 0xff)]) packet_bytes.extend([(self.buffer_size & 0xff), ((self.buffer_size >> 8) & 0xff)])
return bytes(packet_bytes) if(self.report_type == Message.ReportType.OTHER_REPORT):
return self.header(self.report_type) + packet_bytes
else:
return self.header(0x08 | self.report_type) + packet_bytes
@dataclass @dataclass
class SetReportMessage(Message): class SetReportMessage(Message):
report_type: int report_type: int
data : bytes data : bytes
message_type = Message.MessageType.SET_REPORT
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
header = ((Message.MessageType.SET_REPORT << 4) | self.report_type) return self.header(self.report_type) + self.data
packet_bytes = bytearray()
packet_bytes.append(header)
packet_bytes.extend(self.data)
return bytes(packet_bytes)
@dataclass @dataclass
class GetProtocolMessage(Message): class GetProtocolMessage(Message):
message_type = Message.MessageType.GET_PROTOCOL
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
header = (Message.MessageType.GET_PROTOCOL << 4) return self.header()
packet_bytes = bytearray()
packet_bytes.append(header)
return bytes(packet_bytes)
@dataclass @dataclass
class SetProtocolMessage(Message): class SetProtocolMessage(Message):
protocol_mode: int protocol_mode: int
message_type = Message.MessageType.SET_PROTOCOL
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
header = (Message.MessageType.SET_PROTOCOL << 4 | self.protocol_mode) return self.header(self.protocol_mode)
packet_bytes = bytearray()
packet_bytes.append(header)
packet_bytes.append(self.protocol_mode)
return bytes(packet_bytes)
@dataclass @dataclass
class SendData(Message): class SendData(Message):
message_type = Message.MessageType.DATA
data : bytes data : bytes
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
header = ((Message.MessageType.DATA << 4) | Message.ReportType.OUTPUT_REPORT) return self.header(Message.ReportType.OUTPUT_REPORT) + self.data
packet_bytes = bytearray()
packet_bytes.append(header)
packet_bytes.extend(self.data)
return bytes(packet_bytes)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(EventEmitter): class Host(EventEmitter):
l2cap_channel: Optional[l2cap.Channel] l2cap_channel: Optional[l2cap.Channel]
@@ -240,32 +243,31 @@ class Host(EventEmitter):
logger.debug('<<< HID CONTROL DATA') logger.debug('<<< HID CONTROL DATA')
self.emit('data', pdu) self.emit('data', pdu)
def on_intr_pdu(self, pdu: bytes) -> None: def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
self.emit("data", pdu) self.emit("data", pdu)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = GetReportMessage(report_type = report_type , report_id = report_id , buffer_size = buffer_size) msg = GetReportMessage(report_type = report_type , report_id = report_id , buffer_size = buffer_size)
hid_message = msg.__bytes__() hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}') logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) # type: ignore self.send_pdu_on_ctrl(hid_message) # type: ignore
def set_report(self, report_type: int, data: bytes): def set_report(self, report_type: int, data: bytes):
msg = SetReportMessage(report_type= report_type,data = data) msg = SetReportMessage(report_type= report_type,data = data)
hid_message = msg.__bytes__() hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}') logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) # type: ignore self.send_pdu_on_ctrl(hid_message) # type: ignore
def get_protocol(self): def get_protocol(self):
msg = GetProtocolMessage() msg = GetProtocolMessage()
hid_message = msg.__bytes__() hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}') logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) # type: ignore self.send_pdu_on_ctrl(hid_message) # type: ignore
def set_protocol(self, protocol_mode: int): def set_protocol(self, protocol_mode: int):
msg = SetProtocolMessage(protocol_mode= protocol_mode) msg = SetProtocolMessage(protocol_mode= protocol_mode)
hid_message = msg.__bytes__() hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}') logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) # type: ignore self.send_pdu_on_ctrl(hid_message) # type: ignore
@@ -277,7 +279,7 @@ class Host(EventEmitter):
def send_data(self, data): def send_data(self, data):
msg = SendData(data) msg = SendData(data)
hid_message = msg.__bytes__() hid_message = bytes(msg)
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}') logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_message) # type: ignore self.send_pdu_on_intr(hid_message) # type: ignore

View File

@@ -20,7 +20,6 @@ import sys
import os import os
import logging import logging
from bumble.colors import color from bumble.colors import color
import bumble.core import bumble.core
@@ -72,16 +71,18 @@ SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID = 0x020E
SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID = 0x020F SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID = 0x020F
SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID = 0x0210 SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID = 0x0210
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_hid_device_sdp_record(device, connection): async def get_hid_device_sdp_record(device, connection):
# Connect to the SDP Server # Connect to the SDP Server
sdp_client = SDP_Client(device) sdp_client = SDP_Client(device)
await sdp_client.connect(connection) await sdp_client.connect(connection)
if sdp_client: if sdp_client:
print(color('Connected ith SDP Server', 'blue')) print(color('Connected to SDP Server', 'blue'))
else: else:
print(color('Failed to connect with SDP Server', 'red')) print(color('Failed to connect to SDP Server', 'red'))
# List BT HID Device service in the root browse group # List BT HID Device service in the root browse group
service_record_handles = await sdp_client.search_services( service_record_handles = await sdp_client.search_services(
@@ -89,9 +90,8 @@ async def get_hid_device_sdp_record(device, connection):
) )
if (len(service_record_handles) < 1): if (len(service_record_handles) < 1):
print(color('BT HID Device service not found on peer device!!!!','red'))
await sdp_client.disconnect() await sdp_client.disconnect()
return raise Exception(color(f'BT HID Device service not found on peer device!!!!','red'))
# For BT_HUMAN_INTERFACE_DEVICE_SERVICE service, get all its attributes # For BT_HUMAN_INTERFACE_DEVICE_SERVICE service, get all its attributes
for service_record_handle in service_record_handles: for service_record_handle in service_record_handles:
@@ -133,89 +133,69 @@ async def get_hid_device_sdp_record(device, connection):
elif attribute.id == SDP_HID_SERVICE_NAME_ATTRIBUTE_ID : elif attribute.id == SDP_HID_SERVICE_NAME_ATTRIBUTE_ID :
print(color(' Service Name: ', 'cyan'), attribute.value.value) print(color(' Service Name: ', 'cyan'), attribute.value.value)
HID_Service_Name = attribute.value.value
elif attribute.id == SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID : elif attribute.id == SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID :
print(color(' Service Description: ', 'cyan'), attribute.value.value) print(color(' Service Description: ', 'cyan'), attribute.value.value)
HID_Service_Description = attribute.value.value
elif attribute.id == SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID : elif attribute.id == SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID :
print(color(' Provider Name: ', 'cyan'), attribute.value.value) print(color(' Provider Name: ', 'cyan'), attribute.value.value)
HID_Provider_Name = attribute.value.value
elif attribute.id == SDP_HID_DEVICE_RELEASE_NUMBER_ATTRIBUTE_ID : elif attribute.id == SDP_HID_DEVICE_RELEASE_NUMBER_ATTRIBUTE_ID :
print(color(' Release Number: ', 'cyan'), hex(attribute.value.value)) print(color(' Release Number: ', 'cyan'), hex(attribute.value.value))
HID_Device_Release_Number = attribute.value.value
elif attribute.id == SDP_HID_PARSER_VERSION_ATTRIBUTE_ID : elif attribute.id == SDP_HID_PARSER_VERSION_ATTRIBUTE_ID :
print(color(' HID Parser Version: ', 'cyan'), hex(attribute.value.value)) print(color(' HID Parser Version: ', 'cyan'), hex(attribute.value.value))
HID_Parser_Version = attribute.value.value
elif attribute.id == SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID : elif attribute.id == SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID :
print(color(' HIDDeviceSubclass: ', 'cyan'), hex(attribute.value.value)) print(color(' HIDDeviceSubclass: ', 'cyan'), hex(attribute.value.value))
HID_Device_Subclass = attribute.value.value
elif attribute.id == SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID : elif attribute.id == SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID :
print(color(' HIDCountryCode: ', 'cyan'), hex(attribute.value.value)) print(color(' HIDCountryCode: ', 'cyan'), hex(attribute.value.value))
HID_Country_Code = attribute.value.value
elif attribute.id == SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID : elif attribute.id == SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID :
print(color(' HIDVirtualCable: ', 'cyan'), attribute.value.value) print(color(' HIDVirtualCable: ', 'cyan'), attribute.value.value)
HID_Virtual_Cable = attribute.value.value
elif attribute.id == SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID : elif attribute.id == SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID :
print(color(' HIDReconnectInitiate: ', 'cyan'), attribute.value.value) print(color(' HIDReconnectInitiate: ', 'cyan'), attribute.value.value)
HID_Reconnect_Initiate = attribute.value.value
elif attribute.id == SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID : elif attribute.id == SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID :
print(color(' HID Report Descriptor type: ', 'cyan'), hex(attribute.value.value[0].value[0].value)) print(color(' HID Report Descriptor type: ', 'cyan'), hex(attribute.value.value[0].value[0].value))
print(color(' HID Report DescriptorList: ', 'cyan'), attribute.value.value[0].value[1].value) print(color(' HID Report DescriptorList: ', 'cyan'), attribute.value.value[0].value[1].value)
HID_Descriptor_Type = attribute.value.value[0].value[0].value
HID_Report_Descriptor_List = attribute.value.value[0].value[1].value
elif attribute.id == SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID : elif attribute.id == SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID :
print(color(' HID LANGID Base Language: ', 'cyan'), hex(attribute.value.value[0].value[0].value)) print(color(' HID LANGID Base Language: ', 'cyan'), hex(attribute.value.value[0].value[0].value))
print(color(' HID LANGID Base Bluetooth String Offset: ', 'cyan'), hex(attribute.value.value[0].value[1].value)) print(color(' HID LANGID Base Bluetooth String Offset: ', 'cyan'), hex(attribute.value.value[0].value[1].value))
HID_LANGID_Base_Language = attribute.value.value[0].value[0].value
HID_LANGID_Base_Bluetooth_String_Offset = attribute.value.value[0].value[1].value
elif attribute.id == SDP_HID_BATTERY_POWER_ATTRIBUTE_ID : elif attribute.id == SDP_HID_BATTERY_POWER_ATTRIBUTE_ID :
print(color(' HIDBatteryPower: ', 'cyan'), attribute.value.value) print(color(' HIDBatteryPower: ', 'cyan'), attribute.value.value)
HID_Battery_Power = attribute.value.value
elif attribute.id == SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID : elif attribute.id == SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID :
print(color(' HIDRemoteWake: ', 'cyan'), attribute.value.value) print(color(' HIDRemoteWake: ', 'cyan'), attribute.value.value)
HID_Remote_Wake = attribute.value.value
elif attribute.id == SDP_HID_PROFILE_VERSION_ATTRIBUTE_ID : elif attribute.id == SDP_HID_PROFILE_VERSION_ATTRIBUTE_ID :
print(color(' HIDProfileVersion : ', 'cyan'), hex(attribute.value.value)) print(color(' HIDProfileVersion : ', 'cyan'), hex(attribute.value.value))
HID_Profile_Version = attribute.value.value
elif attribute.id == SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID : elif attribute.id == SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID :
print(color(' HIDSupervisionTimeout: ', 'cyan'), hex(attribute.value.value)) print(color(' HIDSupervisionTimeout: ', 'cyan'), hex(attribute.value.value))
HID_Supervision_Timeout = attribute.value.value
elif attribute.id == SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID : elif attribute.id == SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID :
print(color(' HIDNormallyConnectable: ', 'cyan'), attribute.value.value) print(color(' HIDNormallyConnectable: ', 'cyan'), attribute.value.value)
HID_Normally_Connectable = attribute.value.value
elif attribute.id == SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID : elif attribute.id == SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID :
print(color(' HIDBootDevice: ', 'cyan'), attribute.value.value) print(color(' HIDBootDevice: ', 'cyan'), attribute.value.value)
HID_Boot_Device = attribute.value.value
elif attribute.id == SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID : elif attribute.id == SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID :
print(color(' HIDSSRHostMaxLatency: ', 'cyan'), hex(attribute.value.value)) print(color(' HIDSSRHostMaxLatency: ', 'cyan'), hex(attribute.value.value))
HID_SSR_Host_Max_Latency = attribute.value.value
elif attribute.id == SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID : elif attribute.id == SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID :
print(color(' HIDSSRHostMinTimeout: ', 'cyan'), hex(attribute.value.value)) print(color(' HIDSSRHostMinTimeout: ', 'cyan'), hex(attribute.value.value))
HID_SSR_Host_Min_Timeout = attribute.value.value
else: else:
print(color(f' Warning: Attribute ID: {attribute.id} match not found.\n Attribute Info: {attribute}', 'yellow')) print(color(f' Warning: Attribute ID: {attribute.id} match not found.\n Attribute Info: {attribute}', 'yellow'))
await sdp_client.disconnect() await sdp_client.disconnect()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_stream_reader(pipe) -> asyncio.StreamReader: async def get_stream_reader(pipe) -> asyncio.StreamReader:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -224,6 +204,7 @@ async def get_stream_reader(pipe) -> asyncio.StreamReader:
await loop.connect_read_pipe(lambda: protocol, pipe) await loop.connect_read_pipe(lambda: protocol, pipe)
return reader return reader
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 4: if len(sys.argv) < 4:
@@ -461,5 +442,7 @@ async def main():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())