diff --git a/src/multilang_translator/translator_api/endpoints_db.py b/src/multilang_translator/translator_api/endpoints_db.py index dab28cb..0161118 100644 --- a/src/multilang_translator/translator_api/endpoints_db.py +++ b/src/multilang_translator/translator_api/endpoints_db.py @@ -2,60 +2,144 @@ Database file for endpoint definitions. This file contains configurations for auracast endpoints including their IP addresses and capabilities. """ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set from pydantic import BaseModel +from multilang_translator.translator_api.translator_models import EndpointGroup, Endpoint -class EndpointDefinition(BaseModel): - """Defines an endpoint with its URL and capabilities.""" - id: str - name: str - url: str - max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts - requires_authentication: bool = False # Whether authentication is required - +SUPPORTED_LANGUAGES = ["deu", "eng", "fra", "spa", "ita"] # Database of endpoints -ENDPOINTS: Dict[str, EndpointDefinition] = { - "endpoint0": EndpointDefinition( - id="endpoint0", +ENDPOINTS: dict[int: Endpoint] = { # for now make sure, .id and key are the same + 0: Endpoint( + id=0, name="Local Endpoint", url="http://localhost:5000", - max_broadcasts=2, + max_broadcasts=3, ), - "endpoint1": EndpointDefinition( - id="endpoint1", + 1: Endpoint( + id=1, name="Gate 1 Endpoint", url="http://192.168.1.101:5000", - max_broadcasts=2, + max_broadcasts=3, ), - "endpoint2": EndpointDefinition( - id="endpoint2", + 2: Endpoint( + id=2, name="Gate 2 Endpoint", url="http://192.168.1.102:5000", - max_broadcasts=1, + max_broadcasts=3, ), } +# Database of endpoint groups with default endpoints +ENDPOINT_GROUPS: dict[int:EndpointGroup] = { # for now make sure , .id and key are the same + 0: EndpointGroup( + id=0, + name="Gate1", + languages=["deu", "eng"], + endpoints=[ENDPOINTS[0]], + ), + 1: EndpointGroup( + id=1, + name="Gate2", + languages=["deu", "eng", "fra"], + endpoints=[ENDPOINTS[2]], + ) +} -def get_all_endpoints() -> List[EndpointDefinition]: +def get_available_languages() -> List[str]: + """Get a list of all supported languages.""" + return SUPPORTED_LANGUAGES + +# Endpoint functions +def get_all_endpoints() -> List[Endpoint]: """Get all active endpoints.""" - return ENDPOINTS.values() + return ENDPOINTS -def get_endpoint_by_id(endpoint_id: str) -> Optional[EndpointDefinition]: + +def get_endpoint_by_id(endpoint_id: str) -> Optional[Endpoint]: """Get an endpoint by its ID.""" - return ENDPOINTS.get(endpoint_id) + return ENDPOINTS[endpoint_id] -def get_max_broadcasts(endpoint_id: str) -> int: - """Get the maximum number of simultaneous broadcasts for a specific endpoint.""" - endpoint = get_endpoint_by_id(endpoint_id) - if not endpoint: - raise ValueError(f"Endpoint {endpoint_id} not found") - return endpoint.max_broadcasts -def get_endpoint_url(endpoint_id: str) -> str: - """Get the full URL for an endpoint.""" - endpoint = get_endpoint_by_id(endpoint_id) - if not endpoint: +def add_endpoint(endpoint: Endpoint) -> Endpoint: + """Add a new endpoint to the database.""" + if endpoint.id in ENDPOINTS: + raise ValueError(f"Endpoint with ID {endpoint.id} already exists") + ENDPOINTS[endpoint.id] = endpoint + return endpoint + + +def update_endpoint(endpoint_id: str, updated_endpoint: Endpoint) -> Endpoint: + """Update an existing endpoint in the database.""" + if endpoint_id not in ENDPOINTS: raise ValueError(f"Endpoint {endpoint_id} not found") - return endpoint.url + + # Ensure the ID is preserved + updated_endpoint.id = endpoint_id + ENDPOINTS[endpoint_id] = updated_endpoint + return updated_endpoint + + +def delete_endpoint(endpoint_id: str) -> None: + """Delete an endpoint from the database.""" + if endpoint_id not in ENDPOINTS: + raise ValueError(f"Endpoint {endpoint_id} not found") + + # Check if this endpoint is used in any groups + for group in ENDPOINT_GROUPS.values(): + if endpoint_id in group.endpoints: + raise ValueError(f"Cannot delete endpoint {endpoint_id}, it is used in group {group.id}") + + del ENDPOINTS[endpoint_id] + + +# Endpoint Group functions +def get_all_groups() -> List[EndpointGroup]: + """Get all endpoint groups.""" + return list(ENDPOINT_GROUPS.values()) + + +def get_group_by_id(group_id: int) -> Optional[EndpointGroup]: + """Get an endpoint group by its ID.""" + return ENDPOINT_GROUPS.get(group_id) + + +def add_group(group: EndpointGroup) -> EndpointGroup: + """Add a new endpoint group to the database.""" + if group.id in ENDPOINT_GROUPS: + raise ValueError(f"Group with ID {group.id} already exists") + + # Validate that all referenced endpoints exist + for endpoint_id in group.endpoints: + if endpoint_id not in ENDPOINTS: + raise ValueError(f"Endpoint {endpoint_id} not found") + + ENDPOINT_GROUPS[group.id] = group + return group + + +def update_group(group_id: int, updated_group: EndpointGroup) -> EndpointGroup: + """Update an existing endpoint group in the database.""" + if group_id not in ENDPOINT_GROUPS: + raise ValueError(f"Group {group_id} not found") + + # Validate that all referenced endpoints exist + for endpoint in updated_group.endpoints: + if endpoint.id not in ENDPOINTS.keys(): + raise ValueError(f"Endpoint {endpoint_id} not found") + + # Ensure the ID is preserved + updated_group.id = group_id + ENDPOINT_GROUPS[group_id] = updated_group + return updated_group + + +def delete_group(group_id: int) -> None: + """Delete an endpoint group from the database.""" + if group_id not in ENDPOINT_GROUPS: + raise ValueError(f"Group {group_id} not found") + + del ENDPOINT_GROUPS[group_id] + + diff --git a/src/multilang_translator/translator_api/translator_api.py b/src/multilang_translator/translator_api/translator_api.py index 7b80ecd..cdfa76a 100644 --- a/src/multilang_translator/translator_api/translator_api.py +++ b/src/multilang_translator/translator_api/translator_api.py @@ -4,22 +4,18 @@ This API mimics the mock_api from auracaster-webui to allow integration. """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -import sys -import os import asyncio -import threading import time -import logging as log -from typing import List, Dict, Optional +import logging as log # Import models -from multilang_translator.translator_api.translator_models import EndpointGroup, AnnouncementStates -from multilang_translator.translator_config import TranslatorConfigGroup, TranslatorConfigDeu, TranslatorConfigEng, TranslatorConfigFra +from multilang_translator.translator_api.translator_models import AnnouncementStates, Endpoint, EndpointGroup from multilang_translator.translator import llm_translator -from auracast import multicast_control +from voice_provider import text_to_speech + # Import the endpoints database and multicast client from multilang_translator.translator_api import endpoints_db -from auracast import multicast_client +from auracast import multicast_client, auracast_config # Create FastAPI app app = FastAPI() @@ -33,530 +29,226 @@ app.add_middleware( allow_headers=["*"], ) -# Global variables -config = TranslatorConfigGroup( - bigs=[ - TranslatorConfigDeu(), - TranslatorConfigEng(), - TranslatorConfigFra(), - ] -) - -# Configure the transport -config.transport = 'serial:/dev/serial/by-id/usb-ZEPHYR_Zephyr_HCI_UART_sample_81BD14B8D71B5662-if00,115200,rtscts' # nrf52dongle hci_uart usb cdc - -# Configure the bigs -for conf in config.bigs: - conf.loop = False - conf.llm_client = 'openwebui' # comment out for local llm - conf.llm_host_url = 'https://ollama.pstruebi.xyz' - conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13' # Get available endpoints and languages from the database AVAILABLE_ENDPOINTS = endpoints_db.get_all_endpoints() -AVAILABLE_LANGUAGES = ['deu', 'eng', 'fra'] +AVAILABLE_LANGUAGES = endpoints_db.get_available_languages() -# Default endpoint groups -endpoint_groups = [ - EndpointGroup( - id=1, - name="Gate1", - endpoints=["endpoint1", "endpoint2"], - languages=["deu", "eng"] - ), - EndpointGroup( - id=2, - name="Gate2", - endpoints=["endpoint3"], - languages=["deu", "eng", "fra"] - ) -] +CURRENT_ENDPOINT_CONFIG = {} -# Announcement state tracking -active_group_id = None -last_completed_group_id = None -announcement_task = None -caster = None -reset_task = None - -# Initialization flag -initialized = False - -# Dictionary to track endpoint status and active broadcasts -endpoint_status = {} - - -async def init_caster(): - """Initialize the multicaster if not already initialized.""" - global caster, initialized - - if not initialized: - log.info("Initializing caster...") - try: - caster = multicast_control.Multicaster( - config, - [big for big in config.bigs] - ) - await caster.init_broadcast() - initialized = True - log.info("Caster initialized successfully") - except Exception as e: - log.error(f"Failed to initialize caster: {e}") - raise e - - -async def init_endpoint(endpoint_id: str): +def init_endpoint(endpoint: Endpoint, languages: list[str]): """Initialize a specific endpoint for multicast.""" - endpoint = endpoints_db.get_endpoint_by_id(endpoint_id) - if not endpoint: - raise ValueError(f"Endpoint {endpoint_id} not found") - - log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.ip_address}:{endpoint.port}") - - # Update the BASE_URL in multicast_client for this endpoint - multicast_client.BASE_URL = f"http://{endpoint.ip_address}:{endpoint.port}" - + + log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.url}") # Create a config for this endpoint - endpoint_config = multicast_client.AuracastConfigGroup( + config = multicast_client.AuracastConfigGroup( bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")() - for lang in endpoint.supported_languages] + for lang in languages] ) - endpoint_config.transport = config.transport - + + # some default configs (for now) + # Configure the transport + config.transport = 'auto' + # Configure the bigs - for conf in endpoint_config.bigs: - conf.loop = False - - try: - # Initialize the endpoint - multicast_client.request_init(endpoint_config) - endpoint_status[endpoint_id] = { - "active": True, - "broadcasts": 0, - "max_broadcasts": endpoint.max_broadcasts - } - log.info(f"Endpoint {endpoint_id} initialized successfully") - except Exception as e: - log.error(f"Failed to initialize endpoint {endpoint_id}: {e}") - endpoint_status[endpoint_id] = { - "active": False, - "error": str(e) - } - raise e + for conf in config.bigs: # TODO: this is now part of the endpoint group config + conf.loop = False + # conf.llm_client = 'openwebui' # comment out for local llm + # conf.llm_host_url = 'https://ollama.pstruebi.xyz' + # conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13' + + # Initialize the endpoint if config changed or if it's not already initialized + if not multicast_client.get_status(base_url=endpoint.url)['is_initialized'] or config != CURRENT_ENDPOINT_CONFIG.get(endpoint.id): + multicast_client.request_init(config, base_url=endpoint.url) + CURRENT_ENDPOINT_CONFIG[endpoint.id] = config + else: + log.info('Endpoint %s was already initialized', endpoint.name) + + log.info(f"Endpoint {endpoint.name} initialized successfully") -async def process_announcement(text: str, group: EndpointGroup): +async def make_announcement(text: str, ep_group: endpoints_db.EndpointGroup): """ Process an announcement using the multilang_translator. This function now uses the multicast_client to send announcements to endpoints. """ - global active_group_id, last_completed_group_id, caster, reset_task + + ep_group.current_state = AnnouncementStates.IDLE + ep_group.anouncement_start_time = time.time() + # update the database with the new state and start time so this can be read by another process + endpoints_db.update_group(ep_group.id, ep_group) + - # Make sure the caster is initialized - if not initialized: - await init_caster() - - try: - # Set start time and track the active group - group.parameters.text = text - group.parameters.languages = group.languages - group.parameters.start_time = time.time() - active_group_id = group.id - - # Update status to translating - group.progress.current_state = AnnouncementStates.TRANSLATING.value - group.progress.progress = 0.2 - - # Initialize all endpoints in the group - endpoints_to_broadcast = [] - for endpoint_id in group.endpoints: - endpoint = endpoints_db.get_endpoint_by_id(endpoint_id) - if not endpoint: - log.warning(f"Endpoint {endpoint_id} not found, skipping") - continue - - # Check if the endpoint supports all the required languages - if not all(lang in endpoint.supported_languages for lang in group.languages): - log.warning(f"Endpoint {endpoint_id} does not support all required languages, skipping") - continue - - # Check if the endpoint has room for another broadcast - if endpoint_id in endpoint_status and endpoint_status[endpoint_id]["active"]: - if endpoint_status[endpoint_id]["broadcasts"] >= endpoint_status[endpoint_id]["max_broadcasts"]: - log.warning(f"Endpoint {endpoint_id} already at maximum broadcasts, skipping") - continue - - endpoint_status[endpoint_id]["broadcasts"] += 1 - endpoints_to_broadcast.append(endpoint) - else: - # Initialize the endpoint if not already active - try: - await init_endpoint(endpoint_id) - endpoint_status[endpoint_id]["broadcasts"] += 1 - endpoints_to_broadcast.append(endpoint) - except Exception as e: - log.error(f"Could not initialize endpoint {endpoint_id}: {e}") - continue - - if not endpoints_to_broadcast: - raise HTTPException(status_code=400, detail="No valid endpoints available for broadcast") - - # Translate the text for each language - base_lang = "deu" # German is the base language - audio_data = {} - - for i, big in enumerate(config.bigs): - # Check if this language is in the requested languages - if big.language not in group.languages: - continue - - # Translate if not the base language - if big.language == base_lang: - translated_text = text - else: - group.progress.current_state = AnnouncementStates.TRANSLATING.value - translated_text = llm_translator.translate_de_to_x( - text, - big.language, - model=big.translator_llm, - client=big.llm_client, - host=big.llm_host_url, - token=big.llm_host_token - ) - - log.info(f'Translated text ({big.language}): {translated_text}') - - # Update status to generating voice - group.progress.current_state = AnnouncementStates.GENERATING_VOICE.value - group.progress.progress = 0.4 - - # This will use the voice_provider's text_to_speech.synthesize function - from voice_provider import text_to_speech - lc3_audio = text_to_speech.synthesize( - translated_text, - config.auracast_sampling_rate_hz, - big.tts_system, - big.tts_model, - return_lc3=True + # Initialize all endpoints in the group if they were not initalized before + for endpoint in ep_group.endpoints: + init_endpoint(endpoint, ep_group.languages) + + # Translate the text for each language + base_lang = "deu" # German is the base language + ep_group.current_state = AnnouncementStates.TRANSLATING + endpoints_db.update_group(ep_group.id, ep_group) + + translations = {base_lang: text} + for lang in ep_group.languages: + if lang != base_lang: + # Translate the text + lang_conf = getattr(ep_group.translator_config, lang) + translation = llm_translator.translate_de_to_x( + text=text, + target_language=lang, + client=lang_conf.llm_client, + model=lang_conf.translator_llm, + host=lang_conf.llm_host_url, + token=lang_conf.llm_host_token ) - # Add the audio to the audio_data dictionary - audio_data[big.language] = lc3_audio.decode('latin-1') if isinstance(lc3_audio, bytes) else lc3_audio - - # Set the audio source for this language (for the traditional caster) - caster.big_conf[i].audio_source = lc3_audio - - # Update status to routing - group.progress.current_state = AnnouncementStates.ROUTING.value - group.progress.progress = 0.6 - await asyncio.sleep(0.5) # Small delay for routing - - # Update status to active and start streaming - group.progress.current_state = AnnouncementStates.ACTIVE.value - group.progress.progress = 0.8 - - # Send the audio to each endpoint using multicast_client - broadcast_tasks = [] - for endpoint in endpoints_to_broadcast: - # Set the BASE_URL for this endpoint - multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint.id) - - # Create a task to send the audio to this endpoint - broadcast_tasks.append(asyncio.create_task( - asyncio.to_thread(multicast_client.send_audio, audio_data) - )) - - # Also start the traditional caster - caster.start_streaming() - - # Wait for all broadcasts to complete - await asyncio.gather(*broadcast_tasks) - - # Wait for streaming to complete with the traditional caster - await caster.streamer.task - - # Update endpoint status - for endpoint in endpoints_to_broadcast: - if endpoint.id in endpoint_status and endpoint_status[endpoint.id]["broadcasts"] > 0: - endpoint_status[endpoint.id]["broadcasts"] -= 1 - - # Update status to complete - group.progress.current_state = AnnouncementStates.COMPLETE.value - group.progress.progress = 1.0 - last_completed_group_id = group.id - - # Reset active group if this is still the active one - if active_group_id == group.id: - active_group_id = None - - # After a while, reset to idle state (in a separate task) - async def reset_to_idle(): - log.info(f"Waiting 10 seconds before resetting group {group.id} to IDLE state") - await asyncio.sleep(10) # Keep completed state visible for 10 seconds - log.info(f"Resetting group {group.id} to IDLE state") - # Use direct value lookup for the state comparison - if group.progress.current_state == AnnouncementStates.COMPLETE.value: - group.progress.current_state = AnnouncementStates.IDLE.value - group.progress.progress = 0.0 - log.info(f"Group {group.id} state reset to IDLE") - - # Create and save the task so it won't be garbage collected - reset_task = asyncio.create_task(reset_to_idle()) - - except Exception as e: - log.error(f"Error during announcement processing: {e}") - group.progress.current_state = AnnouncementStates.ERROR.value - group.progress.error = str(e) - - # Reset active group if this is still the active one - if active_group_id == group.id: - active_group_id = None - - raise HTTPException(status_code=500, detail=str(e)) + translations[lang] = translation + log.info(f"Translated to {lang}: {translation}") -@app.get("/api/groups") + # Generate voices + ep_group.current_state = AnnouncementStates.GENERATING_VOICE + endpoints_db.update_group(ep_group.id, ep_group) + + audio = {} + for lang, text in translations.items(): + # Get the appropriate language configuration + lang_conf = getattr(ep_group.translator_config, lang) + + # Convert text to LC3-encoded audio using the configuration's TTS settings + audio[lang] = text_to_speech.synthesize( + translations[lang], + 16000, # Sample rate from auracast config # TODO: take sampling rate from auracast config + lang_conf.tts_system, # TTS system from config + lang_conf.tts_model, # TTS model from config + return_lc3=True + ).decode('latin-1') + + # Add to audio data dictionary (decode bytes to string for JSON serialization) + + # Update group progress + ep_group.current_state = AnnouncementStates.BROADCASTING + endpoints_db.update_group(ep_group.id, ep_group) + + + # Broadcast to all endpoints in group + for endpoint in ep_group.endpoints: + # Send the audio data to the server using the existing send_audio function + log.info(f"Broadcastcasting to {endpoint.name} for languages: {', '.join(audio.keys())}") + multicast_client.send_audio(audio, base_url=endpoint.url) + + + # Update group state to completed + #ep_group.current_state = AnnouncementStates.COMPLETED # TODO: somehow the group state needs to be updated to completed after a broadcast + + # Schedule group endpoint reset + # if ep_group.reset_task and not ep_group.reset_task.done(): + # ep_group.reset_task.cancel() + + # ep_group.reset_task = asyncio.create_task(reset_endpoints_after_delay(ep_group.endpoints, 60)) + + # Return the translations + return {"translations": translations} + + +async def reset_endpoints_after_delay(endpoint_ids, delay_seconds): + """Reset endpoints after a delay.""" + await asyncio.sleep(delay_seconds) + + for endpoint_id in endpoint_ids: + if endpoint_id in endpoint_status and endpoint_status[endpoint_id]["active"]: + try: + endpoint_status[endpoint_id]["broadcasts"] = 0 + log.info(f"Reset broadcasts count for endpoint {endpoint_id}") + except Exception as e: + log.error(f"Failed to reset endpoint {endpoint_id}: {e}") + + +@app.get("/groups") async def get_groups(): - """Get all endpoint groups.""" - return endpoint_groups + """Get all endpoint groups with their current status.""" + return endpoints_db.get_all_groups() -@app.post("/api/groups") -async def create_group(group: EndpointGroup): + +@app.post("/groups") +async def create_group(group: endpoints_db.EndpointGroup): """Add a new endpoint group.""" - # Validate endpoints - for endpoint_id in group.endpoints: - if endpoint_id not in AVAILABLE_ENDPOINTS: - raise HTTPException(status_code=400, detail=f"Endpoint {endpoint_id} is not available") - - # Validate languages - for language in group.languages: - if language not in AVAILABLE_LANGUAGES: - raise HTTPException(status_code=400, detail=f"Language {language} is not available") - - # Add the group - endpoint_groups.append(group) - return group + try: + return endpoints_db.add_group(group) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) -@app.put("/api/groups/{group_id}") -async def update_group(group_id: int, updated_group: EndpointGroup): +@app.put("/groups/{group_id}") +async def update_group(group_id: int, updated_group: endpoints_db.EndpointGroup): """Update an existing endpoint group.""" - for i, group in enumerate(endpoint_groups): - if group.id == group_id: - # Validate endpoints - for endpoint_id in updated_group.endpoints: - if endpoint_id not in AVAILABLE_ENDPOINTS: - raise HTTPException(status_code=400, detail=f"Endpoint {endpoint_id} is not available") - - # Validate languages - for language in updated_group.languages: - if language not in AVAILABLE_LANGUAGES: - raise HTTPException(status_code=400, detail=f"Language {language} is not available") - - # Update the group - endpoint_groups[i] = updated_group - return updated_group - - raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found") + try: + return endpoints_db.update_group(group_id, updated_group) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) -@app.delete("/api/groups/{group_id}") +@app.delete("/groups/{group_id}") async def delete_group(group_id: int): """Delete an endpoint group.""" - for i, group in enumerate(endpoint_groups): - if group.id == group_id: - del endpoint_groups[i] - return {"detail": f"Group with id {group_id} deleted"} - - raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found") + try: + endpoints_db.delete_group(group_id) + return {"message": f"Group {group_id} deleted successfully"} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) -@app.post("/api/announcements") +@app.post("/announcement") async def start_announcement(text: str, group_id: int): """Start a new announcement to the specified endpoint group.""" global announcement_task - # Find the group - target_group = None - for group in endpoint_groups: - if group.id == group_id: - target_group = group - break + # Get the group from active groups or database + group = endpoints_db.get_group_by_id(group_id) + if not group: + raise HTTPException(status_code=400, detail=f"Group {group_id} not found") - if not target_group: - raise HTTPException(status_code=404, detail=f"Group with id {group_id} not found") + # Check if we're already processing an announcement + #if announcement_task and not announcement_task.done(): + # raise HTTPException(status_code=400, detail="Already processing an announcement") - # Check if an announcement is already in progress - if active_group_id is not None: - raise HTTPException(status_code=400, detail="An announcement is already in progress") - - # Process the announcement - announcement_task = asyncio.create_task(process_announcement(text, target_group)) - - return {"detail": "Announcement started"} + # Start the announcement task + announcement_task = asyncio.create_task(make_announcement(text, group)) + return {"status": "Announcement started", "group_id": group_id} -@app.get("/api/announcements/status") -async def get_announcement_status(): - """Get the status of the current announcement.""" - # If no group is active, check if the last completed group exists - if active_group_id is None: - if last_completed_group_id is not None: - # Find the last completed group - for group in endpoint_groups: - if group.id == last_completed_group_id: - return { - "active_group": None, - "last_completed_group": { - "id": group.id, - "name": group.name, - "endpoints": group.endpoints, - "languages": group.languages, - "progress": { - "current_state": group.progress.current_state, - "progress": group.progress.progress, - "error": group.progress.error - }, - "parameters": { - "text": group.parameters.text, - "languages": group.parameters.languages, - "start_time": group.parameters.start_time - } - } - } - - # If the last completed group couldn't be found - return { - "active_group": None, - "last_completed_group": None - } - else: - # No active or last completed group - return { - "active_group": None, - "last_completed_group": None - } - - # If a group is active, return its information - for group in endpoint_groups: - if group.id == active_group_id: - return { - "active_group": { - "id": group.id, - "name": group.name, - "endpoints": group.endpoints, - "languages": group.languages, - "progress": { - "current_state": group.progress.current_state, - "progress": group.progress.progress, - "error": group.progress.error - }, - "parameters": { - "text": group.parameters.text, - "languages": group.parameters.languages, - "start_time": group.parameters.start_time - } - }, - "last_completed_group": None - } - - # If the active group couldn't be found - return { - "active_group": None, - "last_completed_group": None - } - -@app.get("/api/endpoints") -async def get_available_endpoints(): - """Get all available endpoints with their capabilities.""" - return [endpoint_db.dict() for endpoint_db in endpoints_db.get_all_endpoints()] - - -@app.get("/api/languages") -async def get_available_languages(): - """Get all available languages for announcements.""" - return endpoints_db.get_available_languages() - - -@app.get("/api/endpoints/{endpoint_id}") +@app.get("/endpoints/{endpoint_id}/status") # TODO: think about progress tracking async def get_endpoint_status(endpoint_id: str): """Get the status of a specific endpoint.""" + # Check if the endpoint exists endpoint = endpoints_db.get_endpoint_by_id(endpoint_id) if not endpoint: raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_id} not found") - status = endpoint_status.get(endpoint_id, {"active": False, "broadcasts": 0}) - return { - "endpoint": endpoint.dict(), - "status": status - } + # Return the status if available, otherwise a default status + if endpoint_id in endpoint_status: + return endpoint_status[endpoint_id] + else: + return { + "active": False, + "broadcasts": 0, + } -@app.post("/api/endpoints/{endpoint_id}/reset") -async def reset_endpoint(endpoint_id: str): - """Reset an endpoint's status.""" - endpoint = endpoints_db.get_endpoint_by_id(endpoint_id) - if not endpoint: - raise HTTPException(status_code=404, detail=f"Endpoint {endpoint_id} not found") - - try: - # Set the BASE_URL for this endpoint - multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint_id) - - # Stop any current audio - multicast_client.stop_audio() - - # Reinitialize the endpoint - await init_endpoint(endpoint_id) - - return {"detail": f"Endpoint {endpoint_id} reset successfully"} - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to reset endpoint: {str(e)}") +@app.get("/endpoints") +async def get_available_endpoints(): + """Get all available endpoints with their capabilities.""" + return AVAILABLE_ENDPOINTS -@app.on_event("startup") -async def startup_event(): - """Initialize the caster on startup.""" - global caster, initialized - - try: - await init_caster() - - # Initialize all endpoints - for endpoint_id in AVAILABLE_ENDPOINTS: - try: - await init_endpoint(endpoint_id) - except Exception as e: - log.error(f"Failed to initialize endpoint {endpoint_id}: {e}") - except Exception as e: - log.error(f"Startup error: {e}") +@app.get("/languages") +async def get_available_languages(): + """Get all available languages for announcements.""" + return AVAILABLE_LANGUAGES -@app.on_event("shutdown") -async def shutdown_event(): - """Clean up resources on shutdown.""" - global caster - - if caster is not None: - try: - log.info("Stopping caster...") - caster.stop_streaming() - caster = None - log.info("Caster stopped") - except Exception as e: - log.error(f"Error during caster shutdown: {e}") - - # Shutdown all active endpoints - for endpoint_id in endpoint_status: - if endpoint_status[endpoint_id].get("active", False): - try: - multicast_client.BASE_URL = endpoints_db.get_endpoint_url(endpoint_id) - multicast_client.shutdown() - except Exception as e: - log.error(f"Error shutting down endpoint {endpoint_id}: {e}") - if __name__ == "__main__": import uvicorn @@ -564,11 +256,13 @@ if __name__ == "__main__": level=log.INFO, format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s' ) - log.info("Starting Translator API server") + # with reload=True logging of modules does not function as expected uvicorn.run( - "translator_api:app", - host="0.0.0.0", - port=7999, - reload=True, - log_level="debug" - ) \ No newline at end of file + app, + #'translator_api:app', + host="0.0.0.0", + port=7999, + #reload=True, + #log_config=None, + #log_level="info" + ) diff --git a/src/multilang_translator/translator_api/translator_models.py b/src/multilang_translator/translator_api/translator_models.py index 547133f..f27ca52 100644 --- a/src/multilang_translator/translator_api/translator_models.py +++ b/src/multilang_translator/translator_api/translator_models.py @@ -3,36 +3,50 @@ Models for the translator API. Similar to the models used in auracaster-webui but simplified for the translator middleware. """ from enum import Enum -from typing import List, Optional, Dict, Any +from typing import List, Optional from pydantic import BaseModel -class AnnouncementStates(str, Enum): - IDLE = "idle" - TRANSLATING = "translating" - GENERATING_VOICE = "generating_voice" - ROUTING = "routing" - ACTIVE = "active" - COMPLETE = "complete" - ERROR = "error" +class AnnouncementStates(Enum): + IDLE = 0 + TRANSLATING = 0.2 + GENERATING_VOICE = 0.4 + ROUTING = 0.6 + BROADCASTING = 0.8 + COMPLETED = 1 + ERROR = 0 -class AnnouncementProgress(BaseModel): - current_state: str = AnnouncementStates.IDLE - progress: float = 0.0 - error: Optional[str] = None +class Endpoint(BaseModel): + """Defines an endpoint with its URL and capabilities.""" + id: int + name: str + url: str + max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts -class AnnouncementParameters(BaseModel): - text: str = "" - languages: List[str] = [] - start_time: float = 0.0 +class TranslatorLangConfig(BaseModel): + translator_llm: str = 'llama3.2:3b-instruct-q4_0' + llm_client: str = 'ollama' + llm_host_url: str | None = 'http://localhost:11434' + llm_host_token: str | None = None + tts_system: str = 'piper' + tts_model: str ='de_DE-kerstin-low' + + +class TranslatorConfig(BaseModel): + deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high') + eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium') + fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium') + spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium') + ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium') class EndpointGroup(BaseModel): id: int name: str - endpoints: List[str] languages: List[str] - progress: AnnouncementProgress = AnnouncementProgress() - parameters: AnnouncementParameters = AnnouncementParameters() + endpoints: List[Endpoint] + translator_config: TranslatorConfig = TranslatorConfig() + current_state: str = AnnouncementStates.IDLE + anouncement_start_time: float = 0.0 diff --git a/src/multilang_translator/translator_config.py b/src/multilang_translator/translator_config.py index afaa906..0066b1b 100644 --- a/src/multilang_translator/translator_config.py +++ b/src/multilang_translator/translator_config.py @@ -1,35 +1,20 @@ import os -from typing import List from pydantic import BaseModel -from auracast import auracast_config VENV_DIR = os.path.join(os.path.dirname(__file__), './../../venv') -class TranslatorLangConfig(auracast_config.AuracastBigConfig): - translator_llm: str = 'llama3.2:3b-instruct-q4_0' +class TranslatorLangConfig(BaseModel): + translator_llm: str = 'llama3.2:3b-instruct-q4_0' # TODO: this was migrated to translator_models - remove this llm_client: str = 'ollama' llm_host_url: str | None = 'http://localhost:11434' llm_host_token: str | None = None tts_system: str = 'piper' tts_model: str ='de_DE-kerstin-low' +class TranslatorConfig(BaseModel): + deu: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'de_DE-thorsten-high') + eng: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'en_GB-alba-medium') + fra: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'fr_FR-siwis-medium') + spa: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'es_ES-sharvard-medium') + ita: TranslatorLangConfig = TranslatorLangConfig(tts_model = 'it_IT-paola-medium') -class TranslatorConfigDeu(TranslatorLangConfig, auracast_config.AuracastBigConfigDeu): - tts_model: str ='de_DE-thorsten-high' - -class TranslatorConfigEng(TranslatorLangConfig, auracast_config.AuracastBigConfigEng): - tts_model: str = 'en_GB-alba-medium' - -class TranslatorConfigFra(TranslatorLangConfig, auracast_config.AuracastBigConfigFra): - tts_model: str = 'fr_FR-siwis-medium' - -class TranslatorConfigSpa(TranslatorLangConfig, auracast_config.AuracastBigConfigSpa): - tts_model: str = 'es_ES-sharvard-medium' - -class TranslatorConfigIta(TranslatorLangConfig, auracast_config.AuracastBigConfigIta): - tts_model: str = 'it_IT-paola-medium' - -class TranslatorConfigGroup(auracast_config.AuracastGlobalConfig): - bigs: List[TranslatorLangConfig] = [ - TranslatorConfigDeu(), - ]