implement a basic translator api to work with the auracaster webui
This commit is contained in:
1
src/multilang_translator/translator_api/__init__.py
Normal file
1
src/multilang_translator/translator_api/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Empty file to make the directory a package
|
||||||
394
src/multilang_translator/translator_api/api.py
Normal file
394
src/multilang_translator/translator_api/api.py
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
"""
|
||||||
|
FastAPI implementation of the Multilang Translator API.
|
||||||
|
This API mimics the mock_api from auracaster-webui to allow integration.
|
||||||
|
"""
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import logging as log
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
# Import models
|
||||||
|
from multilang_translator.translator_api.translator_models import EndpointGroup, AnnouncementStates
|
||||||
|
from multilang_translator.translator_config import TranslatorConfigGroup, TranslatorConfigDe, TranslatorConfigEn, TranslatorConfigFr
|
||||||
|
from multilang_translator.translator import llm_translator
|
||||||
|
from auracast import multicast_control
|
||||||
|
|
||||||
|
# Create FastAPI app
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Add CORS middleware to allow cross-origin requests
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global variables
|
||||||
|
config = TranslatorConfigGroup(
|
||||||
|
bigs=[
|
||||||
|
TranslatorConfigDe(),
|
||||||
|
TranslatorConfigEn(),
|
||||||
|
TranslatorConfigFr(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure the transport
|
||||||
|
config.transport = 'serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' # nrf52dongle hci_uart usb cdc
|
||||||
|
|
||||||
|
# Configure the bigs
|
||||||
|
for conf in config.bigs:
|
||||||
|
conf.loop = False
|
||||||
|
conf.llm_client = 'openwebui' # comment out for local llm
|
||||||
|
conf.llm_host_url = 'https://ollama.pstruebi.xyz'
|
||||||
|
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
|
||||||
|
|
||||||
|
# Available endpoints and languages
|
||||||
|
AVAILABLE_ENDPOINTS = ["endpoint1", "endpoint2", "endpoint3"]
|
||||||
|
AVAILABLE_LANGUAGES = ["German", "English", "French"]
|
||||||
|
|
||||||
|
# Default endpoint groups
|
||||||
|
endpoint_groups = [
|
||||||
|
EndpointGroup(
|
||||||
|
id=1,
|
||||||
|
name="Gate1",
|
||||||
|
endpoints=["endpoint1", "endpoint2"],
|
||||||
|
languages=["German", "English"]
|
||||||
|
),
|
||||||
|
EndpointGroup(
|
||||||
|
id=2,
|
||||||
|
name="Gate2",
|
||||||
|
endpoints=["endpoint3"],
|
||||||
|
languages=["German", "English", "French"]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Announcement state tracking
|
||||||
|
active_group_id = None
|
||||||
|
last_completed_group_id = None
|
||||||
|
announcement_task = None
|
||||||
|
caster = None
|
||||||
|
|
||||||
|
# Initialization flag
|
||||||
|
initialized = False
|
||||||
|
|
||||||
|
|
||||||
|
async def init_caster():
|
||||||
|
"""Initialize the multicaster if not already initialized."""
|
||||||
|
global caster, initialized
|
||||||
|
|
||||||
|
if not initialized:
|
||||||
|
log.info("Initializing caster...")
|
||||||
|
try:
|
||||||
|
caster = multicast_control.Multicaster(
|
||||||
|
config,
|
||||||
|
[big for big in config.bigs]
|
||||||
|
)
|
||||||
|
await caster.init_broadcast()
|
||||||
|
initialized = True
|
||||||
|
log.info("Caster initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to initialize caster: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
async def process_announcement(text: str, group: EndpointGroup):
|
||||||
|
"""
|
||||||
|
Process an announcement using the multilang_translator.
|
||||||
|
This is based on the announcement_from_german_text function in main_local.py.
|
||||||
|
"""
|
||||||
|
global active_group_id, last_completed_group_id, caster
|
||||||
|
|
||||||
|
# Make sure the caster is initialized
|
||||||
|
if not initialized:
|
||||||
|
await init_caster()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set start time and track the active group
|
||||||
|
group.parameters.text = text
|
||||||
|
group.parameters.languages = group.languages
|
||||||
|
group.parameters.start_time = time.time()
|
||||||
|
active_group_id = group.id
|
||||||
|
|
||||||
|
# Update status to translating
|
||||||
|
group.progress.current_state = AnnouncementStates.TRANSLATING
|
||||||
|
group.progress.progress = 0.2
|
||||||
|
|
||||||
|
# Translate the text for each language
|
||||||
|
base_lang = "deu" # German is the base language
|
||||||
|
|
||||||
|
for i, big in enumerate(config.bigs):
|
||||||
|
# Check if this language is in the requested languages
|
||||||
|
lang_code_to_name = {"deu": "German", "eng": "English", "fra": "French"}
|
||||||
|
lang_name = lang_code_to_name.get(big.language, "")
|
||||||
|
|
||||||
|
if lang_name not in group.languages:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Translate if not the base language
|
||||||
|
if big.language == base_lang:
|
||||||
|
translated_text = text
|
||||||
|
else:
|
||||||
|
group.progress.current_state = AnnouncementStates.TRANSLATING
|
||||||
|
translated_text = llm_translator.translate_de_to_x(
|
||||||
|
text,
|
||||||
|
big.language,
|
||||||
|
model=big.translator_llm,
|
||||||
|
client=big.llm_client,
|
||||||
|
host=big.llm_host_url,
|
||||||
|
token=big.llm_host_token
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f'Translated text ({big.language}): {translated_text}')
|
||||||
|
|
||||||
|
# Update status to generating voice
|
||||||
|
group.progress.current_state = AnnouncementStates.GENERATING_VOICE
|
||||||
|
group.progress.progress = 0.4
|
||||||
|
|
||||||
|
# This will use the voice_provider's text_to_speech.synthesize function
|
||||||
|
from voice_provider import text_to_speech
|
||||||
|
lc3_audio = text_to_speech.synthesize(
|
||||||
|
translated_text,
|
||||||
|
config.auracast_sampling_rate_hz,
|
||||||
|
big.tts_system,
|
||||||
|
big.tts_model,
|
||||||
|
return_lc3=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the audio source for this language
|
||||||
|
caster.big_conf[i].audio_source = lc3_audio
|
||||||
|
|
||||||
|
# Update status to routing
|
||||||
|
group.progress.current_state = AnnouncementStates.ROUTING
|
||||||
|
group.progress.progress = 0.6
|
||||||
|
await asyncio.sleep(0.5) # Small delay for routing
|
||||||
|
|
||||||
|
# Update status to active and start streaming
|
||||||
|
group.progress.current_state = AnnouncementStates.ACTIVE
|
||||||
|
group.progress.progress = 0.8
|
||||||
|
caster.start_streaming()
|
||||||
|
|
||||||
|
# Wait for streaming to complete
|
||||||
|
if hasattr(caster, 'streamer') and hasattr(caster.streamer, 'task'):
|
||||||
|
await caster.streamer.task
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(3) # Fallback wait if no task available
|
||||||
|
|
||||||
|
# Update status to complete
|
||||||
|
group.progress.current_state = AnnouncementStates.COMPLETE
|
||||||
|
group.progress.progress = 1.0
|
||||||
|
last_completed_group_id = group.id
|
||||||
|
|
||||||
|
# Reset active group if this is still the active one
|
||||||
|
if active_group_id == group.id:
|
||||||
|
active_group_id = None
|
||||||
|
|
||||||
|
# After a while, reset to idle state (in a separate task)
|
||||||
|
async def reset_to_idle():
|
||||||
|
await asyncio.sleep(10) # Keep completed state visible for 10 seconds
|
||||||
|
if group.progress.current_state == AnnouncementStates.COMPLETE:
|
||||||
|
group.progress.current_state = AnnouncementStates.IDLE
|
||||||
|
group.progress.progress = 0.0
|
||||||
|
|
||||||
|
asyncio.create_task(reset_to_idle())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error in processing announcement: {e}")
|
||||||
|
group.progress.current_state = AnnouncementStates.ERROR
|
||||||
|
group.progress.error = str(e)
|
||||||
|
if active_group_id == group.id:
|
||||||
|
active_group_id = None
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# API Endpoints
|
||||||
|
|
||||||
|
@app.get("/groups", response_model=List[EndpointGroup])
|
||||||
|
async def get_groups():
|
||||||
|
"""Get all endpoint groups."""
|
||||||
|
return endpoint_groups
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/groups", response_model=EndpointGroup)
|
||||||
|
async def create_group(group: EndpointGroup):
|
||||||
|
"""Add a new endpoint group."""
|
||||||
|
# Ensure group with this ID doesn't already exist
|
||||||
|
for existing_group in endpoint_groups:
|
||||||
|
if existing_group.id == group.id:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Group with ID {group.id} already exists")
|
||||||
|
|
||||||
|
# If no ID is provided or ID is 0, auto-assign the next available ID
|
||||||
|
if group.id == 0:
|
||||||
|
max_id = max([g.id for g in endpoint_groups]) if endpoint_groups else 0
|
||||||
|
group.id = max_id + 1
|
||||||
|
|
||||||
|
endpoint_groups.append(group)
|
||||||
|
return group
|
||||||
|
|
||||||
|
|
||||||
|
@app.put("/groups/{group_id}", response_model=EndpointGroup)
|
||||||
|
async def update_group(group_id: int, updated_group: EndpointGroup):
|
||||||
|
"""Update an existing endpoint group."""
|
||||||
|
for i, group in enumerate(endpoint_groups):
|
||||||
|
if group.id == group_id:
|
||||||
|
# Ensure the ID doesn't change
|
||||||
|
updated_group.id = group_id
|
||||||
|
endpoint_groups[i] = updated_group
|
||||||
|
return updated_group
|
||||||
|
|
||||||
|
raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found")
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/groups/{group_id}")
|
||||||
|
async def delete_group(group_id: int):
|
||||||
|
"""Delete an endpoint group."""
|
||||||
|
for i, group in enumerate(endpoint_groups):
|
||||||
|
if group.id == group_id:
|
||||||
|
del endpoint_groups[i]
|
||||||
|
return {"message": "Group deleted successfully"}
|
||||||
|
|
||||||
|
raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/announcements")
|
||||||
|
async def start_announcement(text: str, group_id: int):
|
||||||
|
"""Start a new announcement to the specified endpoint group."""
|
||||||
|
global announcement_task, active_group_id
|
||||||
|
|
||||||
|
# Find the group
|
||||||
|
group = None
|
||||||
|
for g in endpoint_groups:
|
||||||
|
if g.id == group_id:
|
||||||
|
group = g
|
||||||
|
break
|
||||||
|
|
||||||
|
if not group:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found")
|
||||||
|
|
||||||
|
# Cancel any ongoing announcement
|
||||||
|
if active_group_id is not None:
|
||||||
|
# Reset the active group
|
||||||
|
active_group_id = None
|
||||||
|
|
||||||
|
# If caster is initialized, stop streaming
|
||||||
|
if initialized and caster:
|
||||||
|
caster.stop_streaming()
|
||||||
|
|
||||||
|
# Start a new announcement process in a background task
|
||||||
|
try:
|
||||||
|
# Create a new task for the announcement
|
||||||
|
announcement_task = asyncio.create_task(process_announcement(text, group))
|
||||||
|
return {"message": "Announcement started successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/announcements/status")
|
||||||
|
async def get_announcement_status():
|
||||||
|
"""Get the status of the current announcement."""
|
||||||
|
# If an announcement is active, return its status
|
||||||
|
if active_group_id is not None:
|
||||||
|
group = None
|
||||||
|
for g in endpoint_groups:
|
||||||
|
if g.id == active_group_id:
|
||||||
|
group = g
|
||||||
|
break
|
||||||
|
|
||||||
|
if group:
|
||||||
|
return {
|
||||||
|
"state": group.progress.current_state,
|
||||||
|
"progress": group.progress.progress,
|
||||||
|
"error": group.progress.error,
|
||||||
|
"details": {
|
||||||
|
"group": {
|
||||||
|
"id": group.id,
|
||||||
|
"name": group.name,
|
||||||
|
"endpoints": group.endpoints
|
||||||
|
},
|
||||||
|
"text": group.parameters.text,
|
||||||
|
"languages": group.parameters.languages,
|
||||||
|
"start_time": group.parameters.start_time
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# If no announcement is active but we have a last completed one
|
||||||
|
elif last_completed_group_id is not None:
|
||||||
|
group = None
|
||||||
|
for g in endpoint_groups:
|
||||||
|
if g.id == last_completed_group_id:
|
||||||
|
group = g
|
||||||
|
break
|
||||||
|
|
||||||
|
if group and group.progress.current_state == AnnouncementStates.COMPLETE:
|
||||||
|
return {
|
||||||
|
"state": group.progress.current_state,
|
||||||
|
"progress": group.progress.progress,
|
||||||
|
"error": group.progress.error,
|
||||||
|
"details": {
|
||||||
|
"group": {
|
||||||
|
"id": group.id,
|
||||||
|
"name": group.name,
|
||||||
|
"endpoints": group.endpoints
|
||||||
|
},
|
||||||
|
"text": group.parameters.text,
|
||||||
|
"languages": group.parameters.languages,
|
||||||
|
"start_time": group.parameters.start_time
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default: no active announcement
|
||||||
|
return {
|
||||||
|
"state": AnnouncementStates.IDLE,
|
||||||
|
"progress": 0.0,
|
||||||
|
"error": None,
|
||||||
|
"details": {
|
||||||
|
"group": {
|
||||||
|
"id": 0,
|
||||||
|
"name": "",
|
||||||
|
"endpoints": []
|
||||||
|
},
|
||||||
|
"text": "",
|
||||||
|
"languages": [],
|
||||||
|
"start_time": time.time()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/endpoints")
|
||||||
|
async def get_available_endpoints():
|
||||||
|
"""Get all available endpoints."""
|
||||||
|
return AVAILABLE_ENDPOINTS
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/languages")
|
||||||
|
async def get_available_languages():
|
||||||
|
"""Get all available languages for announcements."""
|
||||||
|
return AVAILABLE_LANGUAGES
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Initialize the caster on startup."""
|
||||||
|
log.basicConfig(
|
||||||
|
level=log.INFO,
|
||||||
|
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||||
|
)
|
||||||
|
# We don't initialize the caster here to avoid blocking startup
|
||||||
|
# It will be initialized on the first announcement request
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
"""Clean up resources on shutdown."""
|
||||||
|
global caster, initialized
|
||||||
|
if initialized and caster:
|
||||||
|
try:
|
||||||
|
await caster.shutdown()
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error during shutdown: {e}")
|
||||||
28
src/multilang_translator/translator_api/main_api_server.py
Normal file
28
src/multilang_translator/translator_api/main_api_server.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
Entry point for the Translator API server.
|
||||||
|
This file starts the FastAPI server with the translator_api.
|
||||||
|
"""
|
||||||
|
import uvicorn
|
||||||
|
import logging as log
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add the parent directory to the Python path to find the multilang_translator package
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
|
||||||
|
if parent_dir not in sys.path:
|
||||||
|
sys.path.insert(0, parent_dir)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
log.basicConfig(
|
||||||
|
level=log.INFO,
|
||||||
|
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||||
|
)
|
||||||
|
log.info("Starting Translator API server")
|
||||||
|
uvicorn.run(
|
||||||
|
"multilang_translator.translator_api.api:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=7999,
|
||||||
|
reload=True,
|
||||||
|
log_level="debug"
|
||||||
|
)
|
||||||
39
src/multilang_translator/translator_api/translator_models.py
Normal file
39
src/multilang_translator/translator_api/translator_models.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
Models for the translator API.
|
||||||
|
Similar to the models used in auracaster-webui but simplified for the translator middleware.
|
||||||
|
"""
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class AnnouncementStates(str, Enum):
|
||||||
|
IDLE = "idle"
|
||||||
|
TRANSLATING = "translating"
|
||||||
|
GENERATING_VOICE = "generating_voice"
|
||||||
|
ROUTING = "routing"
|
||||||
|
ACTIVE = "active"
|
||||||
|
COMPLETE = "complete"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
class AnnouncementProgress(BaseModel):
|
||||||
|
current_state: str = AnnouncementStates.IDLE
|
||||||
|
progress: float = 0.0
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnnouncementParameters(BaseModel):
|
||||||
|
text: str = ""
|
||||||
|
languages: List[str] = []
|
||||||
|
start_time: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointGroup(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
endpoints: List[str]
|
||||||
|
languages: List[str]
|
||||||
|
progress: AnnouncementProgress = AnnouncementProgress()
|
||||||
|
parameters: AnnouncementParameters = AnnouncementParameters()
|
||||||
Reference in New Issue
Block a user