From 9a1f1dc838d8bd1387b38eb84373522ea47b36a3 Mon Sep 17 00:00:00 2001 From: pstruebi Date: Wed, 5 Mar 2025 17:00:23 +0100 Subject: [PATCH] refractoring --- multilang_translator/config.py | 39 ----------- multilang_translator/main_local.py | 45 ++++++++----- .../text_to_speech/text_to_speech.py | 24 ++++--- .../translator/credentials.py | 2 +- .../translator/llm_translator.py | 64 ++++++++++++++----- multilang_translator/translator_config.py | 37 +++++++++++ tests/test_backend.py | 2 +- 7 files changed, 131 insertions(+), 82 deletions(-) delete mode 100644 multilang_translator/config.py create mode 100644 multilang_translator/translator_config.py diff --git a/multilang_translator/config.py b/multilang_translator/config.py deleted file mode 100644 index 597928f..0000000 --- a/multilang_translator/config.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -from dataclasses import dataclass, field -from auracast import auracast_config - -ANNOUNCEMENT_DIR = os.path.join(os.path.dirname(__file__), 'announcements') -VENV_DIR = os.path.join(os.path.dirname(__file__), '../venv') -PIPER_EXE_PATH = f'{VENV_DIR}/bin/piper' - - -# TODO: TRANSLATOR_LLM = 'llama3.2:3b-instruct-q4_0' -@dataclass -class TranslatorConfigDe(): - big: auracast_config.AuracastBigConfig = field(default_factory=auracast_config.AuracastBigConfigDe) - tts_system: str = 'piper' - tts_model: str ='de_DE-kerstin-low' - -@dataclass -class TranslatorConfigEn(): - big: auracast_config.AuracastBigConfig = field(default_factory=auracast_config.AuracastBigConfigEn) - tts_system: str = 'piper' - tts_model: str = 'en_US-lessac-medium' - -@dataclass -class TranslatorConfigFr(): - big: auracast_config.AuracastBigConfig = field(default_factory=auracast_config.AuracastBigConfigFr) - tts_system: str = 'piper' - tts_model: str = 'fr_FR-siwis-medium' - -@dataclass -class TranslatorConfigEs(): - big: auracast_config.AuracastBigConfig = field(default_factory=auracast_config.AuracastBigConfigEs) - tts_system: str = 'piper' - tts_model: str = 'es_ES-sharvard-medium' - -@dataclass -class TranslatorConfigIt(): - big: auracast_config.AuracastBigConfig = field(default_factory=auracast_config.AuracastBigConfigIt) - tts_system: str = 'piper', - tts_model: str = 'it_IT-paola-medium' diff --git a/multilang_translator/main_local.py b/multilang_translator/main_local.py index cdf5084..e59216f 100644 --- a/multilang_translator/main_local.py +++ b/multilang_translator/main_local.py @@ -12,7 +12,7 @@ import time import logging as log import aioconsole -import config +import multilang_translator.translator_config as translator_config from utils import resample from translator import llm_translator, test_content from text_to_speech import text_to_speech @@ -28,7 +28,7 @@ def transcribe(): async def announcement_from_german_text( global_config: auracast_config.AuracastGlobalConfig, - translator_config: List[config.TranslatorConfigDe], + translator_config: List[translator_config.TranslatorConfigDe], caster: multicast_control.Multicaster, text_de ): @@ -40,15 +40,22 @@ async def announcement_from_german_text( if trans.big.language == base_lang: text = text_de else: - text = llm_translator.translate_de_to_x(text_de, trans.big.language, model=TRANSLATOR_LLM) + text = llm_translator.translate_de_to_x( + text_de, + trans.big.language, + model=TRANSLATOR_LLM, + client = trans.llm_client, + host=trans.llm_host_url, + token=trans.llm_host_token + ) log.info('%s', text) lc3_audio = text_to_speech.synthesize( text, global_config.auracast_sampling_rate_hz, - trans.big.tts_system, - trans.big.tts_model, + trans.tts_system, + trans.tts_model, return_lc3=True ) caster.big_conf[i].audio_source = lc3_audio @@ -59,7 +66,7 @@ async def announcement_from_german_text( log.info("Starting all broadcasts took %s s", round(time.time() - start, 3)) -async def command_line_ui(caster: multicast_control.Multicaster): +async def command_line_ui(global_conf, translator_conf, caster: multicast_control.Multicaster): while True: # make a list of all available testsentence sentence_list = list(asdict(TESTSENTENCE).values()) @@ -80,7 +87,11 @@ async def command_line_ui(caster: multicast_control.Multicaster): # Check if command is a single number elif command.strip().isdigit(): ind = int(command.strip()) - await announcement_from_german_text(caster, sentence_list[ind]) + await announcement_from_german_text( + global_conf, + translator_conf, + caster, + sentence_list[ind]) await asyncio.wait([caster.streamer.task]) # Interpret the command as announcement else: @@ -93,33 +104,37 @@ async def main(): format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' ) - global_conf = auracast_config.global_base_config + global_conf = auracast_config.AuracastGlobalConfig() #global_conf.transport='serial:/dev/serial/by-id/usb-SEGGER_J-Link_001057705357-if02,1000000,rtscts' # transport for nrf54l15dk global_conf.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc + translator_conf = [ - config.TranslatorConfigDe(), - config.TranslatorConfigEn(), - config.TranslatorConfigFr(), + translator_config.TranslatorConfigDe(), + translator_config.TranslatorConfigEn(), + translator_config.TranslatorConfigFr(), #auracast_config.broadcast_es, #auracast_config.broadcast_it, ] - for i, conf in enumerate(translator_conf): + for conf in translator_conf: conf.big.loop = False + conf.llm_client = 'openwebui' # comment out for local llm + conf.llm_host_url = 'https://ollama.pstruebi.xyz' + conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13' caster = multicast_control.Multicaster(global_conf, [conf.big for conf in translator_conf]) await caster.init_broadcast() await announcement_from_german_text( + global_conf, + translator_conf, caster, test_content.TESTSENTENCE.DE_HELLO ) await asyncio.wait([caster.streamer.task]) - #await command_line_ui(caster) + #await command_line_ui(global_conf, translator_conf, caster) if __name__ == '__main__': asyncio.run(main()) - # TODO: integrate this in the LANG_CONFIG dict, better: make a hierachy of dataclasses - # TODO: remove the nececcity for files # TODO: add support for multiple radios \ No newline at end of file diff --git a/multilang_translator/text_to_speech/text_to_speech.py b/multilang_translator/text_to_speech/text_to_speech.py index a2c231d..0b777e3 100644 --- a/multilang_translator/text_to_speech/text_to_speech.py +++ b/multilang_translator/text_to_speech/text_to_speech.py @@ -4,7 +4,7 @@ import time import json import logging as log import numpy as np -from multilang_translator import config +from multilang_translator import translator_config from multilang_translator.utils.resample import resample_array from multilang_translator.text_to_speech import encode_lc3 @@ -12,20 +12,24 @@ TTS_DIR = os.path.join(os.path.dirname(__file__)) PIPER_DIR = f'{TTS_DIR}/piper' os.makedirs(PIPER_DIR, exist_ok=True) -def synth_piper(text, model="en_US-lessac-medium",): +def synth_piper(text, model="en_US-lessac-medium"): + pwd = os.getcwd() + os.chdir(PIPER_DIR) start = time.time() ret = subprocess.run( # TODO: wrap this whole thing in a class and open a permanent pipe to the model - [config.PIPER_EXE_PATH, - '--cuda', - '--data-dir', PIPER_DIR, - '--download-dir', PIPER_DIR, - '--model', model, - '--output-raw' - ], + [translator_config.PIPER_EXE_PATH, + '--cuda', + #'--data-dir', PIPER_DIR, # not working, change workdir instead + #'--download-dir', PIPER_DIR, + #'--model', f'{PIPER_DIR}/{model}.onnx', + '--model', model, + '--output-raw' + ], input=text.encode('utf-8'), capture_output=True ) + os.chdir(pwd) log.warning('Piper stderr:\n%s', ret.stderr) assert ret.returncode == 0, 'Piper returncode was not 0.' @@ -79,5 +83,5 @@ if __name__ == '__main__': sf.write('hello.wav', audio, target_rate) - + # TODO: "WARNING:piper.download:Wrong size (expected=5952, actual=4158 print('Done.') diff --git a/multilang_translator/translator/credentials.py b/multilang_translator/translator/credentials.py index 4dbb3b7..4036a35 100644 --- a/multilang_translator/translator/credentials.py +++ b/multilang_translator/translator/credentials.py @@ -1,2 +1,2 @@ -BASE_URL='https://ollama.hinterwaldner.duckdns.org' +BASE_URL='https://ollama.pstruebi.xyz' TOKEN = 'sk-17124cb84df14cc6ab2d9e17d0724d13' \ No newline at end of file diff --git a/multilang_translator/translator/llm_translator.py b/multilang_translator/translator/llm_translator.py index 07e52b0..68604a5 100644 --- a/multilang_translator/translator/llm_translator.py +++ b/multilang_translator/translator/llm_translator.py @@ -7,43 +7,75 @@ import ollama from multilang_translator.translator import credentials from multilang_translator.translator import syspromts -from multilang_translator.translator import test_content # ollama.create( # TODO: create models on startup # model='example', # from_='llama3.2', system="You are Mario from Super Mario Bros." # ) -def query_model(model, query): - url = f'{credentials.BASE_URL}/api/chat/completions' + +async def chat(): + message = {'role': 'user', 'content': 'Why is the sky blue?'} + response = await ollama.AsyncClient().chat(model='llama3.2', messages=[message]) + + + +def query_openwebui(model, system, query, url, token): + url = f'{url}/api/chat/completions' headers = { - 'Authorization': f'Bearer {credentials.TOKEN}', + 'Authorization': f'Bearer {token}', } payload = { 'model': model, - 'messages': [{'role': 'user', 'content': query}], + 'messages': [ + {'role': 'system', 'content': system}, + {'role': 'user', 'content': query} + ], } start = time.time() response = requests.post(url, headers=headers, json=payload) log.info("Translating the text took %s s", round(time.time() - start, 2)) - return response.json() + return response.json()['choices'][0]['message']['content'] -def translate_de_to_x(text:str, target_language: str, model='llama3.2:3b-instruct-q4_0'): # remember to use instruct models - start=time.time() - s = getattr(syspromts, f"TRANSLATOR_DE_{target_language.upper()}") +def query_ollama(model, system, query, host=None): + # client = ollama.Client( + # host=host, + # ) + response = ollama.chat( - model = model, - messages = [ - {'role': 'system', 'content': s}, - {'role': 'user', 'content': text} - ], - ) + model = model, + messages = [ + {'role': 'system', 'content': system}, + {'role': 'user', 'content': query} + ], + ) + return response.message.content + +def translate_de_to_x( # TODO: use async ollama client later - implenent a translate async function + text:str, + target_language: str, + client='ollama', + model='llama3.2:3b-instruct-q4_0', # remember to use instruct models + host = None, + token = None + ): + start=time.time() + s = getattr(syspromts, f"TRANSLATOR_DEU_{target_language.upper()}") + + if client == 'ollama': + response = query_ollama(model, s, text, host=host) + elif client == 'openwebui': + response = query_openwebui(model, s, text, url=host, token=token) + else: raise NotImplementedError('llm client not implemented') + log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3)) - return response['message']['content'] + return response if __name__ == "__main__": import time + from multilang_translator.translator import test_content + start=time.time() response = translate_de_to_x('Der Zug ist da.', target_language='en', model='llama3.2:1b-instruct-q4_0') diff --git a/multilang_translator/translator_config.py b/multilang_translator/translator_config.py new file mode 100644 index 0000000..2f80f7b --- /dev/null +++ b/multilang_translator/translator_config.py @@ -0,0 +1,37 @@ +import os +from pydantic import BaseModel +from auracast import auracast_config + +ANNOUNCEMENT_DIR = os.path.join(os.path.dirname(__file__), 'announcements') +VENV_DIR = os.path.join(os.path.dirname(__file__), '../venv') +PIPER_EXE_PATH = f'{VENV_DIR}/bin/piper' + +class TranslatorBaseconfig(BaseModel): + big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigDe() + translator_llm: str = 'llama3.2:3b-instruct-q4_0' + llm_client: str = 'ollama' + llm_host_url: str | None = None + llm_host_token: str | None = None + tts_system: str = 'piper' + tts_model: str ='de_DE-kerstin-low' + + +class TranslatorConfigDe(TranslatorBaseconfig): + big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigDe() + tts_model: str ='de_DE-kerstin-low' + +class TranslatorConfigEn(TranslatorBaseconfig): + big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigEn() + tts_model: str = 'en_US-lessac-medium' + +class TranslatorConfigFr(TranslatorBaseconfig): + big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigFr() + tts_model: str = 'fr_FR-siwis-medium' + +class TranslatorConfigEs(TranslatorBaseconfig): + big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigEs() + tts_model: str = 'es_ES-sharvard-medium' + +class TranslatorConfigIt(TranslatorBaseconfig): + big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigIt() + tts_model: str = 'it_IT-paola-medium' diff --git a/tests/test_backend.py b/tests/test_backend.py index 4fc577b..d8f96be 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -4,7 +4,7 @@ import time import os import subprocess -from multilang_translator.config import LANG_CONFIG +from multilang_translator.translator_config import LANG_CONFIG from multilang_translator.backend_controller.broadcaster_play_once import broadcaster_play_file from multilang_translator.backend_controller.broadcaster_copy_files import copy_to_broadcaster