Files
bumble_fork/bumble/profiles/hap.py
2025-08-21 16:38:58 -07:00

661 lines
25 KiB
Python

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import functools
import logging
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from bumble import att, gatt, gatt_adapters, gatt_client, utils
from bumble.core import InvalidArgumentError, InvalidStateError
from bumble.device import Connection, Device
from bumble.hci import Address
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class ErrorCode(utils.OpenIntEnum):
'''See Hearing Access Service 2.4. Attribute Profile error codes.'''
INVALID_OPCODE = 0x80
WRITE_NAME_NOT_ALLOWED = 0x81
PRESET_SYNCHRONIZATION_NOT_SUPPORTED = 0x82
PRESET_OPERATION_NOT_POSSIBLE = 0x83
INVALID_PARAMETERS_LENGTH = 0x84
class HearingAidType(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
BINAURAL_HEARING_AID = 0b00
MONAURAL_HEARING_AID = 0b01
BANDED_HEARING_AID = 0b10
class PresetSynchronizationSupport(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0
PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1
class IndependentPresets(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
IDENTICAL_PRESET_RECORD = 0b0
DIFFERENT_PRESET_RECORD = 0b1
class DynamicPresets(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_RECORDS_DOES_NOT_CHANGE = 0b0
PRESET_RECORDS_MAY_CHANGE = 0b1
class WritablePresetsSupport(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0
WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1
class HearingAidPresetControlPointOpcode(utils.OpenIntEnum):
'''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.'''
# fmt: off
READ_PRESETS_REQUEST = 0x01
READ_PRESET_RESPONSE = 0x02
PRESET_CHANGED = 0x03
WRITE_PRESET_NAME = 0x04
SET_ACTIVE_PRESET = 0x05
SET_NEXT_PRESET = 0x06
SET_PREVIOUS_PRESET = 0x07
SET_ACTIVE_PRESET_SYNCHRONIZED_LOCALLY = 0x08
SET_NEXT_PRESET_SYNCHRONIZED_LOCALLY = 0x09
SET_PREVIOUS_PRESET_SYNCHRONIZED_LOCALLY = 0x0A
@dataclass
class HearingAidFeatures:
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
hearing_aid_type: HearingAidType
preset_synchronization_support: PresetSynchronizationSupport
independent_presets: IndependentPresets
dynamic_presets: DynamicPresets
writable_presets_support: WritablePresetsSupport
def __bytes__(self) -> bytes:
return bytes(
[
(self.hearing_aid_type << 0)
| (self.preset_synchronization_support << 2)
| (self.independent_presets << 3)
| (self.dynamic_presets << 4)
| (self.writable_presets_support << 5)
]
)
def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures:
return HearingAidFeatures(
HearingAidType(data & 0b11),
PresetSynchronizationSupport(data >> 2 & 0b1),
IndependentPresets(data >> 3 & 0b1),
DynamicPresets(data >> 4 & 0b1),
WritablePresetsSupport(data >> 5 & 0b1),
)
@dataclass
class PresetChangedOperation:
'''See Hearing Access Service 3.2.2.2. Preset Changed operation.'''
class ChangeId(utils.OpenIntEnum):
# fmt: off
GENERIC_UPDATE = 0x00
PRESET_RECORD_DELETED = 0x01
PRESET_RECORD_AVAILABLE = 0x02
PRESET_RECORD_UNAVAILABLE = 0x03
@dataclass
class Generic:
prev_index: int
preset_record: PresetRecord
def __bytes__(self) -> bytes:
return bytes([self.prev_index]) + bytes(self.preset_record)
change_id: ChangeId
additional_parameters: Union[Generic, int]
def to_bytes(self, is_last: bool) -> bytes:
if isinstance(self.additional_parameters, PresetChangedOperation.Generic):
additional_parameters_bytes = bytes(self.additional_parameters)
else:
additional_parameters_bytes = bytes([self.additional_parameters])
return (
bytes(
[
HearingAidPresetControlPointOpcode.PRESET_CHANGED,
self.change_id,
is_last,
]
)
+ additional_parameters_bytes
)
class PresetChangedOperationDeleted(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_DELETED
self.additional_parameters = index
class PresetChangedOperationAvailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_AVAILABLE
self.additional_parameters = index
class PresetChangedOperationUnavailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_UNAVAILABLE
self.additional_parameters = index
@dataclass
class PresetRecord:
'''See Hearing Access Service 2.8. Preset record.'''
@dataclass
class Property:
class Writable(utils.OpenIntEnum):
CANNOT_BE_WRITTEN = 0b0
CAN_BE_WRITTEN = 0b1
class IsAvailable(utils.OpenIntEnum):
IS_UNAVAILABLE = 0b0
IS_AVAILABLE = 0b1
writable: Writable = Writable.CAN_BE_WRITTEN
is_available: IsAvailable = IsAvailable.IS_AVAILABLE
def __bytes__(self) -> bytes:
return bytes([self.writable | (self.is_available << 1)])
index: int
name: str
properties: Property = field(default_factory=Property)
def __bytes__(self) -> bytes:
return bytes([self.index]) + bytes(self.properties) + self.name.encode('utf-8')
def is_available(self) -> bool:
return (
self.properties.is_available
== PresetRecord.Property.IsAvailable.IS_AVAILABLE
)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class HearingAccessService(gatt.TemplateService):
UUID = gatt.GATT_HEARING_ACCESS_SERVICE
hearing_aid_features_characteristic: gatt.Characteristic[bytes]
hearing_aid_preset_control_point: gatt.Characteristic[bytes]
active_preset_index_characteristic: gatt.Characteristic[bytes]
active_preset_index: int
active_preset_index_per_device: dict[Address, int]
device: Device
server_features: HearingAidFeatures
preset_records: dict[int, PresetRecord] # key is the preset index
read_presets_request_in_progress: bool
other_server_in_binaural_set: Optional[HearingAccessService] = None
preset_changed_operations_history_per_device: dict[
Address, list[PresetChangedOperation]
]
# Keep an updated list of connected client to send notification to
currently_connected_clients: set[Connection]
def __init__(
self, device: Device, features: HearingAidFeatures, presets: list[PresetRecord]
) -> None:
self.active_preset_index_per_device = {}
self.read_presets_request_in_progress = False
self.preset_changed_operations_history_per_device = {}
self.currently_connected_clients = set()
self.device = device
self.server_features = features
if len(presets) < 1:
raise InvalidArgumentError(f'Invalid presets: {presets}')
self.preset_records = {}
for p in presets:
if len(p.name.encode()) < 1 or len(p.name.encode()) > 40:
raise InvalidArgumentError(f'Invalid name: {p.name}')
self.preset_records[p.index] = p
# associate the lowest index as the current active preset at startup
self.active_preset_index = sorted(self.preset_records.keys())[0]
@device.on(device.EVENT_CONNECTION)
def on_connection(connection: Connection) -> None:
@connection.on(connection.EVENT_DISCONNECTION)
def on_disconnection(_reason) -> None:
self.currently_connected_clients.remove(connection)
@connection.on(connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None:
self.on_incoming_paired_connection(connection)
if connection.peer_resolvable_address:
self.on_incoming_paired_connection(connection)
self.hearing_aid_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes(self.server_features),
)
self.hearing_aid_preset_control_point = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.INDICATE
),
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(
write=self._on_write_hearing_aid_preset_control_point
),
)
self.active_preset_index_characteristic = gatt.Characteristic(
uuid=gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_active_preset_index),
)
super().__init__(
[
self.hearing_aid_features_characteristic,
self.hearing_aid_preset_control_point,
self.active_preset_index_characteristic,
]
)
def on_incoming_paired_connection(self, connection: Connection):
'''Setup initial operations to handle a remote bonded HAP device'''
# TODO Should we filter on HAP device only ?
self.currently_connected_clients.add(connection)
if (
connection.peer_address
not in self.preset_changed_operations_history_per_device
):
self.preset_changed_operations_history_per_device[
connection.peer_address
] = []
return
async def on_connection_async() -> None:
# Send all the PresetChangedOperation that occur when not connected
await self._preset_changed_operation(connection)
# Update the active preset index if needed
await self.notify_active_preset_for_connection(connection)
connection.cancel_on_disconnection(on_connection_async())
def _on_read_active_preset_index(self, connection: Connection) -> bytes:
del connection # Unused
return bytes([self.active_preset_index])
# TODO this need to be triggered when device is unbonded
def on_forget(self, addr: Address) -> None:
self.preset_changed_operations_history_per_device.pop(addr)
async def _on_write_hearing_aid_preset_control_point(
self, connection: Connection, value: bytes
):
opcode = HearingAidPresetControlPointOpcode(value[0])
handler = getattr(self, '_on_' + opcode.name.lower())
await handler(connection, value)
async def _on_read_presets_request(self, connection: Connection, value: bytes):
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
logging.warning(f'HAS require MTU >= 49: {connection}')
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
self.read_presets_request_in_progress = True
start_index = value[1]
if start_index == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
num_presets = value[2]
if num_presets == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
# Sending `num_presets` presets ordered by increasing index field, starting from start_index
presets = [
self.preset_records[key]
for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index
]
del presets[num_presets:]
if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
utils.AsyncRunner.spawn(self._read_preset_response(connection, presets))
async def _read_preset_response(
self, connection: Connection, presets: list[PresetRecord]
):
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
try:
for i, preset in enumerate(presets):
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=bytes(
[
HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE,
i == len(presets) - 1,
]
)
+ bytes(preset),
)
finally:
# indicate_subscriber can raise a TimeoutError, we need to gracefully terminate the operation
self.read_presets_request_in_progress = False
async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notifyPresetOperations(op)
async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index]
await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None:
'''Server API to make a preset available'''
preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot set active preset as unavailable')
preset = self.preset_records[index]
preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
)
await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection'''
op_list = self.preset_changed_operations_history_per_device.get(
connection.peer_address, []
)
# Notification will be sent in index order
def get_op_index(op: PresetChangedOperation) -> int:
if isinstance(op.additional_parameters, PresetChangedOperation.Generic):
return op.additional_parameters.prev_index
return op.additional_parameters
op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
while len(op_list) > 0:
try:
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=op_list[0].to_bytes(len(op_list) == 1),
)
# Remove item once sent, and keep the non sent item in the list
op_list.pop(0)
except TimeoutError:
break
async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
for historyList in self.preset_changed_operations_history_per_device.values():
historyList.append(op)
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(self, connection: Connection, value: bytes):
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.writable
== PresetRecord.Property.Writable.CANNOT_BE_WRITTEN
):
raise att.ATT_Error(ErrorCode.WRITE_NAME_NOT_ALLOWED)
name = value[2:].decode('utf-8')
if not name or len(name) > 40:
raise att.ATT_Error(ErrorCode.INVALID_PARAMETERS_LENGTH)
preset.name = name
await self.generic_update(
PresetChangedOperation(
PresetChangedOperation.ChangeId.GENERIC_UPDATE,
PresetChangedOperation.Generic(index, preset),
)
)
async def notify_active_preset_for_connection(self, connection: Connection) -> None:
if (
self.active_preset_index_per_device.get(connection.peer_address, 0x00)
== self.active_preset_index
):
# Nothing to do, peer is already updated
return
await connection.device.notify_subscriber(
connection,
attribute=self.active_preset_index_characteristic,
value=bytes([self.active_preset_index]),
)
self.active_preset_index_per_device[connection.peer_address] = (
self.active_preset_index
)
async def notify_active_preset(self) -> None:
for connection in self.currently_connected_clients:
await self.notify_active_preset_for_connection(connection)
async def set_active_preset(self, value: bytes) -> None:
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.is_available
!= PresetRecord.Property.IsAvailable.IS_AVAILABLE
):
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if index == self.active_preset_index:
# Already at correct value
return
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(self, _: Connection, value: bytes):
await self.set_active_preset(value)
async def set_next_or_previous_preset(self, is_previous):
'''Set the next or the previous preset as active'''
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
first_preset: Optional[PresetRecord] = None # To loop to first preset
next_preset: Optional[PresetRecord] = None
for index, record in sorted(self.preset_records.items(), reverse=is_previous):
if not record.is_available():
continue
if first_preset == None:
first_preset = record
if is_previous:
if index >= self.active_preset_index:
continue
elif index <= self.active_preset_index:
continue
next_preset = record
break
if not first_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if next_preset:
self.active_preset_index = next_preset.index
else:
self.active_preset_index = first_preset.index
await self.notify_active_preset()
async def _on_set_next_preset(self, _: Connection, __value__: bytes) -> None:
await self.set_next_or_previous_preset(False)
async def _on_set_previous_preset(self, _: Connection, __value__: bytes) -> None:
await self.set_next_or_previous_preset(True)
async def _on_set_active_preset_synchronized_locally(
self, _: Connection, value: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_active_preset(value)
if self.other_server_in_binaural_set:
await self.other_server_in_binaural_set.set_active_preset(value)
async def _on_set_next_preset_synchronized_locally(
self, _: Connection, __value__: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(False)
if self.other_server_in_binaural_set:
await self.other_server_in_binaural_set.set_next_or_previous_preset(False)
async def _on_set_previous_preset_synchronized_locally(
self, _: Connection, __value__: bytes
):
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(True)
if self.other_server_in_binaural_set:
await self.other_server_in_binaural_set.set_next_or_previous_preset(True)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue
active_preset_index_notification: asyncio.Queue
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC
)[0],
'B',
)
self.hearing_aid_preset_control_point = (
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC
)[0]
)
self.active_preset_index = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC
)[0],
'B',
)
async def setup_subscription(self):
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
def on_active_preset_index_notification(data: bytes):
self.active_preset_index_notification.put_nowait(data)
def on_preset_control_point_indication(data: bytes):
self.preset_control_point_indications.put_nowait(data)
await self.hearing_aid_preset_control_point.subscribe(
functools.partial(on_preset_control_point_indication), prefer_notify=False
)
await self.active_preset_index.subscribe(
functools.partial(on_active_preset_index_notification)
)