mirror of
https://github.com/google/bumble.git
synced 2026-05-09 04:08:02 +00:00
improve type hints for notify/indicate subscriber(s) methods
Pyright expects generic type parameters to be specified for the Attribute class, otherwise it treats the type as Unknown which can trigger reportUnknownMemberType errors. This can be solved by using a generic type parameter for these methods which also has the benefit of making sure that the value parameter has the correct type for the attribute. In some cases, a new local `value_as_bytes` variable is needed to avoid type errors and makes the code less confusing by not overwriting the original `value` variable.
This commit is contained in:
@@ -5618,8 +5618,8 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
async def notify_subscriber(
|
async def notify_subscriber(
|
||||||
self,
|
self,
|
||||||
connection: Connection,
|
connection: Connection,
|
||||||
attribute: Attribute,
|
attribute: Attribute[_T],
|
||||||
value: Any | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -5638,7 +5638,7 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
|
await self.gatt_server.notify_subscriber(connection, attribute, value, force)
|
||||||
|
|
||||||
async def notify_subscribers(
|
async def notify_subscribers(
|
||||||
self, attribute: Attribute, value: Any | None = None, force: bool = False
|
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Send a notification to all the subscribers of an attribute.
|
Send a notification to all the subscribers of an attribute.
|
||||||
@@ -5657,8 +5657,8 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
async def indicate_subscriber(
|
async def indicate_subscriber(
|
||||||
self,
|
self,
|
||||||
connection: Connection,
|
connection: Connection,
|
||||||
attribute: Attribute,
|
attribute: Attribute[_T],
|
||||||
value: Any | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -5679,7 +5679,7 @@ class Device(utils.CompositeEventEmitter):
|
|||||||
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
|
await self.gatt_server.indicate_subscriber(connection, attribute, value, force)
|
||||||
|
|
||||||
async def indicate_subscribers(
|
async def indicate_subscribers(
|
||||||
self, attribute: Attribute, value: Any | None = None, force: bool = False
|
self, attribute: Attribute[_T], value: _T | None = None, force: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send an indication to all the subscribers of an attribute.
|
Send an indication to all the subscribers of an attribute.
|
||||||
|
|||||||
@@ -67,6 +67,8 @@ GATT_SERVER_DEFAULT_MAX_MTU = 517
|
|||||||
# Helpers
|
# Helpers
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_T = TypeVar('_T')
|
||||||
|
|
||||||
|
|
||||||
def _bearer_id(bearer: att.Bearer) -> str:
|
def _bearer_id(bearer: att.Bearer) -> str:
|
||||||
if att.is_enhanced_bearer(bearer):
|
if att.is_enhanced_bearer(bearer):
|
||||||
@@ -369,8 +371,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def notify_subscriber(
|
async def notify_subscriber(
|
||||||
self,
|
self,
|
||||||
bearer: att.Bearer,
|
bearer: att.Bearer,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if att.is_enhanced_bearer(bearer) or force:
|
if att.is_enhanced_bearer(bearer) or force:
|
||||||
@@ -390,8 +392,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def _notify_single_subscriber(
|
async def _notify_single_subscriber(
|
||||||
self,
|
self,
|
||||||
bearer: att.Bearer,
|
bearer: att.Bearer,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None,
|
value: _T | None,
|
||||||
force: bool,
|
force: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Check if there's a subscriber
|
# Check if there's a subscriber
|
||||||
@@ -411,19 +413,19 @@ class Server(utils.EventEmitter):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Get or encode the value
|
# Get or encode the value
|
||||||
value = (
|
value_as_bytes = (
|
||||||
await attribute.read_value(bearer)
|
await attribute.read_value(bearer)
|
||||||
if value is None
|
if value is None
|
||||||
else attribute.encode_value(value)
|
else attribute.encode_value(value)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Truncate if needed
|
# Truncate if needed
|
||||||
if len(value) > bearer.att_mtu - 3:
|
if len(value_as_bytes) > bearer.att_mtu - 3:
|
||||||
value = value[: bearer.att_mtu - 3]
|
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
|
||||||
|
|
||||||
# Notify
|
# Notify
|
||||||
notification = att.ATT_Handle_Value_Notification(
|
notification = att.ATT_Handle_Value_Notification(
|
||||||
attribute_handle=attribute.handle, attribute_value=value
|
attribute_handle=attribute.handle, attribute_value=value_as_bytes
|
||||||
)
|
)
|
||||||
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
|
logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}')
|
||||||
self.send_gatt_pdu(bearer, bytes(notification))
|
self.send_gatt_pdu(bearer, bytes(notification))
|
||||||
@@ -431,8 +433,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def indicate_subscriber(
|
async def indicate_subscriber(
|
||||||
self,
|
self,
|
||||||
bearer: att.Bearer,
|
bearer: att.Bearer,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if att.is_enhanced_bearer(bearer) or force:
|
if att.is_enhanced_bearer(bearer) or force:
|
||||||
@@ -452,8 +454,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def _indicate_single_bearer(
|
async def _indicate_single_bearer(
|
||||||
self,
|
self,
|
||||||
bearer: att.Bearer,
|
bearer: att.Bearer,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None,
|
value: _T | None,
|
||||||
force: bool,
|
force: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Check if there's a subscriber
|
# Check if there's a subscriber
|
||||||
@@ -473,19 +475,19 @@ class Server(utils.EventEmitter):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Get or encode the value
|
# Get or encode the value
|
||||||
value = (
|
value_as_bytes = (
|
||||||
await attribute.read_value(bearer)
|
await attribute.read_value(bearer)
|
||||||
if value is None
|
if value is None
|
||||||
else attribute.encode_value(value)
|
else attribute.encode_value(value)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Truncate if needed
|
# Truncate if needed
|
||||||
if len(value) > bearer.att_mtu - 3:
|
if len(value_as_bytes) > bearer.att_mtu - 3:
|
||||||
value = value[: bearer.att_mtu - 3]
|
value_as_bytes = value_as_bytes[: bearer.att_mtu - 3]
|
||||||
|
|
||||||
# Indicate
|
# Indicate
|
||||||
indication = att.ATT_Handle_Value_Indication(
|
indication = att.ATT_Handle_Value_Indication(
|
||||||
attribute_handle=attribute.handle, attribute_value=value
|
attribute_handle=attribute.handle, attribute_value=value_as_bytes
|
||||||
)
|
)
|
||||||
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
|
logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}')
|
||||||
|
|
||||||
@@ -510,8 +512,8 @@ class Server(utils.EventEmitter):
|
|||||||
async def _notify_or_indicate_subscribers(
|
async def _notify_or_indicate_subscribers(
|
||||||
self,
|
self,
|
||||||
indicate: bool,
|
indicate: bool,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Get all the bearers for which there's at least one subscription
|
# Get all the bearers for which there's at least one subscription
|
||||||
@@ -537,8 +539,8 @@ class Server(utils.EventEmitter):
|
|||||||
|
|
||||||
async def notify_subscribers(
|
async def notify_subscribers(
|
||||||
self,
|
self,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
):
|
):
|
||||||
return await self._notify_or_indicate_subscribers(
|
return await self._notify_or_indicate_subscribers(
|
||||||
@@ -547,8 +549,8 @@ class Server(utils.EventEmitter):
|
|||||||
|
|
||||||
async def indicate_subscribers(
|
async def indicate_subscribers(
|
||||||
self,
|
self,
|
||||||
attribute: att.Attribute,
|
attribute: att.Attribute[_T],
|
||||||
value: bytes | None = None,
|
value: _T | None = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
):
|
):
|
||||||
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
|
return await self._notify_or_indicate_subscribers(True, attribute, value, force)
|
||||||
|
|||||||
Reference in New Issue
Block a user