Renamed the status message class

This commit is contained in:
dhavan
2023-11-22 17:14:24 +00:00
parent dc410b14c4
commit d6cefdff8e
2 changed files with 25 additions and 26 deletions

View File

@@ -298,7 +298,6 @@ class HID(EventEmitter):
def on_ctrl_pdu(self, pdu: bytes) -> None: def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
# Here we will receive all kinds of packets, parse and then call respective callbacks
param = pdu[0] & 0x0F param = pdu[0] & 0x0F
message_type = pdu[0] >> 4 message_type = pdu[0] >> 4
@@ -369,7 +368,7 @@ class HID(EventEmitter):
class Device(HID): class Device(HID):
class ReportStatus(enum.IntEnum): class GetSetReturn(enum.IntEnum):
FAILURE = 0x00 FAILURE = 0x00
REPORT_ID_NOT_FOUND = 0x01 REPORT_ID_NOT_FOUND = 0x01
ERR_UNSUPPORTED_REQUEST = 0x02 ERR_UNSUPPORTED_REQUEST = 0x02
@@ -377,7 +376,7 @@ class Device(HID):
SUCCESS = 0xff SUCCESS = 0xff
class GetReportStatus(): class GetSetStatus():
def __init__(self) -> None: def __init__(self) -> None:
self.status = 0 self.status = 0
self.data=None self.data=None
@@ -398,7 +397,7 @@ class Device(HID):
self.send_pdu_on_ctrl(hid_message) self.send_pdu_on_ctrl(hid_message)
def handle_get_report(self, pdu: bytes): def handle_get_report(self, pdu: bytes):
ret = self.GetReportStatus() ret = self.GetSetStatus()
report_type=pdu[0] & 0x03 report_type=pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3 buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1] report_id = pdu[1]
@@ -411,18 +410,18 @@ class Device(HID):
if(self.get_report_cb != None): if(self.get_report_cb != None):
ret = self.get_report_cb(report_id, report_type, buffer_size) ret = self.get_report_cb(report_id, report_type, buffer_size)
if(ret.status == self.ReportStatus.FAILURE): if(ret.status == self.GetSetReturn.FAILURE):
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif(ret.status == self.ReportStatus.SUCCESS): elif(ret.status == self.GetSetReturn.SUCCESS):
data = bytearray() data = bytearray()
data.append(report_id) data.append(report_id)
data.extend(ret.data) data.extend(ret.data)
#TODO Check the data size and MTU size here and only then send out #TODO Check the data size and MTU size here and only then send out
#the message #the message
self.send_control_data(report_type=report_type, data = data) self.send_control_data(report_type=report_type, data = data)
elif(ret.status == self.ReportStatus.REPORT_ID_NOT_FOUND): elif(ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND):
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID) self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif(ret.status == self.ReportStatus.ERR_UNSUPPORTED_REQUEST): elif(ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
else: else:
logger.debug("GetReport callback not registered !!") logger.debug("GetReport callback not registered !!")
@@ -438,7 +437,7 @@ class Device(HID):
report_id = pdu[1] report_id = pdu[1]
report_data = pdu[2:] report_data = pdu[2:]
ret = self.set_report_cb(report_id, report_type, report_data) ret = self.set_report_cb(report_id, report_type, report_data)
if(ret.status == self.ReportStatus.SUCCESS): if(ret.status == self.GetSetReturn.SUCCESS):
self.send_handshake_message(Message.Handshake.SUCCESSFUL) self.send_handshake_message(Message.Handshake.SUCCESSFUL)
return return
else: else:
@@ -451,10 +450,10 @@ class Device(HID):
logger.debug("SetReport callback registered successfully") logger.debug("SetReport callback registered successfully")
def handle_get_protocol(self, pdu: bytes): def handle_get_protocol(self, pdu: bytes):
ret = self.GetReportStatus() ret = self.GetSetStatus()
if(self.get_protocol_cb != None): if(self.get_protocol_cb != None):
ret=self.get_protocol_cb() ret=self.get_protocol_cb()
if(ret.status == self.ReportStatus.SUCCESS): if(ret.status == self.GetSetReturn.SUCCESS):
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
return return
else: else:
@@ -467,10 +466,10 @@ class Device(HID):
logger.debug("GetProtocol callback registered successfully") logger.debug("GetProtocol callback registered successfully")
def handle_set_protocol(self, pdu: bytes): def handle_set_protocol(self, pdu: bytes):
ret = self.GetReportStatus() ret = self.GetSetStatus()
if(self.set_protocol_cb != None): if(self.set_protocol_cb != None):
ret=self.set_protocol_cb(pdu[0] & 0x01) ret=self.set_protocol_cb(pdu[0] & 0x01)
if(ret.status == self.ReportStatus.SUCCESS): if(ret.status == self.GetSetReturn.SUCCESS):
return return
else: else:
logger.debug("SetProtocol callback not registered !!") logger.debug("SetProtocol callback not registered !!")

View File

@@ -503,18 +503,18 @@ async def main():
print(f'Received Data, PDU: {pdu.hex()}') print(f'Received Data, PDU: {pdu.hex()}')
def on_get_report_cb(report_id,report_type, buffer_size): def on_get_report_cb(report_id,report_type, buffer_size):
retValue = hid_device.GetReportStatus() retValue = hid_device.GetSetStatus()
print("GET_REPORT report_id: " + str(report_id) +"report_type: "+ str(report_type)+ print("GET_REPORT report_id: " + str(report_id) +"report_type: "+ str(report_type)+
"buffer_size:" + str(buffer_size)) "buffer_size:" + str(buffer_size))
if report_type == Message.ReportType.INPUT_REPORT: if report_type == Message.ReportType.INPUT_REPORT:
if report_id == 1: if report_id == 1:
retValue.data = keyboardData retValue.data = keyboardData
retValue.status = hid_device.ReportStatus.SUCCESS retValue.status = hid_device.GetSetReturn.SUCCESS
elif report_id == 2: elif report_id == 2:
retValue.data = mouseData retValue.data = mouseData
retValue.status = hid_device.ReportStatus.SUCCESS retValue.status = hid_device.GetSetReturn.SUCCESS
else: else:
retValue.status = hid_device.ReportStatus.REPORT_ID_NOT_FOUND retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
if(buffer_size): if(buffer_size):
data_len = buffer_size -1 data_len = buffer_size -1
@@ -523,36 +523,36 @@ async def main():
#This sample app has nothing to do with the report received, to enable PTS #This sample app has nothing to do with the report received, to enable PTS
#testing, we will return single byte random data. #testing, we will return single byte random data.
retValue.data = bytearray([0x11]) retValue.data = bytearray([0x11])
retValue.status = hid_device.ReportStatus.SUCCESS retValue.status = hid_device.GetSetReturn.SUCCESS
elif report_type == Message.ReportType.FEATURE_REPORT: elif report_type == Message.ReportType.FEATURE_REPORT:
#TBD - not requried for PTS testing #TBD - not requried for PTS testing
retValue.status = hid_device.ReportStatus.ERR_UNSUPPORTED_REQUEST retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
else: else:
retValue.status = hid_device.ReportStatus.FAILURE retValue.status = hid_device.GetSetReturn.FAILURE
return retValue return retValue
def on_set_report_cb(report_id, report_type, data): def on_set_report_cb(report_id, report_type, data):
retValue = hid_device.GetReportStatus() retValue = hid_device.GetSetStatus()
print("SET_REPORT report_id: " + str(report_id) +"report_type: "+ str(report_type)+ print("SET_REPORT report_id: " + str(report_id) +"report_type: "+ str(report_type)+
"data:" + str(data)) "data:" + str(data))
retValue.status = hid_device.ReportStatus.SUCCESS retValue.status = hid_device.GetSetReturn.SUCCESS
return retValue return retValue
def on_get_protocol_cb(): def on_get_protocol_cb():
retValue = hid_device.GetReportStatus() retValue = hid_device.GetSetStatus()
retValue.data=protocol_mode.to_bytes() retValue.data=protocol_mode.to_bytes()
retValue.status=hid_device.ReportStatus.SUCCESS retValue.status=hid_device.GetSetReturn.SUCCESS
return retValue return retValue
def on_set_protocol_cb(protocol): def on_set_protocol_cb(protocol):
retValue = hid_device.GetReportStatus() retValue = hid_device.GetSetStatus()
#We do not support SET_PROTOCOL #We do not support SET_PROTOCOL
print("SET_PROTOCOL report_id: " + str(protocol)) print("SET_PROTOCOL report_id: " + str(protocol))
retValue.status=hid_device.ReportStatus.ERR_UNSUPPORTED_REQUEST retValue.status=hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
return retValue return retValue
def on_virtual_cable_unplug_cb(): def on_virtual_cable_unplug_cb():