diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 43682fb..6f9b75b 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -17,7 +17,8 @@ { "label": "pip install -e auracast", "type": "shell", - "command": "./venv/bin/python -m pip install -e ../bumble-auracast --config-settings editable_mode=compat" + "command": "./venv/bin/python -m pip install -e ../bumble-auracast --config-settings editable_mode=compat", + "problemMatcher": [] } ] } \ No newline at end of file diff --git a/multilang_translator/translator_config.py b/multilang_translator/translator_config.py deleted file mode 100644 index 4a1b112..0000000 --- a/multilang_translator/translator_config.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -from pydantic import BaseModel -from auracast import auracast_config - -ANNOUNCEMENT_DIR = os.path.join(os.path.dirname(__file__), 'announcements') -VENV_DIR = os.path.join(os.path.dirname(__file__), '../venv') -PIPER_EXE_PATH = f'{VENV_DIR}/bin/piper' - -class TranslatorBaseconfig(BaseModel): - big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigDe() - translator_llm: str = 'llama3.2:3b-instruct-q4_0' - llm_client: str = 'ollama' - llm_host_url: str | None = 'http://localhost:11434' - llm_host_token: str | None = None - tts_system: str = 'piper' - tts_model: str ='de_DE-kerstin-low' - - -class TranslatorConfigDe(TranslatorBaseconfig): - big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigDe() - tts_model: str ='de_DE-thorsten-high' - -class TranslatorConfigEn(TranslatorBaseconfig): - big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigEn() - tts_model: str = 'en_GB-alba-medium' - -class TranslatorConfigFr(TranslatorBaseconfig): - big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigFr() - tts_model: str = 'fr_FR-siwis-medium' - -class TranslatorConfigEs(TranslatorBaseconfig): - big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigEs() - tts_model: str = 'es_ES-sharvard-medium' - -class TranslatorConfigIt(TranslatorBaseconfig): - big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigIt() - tts_model: str = 'it_IT-paola-medium' diff --git a/pyproject.toml b/pyproject.toml index f67cc2d..192c1bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,9 @@ dependencies = [ "requests==2.32.3", "ollama==0.4.7", "aioconsole==0.8.1", - "piper-phonemize==1.1.0", - "piper-tts==1.2.0", + "fastapi==0.115.11", + "uvicorn==0.34.0", + "aiohttp==3.9.3", ] [project.optional-dependencies] @@ -17,6 +18,11 @@ test = [ "pytest >= 8.2", ] +[tool.poetry.group.tts.dependencies] +piper-phonemize = "==1.1.0" +piper-tts = "==1.2.0" + + [tool.pytest.ini_options] addopts = [ "--import-mode=importlib","--count=1","-s","-v" diff --git a/multilang_translator/__init__.py b/src/multilang_translator/__init__.py similarity index 100% rename from multilang_translator/__init__.py rename to src/multilang_translator/__init__.py diff --git a/multilang_translator/encode/encode_lc3.py b/src/multilang_translator/encode/encode_lc3.py similarity index 100% rename from multilang_translator/encode/encode_lc3.py rename to src/multilang_translator/encode/encode_lc3.py diff --git a/src/multilang_translator/main_cloud.py b/src/multilang_translator/main_cloud.py new file mode 100644 index 0000000..f51806b --- /dev/null +++ b/src/multilang_translator/main_cloud.py @@ -0,0 +1,89 @@ +from typing import List +import time +import asyncio + +import logging as log + +from auracast import multicast_client +from auracast import auracast_config + +import voice_client +import voice_models + +from multilang_translator import translator_config +from multilang_translator.translator import llm_translator +import voice_client.tts_client +import voice_models.request_models + + +async def announcement_from_german_text( + config: translator_config.TranslatorConfigGroup, + text_de + ): + base_lang = "deu" + + audio_data_dict = {} + for i, big in enumerate(config.bigs): + if big.language == base_lang: + text = text_de + else: + text = llm_translator.translate_de_to_x( + text_de, + big.language, + model=big.translator_llm, + client = big.llm_client, + host=big.llm_host_url, + token=big.llm_host_token + ) + + log.info('%s', text) + request_data = voice_models.request_models.SynthesizeRequest( + text=text, + target_sample_rate=config.auracast_sampling_rate_hz, + framework=big.tts_system, + model=big.tts_model, + return_lc3=True + ) + start = time.time() + lc3_audio = voice_client.tts_client.request_synthesis( + request_data + ) + log.info('Voice synth took %s', time.time() - start) + audio_data_dict[big.language] = lc3_audio.decode('latin-1') # TODO: should be .hex in the future + + await multicast_client.send_audio( + audio_data_dict + ) + + +async def main(): + log.basicConfig( + level=log.INFO, + format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' + ) + + config = translator_config.TranslatorConfigGroup( + bigs=[ + translator_config.TranslatorConfigDe(), + translator_config.TranslatorConfigEn(), + translator_config.TranslatorConfigFr(), + ] + ) + + config.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc + + for conf in config.bigs: + conf.loop = False + conf.llm_client = 'openwebui' # comment out for local llm + conf.llm_host_url = 'https://ollama.pstruebi.xyz' + conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13' + + await multicast_client.init( + config + ) + + await announcement_from_german_text(config, 'Hello') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/multilang_translator/main_local.py b/src/multilang_translator/main_local.py similarity index 58% rename from multilang_translator/main_local.py rename to src/multilang_translator/main_local.py index 1b3687c..6829ac8 100644 --- a/multilang_translator/main_local.py +++ b/src/multilang_translator/main_local.py @@ -1,24 +1,17 @@ -# -*- coding: utf-8 -*- -""" -list prompt example -""" -from __future__ import print_function, unicode_literals - from typing import List from dataclasses import asdict import asyncio -from copy import copy import time import logging as log import aioconsole -import multilang_translator.translator_config as translator_config -from utils import resample -from translator import llm_translator, test_content -from text_to_speech import text_to_speech from auracast import multicast_control from auracast import auracast_config -from translator.test_content import TESTSENTENCE +from voice_provider import text_to_speech + +from multilang_translator import translator_config +from multilang_translator.translator import llm_translator +from multilang_translator.translator.test_content import TESTSENTENCE # TODO: look for a end to end translation solution @@ -27,33 +20,32 @@ def transcribe(): async def announcement_from_german_text( - global_config: auracast_config.AuracastGlobalConfig, - translator_config: List[translator_config.TranslatorConfigDe], + config: translator_config.TranslatorConfigGroup, caster: multicast_control.Multicaster, text_de ): base_lang = "deu" - for i, trans in enumerate(translator_config): - if trans.big.language == base_lang: + for i, big in enumerate(config.bigs): + if big.language == base_lang: text = text_de else: text = llm_translator.translate_de_to_x( text_de, - trans.big.language, - model=trans.translator_llm, - client = trans.llm_client, - host=trans.llm_host_url, - token=trans.llm_host_token + big.language, + model=big.translator_llm, + client = big.llm_client, + host=big.llm_host_url, + token=big.llm_host_token ) log.info('%s', text) lc3_audio = text_to_speech.synthesize( text, - global_config.auracast_sampling_rate_hz, - trans.tts_system, - trans.tts_model, + config.auracast_sampling_rate_hz, + big.tts_system, + big.tts_model, return_lc3=True ) caster.big_conf[i].audio_source = lc3_audio @@ -64,7 +56,7 @@ async def announcement_from_german_text( log.info("Starting all broadcasts took %s s", round(time.time() - start, 3)) -async def command_line_ui(global_conf, translator_conf, caster: multicast_control.Multicaster): +async def command_line_ui(config: translator_config.TranslatorConfigGroup, translator_conf, caster: multicast_control.Multicaster): while True: # make a list of all available testsentence sentence_list = list(asdict(TESTSENTENCE).values()) @@ -86,8 +78,7 @@ async def command_line_ui(global_conf, translator_conf, caster: multicast_contro elif command.strip().isdigit(): ind = int(command.strip()) await announcement_from_german_text( - global_conf, - translator_conf, + config, caster, sentence_list[ind]) await asyncio.wait([caster.streamer.task]) @@ -103,35 +94,41 @@ async def main(): format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' ) - global_conf = auracast_config.AuracastGlobalConfig() - #global_conf.transport='serial:/dev/serial/by-id/usb-SEGGER_J-Link_001057705357-if02,1000000,rtscts' # transport for nrf54l15dk - global_conf.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc - - - translator_conf = [ - translator_config.TranslatorConfigDe(), - translator_config.TranslatorConfigEn(), - translator_config.TranslatorConfigFr(), - #auracast_config.broadcast_es, - #auracast_config.broadcast_it, + config = translator_config.TranslatorConfigGroup( + bigs=[ + translator_config.TranslatorConfigDe(), + translator_config.TranslatorConfigEn(), + translator_config.TranslatorConfigFr(), ] - for conf in translator_conf: - conf.big.loop = False + ) + + #config = auracast_config.AuracastGlobalConfig() + #config.transport='serial:/dev/serial/by-id/usb-SEGGER_J-Link_001057705357-if02,1000000,rtscts' # transport for nrf54l15dk + config.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc + + for conf in config.bigs: + conf.loop = False conf.llm_client = 'openwebui' # comment out for local llm conf.llm_host_url = 'https://ollama.pstruebi.xyz' conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13' - caster = multicast_control.Multicaster(global_conf, [conf.big for conf in translator_conf]) + caster = multicast_control.Multicaster( + config, + [big for big in config.bigs] + ) await caster.init_broadcast() # await announcement_from_german_text( - # global_conf, - # translator_conf, + # config, # caster, # test_content.TESTSENTENCE.DE_HELLO # ) # await asyncio.wait([caster.streamer.task]) - await command_line_ui(global_conf, translator_conf, caster) + await command_line_ui( + config, + [big for big in config.bigs], + caster + ) if __name__ == '__main__': asyncio.run(main()) diff --git a/multilang_translator/text_to_speech/__init__.py b/src/multilang_translator/translator/__init__.py similarity index 100% rename from multilang_translator/text_to_speech/__init__.py rename to src/multilang_translator/translator/__init__.py diff --git a/multilang_translator/translator/llm_translator.py b/src/multilang_translator/translator/llm_translator.py similarity index 59% rename from multilang_translator/translator/llm_translator.py rename to src/multilang_translator/translator/llm_translator.py index 93f6eb6..6014e76 100644 --- a/multilang_translator/translator/llm_translator.py +++ b/src/multilang_translator/translator/llm_translator.py @@ -4,6 +4,7 @@ import json import logging as log import time import ollama +import aiohttp from multilang_translator.translator import syspromts @@ -12,10 +13,6 @@ from multilang_translator.translator import syspromts # from_='llama3.2', system="You are Mario from Super Mario Bros." # ) -async def chat(): - message = {'role': 'user', 'content': 'Why is the sky blue?'} - response = await ollama.AsyncClient().chat(model='llama3.2', messages=[message]) - def query_openwebui(model, system, query, url, token): url = f'{url}/api/chat/completions' @@ -50,6 +47,41 @@ def query_ollama(model, system, query, host='http://localhost:11434'): return response.message.content +async def query_openwebui_async(model, system, query, url, token): + url = f'{url}/api/chat/completions' + headers = { + 'Authorization': f'Bearer {token}', + } + payload = { + 'model': model, + 'messages': [ + {'role': 'system', 'content': system}, + {'role': 'user', 'content': query} + ], + } + start = time.time() + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=payload) as response: + response_json = await response.json() + log.info("Translating the text took %s s", round(time.time() - start, 2)) + return response_json['choices'][0]['message']['content'] + + +async def query_ollama_async(model, system, query, host='http://localhost:11434'): + client = ollama.AsyncClient( + host=host, + ) + + response = await client.chat( + model=model, + messages=[ + {'role': 'system', 'content': system}, + {'role': 'user', 'content': query} + ], + ) + return response.message.content + + def translate_de_to_x( # TODO: use async ollama client later - implenent a translate async function text:str, target_language: str, @@ -70,6 +102,27 @@ def translate_de_to_x( # TODO: use async ollama client later - implenent a trans log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3)) return response +async def translate_de_to_x_async( + text:str, + target_language: str, + client='ollama', + model='llama3.2:3b-instruct-q4_0', # remember to use instruct models + host = None, + token = None + ): + start=time.time() + s = getattr(syspromts, f"TRANSLATOR_DEU_{target_language.upper()}") + + if client == 'ollama': + response = await query_ollama_async(model, s, text, host=host) + elif client == 'openwebui': + response = await query_openwebui_async(model, s, text, url=host, token=token) + else: raise NotImplementedError('llm client not implemented') + + log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3)) + return response + + if __name__ == "__main__": import time from multilang_translator.translator import test_content diff --git a/multilang_translator/translator/syspromts.py b/src/multilang_translator/translator/syspromts.py similarity index 100% rename from multilang_translator/translator/syspromts.py rename to src/multilang_translator/translator/syspromts.py diff --git a/multilang_translator/translator/test_content.py b/src/multilang_translator/translator/test_content.py similarity index 100% rename from multilang_translator/translator/test_content.py rename to src/multilang_translator/translator/test_content.py diff --git a/multilang_translator/translator/__init__.py b/src/multilang_translator/translator_client/__init__.py similarity index 100% rename from multilang_translator/translator/__init__.py rename to src/multilang_translator/translator_client/__init__.py diff --git a/src/multilang_translator/translator_client/translator_client.py b/src/multilang_translator/translator_client/translator_client.py new file mode 100644 index 0000000..6355b7c --- /dev/null +++ b/src/multilang_translator/translator_client/translator_client.py @@ -0,0 +1,94 @@ +""" +API client functions for interacting with the Translator API. +""" +import requests +from typing import List, Optional, Dict, Any, Tuple +from enum import Enum + + +from multilang_translator.translator_models.translator_models import AnnouncementStates, Endpoint, EndpointGroup + + +# This can be overridden through environment variables +API_BASE_URL = "http://localhost:7999" + +def get_groups() -> List[EndpointGroup]: + """Get all endpoint groups.""" + response = requests.get(f"{API_BASE_URL}/groups") + response.raise_for_status() + return [EndpointGroup.model_validate(group) for group in response.json()] + +def get_group(group_id: int) -> Optional[EndpointGroup]: + """Get a specific endpoint group by ID.""" + response = requests.get(f"{API_BASE_URL}/groups/{group_id}") + if response.status_code == 404: + return None + response.raise_for_status() + return EndpointGroup.model_validate(response.json()) + +def create_group(group: EndpointGroup) -> EndpointGroup: + """Create a new endpoint group.""" + # Convert the model to a dict with enum values as their primitive values + payload = group.model_dump(mode='json') + response = requests.post(f"{API_BASE_URL}/groups", json=payload) + response.raise_for_status() + return EndpointGroup.model_validate(response.json()) + +def update_group(group_id: int, updated_group: EndpointGroup) -> EndpointGroup: + """Update an existing endpoint group.""" + # Convert the model to a dict with enum values as their primitive values + payload = updated_group.model_dump(mode='json') + response = requests.put(f"{API_BASE_URL}/groups/{group_id}", json=payload) + response.raise_for_status() + return EndpointGroup.model_validate(response.json()) + +def delete_group(group_id: int) -> None: + """Delete an endpoint group.""" + response = requests.delete(f"{API_BASE_URL}/groups/{group_id}") + response.raise_for_status() + +def start_announcement(text: str, group_id: int) -> Dict[str, Any]: + """ + Start a new announcement. + + Args: + text: The text content of the announcement + group_id: The ID of the endpoint group to send the announcement to + + Returns: + Dictionary with status information + """ + response = requests.post(f"{API_BASE_URL}/announcement", params={"text": text, "group_id": group_id}) + response.raise_for_status() + return response.json() + +def get_group_state(group_id: int) -> Tuple[str, float]: + """ + Get the status of the current announcement for a specific group. + + Args: + group_id: The ID of the group to check the announcement status for + + Returns: + Tuple containing (state_name, state_value) + """ + response = requests.get(f"{API_BASE_URL}/groups/{group_id}/state") + response.raise_for_status() + state_data = response.json() + return (state_data["name"], state_data["value"]) + + +def get_available_endpoints() -> List[Endpoint]: + """Get all available endpoints.""" + response = requests.get(f"{API_BASE_URL}/endpoints") + response.raise_for_status() + endpoints_dict = response.json() + # API returns a dictionary with endpoint IDs as keys + # Convert this to a list of Endpoint objects + return [Endpoint.model_validate(endpoint_data) for endpoint_id, endpoint_data in endpoints_dict.items()] + +def get_available_languages() -> List[str]: + """Get all available languages for announcements.""" + response = requests.get(f"{API_BASE_URL}/languages") + response.raise_for_status() + return response.json() diff --git a/src/multilang_translator/translator_config.py b/src/multilang_translator/translator_config.py new file mode 100644 index 0000000..0066b1b --- /dev/null +++ b/src/multilang_translator/translator_config.py @@ -0,0 +1,20 @@ +import os +from pydantic import BaseModel + +VENV_DIR = os.path.join(os.path.dirname(__file__), './../../venv') + +class TranslatorLangConfig(BaseModel): + translator_llm: str = 'llama3.2:3b-instruct-q4_0' # TODO: this was migrated to translator_models - remove this + llm_client: str = 'ollama' + llm_host_url: str | None = 'http://localhost:11434' + llm_host_token: str | None = None + tts_system: str = 'piper' + tts_model: str ='de_DE-kerstin-low' + +class TranslatorConfig(BaseModel): + deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high') + eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium') + fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium') + spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium') + ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium') + diff --git a/src/multilang_translator/translator_models/__init__.py b/src/multilang_translator/translator_models/__init__.py new file mode 100644 index 0000000..71bc2fd --- /dev/null +++ b/src/multilang_translator/translator_models/__init__.py @@ -0,0 +1 @@ +# Empty file to make the directory a package diff --git a/src/multilang_translator/translator_models/translator_models.py b/src/multilang_translator/translator_models/translator_models.py new file mode 100644 index 0000000..938c30c --- /dev/null +++ b/src/multilang_translator/translator_models/translator_models.py @@ -0,0 +1,58 @@ +""" +Models for the translator API. +Similar to the models used in auracaster-webui but simplified for the translator middleware. +""" +from enum import Enum +from typing import List, Optional +from pydantic import BaseModel + + +class AnnouncementStates(Enum): + IDLE = 0 + INIT = 0.1 + TRANSLATING = 0.2 + GENERATING_VOICE = 0.4 + ROUTING = 0.6 + BROADCASTING = 0.8 + COMPLETED = 1 + ERROR = 0 + + +class Endpoint(BaseModel): + """Defines an endpoint with its URL and capabilities.""" + id: int + name: str + url: str + max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts + +class TranslatorLangConfig(BaseModel): + translator_llm: str = 'llama3.2:3b-instruct-q4_0' + + llm_client: str = 'openwebui' # remote (homserver) + llm_host_url: str = 'https://ollama.pstruebi.xyz' + llm_host_token: str = 'sk-17124cb84df14cc6ab2d9e17d0724d13' + # llm_client: str = 'ollama' #local + # llm_host_url: str | None = 'http://localhost:11434' + # llm_host_token: str | None = None + + tts_system: str = 'piper' + tts_model: str ='de_DE-kerstin-low' + + +class TranslatorConfig(BaseModel): + deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high') + eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium') + fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium') + spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium') + ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium') + + +class EndpointGroup(BaseModel): + id: int + name: str + languages: List[str] + endpoints: List[Endpoint] + sampling_rate_hz: int = 16000 + translator_config: TranslatorConfig = TranslatorConfig() + current_state: AnnouncementStates = AnnouncementStates.IDLE + anouncement_start_time: float = 0.0 diff --git a/src/multilang_translator/translator_server/__init__.py b/src/multilang_translator/translator_server/__init__.py new file mode 100644 index 0000000..71bc2fd --- /dev/null +++ b/src/multilang_translator/translator_server/__init__.py @@ -0,0 +1 @@ +# Empty file to make the directory a package diff --git a/src/multilang_translator/translator_server/endpoints_db.py b/src/multilang_translator/translator_server/endpoints_db.py new file mode 100644 index 0000000..e4e0255 --- /dev/null +++ b/src/multilang_translator/translator_server/endpoints_db.py @@ -0,0 +1,143 @@ +""" +Database file for endpoint definitions. +This file contains configurations for auracast endpoints including their IP addresses and capabilities. +""" +from typing import List, Optional +from multilang_translator.translator_models.translator_models import EndpointGroup, Endpoint + + +SUPPORTED_LANGUAGES = ["deu", "eng", "fra", "spa", "ita"] + +# Database of endpoints +ENDPOINTS: dict[int: Endpoint] = { # for now make sure, .id and key are the same + 0: Endpoint( + id=0, + name="Local Endpoint", + url="http://localhost:5000", + max_broadcasts=3, + ), + 1: Endpoint( + id=1, + name="Gate 1", + url="http://pi3:5000", + max_broadcasts=3, + ), + 2: Endpoint( + id=2, + name="Gate 2", + url="http://192.168.1.102:5000", + max_broadcasts=3, + ), +} + +# Database of endpoint groups with default endpoints +ENDPOINT_GROUPS: dict[int:EndpointGroup] = { # for now make sure , .id and key are the same + 0: EndpointGroup( + id=0, + name="Local Group", + languages=["deu", "eng"], + endpoints=[ENDPOINTS[0]], + ), + 1: EndpointGroup( + id=1, + name="Gate1", + languages=["deu", "fra"], + endpoints=[ENDPOINTS[1]], + ) +} + +def get_available_languages() -> List[str]: + """Get a list of all supported languages.""" + return SUPPORTED_LANGUAGES + +# Endpoint functions +def get_all_endpoints() -> List[Endpoint]: + """Get all active endpoints.""" + return ENDPOINTS + +def get_endpoint_by_id(endpoint_id: str) -> Optional[Endpoint]: + """Get an endpoint by its ID.""" + return ENDPOINTS[endpoint_id] + + +def add_endpoint(endpoint: Endpoint) -> Endpoint: + """Add a new endpoint to the database.""" + if endpoint.id in ENDPOINTS: + raise ValueError(f"Endpoint with ID {endpoint.id} already exists") + ENDPOINTS[endpoint.id] = endpoint + return endpoint + + +def update_endpoint(endpoint_id: str, updated_endpoint: Endpoint) -> Endpoint: + """Update an existing endpoint in the database.""" + if endpoint_id not in ENDPOINTS: + raise ValueError(f"Endpoint {endpoint_id} not found") + + # Ensure the ID is preserved + updated_endpoint.id = endpoint_id + ENDPOINTS[endpoint_id] = updated_endpoint + return updated_endpoint + + +def delete_endpoint(endpoint_id: str) -> None: + """Delete an endpoint from the database.""" + if endpoint_id not in ENDPOINTS: + raise ValueError(f"Endpoint {endpoint_id} not found") + + # Check if this endpoint is used in any groups + for group in ENDPOINT_GROUPS.values(): + if endpoint_id in group.endpoints: + raise ValueError(f"Cannot delete endpoint {endpoint_id}, it is used in group {group.id}") + + del ENDPOINTS[endpoint_id] + + +# Endpoint Group functions +def get_all_groups() -> List[EndpointGroup]: + """Get all endpoint groups.""" + return list(ENDPOINT_GROUPS.values()) + + +def get_group_by_id(group_id: int) -> Optional[EndpointGroup]: + """Get an endpoint group by its ID.""" + return ENDPOINT_GROUPS.get(group_id) + + +def add_group(group: EndpointGroup) -> EndpointGroup: + """Add a new endpoint group to the database.""" + if group.id in ENDPOINT_GROUPS: + raise ValueError(f"Group with ID {group.id} already exists") + + # Validate that all referenced endpoints exist + for endpoint_id in group.endpoints: + if endpoint_id not in ENDPOINTS: + raise ValueError(f"Endpoint {endpoint_id} not found") + + ENDPOINT_GROUPS[group.id] = group + return group + + +def update_group(group_id: int, updated_group: EndpointGroup) -> EndpointGroup: + """Update an existing endpoint group in the database.""" + if group_id not in ENDPOINT_GROUPS: + raise ValueError(f"Group {group_id} not found") + + # Validate that all referenced endpoints exist + for endpoint in updated_group.endpoints: + if endpoint.id not in ENDPOINTS.keys(): + raise ValueError(f"Endpoint with id {endpoint.id} not found") + + # Ensure the ID is preserved + updated_group.id = group_id + ENDPOINT_GROUPS[group_id] = updated_group + return updated_group + + +def delete_group(group_id: int) -> None: + """Delete an endpoint group from the database.""" + if group_id not in ENDPOINT_GROUPS: + raise ValueError(f"Group {group_id} not found") + + del ENDPOINT_GROUPS[group_id] + + diff --git a/src/multilang_translator/translator_server/main_translator_api.py b/src/multilang_translator/translator_server/main_translator_api.py new file mode 100644 index 0000000..dae4df8 --- /dev/null +++ b/src/multilang_translator/translator_server/main_translator_api.py @@ -0,0 +1,28 @@ +""" +Entry point for the Translator API server. +This file starts the FastAPI server with the translator_server. +""" +import uvicorn +import logging as log +import sys +import os + +# Add the parent directory to the Python path to find the multilang_translator package +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir))) +if parent_dir not in sys.path: + sys.path.insert(0, parent_dir) + +if __name__ == "__main__": + log.basicConfig( + level=log.INFO, + format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' + ) + log.info("Starting Translator API server") + uvicorn.run( + "multilang_translator.translator_server.translator_server:app", + host="0.0.0.0", + port=7999, + reload=True, + log_level="debug" + ) diff --git a/src/multilang_translator/translator_server/translator_server.py b/src/multilang_translator/translator_server/translator_server.py new file mode 100644 index 0000000..4d1d386 --- /dev/null +++ b/src/multilang_translator/translator_server/translator_server.py @@ -0,0 +1,346 @@ +""" +FastAPI implementation of the Multilang Translator API. +This API mimics the mock_api from auracaster-webui to allow integration. +""" +import time +import logging as log +import asyncio +import random + +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware + +# Import models +from multilang_translator.translator_models.translator_models import AnnouncementStates, Endpoint, EndpointGroup +from multilang_translator.translator import llm_translator +from multilang_translator.translator_server import endpoints_db +from voice_provider import text_to_speech + +# Import the endpoints database and multicast client +from auracast import multicast_client, auracast_config + +# Create FastAPI app +app = FastAPI() + +# Add CORS middleware to allow cross-origin requests +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Endpoint configuration cache +CURRENT_ENDPOINT_CONFIG = {} + +async def init_endpoint(endpoint: Endpoint, languages: list[str], sampling_rate_hz: int): + """Initialize a specific endpoint for multicast.""" + + current_config = CURRENT_ENDPOINT_CONFIG.get(endpoint.id) + + if current_config is not None: + current_langs = [big.language for big in current_config.bigs] + # if languages are unchanged and the caster client status is initiailized, skip init + if current_langs == languages: + # Get status asynchronously + status = await multicast_client.get_status(base_url=endpoint.url) + if status['is_initialized']: + log.info('Endpoint %s was already initialized', endpoint.name) + return + + log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.url}") + # Load a default config + config = auracast_config.AuracastConfigGroup( + bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")() + for lang in languages] + ) + + # overwrite some default configs + config.transport = 'auto' + config.auracast_device_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6)) + config.auracast_sampling_rate_hz = sampling_rate_hz + + # Configure the bigs + for big in config.bigs: + big.loop = False + big.name = endpoint.name + big.random_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6)) + big.id = random.randint(0, 2**16) #TODO: how many bits is this ? + #big.program_info = big.program_info + ' ' + endpoint.name + + # make async init request + ret = await multicast_client.init(config, base_url=endpoint.url) + # if ret != 200: # TODO: this is not working, should probably be handled async + # log.error('Init of endpoint %s was unsucessfull', endpoint.name) + # raise Exception(f"Init was of endpoint {endpoint.name} was unsucessfull") + CURRENT_ENDPOINT_CONFIG[endpoint.id] = config.model_copy() + log.info(f"Endpoint {endpoint.name} initialized successfully") + #else: + # log.info('Endpoint %s was already initialized', endpoint.name) + + +async def make_announcement(text: str, ep_group: EndpointGroup): + """ + Make an announcement to a group of endpoints. + """ + + if text == "": + log.warning("Announcement text is empty") + return {"error": "Announcement text is empty"} + + ep_group.current_state = AnnouncementStates.IDLE + ep_group.anouncement_start_time = time.time() + # update the database with the new state and start time so this can be read by another process + endpoints_db.update_group(ep_group.id, ep_group) + + # Initialize all endpoints in the group concurrently + ep_group.current_state = AnnouncementStates.INIT + endpoints_db.update_group(ep_group.id, ep_group) + + # Create init tasks and run them concurrently + init_tasks = [ + init_endpoint(endpoint, ep_group.languages, ep_group.sampling_rate_hz) + for endpoint in ep_group.endpoints + ] + + # make sure init finished + await asyncio.gather(*init_tasks) + + # Translate the text for each language (concurrently) + base_lang = "deu" # German is the base language + target_langs = ep_group.languages.copy() + if base_lang in target_langs: + target_langs.remove(base_lang) + + ep_group.current_state = AnnouncementStates.TRANSLATING + endpoints_db.update_group(ep_group.id, ep_group) + + # Create translation tasks + translations = {base_lang: text} + translation_tasks = [] + + for lang in target_langs: + # Prepare translation task + trans_conf = getattr(ep_group.translator_config, lang) + task = llm_translator.translate_de_to_x_async( + text=text, + target_language=lang, + client=trans_conf.llm_client, + model=trans_conf.translator_llm, + host=trans_conf.llm_host_url, + token=trans_conf.llm_host_token + ) + translation_tasks.append(task) + + # Wait for all translations to complete concurrently + results = await asyncio.gather(*translation_tasks) + for i, translation in enumerate(results): + lang = target_langs[i] + translations[lang] = translation + log.info(f"Translated to {lang}: {translation}") + + # Generate voices concurrently + ep_group.current_state = AnnouncementStates.GENERATING_VOICE + endpoints_db.update_group(ep_group.id, ep_group) + + # Prepare synthesis tasks and run them concurrently + synth_langs = ep_group.languages + synthesis_tasks = [] + for lang in synth_langs: + trans_conf = getattr(ep_group.translator_config, lang) + task = text_to_speech.synthesize_async( + translations[lang], + ep_group.sampling_rate_hz, + trans_conf.tts_system, + trans_conf.tts_model, + return_lc3=True + ) + synthesis_tasks.append(task) + + # Wait for all synthesis tasks to complete concurrently + audio = {} + if synthesis_tasks: + results = await asyncio.gather(*synthesis_tasks) + for i, audio_data in enumerate(results): + audio[synth_langs[i]] = audio_data + + # Start the monitoring coroutine to wait for streaming to complete + # This will set the state to COMPLETED when finished + asyncio.create_task(monitor_streaming_completion(ep_group)) + + # Broadcast to all endpoints in group concurrently + broadcast_tasks = [] + for endpoint in ep_group.endpoints: + log.info(f"Broadcasting to {endpoint.name} for languages: {', '.join(audio.keys())}") + task = multicast_client.send_audio(audio, base_url=endpoint.url) + broadcast_tasks.append(task) + + # Wait for all broadcasts to complete + await asyncio.gather(*broadcast_tasks) + + # Return the translations + return {"translations": translations} + + +async def monitor_streaming_completion(ep_group: EndpointGroup): + """ + Monitor streaming status after audio is sent and update group state when complete. + + Args: + ep_group: The endpoint group being monitored + """ + log.info(f"Starting streaming completion monitoring for endpoint group {ep_group.id}") + + + # Set a shorter timeout as requested + max_completion_time = 60 # seconds + + # First check if we are actually in streaming state + streaming_started = False + initial_check_timeout = 10 # seconds + initial_check_start = time.time() + + # Wait for streaming to start (with timeout) + while time.time() - initial_check_start < initial_check_timeout: + # Wait before checking again + await asyncio.sleep(1) + + any_streaming = False + for endpoint in ep_group.endpoints: + status = await multicast_client.get_status(base_url=endpoint.url) + if status.get("is_streaming", False): + any_streaming = True + log.info(f"Streaming confirmed started on endpoint {endpoint.name}") + break + + if any_streaming: + streaming_started = True + break + + if not streaming_started: + log.warning(f"No endpoints started streaming for group {ep_group.id} after {initial_check_timeout}s") + # Still update to completed since there's nothing to wait for + ep_group.current_state = AnnouncementStates.ERROR + endpoints_db.update_group(ep_group.id, ep_group) + return + + # Update group progress + ep_group.current_state = AnnouncementStates.BROADCASTING + endpoints_db.update_group(ep_group.id, ep_group) + + # Now monitor until streaming completes on all endpoints + check_completion_start_time = time.time() + completed = [False for _ in ep_group.endpoints] + while not all(completed) or time.time() - check_completion_start_time > max_completion_time: + await asyncio.sleep(1) + + # Check status of each endpoint + for i, endpoint in enumerate(ep_group.endpoints): + status = await multicast_client.get_status(base_url=endpoint.url) + completed[i] = not status['is_streaming'] + + if all(completed): + log.info(f"All endpoints completed streaming for group {ep_group.id}") + # Update group state to completed + ep_group.current_state = AnnouncementStates.COMPLETED + endpoints_db.update_group(ep_group.id, ep_group) + log.info(f"Updated group {ep_group.id} state to COMPLETED") + + else: + log.error(f"Max wait time reached for group {ep_group.id}. Forcing completion.") + + +@app.get("/groups") +async def get_groups(): + """Get all endpoint groups with their current status.""" + return endpoints_db.get_all_groups() + + +@app.post("/groups") +async def create_group(group: endpoints_db.EndpointGroup): + """Add a new endpoint group.""" + try: + return endpoints_db.add_group(group) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + +@app.get("/groups/{group_id}/state") # TODO: think about progress tracking +async def get_group_state(group_id: int): + """Get the status of a specific endpoint.""" + # Check if the endpoint exists + ep_group = endpoints_db.get_group_by_id(group_id) + if not ep_group: + raise HTTPException(status_code=404, detail=f"Endpoint {group_id} not found") + + return {"name": ep_group.current_state.name, "value": ep_group.current_state.value} + + +@app.put("/groups/{group_id}") +async def update_group(group_id: int, updated_group: endpoints_db.EndpointGroup): + """Update an existing endpoint group.""" + try: + return endpoints_db.update_group(group_id, updated_group) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@app.delete("/groups/{group_id}") +async def delete_group(group_id: int): + """Delete an endpoint group.""" + try: + endpoints_db.delete_group(group_id) + return {"message": f"Group {group_id} deleted successfully"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@app.post("/announcement") +async def start_announcement(text: str, group_id: int): + """Start a new announcement to the specified endpoint group.""" + global announcement_task + + # Get the group from active groups or database + group = endpoints_db.get_group_by_id(group_id) + if not group: + raise HTTPException(status_code=400, detail=f"Group {group_id} not found") + + # Check if we're already processing an announcement + #if announcement_task and not announcement_task.done(): + # raise HTTPException(status_code=400, detail="Already processing an announcement") + + # Start the announcement task + announcement_task = asyncio.create_task(make_announcement(text, group)) + return {"status": "Announcement started", "group_id": group_id} + + + +@app.get("/endpoints") +async def get_available_endpoints(): + """Get all available endpoints with their capabilities.""" + return endpoints_db.get_all_endpoints() + + +@app.get("/languages") +async def get_available_languages(): + """Get all available languages for announcements.""" + return endpoints_db.get_available_languages() + + + +if __name__ == "__main__": + import uvicorn + log.basicConfig( + level=log.DEBUG, + format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' + ) + # with reload=True logging of modules does not function as expected + uvicorn.run( + app, + #'translator_server:app', + host="0.0.0.0", + port=7999, + #reload=True, + #log_config=None, + #log_level="info" + ) diff --git a/src/voice_client/tts_client.py b/src/voice_client/tts_client.py new file mode 100644 index 0000000..3ee7732 --- /dev/null +++ b/src/voice_client/tts_client.py @@ -0,0 +1,44 @@ +import requests +import numpy as np +import soundfile as sf + +from voice_models.request_models import SynthesizeRequest + + +API_URL = "http://127.0.0.1:8099/synthesize/" + +def request_synthesis(request_data: SynthesizeRequest): + response = requests.post(API_URL, json=request_data.model_dump()) + + if response.status_code == 200: + response_data = response.json() + + if request_data.return_lc3: + # Save LC3 audio as binary file + lc3_bytes = bytes.fromhex(response_data["audio_lc3"]) + return lc3_bytes + + else: + # Convert hex-encoded PCM bytes back to numpy array and save as WAV + audio_bytes = bytes.fromhex(response_data["audio_pcm"]) + audio_array = np.frombuffer(audio_bytes, dtype=np.float32) + return audio_array + + else: + print(f"Error: {response.status_code}, {response.text}") + +if __name__ == "__main__": + + target_rate=16000 + + # Example request + request_data = SynthesizeRequest( + text="Hello, this is a test.", + target_sample_rate=target_rate, + framework="piper", + model="de_DE-kerstin-low", + return_lc3=False # Set to True to receive LC3 compressed output + ) + + audio = request_synthesis(request_data) + sf.write('hello.wav', audio, target_rate) diff --git a/src/voice_models/request_models.py b/src/voice_models/request_models.py new file mode 100644 index 0000000..69edbc0 --- /dev/null +++ b/src/voice_models/request_models.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +class SynthesizeRequest(BaseModel): + text: str + target_sample_rate: int = 16000 + framework: str = "piper" + model: str = "en_US-lessac-medium" + return_lc3: bool = False + diff --git a/multilang_translator/utils/__init__.py b/src/voice_provider/__init__.py similarity index 100% rename from multilang_translator/utils/__init__.py rename to src/voice_provider/__init__.py diff --git a/multilang_translator/text_to_speech/piper/voices.json b/src/voice_provider/piper/voices.json similarity index 100% rename from multilang_translator/text_to_speech/piper/voices.json rename to src/voice_provider/piper/voices.json diff --git a/multilang_translator/text_to_speech/piper_welcome.sh b/src/voice_provider/piper_welcome.sh similarity index 100% rename from multilang_translator/text_to_speech/piper_welcome.sh rename to src/voice_provider/piper_welcome.sh diff --git a/multilang_translator/text_to_speech/text_to_speech.py b/src/voice_provider/text_to_speech.py similarity index 53% rename from multilang_translator/text_to_speech/text_to_speech.py rename to src/voice_provider/text_to_speech.py index 50ed861..c076c24 100644 --- a/multilang_translator/text_to_speech/text_to_speech.py +++ b/src/voice_provider/text_to_speech.py @@ -1,27 +1,34 @@ import os +import shutil import subprocess import time import json import logging as log import numpy as np -from multilang_translator import translator_config -from multilang_translator.utils.resample import resample_array -from multilang_translator.text_to_speech import encode_lc3 +import asyncio +from voice_provider.utils.resample import resample_array +from voice_provider.utils.encode_lc3 import encode_lc3 + +PIPER_EXE = shutil.which('piper') TTS_DIR = os.path.join(os.path.dirname(__file__)) -PIPER_DIR = f'{TTS_DIR}/piper' +PIPER_WORKDIR = f'{TTS_DIR}/piper' + +if not PIPER_EXE: + PIPER_EXE = f'{TTS_DIR}/../../venv/bin/piper' def synth_piper(text, model="en_US-lessac-medium"): pwd = os.getcwd() - os.chdir(PIPER_DIR) + os.chdir(PIPER_WORKDIR) start = time.time() # make sure piper has voices.json in working directory, otherwise it attempts to always load models ret = subprocess.run( # TODO: wrap this whole thing in a class and open a permanent pipe to the model - [translator_config.PIPER_EXE_PATH, - '--cuda', - '--model', model, - '--output-raw' + [ + PIPER_EXE, + '--cuda', + '--model', model, + '--output-raw' ], input=text.encode('utf-8'), capture_output=True @@ -34,14 +41,19 @@ def synth_piper(text, model="en_US-lessac-medium"): log.info("Running piper for model %s took %s s", model, round(time.time() - start, 3)) - with open (f'{PIPER_DIR}/{model}.onnx.json') as f: # TODO: wrap everyth0ing into a class, store the json permanently + with open (f'{PIPER_WORKDIR}/{model}.onnx.json') as f: # TODO: wrap everyth0ing into a class, store the json permanently model_json = json.load(f) return model_json, audio -# TODO: framework should probably be a dataclass that holds all the relevant informations, also model -def synthesize(text, target_sample_rate, framework, model="en_US-lessac-medium", return_lc3=True): +def synthesize( + text, + target_sample_rate, + framework, + model="en_US-lessac-medium", + return_lc3=True + ): if framework == 'piper': model_json, audio_raw = synth_piper(text, model) @@ -59,12 +71,41 @@ def synthesize(text, target_sample_rate, framework, model="en_US-lessac-medium", if return_lc3: audio_pcm = (audio * 2**15-1).astype(np.int16) - lc3 = encode_lc3.encode(audio_pcm, target_sample_rate, 40) # TODO: octetts per frame should be parameter + lc3 = encode_lc3(audio_pcm, target_sample_rate, 40) # TODO: octetts per frame should be parameter return lc3 else: return audio +async def synthesize_async( + text, + target_sample_rate, + framework, + model="en_US-lessac-medium", + return_lc3=True + ): + """ + Asynchronous version of the synthesize function that runs in a thread pool. + + Args: + text: Text to synthesize + target_sample_rate: Target sample rate for the audio + framework: TTS framework to use (e.g., 'piper') + model: Model to use for synthesis + return_lc3: Whether to return LC3-encoded audio + + Returns: + LC3-encoded audio as string or raw audio as numpy array + """ + # Run the CPU-intensive synthesis in a thread pool + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + lambda: synthesize(text, target_sample_rate, framework, model, return_lc3) + ) + return result + + if __name__ == '__main__': import logging import soundfile as sf @@ -79,5 +120,4 @@ if __name__ == '__main__': sf.write('hello.wav', audio, target_rate) - # TODO: "WARNING:piper.download:Wrong size (expected=5952, actual=4158 print('Done.') diff --git a/src/voice_provider/tts_server.py b/src/voice_provider/tts_server.py new file mode 100644 index 0000000..6ce8fe1 --- /dev/null +++ b/src/voice_provider/tts_server.py @@ -0,0 +1,43 @@ +from fastapi import FastAPI, HTTPException +import numpy as np + +from voice_models.request_models import SynthesizeRequest +from voice_provider.text_to_speech import synthesize_async + +app = FastAPI() + +HOST_PORT = 8099 + + +@app.post("/synthesize/") +async def synthesize_speech(request: SynthesizeRequest): + try: + audio = await synthesize_async( + text=request.text, + target_sample_rate=request.target_sample_rate, + framework=request.framework, + model=request.model, + return_lc3=request.return_lc3 + ) + + if request.return_lc3: + # If it's already a string (LC3 data), convert it to bytes for hex encoding + if isinstance(audio, str): + audio_bytes = audio.encode('latin-1') + return {"audio_lc3": audio_bytes.hex()} + # If it's already bytes + elif isinstance(audio, bytes): + return {"audio_lc3": audio.hex()} + else: + raise ValueError(f"Unexpected audio type: {type(audio)}") + else: + # If it's numpy array (non-LC3), convert to bytes + audio_bytes = audio.astype(np.float32).tobytes() + return {"audio_pcm": audio_bytes.hex()} + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="127.0.0.1", port=HOST_PORT) \ No newline at end of file diff --git a/src/voice_provider/utils/__init__.py b/src/voice_provider/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/multilang_translator/text_to_speech/encode_lc3.py b/src/voice_provider/utils/encode_lc3.py similarity index 98% rename from multilang_translator/text_to_speech/encode_lc3.py rename to src/voice_provider/utils/encode_lc3.py index c120771..57e3c2e 100644 --- a/multilang_translator/text_to_speech/encode_lc3.py +++ b/src/voice_provider/utils/encode_lc3.py @@ -1,7 +1,7 @@ import numpy as np import lc3 -def encode( +def encode_lc3( audio: np.array, output_sample_rate_hz, octets_per_frame, diff --git a/multilang_translator/utils/resample.py b/src/voice_provider/utils/resample.py similarity index 100% rename from multilang_translator/utils/resample.py rename to src/voice_provider/utils/resample.py diff --git a/tests/get_group_state.py b/tests/get_group_state.py new file mode 100644 index 0000000..25befce --- /dev/null +++ b/tests/get_group_state.py @@ -0,0 +1,9 @@ +import requests +import time + +if __name__ == '__main__': + # get the group state every 0.5s + while True: + response = requests.get('http://localhost:7999/groups/0/state') + print(response.json()) + time.sleep(0.5) \ No newline at end of file