refractoring
This commit is contained in:
@@ -1,405 +0,0 @@
|
||||
"""
|
||||
FastAPI implementation of the Multilang Translator API.
|
||||
This API mimics the mock_api from auracaster-webui to allow integration.
|
||||
"""
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import logging as log
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# Import models
|
||||
from multilang_translator.translator_api.translator_models import EndpointGroup, AnnouncementStates
|
||||
from multilang_translator.translator_config import TranslatorConfigGroup, TranslatorConfigDe, TranslatorConfigEn, TranslatorConfigFr
|
||||
from multilang_translator.translator import llm_translator
|
||||
from auracast import multicast_control
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware to allow cross-origin requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Global variables
|
||||
config = TranslatorConfigGroup(
|
||||
bigs=[
|
||||
TranslatorConfigDe(),
|
||||
TranslatorConfigEn(),
|
||||
TranslatorConfigFr(),
|
||||
]
|
||||
)
|
||||
|
||||
# Configure the transport
|
||||
config.transport = 'serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' # nrf52dongle hci_uart usb cdc
|
||||
|
||||
# Configure the bigs
|
||||
for conf in config.bigs:
|
||||
conf.loop = False
|
||||
conf.llm_client = 'openwebui' # comment out for local llm
|
||||
conf.llm_host_url = 'https://ollama.pstruebi.xyz'
|
||||
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
|
||||
|
||||
# Available endpoints and languages
|
||||
AVAILABLE_ENDPOINTS = ["endpoint1", "endpoint2", "endpoint3"]
|
||||
AVAILABLE_LANGUAGES = ["German", "English", "French"]
|
||||
|
||||
# Default endpoint groups
|
||||
endpoint_groups = [
|
||||
EndpointGroup(
|
||||
id=1,
|
||||
name="Gate1",
|
||||
endpoints=["endpoint1", "endpoint2"],
|
||||
languages=["German", "English"]
|
||||
),
|
||||
EndpointGroup(
|
||||
id=2,
|
||||
name="Gate2",
|
||||
endpoints=["endpoint3"],
|
||||
languages=["German", "English", "French"]
|
||||
)
|
||||
]
|
||||
|
||||
# Announcement state tracking
|
||||
active_group_id = None
|
||||
last_completed_group_id = None
|
||||
announcement_task = None
|
||||
caster = None
|
||||
reset_task = None
|
||||
|
||||
# Initialization flag
|
||||
initialized = False
|
||||
|
||||
|
||||
async def init_caster():
|
||||
"""Initialize the multicaster if not already initialized."""
|
||||
global caster, initialized
|
||||
|
||||
if not initialized:
|
||||
log.info("Initializing caster...")
|
||||
try:
|
||||
caster = multicast_control.Multicaster(
|
||||
config,
|
||||
[big for big in config.bigs]
|
||||
)
|
||||
await caster.init_broadcast()
|
||||
initialized = True
|
||||
log.info("Caster initialized successfully")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize caster: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
async def process_announcement(text: str, group: EndpointGroup):
|
||||
"""
|
||||
Process an announcement using the multilang_translator.
|
||||
This is based on the announcement_from_german_text function in main_local.py.
|
||||
"""
|
||||
global active_group_id, last_completed_group_id, caster, reset_task
|
||||
|
||||
# Make sure the caster is initialized
|
||||
if not initialized:
|
||||
await init_caster()
|
||||
|
||||
try:
|
||||
# Set start time and track the active group
|
||||
group.parameters.text = text
|
||||
group.parameters.languages = group.languages
|
||||
group.parameters.start_time = time.time()
|
||||
active_group_id = group.id
|
||||
|
||||
# Update status to translating
|
||||
group.progress.current_state = AnnouncementStates.TRANSLATING.value
|
||||
group.progress.progress = 0.2
|
||||
|
||||
# Translate the text for each language
|
||||
base_lang = "deu" # German is the base language
|
||||
|
||||
for i, big in enumerate(config.bigs):
|
||||
# Check if this language is in the requested languages
|
||||
lang_code_to_name = {"deu": "German", "eng": "English", "fra": "French"}
|
||||
lang_name = lang_code_to_name.get(big.language, "")
|
||||
|
||||
if lang_name not in group.languages:
|
||||
continue
|
||||
|
||||
# Translate if not the base language
|
||||
if big.language == base_lang:
|
||||
translated_text = text
|
||||
else:
|
||||
group.progress.current_state = AnnouncementStates.TRANSLATING.value
|
||||
translated_text = llm_translator.translate_de_to_x(
|
||||
text,
|
||||
big.language,
|
||||
model=big.translator_llm,
|
||||
client=big.llm_client,
|
||||
host=big.llm_host_url,
|
||||
token=big.llm_host_token
|
||||
)
|
||||
|
||||
log.info(f'Translated text ({big.language}): {translated_text}')
|
||||
|
||||
# Update status to generating voice
|
||||
group.progress.current_state = AnnouncementStates.GENERATING_VOICE.value
|
||||
group.progress.progress = 0.4
|
||||
|
||||
# This will use the voice_provider's text_to_speech.synthesize function
|
||||
from voice_provider import text_to_speech
|
||||
lc3_audio = text_to_speech.synthesize(
|
||||
translated_text,
|
||||
config.auracast_sampling_rate_hz,
|
||||
big.tts_system,
|
||||
big.tts_model,
|
||||
return_lc3=True
|
||||
)
|
||||
|
||||
# Set the audio source for this language
|
||||
caster.big_conf[i].audio_source = lc3_audio
|
||||
|
||||
# Update status to routing
|
||||
group.progress.current_state = AnnouncementStates.ROUTING.value
|
||||
group.progress.progress = 0.6
|
||||
await asyncio.sleep(0.5) # Small delay for routing # TODO: actually needs to be implemented
|
||||
|
||||
# Update status to active and start streaming
|
||||
group.progress.current_state = AnnouncementStates.ACTIVE.value
|
||||
group.progress.progress = 0.8
|
||||
caster.start_streaming()
|
||||
|
||||
# Wait for streaming to complete
|
||||
await caster.streamer.task
|
||||
|
||||
# Update status to complete
|
||||
group.progress.current_state = AnnouncementStates.COMPLETE.value
|
||||
group.progress.progress = 1.0
|
||||
last_completed_group_id = group.id
|
||||
|
||||
# Reset active group if this is still the active one
|
||||
if active_group_id == group.id:
|
||||
active_group_id = None
|
||||
|
||||
# After a while, reset to idle state (in a separate task)
|
||||
async def reset_to_idle():
|
||||
log.info(f"Waiting 10 seconds before resetting group {group.id} to IDLE state")
|
||||
await asyncio.sleep(10) # Keep completed state visible for 10 seconds
|
||||
log.info(f"Resetting group {group.id} to IDLE state")
|
||||
# Use direct value lookup for the state comparison
|
||||
if group.progress.current_state == AnnouncementStates.COMPLETE.value:
|
||||
group.progress.current_state = AnnouncementStates.IDLE.value
|
||||
group.progress.progress = 0.0
|
||||
log.info(f"Group {group.id} state reset to IDLE")
|
||||
|
||||
# Create and save the task so it won't be garbage collected
|
||||
reset_task = asyncio.create_task(reset_to_idle())
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error in processing announcement: {e}")
|
||||
group.progress.current_state = AnnouncementStates.ERROR.value
|
||||
group.progress.error = str(e)
|
||||
if active_group_id == group.id:
|
||||
active_group_id = None
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@app.get("/groups", response_model=List[EndpointGroup])
|
||||
async def get_groups():
|
||||
"""Get all endpoint groups."""
|
||||
return endpoint_groups
|
||||
|
||||
|
||||
@app.post("/groups", response_model=EndpointGroup)
|
||||
async def create_group(group: EndpointGroup):
|
||||
"""Add a new endpoint group."""
|
||||
# Ensure group with this ID doesn't already exist
|
||||
for existing_group in endpoint_groups:
|
||||
if existing_group.id == group.id:
|
||||
raise HTTPException(status_code=400, detail=f"Group with ID {group.id} already exists")
|
||||
|
||||
# If no ID is provided or ID is 0, auto-assign the next available ID
|
||||
if group.id == 0:
|
||||
max_id = max([g.id for g in endpoint_groups]) if endpoint_groups else 0
|
||||
group.id = max_id + 1
|
||||
|
||||
endpoint_groups.append(group)
|
||||
return group
|
||||
|
||||
|
||||
@app.put("/groups/{group_id}", response_model=EndpointGroup)
|
||||
async def update_group(group_id: int, updated_group: EndpointGroup):
|
||||
"""Update an existing endpoint group."""
|
||||
for i, group in enumerate(endpoint_groups):
|
||||
if group.id == group_id:
|
||||
# Ensure the ID doesn't change
|
||||
updated_group.id = group_id
|
||||
endpoint_groups[i] = updated_group
|
||||
return updated_group
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found")
|
||||
|
||||
|
||||
@app.delete("/groups/{group_id}")
|
||||
async def delete_group(group_id: int):
|
||||
"""Delete an endpoint group."""
|
||||
for i, group in enumerate(endpoint_groups):
|
||||
if group.id == group_id:
|
||||
del endpoint_groups[i]
|
||||
return {"message": "Group deleted successfully"}
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found")
|
||||
|
||||
|
||||
@app.post("/announcements")
|
||||
async def start_announcement(text: str, group_id: int):
|
||||
"""Start a new announcement to the specified endpoint group."""
|
||||
global announcement_task, active_group_id
|
||||
|
||||
log.info(f"Received announcement request - text: '{text}', group_id: {group_id}")
|
||||
|
||||
# Find the group
|
||||
group = None
|
||||
for g in endpoint_groups:
|
||||
if g.id == group_id:
|
||||
group = g
|
||||
break
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found")
|
||||
|
||||
# Cancel any ongoing announcement
|
||||
if active_group_id is not None:
|
||||
# Reset the active group
|
||||
active_group_id = None
|
||||
|
||||
# If caster is initialized, stop streaming
|
||||
if initialized and caster:
|
||||
caster.stop_streaming()
|
||||
|
||||
# Start a new announcement process in a background task
|
||||
try:
|
||||
# Create a new task for the announcement
|
||||
announcement_task = asyncio.create_task(process_announcement(text, group))
|
||||
return {"message": "Announcement started successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/announcements/status")
|
||||
async def get_announcement_status():
|
||||
"""Get the status of the current announcement."""
|
||||
global active_group_id, last_completed_group_id
|
||||
log.debug(f"Status request - active_group_id: {active_group_id}, last_completed_group_id: {last_completed_group_id}")
|
||||
|
||||
# If an announcement is active, return its status
|
||||
if active_group_id is not None:
|
||||
group = None
|
||||
for g in endpoint_groups:
|
||||
if g.id == active_group_id:
|
||||
group = g
|
||||
break
|
||||
|
||||
if group:
|
||||
log.debug(f"Returning active status for group {group.id}: {group.progress.current_state}")
|
||||
return {
|
||||
"state": group.progress.current_state,
|
||||
"progress": group.progress.progress,
|
||||
"error": group.progress.error,
|
||||
"details": {
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"endpoints": group.endpoints
|
||||
},
|
||||
"text": group.parameters.text,
|
||||
"languages": group.parameters.languages,
|
||||
"start_time": group.parameters.start_time
|
||||
}
|
||||
}
|
||||
|
||||
# If no announcement is active but we have a last completed one
|
||||
elif last_completed_group_id is not None:
|
||||
group = None
|
||||
for g in endpoint_groups:
|
||||
if g.id == last_completed_group_id:
|
||||
group = g
|
||||
break
|
||||
|
||||
if group and group.progress.current_state == AnnouncementStates.COMPLETE.value:
|
||||
log.debug(f"Returning completed status for group {group.id}")
|
||||
return {
|
||||
"state": group.progress.current_state,
|
||||
"progress": group.progress.progress,
|
||||
"error": group.progress.error,
|
||||
"details": {
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"endpoints": group.endpoints
|
||||
},
|
||||
"text": group.parameters.text,
|
||||
"languages": group.parameters.languages,
|
||||
"start_time": group.parameters.start_time
|
||||
}
|
||||
}
|
||||
|
||||
# Default: no active announcement
|
||||
log.debug("Returning idle status (no active or completed announcements)")
|
||||
return {
|
||||
"state": AnnouncementStates.IDLE.value,
|
||||
"progress": 0.0,
|
||||
"error": None,
|
||||
"details": {
|
||||
"group": {
|
||||
"id": 0,
|
||||
"name": "",
|
||||
"endpoints": []
|
||||
},
|
||||
"text": "",
|
||||
"languages": [],
|
||||
"start_time": time.time()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/endpoints")
|
||||
async def get_available_endpoints():
|
||||
"""Get all available endpoints."""
|
||||
return AVAILABLE_ENDPOINTS
|
||||
|
||||
|
||||
@app.get("/languages")
|
||||
async def get_available_languages():
|
||||
"""Get all available languages for announcements."""
|
||||
return AVAILABLE_LANGUAGES
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize the caster on startup."""
|
||||
log.basicConfig(
|
||||
level=log.INFO,
|
||||
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||
)
|
||||
# We don't initialize the caster here to avoid blocking startup
|
||||
# It will be initialized on the first announcement request
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Clean up resources on shutdown."""
|
||||
global caster, initialized
|
||||
if initialized and caster:
|
||||
try:
|
||||
await caster.shutdown()
|
||||
except Exception as e:
|
||||
log.error(f"Error during shutdown: {e}")
|
||||
61
src/multilang_translator/translator_api/endpoints_db.py
Normal file
61
src/multilang_translator/translator_api/endpoints_db.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Database file for endpoint definitions.
|
||||
This file contains configurations for auracast endpoints including their IP addresses and capabilities.
|
||||
"""
|
||||
from typing import Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EndpointDefinition(BaseModel):
|
||||
"""Defines an endpoint with its URL and capabilities."""
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts
|
||||
requires_authentication: bool = False # Whether authentication is required
|
||||
|
||||
|
||||
# Database of endpoints
|
||||
ENDPOINTS: Dict[str, EndpointDefinition] = {
|
||||
"endpoint0": EndpointDefinition(
|
||||
id="endpoint0",
|
||||
name="Local Endpoint",
|
||||
url="http://localhost:5000",
|
||||
max_broadcasts=2,
|
||||
),
|
||||
"endpoint1": EndpointDefinition(
|
||||
id="endpoint1",
|
||||
name="Gate 1 Endpoint",
|
||||
url="http://192.168.1.101:5000",
|
||||
max_broadcasts=2,
|
||||
),
|
||||
"endpoint2": EndpointDefinition(
|
||||
id="endpoint2",
|
||||
name="Gate 2 Endpoint",
|
||||
url="http://192.168.1.102:5000",
|
||||
max_broadcasts=1,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_all_endpoints() -> List[EndpointDefinition]:
|
||||
"""Get all active endpoints."""
|
||||
return ENDPOINTS.values()
|
||||
|
||||
def get_endpoint_by_id(endpoint_id: str) -> Optional[EndpointDefinition]:
|
||||
"""Get an endpoint by its ID."""
|
||||
return ENDPOINTS.get(endpoint_id)
|
||||
|
||||
def get_max_broadcasts(endpoint_id: str) -> int:
|
||||
"""Get the maximum number of simultaneous broadcasts for a specific endpoint."""
|
||||
endpoint = get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
return endpoint.max_broadcasts
|
||||
|
||||
def get_endpoint_url(endpoint_id: str) -> str:
|
||||
"""Get the full URL for an endpoint."""
|
||||
endpoint = get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
return endpoint.url
|
||||
574
src/multilang_translator/translator_api/translator_api.py
Normal file
574
src/multilang_translator/translator_api/translator_api.py
Normal file
@@ -0,0 +1,574 @@
|
||||
"""
|
||||
FastAPI implementation of the Multilang Translator API.
|
||||
This API mimics the mock_api from auracaster-webui to allow integration.
|
||||
"""
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import logging as log
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# Import models
|
||||
from multilang_translator.translator_api.translator_models import EndpointGroup, AnnouncementStates
|
||||
from multilang_translator.translator_config import TranslatorConfigGroup, TranslatorConfigDeu, TranslatorConfigEng, TranslatorConfigFra
|
||||
from multilang_translator.translator import llm_translator
|
||||
from auracast import multicast_control
|
||||
# Import the endpoints database and multicast client
|
||||
from multilang_translator.translator_api import endpoints_db
|
||||
from auracast import multicast_client
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware to allow cross-origin requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Global variables
|
||||
config = TranslatorConfigGroup(
|
||||
bigs=[
|
||||
TranslatorConfigDeu(),
|
||||
TranslatorConfigEng(),
|
||||
TranslatorConfigFra(),
|
||||
]
|
||||
)
|
||||
|
||||
# Configure the transport
|
||||
config.transport = 'serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' # nrf52dongle hci_uart usb cdc
|
||||
|
||||
# Configure the bigs
|
||||
for conf in config.bigs:
|
||||
conf.loop = False
|
||||
conf.llm_client = 'openwebui' # comment out for local llm
|
||||
conf.llm_host_url = 'https://ollama.pstruebi.xyz'
|
||||
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
|
||||
|
||||
# Get available endpoints and languages from the database
|
||||
AVAILABLE_ENDPOINTS = endpoints_db.get_all_endpoints()
|
||||
AVAILABLE_LANGUAGES = ['deu', 'eng', 'fra']
|
||||
|
||||
# Default endpoint groups
|
||||
endpoint_groups = [
|
||||
EndpointGroup(
|
||||
id=1,
|
||||
name="Gate1",
|
||||
endpoints=["endpoint1", "endpoint2"],
|
||||
languages=["deu", "eng"]
|
||||
),
|
||||
EndpointGroup(
|
||||
id=2,
|
||||
name="Gate2",
|
||||
endpoints=["endpoint3"],
|
||||
languages=["deu", "eng", "fra"]
|
||||
)
|
||||
]
|
||||
|
||||
# Announcement state tracking
|
||||
active_group_id = None
|
||||
last_completed_group_id = None
|
||||
announcement_task = None
|
||||
caster = None
|
||||
reset_task = None
|
||||
|
||||
# Initialization flag
|
||||
initialized = False
|
||||
|
||||
# Dictionary to track endpoint status and active broadcasts
|
||||
endpoint_status = {}
|
||||
|
||||
|
||||
async def init_caster():
|
||||
"""Initialize the multicaster if not already initialized."""
|
||||
global caster, initialized
|
||||
|
||||
if not initialized:
|
||||
log.info("Initializing caster...")
|
||||
try:
|
||||
caster = multicast_control.Multicaster(
|
||||
config,
|
||||
[big for big in config.bigs]
|
||||
)
|
||||
await caster.init_broadcast()
|
||||
initialized = True
|
||||
log.info("Caster initialized successfully")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize caster: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
async def init_endpoint(endpoint_id: str):
|
||||
"""Initialize a specific endpoint for multicast."""
|
||||
endpoint = endpoints_db.get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
raise ValueError(f"Endpoint {endpoint_id} not found")
|
||||
|
||||
log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.ip_address}:{endpoint.port}")
|
||||
|
||||
# Update the BASE_URL in multicast_client for this endpoint
|
||||
multicast_client.BASE_URL = f"http://{endpoint.ip_address}:{endpoint.port}"
|
||||
|
||||
# Create a config for this endpoint
|
||||
endpoint_config = multicast_client.AuracastConfigGroup(
|
||||
bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")()
|
||||
for lang in endpoint.supported_languages]
|
||||
)
|
||||
endpoint_config.transport = config.transport
|
||||
|
||||
# Configure the bigs
|
||||
for conf in endpoint_config.bigs:
|
||||
conf.loop = False
|
||||
|
||||
try:
|
||||
# Initialize the endpoint
|
||||
multicast_client.request_init(endpoint_config)
|
||||
endpoint_status[endpoint_id] = {
|
||||
"active": True,
|
||||
"broadcasts": 0,
|
||||
"max_broadcasts": endpoint.max_broadcasts
|
||||
}
|
||||
log.info(f"Endpoint {endpoint_id} initialized successfully")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize endpoint {endpoint_id}: {e}")
|
||||
endpoint_status[endpoint_id] = {
|
||||
"active": False,
|
||||
"error": str(e)
|
||||
}
|
||||
raise e
|
||||
|
||||
|
||||
async def process_announcement(text: str, group: EndpointGroup):
|
||||
"""
|
||||
Process an announcement using the multilang_translator.
|
||||
This function now uses the multicast_client to send announcements to endpoints.
|
||||
"""
|
||||
global active_group_id, last_completed_group_id, caster, reset_task
|
||||
|
||||
# Make sure the caster is initialized
|
||||
if not initialized:
|
||||
await init_caster()
|
||||
|
||||
try:
|
||||
# Set start time and track the active group
|
||||
group.parameters.text = text
|
||||
group.parameters.languages = group.languages
|
||||
group.parameters.start_time = time.time()
|
||||
active_group_id = group.id
|
||||
|
||||
# Update status to translating
|
||||
group.progress.current_state = AnnouncementStates.TRANSLATING.value
|
||||
group.progress.progress = 0.2
|
||||
|
||||
# Initialize all endpoints in the group
|
||||
endpoints_to_broadcast = []
|
||||
for endpoint_id in group.endpoints:
|
||||
endpoint = endpoints_db.get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
log.warning(f"Endpoint {endpoint_id} not found, skipping")
|
||||
continue
|
||||
|
||||
# Check if the endpoint supports all the required languages
|
||||
if not all(lang in endpoint.supported_languages for lang in group.languages):
|
||||
log.warning(f"Endpoint {endpoint_id} does not support all required languages, skipping")
|
||||
continue
|
||||
|
||||
# Check if the endpoint has room for another broadcast
|
||||
if endpoint_id in endpoint_status and endpoint_status[endpoint_id]["active"]:
|
||||
if endpoint_status[endpoint_id]["broadcasts"] >= endpoint_status[endpoint_id]["max_broadcasts"]:
|
||||
log.warning(f"Endpoint {endpoint_id} already at maximum broadcasts, skipping")
|
||||
continue
|
||||
|
||||
endpoint_status[endpoint_id]["broadcasts"] += 1
|
||||
endpoints_to_broadcast.append(endpoint)
|
||||
else:
|
||||
# Initialize the endpoint if not already active
|
||||
try:
|
||||
await init_endpoint(endpoint_id)
|
||||
endpoint_status[endpoint_id]["broadcasts"] += 1
|
||||
endpoints_to_broadcast.append(endpoint)
|
||||
except Exception as e:
|
||||
log.error(f"Could not initialize endpoint {endpoint_id}: {e}")
|
||||
continue
|
||||
|
||||
if not endpoints_to_broadcast:
|
||||
raise HTTPException(status_code=400, detail="No valid endpoints available for broadcast")
|
||||
|
||||
# Translate the text for each language
|
||||
base_lang = "deu" # German is the base language
|
||||
audio_data = {}
|
||||
|
||||
for i, big in enumerate(config.bigs):
|
||||
# Check if this language is in the requested languages
|
||||
if big.language not in group.languages:
|
||||
continue
|
||||
|
||||
# Translate if not the base language
|
||||
if big.language == base_lang:
|
||||
translated_text = text
|
||||
else:
|
||||
group.progress.current_state = AnnouncementStates.TRANSLATING.value
|
||||
translated_text = llm_translator.translate_de_to_x(
|
||||
text,
|
||||
big.language,
|
||||
model=big.translator_llm,
|
||||
client=big.llm_client,
|
||||
host=big.llm_host_url,
|
||||
token=big.llm_host_token
|
||||
)
|
||||
|
||||
log.info(f'Translated text ({big.language}): {translated_text}')
|
||||
|
||||
# Update status to generating voice
|
||||
group.progress.current_state = AnnouncementStates.GENERATING_VOICE.value
|
||||
group.progress.progress = 0.4
|
||||
|
||||
# This will use the voice_provider's text_to_speech.synthesize function
|
||||
from voice_provider import text_to_speech
|
||||
lc3_audio = text_to_speech.synthesize(
|
||||
translated_text,
|
||||
config.auracast_sampling_rate_hz,
|
||||
big.tts_system,
|
||||
big.tts_model,
|
||||
return_lc3=True
|
||||
)
|
||||
|
||||
# Add the audio to the audio_data dictionary
|
||||
audio_data[big.language] = lc3_audio.decode('latin-1') if isinstance(lc3_audio, bytes) else lc3_audio
|
||||
|
||||
# Set the audio source for this language (for the traditional caster)
|
||||
caster.big_conf[i].audio_source = lc3_audio
|
||||
|
||||
# Update status to routing
|
||||
group.progress.current_state = AnnouncementStates.ROUTING.value
|
||||
group.progress.progress = 0.6
|
||||
await asyncio.sleep(0.5) # Small delay for routing
|
||||
|
||||
# Update status to active and start streaming
|
||||
group.progress.current_state = AnnouncementStates.ACTIVE.value
|
||||
group.progress.progress = 0.8
|
||||
|
||||
# Send the audio to each endpoint using multicast_client
|
||||
broadcast_tasks = []
|
||||
for endpoint in endpoints_to_broadcast:
|
||||
# Set the BASE_URL for this endpoint
|
||||
multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint.id)
|
||||
|
||||
# Create a task to send the audio to this endpoint
|
||||
broadcast_tasks.append(asyncio.create_task(
|
||||
asyncio.to_thread(multicast_client.send_audio, audio_data)
|
||||
))
|
||||
|
||||
# Also start the traditional caster
|
||||
caster.start_streaming()
|
||||
|
||||
# Wait for all broadcasts to complete
|
||||
await asyncio.gather(*broadcast_tasks)
|
||||
|
||||
# Wait for streaming to complete with the traditional caster
|
||||
await caster.streamer.task
|
||||
|
||||
# Update endpoint status
|
||||
for endpoint in endpoints_to_broadcast:
|
||||
if endpoint.id in endpoint_status and endpoint_status[endpoint.id]["broadcasts"] > 0:
|
||||
endpoint_status[endpoint.id]["broadcasts"] -= 1
|
||||
|
||||
# Update status to complete
|
||||
group.progress.current_state = AnnouncementStates.COMPLETE.value
|
||||
group.progress.progress = 1.0
|
||||
last_completed_group_id = group.id
|
||||
|
||||
# Reset active group if this is still the active one
|
||||
if active_group_id == group.id:
|
||||
active_group_id = None
|
||||
|
||||
# After a while, reset to idle state (in a separate task)
|
||||
async def reset_to_idle():
|
||||
log.info(f"Waiting 10 seconds before resetting group {group.id} to IDLE state")
|
||||
await asyncio.sleep(10) # Keep completed state visible for 10 seconds
|
||||
log.info(f"Resetting group {group.id} to IDLE state")
|
||||
# Use direct value lookup for the state comparison
|
||||
if group.progress.current_state == AnnouncementStates.COMPLETE.value:
|
||||
group.progress.current_state = AnnouncementStates.IDLE.value
|
||||
group.progress.progress = 0.0
|
||||
log.info(f"Group {group.id} state reset to IDLE")
|
||||
|
||||
# Create and save the task so it won't be garbage collected
|
||||
reset_task = asyncio.create_task(reset_to_idle())
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error during announcement processing: {e}")
|
||||
group.progress.current_state = AnnouncementStates.ERROR.value
|
||||
group.progress.error = str(e)
|
||||
|
||||
# Reset active group if this is still the active one
|
||||
if active_group_id == group.id:
|
||||
active_group_id = None
|
||||
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/groups")
|
||||
async def get_groups():
|
||||
"""Get all endpoint groups."""
|
||||
return endpoint_groups
|
||||
|
||||
|
||||
@app.post("/api/groups")
|
||||
async def create_group(group: EndpointGroup):
|
||||
"""Add a new endpoint group."""
|
||||
# Validate endpoints
|
||||
for endpoint_id in group.endpoints:
|
||||
if endpoint_id not in AVAILABLE_ENDPOINTS:
|
||||
raise HTTPException(status_code=400, detail=f"Endpoint {endpoint_id} is not available")
|
||||
|
||||
# Validate languages
|
||||
for language in group.languages:
|
||||
if language not in AVAILABLE_LANGUAGES:
|
||||
raise HTTPException(status_code=400, detail=f"Language {language} is not available")
|
||||
|
||||
# Add the group
|
||||
endpoint_groups.append(group)
|
||||
return group
|
||||
|
||||
|
||||
@app.put("/api/groups/{group_id}")
|
||||
async def update_group(group_id: int, updated_group: EndpointGroup):
|
||||
"""Update an existing endpoint group."""
|
||||
for i, group in enumerate(endpoint_groups):
|
||||
if group.id == group_id:
|
||||
# Validate endpoints
|
||||
for endpoint_id in updated_group.endpoints:
|
||||
if endpoint_id not in AVAILABLE_ENDPOINTS:
|
||||
raise HTTPException(status_code=400, detail=f"Endpoint {endpoint_id} is not available")
|
||||
|
||||
# Validate languages
|
||||
for language in updated_group.languages:
|
||||
if language not in AVAILABLE_LANGUAGES:
|
||||
raise HTTPException(status_code=400, detail=f"Language {language} is not available")
|
||||
|
||||
# Update the group
|
||||
endpoint_groups[i] = updated_group
|
||||
return updated_group
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found")
|
||||
|
||||
|
||||
@app.delete("/api/groups/{group_id}")
|
||||
async def delete_group(group_id: int):
|
||||
"""Delete an endpoint group."""
|
||||
for i, group in enumerate(endpoint_groups):
|
||||
if group.id == group_id:
|
||||
del endpoint_groups[i]
|
||||
return {"detail": f"Group with id {group_id} deleted"}
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found")
|
||||
|
||||
|
||||
@app.post("/api/announcements")
|
||||
async def start_announcement(text: str, group_id: int):
|
||||
"""Start a new announcement to the specified endpoint group."""
|
||||
global announcement_task
|
||||
|
||||
# Find the group
|
||||
target_group = None
|
||||
for group in endpoint_groups:
|
||||
if group.id == group_id:
|
||||
target_group = group
|
||||
break
|
||||
|
||||
if not target_group:
|
||||
raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found")
|
||||
|
||||
# Check if an announcement is already in progress
|
||||
if active_group_id is not None:
|
||||
raise HTTPException(status_code=400, detail="An announcement is already in progress")
|
||||
|
||||
# Process the announcement
|
||||
announcement_task = asyncio.create_task(process_announcement(text, target_group))
|
||||
|
||||
return {"detail": "Announcement started"}
|
||||
|
||||
|
||||
@app.get("/api/announcements/status")
|
||||
async def get_announcement_status():
|
||||
"""Get the status of the current announcement."""
|
||||
# If no group is active, check if the last completed group exists
|
||||
if active_group_id is None:
|
||||
if last_completed_group_id is not None:
|
||||
# Find the last completed group
|
||||
for group in endpoint_groups:
|
||||
if group.id == last_completed_group_id:
|
||||
return {
|
||||
"active_group": None,
|
||||
"last_completed_group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"endpoints": group.endpoints,
|
||||
"languages": group.languages,
|
||||
"progress": {
|
||||
"current_state": group.progress.current_state,
|
||||
"progress": group.progress.progress,
|
||||
"error": group.progress.error
|
||||
},
|
||||
"parameters": {
|
||||
"text": group.parameters.text,
|
||||
"languages": group.parameters.languages,
|
||||
"start_time": group.parameters.start_time
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# If the last completed group couldn't be found
|
||||
return {
|
||||
"active_group": None,
|
||||
"last_completed_group": None
|
||||
}
|
||||
else:
|
||||
# No active or last completed group
|
||||
return {
|
||||
"active_group": None,
|
||||
"last_completed_group": None
|
||||
}
|
||||
|
||||
# If a group is active, return its information
|
||||
for group in endpoint_groups:
|
||||
if group.id == active_group_id:
|
||||
return {
|
||||
"active_group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"endpoints": group.endpoints,
|
||||
"languages": group.languages,
|
||||
"progress": {
|
||||
"current_state": group.progress.current_state,
|
||||
"progress": group.progress.progress,
|
||||
"error": group.progress.error
|
||||
},
|
||||
"parameters": {
|
||||
"text": group.parameters.text,
|
||||
"languages": group.parameters.languages,
|
||||
"start_time": group.parameters.start_time
|
||||
}
|
||||
},
|
||||
"last_completed_group": None
|
||||
}
|
||||
|
||||
# If the active group couldn't be found
|
||||
return {
|
||||
"active_group": None,
|
||||
"last_completed_group": None
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/endpoints")
|
||||
async def get_available_endpoints():
|
||||
"""Get all available endpoints with their capabilities."""
|
||||
return [endpoint_db.dict() for endpoint_db in endpoints_db.get_all_endpoints()]
|
||||
|
||||
|
||||
@app.get("/api/languages")
|
||||
async def get_available_languages():
|
||||
"""Get all available languages for announcements."""
|
||||
return endpoints_db.get_available_languages()
|
||||
|
||||
|
||||
@app.get("/api/endpoints/{endpoint_id}")
|
||||
async def get_endpoint_status(endpoint_id: str):
|
||||
"""Get the status of a specific endpoint."""
|
||||
endpoint = endpoints_db.get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_id} not found")
|
||||
|
||||
status = endpoint_status.get(endpoint_id, {"active": False, "broadcasts": 0})
|
||||
return {
|
||||
"endpoint": endpoint.dict(),
|
||||
"status": status
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/endpoints/{endpoint_id}/reset")
|
||||
async def reset_endpoint(endpoint_id: str):
|
||||
"""Reset an endpoint's status."""
|
||||
endpoint = endpoints_db.get_endpoint_by_id(endpoint_id)
|
||||
if not endpoint:
|
||||
raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_id} not found")
|
||||
|
||||
try:
|
||||
# Set the BASE_URL for this endpoint
|
||||
multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint_id)
|
||||
|
||||
# Stop any current audio
|
||||
multicast_client.stop_audio()
|
||||
|
||||
# Reinitialize the endpoint
|
||||
await init_endpoint(endpoint_id)
|
||||
|
||||
return {"detail": f"Endpoint {endpoint_id} reset successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to reset endpoint: {str(e)}")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize the caster on startup."""
|
||||
global caster, initialized
|
||||
|
||||
try:
|
||||
await init_caster()
|
||||
|
||||
# Initialize all endpoints
|
||||
for endpoint_id in AVAILABLE_ENDPOINTS:
|
||||
try:
|
||||
await init_endpoint(endpoint_id)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize endpoint {endpoint_id}: {e}")
|
||||
except Exception as e:
|
||||
log.error(f"Startup error: {e}")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Clean up resources on shutdown."""
|
||||
global caster
|
||||
|
||||
if caster is not None:
|
||||
try:
|
||||
log.info("Stopping caster...")
|
||||
caster.stop_streaming()
|
||||
caster = None
|
||||
log.info("Caster stopped")
|
||||
except Exception as e:
|
||||
log.error(f"Error during caster shutdown: {e}")
|
||||
|
||||
# Shutdown all active endpoints
|
||||
for endpoint_id in endpoint_status:
|
||||
if endpoint_status[endpoint_id].get("active", False):
|
||||
try:
|
||||
multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint_id)
|
||||
multicast_client.shutdown()
|
||||
except Exception as e:
|
||||
log.error(f"Error shutting down endpoint {endpoint_id}: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
log.basicConfig(
|
||||
level=log.INFO,
|
||||
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||
)
|
||||
log.info("Starting Translator API server")
|
||||
uvicorn.run(
|
||||
"translator_api:app",
|
||||
host="0.0.0.0",
|
||||
port=7999,
|
||||
reload=True,
|
||||
log_level="debug"
|
||||
)
|
||||
@@ -5,7 +5,6 @@ Similar to the models used in auracaster-webui but simplified for the translator
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
import time
|
||||
|
||||
|
||||
class AnnouncementStates(str, Enum):
|
||||
|
||||
@@ -14,22 +14,22 @@ class TranslatorLangConfig(auracast_config.AuracastBigConfig):
|
||||
tts_model: str ='de_DE-kerstin-low'
|
||||
|
||||
|
||||
class TranslatorConfigDe(TranslatorLangConfig, auracast_config.AuracastBigConfigDe):
|
||||
class TranslatorConfigDeu(TranslatorLangConfig, auracast_config.AuracastBigConfigDeu):
|
||||
tts_model: str ='de_DE-thorsten-high'
|
||||
|
||||
class TranslatorConfigEn(TranslatorLangConfig, auracast_config.AuracastBigConfigEn):
|
||||
class TranslatorConfigEng(TranslatorLangConfig, auracast_config.AuracastBigConfigEng):
|
||||
tts_model: str = 'en_GB-alba-medium'
|
||||
|
||||
class TranslatorConfigFr(TranslatorLangConfig, auracast_config.AuracastBigConfigFr):
|
||||
class TranslatorConfigFra(TranslatorLangConfig, auracast_config.AuracastBigConfigFra):
|
||||
tts_model: str = 'fr_FR-siwis-medium'
|
||||
|
||||
class TranslatorConfigEs(TranslatorLangConfig, auracast_config.AuracastBigConfigEs):
|
||||
class TranslatorConfigSpa(TranslatorLangConfig, auracast_config.AuracastBigConfigSpa):
|
||||
tts_model: str = 'es_ES-sharvard-medium'
|
||||
|
||||
class TranslatorConfigIt(TranslatorLangConfig, auracast_config.AuracastBigConfigIt):
|
||||
class TranslatorConfigIta(TranslatorLangConfig, auracast_config.AuracastBigConfigIta):
|
||||
tts_model: str = 'it_IT-paola-medium'
|
||||
|
||||
class TranslatorConfigGroup(auracast_config.AuracastGlobalConfig):
|
||||
bigs: List[TranslatorLangConfig] = [
|
||||
TranslatorConfigDe(),
|
||||
TranslatorConfigDeu(),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user