improve type hints for notify/indicate subscriber(s) methods

Pyright expects generic type parameters to be specified for the
Attribute class, otherwise it treats the type as Unknown which can
trigger reportUnknownMemberType errors.

This can be solved by using a generic type parameter for these methods
which also has the benefit of making sure that the value parameter has
the correct type for the attribute.

In some cases, a new local `value_as_bytes` variable is needed to avoid
type errors and makes the code less confusing by not overwriting the
original `value` variable.
This commit is contained in:
David Lechner
2026-04-25 16:34:28 -05:00
parent 27d02ef18d
commit baa5257780
2 changed files with 30 additions and 28 deletions

View File

@@ -5618,8 +5618,8 @@ class Device(utils.CompositeEventEmitter):
async def notify_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Any | None = None,
attribute: Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
"""
@@ -5638,7 +5638,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
async def notify_subscribers(
self, attribute: Attribute, value: Any | None = None, force: bool = False
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
) -> None:
"""
Send a notification to all the subscribers of an attribute.
@@ -5657,8 +5657,8 @@ class Device(utils.CompositeEventEmitter):
async def indicate_subscriber(
self,
connection: Connection,
attribute: Attribute,
value: Any | None = None,
attribute: Attribute[_T],
value: _T | None = None,
force: bool = False,
):
"""
@@ -5679,7 +5679,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
async def indicate_subscribers(
self, attribute: Attribute, value: Any | None = None, force: bool = False
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
):
"""
Send an indication to all the subscribers of an attribute.

View File

@@ -67,6 +67,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# Helpers
# -----------------------------------------------------------------------------
_T = TypeVar('_T')
def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer):
@@ -369,8 +371,8 @@ class Server(utils.EventEmitter):
async def notify_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -390,8 +392,8 @@ class Server(utils.EventEmitter):
async def _notify_single_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
attribute: att.Attribute[_T],
value: _T | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -411,19 +413,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value = (
value_as_bytes = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Notify
notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value
attribute_handle=attribute.handle, attribute_value=value_as_bytes
)
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification))
@@ -431,8 +433,8 @@ class Server(utils.EventEmitter):
async def indicate_subscriber(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
if att.is_enhanced_bearer(bearer) or force:
@@ -452,8 +454,8 @@ class Server(utils.EventEmitter):
async def _indicate_single_bearer(
self,
bearer: att.Bearer,
attribute: att.Attribute,
value: bytes | None,
attribute: att.Attribute[_T],
value: _T | None,
force: bool,
) -> None:
# Check if there's a subscriber
@@ -473,19 +475,19 @@ class Server(utils.EventEmitter):
return
# Get or encode the value
value = (
value_as_bytes = (
await attribute.read_value(bearer)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3]
if len(value_as_bytes) > bearer.att_mtu - 3:
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Indicate
indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value
attribute_handle=attribute.handle, attribute_value=value_as_bytes
)
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
@@ -510,8 +512,8 @@ class Server(utils.EventEmitter):
async def _notify_or_indicate_subscribers(
self,
indicate: bool,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
) -> None:
# Get all the bearers for which there's at least one subscription
@@ -537,8 +539,8 @@ class Server(utils.EventEmitter):
async def notify_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(
@@ -547,8 +549,8 @@ class Server(utils.EventEmitter):
async def indicate_subscribers(
self,
attribute: att.Attribute,
value: bytes | None = None,
attribute: att.Attribute[_T],
value: _T | None = None,
force: bool = False,
):
return await self._notify_or_indicate_subscribers(True, attribute, value, force)