Review comment Fix 3 - rename json file and usage of Optional in parameters

This commit is contained in:
skarnataki
2023-12-15 09:42:57 +00:00
parent 93c0875740
commit 9da2e32ad7
3 changed files with 23 additions and 16 deletions

View File

@@ -27,6 +27,7 @@ from typing import Optional, Callable, TYPE_CHECKING
from bumble import l2cap, device from bumble import l2cap, device
from bumble.colors import color from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError from bumble.core import InvalidStateError, ProtocolError
from .hci import Address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -203,6 +204,7 @@ class HID(EventEmitter):
def __init__(self, device: device.Device, role: Role) -> None: def __init__(self, device: device.Device, role: Role) -> None:
super().__init__() super().__init__()
self.remote_device_bd_address: Optional[Address] = None
self.device = device self.device = device
self.role = role self.role = role
@@ -213,19 +215,19 @@ class HID(EventEmitter):
device.on('connection', self.on_device_connection) device.on('connection', self.on_device_connection)
def handle_get_report(self, pdu: bytes): def handle_get_report(self, pdu: bytes):
return pass
def handle_set_report(self, pdu: bytes): def handle_set_report(self, pdu: bytes):
return pass
def handle_get_protocol(self, pdu: bytes): def handle_get_protocol(self, pdu: bytes):
return pass
def handle_set_protocol(self, pdu: bytes): def handle_set_protocol(self, pdu: bytes):
return pass
def send_handshake_message(self, result_code: int): def send_handshake_message(self, result_code: int):
return pass
async def connect_control_channel(self) -> None: async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel # Create a new L2CAP connection - control channel
@@ -271,6 +273,7 @@ class HID(EventEmitter):
def on_device_connection(self, connection: device.Connection) -> None: def on_device_connection(self, connection: device.Connection) -> None:
self.connection = connection self.connection = connection
self.remote_device_bd_address = connection.peer_address
connection.on('disconnection', self.on_device_disconnection) connection.on('disconnection', self.on_device_disconnection)
def on_device_disconnection(self, reason: int) -> None: def on_device_disconnection(self, reason: int) -> None:
@@ -388,7 +391,7 @@ class Device(HID):
get_report_cb: Optional[Callable[[int, int, int], None]] = None get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int, bytes], None]] = None set_protocol_cb: Optional[Callable[[int], None]] = None
def send_handshake_message(self, result_code: int) -> None: def send_handshake_message(self, result_code: int) -> None:
msg = SendHandshakeMessage(result_code) msg = SendHandshakeMessage(result_code)
@@ -417,7 +420,7 @@ class Device(HID):
buffer_size = 0 buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size) ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE: if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN) self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS: elif ret.status == self.GetSetReturn.SUCCESS:
@@ -435,7 +438,7 @@ class Device(HID):
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST: elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb): def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
self.get_report_cb = cb self.get_report_cb = cb
logger.debug("GetReport callback registered successfully") logger.debug("GetReport callback registered successfully")
@@ -449,6 +452,7 @@ class Device(HID):
report_data = pdu[2:] report_data = pdu[2:]
report_size = len(report_data) + 1 report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data) ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS: if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL) self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER: elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
@@ -458,7 +462,9 @@ class Device(HID):
else: else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(self, cb): def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None]
) -> None:
self.set_report_cb = cb self.set_report_cb = cb
logger.debug("SetReport callback registered successfully") logger.debug("SetReport callback registered successfully")
@@ -468,12 +474,13 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return return
ret = self.get_protocol_cb() ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS: if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data) self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else: else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb): def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
self.get_protocol_cb = cb self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully") logger.debug("GetProtocol callback registered successfully")
@@ -483,12 +490,13 @@ class Device(HID):
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return return
ret = self.set_protocol_cb(pdu[0] & 0x01) ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS: if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL) self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else: else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb): def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
self.set_protocol_cb = cb self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully") logger.debug("SetProtocol callback registered successfully")

View File

@@ -496,8 +496,8 @@ async def main():
' web (run a keyboard with keypress input from a web page, ' ' web (run a keyboard with keypress input from a web page, '
'see keyboard.html' 'see keyboard.html'
) )
print('example: python run_hid_device.py classic3.json usb:0 web') print('example: python run_hid_device.py hid_keyboard.json usb:0 web')
print('example: python run_hid_device.py classic3.json usb:0 test-mode') print('example: python run_hid_device.py hid_keyboard.json usb:0 test-mode')
return return
@@ -732,13 +732,12 @@ async def main():
else: else:
print("Invalid option selected.") print("Invalid option selected.")
if (len(sys.argv) > 3) and (command == 'test-mode'): if (len(sys.argv) > 3) and (sys.argv[3] == 'test-mode'):
# Test mode for PTS/Unit testing # Test mode for PTS/Unit testing
command = sys.argv[3]
await menu() await menu()
else: else:
# default option is using keyboard.html (web) # default option is using keyboard.html (web)
print("Command incorrect. Switching to default") print("Executing in Web mode")
await keyboard_device(hid_device) await keyboard_device(hid_device)
await hci_source.wait_for_termination() await hci_source.wait_for_termination()