mirror of
https://github.com/google/bumble.git
synced 2026-05-09 04:08:02 +00:00
Review comment fix HID device
This commit is contained in:
160
bumble/hid.py
160
bumble/hid.py
@@ -19,17 +19,15 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
import enum
|
import enum
|
||||||
|
import struct
|
||||||
|
|
||||||
from pyee import EventEmitter
|
from pyee import EventEmitter
|
||||||
from typing import Optional, TYPE_CHECKING
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from bumble import l2cap
|
from bumble import l2cap, device
|
||||||
from bumble.colors import color
|
from bumble.colors import color
|
||||||
from bumble.core import InvalidStateError, ProtocolError
|
from bumble.core import InvalidStateError, ProtocolError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from bumble.device import Device, Connection
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Logging
|
# Logging
|
||||||
@@ -105,10 +103,11 @@ class GetReportMessage(Message):
|
|||||||
if self.buffer_size == 0:
|
if self.buffer_size == 0:
|
||||||
return self.header(self.report_type) + packet_bytes
|
return self.header(self.report_type) + packet_bytes
|
||||||
else:
|
else:
|
||||||
packet_bytes.extend(
|
return (
|
||||||
[(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)]
|
self.header(0x08 | self.report_type)
|
||||||
|
+ packet_bytes
|
||||||
|
+ struct.pack("<H", self.buffer_size)
|
||||||
)
|
)
|
||||||
return self.header(0x08 | self.report_type) + packet_bytes
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -128,10 +127,7 @@ class SendControlData(Message):
|
|||||||
message_type = Message.MessageType.DATA
|
message_type = Message.MessageType.DATA
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
def __bytes__(self) -> bytes:
|
||||||
packet_bytes = bytearray()
|
return self.header(self.report_type) + self.data
|
||||||
|
|
||||||
packet_bytes.extend(self.data)
|
|
||||||
return self.header(self.report_type) + packet_bytes
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -151,17 +147,6 @@ class SetProtocolMessage(Message):
|
|||||||
return self.header(self.protocol_mode)
|
return self.header(self.protocol_mode)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GetProtocolReplyMessage(Message):
|
|
||||||
protocol_mode: int
|
|
||||||
message_type = Message.MessageType.DATA
|
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
|
||||||
packet_bytes = bytearray()
|
|
||||||
packet_bytes.append(self.protocol_mode)
|
|
||||||
return self.header(Message.ReportType.OTHER_REPORT) + packet_bytes
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Suspend(Message):
|
class Suspend(Message):
|
||||||
message_type = Message.MessageType.CONTROL
|
message_type = Message.MessageType.CONTROL
|
||||||
@@ -215,7 +200,7 @@ class HID(EventEmitter):
|
|||||||
HOST = 0x00
|
HOST = 0x00
|
||||||
DEVICE = 0x01
|
DEVICE = 0x01
|
||||||
|
|
||||||
def __init__(self, device: Device, role: int) -> None:
|
def __init__(self, device: device.Device, role: Role) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.connection = None
|
self.connection = None
|
||||||
@@ -273,11 +258,11 @@ class HID(EventEmitter):
|
|||||||
self.l2cap_ctrl_channel = None
|
self.l2cap_ctrl_channel = None
|
||||||
await channel.disconnect()
|
await channel.disconnect()
|
||||||
|
|
||||||
def on_device_connection(self, connection: Connection) -> None:
|
def on_device_connection(self, connection: device.Connection) -> None:
|
||||||
self.connection = connection # type: ignore[assignment]
|
self.connection = connection # type: ignore[assignment]
|
||||||
self.remote_device_bd_address = (
|
self.remote_device_bd_address = (
|
||||||
connection.peer_address
|
connection.peer_address # type: ignore[assignment]
|
||||||
) # type: ignore[assignment]
|
)
|
||||||
connection.on('disconnection', self.on_disconnection)
|
connection.on('disconnection', self.on_disconnection)
|
||||||
|
|
||||||
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||||
@@ -341,9 +326,9 @@ class HID(EventEmitter):
|
|||||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||||
else:
|
else:
|
||||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||||
self.send_handshake_message(
|
self.send_handshake_message( # type: ignore[attr-defined]
|
||||||
Message.Handshake.ERR_UNSUPPORTED_REQUEST
|
Message.Handshake.ERR_UNSUPPORTED_REQUEST
|
||||||
) # type: ignore[attr-defined]
|
)
|
||||||
|
|
||||||
def on_intr_pdu(self, pdu: bytes) -> None:
|
def on_intr_pdu(self, pdu: bytes) -> None:
|
||||||
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
|
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
|
||||||
@@ -378,7 +363,7 @@ class HID(EventEmitter):
|
|||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class Device(HID): # type: ignore[no-redef]
|
class Device(HID):
|
||||||
class GetSetReturn(enum.IntEnum):
|
class GetSetReturn(enum.IntEnum):
|
||||||
FAILURE = 0x00
|
FAILURE = 0x00
|
||||||
REPORT_ID_NOT_FOUND = 0x01
|
REPORT_ID_NOT_FOUND = 0x01
|
||||||
@@ -389,10 +374,10 @@ class Device(HID): # type: ignore[no-redef]
|
|||||||
|
|
||||||
class GetSetStatus:
|
class GetSetStatus:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.data: bytes
|
self.data = bytearray()
|
||||||
self.status = 0
|
self.status = 0
|
||||||
|
|
||||||
def __init__(self, device: Device) -> None:
|
def __init__(self, device: device.Device) -> None:
|
||||||
super().__init__(device, HID.Role.DEVICE)
|
super().__init__(device, HID.Role.DEVICE)
|
||||||
self.get_report_cb = None
|
self.get_report_cb = None
|
||||||
self.set_report_cb = None
|
self.set_report_cb = None
|
||||||
@@ -412,7 +397,10 @@ class Device(HID): # type: ignore[no-redef]
|
|||||||
self.send_pdu_on_ctrl(hid_message)
|
self.send_pdu_on_ctrl(hid_message)
|
||||||
|
|
||||||
def handle_get_report(self, pdu: bytes):
|
def handle_get_report(self, pdu: bytes):
|
||||||
ret = self.GetSetStatus()
|
if self.get_report_cb is None:
|
||||||
|
logger.debug("GetReport callback not registered !!")
|
||||||
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
return
|
||||||
report_type = pdu[0] & 0x03
|
report_type = pdu[0] & 0x03
|
||||||
buffer_flag = (pdu[0] & 0x08) >> 3
|
buffer_flag = (pdu[0] & 0x08) >> 3
|
||||||
report_id = pdu[1]
|
report_id = pdu[1]
|
||||||
@@ -422,32 +410,23 @@ class Device(HID): # type: ignore[no-redef]
|
|||||||
else:
|
else:
|
||||||
buffer_size = 0
|
buffer_size = 0
|
||||||
|
|
||||||
if self.get_report_cb != None:
|
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
||||||
ret = self.get_report_cb(
|
|
||||||
report_id, report_type, buffer_size
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
if ret.status == self.GetSetReturn.FAILURE:
|
if ret.status == self.GetSetReturn.FAILURE:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
||||||
elif ret.status == self.GetSetReturn.SUCCESS:
|
elif ret.status == self.GetSetReturn.SUCCESS:
|
||||||
data = bytearray()
|
data = bytearray()
|
||||||
data.append(report_id)
|
data.append(report_id)
|
||||||
data.extend(ret.data)
|
data.extend(ret.data)
|
||||||
if (
|
if len(data) < self.l2cap_ctrl_channel.mtu:
|
||||||
len(data) < self.l2cap_ctrl_channel.mtu
|
self.send_control_data(report_type=report_type, data=data)
|
||||||
): # type: ignore[union-attr]
|
else:
|
||||||
self.send_control_data(report_type=report_type, data=data)
|
|
||||||
else:
|
|
||||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
|
||||||
|
|
||||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
|
||||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
|
||||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
|
||||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||||
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||||
else:
|
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||||
logger.debug("GetReport callback not registered !!")
|
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||||
|
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_get_report_cb(self, cb):
|
def register_get_report_cb(self, cb):
|
||||||
@@ -455,56 +434,55 @@ class Device(HID): # type: ignore[no-redef]
|
|||||||
logger.debug("GetReport callback registered successfully")
|
logger.debug("GetReport callback registered successfully")
|
||||||
|
|
||||||
def handle_set_report(self, pdu: bytes):
|
def handle_set_report(self, pdu: bytes):
|
||||||
if self.set_report_cb != None:
|
if self.set_report_cb is None:
|
||||||
report_type = pdu[0] & 0x03
|
|
||||||
report_id = pdu[1]
|
|
||||||
report_data = pdu[2:]
|
|
||||||
report_size = len(pdu[1:])
|
|
||||||
ret = self.set_report_cb(
|
|
||||||
report_id, report_type, report_size, report_data
|
|
||||||
) # type: ignore
|
|
||||||
if ret.status == self.GetSetReturn.SUCCESS:
|
|
||||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
|
||||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
|
||||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
|
||||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
|
||||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
|
||||||
else:
|
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
|
||||||
else:
|
|
||||||
logger.debug("SetReport callback not registered !!")
|
logger.debug("SetReport callback not registered !!")
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
return
|
||||||
|
report_type = pdu[0] & 0x03
|
||||||
|
report_id = pdu[1]
|
||||||
|
report_data = pdu[2:]
|
||||||
|
report_size = len(pdu[1:])
|
||||||
|
ret = self.set_report_cb(
|
||||||
|
report_id, report_type, report_size, report_data
|
||||||
|
) # type: ignore
|
||||||
|
if ret.status == self.GetSetReturn.SUCCESS:
|
||||||
|
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||||
|
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||||
|
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||||
|
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||||
|
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||||
|
else:
|
||||||
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_set_report_cb(self, cb):
|
def register_set_report_cb(self, cb):
|
||||||
self.set_report_cb = cb
|
self.set_report_cb = cb
|
||||||
logger.debug("SetReport callback registered successfully")
|
logger.debug("SetReport callback registered successfully")
|
||||||
|
|
||||||
def handle_get_protocol(self, pdu: bytes):
|
def handle_get_protocol(self, pdu: bytes):
|
||||||
ret = self.GetSetStatus()
|
if self.get_protocol_cb is None:
|
||||||
if self.get_protocol_cb != None:
|
|
||||||
ret = self.get_protocol_cb() # type: ignore
|
|
||||||
if ret.status == self.GetSetReturn.SUCCESS:
|
|
||||||
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
logger.debug("GetProtocol callback not registered !!")
|
logger.debug("GetProtocol callback not registered !!")
|
||||||
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
return
|
||||||
|
ret = self.get_protocol_cb()
|
||||||
|
if ret.status == self.GetSetReturn.SUCCESS:
|
||||||
|
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
||||||
|
else:
|
||||||
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_get_protocol_cb(self, cb):
|
def register_get_protocol_cb(self, cb):
|
||||||
self.get_protocol_cb = cb
|
self.get_protocol_cb = cb
|
||||||
logger.debug("GetProtocol callback registered successfully")
|
logger.debug("GetProtocol callback registered successfully")
|
||||||
|
|
||||||
def handle_set_protocol(self, pdu: bytes):
|
def handle_set_protocol(self, pdu: bytes):
|
||||||
ret = self.GetSetStatus()
|
if self.set_protocol_cb is None:
|
||||||
if self.set_protocol_cb != None:
|
|
||||||
ret = self.set_protocol_cb(pdu[0] & 0x01) # type: ignore
|
|
||||||
if ret.status == self.GetSetReturn.SUCCESS:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
logger.debug("SetProtocol callback not registered !!")
|
logger.debug("SetProtocol callback not registered !!")
|
||||||
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
return
|
||||||
|
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
||||||
|
if ret.status == self.GetSetReturn.SUCCESS:
|
||||||
|
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||||
|
else:
|
||||||
|
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||||
|
|
||||||
def register_set_protocol_cb(self, cb):
|
def register_set_protocol_cb(self, cb):
|
||||||
self.set_protocol_cb = cb
|
self.set_protocol_cb = cb
|
||||||
@@ -513,7 +491,7 @@ class Device(HID): # type: ignore[no-redef]
|
|||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
class Host(HID):
|
class Host(HID):
|
||||||
def __init__(self, device: Device) -> None:
|
def __init__(self, device: device.Device) -> None:
|
||||||
super().__init__(device, HID.Role.HOST)
|
super().__init__(device, HID.Role.HOST)
|
||||||
|
|
||||||
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
|
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user