# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio import logging from dataclasses import dataclass, field from typing import Any from bumble import att, gatt, gatt_adapters, gatt_client, utils from bumble.core import InvalidArgumentError, InvalidStateError from bumble.device import Connection, Device from bumble.hci import Address # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- class ErrorCode(utils.OpenIntEnum): '''See Hearing Access Service 2.4. Attribute Profile error codes.''' INVALID_OPCODE = 0x80 WRITE_NAME_NOT_ALLOWED = 0x81 PRESET_SYNCHRONIZATION_NOT_SUPPORTED = 0x82 PRESET_OPERATION_NOT_POSSIBLE = 0x83 INVALID_PARAMETERS_LENGTH = 0x84 class HearingAidType(utils.OpenIntEnum): '''See Hearing Access Service 3.1. Hearing Aid Features.''' BINAURAL_HEARING_AID = 0b00 MONAURAL_HEARING_AID = 0b01 BANDED_HEARING_AID = 0b10 class PresetSynchronizationSupport(utils.OpenIntEnum): '''See Hearing Access Service 3.1. Hearing Aid Features.''' PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0 PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1 class IndependentPresets(utils.OpenIntEnum): '''See Hearing Access Service 3.1. Hearing Aid Features.''' IDENTICAL_PRESET_RECORD = 0b0 DIFFERENT_PRESET_RECORD = 0b1 class DynamicPresets(utils.OpenIntEnum): '''See Hearing Access Service 3.1. Hearing Aid Features.''' PRESET_RECORDS_DOES_NOT_CHANGE = 0b0 PRESET_RECORDS_MAY_CHANGE = 0b1 class WritablePresetsSupport(utils.OpenIntEnum): '''See Hearing Access Service 3.1. Hearing Aid Features.''' WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0 WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1 class HearingAidPresetControlPointOpcode(utils.OpenIntEnum): '''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.''' # fmt: off READ_PRESETS_REQUEST = 0x01 READ_PRESET_RESPONSE = 0x02 PRESET_CHANGED = 0x03 WRITE_PRESET_NAME = 0x04 SET_ACTIVE_PRESET = 0x05 SET_NEXT_PRESET = 0x06 SET_PREVIOUS_PRESET = 0x07 SET_ACTIVE_PRESET_SYNCHRONIZED_LOCALLY = 0x08 SET_NEXT_PRESET_SYNCHRONIZED_LOCALLY = 0x09 SET_PREVIOUS_PRESET_SYNCHRONIZED_LOCALLY = 0x0A @dataclass class HearingAidFeatures: '''See Hearing Access Service 3.1. Hearing Aid Features.''' hearing_aid_type: HearingAidType preset_synchronization_support: PresetSynchronizationSupport independent_presets: IndependentPresets dynamic_presets: DynamicPresets writable_presets_support: WritablePresetsSupport def __bytes__(self) -> bytes: return bytes( [ (self.hearing_aid_type << 0) | (self.preset_synchronization_support << 2) | (self.independent_presets << 3) | (self.dynamic_presets << 4) | (self.writable_presets_support << 5) ] ) def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures: return HearingAidFeatures( HearingAidType(data & 0b11), PresetSynchronizationSupport(data >> 2 & 0b1), IndependentPresets(data >> 3 & 0b1), DynamicPresets(data >> 4 & 0b1), WritablePresetsSupport(data >> 5 & 0b1), ) @dataclass class PresetChangedOperation: '''See Hearing Access Service 3.2.2.2. Preset Changed operation.''' class ChangeId(utils.OpenIntEnum): # fmt: off GENERIC_UPDATE = 0x00 PRESET_RECORD_DELETED = 0x01 PRESET_RECORD_AVAILABLE = 0x02 PRESET_RECORD_UNAVAILABLE = 0x03 @dataclass class Generic: prev_index: int preset_record: PresetRecord def __bytes__(self) -> bytes: return bytes([self.prev_index]) + bytes(self.preset_record) change_id: ChangeId additional_parameters: Generic | int def to_bytes(self, is_last: bool) -> bytes: if isinstance(self.additional_parameters, PresetChangedOperation.Generic): additional_parameters_bytes = bytes(self.additional_parameters) else: additional_parameters_bytes = bytes([self.additional_parameters]) return ( bytes( [ HearingAidPresetControlPointOpcode.PRESET_CHANGED, self.change_id, is_last, ] ) + additional_parameters_bytes ) class PresetChangedOperationDeleted(PresetChangedOperation): def __init__(self, index) -> None: self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_DELETED self.additional_parameters = index class PresetChangedOperationAvailable(PresetChangedOperation): def __init__(self, index) -> None: self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_AVAILABLE self.additional_parameters = index class PresetChangedOperationUnavailable(PresetChangedOperation): def __init__(self, index) -> None: self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_UNAVAILABLE self.additional_parameters = index @dataclass class PresetRecord: '''See Hearing Access Service 2.8. Preset record.''' @dataclass class Property: class Writable(utils.OpenIntEnum): CANNOT_BE_WRITTEN = 0b0 CAN_BE_WRITTEN = 0b1 class IsAvailable(utils.OpenIntEnum): IS_UNAVAILABLE = 0b0 IS_AVAILABLE = 0b1 writable: Writable = Writable.CAN_BE_WRITTEN is_available: IsAvailable = IsAvailable.IS_AVAILABLE def __bytes__(self) -> bytes: return bytes([self.writable | (self.is_available << 1)]) index: int name: str properties: Property = field(default_factory=Property) def __bytes__(self) -> bytes: return bytes([self.index]) + bytes(self.properties) + self.name.encode('utf-8') def is_available(self) -> bool: return ( self.properties.is_available == PresetRecord.Property.IsAvailable.IS_AVAILABLE ) # ----------------------------------------------------------------------------- # Server # ----------------------------------------------------------------------------- class HearingAccessService(gatt.TemplateService): UUID = gatt.GATT_HEARING_ACCESS_SERVICE hearing_aid_features_characteristic: gatt.Characteristic[bytes] hearing_aid_preset_control_point: gatt.Characteristic[bytes] active_preset_index_characteristic: gatt.Characteristic[bytes] active_preset_index: int active_preset_index_per_device: dict[Address, int] device: Device server_features: HearingAidFeatures preset_records: dict[int, PresetRecord] # key is the preset index read_presets_request_in_progress: bool other_server_in_binaural_set: HearingAccessService | None = None preset_changed_operations_history_per_device: dict[ Address, list[PresetChangedOperation] ] # Keep an updated list of connected client to send notification to currently_connected_clients: set[Connection] def __init__( self, device: Device, features: HearingAidFeatures, presets: list[PresetRecord] ) -> None: self.active_preset_index_per_device = {} self.read_presets_request_in_progress = False self.preset_changed_operations_history_per_device = {} self.currently_connected_clients = set() self.device = device self.server_features = features if len(presets) < 1: raise InvalidArgumentError(f'Invalid presets: {presets}') self.preset_records = {} for p in presets: if len(p.name.encode()) < 1 or len(p.name.encode()) > 40: raise InvalidArgumentError(f'Invalid name: {p.name}') self.preset_records[p.index] = p # associate the lowest index as the current active preset at startup self.active_preset_index = sorted(self.preset_records.keys())[0] @device.on(device.EVENT_CONNECTION) def on_connection(connection: Connection) -> None: @connection.on(connection.EVENT_DISCONNECTION) def on_disconnection(_reason) -> None: self.currently_connected_clients.discard(connection) @connection.on(connection.EVENT_CONNECTION_ATT_MTU_UPDATE) def on_mtu_update(*_: Any) -> None: self.on_incoming_connection(connection) @connection.on(connection.EVENT_CONNECTION_ENCRYPTION_CHANGE) def on_encryption_change(*_: Any) -> None: self.on_incoming_connection(connection) @connection.on(connection.EVENT_PAIRING) def on_pairing(*_: Any) -> None: self.on_incoming_connection(connection) self.on_incoming_connection(connection) self.hearing_aid_features_characteristic = gatt.Characteristic( uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC, properties=gatt.Characteristic.Properties.READ, permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, value=bytes(self.server_features), ) self.hearing_aid_preset_control_point = gatt.Characteristic( uuid=gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC, properties=( gatt.Characteristic.Properties.WRITE | gatt.Characteristic.Properties.INDICATE ), permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION, value=gatt.CharacteristicValue( write=self._on_write_hearing_aid_preset_control_point ), ) self.active_preset_index_characteristic = gatt.Characteristic( uuid=gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC, properties=( gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.NOTIFY ), permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION, value=gatt.CharacteristicValue(read=self._on_read_active_preset_index), ) super().__init__( [ self.hearing_aid_features_characteristic, self.hearing_aid_preset_control_point, self.active_preset_index_characteristic, ] ) def on_incoming_connection(self, connection: Connection): '''Setup initial operations to handle a remote bonded HAP device''' # TODO Should we filter on HAP device only ? if not connection.is_encrypted: logging.debug(f'HAS: {connection.peer_address} is not encrypted') return if not connection.peer_resolvable_address: logging.debug(f'HAS: {connection.peer_address} is not paired') return if connection.att_mtu < 49: logging.debug( f'HAS: {connection.peer_address} invalid MTU={connection.att_mtu}' ) return if connection.peer_address in self.currently_connected_clients: logging.debug( f'HAS: Already connected to {connection.peer_address} nothing to do' ) return self.currently_connected_clients.add(connection) if ( connection.peer_address not in self.preset_changed_operations_history_per_device ): self.preset_changed_operations_history_per_device[ connection.peer_address ] = [] return async def on_connection_async() -> None: # Send all the PresetChangedOperation that occur when not connected await self._preset_changed_operation(connection) # Update the active preset index if needed await self.notify_active_preset_for_connection(connection) connection.cancel_on_disconnection(on_connection_async()) def _on_read_active_preset_index(self, connection: Connection) -> bytes: del connection # Unused return bytes([self.active_preset_index]) # TODO this need to be triggered when device is unbonded def on_forget(self, addr: Address) -> None: self.preset_changed_operations_history_per_device.pop(addr) async def _on_write_hearing_aid_preset_control_point( self, connection: Connection, value: bytes ): opcode = HearingAidPresetControlPointOpcode(value[0]) handler = getattr(self, '_on_' + opcode.name.lower()) await handler(connection, value) async def _on_read_presets_request(self, connection: Connection, value: bytes): if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements logging.warning(f'HAS require MTU >= 49: {connection}') if self.read_presets_request_in_progress: raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS) self.read_presets_request_in_progress = True start_index = value[1] if start_index == 0x00: raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE) num_presets = value[2] if num_presets == 0x00: raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE) # Sending `num_presets` presets ordered by increasing index field, starting from start_index presets = [ self.preset_records[key] for key in sorted(self.preset_records.keys()) if self.preset_records[key].index >= start_index ][:num_presets] if len(presets) == 0: raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE) utils.AsyncRunner.spawn(self._read_preset_response(connection, presets)) async def _read_preset_response( self, connection: Connection, presets: list[PresetRecord] ): # If the ATT bearer is terminated before all notifications or indications are # sent, then the server shall consider the Read Presets Request operation # aborted and shall not either continue or restart the operation when the client # reconnects. try: for i, preset in enumerate(presets): await connection.device.indicate_subscriber( connection, self.hearing_aid_preset_control_point, value=bytes( [ HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE, i == len(presets) - 1, ] ) + bytes(preset), ) finally: # indicate_subscriber can raise a TimeoutError, we need to gracefully terminate the operation self.read_presets_request_in_progress = False async def generic_update(self, op: PresetChangedOperation) -> None: '''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent''' await self._notify_preset_operations(op) async def delete_preset(self, index: int) -> None: '''Server API to delete a preset. It should not be the current active preset''' if index == self.active_preset_index: raise InvalidStateError('Cannot delete active preset') del self.preset_records[index] await self._notify_preset_operations(PresetChangedOperationDeleted(index)) async def available_preset(self, index: int) -> None: '''Server API to make a preset available''' preset = self.preset_records[index] preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE await self._notify_preset_operations(PresetChangedOperationAvailable(index)) async def unavailable_preset(self, index: int) -> None: '''Server API to make a preset unavailable. It should not be the current active preset''' if index == self.active_preset_index: raise InvalidStateError('Cannot set active preset as unavailable') preset = self.preset_records[index] preset.properties.is_available = ( PresetRecord.Property.IsAvailable.IS_UNAVAILABLE ) await self._notify_preset_operations(PresetChangedOperationUnavailable(index)) async def _preset_changed_operation(self, connection: Connection) -> None: '''Send all PresetChangedOperation saved for a given connection''' op_list = self.preset_changed_operations_history_per_device.get( connection.peer_address, [] ) # Notification will be sent in index order def get_op_index(op: PresetChangedOperation) -> int: if isinstance(op.additional_parameters, PresetChangedOperation.Generic): return op.additional_parameters.prev_index return op.additional_parameters op_list.sort(key=get_op_index) # If the ATT bearer is terminated before all notifications or indications are # sent, then the server shall consider the Preset Changed operation aborted and # shall continue the operation when the client reconnects. while op_list: try: await connection.device.indicate_subscriber( connection, self.hearing_aid_preset_control_point, value=op_list[0].to_bytes(len(op_list) == 1), force=True, # TODO GATT notification subscription should be persistent ) # Remove item once sent, and keep the non sent item in the list op_list.pop(0) except TimeoutError: break async def _notify_preset_operations(self, op: PresetChangedOperation) -> None: for history_list in self.preset_changed_operations_history_per_device.values(): history_list.append(op) for connection in self.currently_connected_clients: await self._preset_changed_operation(connection) async def _on_write_preset_name(self, connection: Connection, value: bytes): del connection # Unused if self.read_presets_request_in_progress: raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS) index = value[1] preset = self.preset_records.get(index, None) if ( not preset or preset.properties.writable == PresetRecord.Property.Writable.CANNOT_BE_WRITTEN ): raise att.ATT_Error(ErrorCode.WRITE_NAME_NOT_ALLOWED) name = value[2:].decode('utf-8') if not name or len(name) > 40: raise att.ATT_Error(ErrorCode.INVALID_PARAMETERS_LENGTH) preset.name = name await self.generic_update( PresetChangedOperation( PresetChangedOperation.ChangeId.GENERIC_UPDATE, PresetChangedOperation.Generic(index, preset), ) ) async def notify_active_preset_for_connection(self, connection: Connection) -> None: if ( self.active_preset_index_per_device.get(connection.peer_address, 0x00) == self.active_preset_index ): # Nothing to do, peer is already updated return await connection.device.notify_subscriber( connection, attribute=self.active_preset_index_characteristic, value=bytes([self.active_preset_index]), ) self.active_preset_index_per_device[connection.peer_address] = ( self.active_preset_index ) async def notify_active_preset(self) -> None: for connection in self.currently_connected_clients: await self.notify_active_preset_for_connection(connection) async def set_active_preset(self, value: bytes) -> None: index = value[1] preset = self.preset_records.get(index, None) if ( not preset or preset.properties.is_available != PresetRecord.Property.IsAvailable.IS_AVAILABLE ): raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) if index == self.active_preset_index: # Already at correct value return self.active_preset_index = index await self.notify_active_preset() async def _on_set_active_preset(self, connection: Connection, value: bytes): del connection # Unused await self.set_active_preset(value) async def set_next_or_previous_preset(self, is_previous: bool) -> None: '''Set the next or the previous preset as active''' if self.active_preset_index == 0x00: raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) presets = sorted( [ record for record in self.preset_records.values() if record.is_available() ], key=lambda record: record.index, ) current_preset = self.preset_records[self.active_preset_index] current_preset_pos = presets.index(current_preset) if is_previous: new_preset = presets[(current_preset_pos - 1) % len(presets)] else: new_preset = presets[(current_preset_pos + 1) % len(presets)] if current_preset == new_preset: # If no other preset are available raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) self.active_preset_index = new_preset.index await self.notify_active_preset() async def _on_set_next_preset(self, connection: Connection, value: bytes) -> None: del connection, value # Unused. await self.set_next_or_previous_preset(False) async def _on_set_previous_preset( self, connection: Connection, value: bytes ) -> None: del connection, value # Unused. await self.set_next_or_previous_preset(True) async def _on_set_active_preset_synchronized_locally( self, connection: Connection, value: bytes ): del connection # Unused. if ( self.server_features.preset_synchronization_support == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED ): raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED) await self.set_active_preset(value) if self.other_server_in_binaural_set: await self.other_server_in_binaural_set.set_active_preset(value) async def _on_set_next_preset_synchronized_locally( self, connection: Connection, value: bytes ): del connection, value # Unused. if ( self.server_features.preset_synchronization_support == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED ): raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED) await self.set_next_or_previous_preset(False) if self.other_server_in_binaural_set: await self.other_server_in_binaural_set.set_next_or_previous_preset(False) async def _on_set_previous_preset_synchronized_locally( self, connection: Connection, value: bytes ): del connection, value # Unused. if ( self.server_features.preset_synchronization_support == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED ): raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED) await self.set_next_or_previous_preset(True) if self.other_server_in_binaural_set: await self.other_server_in_binaural_set.set_next_or_previous_preset(True) # ----------------------------------------------------------------------------- # Client # ----------------------------------------------------------------------------- class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy): SERVICE_CLASS = HearingAccessService hearing_aid_preset_control_point: gatt_client.CharacteristicProxy preset_control_point_indications: asyncio.Queue[bytes] active_preset_index_notification: asyncio.Queue[bytes] def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: self.service_proxy = service_proxy self.preset_control_point_indications = asyncio.Queue() self.active_preset_index_notification = asyncio.Queue() self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter( service_proxy.get_characteristics_by_uuid( gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC )[0], 'B', ) self.hearing_aid_preset_control_point = ( service_proxy.get_characteristics_by_uuid( gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC )[0] ) self.active_preset_index = gatt_adapters.PackedCharacteristicProxyAdapter( service_proxy.get_characteristics_by_uuid( gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC )[0], 'B', ) async def setup_subscription(self) -> None: await self.hearing_aid_preset_control_point.subscribe( self.preset_control_point_indications.put_nowait, prefer_notify=False, ) await self.active_preset_index.subscribe( self.active_preset_index_notification.put_nowait )