From 29e53b89d41fde9d27e7f42dca0f2f15ad00d003 Mon Sep 17 00:00:00 2001 From: pstruebi Date: Wed, 19 Mar 2025 12:52:31 +0100 Subject: [PATCH] refractoring --- .../translator_server/translator_server.py | 74 ++++++++----------- src/voice_provider/text_to_speech.py | 31 +++++++- src/voice_provider/tts_server.py | 15 +++- 3 files changed, 73 insertions(+), 47 deletions(-) diff --git a/src/multilang_translator/translator_server/translator_server.py b/src/multilang_translator/translator_server/translator_server.py index 27fb2d3..4d1d386 100644 --- a/src/multilang_translator/translator_server/translator_server.py +++ b/src/multilang_translator/translator_server/translator_server.py @@ -104,6 +104,9 @@ async def make_announcement(text: str, ep_group: EndpointGroup): for endpoint in ep_group.endpoints ] + # make sure init finished + await asyncio.gather(*init_tasks) + # Translate the text for each language (concurrently) base_lang = "deu" # German is the base language target_langs = ep_group.languages.copy() @@ -118,17 +121,17 @@ async def make_announcement(text: str, ep_group: EndpointGroup): translation_tasks = [] for lang in target_langs: - # Prepare translation task - trans_conf = getattr(ep_group.translator_config, lang) - task = llm_translator.translate_de_to_x_async( - text=text, - target_language=lang, - client=trans_conf.llm_client, - model=trans_conf.translator_llm, - host=trans_conf.llm_host_url, - token=trans_conf.llm_host_token - ) - translation_tasks.append(task) + # Prepare translation task + trans_conf = getattr(ep_group.translator_config, lang) + task = llm_translator.translate_de_to_x_async( + text=text, + target_language=lang, + client=trans_conf.llm_client, + model=trans_conf.translator_llm, + host=trans_conf.llm_host_url, + token=trans_conf.llm_host_token + ) + translation_tasks.append(task) # Wait for all translations to complete concurrently results = await asyncio.gather(*translation_tasks) @@ -141,41 +144,26 @@ async def make_announcement(text: str, ep_group: EndpointGroup): ep_group.current_state = AnnouncementStates.GENERATING_VOICE endpoints_db.update_group(ep_group.id, ep_group) - # Prepare synthesis jobs - audio = {} - - # Define a helper function for voice synthesis - async def synthesize_text(lang, text): + # Prepare synthesis tasks and run them concurrently + synth_langs = ep_group.languages + synthesis_tasks = [] + for lang in synth_langs: trans_conf = getattr(ep_group.translator_config, lang) - # Since synthesize is not async, we'll run it in a thread pool - loop = asyncio.get_event_loop() - audio_data = await loop.run_in_executor( - None, - lambda: text_to_speech.synthesize( - text, - ep_group.sampling_rate_hz, - trans_conf.tts_system, - trans_conf.tts_model, - return_lc3=True - ).decode('latin-1') + task = text_to_speech.synthesize_async( + translations[lang], + ep_group.sampling_rate_hz, + trans_conf.tts_system, + trans_conf.tts_model, + return_lc3=True ) - return lang, audio_data + synthesis_tasks.append(task) - # Create tasks for voice synthesis - synthesis_tasks = [ - synthesize_text(lang, text) - for lang, text in translations.items() - ] - - # Wait for all synthesis tasks to complete - synthesis_results = await asyncio.gather(*synthesis_tasks) - - # Build audio dictionary from results - for lang, audio_data in synthesis_results: - audio[lang] = audio_data - - # make sure init finished before broadcasting - await asyncio.gather(*init_tasks) + # Wait for all synthesis tasks to complete concurrently + audio = {} + if synthesis_tasks: + results = await asyncio.gather(*synthesis_tasks) + for i, audio_data in enumerate(results): + audio[synth_langs[i]] = audio_data # Start the monitoring coroutine to wait for streaming to complete # This will set the state to COMPLETED when finished diff --git a/src/voice_provider/text_to_speech.py b/src/voice_provider/text_to_speech.py index bc2ac01..c076c24 100644 --- a/src/voice_provider/text_to_speech.py +++ b/src/voice_provider/text_to_speech.py @@ -5,6 +5,7 @@ import time import json import logging as log import numpy as np +import asyncio from voice_provider.utils.resample import resample_array from voice_provider.utils.encode_lc3 import encode_lc3 @@ -46,7 +47,6 @@ def synth_piper(text, model="en_US-lessac-medium"): return model_json, audio -# TODO: framework should probably be a dataclass that holds all the relevant informations, also model def synthesize( text, target_sample_rate, @@ -77,6 +77,35 @@ def synthesize( return audio +async def synthesize_async( + text, + target_sample_rate, + framework, + model="en_US-lessac-medium", + return_lc3=True + ): + """ + Asynchronous version of the synthesize function that runs in a thread pool. + + Args: + text: Text to synthesize + target_sample_rate: Target sample rate for the audio + framework: TTS framework to use (e.g., 'piper') + model: Model to use for synthesis + return_lc3: Whether to return LC3-encoded audio + + Returns: + LC3-encoded audio as string or raw audio as numpy array + """ + # Run the CPU-intensive synthesis in a thread pool + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + lambda: synthesize(text, target_sample_rate, framework, model, return_lc3) + ) + return result + + if __name__ == '__main__': import logging import soundfile as sf diff --git a/src/voice_provider/tts_server.py b/src/voice_provider/tts_server.py index 15cb75c..6ce8fe1 100644 --- a/src/voice_provider/tts_server.py +++ b/src/voice_provider/tts_server.py @@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException import numpy as np from voice_models.request_models import SynthesizeRequest -from voice_provider.text_to_speech import synthesize +from voice_provider.text_to_speech import synthesize_async app = FastAPI() @@ -12,7 +12,7 @@ HOST_PORT = 8099 @app.post("/synthesize/") async def synthesize_speech(request: SynthesizeRequest): try: - audio = synthesize( + audio = await synthesize_async( text=request.text, target_sample_rate=request.target_sample_rate, framework=request.framework, @@ -21,8 +21,17 @@ async def synthesize_speech(request: SynthesizeRequest): ) if request.return_lc3: - return {"audio_lc3": audio.hex()} + # If it's already a string (LC3 data), convert it to bytes for hex encoding + if isinstance(audio, str): + audio_bytes = audio.encode('latin-1') + return {"audio_lc3": audio_bytes.hex()} + # If it's already bytes + elif isinstance(audio, bytes): + return {"audio_lc3": audio.hex()} + else: + raise ValueError(f"Unexpected audio type: {type(audio)}") else: + # If it's numpy array (non-LC3), convert to bytes audio_bytes = audio.astype(np.float32).tobytes() return {"audio_pcm": audio_bytes.hex()}