diff --git a/bumble/profiles/ascs.py b/bumble/profiles/ascs.py index ded78973..216874d8 100644 --- a/bumble/profiles/ascs.py +++ b/bumble/profiles/ascs.py @@ -452,6 +452,16 @@ class AseStateMachine(gatt.Characteristic): self.metadata = le_audio.Metadata.from_bytes(metadata) self.state = self.State.ENABLING + # CIS could be established before enable. + if cis_link := next( + ( + cis_link + for cis_link in self.service.device.cis_links.values() + if cis_link.cig_id == self.cig_id and cis_link.cis_id == self.cis_id + ), + None, + ): + self.on_cis_establishment(cis_link) return (AseResponseCode.SUCCESS, AseReasonCode.NONE) diff --git a/tests/bap_test.py b/tests/bap_test.py index e8e84f9c..84cf8a1e 100644 --- a/tests/bap_test.py +++ b/tests/bap_test.py @@ -16,7 +16,6 @@ # Imports # ----------------------------------------------------------------------------- import asyncio -import os import functools import pytest import logging @@ -55,7 +54,7 @@ from bumble.profiles.pacs import ( PublishedAudioCapabilitiesServiceProxy, ) from bumble.profiles.le_audio import Metadata -from tests.test_utils import TwoDevices +from tests.test_utils import TwoDevices, async_barrier # ----------------------------------------------------------------------------- @@ -441,15 +440,114 @@ async def test_ascs(): assert (await notifications[1].get())[:2] == bytes([1, AseStateMachine.State.IDLE]) assert (await notifications[2].get())[:2] == bytes([2, AseStateMachine.State.IDLE]) - await asyncio.sleep(0.001) + await async_barrier() # ----------------------------------------------------------------------------- -async def run(): - await test_pacs() +@pytest.mark.asyncio +async def test_ascs_enable_source_then_sink(): + devices = TwoDevices() + ascs_server = AudioStreamControlService( + device=devices[1], sink_ase_id=[1], source_ase_id=[2] + ) + sink_ase = ascs_server.ase_state_machines[1] + source_ase = ascs_server.ase_state_machines[2] + devices[1].add_service(ascs_server) + condition = asyncio.Condition() + async def on_state_change(): + async with condition: + condition.notify_all() -# ----------------------------------------------------------------------------- -if __name__ == '__main__': - logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) - asyncio.run(run()) + sink_ase.on(sink_ase.EVENT_STATE_CHANGE, on_state_change) + source_ase.on(sink_ase.EVENT_STATE_CHANGE, on_state_change) + + await devices.setup_connection() + peer = device.Peer(devices.connections[0]) + ascs_client = await peer.discover_service_and_create_proxy( + AudioStreamControlServiceProxy + ) + + # Config Codec + config = CodecSpecificConfiguration( + sampling_frequency=SamplingFrequency.FREQ_48000, + frame_duration=FrameDuration.DURATION_10000_US, + audio_channel_allocation=AudioLocation.FRONT_LEFT, + octets_per_codec_frame=120, + codec_frames_per_sdu=1, + ) + await ascs_client.ase_control_point.write_value( + ASE_Config_Codec( + ase_id=[1, 2], + target_latency=[3, 4], + target_phy=[5, 6], + codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)], + codec_specific_configuration=[config, config], + ) + ) + async with condition: + await condition.wait_for( + lambda: ( + sink_ase.state == AseStateMachine.State.CODEC_CONFIGURED + and source_ase.state == AseStateMachine.State.CODEC_CONFIGURED + ) + ) + + # Config QOS + await ascs_client.ase_control_point.write_value( + ASE_Config_QOS( + ase_id=[1, 2], + cig_id=[1, 1], + cis_id=[1, 1], + sdu_interval=[100, 100], + framing=[0, 0], + phy=[1, 1], + max_sdu=[100, 100], + retransmission_number=[16, 16], + max_transport_latency=[150, 150], + presentation_delay=[10, 10], + ) + ) + async with condition: + await condition.wait_for( + lambda: ( + sink_ase.state == AseStateMachine.State.QOS_CONFIGURED + and source_ase.state == AseStateMachine.State.QOS_CONFIGURED + ) + ) + + # Enable ASE 2 + await ascs_client.ase_control_point.write_value( + ASE_Enable(ase_id=[2], metadata=[b'foo']) + ) + await async_barrier() + cis_handles = await devices[0].setup_cig( + device.CigParameters( + cig_id=1, + cis_parameters=[device.CigParameters.CisParameters(cis_id=1)], + sdu_interval_c_to_p=100, + sdu_interval_p_to_c=100, + ) + ) + await devices[0].create_cis([(cis_handles[0], devices.connections[0])]) + + async with condition: + await condition.wait_for( + lambda: (source_ase.state == AseStateMachine.State.ENABLING) + ) + await ascs_client.ase_control_point.write_value( + ASE_Receiver_Start_Ready(ase_id=[2]) + ) + async with condition: + await condition.wait_for( + lambda: (source_ase.state == AseStateMachine.State.STREAMING) + ) + + # Enable ASE 1 + await ascs_client.ase_control_point.write_value( + ASE_Enable(ase_id=[1], metadata=[b'bar']) + ) + async with condition: + await condition.wait_for( + lambda: (sink_ase.state == AseStateMachine.State.STREAMING) + )