restructure_for_cloud (#3)
- implement a presentable working version of translator_Server Reviewed-on: https://gitea.pstruebi.xyz/auracaster/multilang-translator-local/pulls/3
This commit was merged in pull request #3.
This commit is contained in:
3
.vscode/tasks.json
vendored
3
.vscode/tasks.json
vendored
@@ -17,7 +17,8 @@
|
||||
{
|
||||
"label": "pip install -e auracast",
|
||||
"type": "shell",
|
||||
"command": "./venv/bin/python -m pip install -e ../bumble-auracast --config-settings editable_mode=compat"
|
||||
"command": "./venv/bin/python -m pip install -e ../bumble-auracast --config-settings editable_mode=compat",
|
||||
"problemMatcher": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
from auracast import auracast_config
|
||||
|
||||
ANNOUNCEMENT_DIR = os.path.join(os.path.dirname(__file__), 'announcements')
|
||||
VENV_DIR = os.path.join(os.path.dirname(__file__), '../venv')
|
||||
PIPER_EXE_PATH = f'{VENV_DIR}/bin/piper'
|
||||
|
||||
class TranslatorBaseconfig(BaseModel):
|
||||
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigDe()
|
||||
translator_llm: str = 'llama3.2:3b-instruct-q4_0'
|
||||
llm_client: str = 'ollama'
|
||||
llm_host_url: str | None = 'http://localhost:11434'
|
||||
llm_host_token: str | None = None
|
||||
tts_system: str = 'piper'
|
||||
tts_model: str ='de_DE-kerstin-low'
|
||||
|
||||
|
||||
class TranslatorConfigDe(TranslatorBaseconfig):
|
||||
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigDe()
|
||||
tts_model: str ='de_DE-thorsten-high'
|
||||
|
||||
class TranslatorConfigEn(TranslatorBaseconfig):
|
||||
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigEn()
|
||||
tts_model: str = 'en_GB-alba-medium'
|
||||
|
||||
class TranslatorConfigFr(TranslatorBaseconfig):
|
||||
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigFr()
|
||||
tts_model: str = 'fr_FR-siwis-medium'
|
||||
|
||||
class TranslatorConfigEs(TranslatorBaseconfig):
|
||||
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigEs()
|
||||
tts_model: str = 'es_ES-sharvard-medium'
|
||||
|
||||
class TranslatorConfigIt(TranslatorBaseconfig):
|
||||
big: auracast_config.AuracastBigConfig = auracast_config.AuracastBigConfigIt()
|
||||
tts_model: str = 'it_IT-paola-medium'
|
||||
@@ -8,8 +8,9 @@ dependencies = [
|
||||
"requests==2.32.3",
|
||||
"ollama==0.4.7",
|
||||
"aioconsole==0.8.1",
|
||||
"piper-phonemize==1.1.0",
|
||||
"piper-tts==1.2.0",
|
||||
"fastapi==0.115.11",
|
||||
"uvicorn==0.34.0",
|
||||
"aiohttp==3.9.3",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -17,6 +18,11 @@ test = [
|
||||
"pytest >= 8.2",
|
||||
]
|
||||
|
||||
[tool.poetry.group.tts.dependencies]
|
||||
piper-phonemize = "==1.1.0"
|
||||
piper-tts = "==1.2.0"
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = [
|
||||
"--import-mode=importlib","--count=1","-s","-v"
|
||||
|
||||
89
src/multilang_translator/main_cloud.py
Normal file
89
src/multilang_translator/main_cloud.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import List
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
import logging as log
|
||||
|
||||
from auracast import multicast_client
|
||||
from auracast import auracast_config
|
||||
|
||||
import voice_client
|
||||
import voice_models
|
||||
|
||||
from multilang_translator import translator_config
|
||||
from multilang_translator.translator import llm_translator
|
||||
import voice_client.tts_client
|
||||
import voice_models.request_models
|
||||
|
||||
|
||||
async def announcement_from_german_text(
|
||||
config: translator_config.TranslatorConfigGroup,
|
||||
text_de
|
||||
):
|
||||
base_lang = "deu"
|
||||
|
||||
audio_data_dict = {}
|
||||
for i, big in enumerate(config.bigs):
|
||||
if big.language == base_lang:
|
||||
text = text_de
|
||||
else:
|
||||
text = llm_translator.translate_de_to_x(
|
||||
text_de,
|
||||
big.language,
|
||||
model=big.translator_llm,
|
||||
client = big.llm_client,
|
||||
host=big.llm_host_url,
|
||||
token=big.llm_host_token
|
||||
)
|
||||
|
||||
log.info('%s', text)
|
||||
request_data = voice_models.request_models.SynthesizeRequest(
|
||||
text=text,
|
||||
target_sample_rate=config.auracast_sampling_rate_hz,
|
||||
framework=big.tts_system,
|
||||
model=big.tts_model,
|
||||
return_lc3=True
|
||||
)
|
||||
start = time.time()
|
||||
lc3_audio = voice_client.tts_client.request_synthesis(
|
||||
request_data
|
||||
)
|
||||
log.info('Voice synth took %s', time.time() - start)
|
||||
audio_data_dict[big.language] = lc3_audio.decode('latin-1') # TODO: should be .hex in the future
|
||||
|
||||
await multicast_client.send_audio(
|
||||
audio_data_dict
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
log.basicConfig(
|
||||
level=log.INFO,
|
||||
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||
)
|
||||
|
||||
config = translator_config.TranslatorConfigGroup(
|
||||
bigs=[
|
||||
translator_config.TranslatorConfigDe(),
|
||||
translator_config.TranslatorConfigEn(),
|
||||
translator_config.TranslatorConfigFr(),
|
||||
]
|
||||
)
|
||||
|
||||
config.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc
|
||||
|
||||
for conf in config.bigs:
|
||||
conf.loop = False
|
||||
conf.llm_client = 'openwebui' # comment out for local llm
|
||||
conf.llm_host_url = 'https://ollama.pstruebi.xyz'
|
||||
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
|
||||
|
||||
await multicast_client.init(
|
||||
config
|
||||
)
|
||||
|
||||
await announcement_from_german_text(config, 'Hello')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -1,24 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
list prompt example
|
||||
"""
|
||||
from __future__ import print_function, unicode_literals
|
||||
|
||||
from typing import List
|
||||
from dataclasses import asdict
|
||||
import asyncio
|
||||
from copy import copy
|
||||
import time
|
||||
import logging as log
|
||||
import aioconsole
|
||||
|
||||
import multilang_translator.translator_config as translator_config
|
||||
from utils import resample
|
||||
from translator import llm_translator, test_content
|
||||
from text_to_speech import text_to_speech
|
||||
from auracast import multicast_control
|
||||
from auracast import auracast_config
|
||||
from translator.test_content import TESTSENTENCE
|
||||
from voice_provider import text_to_speech
|
||||
|
||||
from multilang_translator import translator_config
|
||||
from multilang_translator.translator import llm_translator
|
||||
from multilang_translator.translator.test_content import TESTSENTENCE
|
||||
|
||||
# TODO: look for a end to end translation solution
|
||||
|
||||
@@ -27,33 +20,32 @@ def transcribe():
|
||||
|
||||
|
||||
async def announcement_from_german_text(
|
||||
global_config: auracast_config.AuracastGlobalConfig,
|
||||
translator_config: List[translator_config.TranslatorConfigDe],
|
||||
config: translator_config.TranslatorConfigGroup,
|
||||
caster: multicast_control.Multicaster,
|
||||
text_de
|
||||
):
|
||||
base_lang = "deu"
|
||||
|
||||
for i, trans in enumerate(translator_config):
|
||||
if trans.big.language == base_lang:
|
||||
for i, big in enumerate(config.bigs):
|
||||
if big.language == base_lang:
|
||||
text = text_de
|
||||
else:
|
||||
text = llm_translator.translate_de_to_x(
|
||||
text_de,
|
||||
trans.big.language,
|
||||
model=trans.translator_llm,
|
||||
client = trans.llm_client,
|
||||
host=trans.llm_host_url,
|
||||
token=trans.llm_host_token
|
||||
big.language,
|
||||
model=big.translator_llm,
|
||||
client = big.llm_client,
|
||||
host=big.llm_host_url,
|
||||
token=big.llm_host_token
|
||||
)
|
||||
|
||||
log.info('%s', text)
|
||||
|
||||
lc3_audio = text_to_speech.synthesize(
|
||||
text,
|
||||
global_config.auracast_sampling_rate_hz,
|
||||
trans.tts_system,
|
||||
trans.tts_model,
|
||||
config.auracast_sampling_rate_hz,
|
||||
big.tts_system,
|
||||
big.tts_model,
|
||||
return_lc3=True
|
||||
)
|
||||
caster.big_conf[i].audio_source = lc3_audio
|
||||
@@ -64,7 +56,7 @@ async def announcement_from_german_text(
|
||||
log.info("Starting all broadcasts took %s s", round(time.time() - start, 3))
|
||||
|
||||
|
||||
async def command_line_ui(global_conf, translator_conf, caster: multicast_control.Multicaster):
|
||||
async def command_line_ui(config: translator_config.TranslatorConfigGroup, translator_conf, caster: multicast_control.Multicaster):
|
||||
while True:
|
||||
# make a list of all available testsentence
|
||||
sentence_list = list(asdict(TESTSENTENCE).values())
|
||||
@@ -86,8 +78,7 @@ async def command_line_ui(global_conf, translator_conf, caster: multicast_contro
|
||||
elif command.strip().isdigit():
|
||||
ind = int(command.strip())
|
||||
await announcement_from_german_text(
|
||||
global_conf,
|
||||
translator_conf,
|
||||
config,
|
||||
caster,
|
||||
sentence_list[ind])
|
||||
await asyncio.wait([caster.streamer.task])
|
||||
@@ -103,35 +94,41 @@ async def main():
|
||||
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||
)
|
||||
|
||||
global_conf = auracast_config.AuracastGlobalConfig()
|
||||
#global_conf.transport='serial:/dev/serial/by-id/usb-SEGGER_J-Link_001057705357-if02,1000000,rtscts' # transport for nrf54l15dk
|
||||
global_conf.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc
|
||||
|
||||
|
||||
translator_conf = [
|
||||
translator_config.TranslatorConfigDe(),
|
||||
translator_config.TranslatorConfigEn(),
|
||||
translator_config.TranslatorConfigFr(),
|
||||
#auracast_config.broadcast_es,
|
||||
#auracast_config.broadcast_it,
|
||||
config = translator_config.TranslatorConfigGroup(
|
||||
bigs=[
|
||||
translator_config.TranslatorConfigDe(),
|
||||
translator_config.TranslatorConfigEn(),
|
||||
translator_config.TranslatorConfigFr(),
|
||||
]
|
||||
for conf in translator_conf:
|
||||
conf.big.loop = False
|
||||
)
|
||||
|
||||
#config = auracast_config.AuracastGlobalConfig()
|
||||
#config.transport='serial:/dev/serial/by-id/usb-SEGGER_J-Link_001057705357-if02,1000000,rtscts' # transport for nrf54l15dk
|
||||
config.transport='serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' #nrf52dongle hci_uart usb cdc
|
||||
|
||||
for conf in config.bigs:
|
||||
conf.loop = False
|
||||
conf.llm_client = 'openwebui' # comment out for local llm
|
||||
conf.llm_host_url = 'https://ollama.pstruebi.xyz'
|
||||
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
|
||||
|
||||
caster = multicast_control.Multicaster(global_conf, [conf.big for conf in translator_conf])
|
||||
caster = multicast_control.Multicaster(
|
||||
config,
|
||||
[big for big in config.bigs]
|
||||
)
|
||||
await caster.init_broadcast()
|
||||
|
||||
# await announcement_from_german_text(
|
||||
# global_conf,
|
||||
# translator_conf,
|
||||
# config,
|
||||
# caster,
|
||||
# test_content.TESTSENTENCE.DE_HELLO
|
||||
# )
|
||||
# await asyncio.wait([caster.streamer.task])
|
||||
await command_line_ui(global_conf, translator_conf, caster)
|
||||
await command_line_ui(
|
||||
config,
|
||||
[big for big in config.bigs],
|
||||
caster
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging as log
|
||||
import time
|
||||
import ollama
|
||||
import aiohttp
|
||||
|
||||
from multilang_translator.translator import syspromts
|
||||
|
||||
@@ -12,10 +13,6 @@ from multilang_translator.translator import syspromts
|
||||
# from_='llama3.2', system="You are Mario from Super Mario Bros."
|
||||
# )
|
||||
|
||||
async def chat():
|
||||
message = {'role': 'user', 'content': 'Why is the sky blue?'}
|
||||
response = await ollama.AsyncClient().chat(model='llama3.2', messages=[message])
|
||||
|
||||
|
||||
def query_openwebui(model, system, query, url, token):
|
||||
url = f'{url}/api/chat/completions'
|
||||
@@ -50,6 +47,41 @@ def query_ollama(model, system, query, host='http://localhost:11434'):
|
||||
return response.message.content
|
||||
|
||||
|
||||
async def query_openwebui_async(model, system, query, url, token):
|
||||
url = f'{url}/api/chat/completions'
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
}
|
||||
payload = {
|
||||
'model': model,
|
||||
'messages': [
|
||||
{'role': 'system', 'content': system},
|
||||
{'role': 'user', 'content': query}
|
||||
],
|
||||
}
|
||||
start = time.time()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
response_json = await response.json()
|
||||
log.info("Translating the text took %s s", round(time.time() - start, 2))
|
||||
return response_json['choices'][0]['message']['content']
|
||||
|
||||
|
||||
async def query_ollama_async(model, system, query, host='http://localhost:11434'):
|
||||
client = ollama.AsyncClient(
|
||||
host=host,
|
||||
)
|
||||
|
||||
response = await client.chat(
|
||||
model=model,
|
||||
messages=[
|
||||
{'role': 'system', 'content': system},
|
||||
{'role': 'user', 'content': query}
|
||||
],
|
||||
)
|
||||
return response.message.content
|
||||
|
||||
|
||||
def translate_de_to_x( # TODO: use async ollama client later - implenent a translate async function
|
||||
text:str,
|
||||
target_language: str,
|
||||
@@ -70,6 +102,27 @@ def translate_de_to_x( # TODO: use async ollama client later - implenent a trans
|
||||
log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3))
|
||||
return response
|
||||
|
||||
async def translate_de_to_x_async(
|
||||
text:str,
|
||||
target_language: str,
|
||||
client='ollama',
|
||||
model='llama3.2:3b-instruct-q4_0', # remember to use instruct models
|
||||
host = None,
|
||||
token = None
|
||||
):
|
||||
start=time.time()
|
||||
s = getattr(syspromts, f"TRANSLATOR_DEU_{target_language.upper()}")
|
||||
|
||||
if client == 'ollama':
|
||||
response = await query_ollama_async(model, s, text, host=host)
|
||||
elif client == 'openwebui':
|
||||
response = await query_openwebui_async(model, s, text, url=host, token=token)
|
||||
else: raise NotImplementedError('llm client not implemented')
|
||||
|
||||
log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3))
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
from multilang_translator.translator import test_content
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
API client functions for interacting with the Translator API.
|
||||
"""
|
||||
import requests
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from enum import Enum
|
||||
|
||||
|
||||
from multilang_translator.translator_models.translator_models import AnnouncementStates, Endpoint, EndpointGroup
|
||||
|
||||
|
||||
# This can be overridden through environment variables
|
||||
API_BASE_URL = "http://localhost:7999"
|
||||
|
||||
def get_groups() -> List[EndpointGroup]:
|
||||
"""Get all endpoint groups."""
|
||||
response = requests.get(f"{API_BASE_URL}/groups")
|
||||
response.raise_for_status()
|
||||
return [EndpointGroup.model_validate(group) for group in response.json()]
|
||||
|
||||
def get_group(group_id: int) -> Optional[EndpointGroup]:
|
||||
"""Get a specific endpoint group by ID."""
|
||||
response = requests.get(f"{API_BASE_URL}/groups/{group_id}")
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
return EndpointGroup.model_validate(response.json())
|
||||
|
||||
def create_group(group: EndpointGroup) -> EndpointGroup:
|
||||
"""Create a new endpoint group."""
|
||||
# Convert the model to a dict with enum values as their primitive values
|
||||
payload = group.model_dump(mode='json')
|
||||
response = requests.post(f"{API_BASE_URL}/groups", json=payload)
|
||||
response.raise_for_status()
|
||||
return EndpointGroup.model_validate(response.json())
|
||||
|
||||
def update_group(group_id: int, updated_group: EndpointGroup) -> EndpointGroup:
|
||||
"""Update an existing endpoint group."""
|
||||
# Convert the model to a dict with enum values as their primitive values
|
||||
payload = updated_group.model_dump(mode='json')
|
||||
response = requests.put(f"{API_BASE_URL}/groups/{group_id}", json=payload)
|
||||
response.raise_for_status()
|
||||
return EndpointGroup.model_validate(response.json())
|
||||
|
||||
def delete_group(group_id: int) -> None:
|
||||
"""Delete an endpoint group."""
|
||||
response = requests.delete(f"{API_BASE_URL}/groups/{group_id}")
|
||||
response.raise_for_status()
|
||||
|
||||
def start_announcement(text: str, group_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Start a new announcement.
|
||||
|
||||
Args:
|
||||
text: The text content of the announcement
|
||||
group_id: The ID of the endpoint group to send the announcement to
|
||||
|
||||
Returns:
|
||||
Dictionary with status information
|
||||
"""
|
||||
response = requests.post(f"{API_BASE_URL}/announcement", params={"text": text, "group_id": group_id})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_group_state(group_id: int) -> Tuple[str, float]:
|
||||
"""
|
||||
Get the status of the current announcement for a specific group.
|
||||
|
||||
Args:
|
||||
group_id: The ID of the group to check the announcement status for
|
||||
|
||||
Returns:
|
||||
Tuple containing (state_name, state_value)
|
||||
"""
|
||||
response = requests.get(f"{API_BASE_URL}/groups/{group_id}/state")
|
||||
response.raise_for_status()
|
||||
state_data = response.json()
|
||||
return (state_data["name"], state_data["value"])
|
||||
|
||||
|
||||
def get_available_endpoints() -> List[Endpoint]:
|
||||
"""Get all available endpoints."""
|
||||
response = requests.get(f"{API_BASE_URL}/endpoints")
|
||||
response.raise_for_status()
|
||||
endpoints_dict = response.json()
|
||||
# API returns a dictionary with endpoint IDs as keys
|
||||
# Convert this to a list of Endpoint objects
|
||||
return [Endpoint.model_validate(endpoint_data) for endpoint_id, endpoint_data in endpoints_dict.items()]
|
||||
|
||||
def get_available_languages() -> List[str]:
|
||||
"""Get all available languages for announcements."""
|
||||
response = requests.get(f"{API_BASE_URL}/languages")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
20
src/multilang_translator/translator_config.py
Normal file
20
src/multilang_translator/translator_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
|
||||
VENV_DIR = os.path.join(os.path.dirname(__file__), './../../venv')
|
||||
|
||||
class TranslatorLangConfig(BaseModel):
|
||||
translator_llm: str = 'llama3.2:3b-instruct-q4_0' # TODO: this was migrated to translator_models - remove this
|
||||
llm_client: str = 'ollama'
|
||||
llm_host_url: str | None = 'http://localhost:11434'
|
||||
llm_host_token: str | None = None
|
||||
tts_system: str = 'piper'
|
||||
tts_model: str ='de_DE-kerstin-low'
|
||||
|
||||
class TranslatorConfig(BaseModel):
|
||||
deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high')
|
||||
eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium')
|
||||
fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium')
|
||||
spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium')
|
||||
ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium')
|
||||
|
||||
1
src/multilang_translator/translator_models/__init__.py
Normal file
1
src/multilang_translator/translator_models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty file to make the directory a package
|
||||
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Models for the translator API.
|
||||
Similar to the models used in auracaster-webui but simplified for the translator middleware.
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnnouncementStates(Enum):
|
||||
IDLE = 0
|
||||
INIT = 0.1
|
||||
TRANSLATING = 0.2
|
||||
GENERATING_VOICE = 0.4
|
||||
ROUTING = 0.6
|
||||
BROADCASTING = 0.8
|
||||
COMPLETED = 1
|
||||
ERROR = 0
|
||||
|
||||
|
||||
class Endpoint(BaseModel):
|
||||
"""Defines an endpoint with its URL and capabilities."""
|
||||
id: int
|
||||
name: str
|
||||
url: str
|
||||
max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts
|
||||
|
||||
class TranslatorLangConfig(BaseModel):
|
||||
translator_llm: str = 'llama3.2:3b-instruct-q4_0'
|
||||
|
||||
llm_client: str = 'openwebui' # remote (homserver)
|
||||
llm_host_url: str = 'https://ollama.pstruebi.xyz'
|
||||
llm_host_token: str = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
|
||||
# llm_client: str = 'ollama' #local
|
||||
# llm_host_url: str | None = 'http://localhost:11434'
|
||||
# llm_host_token: str | None = None
|
||||
|
||||
tts_system: str = 'piper'
|
||||
tts_model: str ='de_DE-kerstin-low'
|
||||
|
||||
|
||||
class TranslatorConfig(BaseModel):
|
||||
deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high')
|
||||
eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium')
|
||||
fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium')
|
||||
spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium')
|
||||
ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium')
|
||||
|
||||
|
||||
class EndpointGroup(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
languages: List[str]
|
||||
endpoints: List[Endpoint]
|
||||
sampling_rate_hz: int = 16000
|
||||
translator_config: TranslatorConfig = TranslatorConfig()
|
||||
current_state: AnnouncementStates = AnnouncementStates.IDLE
|
||||
anouncement_start_time: float = 0.0
|
||||
1
src/multilang_translator/translator_server/__init__.py
Normal file
1
src/multilang_translator/translator_server/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty file to make the directory a package
|
||||
143
src/multilang_translator/translator_server/endpoints_db.py
Normal file
143
src/multilang_translator/translator_server/endpoints_db.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Database file for endpoint definitions.
|
||||
This file contains configurations for auracast endpoints including their IP addresses and capabilities.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from multilang_translator.translator_models.translator_models import EndpointGroup, Endpoint
|
||||
|
||||
|
||||
SUPPORTED_LANGUAGES = ["deu", "eng", "fra", "spa", "ita"]
|
||||
|
||||
# Database of endpoints
|
||||
ENDPOINTS: dict[int: Endpoint] = { # for now make sure, .id and key are the same
|
||||
0: Endpoint(
|
||||
id=0,
|
||||
name="Local Endpoint",
|
||||
url="http://localhost:5000",
|
||||
max_broadcasts=3,
|
||||
),
|
||||
1: Endpoint(
|
||||
id=1,
|
||||
name="Gate 1",
|
||||
url="http://pi3:5000",
|
||||
max_broadcasts=3,
|
||||
),
|
||||
2: Endpoint(
|
||||
id=2,
|
||||
name="Gate 2",
|
||||
url="http://192.168.1.102:5000",
|
||||
max_broadcasts=3,
|
||||
),
|
||||
}
|
||||
|
||||
# Database of endpoint groups with default endpoints
|
||||
ENDPOINT_GROUPS: dict[int:EndpointGroup] = { # for now make sure , .id and key are the same
|
||||
0: EndpointGroup(
|
||||
id=0,
|
||||
name="Local Group",
|
||||
languages=["deu", "eng"],
|
||||
endpoints=[ENDPOINTS[0]],
|
||||
),
|
||||
1: EndpointGroup(
|
||||
id=1,
|
||||
name="Gate1",
|
||||
languages=["deu", "fra"],
|
||||
endpoints=[ENDPOINTS[1]],
|
||||
)
|
||||
}
|
||||
|
||||
def get_available_languages() -> List[str]:
|
||||
"""Get a list of all supported languages."""
|
||||
return SUPPORTED_LANGUAGES
|
||||
|
||||
# Endpoint functions
|
||||
def get_all_endpoints() -> List[Endpoint]:
|
||||
"""Get all active endpoints."""
|
||||
return ENDPOINTS
|
||||
|
||||
def get_endpoint_by_id(endpoint_id: str) -> Optional[Endpoint]:
|
||||
"""Get an endpoint by its ID."""
|
||||
return ENDPOINTS[endpoint_id]
|
||||
|
||||
|
||||
def add_endpoint(endpoint: Endpoint) -> Endpoint:
|
||||
"""Add a new endpoint to the database."""
|
||||
if endpoint.id in ENDPOINTS:
|
||||
raise ValueError(f"Endpoint with ID {endpoint.id} already exists")
|
||||
ENDPOINTS[endpoint.id] = endpoint
|
||||
return endpoint
|
||||
|
||||
|
||||
def update_endpoint(endpoint_id: str, updated_endpoint: Endpoint) -> Endpoint:
|
||||
"""Update an existing endpoint in the database."""
|
||||
if endpoint_id not in ENDPOINTS:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
|
||||
# Ensure the ID is preserved
|
||||
updated_endpoint.id = endpoint_id
|
||||
ENDPOINTS[endpoint_id] = updated_endpoint
|
||||
return updated_endpoint
|
||||
|
||||
|
||||
def delete_endpoint(endpoint_id: str) -> None:
|
||||
"""Delete an endpoint from the database."""
|
||||
if endpoint_id not in ENDPOINTS:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
|
||||
# Check if this endpoint is used in any groups
|
||||
for group in ENDPOINT_GROUPS.values():
|
||||
if endpoint_id in group.endpoints:
|
||||
raise ValueError(f"Cannot delete endpoint {endpoint_id}, it is used in group {group.id}")
|
||||
|
||||
del ENDPOINTS[endpoint_id]
|
||||
|
||||
|
||||
# Endpoint Group functions
|
||||
def get_all_groups() -> List[EndpointGroup]:
|
||||
"""Get all endpoint groups."""
|
||||
return list(ENDPOINT_GROUPS.values())
|
||||
|
||||
|
||||
def get_group_by_id(group_id: int) -> Optional[EndpointGroup]:
|
||||
"""Get an endpoint group by its ID."""
|
||||
return ENDPOINT_GROUPS.get(group_id)
|
||||
|
||||
|
||||
def add_group(group: EndpointGroup) -> EndpointGroup:
|
||||
"""Add a new endpoint group to the database."""
|
||||
if group.id in ENDPOINT_GROUPS:
|
||||
raise ValueError(f"Group with ID {group.id} already exists")
|
||||
|
||||
# Validate that all referenced endpoints exist
|
||||
for endpoint_id in group.endpoints:
|
||||
if endpoint_id not in ENDPOINTS:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
|
||||
ENDPOINT_GROUPS[group.id] = group
|
||||
return group
|
||||
|
||||
|
||||
def update_group(group_id: int, updated_group: EndpointGroup) -> EndpointGroup:
|
||||
"""Update an existing endpoint group in the database."""
|
||||
if group_id not in ENDPOINT_GROUPS:
|
||||
raise ValueError(f"Group {group_id} not found")
|
||||
|
||||
# Validate that all referenced endpoints exist
|
||||
for endpoint in updated_group.endpoints:
|
||||
if endpoint.id not in ENDPOINTS.keys():
|
||||
raise ValueError(f"Endpoint with id {endpoint.id} not found")
|
||||
|
||||
# Ensure the ID is preserved
|
||||
updated_group.id = group_id
|
||||
ENDPOINT_GROUPS[group_id] = updated_group
|
||||
return updated_group
|
||||
|
||||
|
||||
def delete_group(group_id: int) -> None:
|
||||
"""Delete an endpoint group from the database."""
|
||||
if group_id not in ENDPOINT_GROUPS:
|
||||
raise ValueError(f"Group {group_id} not found")
|
||||
|
||||
del ENDPOINT_GROUPS[group_id]
|
||||
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Entry point for the Translator API server.
|
||||
This file starts the FastAPI server with the translator_server.
|
||||
"""
|
||||
import uvicorn
|
||||
import logging as log
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the parent directory to the Python path to find the multilang_translator package
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
log.basicConfig(
|
||||
level=log.INFO,
|
||||
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||
)
|
||||
log.info("Starting Translator API server")
|
||||
uvicorn.run(
|
||||
"multilang_translator.translator_server.translator_server:app",
|
||||
host="0.0.0.0",
|
||||
port=7999,
|
||||
reload=True,
|
||||
log_level="debug"
|
||||
)
|
||||
346
src/multilang_translator/translator_server/translator_server.py
Normal file
346
src/multilang_translator/translator_server/translator_server.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
FastAPI implementation of the Multilang Translator API.
|
||||
This API mimics the mock_api from auracaster-webui to allow integration.
|
||||
"""
|
||||
import time
|
||||
import logging as log
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# Import models
|
||||
from multilang_translator.translator_models.translator_models import AnnouncementStates, Endpoint, EndpointGroup
|
||||
from multilang_translator.translator import llm_translator
|
||||
from multilang_translator.translator_server import endpoints_db
|
||||
from voice_provider import text_to_speech
|
||||
|
||||
# Import the endpoints database and multicast client
|
||||
from auracast import multicast_client, auracast_config
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware to allow cross-origin requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Endpoint configuration cache
|
||||
CURRENT_ENDPOINT_CONFIG = {}
|
||||
|
||||
async def init_endpoint(endpoint: Endpoint, languages: list[str], sampling_rate_hz: int):
|
||||
"""Initialize a specific endpoint for multicast."""
|
||||
|
||||
current_config = CURRENT_ENDPOINT_CONFIG.get(endpoint.id)
|
||||
|
||||
if current_config is not None:
|
||||
current_langs = [big.language for big in current_config.bigs]
|
||||
# if languages are unchanged and the caster client status is initiailized, skip init
|
||||
if current_langs == languages:
|
||||
# Get status asynchronously
|
||||
status = await multicast_client.get_status(base_url=endpoint.url)
|
||||
if status['is_initialized']:
|
||||
log.info('Endpoint %s was already initialized', endpoint.name)
|
||||
return
|
||||
|
||||
log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.url}")
|
||||
# Load a default config
|
||||
config = auracast_config.AuracastConfigGroup(
|
||||
bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")()
|
||||
for lang in languages]
|
||||
)
|
||||
|
||||
# overwrite some default configs
|
||||
config.transport = 'auto'
|
||||
config.auracast_device_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6))
|
||||
config.auracast_sampling_rate_hz = sampling_rate_hz
|
||||
|
||||
# Configure the bigs
|
||||
for big in config.bigs:
|
||||
big.loop = False
|
||||
big.name = endpoint.name
|
||||
big.random_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6))
|
||||
big.id = random.randint(0, 2**16) #TODO: how many bits is this ?
|
||||
#big.program_info = big.program_info + ' ' + endpoint.name
|
||||
|
||||
# make async init request
|
||||
ret = await multicast_client.init(config, base_url=endpoint.url)
|
||||
# if ret != 200: # TODO: this is not working, should probably be handled async
|
||||
# log.error('Init of endpoint %s was unsucessfull', endpoint.name)
|
||||
# raise Exception(f"Init was of endpoint {endpoint.name} was unsucessfull")
|
||||
CURRENT_ENDPOINT_CONFIG[endpoint.id] = config.model_copy()
|
||||
log.info(f"Endpoint {endpoint.name} initialized successfully")
|
||||
#else:
|
||||
# log.info('Endpoint %s was already initialized', endpoint.name)
|
||||
|
||||
|
||||
async def make_announcement(text: str, ep_group: EndpointGroup):
|
||||
"""
|
||||
Make an announcement to a group of endpoints.
|
||||
"""
|
||||
|
||||
if text == "":
|
||||
log.warning("Announcement text is empty")
|
||||
return {"error": "Announcement text is empty"}
|
||||
|
||||
ep_group.current_state = AnnouncementStates.IDLE
|
||||
ep_group.anouncement_start_time = time.time()
|
||||
# update the database with the new state and start time so this can be read by another process
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
# Initialize all endpoints in the group concurrently
|
||||
ep_group.current_state = AnnouncementStates.INIT
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
# Create init tasks and run them concurrently
|
||||
init_tasks = [
|
||||
init_endpoint(endpoint, ep_group.languages, ep_group.sampling_rate_hz)
|
||||
for endpoint in ep_group.endpoints
|
||||
]
|
||||
|
||||
# make sure init finished
|
||||
await asyncio.gather(*init_tasks)
|
||||
|
||||
# Translate the text for each language (concurrently)
|
||||
base_lang = "deu" # German is the base language
|
||||
target_langs = ep_group.languages.copy()
|
||||
if base_lang in target_langs:
|
||||
target_langs.remove(base_lang)
|
||||
|
||||
ep_group.current_state = AnnouncementStates.TRANSLATING
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
# Create translation tasks
|
||||
translations = {base_lang: text}
|
||||
translation_tasks = []
|
||||
|
||||
for lang in target_langs:
|
||||
# Prepare translation task
|
||||
trans_conf = getattr(ep_group.translator_config, lang)
|
||||
task = llm_translator.translate_de_to_x_async(
|
||||
text=text,
|
||||
target_language=lang,
|
||||
client=trans_conf.llm_client,
|
||||
model=trans_conf.translator_llm,
|
||||
host=trans_conf.llm_host_url,
|
||||
token=trans_conf.llm_host_token
|
||||
)
|
||||
translation_tasks.append(task)
|
||||
|
||||
# Wait for all translations to complete concurrently
|
||||
results = await asyncio.gather(*translation_tasks)
|
||||
for i, translation in enumerate(results):
|
||||
lang = target_langs[i]
|
||||
translations[lang] = translation
|
||||
log.info(f"Translated to {lang}: {translation}")
|
||||
|
||||
# Generate voices concurrently
|
||||
ep_group.current_state = AnnouncementStates.GENERATING_VOICE
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
# Prepare synthesis tasks and run them concurrently
|
||||
synth_langs = ep_group.languages
|
||||
synthesis_tasks = []
|
||||
for lang in synth_langs:
|
||||
trans_conf = getattr(ep_group.translator_config, lang)
|
||||
task = text_to_speech.synthesize_async(
|
||||
translations[lang],
|
||||
ep_group.sampling_rate_hz,
|
||||
trans_conf.tts_system,
|
||||
trans_conf.tts_model,
|
||||
return_lc3=True
|
||||
)
|
||||
synthesis_tasks.append(task)
|
||||
|
||||
# Wait for all synthesis tasks to complete concurrently
|
||||
audio = {}
|
||||
if synthesis_tasks:
|
||||
results = await asyncio.gather(*synthesis_tasks)
|
||||
for i, audio_data in enumerate(results):
|
||||
audio[synth_langs[i]] = audio_data
|
||||
|
||||
# Start the monitoring coroutine to wait for streaming to complete
|
||||
# This will set the state to COMPLETED when finished
|
||||
asyncio.create_task(monitor_streaming_completion(ep_group))
|
||||
|
||||
# Broadcast to all endpoints in group concurrently
|
||||
broadcast_tasks = []
|
||||
for endpoint in ep_group.endpoints:
|
||||
log.info(f"Broadcasting to {endpoint.name} for languages: {', '.join(audio.keys())}")
|
||||
task = multicast_client.send_audio(audio, base_url=endpoint.url)
|
||||
broadcast_tasks.append(task)
|
||||
|
||||
# Wait for all broadcasts to complete
|
||||
await asyncio.gather(*broadcast_tasks)
|
||||
|
||||
# Return the translations
|
||||
return {"translations": translations}
|
||||
|
||||
|
||||
async def monitor_streaming_completion(ep_group: EndpointGroup):
|
||||
"""
|
||||
Monitor streaming status after audio is sent and update group state when complete.
|
||||
|
||||
Args:
|
||||
ep_group: The endpoint group being monitored
|
||||
"""
|
||||
log.info(f"Starting streaming completion monitoring for endpoint group {ep_group.id}")
|
||||
|
||||
|
||||
# Set a shorter timeout as requested
|
||||
max_completion_time = 60 # seconds
|
||||
|
||||
# First check if we are actually in streaming state
|
||||
streaming_started = False
|
||||
initial_check_timeout = 10 # seconds
|
||||
initial_check_start = time.time()
|
||||
|
||||
# Wait for streaming to start (with timeout)
|
||||
while time.time() - initial_check_start < initial_check_timeout:
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(1)
|
||||
|
||||
any_streaming = False
|
||||
for endpoint in ep_group.endpoints:
|
||||
status = await multicast_client.get_status(base_url=endpoint.url)
|
||||
if status.get("is_streaming", False):
|
||||
any_streaming = True
|
||||
log.info(f"Streaming confirmed started on endpoint {endpoint.name}")
|
||||
break
|
||||
|
||||
if any_streaming:
|
||||
streaming_started = True
|
||||
break
|
||||
|
||||
if not streaming_started:
|
||||
log.warning(f"No endpoints started streaming for group {ep_group.id} after {initial_check_timeout}s")
|
||||
# Still update to completed since there's nothing to wait for
|
||||
ep_group.current_state = AnnouncementStates.ERROR
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
return
|
||||
|
||||
# Update group progress
|
||||
ep_group.current_state = AnnouncementStates.BROADCASTING
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
# Now monitor until streaming completes on all endpoints
|
||||
check_completion_start_time = time.time()
|
||||
completed = [False for _ in ep_group.endpoints]
|
||||
while not all(completed) or time.time() - check_completion_start_time > max_completion_time:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check status of each endpoint
|
||||
for i, endpoint in enumerate(ep_group.endpoints):
|
||||
status = await multicast_client.get_status(base_url=endpoint.url)
|
||||
completed[i] = not status['is_streaming']
|
||||
|
||||
if all(completed):
|
||||
log.info(f"All endpoints completed streaming for group {ep_group.id}")
|
||||
# Update group state to completed
|
||||
ep_group.current_state = AnnouncementStates.COMPLETED
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
log.info(f"Updated group {ep_group.id} state to COMPLETED")
|
||||
|
||||
else:
|
||||
log.error(f"Max wait time reached for group {ep_group.id}. Forcing completion.")
|
||||
|
||||
|
||||
@app.get("/groups")
|
||||
async def get_groups():
|
||||
"""Get all endpoint groups with their current status."""
|
||||
return endpoints_db.get_all_groups()
|
||||
|
||||
|
||||
@app.post("/groups")
|
||||
async def create_group(group: endpoints_db.EndpointGroup):
|
||||
"""Add a new endpoint group."""
|
||||
try:
|
||||
return endpoints_db.add_group(group)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@app.get("/groups/{group_id}/state") # TODO: think about progress tracking
|
||||
async def get_group_state(group_id: int):
|
||||
"""Get the status of a specific endpoint."""
|
||||
# Check if the endpoint exists
|
||||
ep_group = endpoints_db.get_group_by_id(group_id)
|
||||
if not ep_group:
|
||||
raise HTTPException(status_code=404, detail=f"Endpoint {group_id} not found")
|
||||
|
||||
return {"name": ep_group.current_state.name, "value": ep_group.current_state.value}
|
||||
|
||||
|
||||
@app.put("/groups/{group_id}")
|
||||
async def update_group(group_id: int, updated_group: endpoints_db.EndpointGroup):
|
||||
"""Update an existing endpoint group."""
|
||||
try:
|
||||
return endpoints_db.update_group(group_id, updated_group)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.delete("/groups/{group_id}")
|
||||
async def delete_group(group_id: int):
|
||||
"""Delete an endpoint group."""
|
||||
try:
|
||||
endpoints_db.delete_group(group_id)
|
||||
return {"message": f"Group {group_id} deleted successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/announcement")
|
||||
async def start_announcement(text: str, group_id: int):
|
||||
"""Start a new announcement to the specified endpoint group."""
|
||||
global announcement_task
|
||||
|
||||
# Get the group from active groups or database
|
||||
group = endpoints_db.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=400, detail=f"Group {group_id} not found")
|
||||
|
||||
# Check if we're already processing an announcement
|
||||
#if announcement_task and not announcement_task.done():
|
||||
# raise HTTPException(status_code=400, detail="Already processing an announcement")
|
||||
|
||||
# Start the announcement task
|
||||
announcement_task = asyncio.create_task(make_announcement(text, group))
|
||||
return {"status": "Announcement started", "group_id": group_id}
|
||||
|
||||
|
||||
|
||||
@app.get("/endpoints")
|
||||
async def get_available_endpoints():
|
||||
"""Get all available endpoints with their capabilities."""
|
||||
return endpoints_db.get_all_endpoints()
|
||||
|
||||
|
||||
@app.get("/languages")
|
||||
async def get_available_languages():
|
||||
"""Get all available languages for announcements."""
|
||||
return endpoints_db.get_available_languages()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
log.basicConfig(
|
||||
level=log.DEBUG,
|
||||
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||
)
|
||||
# with reload=True logging of modules does not function as expected
|
||||
uvicorn.run(
|
||||
app,
|
||||
#'translator_server:app',
|
||||
host="0.0.0.0",
|
||||
port=7999,
|
||||
#reload=True,
|
||||
#log_config=None,
|
||||
#log_level="info"
|
||||
)
|
||||
44
src/voice_client/tts_client.py
Normal file
44
src/voice_client/tts_client.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import requests
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from voice_models.request_models import SynthesizeRequest
|
||||
|
||||
|
||||
API_URL = "http://127.0.0.1:8099/synthesize/"
|
||||
|
||||
def request_synthesis(request_data: SynthesizeRequest):
|
||||
response = requests.post(API_URL, json=request_data.model_dump())
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
|
||||
if request_data.return_lc3:
|
||||
# Save LC3 audio as binary file
|
||||
lc3_bytes = bytes.fromhex(response_data["audio_lc3"])
|
||||
return lc3_bytes
|
||||
|
||||
else:
|
||||
# Convert hex-encoded PCM bytes back to numpy array and save as WAV
|
||||
audio_bytes = bytes.fromhex(response_data["audio_pcm"])
|
||||
audio_array = np.frombuffer(audio_bytes, dtype=np.float32)
|
||||
return audio_array
|
||||
|
||||
else:
|
||||
print(f"Error: {response.status_code}, {response.text}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
target_rate=16000
|
||||
|
||||
# Example request
|
||||
request_data = SynthesizeRequest(
|
||||
text="Hello, this is a test.",
|
||||
target_sample_rate=target_rate,
|
||||
framework="piper",
|
||||
model="de_DE-kerstin-low",
|
||||
return_lc3=False # Set to True to receive LC3 compressed output
|
||||
)
|
||||
|
||||
audio = request_synthesis(request_data)
|
||||
sf.write('hello.wav', audio, target_rate)
|
||||
9
src/voice_models/request_models.py
Normal file
9
src/voice_models/request_models.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class SynthesizeRequest(BaseModel):
|
||||
text: str
|
||||
target_sample_rate: int = 16000
|
||||
framework: str = "piper"
|
||||
model: str = "en_US-lessac-medium"
|
||||
return_lc3: bool = False
|
||||
|
||||
@@ -1,27 +1,34 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import json
|
||||
import logging as log
|
||||
import numpy as np
|
||||
from multilang_translator import translator_config
|
||||
from multilang_translator.utils.resample import resample_array
|
||||
from multilang_translator.text_to_speech import encode_lc3
|
||||
import asyncio
|
||||
from voice_provider.utils.resample import resample_array
|
||||
from voice_provider.utils.encode_lc3 import encode_lc3
|
||||
|
||||
PIPER_EXE = shutil.which('piper')
|
||||
|
||||
TTS_DIR = os.path.join(os.path.dirname(__file__))
|
||||
PIPER_DIR = f'{TTS_DIR}/piper'
|
||||
PIPER_WORKDIR = f'{TTS_DIR}/piper'
|
||||
|
||||
if not PIPER_EXE:
|
||||
PIPER_EXE = f'{TTS_DIR}/../../venv/bin/piper'
|
||||
|
||||
def synth_piper(text, model="en_US-lessac-medium"):
|
||||
pwd = os.getcwd()
|
||||
os.chdir(PIPER_DIR)
|
||||
os.chdir(PIPER_WORKDIR)
|
||||
start = time.time()
|
||||
|
||||
# make sure piper has voices.json in working directory, otherwise it attempts to always load models
|
||||
ret = subprocess.run( # TODO: wrap this whole thing in a class and open a permanent pipe to the model
|
||||
[translator_config.PIPER_EXE_PATH,
|
||||
'--cuda',
|
||||
'--model', model,
|
||||
'--output-raw'
|
||||
[
|
||||
PIPER_EXE,
|
||||
'--cuda',
|
||||
'--model', model,
|
||||
'--output-raw'
|
||||
],
|
||||
input=text.encode('utf-8'),
|
||||
capture_output=True
|
||||
@@ -34,14 +41,19 @@ def synth_piper(text, model="en_US-lessac-medium"):
|
||||
|
||||
log.info("Running piper for model %s took %s s", model, round(time.time() - start, 3))
|
||||
|
||||
with open (f'{PIPER_DIR}/{model}.onnx.json') as f: # TODO: wrap everyth0ing into a class, store the json permanently
|
||||
with open (f'{PIPER_WORKDIR}/{model}.onnx.json') as f: # TODO: wrap everyth0ing into a class, store the json permanently
|
||||
model_json = json.load(f)
|
||||
|
||||
return model_json, audio
|
||||
|
||||
|
||||
# TODO: framework should probably be a dataclass that holds all the relevant informations, also model
|
||||
def synthesize(text, target_sample_rate, framework, model="en_US-lessac-medium", return_lc3=True):
|
||||
def synthesize(
|
||||
text,
|
||||
target_sample_rate,
|
||||
framework,
|
||||
model="en_US-lessac-medium",
|
||||
return_lc3=True
|
||||
):
|
||||
|
||||
if framework == 'piper':
|
||||
model_json, audio_raw = synth_piper(text, model)
|
||||
@@ -59,12 +71,41 @@ def synthesize(text, target_sample_rate, framework, model="en_US-lessac-medium",
|
||||
|
||||
if return_lc3:
|
||||
audio_pcm = (audio * 2**15-1).astype(np.int16)
|
||||
lc3 = encode_lc3.encode(audio_pcm, target_sample_rate, 40) # TODO: octetts per frame should be parameter
|
||||
lc3 = encode_lc3(audio_pcm, target_sample_rate, 40) # TODO: octetts per frame should be parameter
|
||||
return lc3
|
||||
else:
|
||||
return audio
|
||||
|
||||
|
||||
async def synthesize_async(
|
||||
text,
|
||||
target_sample_rate,
|
||||
framework,
|
||||
model="en_US-lessac-medium",
|
||||
return_lc3=True
|
||||
):
|
||||
"""
|
||||
Asynchronous version of the synthesize function that runs in a thread pool.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
target_sample_rate: Target sample rate for the audio
|
||||
framework: TTS framework to use (e.g., 'piper')
|
||||
model: Model to use for synthesis
|
||||
return_lc3: Whether to return LC3-encoded audio
|
||||
|
||||
Returns:
|
||||
LC3-encoded audio as string or raw audio as numpy array
|
||||
"""
|
||||
# Run the CPU-intensive synthesis in a thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: synthesize(text, target_sample_rate, framework, model, return_lc3)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import logging
|
||||
import soundfile as sf
|
||||
@@ -79,5 +120,4 @@ if __name__ == '__main__':
|
||||
|
||||
sf.write('hello.wav', audio, target_rate)
|
||||
|
||||
# TODO: "WARNING:piper.download:Wrong size (expected=5952, actual=4158
|
||||
print('Done.')
|
||||
43
src/voice_provider/tts_server.py
Normal file
43
src/voice_provider/tts_server.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
import numpy as np
|
||||
|
||||
from voice_models.request_models import SynthesizeRequest
|
||||
from voice_provider.text_to_speech import synthesize_async
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
HOST_PORT = 8099
|
||||
|
||||
|
||||
@app.post("/synthesize/")
|
||||
async def synthesize_speech(request: SynthesizeRequest):
|
||||
try:
|
||||
audio = await synthesize_async(
|
||||
text=request.text,
|
||||
target_sample_rate=request.target_sample_rate,
|
||||
framework=request.framework,
|
||||
model=request.model,
|
||||
return_lc3=request.return_lc3
|
||||
)
|
||||
|
||||
if request.return_lc3:
|
||||
# If it's already a string (LC3 data), convert it to bytes for hex encoding
|
||||
if isinstance(audio, str):
|
||||
audio_bytes = audio.encode('latin-1')
|
||||
return {"audio_lc3": audio_bytes.hex()}
|
||||
# If it's already bytes
|
||||
elif isinstance(audio, bytes):
|
||||
return {"audio_lc3": audio.hex()}
|
||||
else:
|
||||
raise ValueError(f"Unexpected audio type: {type(audio)}")
|
||||
else:
|
||||
# If it's numpy array (non-LC3), convert to bytes
|
||||
audio_bytes = audio.astype(np.float32).tobytes()
|
||||
return {"audio_pcm": audio_bytes.hex()}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="127.0.0.1", port=HOST_PORT)
|
||||
0
src/voice_provider/utils/__init__.py
Normal file
0
src/voice_provider/utils/__init__.py
Normal file
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import lc3
|
||||
|
||||
def encode(
|
||||
def encode_lc3(
|
||||
audio: np.array,
|
||||
output_sample_rate_hz,
|
||||
octets_per_frame,
|
||||
9
tests/get_group_state.py
Normal file
9
tests/get_group_state.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import requests
|
||||
import time
|
||||
|
||||
if __name__ == '__main__':
|
||||
# get the group state every 0.5s
|
||||
while True:
|
||||
response = requests.get('http://localhost:7999/groups/0/state')
|
||||
print(response.json())
|
||||
time.sleep(0.5)
|
||||
Reference in New Issue
Block a user