Merge pull request #915 from dlech/notify-subscribers-type-hints

improve type hints for notify/indicate subscriber(s) methods
This commit is contained in:
Gilles Boccon-Gibod
2026-04-27 21:45:38 +02:00
committed by GitHub
2 changed files with 30 additions and 28 deletions

View File

@@ -5618,8 +5618,8 @@ class Device(utils.CompositeEventEmitter):
async def notify_subscriber( async def notify_subscriber(
self, self,
connection: Connection, connection: Connection,
attribute: Attribute, attribute: Attribute[_T],
value: Any | None = None, value: _T | None = None,
force: bool = False, force: bool = False,
) -> None: ) -> None:
""" """
@@ -5638,7 +5638,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.notify_subscriber(connection, attribute, value, force) await self.gatt_server.notify_subscriber(connection, attribute, value, force)
async def notify_subscribers( async def notify_subscribers(
self, attribute: Attribute, value: Any | None = None, force: bool = False self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
) -> None: ) -> None:
""" """
Send a notification to all the subscribers of an attribute. Send a notification to all the subscribers of an attribute.
@@ -5657,8 +5657,8 @@ class Device(utils.CompositeEventEmitter):
async def indicate_subscriber( async def indicate_subscriber(
self, self,
connection: Connection, connection: Connection,
attribute: Attribute, attribute: Attribute[_T],
value: Any | None = None, value: _T | None = None,
force: bool = False, force: bool = False,
): ):
""" """
@@ -5679,7 +5679,7 @@ class Device(utils.CompositeEventEmitter):
await self.gatt_server.indicate_subscriber(connection, attribute, value, force) await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
async def indicate_subscribers( async def indicate_subscribers(
self, attribute: Attribute, value: Any | None = None, force: bool = False self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
): ):
""" """
Send an indication to all the subscribers of an attribute. Send an indication to all the subscribers of an attribute.

View File

@@ -67,6 +67,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
# Helpers # Helpers
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
_T = TypeVar('_T')
def _bearer_id(bearer: att.Bearer) -> str: def _bearer_id(bearer: att.Bearer) -> str:
if att.is_enhanced_bearer(bearer): if att.is_enhanced_bearer(bearer):
@@ -369,8 +371,8 @@ class Server(utils.EventEmitter):
async def notify_subscriber( async def notify_subscriber(
self, self,
bearer: att.Bearer, bearer: att.Bearer,
attribute: att.Attribute, attribute: att.Attribute[_T],
value: bytes | None = None, value: _T | None = None,
force: bool = False, force: bool = False,
) -> None: ) -> None:
if att.is_enhanced_bearer(bearer) or force: if att.is_enhanced_bearer(bearer) or force:
@@ -390,8 +392,8 @@ class Server(utils.EventEmitter):
async def _notify_single_subscriber( async def _notify_single_subscriber(
self, self,
bearer: att.Bearer, bearer: att.Bearer,
attribute: att.Attribute, attribute: att.Attribute[_T],
value: bytes | None, value: _T | None,
force: bool, force: bool,
) -> None: ) -> None:
# Check if there's a subscriber # Check if there's a subscriber
@@ -411,19 +413,19 @@ class Server(utils.EventEmitter):
return return
# Get or encode the value # Get or encode the value
value = ( value_as_bytes = (
await attribute.read_value(bearer) await attribute.read_value(bearer)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
# Truncate if needed # Truncate if needed
if len(value) > bearer.att_mtu - 3: if len(value_as_bytes) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3] value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Notify # Notify
notification = att.ATT_Handle_Value_Notification( notification = att.ATT_Handle_Value_Notification(
attribute_handle=attribute.handle, attribute_value=value attribute_handle=attribute.handle, attribute_value=value_as_bytes
) )
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}') logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
self.send_gatt_pdu(bearer, bytes(notification)) self.send_gatt_pdu(bearer, bytes(notification))
@@ -431,8 +433,8 @@ class Server(utils.EventEmitter):
async def indicate_subscriber( async def indicate_subscriber(
self, self,
bearer: att.Bearer, bearer: att.Bearer,
attribute: att.Attribute, attribute: att.Attribute[_T],
value: bytes | None = None, value: _T | None = None,
force: bool = False, force: bool = False,
) -> None: ) -> None:
if att.is_enhanced_bearer(bearer) or force: if att.is_enhanced_bearer(bearer) or force:
@@ -452,8 +454,8 @@ class Server(utils.EventEmitter):
async def _indicate_single_bearer( async def _indicate_single_bearer(
self, self,
bearer: att.Bearer, bearer: att.Bearer,
attribute: att.Attribute, attribute: att.Attribute[_T],
value: bytes | None, value: _T | None,
force: bool, force: bool,
) -> None: ) -> None:
# Check if there's a subscriber # Check if there's a subscriber
@@ -473,19 +475,19 @@ class Server(utils.EventEmitter):
return return
# Get or encode the value # Get or encode the value
value = ( value_as_bytes = (
await attribute.read_value(bearer) await attribute.read_value(bearer)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
# Truncate if needed # Truncate if needed
if len(value) > bearer.att_mtu - 3: if len(value_as_bytes) > bearer.att_mtu - 3:
value = value[: bearer.att_mtu - 3] value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
# Indicate # Indicate
indication = att.ATT_Handle_Value_Indication( indication = att.ATT_Handle_Value_Indication(
attribute_handle=attribute.handle, attribute_value=value attribute_handle=attribute.handle, attribute_value=value_as_bytes
) )
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}') logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
@@ -510,8 +512,8 @@ class Server(utils.EventEmitter):
async def _notify_or_indicate_subscribers( async def _notify_or_indicate_subscribers(
self, self,
indicate: bool, indicate: bool,
attribute: att.Attribute, attribute: att.Attribute[_T],
value: bytes | None = None, value: _T | None = None,
force: bool = False, force: bool = False,
) -> None: ) -> None:
# Get all the bearers for which there's at least one subscription # Get all the bearers for which there's at least one subscription
@@ -537,8 +539,8 @@ class Server(utils.EventEmitter):
async def notify_subscribers( async def notify_subscribers(
self, self,
attribute: att.Attribute, attribute: att.Attribute[_T],
value: bytes | None = None, value: _T | None = None,
force: bool = False, force: bool = False,
): ):
return await self._notify_or_indicate_subscribers( return await self._notify_or_indicate_subscribers(
@@ -547,8 +549,8 @@ class Server(utils.EventEmitter):
async def indicate_subscribers( async def indicate_subscribers(
self, self,
attribute: att.Attribute, attribute: att.Attribute[_T],
value: bytes | None = None, value: _T | None = None,
force: bool = False, force: bool = False,
): ):
return await self._notify_or_indicate_subscribers(True, attribute, value, force) return await self._notify_or_indicate_subscribers(True, attribute, value, force)