From b06ec8109959fc05c73267f43b25446e17d55d4b Mon Sep 17 00:00:00 2001 From: pstruebi Date: Tue, 11 Mar 2025 17:34:15 +0100 Subject: [PATCH] refractoring --- .../translator_api/api.py | 405 ------------ .../translator_api/endpoints_db.py | 61 ++ ...n_api_server.py => main_translator_api.py} | 0 .../translator_api/translator_api.py | 574 ++++++++++++++++++ .../translator_api/translator_models.py | 1 - src/multilang_translator/translator_config.py | 12 +- 6 files changed, 641 insertions(+), 412 deletions(-) delete mode 100644 src/multilang_translator/translator_api/api.py create mode 100644 src/multilang_translator/translator_api/endpoints_db.py rename src/multilang_translator/translator_api/{main_api_server.py => main_translator_api.py} (100%) create mode 100644 src/multilang_translator/translator_api/translator_api.py diff --git a/src/multilang_translator/translator_api/api.py b/src/multilang_translator/translator_api/api.py deleted file mode 100644 index ec0fcaa..0000000 --- a/src/multilang_translator/translator_api/api.py +++ /dev/null @@ -1,405 +0,0 @@ -""" -FastAPI implementation of the Multilang Translator API. -This API mimics the mock_api from auracaster-webui to allow integration. -""" -from fastapi import FastAPI, HTTPException -from fastapi.middleware.cors import CORSMiddleware -import sys -import os -import asyncio -import threading -import time -import logging as log -from typing import List, Dict, Optional - -# Import models -from multilang_translator.translator_api.translator_models import EndpointGroup, AnnouncementStates -from multilang_translator.translator_config import TranslatorConfigGroup, TranslatorConfigDe, TranslatorConfigEn, TranslatorConfigFr -from multilang_translator.translator import llm_translator -from auracast import multicast_control - -# Create FastAPI app -app = FastAPI() - -# Add CORS middleware to allow cross-origin requests -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# Global variables -config = TranslatorConfigGroup( - bigs=[ - TranslatorConfigDe(), - TranslatorConfigEn(), - TranslatorConfigFr(), - ] -) - -# Configure the transport -config.transport = 'serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' # nrf52dongle hci_uart usb cdc - -# Configure the bigs -for conf in config.bigs: - conf.loop = False - conf.llm_client = 'openwebui' # comment out for local llm - conf.llm_host_url = 'https://ollama.pstruebi.xyz' - conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13' - -# Available endpoints and languages -AVAILABLE_ENDPOINTS = ["endpoint1", "endpoint2", "endpoint3"] -AVAILABLE_LANGUAGES = ["German", "English", "French"] - -# Default endpoint groups -endpoint_groups = [ - EndpointGroup( - id=1, - name="Gate1", - endpoints=["endpoint1", "endpoint2"], - languages=["German", "English"] - ), - EndpointGroup( - id=2, - name="Gate2", - endpoints=["endpoint3"], - languages=["German", "English", "French"] - ) -] - -# Announcement state tracking -active_group_id = None -last_completed_group_id = None -announcement_task = None -caster = None -reset_task = None - -# Initialization flag -initialized = False - - -async def init_caster(): - """Initialize the multicaster if not already initialized.""" - global caster, initialized - - if not initialized: - log.info("Initializing caster...") - try: - caster = multicast_control.Multicaster( - config, - [big for big in config.bigs] - ) - await caster.init_broadcast() - initialized = True - log.info("Caster initialized successfully") - except Exception as e: - log.error(f"Failed to initialize caster: {e}") - raise e - - -async def process_announcement(text: str, group: EndpointGroup): - """ - Process an announcement using the multilang_translator. - This is based on the announcement_from_german_text function in main_local.py. - """ - global active_group_id, last_completed_group_id, caster, reset_task - - # Make sure the caster is initialized - if not initialized: - await init_caster() - - try: - # Set start time and track the active group - group.parameters.text = text - group.parameters.languages = group.languages - group.parameters.start_time = time.time() - active_group_id = group.id - - # Update status to translating - group.progress.current_state = AnnouncementStates.TRANSLATING.value - group.progress.progress = 0.2 - - # Translate the text for each language - base_lang = "deu" # German is the base language - - for i, big in enumerate(config.bigs): - # Check if this language is in the requested languages - lang_code_to_name = {"deu": "German", "eng": "English", "fra": "French"} - lang_name = lang_code_to_name.get(big.language, "") - - if lang_name not in group.languages: - continue - - # Translate if not the base language - if big.language == base_lang: - translated_text = text - else: - group.progress.current_state = AnnouncementStates.TRANSLATING.value - translated_text = llm_translator.translate_de_to_x( - text, - big.language, - model=big.translator_llm, - client=big.llm_client, - host=big.llm_host_url, - token=big.llm_host_token - ) - - log.info(f'Translated text ({big.language}): {translated_text}') - - # Update status to generating voice - group.progress.current_state = AnnouncementStates.GENERATING_VOICE.value - group.progress.progress = 0.4 - - # This will use the voice_provider's text_to_speech.synthesize function - from voice_provider import text_to_speech - lc3_audio = text_to_speech.synthesize( - translated_text, - config.auracast_sampling_rate_hz, - big.tts_system, - big.tts_model, - return_lc3=True - ) - - # Set the audio source for this language - caster.big_conf[i].audio_source = lc3_audio - - # Update status to routing - group.progress.current_state = AnnouncementStates.ROUTING.value - group.progress.progress = 0.6 - await asyncio.sleep(0.5) # Small delay for routing # TODO: actually needs to be implemented - - # Update status to active and start streaming - group.progress.current_state = AnnouncementStates.ACTIVE.value - group.progress.progress = 0.8 - caster.start_streaming() - - # Wait for streaming to complete - await caster.streamer.task - - # Update status to complete - group.progress.current_state = AnnouncementStates.COMPLETE.value - group.progress.progress = 1.0 - last_completed_group_id = group.id - - # Reset active group if this is still the active one - if active_group_id == group.id: - active_group_id = None - - # After a while, reset to idle state (in a separate task) - async def reset_to_idle(): - log.info(f"Waiting 10 seconds before resetting group {group.id} to IDLE state") - await asyncio.sleep(10) # Keep completed state visible for 10 seconds - log.info(f"Resetting group {group.id} to IDLE state") - # Use direct value lookup for the state comparison - if group.progress.current_state == AnnouncementStates.COMPLETE.value: - group.progress.current_state = AnnouncementStates.IDLE.value - group.progress.progress = 0.0 - log.info(f"Group {group.id} state reset to IDLE") - - # Create and save the task so it won't be garbage collected - reset_task = asyncio.create_task(reset_to_idle()) - - except Exception as e: - log.error(f"Error in processing announcement: {e}") - group.progress.current_state = AnnouncementStates.ERROR.value - group.progress.error = str(e) - if active_group_id == group.id: - active_group_id = None - raise HTTPException(status_code=500, detail=str(e)) - - -# API Endpoints - -@app.get("/groups", response_model=List[EndpointGroup]) -async def get_groups(): - """Get all endpoint groups.""" - return endpoint_groups - - -@app.post("/groups", response_model=EndpointGroup) -async def create_group(group: EndpointGroup): - """Add a new endpoint group.""" - # Ensure group with this ID doesn't already exist - for existing_group in endpoint_groups: - if existing_group.id == group.id: - raise HTTPException(status_code=400, detail=f"Group with ID {group.id} already exists") - - # If no ID is provided or ID is 0, auto-assign the next available ID - if group.id == 0: - max_id = max([g.id for g in endpoint_groups]) if endpoint_groups else 0 - group.id = max_id + 1 - - endpoint_groups.append(group) - return group - - -@app.put("/groups/{group_id}", response_model=EndpointGroup) -async def update_group(group_id: int, updated_group: EndpointGroup): - """Update an existing endpoint group.""" - for i, group in enumerate(endpoint_groups): - if group.id == group_id: - # Ensure the ID doesn't change - updated_group.id = group_id - endpoint_groups[i] = updated_group - return updated_group - - raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found") - - -@app.delete("/groups/{group_id}") -async def delete_group(group_id: int): - """Delete an endpoint group.""" - for i, group in enumerate(endpoint_groups): - if group.id == group_id: - del endpoint_groups[i] - return {"message": "Group deleted successfully"} - - raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found") - - -@app.post("/announcements") -async def start_announcement(text: str, group_id: int): - """Start a new announcement to the specified endpoint group.""" - global announcement_task, active_group_id - - log.info(f"Received announcement request - text: '{text}', group_id: {group_id}") - - # Find the group - group = None - for g in endpoint_groups: - if g.id == group_id: - group = g - break - - if not group: - raise HTTPException(status_code=404, detail=f"Group with ID {group_id} not found") - - # Cancel any ongoing announcement - if active_group_id is not None: - # Reset the active group - active_group_id = None - - # If caster is initialized, stop streaming - if initialized and caster: - caster.stop_streaming() - - # Start a new announcement process in a background task - try: - # Create a new task for the announcement - announcement_task = asyncio.create_task(process_announcement(text, group)) - return {"message": "Announcement started successfully"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@app.get("/announcements/status") -async def get_announcement_status(): - """Get the status of the current announcement.""" - global active_group_id, last_completed_group_id - log.debug(f"Status request - active_group_id: {active_group_id}, last_completed_group_id: {last_completed_group_id}") - - # If an announcement is active, return its status - if active_group_id is not None: - group = None - for g in endpoint_groups: - if g.id == active_group_id: - group = g - break - - if group: - log.debug(f"Returning active status for group {group.id}: {group.progress.current_state}") - return { - "state": group.progress.current_state, - "progress": group.progress.progress, - "error": group.progress.error, - "details": { - "group": { - "id": group.id, - "name": group.name, - "endpoints": group.endpoints - }, - "text": group.parameters.text, - "languages": group.parameters.languages, - "start_time": group.parameters.start_time - } - } - - # If no announcement is active but we have a last completed one - elif last_completed_group_id is not None: - group = None - for g in endpoint_groups: - if g.id == last_completed_group_id: - group = g - break - - if group and group.progress.current_state == AnnouncementStates.COMPLETE.value: - log.debug(f"Returning completed status for group {group.id}") - return { - "state": group.progress.current_state, - "progress": group.progress.progress, - "error": group.progress.error, - "details": { - "group": { - "id": group.id, - "name": group.name, - "endpoints": group.endpoints - }, - "text": group.parameters.text, - "languages": group.parameters.languages, - "start_time": group.parameters.start_time - } - } - - # Default: no active announcement - log.debug("Returning idle status (no active or completed announcements)") - return { - "state": AnnouncementStates.IDLE.value, - "progress": 0.0, - "error": None, - "details": { - "group": { - "id": 0, - "name": "", - "endpoints": [] - }, - "text": "", - "languages": [], - "start_time": time.time() - } - } - - -@app.get("/endpoints") -async def get_available_endpoints(): - """Get all available endpoints.""" - return AVAILABLE_ENDPOINTS - - -@app.get("/languages") -async def get_available_languages(): - """Get all available languages for announcements.""" - return AVAILABLE_LANGUAGES - - -@app.on_event("startup") -async def startup_event(): - """Initialize the caster on startup.""" - log.basicConfig( - level=log.INFO, - format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' - ) - # We don't initialize the caster here to avoid blocking startup - # It will be initialized on the first announcement request - - -@app.on_event("shutdown") -async def shutdown_event(): - """Clean up resources on shutdown.""" - global caster, initialized - if initialized and caster: - try: - await caster.shutdown() - except Exception as e: - log.error(f"Error during shutdown: {e}") diff --git a/src/multilang_translator/translator_api/endpoints_db.py b/src/multilang_translator/translator_api/endpoints_db.py new file mode 100644 index 0000000..dab28cb --- /dev/null +++ b/src/multilang_translator/translator_api/endpoints_db.py @@ -0,0 +1,61 @@ +""" +Database file for endpoint definitions. +This file contains configurations for auracast endpoints including their IP addresses and capabilities. +""" +from typing import Dict, List, Optional +from pydantic import BaseModel + + +class EndpointDefinition(BaseModel): + """Defines an endpoint with its URL and capabilities.""" + id: str + name: str + url: str + max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts + requires_authentication: bool = False # Whether authentication is required + + +# Database of endpoints +ENDPOINTS: Dict[str, EndpointDefinition] = { + "endpoint0": EndpointDefinition( + id="endpoint0", + name="Local Endpoint", + url="http://localhost:5000", + max_broadcasts=2, + ), + "endpoint1": EndpointDefinition( + id="endpoint1", + name="Gate 1 Endpoint", + url="http://192.168.1.101:5000", + max_broadcasts=2, + ), + "endpoint2": EndpointDefinition( + id="endpoint2", + name="Gate 2 Endpoint", + url="http://192.168.1.102:5000", + max_broadcasts=1, + ), +} + + +def get_all_endpoints() -> List[EndpointDefinition]: + """Get all active endpoints.""" + return ENDPOINTS.values() + +def get_endpoint_by_id(endpoint_id: str) -> Optional[EndpointDefinition]: + """Get an endpoint by its ID.""" + return ENDPOINTS.get(endpoint_id) + +def get_max_broadcasts(endpoint_id: str) -> int: + """Get the maximum number of simultaneous broadcasts for a specific endpoint.""" + endpoint = get_endpoint_by_id(endpoint_id) + if not endpoint: + raise ValueError(f"Endpoint {endpoint_id} not found") + return endpoint.max_broadcasts + +def get_endpoint_url(endpoint_id: str) -> str: + """Get the full URL for an endpoint.""" + endpoint = get_endpoint_by_id(endpoint_id) + if not endpoint: + raise ValueError(f"Endpoint {endpoint_id} not found") + return endpoint.url diff --git a/src/multilang_translator/translator_api/main_api_server.py b/src/multilang_translator/translator_api/main_translator_api.py similarity index 100% rename from src/multilang_translator/translator_api/main_api_server.py rename to src/multilang_translator/translator_api/main_translator_api.py diff --git a/src/multilang_translator/translator_api/translator_api.py b/src/multilang_translator/translator_api/translator_api.py new file mode 100644 index 0000000..7b80ecd --- /dev/null +++ b/src/multilang_translator/translator_api/translator_api.py @@ -0,0 +1,574 @@ +""" +FastAPI implementation of the Multilang Translator API. +This API mimics the mock_api from auracaster-webui to allow integration. +""" +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +import sys +import os +import asyncio +import threading +import time +import logging as log +from typing import List, Dict, Optional + +# Import models +from multilang_translator.translator_api.translator_models import EndpointGroup, AnnouncementStates +from multilang_translator.translator_config import TranslatorConfigGroup, TranslatorConfigDeu, TranslatorConfigEng, TranslatorConfigFra +from multilang_translator.translator import llm_translator +from auracast import multicast_control +# Import the endpoints database and multicast client +from multilang_translator.translator_api import endpoints_db +from auracast import multicast_client + +# Create FastAPI app +app = FastAPI() + +# Add CORS middleware to allow cross-origin requests +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Global variables +config = TranslatorConfigGroup( + bigs=[ + TranslatorConfigDeu(), + TranslatorConfigEng(), + TranslatorConfigFra(), + ] +) + +# Configure the transport +config.transport = 'serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' # nrf52dongle hci_uart usb cdc + +# Configure the bigs +for conf in config.bigs: + conf.loop = False + conf.llm_client = 'openwebui' # comment out for local llm + conf.llm_host_url = 'https://ollama.pstruebi.xyz' + conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13' + +# Get available endpoints and languages from the database +AVAILABLE_ENDPOINTS = endpoints_db.get_all_endpoints() +AVAILABLE_LANGUAGES = ['deu', 'eng', 'fra'] + +# Default endpoint groups +endpoint_groups = [ + EndpointGroup( + id=1, + name="Gate1", + endpoints=["endpoint1", "endpoint2"], + languages=["deu", "eng"] + ), + EndpointGroup( + id=2, + name="Gate2", + endpoints=["endpoint3"], + languages=["deu", "eng", "fra"] + ) +] + +# Announcement state tracking +active_group_id = None +last_completed_group_id = None +announcement_task = None +caster = None +reset_task = None + +# Initialization flag +initialized = False + +# Dictionary to track endpoint status and active broadcasts +endpoint_status = {} + + +async def init_caster(): + """Initialize the multicaster if not already initialized.""" + global caster, initialized + + if not initialized: + log.info("Initializing caster...") + try: + caster = multicast_control.Multicaster( + config, + [big for big in config.bigs] + ) + await caster.init_broadcast() + initialized = True + log.info("Caster initialized successfully") + except Exception as e: + log.error(f"Failed to initialize caster: {e}") + raise e + + +async def init_endpoint(endpoint_id: str): + """Initialize a specific endpoint for multicast.""" + endpoint = endpoints_db.get_endpoint_by_id(endpoint_id) + if not endpoint: + raise ValueError(f"Endpoint {endpoint_id} not found") + + log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.ip_address}:{endpoint.port}") + + # Update the BASE_URL in multicast_client for this endpoint + multicast_client.BASE_URL = f"http://{endpoint.ip_address}:{endpoint.port}" + + # Create a config for this endpoint + endpoint_config = multicast_client.AuracastConfigGroup( + bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")() + for lang in endpoint.supported_languages] + ) + endpoint_config.transport = config.transport + + # Configure the bigs + for conf in endpoint_config.bigs: + conf.loop = False + + try: + # Initialize the endpoint + multicast_client.request_init(endpoint_config) + endpoint_status[endpoint_id] = { + "active": True, + "broadcasts": 0, + "max_broadcasts": endpoint.max_broadcasts + } + log.info(f"Endpoint {endpoint_id} initialized successfully") + except Exception as e: + log.error(f"Failed to initialize endpoint {endpoint_id}: {e}") + endpoint_status[endpoint_id] = { + "active": False, + "error": str(e) + } + raise e + + +async def process_announcement(text: str, group: EndpointGroup): + """ + Process an announcement using the multilang_translator. + This function now uses the multicast_client to send announcements to endpoints. + """ + global active_group_id, last_completed_group_id, caster, reset_task + + # Make sure the caster is initialized + if not initialized: + await init_caster() + + try: + # Set start time and track the active group + group.parameters.text = text + group.parameters.languages = group.languages + group.parameters.start_time = time.time() + active_group_id = group.id + + # Update status to translating + group.progress.current_state = AnnouncementStates.TRANSLATING.value + group.progress.progress = 0.2 + + # Initialize all endpoints in the group + endpoints_to_broadcast = [] + for endpoint_id in group.endpoints: + endpoint = endpoints_db.get_endpoint_by_id(endpoint_id) + if not endpoint: + log.warning(f"Endpoint {endpoint_id} not found, skipping") + continue + + # Check if the endpoint supports all the required languages + if not all(lang in endpoint.supported_languages for lang in group.languages): + log.warning(f"Endpoint {endpoint_id} does not support all required languages, skipping") + continue + + # Check if the endpoint has room for another broadcast + if endpoint_id in endpoint_status and endpoint_status[endpoint_id]["active"]: + if endpoint_status[endpoint_id]["broadcasts"] >= endpoint_status[endpoint_id]["max_broadcasts"]: + log.warning(f"Endpoint {endpoint_id} already at maximum broadcasts, skipping") + continue + + endpoint_status[endpoint_id]["broadcasts"] += 1 + endpoints_to_broadcast.append(endpoint) + else: + # Initialize the endpoint if not already active + try: + await init_endpoint(endpoint_id) + endpoint_status[endpoint_id]["broadcasts"] += 1 + endpoints_to_broadcast.append(endpoint) + except Exception as e: + log.error(f"Could not initialize endpoint {endpoint_id}: {e}") + continue + + if not endpoints_to_broadcast: + raise HTTPException(status_code=400, detail="No valid endpoints available for broadcast") + + # Translate the text for each language + base_lang = "deu" # German is the base language + audio_data = {} + + for i, big in enumerate(config.bigs): + # Check if this language is in the requested languages + if big.language not in group.languages: + continue + + # Translate if not the base language + if big.language == base_lang: + translated_text = text + else: + group.progress.current_state = AnnouncementStates.TRANSLATING.value + translated_text = llm_translator.translate_de_to_x( + text, + big.language, + model=big.translator_llm, + client=big.llm_client, + host=big.llm_host_url, + token=big.llm_host_token + ) + + log.info(f'Translated text ({big.language}): {translated_text}') + + # Update status to generating voice + group.progress.current_state = AnnouncementStates.GENERATING_VOICE.value + group.progress.progress = 0.4 + + # This will use the voice_provider's text_to_speech.synthesize function + from voice_provider import text_to_speech + lc3_audio = text_to_speech.synthesize( + translated_text, + config.auracast_sampling_rate_hz, + big.tts_system, + big.tts_model, + return_lc3=True + ) + + # Add the audio to the audio_data dictionary + audio_data[big.language] = lc3_audio.decode('latin-1') if isinstance(lc3_audio, bytes) else lc3_audio + + # Set the audio source for this language (for the traditional caster) + caster.big_conf[i].audio_source = lc3_audio + + # Update status to routing + group.progress.current_state = AnnouncementStates.ROUTING.value + group.progress.progress = 0.6 + await asyncio.sleep(0.5) # Small delay for routing + + # Update status to active and start streaming + group.progress.current_state = AnnouncementStates.ACTIVE.value + group.progress.progress = 0.8 + + # Send the audio to each endpoint using multicast_client + broadcast_tasks = [] + for endpoint in endpoints_to_broadcast: + # Set the BASE_URL for this endpoint + multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint.id) + + # Create a task to send the audio to this endpoint + broadcast_tasks.append(asyncio.create_task( + asyncio.to_thread(multicast_client.send_audio, audio_data) + )) + + # Also start the traditional caster + caster.start_streaming() + + # Wait for all broadcasts to complete + await asyncio.gather(*broadcast_tasks) + + # Wait for streaming to complete with the traditional caster + await caster.streamer.task + + # Update endpoint status + for endpoint in endpoints_to_broadcast: + if endpoint.id in endpoint_status and endpoint_status[endpoint.id]["broadcasts"] > 0: + endpoint_status[endpoint.id]["broadcasts"] -= 1 + + # Update status to complete + group.progress.current_state = AnnouncementStates.COMPLETE.value + group.progress.progress = 1.0 + last_completed_group_id = group.id + + # Reset active group if this is still the active one + if active_group_id == group.id: + active_group_id = None + + # After a while, reset to idle state (in a separate task) + async def reset_to_idle(): + log.info(f"Waiting 10 seconds before resetting group {group.id} to IDLE state") + await asyncio.sleep(10) # Keep completed state visible for 10 seconds + log.info(f"Resetting group {group.id} to IDLE state") + # Use direct value lookup for the state comparison + if group.progress.current_state == AnnouncementStates.COMPLETE.value: + group.progress.current_state = AnnouncementStates.IDLE.value + group.progress.progress = 0.0 + log.info(f"Group {group.id} state reset to IDLE") + + # Create and save the task so it won't be garbage collected + reset_task = asyncio.create_task(reset_to_idle()) + + except Exception as e: + log.error(f"Error during announcement processing: {e}") + group.progress.current_state = AnnouncementStates.ERROR.value + group.progress.error = str(e) + + # Reset active group if this is still the active one + if active_group_id == group.id: + active_group_id = None + + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/api/groups") +async def get_groups(): + """Get all endpoint groups.""" + return endpoint_groups + + +@app.post("/api/groups") +async def create_group(group: EndpointGroup): + """Add a new endpoint group.""" + # Validate endpoints + for endpoint_id in group.endpoints: + if endpoint_id not in AVAILABLE_ENDPOINTS: + raise HTTPException(status_code=400, detail=f"Endpoint {endpoint_id} is not available") + + # Validate languages + for language in group.languages: + if language not in AVAILABLE_LANGUAGES: + raise HTTPException(status_code=400, detail=f"Language {language} is not available") + + # Add the group + endpoint_groups.append(group) + return group + + +@app.put("/api/groups/{group_id}") +async def update_group(group_id: int, updated_group: EndpointGroup): + """Update an existing endpoint group.""" + for i, group in enumerate(endpoint_groups): + if group.id == group_id: + # Validate endpoints + for endpoint_id in updated_group.endpoints: + if endpoint_id not in AVAILABLE_ENDPOINTS: + raise HTTPException(status_code=400, detail=f"Endpoint {endpoint_id} is not available") + + # Validate languages + for language in updated_group.languages: + if language not in AVAILABLE_LANGUAGES: + raise HTTPException(status_code=400, detail=f"Language {language} is not available") + + # Update the group + endpoint_groups[i] = updated_group + return updated_group + + raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found") + + +@app.delete("/api/groups/{group_id}") +async def delete_group(group_id: int): + """Delete an endpoint group.""" + for i, group in enumerate(endpoint_groups): + if group.id == group_id: + del endpoint_groups[i] + return {"detail": f"Group with id {group_id} deleted"} + + raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found") + + +@app.post("/api/announcements") +async def start_announcement(text: str, group_id: int): + """Start a new announcement to the specified endpoint group.""" + global announcement_task + + # Find the group + target_group = None + for group in endpoint_groups: + if group.id == group_id: + target_group = group + break + + if not target_group: + raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found") + + # Check if an announcement is already in progress + if active_group_id is not None: + raise HTTPException(status_code=400, detail="An announcement is already in progress") + + # Process the announcement + announcement_task = asyncio.create_task(process_announcement(text, target_group)) + + return {"detail": "Announcement started"} + + +@app.get("/api/announcements/status") +async def get_announcement_status(): + """Get the status of the current announcement.""" + # If no group is active, check if the last completed group exists + if active_group_id is None: + if last_completed_group_id is not None: + # Find the last completed group + for group in endpoint_groups: + if group.id == last_completed_group_id: + return { + "active_group": None, + "last_completed_group": { + "id": group.id, + "name": group.name, + "endpoints": group.endpoints, + "languages": group.languages, + "progress": { + "current_state": group.progress.current_state, + "progress": group.progress.progress, + "error": group.progress.error + }, + "parameters": { + "text": group.parameters.text, + "languages": group.parameters.languages, + "start_time": group.parameters.start_time + } + } + } + + # If the last completed group couldn't be found + return { + "active_group": None, + "last_completed_group": None + } + else: + # No active or last completed group + return { + "active_group": None, + "last_completed_group": None + } + + # If a group is active, return its information + for group in endpoint_groups: + if group.id == active_group_id: + return { + "active_group": { + "id": group.id, + "name": group.name, + "endpoints": group.endpoints, + "languages": group.languages, + "progress": { + "current_state": group.progress.current_state, + "progress": group.progress.progress, + "error": group.progress.error + }, + "parameters": { + "text": group.parameters.text, + "languages": group.parameters.languages, + "start_time": group.parameters.start_time + } + }, + "last_completed_group": None + } + + # If the active group couldn't be found + return { + "active_group": None, + "last_completed_group": None + } + + +@app.get("/api/endpoints") +async def get_available_endpoints(): + """Get all available endpoints with their capabilities.""" + return [endpoint_db.dict() for endpoint_db in endpoints_db.get_all_endpoints()] + + +@app.get("/api/languages") +async def get_available_languages(): + """Get all available languages for announcements.""" + return endpoints_db.get_available_languages() + + +@app.get("/api/endpoints/{endpoint_id}") +async def get_endpoint_status(endpoint_id: str): + """Get the status of a specific endpoint.""" + endpoint = endpoints_db.get_endpoint_by_id(endpoint_id) + if not endpoint: + raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_id} not found") + + status = endpoint_status.get(endpoint_id, {"active": False, "broadcasts": 0}) + return { + "endpoint": endpoint.dict(), + "status": status + } + + +@app.post("/api/endpoints/{endpoint_id}/reset") +async def reset_endpoint(endpoint_id: str): + """Reset an endpoint's status.""" + endpoint = endpoints_db.get_endpoint_by_id(endpoint_id) + if not endpoint: + raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_id} not found") + + try: + # Set the BASE_URL for this endpoint + multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint_id) + + # Stop any current audio + multicast_client.stop_audio() + + # Reinitialize the endpoint + await init_endpoint(endpoint_id) + + return {"detail": f"Endpoint {endpoint_id} reset successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to reset endpoint: {str(e)}") + + +@app.on_event("startup") +async def startup_event(): + """Initialize the caster on startup.""" + global caster, initialized + + try: + await init_caster() + + # Initialize all endpoints + for endpoint_id in AVAILABLE_ENDPOINTS: + try: + await init_endpoint(endpoint_id) + except Exception as e: + log.error(f"Failed to initialize endpoint {endpoint_id}: {e}") + except Exception as e: + log.error(f"Startup error: {e}") + + +@app.on_event("shutdown") +async def shutdown_event(): + """Clean up resources on shutdown.""" + global caster + + if caster is not None: + try: + log.info("Stopping caster...") + caster.stop_streaming() + caster = None + log.info("Caster stopped") + except Exception as e: + log.error(f"Error during caster shutdown: {e}") + + # Shutdown all active endpoints + for endpoint_id in endpoint_status: + if endpoint_status[endpoint_id].get("active", False): + try: + multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint_id) + multicast_client.shutdown() + except Exception as e: + log.error(f"Error shutting down endpoint {endpoint_id}: {e}") + + +if __name__ == "__main__": + import uvicorn + log.basicConfig( + level=log.INFO, + format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' + ) + log.info("Starting Translator API server") + uvicorn.run( + "translator_api:app", + host="0.0.0.0", + port=7999, + reload=True, + log_level="debug" + ) \ No newline at end of file diff --git a/src/multilang_translator/translator_api/translator_models.py b/src/multilang_translator/translator_api/translator_models.py index 0103e3f..547133f 100644 --- a/src/multilang_translator/translator_api/translator_models.py +++ b/src/multilang_translator/translator_api/translator_models.py @@ -5,7 +5,6 @@ Similar to the models used in auracaster-webui but simplified for the translator from enum import Enum from typing import List, Optional, Dict, Any from pydantic import BaseModel -import time class AnnouncementStates(str, Enum): diff --git a/src/multilang_translator/translator_config.py b/src/multilang_translator/translator_config.py index a50c810..afaa906 100644 --- a/src/multilang_translator/translator_config.py +++ b/src/multilang_translator/translator_config.py @@ -14,22 +14,22 @@ class TranslatorLangConfig(auracast_config.AuracastBigConfig): tts_model: str ='de_DE-kerstin-low' -class TranslatorConfigDe(TranslatorLangConfig, auracast_config.AuracastBigConfigDe): +class TranslatorConfigDeu(TranslatorLangConfig, auracast_config.AuracastBigConfigDeu): tts_model: str ='de_DE-thorsten-high' -class TranslatorConfigEn(TranslatorLangConfig, auracast_config.AuracastBigConfigEn): +class TranslatorConfigEng(TranslatorLangConfig, auracast_config.AuracastBigConfigEng): tts_model: str = 'en_GB-alba-medium' -class TranslatorConfigFr(TranslatorLangConfig, auracast_config.AuracastBigConfigFr): +class TranslatorConfigFra(TranslatorLangConfig, auracast_config.AuracastBigConfigFra): tts_model: str = 'fr_FR-siwis-medium' -class TranslatorConfigEs(TranslatorLangConfig, auracast_config.AuracastBigConfigEs): +class TranslatorConfigSpa(TranslatorLangConfig, auracast_config.AuracastBigConfigSpa): tts_model: str = 'es_ES-sharvard-medium' -class TranslatorConfigIt(TranslatorLangConfig, auracast_config.AuracastBigConfigIt): +class TranslatorConfigIta(TranslatorLangConfig, auracast_config.AuracastBigConfigIta): tts_model: str = 'it_IT-paola-medium' class TranslatorConfigGroup(auracast_config.AuracastGlobalConfig): bigs: List[TranslatorLangConfig] = [ - TranslatorConfigDe(), + TranslatorConfigDeu(), ]