Merge pull request #772 from zxzxwu/hap

HAP: Slightly Pythonic refactor
This commit is contained in:
zxzxwu
2025-09-05 23:08:09 +08:00
committed by GitHub

View File

@@ -18,7 +18,6 @@
from __future__ import annotations
import asyncio
import functools
import logging
from dataclasses import dataclass, field
from typing import Any, Optional, Union
@@ -272,7 +271,7 @@ class HearingAccessService(gatt.TemplateService):
def on_connection(connection: Connection) -> None:
@connection.on(connection.EVENT_DISCONNECTION)
def on_disconnection(_reason) -> None:
self.currently_connected_clients.remove(connection)
self.currently_connected_clients.discard(connection)
@connection.on(connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None:
@@ -373,8 +372,7 @@ class HearingAccessService(gatt.TemplateService):
self.preset_records[key]
for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index
]
del presets[num_presets:]
][:num_presets]
if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
@@ -383,7 +381,10 @@ class HearingAccessService(gatt.TemplateService):
async def _read_preset_response(
self, connection: Connection, presets: list[PresetRecord]
):
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
# If the ATT bearer is terminated before all notifications or indications are
# sent, then the server shall consider the Read Presets Request operation
# aborted and shall not either continue or restart the operation when the client
# reconnects.
try:
for i, preset in enumerate(presets):
await connection.device.indicate_subscriber(
@@ -404,7 +405,7 @@ class HearingAccessService(gatt.TemplateService):
async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notifyPresetOperations(op)
await self._notify_preset_operations(op)
async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset'''
@@ -413,14 +414,14 @@ class HearingAccessService(gatt.TemplateService):
raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index]
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
await self._notify_preset_operations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None:
'''Server API to make a preset available'''
preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
await self._notify_preset_operations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset'''
@@ -432,7 +433,7 @@ class HearingAccessService(gatt.TemplateService):
preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
)
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
await self._notify_preset_operations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection'''
@@ -447,8 +448,10 @@ class HearingAccessService(gatt.TemplateService):
return op.additional_parameters
op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
while len(op_list) > 0:
# If the ATT bearer is terminated before all notifications or indications are
# sent, then the server shall consider the Preset Changed operation aborted and
# shall continue the operation when the client reconnects.
while op_list:
try:
await connection.device.indicate_subscriber(
connection,
@@ -460,14 +463,15 @@ class HearingAccessService(gatt.TemplateService):
except TimeoutError:
break
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
for historyList in self.preset_changed_operations_history_per_device.values():
historyList.append(op)
async def _notify_preset_operations(self, op: PresetChangedOperation) -> None:
for history_list in self.preset_changed_operations_history_per_device.values():
history_list.append(op)
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(self, connection: Connection, value: bytes):
del connection # Unused
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
@@ -532,48 +536,51 @@ class HearingAccessService(gatt.TemplateService):
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(self, _: Connection, value: bytes):
async def _on_set_active_preset(self, connection: Connection, value: bytes):
del connection # Unused
await self.set_active_preset(value)
async def set_next_or_previous_preset(self, is_previous):
async def set_next_or_previous_preset(self, is_previous: bool) -> None:
'''Set the next or the previous preset as active'''
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
first_preset: Optional[PresetRecord] = None # To loop to first preset
next_preset: Optional[PresetRecord] = None
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
if not record.is_available():
continue
if first_preset == None:
first_preset = record
if is_previous:
if index >= self.active_preset_index:
continue
elif index <= self.active_preset_index:
continue
next_preset = record
break
presets = sorted(
[
record
for record in self.preset_records.values()
if record.is_available()
],
key=lambda record: record.index,
)
current_preset = self.preset_records[self.active_preset_index]
current_preset_pos = presets.index(current_preset)
if is_previous:
new_preset = presets[(current_preset_pos - 1) % len(presets)]
else:
new_preset = presets[(current_preset_pos + 1) % len(presets)]
if not first_preset: # If no other preset are available
if current_preset == new_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if next_preset:
self.active_preset_index = next_preset.index
else:
self.active_preset_index = first_preset.index
self.active_preset_index = new_preset.index
await self.notify_active_preset()
async def _on_set_next_preset(self, _: Connection, __value__: bytes) -> None:
async def _on_set_next_preset(self, connection: Connection, value: bytes) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(False)
async def _on_set_previous_preset(self, _: Connection, __value__: bytes) -> None:
async def _on_set_previous_preset(
self, connection: Connection, value: bytes
) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(True)
async def _on_set_active_preset_synchronized_locally(
self, _: Connection, value: bytes
self, connection: Connection, value: bytes
):
del connection # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -584,8 +591,9 @@ class HearingAccessService(gatt.TemplateService):
await self.other_server_in_binaural_set.set_active_preset(value)
async def _on_set_next_preset_synchronized_locally(
self, _: Connection, __value__: bytes
self, connection: Connection, value: bytes
):
del connection, value # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -596,8 +604,9 @@ class HearingAccessService(gatt.TemplateService):
await self.other_server_in_binaural_set.set_next_or_previous_preset(False)
async def _on_set_previous_preset_synchronized_locally(
self, _: Connection, __value__: bytes
self, connection: Connection, value: bytes
):
del connection, value # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -615,11 +624,13 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue
active_preset_index_notification: asyncio.Queue
preset_control_point_indications: asyncio.Queue[bytes]
active_preset_index_notification: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid(
@@ -641,20 +652,12 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
'B',
)
async def setup_subscription(self):
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
def on_active_preset_index_notification(data: bytes):
self.active_preset_index_notification.put_nowait(data)
def on_preset_control_point_indication(data: bytes):
self.preset_control_point_indications.put_nowait(data)
async def setup_subscription(self) -> None:
await self.hearing_aid_preset_control_point.subscribe(
functools.partial(on_preset_control_point_indication), prefer_notify=False
self.preset_control_point_indications.put_nowait,
prefer_notify=False,
)
await self.active_preset_index.subscribe(
functools.partial(on_active_preset_index_notification)
self.active_preset_index_notification.put_nowait
)