build a minimal working example for making an announcement
This commit is contained in:
@@ -2,60 +2,144 @@
|
||||
Database file for endpoint definitions.
|
||||
This file contains configurations for auracast endpoints including their IP addresses and capabilities.
|
||||
"""
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Set
|
||||
from pydantic import BaseModel
|
||||
from multilang_translator.translator_api.translator_models import EndpointGroup, Endpoint
|
||||
|
||||
|
||||
class EndpointDefinition(BaseModel):
|
||||
"""Defines an endpoint with its URL and capabilities."""
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts
|
||||
requires_authentication: bool = False # Whether authentication is required
|
||||
|
||||
SUPPORTED_LANGUAGES = ["deu", "eng", "fra", "spa", "ita"]
|
||||
|
||||
# Database of endpoints
|
||||
ENDPOINTS: Dict[str, EndpointDefinition] = {
|
||||
"endpoint0": EndpointDefinition(
|
||||
id="endpoint0",
|
||||
ENDPOINTS: dict[int: Endpoint] = { # for now make sure, .id and key are the same
|
||||
0: Endpoint(
|
||||
id=0,
|
||||
name="Local Endpoint",
|
||||
url="http://localhost:5000",
|
||||
max_broadcasts=2,
|
||||
max_broadcasts=3,
|
||||
),
|
||||
"endpoint1": EndpointDefinition(
|
||||
id="endpoint1",
|
||||
1: Endpoint(
|
||||
id=1,
|
||||
name="Gate 1 Endpoint",
|
||||
url="http://192.168.1.101:5000",
|
||||
max_broadcasts=2,
|
||||
max_broadcasts=3,
|
||||
),
|
||||
"endpoint2": EndpointDefinition(
|
||||
id="endpoint2",
|
||||
2: Endpoint(
|
||||
id=2,
|
||||
name="Gate 2 Endpoint",
|
||||
url="http://192.168.1.102:5000",
|
||||
max_broadcasts=1,
|
||||
max_broadcasts=3,
|
||||
),
|
||||
}
|
||||
|
||||
# Database of endpoint groups with default endpoints
|
||||
ENDPOINT_GROUPS: dict[int:EndpointGroup] = { # for now make sure , .id and key are the same
|
||||
0: EndpointGroup(
|
||||
id=0,
|
||||
name="Gate1",
|
||||
languages=["deu", "eng"],
|
||||
endpoints=[ENDPOINTS[0]],
|
||||
),
|
||||
1: EndpointGroup(
|
||||
id=1,
|
||||
name="Gate2",
|
||||
languages=["deu", "eng", "fra"],
|
||||
endpoints=[ENDPOINTS[2]],
|
||||
)
|
||||
}
|
||||
|
||||
def get_all_endpoints() -> List[EndpointDefinition]:
|
||||
def get_available_languages() -> List[str]:
|
||||
"""Get a list of all supported languages."""
|
||||
return SUPPORTED_LANGUAGES
|
||||
|
||||
# Endpoint functions
|
||||
def get_all_endpoints() -> List[Endpoint]:
|
||||
"""Get all active endpoints."""
|
||||
return ENDPOINTS.values()
|
||||
return ENDPOINTS
|
||||
|
||||
def get_endpoint_by_id(endpoint_id: str) -> Optional[EndpointDefinition]:
|
||||
|
||||
def get_endpoint_by_id(endpoint_id: str) -> Optional[Endpoint]:
|
||||
"""Get an endpoint by its ID."""
|
||||
return ENDPOINTS.get(endpoint_id)
|
||||
return ENDPOINTS[endpoint_id]
|
||||
|
||||
def get_max_broadcasts(endpoint_id: str) -> int:
|
||||
"""Get the maximum number of simultaneous broadcasts for a specific endpoint."""
|
||||
endpoint = get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
return endpoint.max_broadcasts
|
||||
|
||||
def get_endpoint_url(endpoint_id: str) -> str:
|
||||
"""Get the full URL for an endpoint."""
|
||||
endpoint = get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
def add_endpoint(endpoint: Endpoint) -> Endpoint:
|
||||
"""Add a new endpoint to the database."""
|
||||
if endpoint.id in ENDPOINTS:
|
||||
raise ValueError(f"Endpoint with ID {endpoint.id} already exists")
|
||||
ENDPOINTS[endpoint.id] = endpoint
|
||||
return endpoint
|
||||
|
||||
|
||||
def update_endpoint(endpoint_id: str, updated_endpoint: Endpoint) -> Endpoint:
|
||||
"""Update an existing endpoint in the database."""
|
||||
if endpoint_id not in ENDPOINTS:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
return endpoint.url
|
||||
|
||||
# Ensure the ID is preserved
|
||||
updated_endpoint.id = endpoint_id
|
||||
ENDPOINTS[endpoint_id] = updated_endpoint
|
||||
return updated_endpoint
|
||||
|
||||
|
||||
def delete_endpoint(endpoint_id: str) -> None:
|
||||
"""Delete an endpoint from the database."""
|
||||
if endpoint_id not in ENDPOINTS:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
|
||||
# Check if this endpoint is used in any groups
|
||||
for group in ENDPOINT_GROUPS.values():
|
||||
if endpoint_id in group.endpoints:
|
||||
raise ValueError(f"Cannot delete endpoint {endpoint_id}, it is used in group {group.id}")
|
||||
|
||||
del ENDPOINTS[endpoint_id]
|
||||
|
||||
|
||||
# Endpoint Group functions
|
||||
def get_all_groups() -> List[EndpointGroup]:
|
||||
"""Get all endpoint groups."""
|
||||
return list(ENDPOINT_GROUPS.values())
|
||||
|
||||
|
||||
def get_group_by_id(group_id: int) -> Optional[EndpointGroup]:
|
||||
"""Get an endpoint group by its ID."""
|
||||
return ENDPOINT_GROUPS.get(group_id)
|
||||
|
||||
|
||||
def add_group(group: EndpointGroup) -> EndpointGroup:
|
||||
"""Add a new endpoint group to the database."""
|
||||
if group.id in ENDPOINT_GROUPS:
|
||||
raise ValueError(f"Group with ID {group.id} already exists")
|
||||
|
||||
# Validate that all referenced endpoints exist
|
||||
for endpoint_id in group.endpoints:
|
||||
if endpoint_id not in ENDPOINTS:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
|
||||
ENDPOINT_GROUPS[group.id] = group
|
||||
return group
|
||||
|
||||
|
||||
def update_group(group_id: int, updated_group: EndpointGroup) -> EndpointGroup:
|
||||
"""Update an existing endpoint group in the database."""
|
||||
if group_id not in ENDPOINT_GROUPS:
|
||||
raise ValueError(f"Group {group_id} not found")
|
||||
|
||||
# Validate that all referenced endpoints exist
|
||||
for endpoint in updated_group.endpoints:
|
||||
if endpoint.id not in ENDPOINTS.keys():
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
|
||||
# Ensure the ID is preserved
|
||||
updated_group.id = group_id
|
||||
ENDPOINT_GROUPS[group_id] = updated_group
|
||||
return updated_group
|
||||
|
||||
|
||||
def delete_group(group_id: int) -> None:
|
||||
"""Delete an endpoint group from the database."""
|
||||
if group_id not in ENDPOINT_GROUPS:
|
||||
raise ValueError(f"Group {group_id} not found")
|
||||
|
||||
del ENDPOINT_GROUPS[group_id]
|
||||
|
||||
|
||||
|
||||
@@ -4,22 +4,18 @@ This API mimics the mock_api from auracaster-webui to allow integration.
|
||||
"""
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import logging as log
|
||||
from typing import List, Dict, Optional
|
||||
import logging as log
|
||||
|
||||
# Import models
|
||||
from multilang_translator.translator_api.translator_models import EndpointGroup, AnnouncementStates
|
||||
from multilang_translator.translator_config import TranslatorConfigGroup, TranslatorConfigDeu, TranslatorConfigEng, TranslatorConfigFra
|
||||
from multilang_translator.translator_api.translator_models import AnnouncementStates, Endpoint, EndpointGroup
|
||||
from multilang_translator.translator import llm_translator
|
||||
from auracast import multicast_control
|
||||
from voice_provider import text_to_speech
|
||||
|
||||
# Import the endpoints database and multicast client
|
||||
from multilang_translator.translator_api import endpoints_db
|
||||
from auracast import multicast_client
|
||||
from auracast import multicast_client, auracast_config
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI()
|
||||
@@ -33,530 +29,226 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Global variables
|
||||
config = TranslatorConfigGroup(
|
||||
bigs=[
|
||||
TranslatorConfigDeu(),
|
||||
TranslatorConfigEng(),
|
||||
TranslatorConfigFra(),
|
||||
]
|
||||
)
|
||||
|
||||
# Configure the transport
|
||||
config.transport = 'serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' # nrf52dongle hci_uart usb cdc
|
||||
|
||||
# Configure the bigs
|
||||
for conf in config.bigs:
|
||||
conf.loop = False
|
||||
conf.llm_client = 'openwebui' # comment out for local llm
|
||||
conf.llm_host_url = 'https://ollama.pstruebi.xyz'
|
||||
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
|
||||
|
||||
# Get available endpoints and languages from the database
|
||||
AVAILABLE_ENDPOINTS = endpoints_db.get_all_endpoints()
|
||||
AVAILABLE_LANGUAGES = ['deu', 'eng', 'fra']
|
||||
AVAILABLE_LANGUAGES = endpoints_db.get_available_languages()
|
||||
|
||||
# Default endpoint groups
|
||||
endpoint_groups = [
|
||||
EndpointGroup(
|
||||
id=1,
|
||||
name="Gate1",
|
||||
endpoints=["endpoint1", "endpoint2"],
|
||||
languages=["deu", "eng"]
|
||||
),
|
||||
EndpointGroup(
|
||||
id=2,
|
||||
name="Gate2",
|
||||
endpoints=["endpoint3"],
|
||||
languages=["deu", "eng", "fra"]
|
||||
)
|
||||
]
|
||||
CURRENT_ENDPOINT_CONFIG = {}
|
||||
|
||||
# Announcement state tracking
|
||||
active_group_id = None
|
||||
last_completed_group_id = None
|
||||
announcement_task = None
|
||||
caster = None
|
||||
reset_task = None
|
||||
|
||||
# Initialization flag
|
||||
initialized = False
|
||||
|
||||
# Dictionary to track endpoint status and active broadcasts
|
||||
endpoint_status = {}
|
||||
|
||||
|
||||
async def init_caster():
|
||||
"""Initialize the multicaster if not already initialized."""
|
||||
global caster, initialized
|
||||
|
||||
if not initialized:
|
||||
log.info("Initializing caster...")
|
||||
try:
|
||||
caster = multicast_control.Multicaster(
|
||||
config,
|
||||
[big for big in config.bigs]
|
||||
)
|
||||
await caster.init_broadcast()
|
||||
initialized = True
|
||||
log.info("Caster initialized successfully")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize caster: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
async def init_endpoint(endpoint_id: str):
|
||||
def init_endpoint(endpoint: Endpoint, languages: list[str]):
|
||||
"""Initialize a specific endpoint for multicast."""
|
||||
endpoint = endpoints_db.get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
|
||||
log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.ip_address}:{endpoint.port}")
|
||||
|
||||
# Update the BASE_URL in multicast_client for this endpoint
|
||||
multicast_client.BASE_URL = f"http://{endpoint.ip_address}:{endpoint.port}"
|
||||
|
||||
|
||||
log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.url}")
|
||||
# Create a config for this endpoint
|
||||
endpoint_config = multicast_client.AuracastConfigGroup(
|
||||
config = multicast_client.AuracastConfigGroup(
|
||||
bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")()
|
||||
for lang in endpoint.supported_languages]
|
||||
for lang in languages]
|
||||
)
|
||||
endpoint_config.transport = config.transport
|
||||
|
||||
|
||||
# some default configs (for now)
|
||||
# Configure the transport
|
||||
config.transport = 'auto'
|
||||
|
||||
# Configure the bigs
|
||||
for conf in endpoint_config.bigs:
|
||||
conf.loop = False
|
||||
|
||||
try:
|
||||
# Initialize the endpoint
|
||||
multicast_client.request_init(endpoint_config)
|
||||
endpoint_status[endpoint_id] = {
|
||||
"active": True,
|
||||
"broadcasts": 0,
|
||||
"max_broadcasts": endpoint.max_broadcasts
|
||||
}
|
||||
log.info(f"Endpoint {endpoint_id} initialized successfully")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize endpoint {endpoint_id}: {e}")
|
||||
endpoint_status[endpoint_id] = {
|
||||
"active": False,
|
||||
"error": str(e)
|
||||
}
|
||||
raise e
|
||||
for conf in config.bigs: # TODO: this is now part of the endpoint group config
|
||||
conf.loop = False
|
||||
# conf.llm_client = 'openwebui' # comment out for local llm
|
||||
# conf.llm_host_url = 'https://ollama.pstruebi.xyz'
|
||||
# conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
|
||||
|
||||
# Initialize the endpoint if config changed or if it's not already initialized
|
||||
if not multicast_client.get_status(base_url=endpoint.url)['is_initialized'] or config != CURRENT_ENDPOINT_CONFIG.get(endpoint.id):
|
||||
multicast_client.request_init(config, base_url=endpoint.url)
|
||||
CURRENT_ENDPOINT_CONFIG[endpoint.id] = config
|
||||
else:
|
||||
log.info('Endpoint %s was already initialized', endpoint.name)
|
||||
|
||||
log.info(f"Endpoint {endpoint.name} initialized successfully")
|
||||
|
||||
|
||||
async def process_announcement(text: str, group: EndpointGroup):
|
||||
async def make_announcement(text: str, ep_group: endpoints_db.EndpointGroup):
|
||||
"""
|
||||
Process an announcement using the multilang_translator.
|
||||
This function now uses the multicast_client to send announcements to endpoints.
|
||||
"""
|
||||
global active_group_id, last_completed_group_id, caster, reset_task
|
||||
|
||||
ep_group.current_state = AnnouncementStates.IDLE
|
||||
ep_group.anouncement_start_time = time.time()
|
||||
# update the database with the new state and start time so this can be read by another process
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
|
||||
# Make sure the caster is initialized
|
||||
if not initialized:
|
||||
await init_caster()
|
||||
|
||||
try:
|
||||
# Set start time and track the active group
|
||||
group.parameters.text = text
|
||||
group.parameters.languages = group.languages
|
||||
group.parameters.start_time = time.time()
|
||||
active_group_id = group.id
|
||||
|
||||
# Update status to translating
|
||||
group.progress.current_state = AnnouncementStates.TRANSLATING.value
|
||||
group.progress.progress = 0.2
|
||||
|
||||
# Initialize all endpoints in the group
|
||||
endpoints_to_broadcast = []
|
||||
for endpoint_id in group.endpoints:
|
||||
endpoint = endpoints_db.get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
log.warning(f"Endpoint {endpoint_id} not found, skipping")
|
||||
continue
|
||||
|
||||
# Check if the endpoint supports all the required languages
|
||||
if not all(lang in endpoint.supported_languages for lang in group.languages):
|
||||
log.warning(f"Endpoint {endpoint_id} does not support all required languages, skipping")
|
||||
continue
|
||||
|
||||
# Check if the endpoint has room for another broadcast
|
||||
if endpoint_id in endpoint_status and endpoint_status[endpoint_id]["active"]:
|
||||
if endpoint_status[endpoint_id]["broadcasts"] >= endpoint_status[endpoint_id]["max_broadcasts"]:
|
||||
log.warning(f"Endpoint {endpoint_id} already at maximum broadcasts, skipping")
|
||||
continue
|
||||
|
||||
endpoint_status[endpoint_id]["broadcasts"] += 1
|
||||
endpoints_to_broadcast.append(endpoint)
|
||||
else:
|
||||
# Initialize the endpoint if not already active
|
||||
try:
|
||||
await init_endpoint(endpoint_id)
|
||||
endpoint_status[endpoint_id]["broadcasts"] += 1
|
||||
endpoints_to_broadcast.append(endpoint)
|
||||
except Exception as e:
|
||||
log.error(f"Could not initialize endpoint {endpoint_id}: {e}")
|
||||
continue
|
||||
|
||||
if not endpoints_to_broadcast:
|
||||
raise HTTPException(status_code=400, detail="No valid endpoints available for broadcast")
|
||||
|
||||
# Translate the text for each language
|
||||
base_lang = "deu" # German is the base language
|
||||
audio_data = {}
|
||||
|
||||
for i, big in enumerate(config.bigs):
|
||||
# Check if this language is in the requested languages
|
||||
if big.language not in group.languages:
|
||||
continue
|
||||
|
||||
# Translate if not the base language
|
||||
if big.language == base_lang:
|
||||
translated_text = text
|
||||
else:
|
||||
group.progress.current_state = AnnouncementStates.TRANSLATING.value
|
||||
translated_text = llm_translator.translate_de_to_x(
|
||||
text,
|
||||
big.language,
|
||||
model=big.translator_llm,
|
||||
client=big.llm_client,
|
||||
host=big.llm_host_url,
|
||||
token=big.llm_host_token
|
||||
)
|
||||
|
||||
log.info(f'Translated text ({big.language}): {translated_text}')
|
||||
|
||||
# Update status to generating voice
|
||||
group.progress.current_state = AnnouncementStates.GENERATING_VOICE.value
|
||||
group.progress.progress = 0.4
|
||||
|
||||
# This will use the voice_provider's text_to_speech.synthesize function
|
||||
from voice_provider import text_to_speech
|
||||
lc3_audio = text_to_speech.synthesize(
|
||||
translated_text,
|
||||
config.auracast_sampling_rate_hz,
|
||||
big.tts_system,
|
||||
big.tts_model,
|
||||
return_lc3=True
|
||||
# Initialize all endpoints in the group if they were not initalized before
|
||||
for endpoint in ep_group.endpoints:
|
||||
init_endpoint(endpoint, ep_group.languages)
|
||||
|
||||
# Translate the text for each language
|
||||
base_lang = "deu" # German is the base language
|
||||
ep_group.current_state = AnnouncementStates.TRANSLATING
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
translations = {base_lang: text}
|
||||
for lang in ep_group.languages:
|
||||
if lang != base_lang:
|
||||
# Translate the text
|
||||
lang_conf = getattr(ep_group.translator_config, lang)
|
||||
translation = llm_translator.translate_de_to_x(
|
||||
text=text,
|
||||
target_language=lang,
|
||||
client=lang_conf.llm_client,
|
||||
model=lang_conf.translator_llm,
|
||||
host=lang_conf.llm_host_url,
|
||||
token=lang_conf.llm_host_token
|
||||
)
|
||||
|
||||
# Add the audio to the audio_data dictionary
|
||||
audio_data[big.language] = lc3_audio.decode('latin-1') if isinstance(lc3_audio, bytes) else lc3_audio
|
||||
|
||||
# Set the audio source for this language (for the traditional caster)
|
||||
caster.big_conf[i].audio_source = lc3_audio
|
||||
|
||||
# Update status to routing
|
||||
group.progress.current_state = AnnouncementStates.ROUTING.value
|
||||
group.progress.progress = 0.6
|
||||
await asyncio.sleep(0.5) # Small delay for routing
|
||||
|
||||
# Update status to active and start streaming
|
||||
group.progress.current_state = AnnouncementStates.ACTIVE.value
|
||||
group.progress.progress = 0.8
|
||||
|
||||
# Send the audio to each endpoint using multicast_client
|
||||
broadcast_tasks = []
|
||||
for endpoint in endpoints_to_broadcast:
|
||||
# Set the BASE_URL for this endpoint
|
||||
multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint.id)
|
||||
|
||||
# Create a task to send the audio to this endpoint
|
||||
broadcast_tasks.append(asyncio.create_task(
|
||||
asyncio.to_thread(multicast_client.send_audio, audio_data)
|
||||
))
|
||||
|
||||
# Also start the traditional caster
|
||||
caster.start_streaming()
|
||||
|
||||
# Wait for all broadcasts to complete
|
||||
await asyncio.gather(*broadcast_tasks)
|
||||
|
||||
# Wait for streaming to complete with the traditional caster
|
||||
await caster.streamer.task
|
||||
|
||||
# Update endpoint status
|
||||
for endpoint in endpoints_to_broadcast:
|
||||
if endpoint.id in endpoint_status and endpoint_status[endpoint.id]["broadcasts"] > 0:
|
||||
endpoint_status[endpoint.id]["broadcasts"] -= 1
|
||||
|
||||
# Update status to complete
|
||||
group.progress.current_state = AnnouncementStates.COMPLETE.value
|
||||
group.progress.progress = 1.0
|
||||
last_completed_group_id = group.id
|
||||
|
||||
# Reset active group if this is still the active one
|
||||
if active_group_id == group.id:
|
||||
active_group_id = None
|
||||
|
||||
# After a while, reset to idle state (in a separate task)
|
||||
async def reset_to_idle():
|
||||
log.info(f"Waiting 10 seconds before resetting group {group.id} to IDLE state")
|
||||
await asyncio.sleep(10) # Keep completed state visible for 10 seconds
|
||||
log.info(f"Resetting group {group.id} to IDLE state")
|
||||
# Use direct value lookup for the state comparison
|
||||
if group.progress.current_state == AnnouncementStates.COMPLETE.value:
|
||||
group.progress.current_state = AnnouncementStates.IDLE.value
|
||||
group.progress.progress = 0.0
|
||||
log.info(f"Group {group.id} state reset to IDLE")
|
||||
|
||||
# Create and save the task so it won't be garbage collected
|
||||
reset_task = asyncio.create_task(reset_to_idle())
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error during announcement processing: {e}")
|
||||
group.progress.current_state = AnnouncementStates.ERROR.value
|
||||
group.progress.error = str(e)
|
||||
|
||||
# Reset active group if this is still the active one
|
||||
if active_group_id == group.id:
|
||||
active_group_id = None
|
||||
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
translations[lang] = translation
|
||||
log.info(f"Translated to {lang}: {translation}")
|
||||
|
||||
|
||||
@app.get("/api/groups")
|
||||
# Generate voices
|
||||
ep_group.current_state = AnnouncementStates.GENERATING_VOICE
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
audio = {}
|
||||
for lang, text in translations.items():
|
||||
# Get the appropriate language configuration
|
||||
lang_conf = getattr(ep_group.translator_config, lang)
|
||||
|
||||
# Convert text to LC3-encoded audio using the configuration's TTS settings
|
||||
audio[lang] = text_to_speech.synthesize(
|
||||
translations[lang],
|
||||
16000, # Sample rate from auracast config # TODO: take sampling rate from auracast config
|
||||
lang_conf.tts_system, # TTS system from config
|
||||
lang_conf.tts_model, # TTS model from config
|
||||
return_lc3=True
|
||||
).decode('latin-1')
|
||||
|
||||
# Add to audio data dictionary (decode bytes to string for JSON serialization)
|
||||
|
||||
# Update group progress
|
||||
ep_group.current_state = AnnouncementStates.BROADCASTING
|
||||
endpoints_db.update_group(ep_group.id, ep_group)
|
||||
|
||||
|
||||
# Broadcast to all endpoints in group
|
||||
for endpoint in ep_group.endpoints:
|
||||
# Send the audio data to the server using the existing send_audio function
|
||||
log.info(f"Broadcastcasting to {endpoint.name} for languages: {', '.join(audio.keys())}")
|
||||
multicast_client.send_audio(audio, base_url=endpoint.url)
|
||||
|
||||
|
||||
# Update group state to completed
|
||||
#ep_group.current_state = AnnouncementStates.COMPLETED # TODO: somehow the group state needs to be updated to completed after a broadcast
|
||||
|
||||
# Schedule group endpoint reset
|
||||
# if ep_group.reset_task and not ep_group.reset_task.done():
|
||||
# ep_group.reset_task.cancel()
|
||||
|
||||
# ep_group.reset_task = asyncio.create_task(reset_endpoints_after_delay(ep_group.endpoints, 60))
|
||||
|
||||
# Return the translations
|
||||
return {"translations": translations}
|
||||
|
||||
|
||||
async def reset_endpoints_after_delay(endpoint_ids, delay_seconds):
|
||||
"""Reset endpoints after a delay."""
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
for endpoint_id in endpoint_ids:
|
||||
if endpoint_id in endpoint_status and endpoint_status[endpoint_id]["active"]:
|
||||
try:
|
||||
endpoint_status[endpoint_id]["broadcasts"] = 0
|
||||
log.info(f"Reset broadcasts count for endpoint {endpoint_id}")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to reset endpoint {endpoint_id}: {e}")
|
||||
|
||||
|
||||
@app.get("/groups")
|
||||
async def get_groups():
|
||||
"""Get all endpoint groups."""
|
||||
return endpoint_groups
|
||||
"""Get all endpoint groups with their current status."""
|
||||
return endpoints_db.get_all_groups()
|
||||
|
||||
|
||||
@app.post("/api/groups")
|
||||
async def create_group(group: EndpointGroup):
|
||||
|
||||
@app.post("/groups")
|
||||
async def create_group(group: endpoints_db.EndpointGroup):
|
||||
"""Add a new endpoint group."""
|
||||
# Validate endpoints
|
||||
for endpoint_id in group.endpoints:
|
||||
if endpoint_id not in AVAILABLE_ENDPOINTS:
|
||||
raise HTTPException(status_code=400, detail=f"Endpoint {endpoint_id} is not available")
|
||||
|
||||
# Validate languages
|
||||
for language in group.languages:
|
||||
if language not in AVAILABLE_LANGUAGES:
|
||||
raise HTTPException(status_code=400, detail=f"Language {language} is not available")
|
||||
|
||||
# Add the group
|
||||
endpoint_groups.append(group)
|
||||
return group
|
||||
try:
|
||||
return endpoints_db.add_group(group)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.put("/api/groups/{group_id}")
|
||||
async def update_group(group_id: int, updated_group: EndpointGroup):
|
||||
@app.put("/groups/{group_id}")
|
||||
async def update_group(group_id: int, updated_group: endpoints_db.EndpointGroup):
|
||||
"""Update an existing endpoint group."""
|
||||
for i, group in enumerate(endpoint_groups):
|
||||
if group.id == group_id:
|
||||
# Validate endpoints
|
||||
for endpoint_id in updated_group.endpoints:
|
||||
if endpoint_id not in AVAILABLE_ENDPOINTS:
|
||||
raise HTTPException(status_code=400, detail=f"Endpoint {endpoint_id} is not available")
|
||||
|
||||
# Validate languages
|
||||
for language in updated_group.languages:
|
||||
if language not in AVAILABLE_LANGUAGES:
|
||||
raise HTTPException(status_code=400, detail=f"Language {language} is not available")
|
||||
|
||||
# Update the group
|
||||
endpoint_groups[i] = updated_group
|
||||
return updated_group
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found")
|
||||
try:
|
||||
return endpoints_db.update_group(group_id, updated_group)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.delete("/api/groups/{group_id}")
|
||||
@app.delete("/groups/{group_id}")
|
||||
async def delete_group(group_id: int):
|
||||
"""Delete an endpoint group."""
|
||||
for i, group in enumerate(endpoint_groups):
|
||||
if group.id == group_id:
|
||||
del endpoint_groups[i]
|
||||
return {"detail": f"Group with id {group_id} deleted"}
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found")
|
||||
try:
|
||||
endpoints_db.delete_group(group_id)
|
||||
return {"message": f"Group {group_id} deleted successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/api/announcements")
|
||||
@app.post("/announcement")
|
||||
async def start_announcement(text: str, group_id: int):
|
||||
"""Start a new announcement to the specified endpoint group."""
|
||||
global announcement_task
|
||||
|
||||
# Find the group
|
||||
target_group = None
|
||||
for group in endpoint_groups:
|
||||
if group.id == group_id:
|
||||
target_group = group
|
||||
break
|
||||
# Get the group from active groups or database
|
||||
group = endpoints_db.get_group_by_id(group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=400, detail=f"Group {group_id} not found")
|
||||
|
||||
if not target_group:
|
||||
raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found")
|
||||
# Check if we're already processing an announcement
|
||||
#if announcement_task and not announcement_task.done():
|
||||
# raise HTTPException(status_code=400, detail="Already processing an announcement")
|
||||
|
||||
# Check if an announcement is already in progress
|
||||
if active_group_id is not None:
|
||||
raise HTTPException(status_code=400, detail="An announcement is already in progress")
|
||||
|
||||
# Process the announcement
|
||||
announcement_task = asyncio.create_task(process_announcement(text, target_group))
|
||||
|
||||
return {"detail": "Announcement started"}
|
||||
# Start the announcement task
|
||||
announcement_task = asyncio.create_task(make_announcement(text, group))
|
||||
return {"status": "Announcement started", "group_id": group_id}
|
||||
|
||||
|
||||
@app.get("/api/announcements/status")
|
||||
async def get_announcement_status():
|
||||
"""Get the status of the current announcement."""
|
||||
# If no group is active, check if the last completed group exists
|
||||
if active_group_id is None:
|
||||
if last_completed_group_id is not None:
|
||||
# Find the last completed group
|
||||
for group in endpoint_groups:
|
||||
if group.id == last_completed_group_id:
|
||||
return {
|
||||
"active_group": None,
|
||||
"last_completed_group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"endpoints": group.endpoints,
|
||||
"languages": group.languages,
|
||||
"progress": {
|
||||
"current_state": group.progress.current_state,
|
||||
"progress": group.progress.progress,
|
||||
"error": group.progress.error
|
||||
},
|
||||
"parameters": {
|
||||
"text": group.parameters.text,
|
||||
"languages": group.parameters.languages,
|
||||
"start_time": group.parameters.start_time
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# If the last completed group couldn't be found
|
||||
return {
|
||||
"active_group": None,
|
||||
"last_completed_group": None
|
||||
}
|
||||
else:
|
||||
# No active or last completed group
|
||||
return {
|
||||
"active_group": None,
|
||||
"last_completed_group": None
|
||||
}
|
||||
|
||||
# If a group is active, return its information
|
||||
for group in endpoint_groups:
|
||||
if group.id == active_group_id:
|
||||
return {
|
||||
"active_group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"endpoints": group.endpoints,
|
||||
"languages": group.languages,
|
||||
"progress": {
|
||||
"current_state": group.progress.current_state,
|
||||
"progress": group.progress.progress,
|
||||
"error": group.progress.error
|
||||
},
|
||||
"parameters": {
|
||||
"text": group.parameters.text,
|
||||
"languages": group.parameters.languages,
|
||||
"start_time": group.parameters.start_time
|
||||
}
|
||||
},
|
||||
"last_completed_group": None
|
||||
}
|
||||
|
||||
# If the active group couldn't be found
|
||||
return {
|
||||
"active_group": None,
|
||||
"last_completed_group": None
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/endpoints")
|
||||
async def get_available_endpoints():
|
||||
"""Get all available endpoints with their capabilities."""
|
||||
return [endpoint_db.dict() for endpoint_db in endpoints_db.get_all_endpoints()]
|
||||
|
||||
|
||||
@app.get("/api/languages")
|
||||
async def get_available_languages():
|
||||
"""Get all available languages for announcements."""
|
||||
return endpoints_db.get_available_languages()
|
||||
|
||||
|
||||
@app.get("/api/endpoints/{endpoint_id}")
|
||||
@app.get("/endpoints/{endpoint_id}/status") # TODO: think about progress tracking
|
||||
async def get_endpoint_status(endpoint_id: str):
|
||||
"""Get the status of a specific endpoint."""
|
||||
# Check if the endpoint exists
|
||||
endpoint = endpoints_db.get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_id} not found")
|
||||
|
||||
status = endpoint_status.get(endpoint_id, {"active": False, "broadcasts": 0})
|
||||
return {
|
||||
"endpoint": endpoint.dict(),
|
||||
"status": status
|
||||
}
|
||||
# Return the status if available, otherwise a default status
|
||||
if endpoint_id in endpoint_status:
|
||||
return endpoint_status[endpoint_id]
|
||||
else:
|
||||
return {
|
||||
"active": False,
|
||||
"broadcasts": 0,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/endpoints/{endpoint_id}/reset")
|
||||
async def reset_endpoint(endpoint_id: str):
|
||||
"""Reset an endpoint's status."""
|
||||
endpoint = endpoints_db.get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_id} not found")
|
||||
|
||||
try:
|
||||
# Set the BASE_URL for this endpoint
|
||||
multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint_id)
|
||||
|
||||
# Stop any current audio
|
||||
multicast_client.stop_audio()
|
||||
|
||||
# Reinitialize the endpoint
|
||||
await init_endpoint(endpoint_id)
|
||||
|
||||
return {"detail": f"Endpoint {endpoint_id} reset successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to reset endpoint: {str(e)}")
|
||||
@app.get("/endpoints")
|
||||
async def get_available_endpoints():
|
||||
"""Get all available endpoints with their capabilities."""
|
||||
return AVAILABLE_ENDPOINTS
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize the caster on startup."""
|
||||
global caster, initialized
|
||||
|
||||
try:
|
||||
await init_caster()
|
||||
|
||||
# Initialize all endpoints
|
||||
for endpoint_id in AVAILABLE_ENDPOINTS:
|
||||
try:
|
||||
await init_endpoint(endpoint_id)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize endpoint {endpoint_id}: {e}")
|
||||
except Exception as e:
|
||||
log.error(f"Startup error: {e}")
|
||||
@app.get("/languages")
|
||||
async def get_available_languages():
|
||||
"""Get all available languages for announcements."""
|
||||
return AVAILABLE_LANGUAGES
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Clean up resources on shutdown."""
|
||||
global caster
|
||||
|
||||
if caster is not None:
|
||||
try:
|
||||
log.info("Stopping caster...")
|
||||
caster.stop_streaming()
|
||||
caster = None
|
||||
log.info("Caster stopped")
|
||||
except Exception as e:
|
||||
log.error(f"Error during caster shutdown: {e}")
|
||||
|
||||
# Shutdown all active endpoints
|
||||
for endpoint_id in endpoint_status:
|
||||
if endpoint_status[endpoint_id].get("active", False):
|
||||
try:
|
||||
multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint_id)
|
||||
multicast_client.shutdown()
|
||||
except Exception as e:
|
||||
log.error(f"Error shutting down endpoint {endpoint_id}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
@@ -564,11 +256,13 @@ if __name__ == "__main__":
|
||||
level=log.INFO,
|
||||
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||
)
|
||||
log.info("Starting Translator API server")
|
||||
# with reload=True logging of modules does not function as expected
|
||||
uvicorn.run(
|
||||
"translator_api:app",
|
||||
host="0.0.0.0",
|
||||
port=7999,
|
||||
reload=True,
|
||||
log_level="debug"
|
||||
)
|
||||
app,
|
||||
#'translator_api:app',
|
||||
host="0.0.0.0",
|
||||
port=7999,
|
||||
#reload=True,
|
||||
#log_config=None,
|
||||
#log_level="info"
|
||||
)
|
||||
|
||||
@@ -3,36 +3,50 @@ Models for the translator API.
|
||||
Similar to the models used in auracaster-webui but simplified for the translator middleware.
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AnnouncementStates(str, Enum):
|
||||
IDLE = "idle"
|
||||
TRANSLATING = "translating"
|
||||
GENERATING_VOICE = "generating_voice"
|
||||
ROUTING = "routing"
|
||||
ACTIVE = "active"
|
||||
COMPLETE = "complete"
|
||||
ERROR = "error"
|
||||
class AnnouncementStates(Enum):
|
||||
IDLE = 0
|
||||
TRANSLATING = 0.2
|
||||
GENERATING_VOICE = 0.4
|
||||
ROUTING = 0.6
|
||||
BROADCASTING = 0.8
|
||||
COMPLETED = 1
|
||||
ERROR = 0
|
||||
|
||||
|
||||
class AnnouncementProgress(BaseModel):
|
||||
current_state: str = AnnouncementStates.IDLE
|
||||
progress: float = 0.0
|
||||
error: Optional[str] = None
|
||||
class Endpoint(BaseModel):
|
||||
"""Defines an endpoint with its URL and capabilities."""
|
||||
id: int
|
||||
name: str
|
||||
url: str
|
||||
max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts
|
||||
|
||||
|
||||
class AnnouncementParameters(BaseModel):
|
||||
text: str = ""
|
||||
languages: List[str] = []
|
||||
start_time: float = 0.0
|
||||
class TranslatorLangConfig(BaseModel):
|
||||
translator_llm: str = 'llama3.2:3b-instruct-q4_0'
|
||||
llm_client: str = 'ollama'
|
||||
llm_host_url: str | None = 'http://localhost:11434'
|
||||
llm_host_token: str | None = None
|
||||
tts_system: str = 'piper'
|
||||
tts_model: str ='de_DE-kerstin-low'
|
||||
|
||||
|
||||
class TranslatorConfig(BaseModel):
|
||||
deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high')
|
||||
eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium')
|
||||
fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium')
|
||||
spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium')
|
||||
ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium')
|
||||
|
||||
|
||||
class EndpointGroup(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
endpoints: List[str]
|
||||
languages: List[str]
|
||||
progress: AnnouncementProgress = AnnouncementProgress()
|
||||
parameters: AnnouncementParameters = AnnouncementParameters()
|
||||
endpoints: List[Endpoint]
|
||||
translator_config: TranslatorConfig = TranslatorConfig()
|
||||
current_state: str = AnnouncementStates.IDLE
|
||||
anouncement_start_time: float = 0.0
|
||||
|
||||
@@ -1,35 +1,20 @@
|
||||
import os
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
from auracast import auracast_config
|
||||
|
||||
VENV_DIR = os.path.join(os.path.dirname(__file__), './../../venv')
|
||||
|
||||
class TranslatorLangConfig(auracast_config.AuracastBigConfig):
|
||||
translator_llm: str = 'llama3.2:3b-instruct-q4_0'
|
||||
class TranslatorLangConfig(BaseModel):
|
||||
translator_llm: str = 'llama3.2:3b-instruct-q4_0' # TODO: this was migrated to translator_models - remove this
|
||||
llm_client: str = 'ollama'
|
||||
llm_host_url: str | None = 'http://localhost:11434'
|
||||
llm_host_token: str | None = None
|
||||
tts_system: str = 'piper'
|
||||
tts_model: str ='de_DE-kerstin-low'
|
||||
|
||||
class TranslatorConfig(BaseModel):
|
||||
deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high')
|
||||
eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium')
|
||||
fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium')
|
||||
spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium')
|
||||
ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium')
|
||||
|
||||
class TranslatorConfigDeu(TranslatorLangConfig, auracast_config.AuracastBigConfigDeu):
|
||||
tts_model: str ='de_DE-thorsten-high'
|
||||
|
||||
class TranslatorConfigEng(TranslatorLangConfig, auracast_config.AuracastBigConfigEng):
|
||||
tts_model: str = 'en_GB-alba-medium'
|
||||
|
||||
class TranslatorConfigFra(TranslatorLangConfig, auracast_config.AuracastBigConfigFra):
|
||||
tts_model: str = 'fr_FR-siwis-medium'
|
||||
|
||||
class TranslatorConfigSpa(TranslatorLangConfig, auracast_config.AuracastBigConfigSpa):
|
||||
tts_model: str = 'es_ES-sharvard-medium'
|
||||
|
||||
class TranslatorConfigIta(TranslatorLangConfig, auracast_config.AuracastBigConfigIta):
|
||||
tts_model: str = 'it_IT-paola-medium'
|
||||
|
||||
class TranslatorConfigGroup(auracast_config.AuracastGlobalConfig):
|
||||
bigs: List[TranslatorLangConfig] = [
|
||||
TranslatorConfigDeu(),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user