# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for """LE Audio - Broadcast Audio Scan Service""" # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- from __future__ import annotations import dataclasses import logging import struct from typing import ClassVar, List, Optional, Sequence from bumble import core from bumble import device from bumble import gatt from bumble import gatt_client from bumble import hci from bumble import utils # ----------------------------------------------------------------------------- # Logging # ----------------------------------------------------------------------------- logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- class ApplicationError(utils.OpenIntEnum): OPCODE_NOT_SUPPORTED = 0x80 INVALID_SOURCE_ID = 0x81 # ----------------------------------------------------------------------------- def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes: return bytes([len(subgroups)]) + b"".join( struct.pack(" List[SubgroupInfo]: num_subgroups = data[0] offset = 1 subgroups = [] for _ in range(num_subgroups): bis_sync = struct.unpack(" ControlPointOperation: op_code = data[0] if op_code == cls.OpCode.REMOTE_SCAN_STOPPED: return RemoteScanStoppedOperation() if op_code == cls.OpCode.REMOTE_SCAN_STARTED: return RemoteScanStartedOperation() if op_code == cls.OpCode.ADD_SOURCE: return AddSourceOperation.from_parameters(data[1:]) if op_code == cls.OpCode.MODIFY_SOURCE: return ModifySourceOperation.from_parameters(data[1:]) if op_code == cls.OpCode.SET_BROADCAST_CODE: return SetBroadcastCodeOperation.from_parameters(data[1:]) if op_code == cls.OpCode.REMOVE_SOURCE: return RemoveSourceOperation.from_parameters(data[1:]) raise core.InvalidArgumentError("invalid op code") def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None: self.op_code = op_code self.parameters = parameters def __bytes__(self) -> bytes: return bytes([self.op_code]) + self.parameters class RemoteScanStoppedOperation(ControlPointOperation): def __init__(self) -> None: super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED) class RemoteScanStartedOperation(ControlPointOperation): def __init__(self) -> None: super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED) class AddSourceOperation(ControlPointOperation): @classmethod def from_parameters(cls, parameters: bytes) -> AddSourceOperation: instance = cls.__new__(cls) instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE instance.parameters = parameters instance.advertiser_address = hci.Address.parse_address_preceded_by_type( parameters, 1 )[1] instance.advertising_sid = parameters[7] instance.broadcast_id = int.from_bytes(parameters[8:11], "little") instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11]) instance.pa_interval = struct.unpack(" None: super().__init__( ControlPointOperation.OpCode.ADD_SOURCE, struct.pack( " ModifySourceOperation: instance = cls.__new__(cls) instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE instance.parameters = parameters instance.source_id = parameters[0] instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1]) instance.pa_interval = struct.unpack(" None: super().__init__( ControlPointOperation.OpCode.MODIFY_SOURCE, struct.pack(" SetBroadcastCodeOperation: instance = cls.__new__(cls) instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE instance.parameters = parameters instance.source_id = parameters[0] instance.broadcast_code = parameters[1:17] return instance def __init__( self, source_id: int, broadcast_code: bytes, ) -> None: super().__init__( ControlPointOperation.OpCode.SET_BROADCAST_CODE, bytes([source_id]) + broadcast_code, ) self.source_id = source_id self.broadcast_code = broadcast_code if len(self.broadcast_code) != 16: raise core.InvalidArgumentError("broadcast_code must be 16 bytes") class RemoveSourceOperation(ControlPointOperation): @classmethod def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation: instance = cls.__new__(cls) instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE instance.parameters = parameters instance.source_id = parameters[0] return instance def __init__(self, source_id: int) -> None: super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id])) self.source_id = source_id @dataclasses.dataclass class BroadcastReceiveState: class PeriodicAdvertisingSyncState(utils.OpenIntEnum): NOT_SYNCHRONIZED_TO_PA = 0x00 SYNCINFO_REQUEST = 0x01 SYNCHRONIZED_TO_PA = 0x02 FAILED_TO_SYNCHRONIZE_TO_PA = 0x03 NO_PAST = 0x04 class BigEncryption(utils.OpenIntEnum): NOT_ENCRYPTED = 0x00 BROADCAST_CODE_REQUIRED = 0x01 DECRYPTING = 0x02 BAD_CODE = 0x03 source_id: int source_address: hci.Address source_adv_sid: int broadcast_id: int pa_sync_state: PeriodicAdvertisingSyncState big_encryption: BigEncryption bad_code: bytes subgroups: List[SubgroupInfo] @classmethod def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]: if not data: return None source_id = data[0] _, source_address = hci.Address.parse_address_preceded_by_type(data, 2) source_adv_sid = data[8] broadcast_id = int.from_bytes(data[9:12], "little") pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12]) big_encryption = cls.BigEncryption(data[13]) if big_encryption == cls.BigEncryption.BAD_CODE: bad_code = data[14:30] subgroups = decode_subgroups(data[30:]) else: bad_code = b"" subgroups = decode_subgroups(data[14:]) return cls( source_id, source_address, source_adv_sid, broadcast_id, pa_sync_state, big_encryption, bad_code, subgroups, ) def __bytes__(self) -> bytes: return ( struct.pack( " None: pass # ----------------------------------------------------------------------------- class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy): SERVICE_CLASS = BroadcastAudioScanService broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter] def __init__(self, service_proxy: gatt_client.ServiceProxy): self.service_proxy = service_proxy if not ( characteristics := service_proxy.get_characteristics_by_uuid( gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC ) ): raise gatt.InvalidServiceError( "Broadcast Audio Scan Control Point characteristic not found" ) self.broadcast_audio_scan_control_point = characteristics[0] if not ( characteristics := service_proxy.get_characteristics_by_uuid( gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC ) ): raise gatt.InvalidServiceError( "Broadcast Receive State characteristic not found" ) self.broadcast_receive_states = [ gatt.DelegatedCharacteristicAdapter( characteristic, decode=BroadcastReceiveState.from_bytes ) for characteristic in characteristics ] async def send_control_point_operation( self, operation: ControlPointOperation ) -> None: await self.broadcast_audio_scan_control_point.write_value( bytes(operation), with_response=True ) async def remote_scan_started(self) -> None: await self.send_control_point_operation(RemoteScanStartedOperation()) async def remote_scan_stopped(self) -> None: await self.send_control_point_operation(RemoteScanStoppedOperation()) async def add_source( self, advertiser_address: hci.Address, advertising_sid: int, broadcast_id: int, pa_sync: PeriodicAdvertisingSyncParams, pa_interval: int, subgroups: Sequence[SubgroupInfo], ) -> None: await self.send_control_point_operation( AddSourceOperation( advertiser_address, advertising_sid, broadcast_id, pa_sync, pa_interval, subgroups, ) ) async def modify_source( self, source_id: int, pa_sync: PeriodicAdvertisingSyncParams, pa_interval: int, subgroups: Sequence[SubgroupInfo], ) -> None: await self.send_control_point_operation( ModifySourceOperation( source_id, pa_sync, pa_interval, subgroups, ) ) async def remove_source(self, source_id: int) -> None: await self.send_control_point_operation(RemoveSourceOperation(source_id))