diff --git a/pyproject.toml b/pyproject.toml index 8746a35..192c1bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "aioconsole==0.8.1", "fastapi==0.115.11", "uvicorn==0.34.0", + "aiohttp==3.9.3", ] [project.optional-dependencies] diff --git a/src/multilang_translator/main_cloud.py b/src/multilang_translator/main_cloud.py index 5585cd3..f51806b 100644 --- a/src/multilang_translator/main_cloud.py +++ b/src/multilang_translator/main_cloud.py @@ -1,5 +1,6 @@ from typing import List import time +import asyncio import logging as log @@ -15,7 +16,7 @@ import voice_client.tts_client import voice_models.request_models -def announcement_from_german_text( +async def announcement_from_german_text( config: translator_config.TranslatorConfigGroup, text_de ): @@ -50,12 +51,12 @@ def announcement_from_german_text( log.info('Voice synth took %s', time.time() - start) audio_data_dict[big.language] = lc3_audio.decode('latin-1') # TODO: should be .hex in the future - multicast_client.send_audio( + await multicast_client.send_audio( audio_data_dict ) -if __name__ == '__main__': +async def main(): log.basicConfig( level=log.INFO, format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' @@ -77,10 +78,12 @@ if __name__ == '__main__': conf.llm_host_url = 'https://ollama.pstruebi.xyz' conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13' - multicast_client.request_init( + await multicast_client.init( config ) - announcement_from_german_text(config, 'Hello') + await announcement_from_german_text(config, 'Hello') - # TODO: make everything async + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/multilang_translator/translator/llm_translator.py b/src/multilang_translator/translator/llm_translator.py index 93f6eb6..6014e76 100644 --- a/src/multilang_translator/translator/llm_translator.py +++ b/src/multilang_translator/translator/llm_translator.py @@ -4,6 +4,7 @@ import json import logging as log import time import ollama +import aiohttp from multilang_translator.translator import syspromts @@ -12,10 +13,6 @@ from multilang_translator.translator import syspromts # from_='llama3.2', system="You are Mario from Super Mario Bros." # ) -async def chat(): - message = {'role': 'user', 'content': 'Why is the sky blue?'} - response = await ollama.AsyncClient().chat(model='llama3.2', messages=[message]) - def query_openwebui(model, system, query, url, token): url = f'{url}/api/chat/completions' @@ -50,6 +47,41 @@ def query_ollama(model, system, query, host='http://localhost:11434'): return response.message.content +async def query_openwebui_async(model, system, query, url, token): + url = f'{url}/api/chat/completions' + headers = { + 'Authorization': f'Bearer {token}', + } + payload = { + 'model': model, + 'messages': [ + {'role': 'system', 'content': system}, + {'role': 'user', 'content': query} + ], + } + start = time.time() + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=payload) as response: + response_json = await response.json() + log.info("Translating the text took %s s", round(time.time() - start, 2)) + return response_json['choices'][0]['message']['content'] + + +async def query_ollama_async(model, system, query, host='http://localhost:11434'): + client = ollama.AsyncClient( + host=host, + ) + + response = await client.chat( + model=model, + messages=[ + {'role': 'system', 'content': system}, + {'role': 'user', 'content': query} + ], + ) + return response.message.content + + def translate_de_to_x( # TODO: use async ollama client later - implenent a translate async function text:str, target_language: str, @@ -70,6 +102,27 @@ def translate_de_to_x( # TODO: use async ollama client later - implenent a trans log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3)) return response +async def translate_de_to_x_async( + text:str, + target_language: str, + client='ollama', + model='llama3.2:3b-instruct-q4_0', # remember to use instruct models + host = None, + token = None + ): + start=time.time() + s = getattr(syspromts, f"TRANSLATOR_DEU_{target_language.upper()}") + + if client == 'ollama': + response = await query_ollama_async(model, s, text, host=host) + elif client == 'openwebui': + response = await query_openwebui_async(model, s, text, url=host, token=token) + else: raise NotImplementedError('llm client not implemented') + + log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3)) + return response + + if __name__ == "__main__": import time from multilang_translator.translator import test_content diff --git a/src/multilang_translator/translator_models/translator_models.py b/src/multilang_translator/translator_models/translator_models.py index 390a666..938c30c 100644 --- a/src/multilang_translator/translator_models/translator_models.py +++ b/src/multilang_translator/translator_models/translator_models.py @@ -25,7 +25,6 @@ class Endpoint(BaseModel): url: str max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts - class TranslatorLangConfig(BaseModel): translator_llm: str = 'llama3.2:3b-instruct-q4_0' @@ -53,6 +52,7 @@ class EndpointGroup(BaseModel): name: str languages: List[str] endpoints: List[Endpoint] + sampling_rate_hz: int = 16000 translator_config: TranslatorConfig = TranslatorConfig() current_state: AnnouncementStates = AnnouncementStates.IDLE anouncement_start_time: float = 0.0 diff --git a/src/multilang_translator/translator_server/endpoints_db.py b/src/multilang_translator/translator_server/endpoints_db.py index 7f8d981..b38565c 100644 --- a/src/multilang_translator/translator_server/endpoints_db.py +++ b/src/multilang_translator/translator_server/endpoints_db.py @@ -55,7 +55,6 @@ def get_all_endpoints() -> List[Endpoint]: """Get all active endpoints.""" return ENDPOINTS - def get_endpoint_by_id(endpoint_id: str) -> Optional[Endpoint]: """Get an endpoint by its ID.""" return ENDPOINTS[endpoint_id] diff --git a/src/multilang_translator/translator_server/translator_server.py b/src/multilang_translator/translator_server/translator_server.py index 9b946fd..cbc107b 100644 --- a/src/multilang_translator/translator_server/translator_server.py +++ b/src/multilang_translator/translator_server/translator_server.py @@ -3,7 +3,7 @@ FastAPI implementation of the Multilang Translator API. This API mimics the mock_api from auracaster-webui to allow integration. """ import time -import logging as log +import logging as log import asyncio import random @@ -31,10 +31,10 @@ app.add_middleware( allow_headers=["*"], ) -# Endpoint configuration cache +# Endpoint configuration cache CURRENT_ENDPOINT_CONFIG = {} -def init_endpoint(endpoint: Endpoint, languages: list[str]): +async def init_endpoint(endpoint: Endpoint, languages: list[str], sampling_rate_hz: int): """Initialize a specific endpoint for multicast.""" current_config = CURRENT_ENDPOINT_CONFIG.get(endpoint.id) @@ -42,20 +42,24 @@ def init_endpoint(endpoint: Endpoint, languages: list[str]): if current_config is not None: current_langs = [big.language for big in current_config.bigs] # if languages are unchanged and the caster client status is initiailized, skip init - if current_langs == languages and multicast_client.get_status(base_url=endpoint.url)['is_initialized']: - log.info('Endpoint %s was already initialized', endpoint.name) - return + if current_langs == languages: + # Get status asynchronously + status = await multicast_client.get_status(base_url=endpoint.url) + if status['is_initialized']: + log.info('Endpoint %s was already initialized', endpoint.name) + return log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.url}") # Load a default config config = auracast_config.AuracastConfigGroup( - bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")() + bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")() for lang in languages] ) # overwrite some default configs config.transport = 'auto' config.auracast_device_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6)) + config.auracast_sampling_rate_hz = sampling_rate_hz # Configure the bigs for big in config.bigs: @@ -63,9 +67,10 @@ def init_endpoint(endpoint: Endpoint, languages: list[str]): big.name = endpoint.name big.random_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6)) big.id = random.randint(0, 2**16) #TODO: how many bits is this ? - #big.program_info = big.program_info + ' ' + endpoint.name + #big.program_info = big.program_info + ' ' + endpoint.name - ret = multicast_client.init(config, base_url=endpoint.url) + # make async init request + ret = await multicast_client.init(config, base_url=endpoint.url) # if ret != 200: # TODO: this is not working, should probably be handled async # log.error('Init of endpoint %s was unsucessfull', endpoint.name) # raise Exception(f"Init was of endpoint {endpoint.name} was unsucessfull") @@ -75,7 +80,6 @@ def init_endpoint(endpoint: Endpoint, languages: list[str]): # log.info('Endpoint %s was already initialized', endpoint.name) - async def make_announcement(text: str, ep_group: EndpointGroup): """ Make an announcement to a group of endpoints. @@ -89,93 +93,89 @@ async def make_announcement(text: str, ep_group: EndpointGroup): ep_group.anouncement_start_time = time.time() # update the database with the new state and start time so this can be read by another process endpoints_db.update_group(ep_group.id, ep_group) - - + # Initialize all endpoints in the group if they were not initalized before for endpoint in ep_group.endpoints: ep_group.current_state = AnnouncementStates.INIT endpoints_db.update_group(ep_group.id, ep_group) - init_endpoint(endpoint, ep_group.languages) - + await init_endpoint(endpoint, ep_group.languages, ep_group.sampling_rate_hz) + # Translate the text for each language base_lang = "deu" # German is the base language ep_group.current_state = AnnouncementStates.TRANSLATING endpoints_db.update_group(ep_group.id, ep_group) - + translations = {base_lang: text} for lang in ep_group.languages: if lang != base_lang: # Translate the text - lang_conf = getattr(ep_group.translator_config, lang) - translation = llm_translator.translate_de_to_x( + trans_conf = getattr(ep_group.translator_config, lang) + translation = await llm_translator.translate_de_to_x_async( text=text, target_language=lang, - client=lang_conf.llm_client, - model=lang_conf.translator_llm, - host=lang_conf.llm_host_url, - token=lang_conf.llm_host_token + client=trans_conf.llm_client, + model=trans_conf.translator_llm, + host=trans_conf.llm_host_url, + token=trans_conf.llm_host_token ) - translations[lang] = translation log.info(f"Translated to {lang}: {translation}") - # Generate voices ep_group.current_state = AnnouncementStates.GENERATING_VOICE endpoints_db.update_group(ep_group.id, ep_group) - + audio = {} - # Convert each translation to audio - for lang, text in translations.items(): + for lang, text in translations.items(): # Get the appropriate language configuration - lang_conf = getattr(ep_group.translator_config, lang) - + trans_conf = getattr(ep_group.translator_config, lang) + # Convert text to LC3-encoded audio using the configuration's TTS settings - audio[lang] = text_to_speech.synthesize( + audio[lang] = text_to_speech.synthesize( translations[lang], - 16000, # Sample rate from auracast config # TODO: take sampling rate from auracast config - lang_conf.tts_system, # TTS system from config - lang_conf.tts_model, # TTS model from config + ep_group.sampling_rate_hz, # Sample rate from auracast config # TODO: take sampling rate from auracast config + trans_conf.tts_system, # TTS system from config + trans_conf.tts_model, # TTS model from config return_lc3=True ).decode('latin-1') - + # Add to audio data dictionary (decode bytes to string for JSON serialization) # Start the monitoring coroutine to wait for streaming to complete # This will set the state to COMPLETED when finished asyncio.create_task(monitor_streaming_completion(ep_group)) - - + + # Broadcast to all endpoints in group for endpoint in ep_group.endpoints: # Send the audio data to the server using the existing send_audio function log.info(f"Broadcastcasting to {endpoint.name} for languages: {', '.join(audio.keys())}") - multicast_client.send_audio(audio, base_url=endpoint.url) + await multicast_client.send_audio(audio, base_url=endpoint.url) # Return the translations - return {"translations": translations} + return {"translations": translations} async def monitor_streaming_completion(ep_group: EndpointGroup): """ Monitor streaming status after audio is sent and update group state when complete. - + Args: ep_group: The endpoint group being monitored """ log.info(f"Starting streaming completion monitoring for endpoint group {ep_group.id}") - - + + # Set a shorter timeout as requested max_completion_time = 60 # seconds - + # First check if we are actually in streaming state streaming_started = False initial_check_timeout = 10 # seconds initial_check_start = time.time() - + # Wait for streaming to start (with timeout) while time.time() - initial_check_start < initial_check_timeout: # Wait before checking again @@ -183,7 +183,7 @@ async def monitor_streaming_completion(ep_group: EndpointGroup): any_streaming = False for endpoint in ep_group.endpoints: - status = multicast_client.get_status(base_url=endpoint.url) + status = await multicast_client.get_status(base_url=endpoint.url) if status.get("is_streaming", False): any_streaming = True log.info(f"Streaming confirmed started on endpoint {endpoint.name}") @@ -192,7 +192,6 @@ async def monitor_streaming_completion(ep_group: EndpointGroup): if any_streaming: streaming_started = True break - if not streaming_started: log.warning(f"No endpoints started streaming for group {ep_group.id} after {initial_check_timeout}s") @@ -200,20 +199,21 @@ async def monitor_streaming_completion(ep_group: EndpointGroup): ep_group.current_state = AnnouncementStates.ERROR endpoints_db.update_group(ep_group.id, ep_group) return - + # Update group progress ep_group.current_state = AnnouncementStates.BROADCASTING endpoints_db.update_group(ep_group.id, ep_group) - + # Now monitor until streaming completes on all endpoints check_completion_start_time = time.time() completed = [False for _ in ep_group.endpoints] while not all(completed) or time.time() - check_completion_start_time > max_completion_time: await asyncio.sleep(1) - + # Check status of each endpoint for i, endpoint in enumerate(ep_group.endpoints): - completed[i] = not multicast_client.get_status(base_url=endpoint.url)['is_streaming'] + status = await multicast_client.get_status(base_url=endpoint.url) + completed[i] = not status['is_streaming'] if all(completed): log.info(f"All endpoints completed streaming for group {ep_group.id}") @@ -247,7 +247,7 @@ async def get_group_state(group_id: int): ep_group = endpoints_db.get_group_by_id(group_id) if not ep_group: raise HTTPException(status_code=404, detail=f"Endpoint {group_id} not found") - + return {"name": ep_group.current_state.name, "value": ep_group.current_state.value} @@ -274,16 +274,16 @@ async def delete_group(group_id: int): async def start_announcement(text: str, group_id: int): """Start a new announcement to the specified endpoint group.""" global announcement_task - + # Get the group from active groups or database group = endpoints_db.get_group_by_id(group_id) if not group: raise HTTPException(status_code=400, detail=f"Group {group_id} not found") - + # Check if we're already processing an announcement #if announcement_task and not announcement_task.done(): # raise HTTPException(status_code=400, detail="Already processing an announcement") - + # Start the announcement task announcement_task = asyncio.create_task(make_announcement(text, group)) return {"status": "Announcement started", "group_id": group_id} @@ -312,10 +312,10 @@ if __name__ == "__main__": # with reload=True logging of modules does not function as expected uvicorn.run( app, - #'translator_server:app', - host="0.0.0.0", - port=7999, - #reload=True, + #'translator_server:app', + host="0.0.0.0", + port=7999, + #reload=True, #log_config=None, #log_level="info" ) diff --git a/tests/get_group_state.py b/tests/get_group_state.py new file mode 100644 index 0000000..25befce --- /dev/null +++ b/tests/get_group_state.py @@ -0,0 +1,9 @@ +import requests +import time + +if __name__ == '__main__': + # get the group state every 0.5s + while True: + response = requests.get('http://localhost:7999/groups/0/state') + print(response.json()) + time.sleep(0.5) \ No newline at end of file