Merge pull request #772 from zxzxwu/hap

HAP: Slightly Pythonic refactor
This commit is contained in:
zxzxwu
2025-09-05 23:08:09 +08:00
committed by GitHub

View File

@@ -18,7 +18,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import functools
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Optional, Union from typing import Any, Optional, Union
@@ -272,7 +271,7 @@ class HearingAccessService(gatt.TemplateService):
def on_connection(connection: Connection) -> None: def on_connection(connection: Connection) -> None:
@connection.on(connection.EVENT_DISCONNECTION) @connection.on(connection.EVENT_DISCONNECTION)
def on_disconnection(_reason) -> None: def on_disconnection(_reason) -> None:
self.currently_connected_clients.remove(connection) self.currently_connected_clients.discard(connection)
@connection.on(connection.EVENT_PAIRING) @connection.on(connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None: def on_pairing(*_: Any) -> None:
@@ -373,8 +372,7 @@ class HearingAccessService(gatt.TemplateService):
self.preset_records[key] self.preset_records[key]
for key in sorted(self.preset_records.keys()) for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index if self.preset_records[key].index >= start_index
] ][:num_presets]
del presets[num_presets:]
if len(presets) == 0: if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE) raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
@@ -383,7 +381,10 @@ class HearingAccessService(gatt.TemplateService):
async def _read_preset_response( async def _read_preset_response(
self, connection: Connection, presets: list[PresetRecord] self, connection: Connection, presets: list[PresetRecord]
): ):
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects. # If the ATT bearer is terminated before all notifications or indications are
# sent, then the server shall consider the Read Presets Request operation
# aborted and shall not either continue or restart the operation when the client
# reconnects.
try: try:
for i, preset in enumerate(presets): for i, preset in enumerate(presets):
await connection.device.indicate_subscriber( await connection.device.indicate_subscriber(
@@ -404,7 +405,7 @@ class HearingAccessService(gatt.TemplateService):
async def generic_update(self, op: PresetChangedOperation) -> None: async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent''' '''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notifyPresetOperations(op) await self._notify_preset_operations(op)
async def delete_preset(self, index: int) -> None: async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset''' '''Server API to delete a preset. It should not be the current active preset'''
@@ -413,14 +414,14 @@ class HearingAccessService(gatt.TemplateService):
raise InvalidStateError('Cannot delete active preset') raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index] del self.preset_records[index]
await self._notifyPresetOperations(PresetChangedOperationDeleted(index)) await self._notify_preset_operations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None: async def available_preset(self, index: int) -> None:
'''Server API to make a preset available''' '''Server API to make a preset available'''
preset = self.preset_records[index] preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notifyPresetOperations(PresetChangedOperationAvailable(index)) await self._notify_preset_operations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None: async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset''' '''Server API to make a preset unavailable. It should not be the current active preset'''
@@ -432,7 +433,7 @@ class HearingAccessService(gatt.TemplateService):
preset.properties.is_available = ( preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
) )
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index)) await self._notify_preset_operations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None: async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection''' '''Send all PresetChangedOperation saved for a given connection'''
@@ -447,8 +448,10 @@ class HearingAccessService(gatt.TemplateService):
return op.additional_parameters return op.additional_parameters
op_list.sort(key=get_op_index) op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects. # If the ATT bearer is terminated before all notifications or indications are
while len(op_list) > 0: # sent, then the server shall consider the Preset Changed operation aborted and
# shall continue the operation when the client reconnects.
while op_list:
try: try:
await connection.device.indicate_subscriber( await connection.device.indicate_subscriber(
connection, connection,
@@ -460,14 +463,15 @@ class HearingAccessService(gatt.TemplateService):
except TimeoutError: except TimeoutError:
break break
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None: async def _notify_preset_operations(self, op: PresetChangedOperation) -> None:
for historyList in self.preset_changed_operations_history_per_device.values(): for history_list in self.preset_changed_operations_history_per_device.values():
historyList.append(op) history_list.append(op)
for connection in self.currently_connected_clients: for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection) await self._preset_changed_operation(connection)
async def _on_write_preset_name(self, connection: Connection, value: bytes): async def _on_write_preset_name(self, connection: Connection, value: bytes):
del connection # Unused
if self.read_presets_request_in_progress: if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS) raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
@@ -532,48 +536,51 @@ class HearingAccessService(gatt.TemplateService):
self.active_preset_index = index self.active_preset_index = index
await self.notify_active_preset() await self.notify_active_preset()
async def _on_set_active_preset(self, _: Connection, value: bytes): async def _on_set_active_preset(self, connection: Connection, value: bytes):
del connection # Unused
await self.set_active_preset(value) await self.set_active_preset(value)
async def set_next_or_previous_preset(self, is_previous): async def set_next_or_previous_preset(self, is_previous: bool) -> None:
'''Set the next or the previous preset as active''' '''Set the next or the previous preset as active'''
if self.active_preset_index == 0x00: if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
first_preset: Optional[PresetRecord] = None # To loop to first preset presets = sorted(
next_preset: Optional[PresetRecord] = None [
for index, record in sorted(self.preset_records.items(), reverse=is_previous): record
if not record.is_available(): for record in self.preset_records.values()
continue if record.is_available()
if first_preset == None: ],
first_preset = record key=lambda record: record.index,
if is_previous: )
if index >= self.active_preset_index: current_preset = self.preset_records[self.active_preset_index]
continue current_preset_pos = presets.index(current_preset)
elif index <= self.active_preset_index: if is_previous:
continue new_preset = presets[(current_preset_pos - 1) % len(presets)]
next_preset = record else:
break new_preset = presets[(current_preset_pos + 1) % len(presets)]
if not first_preset: # If no other preset are available if current_preset == new_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if next_preset: self.active_preset_index = new_preset.index
self.active_preset_index = next_preset.index
else:
self.active_preset_index = first_preset.index
await self.notify_active_preset() await self.notify_active_preset()
async def _on_set_next_preset(self, _: Connection, __value__: bytes) -> None: async def _on_set_next_preset(self, connection: Connection, value: bytes) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(False) await self.set_next_or_previous_preset(False)
async def _on_set_previous_preset(self, _: Connection, __value__: bytes) -> None: async def _on_set_previous_preset(
self, connection: Connection, value: bytes
) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(True) await self.set_next_or_previous_preset(True)
async def _on_set_active_preset_synchronized_locally( async def _on_set_active_preset_synchronized_locally(
self, _: Connection, value: bytes self, connection: Connection, value: bytes
): ):
del connection # Unused.
if ( if (
self.server_features.preset_synchronization_support self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -584,8 +591,9 @@ class HearingAccessService(gatt.TemplateService):
await self.other_server_in_binaural_set.set_active_preset(value) await self.other_server_in_binaural_set.set_active_preset(value)
async def _on_set_next_preset_synchronized_locally( async def _on_set_next_preset_synchronized_locally(
self, _: Connection, __value__: bytes self, connection: Connection, value: bytes
): ):
del connection, value # Unused.
if ( if (
self.server_features.preset_synchronization_support self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -596,8 +604,9 @@ class HearingAccessService(gatt.TemplateService):
await self.other_server_in_binaural_set.set_next_or_previous_preset(False) await self.other_server_in_binaural_set.set_next_or_previous_preset(False)
async def _on_set_previous_preset_synchronized_locally( async def _on_set_previous_preset_synchronized_locally(
self, _: Connection, __value__: bytes self, connection: Connection, value: bytes
): ):
del connection, value # Unused.
if ( if (
self.server_features.preset_synchronization_support self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
@@ -615,11 +624,13 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue preset_control_point_indications: asyncio.Queue[bytes]
active_preset_index_notification: asyncio.Queue active_preset_index_notification: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter( self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid( service_proxy.get_characteristics_by_uuid(
@@ -641,20 +652,12 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
'B', 'B',
) )
async def setup_subscription(self): async def setup_subscription(self) -> None:
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
def on_active_preset_index_notification(data: bytes):
self.active_preset_index_notification.put_nowait(data)
def on_preset_control_point_indication(data: bytes):
self.preset_control_point_indications.put_nowait(data)
await self.hearing_aid_preset_control_point.subscribe( await self.hearing_aid_preset_control_point.subscribe(
functools.partial(on_preset_control_point_indication), prefer_notify=False self.preset_control_point_indications.put_nowait,
prefer_notify=False,
) )
await self.active_preset_index.subscribe( await self.active_preset_index.subscribe(
functools.partial(on_active_preset_index_notification) self.active_preset_index_notification.put_nowait
) )