diff --git a/bumble/hid.py b/bumble/hid.py index 32053233..23fdfb3a 100644 --- a/bumble/hid.py +++ b/bumble/hid.py @@ -308,16 +308,9 @@ class HID(EventEmitter): elif message_type == Message.MessageType.GET_REPORT: logger.debug('<<< HID GET REPORT') self.handle_get_report(pdu) - elif message_type == Message.MessageType.SET_REPORT: logger.debug('<<< HID SET REPORT') - report_type = pdu[0] & 3 - report = pdu[2:] - report_id = pdu[1] - logger.debug(report_id) - logger.debug(report_type) - #TODO: to check for size mentioned in report descriptor - self.emit('set_report', report_id, report) + self.handle_set_report(pdu) elif message_type == Message.MessageType.GET_PROTOCOL: logger.debug('<<< HID GET PROTOCOL') self.emit('get_protocol') @@ -433,11 +426,26 @@ class Device(HID): self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) else: logger.debug("GetReport callback not registered !!") - def register_get_report_cb(self,cb): self.get_report_cb=cb logger.debug("GetReport callback registered successfully") + + def handle_set_report(self, pdu: bytes): + if(self.set_report_cb != None): + report_type=pdu[0] & 0x03 + report_id = pdu[1] + report_data = pdu[2:] + ret = self.set_report_cb(report_id, report_type, report_data) + if(ret.status == self.ReportStatus.SUCCESS): + self.send_handshake_message(Message.Handshake.SUCCESSFUL) + else: + self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER) + + def register_set_report_cb(self, cb): + self.set_report_cb=cb + logger.debug("SetReport callback registered successfully") + # ----------------------------------------------------------------------------- class Host(HID): def __init__(self, device: Device) -> None: diff --git a/examples/run_hid_device.py b/examples/run_hid_device.py index 40d58a97..6e855cee 100644 --- a/examples/run_hid_device.py +++ b/examples/run_hid_device.py @@ -520,9 +520,9 @@ async def main(): def on_get_report_cb(report_id,report_type, buffer_size): retValue = hid_device.GetReportStatus() - + print("GET_REPORT report_id: " + str(report_id) +"report_type: "+ str(report_type)+ + "buffer_size:" + str(buffer_size)) if report_type == Message.ReportType.INPUT_REPORT: - print("GET_REPORT - inputType") if report_id == 1: retValue.data = keyboardData retValue.status = hid_device.ReportStatus.SUCCESS @@ -536,7 +536,6 @@ async def main(): data_len = buffer_size -1 retValue.data = retValue.data[:data_len] elif report_type == Message.ReportType.OUTPUT_REPORT: - print("GET_REPORT - outputType") #This sample app has nothing to do with the report received, to enable PTS #testing, we will return single byte random data. retValue.data = bytearray([0x11]) @@ -544,13 +543,20 @@ async def main(): elif report_type == Message.ReportType.FEATURE_REPORT: #TBD - not requried for PTS testing - print("GET_REPORT - FeatureReport") retValue.status = hid_device.ReportStatus.ERR_UNSUPPORTED_REQUEST else: retValue.status = hid_device.ReportStatus.FAILURE return retValue + + def on_set_report_cb(report_id, report_type, data): + retValue = hid_device.GetReportStatus() + print("SET_REPORT report_id: " + str(report_id) +"report_type: "+ str(report_type)+ + "data:" + str(data)) + retValue.status = hid_device.ReportStatus.SUCCESS + return retValue + def on_set_protocol_cb(param): if HID_BOOT_DEVICE: @@ -577,11 +583,11 @@ async def main(): # Register for call backs hid_device.on('interrupt_data', on_hid_data_cb) - hid_device.on('set_report', on_set_report_cb) hid_device.on('get_protocol', on_get_protocol_cb) hid_device.on('set_protocol', on_set_protocol_cb) hid_device.register_get_report_cb(on_get_report_cb) + hid_device.register_set_report_cb(on_set_report_cb) # Register for virtual cable unplug call back hid_device.on('virtual_cable_unplug', on_virtual_cable_unplug_cb)