Files
2026-01-01 03:25:32 +08:00

693 lines
26 KiB
Python

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
from bumble import att, gatt, gatt_adapters, gatt_client, utils
from bumble.core import InvalidArgumentError, InvalidStateError
from bumble.device import Connection, Device
from bumble.hci import Address
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
class ErrorCode(utils.OpenIntEnum):
'''See Hearing Access Service 2.4. Attribute Profile error codes.'''
INVALID_OPCODE = 0x80
WRITE_NAME_NOT_ALLOWED = 0x81
PRESET_SYNCHRONIZATION_NOT_SUPPORTED = 0x82
PRESET_OPERATION_NOT_POSSIBLE = 0x83
INVALID_PARAMETERS_LENGTH = 0x84
class HearingAidType(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
BINAURAL_HEARING_AID = 0b00
MONAURAL_HEARING_AID = 0b01
BANDED_HEARING_AID = 0b10
class PresetSynchronizationSupport(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED = 0b0
PRESET_SYNCHRONIZATION_IS_SUPPORTED = 0b1
class IndependentPresets(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
IDENTICAL_PRESET_RECORD = 0b0
DIFFERENT_PRESET_RECORD = 0b1
class DynamicPresets(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
PRESET_RECORDS_DOES_NOT_CHANGE = 0b0
PRESET_RECORDS_MAY_CHANGE = 0b1
class WritablePresetsSupport(utils.OpenIntEnum):
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
WRITABLE_PRESET_RECORDS_NOT_SUPPORTED = 0b0
WRITABLE_PRESET_RECORDS_SUPPORTED = 0b1
class HearingAidPresetControlPointOpcode(utils.OpenIntEnum):
'''See Hearing Access Service 3.3.1 Hearing Aid Preset Control Point operation requirements.'''
# fmt: off
READ_PRESETS_REQUEST = 0x01
READ_PRESET_RESPONSE = 0x02
PRESET_CHANGED = 0x03
WRITE_PRESET_NAME = 0x04
SET_ACTIVE_PRESET = 0x05
SET_NEXT_PRESET = 0x06
SET_PREVIOUS_PRESET = 0x07
SET_ACTIVE_PRESET_SYNCHRONIZED_LOCALLY = 0x08
SET_NEXT_PRESET_SYNCHRONIZED_LOCALLY = 0x09
SET_PREVIOUS_PRESET_SYNCHRONIZED_LOCALLY = 0x0A
@dataclass
class HearingAidFeatures:
'''See Hearing Access Service 3.1. Hearing Aid Features.'''
hearing_aid_type: HearingAidType
preset_synchronization_support: PresetSynchronizationSupport
independent_presets: IndependentPresets
dynamic_presets: DynamicPresets
writable_presets_support: WritablePresetsSupport
def __bytes__(self) -> bytes:
return bytes(
[
(self.hearing_aid_type << 0)
| (self.preset_synchronization_support << 2)
| (self.independent_presets << 3)
| (self.dynamic_presets << 4)
| (self.writable_presets_support << 5)
]
)
def HearingAidFeatures_from_bytes(data: int) -> HearingAidFeatures:
return HearingAidFeatures(
HearingAidType(data & 0b11),
PresetSynchronizationSupport(data >> 2 & 0b1),
IndependentPresets(data >> 3 & 0b1),
DynamicPresets(data >> 4 & 0b1),
WritablePresetsSupport(data >> 5 & 0b1),
)
@dataclass
class PresetChangedOperation:
'''See Hearing Access Service 3.2.2.2. Preset Changed operation.'''
class ChangeId(utils.OpenIntEnum):
# fmt: off
GENERIC_UPDATE = 0x00
PRESET_RECORD_DELETED = 0x01
PRESET_RECORD_AVAILABLE = 0x02
PRESET_RECORD_UNAVAILABLE = 0x03
@dataclass
class Generic:
prev_index: int
preset_record: PresetRecord
def __bytes__(self) -> bytes:
return bytes([self.prev_index]) + bytes(self.preset_record)
change_id: ChangeId
additional_parameters: Generic | int
def to_bytes(self, is_last: bool) -> bytes:
if isinstance(self.additional_parameters, PresetChangedOperation.Generic):
additional_parameters_bytes = bytes(self.additional_parameters)
else:
additional_parameters_bytes = bytes([self.additional_parameters])
return (
bytes(
[
HearingAidPresetControlPointOpcode.PRESET_CHANGED,
self.change_id,
is_last,
]
)
+ additional_parameters_bytes
)
class PresetChangedOperationDeleted(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_DELETED
self.additional_parameters = index
class PresetChangedOperationAvailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_AVAILABLE
self.additional_parameters = index
class PresetChangedOperationUnavailable(PresetChangedOperation):
def __init__(self, index) -> None:
self.change_id = PresetChangedOperation.ChangeId.PRESET_RECORD_UNAVAILABLE
self.additional_parameters = index
@dataclass
class PresetRecord:
'''See Hearing Access Service 2.8. Preset record.'''
@dataclass
class Property:
class Writable(utils.OpenIntEnum):
CANNOT_BE_WRITTEN = 0b0
CAN_BE_WRITTEN = 0b1
class IsAvailable(utils.OpenIntEnum):
IS_UNAVAILABLE = 0b0
IS_AVAILABLE = 0b1
writable: Writable = Writable.CAN_BE_WRITTEN
is_available: IsAvailable = IsAvailable.IS_AVAILABLE
def __bytes__(self) -> bytes:
return bytes([self.writable | (self.is_available << 1)])
index: int
name: str
properties: Property = field(default_factory=Property)
def __bytes__(self) -> bytes:
return bytes([self.index]) + bytes(self.properties) + self.name.encode('utf-8')
def is_available(self) -> bool:
return (
self.properties.is_available
== PresetRecord.Property.IsAvailable.IS_AVAILABLE
)
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class HearingAccessService(gatt.TemplateService):
UUID = gatt.GATT_HEARING_ACCESS_SERVICE
hearing_aid_features_characteristic: gatt.Characteristic[bytes]
hearing_aid_preset_control_point: gatt.Characteristic[bytes]
active_preset_index_characteristic: gatt.Characteristic[bytes]
active_preset_index: int
active_preset_index_per_device: dict[Address, int]
device: Device
server_features: HearingAidFeatures
preset_records: dict[int, PresetRecord] # key is the preset index
read_presets_request_in_progress: bool
other_server_in_binaural_set: HearingAccessService | None = None
preset_changed_operations_history_per_device: dict[
Address, list[PresetChangedOperation]
]
# Keep an updated list of connected client to send notification to
currently_connected_clients: set[Connection]
def __init__(
self, device: Device, features: HearingAidFeatures, presets: list[PresetRecord]
) -> None:
self.active_preset_index_per_device = {}
self.read_presets_request_in_progress = False
self.preset_changed_operations_history_per_device = {}
self.currently_connected_clients = set()
self.device = device
self.server_features = features
if len(presets) < 1:
raise InvalidArgumentError(f'Invalid presets: {presets}')
self.preset_records = {}
for p in presets:
if len(p.name.encode()) < 1 or len(p.name.encode()) > 40:
raise InvalidArgumentError(f'Invalid name: {p.name}')
self.preset_records[p.index] = p
# associate the lowest index as the current active preset at startup
self.active_preset_index = sorted(self.preset_records.keys())[0]
@device.on(device.EVENT_CONNECTION)
def on_connection(connection: Connection) -> None:
@connection.on(connection.EVENT_DISCONNECTION)
def on_disconnection(_reason) -> None:
self.currently_connected_clients.discard(connection)
@connection.on(connection.EVENT_CONNECTION_ATT_MTU_UPDATE)
def on_mtu_update(*_: Any) -> None:
self.on_incoming_connection(connection)
@connection.on(connection.EVENT_CONNECTION_ENCRYPTION_CHANGE)
def on_encryption_change(*_: Any) -> None:
self.on_incoming_connection(connection)
@connection.on(connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None:
self.on_incoming_connection(connection)
self.on_incoming_connection(connection)
self.hearing_aid_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes(self.server_features),
)
self.hearing_aid_preset_control_point = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.INDICATE
),
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(
write=self._on_write_hearing_aid_preset_control_point
),
)
self.active_preset_index_characteristic = gatt.Characteristic(
uuid=gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_active_preset_index),
)
super().__init__(
[
self.hearing_aid_features_characteristic,
self.hearing_aid_preset_control_point,
self.active_preset_index_characteristic,
]
)
def on_incoming_connection(self, connection: Connection):
'''Setup initial operations to handle a remote bonded HAP device'''
# TODO Should we filter on HAP device only ?
if not connection.is_encrypted:
logging.debug(f'HAS: {connection.peer_address} is not encrypted')
return
if not connection.peer_resolvable_address:
logging.debug(f'HAS: {connection.peer_address} is not paired')
return
if connection.att_mtu < 49:
logging.debug(
f'HAS: {connection.peer_address} invalid MTU={connection.att_mtu}'
)
return
if connection.peer_address in self.currently_connected_clients:
logging.debug(
f'HAS: Already connected to {connection.peer_address} nothing to do'
)
return
self.currently_connected_clients.add(connection)
if (
connection.peer_address
not in self.preset_changed_operations_history_per_device
):
self.preset_changed_operations_history_per_device[
connection.peer_address
] = []
return
async def on_connection_async() -> None:
# Send all the PresetChangedOperation that occur when not connected
await self._preset_changed_operation(connection)
# Update the active preset index if needed
await self.notify_active_preset_for_connection(connection)
connection.cancel_on_disconnection(on_connection_async())
def _on_read_active_preset_index(self, connection: Connection) -> bytes:
del connection # Unused
return bytes([self.active_preset_index])
# TODO this need to be triggered when device is unbonded
def on_forget(self, addr: Address) -> None:
self.preset_changed_operations_history_per_device.pop(addr)
async def _on_write_hearing_aid_preset_control_point(
self, connection: Connection, value: bytes
):
opcode = HearingAidPresetControlPointOpcode(value[0])
handler = getattr(self, '_on_' + opcode.name.lower())
await handler(connection, value)
async def _on_read_presets_request(self, connection: Connection, value: bytes):
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
logging.warning(f'HAS require MTU >= 49: {connection}')
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
self.read_presets_request_in_progress = True
start_index = value[1]
if start_index == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
num_presets = value[2]
if num_presets == 0x00:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
# Sending `num_presets` presets ordered by increasing index field, starting from start_index
presets = [
self.preset_records[key]
for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index
][:num_presets]
if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
utils.AsyncRunner.spawn(self._read_preset_response(connection, presets))
async def _read_preset_response(
self, connection: Connection, presets: list[PresetRecord]
):
# If the ATT bearer is terminated before all notifications or indications are
# sent, then the server shall consider the Read Presets Request operation
# aborted and shall not either continue or restart the operation when the client
# reconnects.
try:
for i, preset in enumerate(presets):
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=bytes(
[
HearingAidPresetControlPointOpcode.READ_PRESET_RESPONSE,
i == len(presets) - 1,
]
)
+ bytes(preset),
)
finally:
# indicate_subscriber can raise a TimeoutError, we need to gracefully terminate the operation
self.read_presets_request_in_progress = False
async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notify_preset_operations(op)
async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index]
await self._notify_preset_operations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None:
'''Server API to make a preset available'''
preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notify_preset_operations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset'''
if index == self.active_preset_index:
raise InvalidStateError('Cannot set active preset as unavailable')
preset = self.preset_records[index]
preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
)
await self._notify_preset_operations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection'''
op_list = self.preset_changed_operations_history_per_device.get(
connection.peer_address, []
)
# Notification will be sent in index order
def get_op_index(op: PresetChangedOperation) -> int:
if isinstance(op.additional_parameters, PresetChangedOperation.Generic):
return op.additional_parameters.prev_index
return op.additional_parameters
op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are
# sent, then the server shall consider the Preset Changed operation aborted and
# shall continue the operation when the client reconnects.
while op_list:
try:
await connection.device.indicate_subscriber(
connection,
self.hearing_aid_preset_control_point,
value=op_list[0].to_bytes(len(op_list) == 1),
force=True, # TODO GATT notification subscription should be persistent
)
# Remove item once sent, and keep the non sent item in the list
op_list.pop(0)
except TimeoutError:
break
async def _notify_preset_operations(self, op: PresetChangedOperation) -> None:
for history_list in self.preset_changed_operations_history_per_device.values():
history_list.append(op)
for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection)
async def _on_write_preset_name(self, connection: Connection, value: bytes):
del connection # Unused
if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.writable
== PresetRecord.Property.Writable.CANNOT_BE_WRITTEN
):
raise att.ATT_Error(ErrorCode.WRITE_NAME_NOT_ALLOWED)
name = value[2:].decode('utf-8')
if not name or len(name) > 40:
raise att.ATT_Error(ErrorCode.INVALID_PARAMETERS_LENGTH)
preset.name = name
await self.generic_update(
PresetChangedOperation(
PresetChangedOperation.ChangeId.GENERIC_UPDATE,
PresetChangedOperation.Generic(index, preset),
)
)
async def notify_active_preset_for_connection(self, connection: Connection) -> None:
if (
self.active_preset_index_per_device.get(connection.peer_address, 0x00)
== self.active_preset_index
):
# Nothing to do, peer is already updated
return
await connection.device.notify_subscriber(
connection,
attribute=self.active_preset_index_characteristic,
value=bytes([self.active_preset_index]),
)
self.active_preset_index_per_device[connection.peer_address] = (
self.active_preset_index
)
async def notify_active_preset(self) -> None:
for connection in self.currently_connected_clients:
await self.notify_active_preset_for_connection(connection)
async def set_active_preset(self, value: bytes) -> None:
index = value[1]
preset = self.preset_records.get(index, None)
if (
not preset
or preset.properties.is_available
!= PresetRecord.Property.IsAvailable.IS_AVAILABLE
):
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
if index == self.active_preset_index:
# Already at correct value
return
self.active_preset_index = index
await self.notify_active_preset()
async def _on_set_active_preset(self, connection: Connection, value: bytes):
del connection # Unused
await self.set_active_preset(value)
async def set_next_or_previous_preset(self, is_previous: bool) -> None:
'''Set the next or the previous preset as active'''
if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
presets = sorted(
[
record
for record in self.preset_records.values()
if record.is_available()
],
key=lambda record: record.index,
)
current_preset = self.preset_records[self.active_preset_index]
current_preset_pos = presets.index(current_preset)
if is_previous:
new_preset = presets[(current_preset_pos - 1) % len(presets)]
else:
new_preset = presets[(current_preset_pos + 1) % len(presets)]
if current_preset == new_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
self.active_preset_index = new_preset.index
await self.notify_active_preset()
async def _on_set_next_preset(self, connection: Connection, value: bytes) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(False)
async def _on_set_previous_preset(
self, connection: Connection, value: bytes
) -> None:
del connection, value # Unused.
await self.set_next_or_previous_preset(True)
async def _on_set_active_preset_synchronized_locally(
self, connection: Connection, value: bytes
):
del connection # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_active_preset(value)
if self.other_server_in_binaural_set:
await self.other_server_in_binaural_set.set_active_preset(value)
async def _on_set_next_preset_synchronized_locally(
self, connection: Connection, value: bytes
):
del connection, value # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(False)
if self.other_server_in_binaural_set:
await self.other_server_in_binaural_set.set_next_or_previous_preset(False)
async def _on_set_previous_preset_synchronized_locally(
self, connection: Connection, value: bytes
):
del connection, value # Unused.
if (
self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED
):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(True)
if self.other_server_in_binaural_set:
await self.other_server_in_binaural_set.set_next_or_previous_preset(True)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue[bytes]
active_preset_index_notification: asyncio.Queue[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC
)[0],
'B',
)
self.hearing_aid_preset_control_point = (
service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC
)[0]
)
self.active_preset_index = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_ACTIVE_PRESET_INDEX_CHARACTERISTIC
)[0],
'B',
)
async def setup_subscription(self) -> None:
await self.hearing_aid_preset_control_point.subscribe(
self.preset_control_point_indications.put_nowait,
prefer_notify=False,
)
await self.active_preset_index.subscribe(
self.active_preset_index_notification.put_nowait
)