refractoring
This commit is contained in:
@@ -104,6 +104,9 @@ async def make_announcement(text: str, ep_group: EndpointGroup):
|
||||
for endpoint in ep_group.endpoints
|
||||
]
|
||||
|
||||
# make sure init finished
|
||||
await asyncio.gather(*init_tasks)
|
||||
|
||||
# Translate the text for each language (concurrently)
|
||||
base_lang = "deu" # German is the base language
|
||||
target_langs = ep_group.languages.copy()
|
||||
@@ -118,17 +121,17 @@ async def make_announcement(text: str, ep_group: EndpointGroup):
|
||||
translation_tasks = []
|
||||
|
||||
for lang in target_langs:
|
||||
# Prepare translation task
|
||||
trans_conf = getattr(ep_group.translator_config, lang)
|
||||
task = llm_translator.translate_de_to_x_async(
|
||||
text=text,
|
||||
target_language=lang,
|
||||
client=trans_conf.llm_client,
|
||||
model=trans_conf.translator_llm,
|
||||
host=trans_conf.llm_host_url,
|
||||
token=trans_conf.llm_host_token
|
||||
)
|
||||
translation_tasks.append(task)
|
||||
# Prepare translation task
|
||||
trans_conf = getattr(ep_group.translator_config, lang)
|
||||
task = llm_translator.translate_de_to_x_async(
|
||||
text=text,
|
||||
target_language=lang,
|
||||
client=trans_conf.llm_client,
|
||||
model=trans_conf.translator_llm,
|
||||
host=trans_conf.llm_host_url,
|
||||
token=trans_conf.llm_host_token
|
||||
)
|
||||
translation_tasks.append(task)
|
||||
|
||||
# Wait for all translations to complete concurrently
|
||||
results = await asyncio.gather(*translation_tasks)
|
||||
@@ -141,41 +144,26 @@ async def make_announcement(text: str, ep_group: EndpointGroup):
|
||||
ep_group.current_state = AnnouncementStates.GENERATING_VOICE
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
# Prepare synthesis jobs
|
||||
audio = {}
|
||||
|
||||
# Define a helper function for voice synthesis
|
||||
async def synthesize_text(lang, text):
|
||||
# Prepare synthesis tasks and run them concurrently
|
||||
synth_langs = ep_group.languages
|
||||
synthesis_tasks = []
|
||||
for lang in synth_langs:
|
||||
trans_conf = getattr(ep_group.translator_config, lang)
|
||||
# Since synthesize is not async, we'll run it in a thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
audio_data = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: text_to_speech.synthesize(
|
||||
text,
|
||||
ep_group.sampling_rate_hz,
|
||||
trans_conf.tts_system,
|
||||
trans_conf.tts_model,
|
||||
return_lc3=True
|
||||
).decode('latin-1')
|
||||
task = text_to_speech.synthesize_async(
|
||||
translations[lang],
|
||||
ep_group.sampling_rate_hz,
|
||||
trans_conf.tts_system,
|
||||
trans_conf.tts_model,
|
||||
return_lc3=True
|
||||
)
|
||||
return lang, audio_data
|
||||
synthesis_tasks.append(task)
|
||||
|
||||
# Create tasks for voice synthesis
|
||||
synthesis_tasks = [
|
||||
synthesize_text(lang, text)
|
||||
for lang, text in translations.items()
|
||||
]
|
||||
|
||||
# Wait for all synthesis tasks to complete
|
||||
synthesis_results = await asyncio.gather(*synthesis_tasks)
|
||||
|
||||
# Build audio dictionary from results
|
||||
for lang, audio_data in synthesis_results:
|
||||
audio[lang] = audio_data
|
||||
|
||||
# make sure init finished before broadcasting
|
||||
await asyncio.gather(*init_tasks)
|
||||
# Wait for all synthesis tasks to complete concurrently
|
||||
audio = {}
|
||||
if synthesis_tasks:
|
||||
results = await asyncio.gather(*synthesis_tasks)
|
||||
for i, audio_data in enumerate(results):
|
||||
audio[synth_langs[i]] = audio_data
|
||||
|
||||
# Start the monitoring coroutine to wait for streaming to complete
|
||||
# This will set the state to COMPLETED when finished
|
||||
|
||||
@@ -5,6 +5,7 @@ import time
|
||||
import json
|
||||
import logging as log
|
||||
import numpy as np
|
||||
import asyncio
|
||||
from voice_provider.utils.resample import resample_array
|
||||
from voice_provider.utils.encode_lc3 import encode_lc3
|
||||
|
||||
@@ -46,7 +47,6 @@ def synth_piper(text, model="en_US-lessac-medium"):
|
||||
return model_json, audio
|
||||
|
||||
|
||||
# TODO: framework should probably be a dataclass that holds all the relevant informations, also model
|
||||
def synthesize(
|
||||
text,
|
||||
target_sample_rate,
|
||||
@@ -77,6 +77,35 @@ def synthesize(
|
||||
return audio
|
||||
|
||||
|
||||
async def synthesize_async(
|
||||
text,
|
||||
target_sample_rate,
|
||||
framework,
|
||||
model="en_US-lessac-medium",
|
||||
return_lc3=True
|
||||
):
|
||||
"""
|
||||
Asynchronous version of the synthesize function that runs in a thread pool.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
target_sample_rate: Target sample rate for the audio
|
||||
framework: TTS framework to use (e.g., 'piper')
|
||||
model: Model to use for synthesis
|
||||
return_lc3: Whether to return LC3-encoded audio
|
||||
|
||||
Returns:
|
||||
LC3-encoded audio as string or raw audio as numpy array
|
||||
"""
|
||||
# Run the CPU-intensive synthesis in a thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: synthesize(text, target_sample_rate, framework, model, return_lc3)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import logging
|
||||
import soundfile as sf
|
||||
|
||||
@@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException
|
||||
import numpy as np
|
||||
|
||||
from voice_models.request_models import SynthesizeRequest
|
||||
from voice_provider.text_to_speech import synthesize
|
||||
from voice_provider.text_to_speech import synthesize_async
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -12,7 +12,7 @@ HOST_PORT = 8099
|
||||
@app.post("/synthesize/")
|
||||
async def synthesize_speech(request: SynthesizeRequest):
|
||||
try:
|
||||
audio = synthesize(
|
||||
audio = await synthesize_async(
|
||||
text=request.text,
|
||||
target_sample_rate=request.target_sample_rate,
|
||||
framework=request.framework,
|
||||
@@ -21,8 +21,17 @@ async def synthesize_speech(request: SynthesizeRequest):
|
||||
)
|
||||
|
||||
if request.return_lc3:
|
||||
return {"audio_lc3": audio.hex()}
|
||||
# If it's already a string (LC3 data), convert it to bytes for hex encoding
|
||||
if isinstance(audio, str):
|
||||
audio_bytes = audio.encode('latin-1')
|
||||
return {"audio_lc3": audio_bytes.hex()}
|
||||
# If it's already bytes
|
||||
elif isinstance(audio, bytes):
|
||||
return {"audio_lc3": audio.hex()}
|
||||
else:
|
||||
raise ValueError(f"Unexpected audio type: {type(audio)}")
|
||||
else:
|
||||
# If it's numpy array (non-LC3), convert to bytes
|
||||
audio_bytes = audio.astype(np.float32).tobytes()
|
||||
return {"audio_pcm": audio_bytes.hex()}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user