implement a basic translator api to work with the auracaster webui
This commit is contained in:
1
src/multilang_translator/translator_api/__init__.py
Normal file
1
src/multilang_translator/translator_api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Empty file to make the directory a package
|
||||
394
src/multilang_translator/translator_api/api.py
Normal file
394
src/multilang_translator/translator_api/api.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
FastAPI implementation of the Multilang Translator API.
|
||||
This API mimics the mock_api from auracaster-webui to allow integration.
|
||||
"""
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import logging as log
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# Import models
|
||||
from multilang_translator.translator_api.translator_models import EndpointGroup, AnnouncementStates
|
||||
from multilang_translator.translator_config import TranslatorConfigGroup, TranslatorConfigDe, TranslatorConfigEn, TranslatorConfigFr
|
||||
from multilang_translator.translator import llm_translator
|
||||
from auracast import multicast_control
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware to allow cross-origin requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Global variables
|
||||
config = TranslatorConfigGroup(
|
||||
bigs=[
|
||||
TranslatorConfigDe(),
|
||||
TranslatorConfigEn(),
|
||||
TranslatorConfigFr(),
|
||||
]
|
||||
)
|
||||
|
||||
# Configure the transport
|
||||
config.transport = 'serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' # nrf52dongle hci_uart usb cdc
|
||||
|
||||
# Configure the bigs
|
||||
for conf in config.bigs:
|
||||
conf.loop = False
|
||||
conf.llm_client = 'openwebui' # comment out for local llm
|
||||
conf.llm_host_url = 'https://ollama.pstruebi.xyz'
|
||||
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
|
||||
|
||||
# Available endpoints and languages
|
||||
AVAILABLE_ENDPOINTS = ["endpoint1", "endpoint2", "endpoint3"]
|
||||
AVAILABLE_LANGUAGES = ["German", "English", "French"]
|
||||
|
||||
# Default endpoint groups
|
||||
endpoint_groups = [
|
||||
EndpointGroup(
|
||||
id=1,
|
||||
name="Gate1",
|
||||
endpoints=["endpoint1", "endpoint2"],
|
||||
languages=["German", "English"]
|
||||
),
|
||||
EndpointGroup(
|
||||
id=2,
|
||||
name="Gate2",
|
||||
endpoints=["endpoint3"],
|
||||
languages=["German", "English", "French"]
|
||||
)
|
||||
]
|
||||
|
||||
# Announcement state tracking
|
||||
active_group_id = None
|
||||
last_completed_group_id = None
|
||||
announcement_task = None
|
||||
caster = None
|
||||
|
||||
# Initialization flag
|
||||
initialized = False
|
||||
|
||||
|
||||
async def init_caster():
|
||||
"""Initialize the multicaster if not already initialized."""
|
||||
global caster, initialized
|
||||
|
||||
if not initialized:
|
||||
log.info("Initializing caster...")
|
||||
try:
|
||||
caster = multicast_control.Multicaster(
|
||||
config,
|
||||
[big for big in config.bigs]
|
||||
)
|
||||
await caster.init_broadcast()
|
||||
initialized = True
|
||||
log.info("Caster initialized successfully")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize caster: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
async def process_announcement(text: str, group: EndpointGroup):
|
||||
"""
|
||||
Process an announcement using the multilang_translator.
|
||||
This is based on the announcement_from_german_text function in main_local.py.
|
||||
"""
|
||||
global active_group_id, last_completed_group_id, caster
|
||||
|
||||
# Make sure the caster is initialized
|
||||
if not initialized:
|
||||
await init_caster()
|
||||
|
||||
try:
|
||||
# Set start time and track the active group
|
||||
group.parameters.text = text
|
||||
group.parameters.languages = group.languages
|
||||
group.parameters.start_time = time.time()
|
||||
active_group_id = group.id
|
||||
|
||||
# Update status to translating
|
||||
group.progress.current_state = AnnouncementStates.TRANSLATING
|
||||
group.progress.progress = 0.2
|
||||
|
||||
# Translate the text for each language
|
||||
base_lang = "deu" # German is the base language
|
||||
|
||||
for i, big in enumerate(config.bigs):
|
||||
# Check if this language is in the requested languages
|
||||
lang_code_to_name = {"deu": "German", "eng": "English", "fra": "French"}
|
||||
lang_name = lang_code_to_name.get(big.language, "")
|
||||
|
||||
if lang_name not in group.languages:
|
||||
continue
|
||||
|
||||
# Translate if not the base language
|
||||
if big.language == base_lang:
|
||||
translated_text = text
|
||||
else:
|
||||
group.progress.current_state = AnnouncementStates.TRANSLATING
|
||||
translated_text = llm_translator.translate_de_to_x(
|
||||
text,
|
||||
big.language,
|
||||
model=big.translator_llm,
|
||||
client=big.llm_client,
|
||||
host=big.llm_host_url,
|
||||
token=big.llm_host_token
|
||||
)
|
||||
|
||||
log.info(f'Translated text ({big.language}): {translated_text}')
|
||||
|
||||
# Update status to generating voice
|
||||
group.progress.current_state = AnnouncementStates.GENERATING_VOICE
|
||||
group.progress.progress = 0.4
|
||||
|
||||
# This will use the voice_provider's text_to_speech.synthesize function
|
||||
from voice_provider import text_to_speech
|
||||
lc3_audio = text_to_speech.synthesize(
|
||||
translated_text,
|
||||
config.auracast_sampling_rate_hz,
|
||||
big.tts_system,
|
||||
big.tts_model,
|
||||
return_lc3=True
|
||||
)
|
||||
|
||||
# Set the audio source for this language
|
||||
caster.big_conf[i].audio_source = lc3_audio
|
||||
|
||||
# Update status to routing
|
||||
group.progress.current_state = AnnouncementStates.ROUTING
|
||||
group.progress.progress = 0.6
|
||||
await asyncio.sleep(0.5) # Small delay for routing
|
||||
|
||||
# Update status to active and start streaming
|
||||
group.progress.current_state = AnnouncementStates.ACTIVE
|
||||
group.progress.progress = 0.8
|
||||
caster.start_streaming()
|
||||
|
||||
# Wait for streaming to complete
|
||||
if hasattr(caster, 'streamer') and hasattr(caster.streamer, 'task'):
|
||||
await caster.streamer.task
|
||||
else:
|
||||
await asyncio.sleep(3) # Fallback wait if no task available
|
||||
|
||||
# Update status to complete
|
||||
group.progress.current_state = AnnouncementStates.COMPLETE
|
||||
group.progress.progress = 1.0
|
||||
last_completed_group_id = group.id
|
||||
|
||||
# Reset active group if this is still the active one
|
||||
if active_group_id == group.id:
|
||||
active_group_id = None
|
||||
|
||||
# After a while, reset to idle state (in a separate task)
|
||||
async def reset_to_idle():
|
||||
await asyncio.sleep(10) # Keep completed state visible for 10 seconds
|
||||
if group.progress.current_state == AnnouncementStates.COMPLETE:
|
||||
group.progress.current_state = AnnouncementStates.IDLE
|
||||
group.progress.progress = 0.0
|
||||
|
||||
asyncio.create_task(reset_to_idle())
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error in processing announcement: {e}")
|
||||
group.progress.current_state = AnnouncementStates.ERROR
|
||||
group.progress.error = str(e)
|
||||
if active_group_id == group.id:
|
||||
active_group_id = None
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
@app.get("/groups", response_model=List[EndpointGroup])
|
||||
async def get_groups():
|
||||
"""Get all endpoint groups."""
|
||||
return endpoint_groups
|
||||
|
||||
|
||||
@app.post("/groups", response_model=EndpointGroup)
|
||||
async def create_group(group: EndpointGroup):
|
||||
"""Add a new endpoint group."""
|
||||
# Ensure group with this ID doesn't already exist
|
||||
for existing_group in endpoint_groups:
|
||||
if existing_group.id == group.id:
|
||||
raise HTTPException(status_code=400, detail=f"Group with ID {group.id} already exists")
|
||||
|
||||
# If no ID is provided or ID is 0, auto-assign the next available ID
|
||||
if group.id == 0:
|
||||
max_id = max([g.id for g in endpoint_groups]) if endpoint_groups else 0
|
||||
group.id = max_id + 1
|
||||
|
||||
endpoint_groups.append(group)
|
||||
return group
|
||||
|
||||
|
||||
@app.put("/groups/{group_id}", response_model=EndpointGroup)
|
||||
async def update_group(group_id: int, updated_group: EndpointGroup):
|
||||
"""Update an existing endpoint group."""
|
||||
for i, group in enumerate(endpoint_groups):
|
||||
if group.id == group_id:
|
||||
# Ensure the ID doesn't change
|
||||
updated_group.id = group_id
|
||||
endpoint_groups[i] = updated_group
|
||||
return updated_group
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found")
|
||||
|
||||
|
||||
@app.delete("/groups/{group_id}")
|
||||
async def delete_group(group_id: int):
|
||||
"""Delete an endpoint group."""
|
||||
for i, group in enumerate(endpoint_groups):
|
||||
if group.id == group_id:
|
||||
del endpoint_groups[i]
|
||||
return {"message": "Group deleted successfully"}
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found")
|
||||
|
||||
|
||||
@app.post("/announcements")
|
||||
async def start_announcement(text: str, group_id: int):
|
||||
"""Start a new announcement to the specified endpoint group."""
|
||||
global announcement_task, active_group_id
|
||||
|
||||
# Find the group
|
||||
group = None
|
||||
for g in endpoint_groups:
|
||||
if g.id == group_id:
|
||||
group = g
|
||||
break
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found")
|
||||
|
||||
# Cancel any ongoing announcement
|
||||
if active_group_id is not None:
|
||||
# Reset the active group
|
||||
active_group_id = None
|
||||
|
||||
# If caster is initialized, stop streaming
|
||||
if initialized and caster:
|
||||
caster.stop_streaming()
|
||||
|
||||
# Start a new announcement process in a background task
|
||||
try:
|
||||
# Create a new task for the announcement
|
||||
announcement_task = asyncio.create_task(process_announcement(text, group))
|
||||
return {"message": "Announcement started successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/announcements/status")
|
||||
async def get_announcement_status():
|
||||
"""Get the status of the current announcement."""
|
||||
# If an announcement is active, return its status
|
||||
if active_group_id is not None:
|
||||
group = None
|
||||
for g in endpoint_groups:
|
||||
if g.id == active_group_id:
|
||||
group = g
|
||||
break
|
||||
|
||||
if group:
|
||||
return {
|
||||
"state": group.progress.current_state,
|
||||
"progress": group.progress.progress,
|
||||
"error": group.progress.error,
|
||||
"details": {
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"endpoints": group.endpoints
|
||||
},
|
||||
"text": group.parameters.text,
|
||||
"languages": group.parameters.languages,
|
||||
"start_time": group.parameters.start_time
|
||||
}
|
||||
}
|
||||
|
||||
# If no announcement is active but we have a last completed one
|
||||
elif last_completed_group_id is not None:
|
||||
group = None
|
||||
for g in endpoint_groups:
|
||||
if g.id == last_completed_group_id:
|
||||
group = g
|
||||
break
|
||||
|
||||
if group and group.progress.current_state == AnnouncementStates.COMPLETE:
|
||||
return {
|
||||
"state": group.progress.current_state,
|
||||
"progress": group.progress.progress,
|
||||
"error": group.progress.error,
|
||||
"details": {
|
||||
"group": {
|
||||
"id": group.id,
|
||||
"name": group.name,
|
||||
"endpoints": group.endpoints
|
||||
},
|
||||
"text": group.parameters.text,
|
||||
"languages": group.parameters.languages,
|
||||
"start_time": group.parameters.start_time
|
||||
}
|
||||
}
|
||||
|
||||
# Default: no active announcement
|
||||
return {
|
||||
"state": AnnouncementStates.IDLE,
|
||||
"progress": 0.0,
|
||||
"error": None,
|
||||
"details": {
|
||||
"group": {
|
||||
"id": 0,
|
||||
"name": "",
|
||||
"endpoints": []
|
||||
},
|
||||
"text": "",
|
||||
"languages": [],
|
||||
"start_time": time.time()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/endpoints")
|
||||
async def get_available_endpoints():
|
||||
"""Get all available endpoints."""
|
||||
return AVAILABLE_ENDPOINTS
|
||||
|
||||
|
||||
@app.get("/languages")
|
||||
async def get_available_languages():
|
||||
"""Get all available languages for announcements."""
|
||||
return AVAILABLE_LANGUAGES
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize the caster on startup."""
|
||||
log.basicConfig(
|
||||
level=log.INFO,
|
||||
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||
)
|
||||
# We don't initialize the caster here to avoid blocking startup
|
||||
# It will be initialized on the first announcement request
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Clean up resources on shutdown."""
|
||||
global caster, initialized
|
||||
if initialized and caster:
|
||||
try:
|
||||
await caster.shutdown()
|
||||
except Exception as e:
|
||||
log.error(f"Error during shutdown: {e}")
|
||||
28
src/multilang_translator/translator_api/main_api_server.py
Normal file
28
src/multilang_translator/translator_api/main_api_server.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Entry point for the Translator API server.
|
||||
This file starts the FastAPI server with the translator_api.
|
||||
"""
|
||||
import uvicorn
|
||||
import logging as log
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the parent directory to the Python path to find the multilang_translator package
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
log.basicConfig(
|
||||
level=log.INFO,
|
||||
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
|
||||
)
|
||||
log.info("Starting Translator API server")
|
||||
uvicorn.run(
|
||||
"multilang_translator.translator_api.api:app",
|
||||
host="0.0.0.0",
|
||||
port=7999,
|
||||
reload=True,
|
||||
log_level="debug"
|
||||
)
|
||||
39
src/multilang_translator/translator_api/translator_models.py
Normal file
39
src/multilang_translator/translator_api/translator_models.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Models for the translator API.
|
||||
Similar to the models used in auracaster-webui but simplified for the translator middleware.
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
import time
|
||||
|
||||
|
||||
class AnnouncementStates(str, Enum):
|
||||
IDLE = "idle"
|
||||
TRANSLATING = "translating"
|
||||
GENERATING_VOICE = "generating_voice"
|
||||
ROUTING = "routing"
|
||||
ACTIVE = "active"
|
||||
COMPLETE = "complete"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class AnnouncementProgress(BaseModel):
|
||||
current_state: str = AnnouncementStates.IDLE
|
||||
progress: float = 0.0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class AnnouncementParameters(BaseModel):
|
||||
text: str = ""
|
||||
languages: List[str] = []
|
||||
start_time: float = 0.0
|
||||
|
||||
|
||||
class EndpointGroup(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
endpoints: List[str]
|
||||
languages: List[str]
|
||||
progress: AnnouncementProgress = AnnouncementProgress()
|
||||
parameters: AnnouncementParameters = AnnouncementParameters()
|
||||
Reference in New Issue
Block a user