restructure_for_cloud (#3)

- implement a presentable working version of translator_Server

Reviewed-on: https://gitea.pstruebi.xyz/auracaster/multilang-translator-local/pulls/3
This commit was merged in pull request #3.
This commit is contained in:
2025-03-19 12:58:59 +01:00
parent 7fa677d865
commit 92169ed4ae
31 changed files with 1048 additions and 103 deletions

3
.vscode/tasks.json vendored
View File

@@ -17,7 +17,8 @@
{ {
"label": "pip install -e auracast", "label": "pip install -e auracast",
"type": "shell", "type": "shell",
"command": "./venv/bin/python -m pip install -e ../bumble-auracast --config-settings editable_mode=compat" "command": "./venv/bin/python -m pip install -e ../bumble-auracast --config-settings editable_mode=compat",
"problemMatcher": []
} }
] ]
} }

View File

@@ -1,37 +0,0 @@
import os
from pydantic import BaseModel
from auracast import auracast_config
ANNOUNCEMENT_DIR = os.path.join(os.path.dirname(__file__), 'announcements')
VENV_DIR = os.path.join(os.path.dirname(__file__), '../venv')
PIPER_EXE_PATH = f'{VENV_DIR}/bin/piper'
class TranslatorBaseconfig(BaseModel):
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigDe()
translator_llm: str = 'llama3.2:3b-instruct-q4_0'
llm_client: str = 'ollama'
llm_host_url: str | None = 'http://localhost:11434'
llm_host_token: str | None = None
tts_system: str = 'piper'
tts_model: str ='de_DE-kerstin-low'
class TranslatorConfigDe(TranslatorBaseconfig):
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigDe()
tts_model: str ='de_DE-thorsten-high'
class TranslatorConfigEn(TranslatorBaseconfig):
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigEn()
tts_model: str = 'en_GB-alba-medium'
class TranslatorConfigFr(TranslatorBaseconfig):
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigFr()
tts_model: str = 'fr_FR-siwis-medium'
class TranslatorConfigEs(TranslatorBaseconfig):
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigEs()
tts_model: str = 'es_ES-sharvard-medium'
class TranslatorConfigIt(TranslatorBaseconfig):
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigIt()
tts_model: str = 'it_IT-paola-medium'

View File

@@ -8,8 +8,9 @@ dependencies = [
"requests==2.32.3", "requests==2.32.3",
"ollama==0.4.7", "ollama==0.4.7",
"aioconsole==0.8.1", "aioconsole==0.8.1",
"piper-phonemize==1.1.0", "fastapi==0.115.11",
"piper-tts==1.2.0", "uvicorn==0.34.0",
"aiohttp==3.9.3",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@@ -17,6 +18,11 @@ test = [
"pytest >= 8.2", "pytest >= 8.2",
] ]
[tool.poetry.group.tts.dependencies]
piper-phonemize = "==1.1.0"
piper-tts = "==1.2.0"
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = [ addopts = [
"--import-mode=importlib","--count=1","-s","-v" "--import-mode=importlib","--count=1","-s","-v"

View File

@@ -0,0 +1,89 @@
from typing import List
import time
import asyncio
import logging as log
from auracast import multicast_client
from auracast import auracast_config
import voice_client
import voice_models
from multilang_translator import translator_config
from multilang_translator.translator import llm_translator
import voice_client.tts_client
import voice_models.request_models
async def announcement_from_german_text(
config: translator_config.TranslatorConfigGroup,
text_de
):
base_lang = "deu"
audio_data_dict = {}
for i, big in enumerate(config.bigs):
if big.language == base_lang:
text = text_de
else:
text = llm_translator.translate_de_to_x(
text_de,
big.language,
model=big.translator_llm,
client = big.llm_client,
host=big.llm_host_url,
token=big.llm_host_token
)
log.info('%s', text)
request_data = voice_models.request_models.SynthesizeRequest(
text=text,
target_sample_rate=config.auracast_sampling_rate_hz,
framework=big.tts_system,
model=big.tts_model,
return_lc3=True
)
start = time.time()
lc3_audio = voice_client.tts_client.request_synthesis(
request_data
)
log.info('Voice synth took %s', time.time() - start)
audio_data_dict[big.language] = lc3_audio.decode('latin-1') # TODO: should be .hex in the future
await multicast_client.send_audio(
audio_data_dict
)
async def main():
log.basicConfig(
level=log.INFO,
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
)
config = translator_config.TranslatorConfigGroup(
bigs=[
translator_config.TranslatorConfigDe(),
translator_config.TranslatorConfigEn(),
translator_config.TranslatorConfigFr(),
]
)
config.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc
for conf in config.bigs:
conf.loop = False
conf.llm_client = 'openwebui' # comment out for local llm
conf.llm_host_url = 'https://ollama.pstruebi.xyz'
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
await multicast_client.init(
config
)
await announcement_from_german_text(config, 'Hello')
if __name__ == '__main__':
asyncio.run(main())

View File

@@ -1,24 +1,17 @@
# -*- coding: utf-8 -*-
"""
list prompt example
"""
from __future__ import print_function, unicode_literals
from typing import List from typing import List
from dataclasses import asdict from dataclasses import asdict
import asyncio import asyncio
from copy import copy
import time import time
import logging as log import logging as log
import aioconsole import aioconsole
import multilang_translator.translator_config as translator_config
from utils import resample
from translator import llm_translator, test_content
from text_to_speech import text_to_speech
from auracast import multicast_control from auracast import multicast_control
from auracast import auracast_config from auracast import auracast_config
from translator.test_content import TESTSENTENCE from voice_provider import text_to_speech
from multilang_translator import translator_config
from multilang_translator.translator import llm_translator
from multilang_translator.translator.test_content import TESTSENTENCE
# TODO: look for a end to end translation solution # TODO: look for a end to end translation solution
@@ -27,33 +20,32 @@ def transcribe():
async def announcement_from_german_text( async def announcement_from_german_text(
global_config: auracast_config.AuracastGlobalConfig, config: translator_config.TranslatorConfigGroup,
translator_config: List[translator_config.TranslatorConfigDe],
caster: multicast_control.Multicaster, caster: multicast_control.Multicaster,
text_de text_de
): ):
base_lang = "deu" base_lang = "deu"
for i, trans in enumerate(translator_config): for i, big in enumerate(config.bigs):
if trans.big.language == base_lang: if big.language == base_lang:
text = text_de text = text_de
else: else:
text = llm_translator.translate_de_to_x( text = llm_translator.translate_de_to_x(
text_de, text_de,
trans.big.language, big.language,
model=trans.translator_llm, model=big.translator_llm,
client = trans.llm_client, client = big.llm_client,
host=trans.llm_host_url, host=big.llm_host_url,
token=trans.llm_host_token token=big.llm_host_token
) )
log.info('%s', text) log.info('%s', text)
lc3_audio = text_to_speech.synthesize( lc3_audio = text_to_speech.synthesize(
text, text,
global_config.auracast_sampling_rate_hz, config.auracast_sampling_rate_hz,
trans.tts_system, big.tts_system,
trans.tts_model, big.tts_model,
return_lc3=True return_lc3=True
) )
caster.big_conf[i].audio_source = lc3_audio caster.big_conf[i].audio_source = lc3_audio
@@ -64,7 +56,7 @@ async def announcement_from_german_text(
log.info("Starting all broadcasts took %s s", round(time.time() - start, 3)) log.info("Starting all broadcasts took %s s", round(time.time() - start, 3))
async def command_line_ui(global_conf, translator_conf, caster: multicast_control.Multicaster): async def command_line_ui(config: translator_config.TranslatorConfigGroup, translator_conf, caster: multicast_control.Multicaster):
while True: while True:
# make a list of all available testsentence # make a list of all available testsentence
sentence_list = list(asdict(TESTSENTENCE).values()) sentence_list = list(asdict(TESTSENTENCE).values())
@@ -86,8 +78,7 @@ async def command_line_ui(global_conf, translator_conf, caster: multicast_contro
elif command.strip().isdigit(): elif command.strip().isdigit():
ind = int(command.strip()) ind = int(command.strip())
await announcement_from_german_text( await announcement_from_german_text(
global_conf, config,
translator_conf,
caster, caster,
sentence_list[ind]) sentence_list[ind])
await asyncio.wait([caster.streamer.task]) await asyncio.wait([caster.streamer.task])
@@ -103,35 +94,41 @@ async def main():
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
) )
global_conf = auracast_config.AuracastGlobalConfig() config = translator_config.TranslatorConfigGroup(
#global_conf.transport='serial:/dev/serial/by-id/usb-SEGGER_J-Link_001057705357-if02,1000000,rtscts' # transport for nrf54l15dk bigs=[
global_conf.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc translator_config.TranslatorConfigDe(),
translator_config.TranslatorConfigEn(),
translator_config.TranslatorConfigFr(),
translator_conf = [
translator_config.TranslatorConfigDe(),
translator_config.TranslatorConfigEn(),
translator_config.TranslatorConfigFr(),
#auracast_config.broadcast_es,
#auracast_config.broadcast_it,
] ]
for conf in translator_conf: )
conf.big.loop = False
#config = auracast_config.AuracastGlobalConfig()
#config.transport='serial:/dev/serial/by-id/usb-SEGGER_J-Link_001057705357-if02,1000000,rtscts' # transport for nrf54l15dk
config.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc
for conf in config.bigs:
conf.loop = False
conf.llm_client = 'openwebui' # comment out for local llm conf.llm_client = 'openwebui' # comment out for local llm
conf.llm_host_url = 'https://ollama.pstruebi.xyz' conf.llm_host_url = 'https://ollama.pstruebi.xyz'
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13' conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
caster = multicast_control.Multicaster(global_conf, [conf.big for conf in translator_conf]) caster = multicast_control.Multicaster(
config,
[big for big in config.bigs]
)
await caster.init_broadcast() await caster.init_broadcast()
# await announcement_from_german_text( # await announcement_from_german_text(
# global_conf, # config,
# translator_conf,
# caster, # caster,
# test_content.TESTSENTENCE.DE_HELLO # test_content.TESTSENTENCE.DE_HELLO
# ) # )
# await asyncio.wait([caster.streamer.task]) # await asyncio.wait([caster.streamer.task])
await command_line_ui(global_conf, translator_conf, caster) await command_line_ui(
config,
[big for big in config.bigs],
caster
)
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(main()) asyncio.run(main())

View File

@@ -4,6 +4,7 @@ import json
import logging as log import logging as log
import time import time
import ollama import ollama
import aiohttp
from multilang_translator.translator import syspromts from multilang_translator.translator import syspromts
@@ -12,10 +13,6 @@ from multilang_translator.translator import syspromts
# from_='llama3.2', system="You are Mario from Super Mario Bros." # from_='llama3.2', system="You are Mario from Super Mario Bros."
# ) # )
async def chat():
message = {'role': 'user', 'content': 'Why is the sky blue?'}
response = await ollama.AsyncClient().chat(model='llama3.2', messages=[message])
def query_openwebui(model, system, query, url, token): def query_openwebui(model, system, query, url, token):
url = f'{url}/api/chat/completions' url = f'{url}/api/chat/completions'
@@ -50,6 +47,41 @@ def query_ollama(model, system, query, host='http://localhost:11434'):
return response.message.content return response.message.content
async def query_openwebui_async(model, system, query, url, token):
url = f'{url}/api/chat/completions'
headers = {
'Authorization': f'Bearer {token}',
}
payload = {
'model': model,
'messages': [
{'role': 'system', 'content': system},
{'role': 'user', 'content': query}
],
}
start = time.time()
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as response:
response_json = await response.json()
log.info("Translating the text took %s s", round(time.time() - start, 2))
return response_json['choices'][0]['message']['content']
async def query_ollama_async(model, system, query, host='http://localhost:11434'):
client = ollama.AsyncClient(
host=host,
)
response = await client.chat(
model=model,
messages=[
{'role': 'system', 'content': system},
{'role': 'user', 'content': query}
],
)
return response.message.content
def translate_de_to_x( # TODO: use async ollama client later - implenent a translate async function def translate_de_to_x( # TODO: use async ollama client later - implenent a translate async function
text:str, text:str,
target_language: str, target_language: str,
@@ -70,6 +102,27 @@ def translate_de_to_x( # TODO: use async ollama client later - implenent a trans
log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3)) log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3))
return response return response
async def translate_de_to_x_async(
text:str,
target_language: str,
client='ollama',
model='llama3.2:3b-instruct-q4_0', # remember to use instruct models
host = None,
token = None
):
start=time.time()
s = getattr(syspromts, f"TRANSLATOR_DEU_{target_language.upper()}")
if client == 'ollama':
response = await query_ollama_async(model, s, text, host=host)
elif client == 'openwebui':
response = await query_openwebui_async(model, s, text, url=host, token=token)
else: raise NotImplementedError('llm client not implemented')
log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3))
return response
if __name__ == "__main__": if __name__ == "__main__":
import time import time
from multilang_translator.translator import test_content from multilang_translator.translator import test_content

View File

@@ -0,0 +1,94 @@
"""
API client functions for interacting with the Translator API.
"""
import requests
from typing import List, Optional, Dict, Any, Tuple
from enum import Enum
from multilang_translator.translator_models.translator_models import AnnouncementStates, Endpoint, EndpointGroup
# This can be overridden through environment variables
API_BASE_URL = "http://localhost:7999"
def get_groups() -> List[EndpointGroup]:
"""Get all endpoint groups."""
response = requests.get(f"{API_BASE_URL}/groups")
response.raise_for_status()
return [EndpointGroup.model_validate(group) for group in response.json()]
def get_group(group_id: int) -> Optional[EndpointGroup]:
"""Get a specific endpoint group by ID."""
response = requests.get(f"{API_BASE_URL}/groups/{group_id}")
if response.status_code == 404:
return None
response.raise_for_status()
return EndpointGroup.model_validate(response.json())
def create_group(group: EndpointGroup) -> EndpointGroup:
"""Create a new endpoint group."""
# Convert the model to a dict with enum values as their primitive values
payload = group.model_dump(mode='json')
response = requests.post(f"{API_BASE_URL}/groups", json=payload)
response.raise_for_status()
return EndpointGroup.model_validate(response.json())
def update_group(group_id: int, updated_group: EndpointGroup) -> EndpointGroup:
"""Update an existing endpoint group."""
# Convert the model to a dict with enum values as their primitive values
payload = updated_group.model_dump(mode='json')
response = requests.put(f"{API_BASE_URL}/groups/{group_id}", json=payload)
response.raise_for_status()
return EndpointGroup.model_validate(response.json())
def delete_group(group_id: int) -> None:
"""Delete an endpoint group."""
response = requests.delete(f"{API_BASE_URL}/groups/{group_id}")
response.raise_for_status()
def start_announcement(text: str, group_id: int) -> Dict[str, Any]:
"""
Start a new announcement.
Args:
text: The text content of the announcement
group_id: The ID of the endpoint group to send the announcement to
Returns:
Dictionary with status information
"""
response = requests.post(f"{API_BASE_URL}/announcement", params={"text": text, "group_id": group_id})
response.raise_for_status()
return response.json()
def get_group_state(group_id: int) -> Tuple[str, float]:
"""
Get the status of the current announcement for a specific group.
Args:
group_id: The ID of the group to check the announcement status for
Returns:
Tuple containing (state_name, state_value)
"""
response = requests.get(f"{API_BASE_URL}/groups/{group_id}/state")
response.raise_for_status()
state_data = response.json()
return (state_data["name"], state_data["value"])
def get_available_endpoints() -> List[Endpoint]:
"""Get all available endpoints."""
response = requests.get(f"{API_BASE_URL}/endpoints")
response.raise_for_status()
endpoints_dict = response.json()
# API returns a dictionary with endpoint IDs as keys
# Convert this to a list of Endpoint objects
return [Endpoint.model_validate(endpoint_data) for endpoint_id, endpoint_data in endpoints_dict.items()]
def get_available_languages() -> List[str]:
"""Get all available languages for announcements."""
response = requests.get(f"{API_BASE_URL}/languages")
response.raise_for_status()
return response.json()

View File

@@ -0,0 +1,20 @@
import os
from pydantic import BaseModel
VENV_DIR = os.path.join(os.path.dirname(__file__), './../../venv')
class TranslatorLangConfig(BaseModel):
translator_llm: str = 'llama3.2:3b-instruct-q4_0' # TODO: this was migrated to translator_models - remove this
llm_client: str = 'ollama'
llm_host_url: str | None = 'http://localhost:11434'
llm_host_token: str | None = None
tts_system: str = 'piper'
tts_model: str ='de_DE-kerstin-low'
class TranslatorConfig(BaseModel):
deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high')
eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium')
fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium')
spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium')
ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium')

View File

@@ -0,0 +1 @@
# Empty file to make the directory a package

View File

@@ -0,0 +1,58 @@
"""
Models for the translator API.
Similar to the models used in auracaster-webui but simplified for the translator middleware.
"""
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
class AnnouncementStates(Enum):
IDLE = 0
INIT = 0.1
TRANSLATING = 0.2
GENERATING_VOICE = 0.4
ROUTING = 0.6
BROADCASTING = 0.8
COMPLETED = 1
ERROR = 0
class Endpoint(BaseModel):
"""Defines an endpoint with its URL and capabilities."""
id: int
name: str
url: str
max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts
class TranslatorLangConfig(BaseModel):
translator_llm: str = 'llama3.2:3b-instruct-q4_0'
llm_client: str = 'openwebui' # remote (homserver)
llm_host_url: str = 'https://ollama.pstruebi.xyz'
llm_host_token: str = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
# llm_client: str = 'ollama' #local
# llm_host_url: str | None = 'http://localhost:11434'
# llm_host_token: str | None = None
tts_system: str = 'piper'
tts_model: str ='de_DE-kerstin-low'
class TranslatorConfig(BaseModel):
deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high')
eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium')
fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium')
spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium')
ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium')
class EndpointGroup(BaseModel):
id: int
name: str
languages: List[str]
endpoints: List[Endpoint]
sampling_rate_hz: int = 16000
translator_config: TranslatorConfig = TranslatorConfig()
current_state: AnnouncementStates = AnnouncementStates.IDLE
anouncement_start_time: float = 0.0

View File

@@ -0,0 +1 @@
# Empty file to make the directory a package

View File

@@ -0,0 +1,143 @@
"""
Database file for endpoint definitions.
This file contains configurations for auracast endpoints including their IP addresses and capabilities.
"""
from typing import List, Optional
from multilang_translator.translator_models.translator_models import EndpointGroup, Endpoint
SUPPORTED_LANGUAGES = ["deu", "eng", "fra", "spa", "ita"]
# Database of endpoints
ENDPOINTS: dict[int: Endpoint] = { # for now make sure, .id and key are the same
0: Endpoint(
id=0,
name="Local Endpoint",
url="http://localhost:5000",
max_broadcasts=3,
),
1: Endpoint(
id=1,
name="Gate 1",
url="http://pi3:5000",
max_broadcasts=3,
),
2: Endpoint(
id=2,
name="Gate 2",
url="http://192.168.1.102:5000",
max_broadcasts=3,
),
}
# Database of endpoint groups with default endpoints
ENDPOINT_GROUPS: dict[int:EndpointGroup] = { # for now make sure , .id and key are the same
0: EndpointGroup(
id=0,
name="Local Group",
languages=["deu", "eng"],
endpoints=[ENDPOINTS[0]],
),
1: EndpointGroup(
id=1,
name="Gate1",
languages=["deu", "fra"],
endpoints=[ENDPOINTS[1]],
)
}
def get_available_languages() -> List[str]:
"""Get a list of all supported languages."""
return SUPPORTED_LANGUAGES
# Endpoint functions
def get_all_endpoints() -> List[Endpoint]:
"""Get all active endpoints."""
return ENDPOINTS
def get_endpoint_by_id(endpoint_id: str) -> Optional[Endpoint]:
"""Get an endpoint by its ID."""
return ENDPOINTS[endpoint_id]
def add_endpoint(endpoint: Endpoint) -> Endpoint:
"""Add a new endpoint to the database."""
if endpoint.id in ENDPOINTS:
raise ValueError(f"Endpoint with ID {endpoint.id} already exists")
ENDPOINTS[endpoint.id] = endpoint
return endpoint
def update_endpoint(endpoint_id: str, updated_endpoint: Endpoint) -> Endpoint:
"""Update an existing endpoint in the database."""
if endpoint_id not in ENDPOINTS:
raise ValueError(f"Endpoint {endpoint_id} not found")
# Ensure the ID is preserved
updated_endpoint.id = endpoint_id
ENDPOINTS[endpoint_id] = updated_endpoint
return updated_endpoint
def delete_endpoint(endpoint_id: str) -> None:
"""Delete an endpoint from the database."""
if endpoint_id not in ENDPOINTS:
raise ValueError(f"Endpoint {endpoint_id} not found")
# Check if this endpoint is used in any groups
for group in ENDPOINT_GROUPS.values():
if endpoint_id in group.endpoints:
raise ValueError(f"Cannot delete endpoint {endpoint_id}, it is used in group {group.id}")
del ENDPOINTS[endpoint_id]
# Endpoint Group functions
def get_all_groups() -> List[EndpointGroup]:
"""Get all endpoint groups."""
return list(ENDPOINT_GROUPS.values())
def get_group_by_id(group_id: int) -> Optional[EndpointGroup]:
"""Get an endpoint group by its ID."""
return ENDPOINT_GROUPS.get(group_id)
def add_group(group: EndpointGroup) -> EndpointGroup:
"""Add a new endpoint group to the database."""
if group.id in ENDPOINT_GROUPS:
raise ValueError(f"Group with ID {group.id} already exists")
# Validate that all referenced endpoints exist
for endpoint_id in group.endpoints:
if endpoint_id not in ENDPOINTS:
raise ValueError(f"Endpoint {endpoint_id} not found")
ENDPOINT_GROUPS[group.id] = group
return group
def update_group(group_id: int, updated_group: EndpointGroup) -> EndpointGroup:
"""Update an existing endpoint group in the database."""
if group_id not in ENDPOINT_GROUPS:
raise ValueError(f"Group {group_id} not found")
# Validate that all referenced endpoints exist
for endpoint in updated_group.endpoints:
if endpoint.id not in ENDPOINTS.keys():
raise ValueError(f"Endpoint with id {endpoint.id} not found")
# Ensure the ID is preserved
updated_group.id = group_id
ENDPOINT_GROUPS[group_id] = updated_group
return updated_group
def delete_group(group_id: int) -> None:
"""Delete an endpoint group from the database."""
if group_id not in ENDPOINT_GROUPS:
raise ValueError(f"Group {group_id} not found")
del ENDPOINT_GROUPS[group_id]

View File

@@ -0,0 +1,28 @@
"""
Entry point for the Translator API server.
This file starts the FastAPI server with the translator_server.
"""
import uvicorn
import logging as log
import sys
import os
# Add the parent directory to the Python path to find the multilang_translator package
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
if __name__ == "__main__":
log.basicConfig(
level=log.INFO,
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
)
log.info("Starting Translator API server")
uvicorn.run(
"multilang_translator.translator_server.translator_server:app",
host="0.0.0.0",
port=7999,
reload=True,
log_level="debug"
)

View File

@@ -0,0 +1,346 @@
"""
FastAPI implementation of the Multilang Translator API.
This API mimics the mock_api from auracaster-webui to allow integration.
"""
import time
import logging as log
import asyncio
import random
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
# Import models
from multilang_translator.translator_models.translator_models import AnnouncementStates, Endpoint, EndpointGroup
from multilang_translator.translator import llm_translator
from multilang_translator.translator_server import endpoints_db
from voice_provider import text_to_speech
# Import the endpoints database and multicast client
from auracast import multicast_client, auracast_config
# Create FastAPI app
app = FastAPI()
# Add CORS middleware to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Endpoint configuration cache
CURRENT_ENDPOINT_CONFIG = {}
async def init_endpoint(endpoint: Endpoint, languages: list[str], sampling_rate_hz: int):
"""Initialize a specific endpoint for multicast."""
current_config = CURRENT_ENDPOINT_CONFIG.get(endpoint.id)
if current_config is not None:
current_langs = [big.language for big in current_config.bigs]
# if languages are unchanged and the caster client status is initiailized, skip init
if current_langs == languages:
# Get status asynchronously
status = await multicast_client.get_status(base_url=endpoint.url)
if status['is_initialized']:
log.info('Endpoint %s was already initialized', endpoint.name)
return
log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.url}")
# Load a default config
config = auracast_config.AuracastConfigGroup(
bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")()
for lang in languages]
)
# overwrite some default configs
config.transport = 'auto'
config.auracast_device_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6))
config.auracast_sampling_rate_hz = sampling_rate_hz
# Configure the bigs
for big in config.bigs:
big.loop = False
big.name = endpoint.name
big.random_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6))
big.id = random.randint(0, 2**16) #TODO: how many bits is this ?
#big.program_info = big.program_info + ' ' + endpoint.name
# make async init request
ret = await multicast_client.init(config, base_url=endpoint.url)
# if ret != 200: # TODO: this is not working, should probably be handled async
# log.error('Init of endpoint %s was unsucessfull', endpoint.name)
# raise Exception(f"Init was of endpoint {endpoint.name} was unsucessfull")
CURRENT_ENDPOINT_CONFIG[endpoint.id] = config.model_copy()
log.info(f"Endpoint {endpoint.name} initialized successfully")
#else:
# log.info('Endpoint %s was already initialized', endpoint.name)
async def make_announcement(text: str, ep_group: EndpointGroup):
"""
Make an announcement to a group of endpoints.
"""
if text == "":
log.warning("Announcement text is empty")
return {"error": "Announcement text is empty"}
ep_group.current_state = AnnouncementStates.IDLE
ep_group.anouncement_start_time = time.time()
# update the database with the new state and start time so this can be read by another process
endpoints_db.update_group(ep_group.id, ep_group)
# Initialize all endpoints in the group concurrently
ep_group.current_state = AnnouncementStates.INIT
endpoints_db.update_group(ep_group.id, ep_group)
# Create init tasks and run them concurrently
init_tasks = [
init_endpoint(endpoint, ep_group.languages, ep_group.sampling_rate_hz)
for endpoint in ep_group.endpoints
]
# make sure init finished
await asyncio.gather(*init_tasks)
# Translate the text for each language (concurrently)
base_lang = "deu" # German is the base language
target_langs = ep_group.languages.copy()
if base_lang in target_langs:
target_langs.remove(base_lang)
ep_group.current_state = AnnouncementStates.TRANSLATING
endpoints_db.update_group(ep_group.id, ep_group)
# Create translation tasks
translations = {base_lang: text}
translation_tasks = []
for lang in target_langs:
# Prepare translation task
trans_conf = getattr(ep_group.translator_config, lang)
task = llm_translator.translate_de_to_x_async(
text=text,
target_language=lang,
client=trans_conf.llm_client,
model=trans_conf.translator_llm,
host=trans_conf.llm_host_url,
token=trans_conf.llm_host_token
)
translation_tasks.append(task)
# Wait for all translations to complete concurrently
results = await asyncio.gather(*translation_tasks)
for i, translation in enumerate(results):
lang = target_langs[i]
translations[lang] = translation
log.info(f"Translated to {lang}: {translation}")
# Generate voices concurrently
ep_group.current_state = AnnouncementStates.GENERATING_VOICE
endpoints_db.update_group(ep_group.id, ep_group)
# Prepare synthesis tasks and run them concurrently
synth_langs = ep_group.languages
synthesis_tasks = []
for lang in synth_langs:
trans_conf = getattr(ep_group.translator_config, lang)
task = text_to_speech.synthesize_async(
translations[lang],
ep_group.sampling_rate_hz,
trans_conf.tts_system,
trans_conf.tts_model,
return_lc3=True
)
synthesis_tasks.append(task)
# Wait for all synthesis tasks to complete concurrently
audio = {}
if synthesis_tasks:
results = await asyncio.gather(*synthesis_tasks)
for i, audio_data in enumerate(results):
audio[synth_langs[i]] = audio_data
# Start the monitoring coroutine to wait for streaming to complete
# This will set the state to COMPLETED when finished
asyncio.create_task(monitor_streaming_completion(ep_group))
# Broadcast to all endpoints in group concurrently
broadcast_tasks = []
for endpoint in ep_group.endpoints:
log.info(f"Broadcasting to {endpoint.name} for languages: {', '.join(audio.keys())}")
task = multicast_client.send_audio(audio, base_url=endpoint.url)
broadcast_tasks.append(task)
# Wait for all broadcasts to complete
await asyncio.gather(*broadcast_tasks)
# Return the translations
return {"translations": translations}
async def monitor_streaming_completion(ep_group: EndpointGroup):
"""
Monitor streaming status after audio is sent and update group state when complete.
Args:
ep_group: The endpoint group being monitored
"""
log.info(f"Starting streaming completion monitoring for endpoint group {ep_group.id}")
# Set a shorter timeout as requested
max_completion_time = 60 # seconds
# First check if we are actually in streaming state
streaming_started = False
initial_check_timeout = 10 # seconds
initial_check_start = time.time()
# Wait for streaming to start (with timeout)
while time.time() - initial_check_start < initial_check_timeout:
# Wait before checking again
await asyncio.sleep(1)
any_streaming = False
for endpoint in ep_group.endpoints:
status = await multicast_client.get_status(base_url=endpoint.url)
if status.get("is_streaming", False):
any_streaming = True
log.info(f"Streaming confirmed started on endpoint {endpoint.name}")
break
if any_streaming:
streaming_started = True
break
if not streaming_started:
log.warning(f"No endpoints started streaming for group {ep_group.id} after {initial_check_timeout}s")
# Still update to completed since there's nothing to wait for
ep_group.current_state = AnnouncementStates.ERROR
endpoints_db.update_group(ep_group.id, ep_group)
return
# Update group progress
ep_group.current_state = AnnouncementStates.BROADCASTING
endpoints_db.update_group(ep_group.id, ep_group)
# Now monitor until streaming completes on all endpoints
check_completion_start_time = time.time()
completed = [False for _ in ep_group.endpoints]
while not all(completed) or time.time() - check_completion_start_time > max_completion_time:
await asyncio.sleep(1)
# Check status of each endpoint
for i, endpoint in enumerate(ep_group.endpoints):
status = await multicast_client.get_status(base_url=endpoint.url)
completed[i] = not status['is_streaming']
if all(completed):
log.info(f"All endpoints completed streaming for group {ep_group.id}")
# Update group state to completed
ep_group.current_state = AnnouncementStates.COMPLETED
endpoints_db.update_group(ep_group.id, ep_group)
log.info(f"Updated group {ep_group.id} state to COMPLETED")
else:
log.error(f"Max wait time reached for group {ep_group.id}. Forcing completion.")
@app.get("/groups")
async def get_groups():
"""Get all endpoint groups with their current status."""
return endpoints_db.get_all_groups()
@app.post("/groups")
async def create_group(group: endpoints_db.EndpointGroup):
"""Add a new endpoint group."""
try:
return endpoints_db.add_group(group)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/groups/{group_id}/state") # TODO: think about progress tracking
async def get_group_state(group_id: int):
"""Get the status of a specific endpoint."""
# Check if the endpoint exists
ep_group = endpoints_db.get_group_by_id(group_id)
if not ep_group:
raise HTTPException(status_code=404, detail=f"Endpoint {group_id} not found")
return {"name": ep_group.current_state.name, "value": ep_group.current_state.value}
@app.put("/groups/{group_id}")
async def update_group(group_id: int, updated_group: endpoints_db.EndpointGroup):
"""Update an existing endpoint group."""
try:
return endpoints_db.update_group(group_id, updated_group)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.delete("/groups/{group_id}")
async def delete_group(group_id: int):
"""Delete an endpoint group."""
try:
endpoints_db.delete_group(group_id)
return {"message": f"Group {group_id} deleted successfully"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/announcement")
async def start_announcement(text: str, group_id: int):
"""Start a new announcement to the specified endpoint group."""
global announcement_task
# Get the group from active groups or database
group = endpoints_db.get_group_by_id(group_id)
if not group:
raise HTTPException(status_code=400, detail=f"Group {group_id} not found")
# Check if we're already processing an announcement
#if announcement_task and not announcement_task.done():
# raise HTTPException(status_code=400, detail="Already processing an announcement")
# Start the announcement task
announcement_task = asyncio.create_task(make_announcement(text, group))
return {"status": "Announcement started", "group_id": group_id}
@app.get("/endpoints")
async def get_available_endpoints():
"""Get all available endpoints with their capabilities."""
return endpoints_db.get_all_endpoints()
@app.get("/languages")
async def get_available_languages():
"""Get all available languages for announcements."""
return endpoints_db.get_available_languages()
if __name__ == "__main__":
import uvicorn
log.basicConfig(
level=log.DEBUG,
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
)
# with reload=True logging of modules does not function as expected
uvicorn.run(
app,
#'translator_server:app',
host="0.0.0.0",
port=7999,
#reload=True,
#log_config=None,
#log_level="info"
)

View File

@@ -0,0 +1,44 @@
import requests
import numpy as np
import soundfile as sf
from voice_models.request_models import SynthesizeRequest
API_URL = "http://127.0.0.1:8099/synthesize/"
def request_synthesis(request_data: SynthesizeRequest):
response = requests.post(API_URL, json=request_data.model_dump())
if response.status_code == 200:
response_data = response.json()
if request_data.return_lc3:
# Save LC3 audio as binary file
lc3_bytes = bytes.fromhex(response_data["audio_lc3"])
return lc3_bytes
else:
# Convert hex-encoded PCM bytes back to numpy array and save as WAV
audio_bytes = bytes.fromhex(response_data["audio_pcm"])
audio_array = np.frombuffer(audio_bytes, dtype=np.float32)
return audio_array
else:
print(f"Error: {response.status_code}, {response.text}")
if __name__ == "__main__":
target_rate=16000
# Example request
request_data = SynthesizeRequest(
text="Hello, this is a test.",
target_sample_rate=target_rate,
framework="piper",
model="de_DE-kerstin-low",
return_lc3=False # Set to True to receive LC3 compressed output
)
audio = request_synthesis(request_data)
sf.write('hello.wav', audio, target_rate)

View File

@@ -0,0 +1,9 @@
from pydantic import BaseModel
class SynthesizeRequest(BaseModel):
text: str
target_sample_rate: int = 16000
framework: str = "piper"
model: str = "en_US-lessac-medium"
return_lc3: bool = False

View File

@@ -1,27 +1,34 @@
import os import os
import shutil
import subprocess import subprocess
import time import time
import json import json
import logging as log import logging as log
import numpy as np import numpy as np
from multilang_translator import translator_config import asyncio
from multilang_translator.utils.resample import resample_array from voice_provider.utils.resample import resample_array
from multilang_translator.text_to_speech import encode_lc3 from voice_provider.utils.encode_lc3 import encode_lc3
PIPER_EXE = shutil.which('piper')
TTS_DIR = os.path.join(os.path.dirname(__file__)) TTS_DIR = os.path.join(os.path.dirname(__file__))
PIPER_DIR = f'{TTS_DIR}/piper' PIPER_WORKDIR = f'{TTS_DIR}/piper'
if not PIPER_EXE:
PIPER_EXE = f'{TTS_DIR}/../../venv/bin/piper'
def synth_piper(text, model="en_US-lessac-medium"): def synth_piper(text, model="en_US-lessac-medium"):
pwd = os.getcwd() pwd = os.getcwd()
os.chdir(PIPER_DIR) os.chdir(PIPER_WORKDIR)
start = time.time() start = time.time()
# make sure piper has voices.json in working directory, otherwise it attempts to always load models # make sure piper has voices.json in working directory, otherwise it attempts to always load models
ret = subprocess.run( # TODO: wrap this whole thing in a class and open a permanent pipe to the model ret = subprocess.run( # TODO: wrap this whole thing in a class and open a permanent pipe to the model
[translator_config.PIPER_EXE_PATH, [
'--cuda', PIPER_EXE,
'--model', model, '--cuda',
'--output-raw' '--model', model,
'--output-raw'
], ],
input=text.encode('utf-8'), input=text.encode('utf-8'),
capture_output=True capture_output=True
@@ -34,14 +41,19 @@ def synth_piper(text, model="en_US-lessac-medium"):
log.info("Running piper for model %s took %s s", model, round(time.time() - start, 3)) log.info("Running piper for model %s took %s s", model, round(time.time() - start, 3))
with open (f'{PIPER_DIR}/{model}.onnx.json') as f: # TODO: wrap everyth0ing into a class, store the json permanently with open (f'{PIPER_WORKDIR}/{model}.onnx.json') as f: # TODO: wrap everyth0ing into a class, store the json permanently
model_json = json.load(f) model_json = json.load(f)
return model_json, audio return model_json, audio
# TODO: framework should probably be a dataclass that holds all the relevant informations, also model def synthesize(
def synthesize(text, target_sample_rate, framework, model="en_US-lessac-medium", return_lc3=True): text,
target_sample_rate,
framework,
model="en_US-lessac-medium",
return_lc3=True
):
if framework == 'piper': if framework == 'piper':
model_json, audio_raw = synth_piper(text, model) model_json, audio_raw = synth_piper(text, model)
@@ -59,12 +71,41 @@ def synthesize(text, target_sample_rate, framework, model="en_US-lessac-medium",
if return_lc3: if return_lc3:
audio_pcm = (audio * 2**15-1).astype(np.int16) audio_pcm = (audio * 2**15-1).astype(np.int16)
lc3 = encode_lc3.encode(audio_pcm, target_sample_rate, 40) # TODO: octetts per frame should be parameter lc3 = encode_lc3(audio_pcm, target_sample_rate, 40) # TODO: octetts per frame should be parameter
return lc3 return lc3
else: else:
return audio return audio
async def synthesize_async(
text,
target_sample_rate,
framework,
model="en_US-lessac-medium",
return_lc3=True
):
"""
Asynchronous version of the synthesize function that runs in a thread pool.
Args:
text: Text to synthesize
target_sample_rate: Target sample rate for the audio
framework: TTS framework to use (e.g., 'piper')
model: Model to use for synthesis
return_lc3: Whether to return LC3-encoded audio
Returns:
LC3-encoded audio as string or raw audio as numpy array
"""
# Run the CPU-intensive synthesis in a thread pool
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: synthesize(text, target_sample_rate, framework, model, return_lc3)
)
return result
if __name__ == '__main__': if __name__ == '__main__':
import logging import logging
import soundfile as sf import soundfile as sf
@@ -79,5 +120,4 @@ if __name__ == '__main__':
sf.write('hello.wav', audio, target_rate) sf.write('hello.wav', audio, target_rate)
# TODO: "WARNING:piper.download:Wrong size (expected=5952, actual=4158
print('Done.') print('Done.')

View File

@@ -0,0 +1,43 @@
from fastapi import FastAPI, HTTPException
import numpy as np
from voice_models.request_models import SynthesizeRequest
from voice_provider.text_to_speech import synthesize_async
app = FastAPI()
HOST_PORT = 8099
@app.post("/synthesize/")
async def synthesize_speech(request: SynthesizeRequest):
try:
audio = await synthesize_async(
text=request.text,
target_sample_rate=request.target_sample_rate,
framework=request.framework,
model=request.model,
return_lc3=request.return_lc3
)
if request.return_lc3:
# If it's already a string (LC3 data), convert it to bytes for hex encoding
if isinstance(audio, str):
audio_bytes = audio.encode('latin-1')
return {"audio_lc3": audio_bytes.hex()}
# If it's already bytes
elif isinstance(audio, bytes):
return {"audio_lc3": audio.hex()}
else:
raise ValueError(f"Unexpected audio type: {type(audio)}")
else:
# If it's numpy array (non-LC3), convert to bytes
audio_bytes = audio.astype(np.float32).tobytes()
return {"audio_pcm": audio_bytes.hex()}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=HOST_PORT)

View File

View File

@@ -1,7 +1,7 @@
import numpy as np import numpy as np
import lc3 import lc3
def encode( def encode_lc3(
audio: np.array, audio: np.array,
output_sample_rate_hz, output_sample_rate_hz,
octets_per_frame, octets_per_frame,

9
tests/get_group_state.py Normal file
View File

@@ -0,0 +1,9 @@
import requests
import time
if __name__ == '__main__':
# get the group state every 0.5s
while True:
response = requests.get('http://localhost:7999/groups/0/state')
print(response.json())
time.sleep(0.5)