move translator_client
This commit is contained in:
@@ -22,6 +22,7 @@ dependencies = [
|
|||||||
"fastapi>=0.95.0",
|
"fastapi>=0.95.0",
|
||||||
"uvicorn>=0.22.0",
|
"uvicorn>=0.22.0",
|
||||||
"pydantic>=1.10.0",
|
"pydantic>=1.10.0",
|
||||||
|
"multilang-translator>=0.1.0" # TODO: this should only include
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
"""
|
|
||||||
API client functions for interacting with the Airport Announcement System backend.
|
|
||||||
"""
|
|
||||||
import requests
|
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
|
|
||||||
# This can be overridden through environment variables
|
|
||||||
API_BASE_URL = "http://localhost:7999"
|
|
||||||
|
|
||||||
def get_groups() -> List[dict]:
|
|
||||||
"""Get all endpoint groups."""
|
|
||||||
response = requests.get(f"{API_BASE_URL}/groups")
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def get_group(group_id: int) -> Optional[dict]:
|
|
||||||
"""Get a specific endpoint group by ID."""
|
|
||||||
response = requests.get(f"{API_BASE_URL}/groups/{group_id}")
|
|
||||||
if response.status_code == 404:
|
|
||||||
return None
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def create_group(group: dict) -> dict:
|
|
||||||
"""Create a new endpoint group."""
|
|
||||||
response = requests.post(f"{API_BASE_URL}/groups", json=group)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def update_group(group_id: int, updated_group: dict) -> dict:
|
|
||||||
"""Update an existing endpoint group."""
|
|
||||||
response = requests.put(f"{API_BASE_URL}/groups/{group_id}", json=updated_group)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def delete_group(group_id: int) -> None:
|
|
||||||
"""Delete an endpoint group."""
|
|
||||||
response = requests.delete(f"{API_BASE_URL}/groups/{group_id}")
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
def start_announcement(text: str, group_id: int) -> None:
|
|
||||||
"""Start a new announcement."""
|
|
||||||
# Changed from /announcements to /announcement to match translator_api.py
|
|
||||||
response = requests.post(f"{API_BASE_URL}/announcement", params={"text": text, "group_id": group_id})
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
def get_group_state(group_id: int) -> Dict[str, Any]:
|
|
||||||
"""Get the status of the current announcement for a specific group.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group_id: The ID of the group to check the announcement status for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with 'name' and 'value' keys representing the current state
|
|
||||||
"""
|
|
||||||
response = requests.get(f"{API_BASE_URL}/groups/{group_id}/state")
|
|
||||||
|
|
||||||
state = response.json()
|
|
||||||
return state
|
|
||||||
|
|
||||||
def get_available_endpoints() -> List[str]:
|
|
||||||
"""Get all available endpoints."""
|
|
||||||
response = requests.get(f"{API_BASE_URL}/endpoints")
|
|
||||||
response.raise_for_status()
|
|
||||||
# Transform the endpoint objects to match the expected format in app.py
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def get_available_languages() -> List[str]:
|
|
||||||
"""Get all available languages for announcements."""
|
|
||||||
response = requests.get(f"{API_BASE_URL}/languages")
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
"""
|
|
||||||
API models for the Airport Announcement System.
|
|
||||||
"""
|
|
||||||
from enum import Enum
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
class AnnouncementStates(Enum):
|
|
||||||
IDLE: str = "Ready"
|
|
||||||
TRANSLATING: str = "Translating"
|
|
||||||
GENERATING_VOICE: str = "Generating voice synthesis"
|
|
||||||
ROUTING: str = "Routing to endpoints"
|
|
||||||
ACTIVE: str = "Broadcasting announcement"
|
|
||||||
COMPLETE: str = "Complete"
|
|
||||||
ERROR: str = "Error"
|
|
||||||
|
|
||||||
class AnnouncementParameters(BaseModel):
|
|
||||||
text: Optional[str] = None
|
|
||||||
languages: List[str] = []
|
|
||||||
start_time: Optional[float] = None
|
|
||||||
|
|
||||||
class AnnouncementProgress(BaseModel):
|
|
||||||
current_state: str = AnnouncementStates.IDLE.value
|
|
||||||
progress: float = Field(default=0.0, ge=0.0, le=1.0)
|
|
||||||
error: Optional[str] = None
|
|
||||||
|
|
||||||
class EndpointGroup(BaseModel):
|
|
||||||
id: int
|
|
||||||
name: str
|
|
||||||
endpoints: List[str]
|
|
||||||
languages: List[str]
|
|
||||||
|
|
||||||
# Announcement parameters and progress as nested models
|
|
||||||
parameters: AnnouncementParameters = Field(default_factory=AnnouncementParameters)
|
|
||||||
progress: AnnouncementProgress = Field(default_factory=AnnouncementProgress)
|
|
||||||
@@ -8,10 +8,11 @@ st.set_page_config(page_title="Airport Announcement System", page_icon="✈️")
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
import requests
|
import requests
|
||||||
from api_client.client import (
|
from multilang_translator.translator_client.translator_client import (
|
||||||
get_groups, get_available_languages, get_group_state,
|
get_groups, get_available_languages, get_group_state,
|
||||||
start_announcement, update_group, get_available_endpoints
|
start_announcement, update_group, get_available_endpoints
|
||||||
)
|
)
|
||||||
|
from multilang_translator.translator_api.translator_models import Endpoint, EndpointGroup, AnnouncementStates
|
||||||
|
|
||||||
# Initialize session state for configuration
|
# Initialize session state for configuration
|
||||||
if "endpoint_groups" not in st.session_state:
|
if "endpoint_groups" not in st.session_state:
|
||||||
@@ -48,16 +49,16 @@ def show_announcement_status(group_id: int):
|
|||||||
"""Show the status of an announcement for a specific group."""
|
"""Show the status of an announcement for a specific group."""
|
||||||
try:
|
try:
|
||||||
# Get the group for additional information
|
# Get the group for additional information
|
||||||
group = next((g for g in st.session_state.endpoint_groups if g["id"] == group_id), None)
|
group = next((g for g in st.session_state.endpoint_groups if g.id == group_id), None)
|
||||||
if not group:
|
if not group:
|
||||||
st.error(f"Group with ID {group_id} not found")
|
st.error(f"Group with ID {group_id} not found")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get the state from the API
|
# Get the state from the API
|
||||||
state = get_group_state(group_id)
|
state_name, state_value = get_group_state(group_id)
|
||||||
|
|
||||||
# Only show status if state is not IDLE
|
# Only show status if state is not IDLE
|
||||||
if state["value"] != 0:
|
if state_value != AnnouncementStates.IDLE.value:
|
||||||
# Create a container with a unique key for each announcement
|
# Create a container with a unique key for each announcement
|
||||||
# This ensures we get a fresh container for each new announcement
|
# This ensures we get a fresh container for each new announcement
|
||||||
with st.container(key=f"status_container_{st.session_state.status_container_key}"):
|
with st.container(key=f"status_container_{st.session_state.status_container_key}"):
|
||||||
@@ -67,8 +68,8 @@ def show_announcement_status(group_id: int):
|
|||||||
with status:
|
with status:
|
||||||
# Calculate progress from state value (normalize to 0.0-1.0 range)
|
# Calculate progress from state value (normalize to 0.0-1.0 range)
|
||||||
# Assuming states are ordered from IDLE(0) to COMPLETED(4)
|
# Assuming states are ordered from IDLE(0) to COMPLETED(4)
|
||||||
max_state_value = 1.0 # COMPLETED is the maximum state value
|
max_state_value = AnnouncementStates.COMPLETED.value # COMPLETED is the maximum state value
|
||||||
progress = min(state["value"] / max_state_value, 1.0)
|
progress = min(state_value / max_state_value, 1.0)
|
||||||
|
|
||||||
# Progress elements
|
# Progress elements
|
||||||
progress_bar = st.progress(progress)
|
progress_bar = st.progress(progress)
|
||||||
@@ -82,7 +83,7 @@ def show_announcement_status(group_id: int):
|
|||||||
start_tracking_time = time.time()
|
start_tracking_time = time.time()
|
||||||
|
|
||||||
# Update loop
|
# Update loop
|
||||||
while state["name"] not in ["COMPLETED", "IDLE", "ERROR"]:
|
while state_name not in ["COMPLETED", "IDLE", "ERROR"]:
|
||||||
# Check for timeout
|
# Check for timeout
|
||||||
elapsed_time = time.time() - start_tracking_time
|
elapsed_time = time.time() - start_tracking_time
|
||||||
if elapsed_time > max_wait_time:
|
if elapsed_time > max_wait_time:
|
||||||
@@ -90,8 +91,9 @@ def show_announcement_status(group_id: int):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Only update stage display and elapsed time if state changed
|
# Only update stage display and elapsed time if state changed
|
||||||
if state != last_state:
|
current_state = (state_name, state_value)
|
||||||
stage_col.write(f"**Stage:** {state['name']}")
|
if current_state != last_state:
|
||||||
|
stage_col.write(f"**Stage:** {state_name}")
|
||||||
|
|
||||||
# Update elapsed time only on state change
|
# Update elapsed time only on state change
|
||||||
if "announcement_start_time" in st.session_state:
|
if "announcement_start_time" in st.session_state:
|
||||||
@@ -103,24 +105,24 @@ def show_announcement_status(group_id: int):
|
|||||||
elapsed_seconds = time.time() - start_tracking_time
|
elapsed_seconds = time.time() - start_tracking_time
|
||||||
time_col.write(f"⏱️ {elapsed_seconds:.1f}s")
|
time_col.write(f"⏱️ {elapsed_seconds:.1f}s")
|
||||||
|
|
||||||
last_state = state
|
last_state = current_state
|
||||||
|
|
||||||
# Update progress bar directly from state value
|
# Update progress bar directly from state value
|
||||||
progress = min(state["value"] / max_state_value, 1.0)
|
progress = min(state_value / max_state_value, 1.0)
|
||||||
progress_bar.progress(progress)
|
progress_bar.progress(progress)
|
||||||
|
|
||||||
# Add a small delay between requests to avoid hammering the API
|
# Add a small delay between requests to avoid hammering the API
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
# Refresh state
|
# Refresh state
|
||||||
state = get_group_state(group_id)
|
state_name, state_value = get_group_state(group_id)
|
||||||
|
|
||||||
# Final update to progress bar
|
# Final update to progress bar
|
||||||
progress = min(state["value"] / max_state_value, 1.0)
|
progress = min(state_value / max_state_value, 1.0)
|
||||||
progress_bar.progress(progress)
|
progress_bar.progress(progress)
|
||||||
|
|
||||||
# Final update to stage display - this ensures we see the COMPLETED state
|
# Final update to stage display - this ensures we see the COMPLETED state
|
||||||
stage_col.write(f"**Stage:** {state['name']}")
|
stage_col.write(f"**Stage:** {state_name}")
|
||||||
|
|
||||||
# Final update to elapsed time display
|
# Final update to elapsed time display
|
||||||
if "announcement_start_time" in st.session_state:
|
if "announcement_start_time" in st.session_state:
|
||||||
@@ -132,22 +134,20 @@ def show_announcement_status(group_id: int):
|
|||||||
time_col.write(f"⏱️ {final_elapsed_seconds:.1f}s")
|
time_col.write(f"⏱️ {final_elapsed_seconds:.1f}s")
|
||||||
|
|
||||||
# Final state
|
# Final state
|
||||||
if state["name"] == "COMPLETED":
|
if state_name == "COMPLETED":
|
||||||
st.success("✅ Announcement completed successfully")
|
st.success("✅ Announcement completed successfully")
|
||||||
|
|
||||||
# Display group information
|
# Display group information
|
||||||
st.write(f"📢 Announcement made to group {group.get('name', '')}")
|
st.write(f"📢 Announcement made to group {group.name}")
|
||||||
|
|
||||||
# Display endpoints if available
|
# Display endpoints if available
|
||||||
endpoints = group.get("endpoints", [])
|
endpoints = group.endpoints
|
||||||
if endpoints:
|
if endpoints:
|
||||||
endpoint_names = [ep.get("name", ep) for ep in endpoints if isinstance(ep, dict)]
|
endpoint_names = [ep.name for ep in endpoints]
|
||||||
if not endpoint_names: # Handle case where endpoints is a list of strings
|
|
||||||
endpoint_names = endpoints
|
|
||||||
st.write(f"📡 Endpoints: {', '.join(endpoint_names)}")
|
st.write(f"📡 Endpoints: {', '.join(endpoint_names)}")
|
||||||
|
|
||||||
# Display languages if available
|
# Display languages if available
|
||||||
languages = group.get("languages", [])
|
languages = group.languages
|
||||||
if languages:
|
if languages:
|
||||||
st.write(f"🌐 Languages: {', '.join(languages)}")
|
st.write(f"🌐 Languages: {', '.join(languages)}")
|
||||||
|
|
||||||
@@ -181,7 +181,7 @@ with st.container():
|
|||||||
# Custom announcement
|
# Custom announcement
|
||||||
with st.form("custom_announcement"):
|
with st.form("custom_announcement"):
|
||||||
# Get all groups with their names and IDs
|
# Get all groups with their names and IDs
|
||||||
group_options = [(g["name"], g["id"]) for g in st.session_state.endpoint_groups]
|
group_options = [(g.name, g.id) for g in st.session_state.endpoint_groups]
|
||||||
selected_group_name = st.selectbox(
|
selected_group_name = st.selectbox(
|
||||||
"Select announcement area",
|
"Select announcement area",
|
||||||
options=[g[0] for g in group_options] if group_options else ["No groups available"]
|
options=[g[0] for g in group_options] if group_options else ["No groups available"]
|
||||||
@@ -231,21 +231,21 @@ with st.sidebar:
|
|||||||
|
|
||||||
# Initialize the previous value in session state if not present
|
# Initialize the previous value in session state if not present
|
||||||
if f"prev_{input_key}" not in st.session_state:
|
if f"prev_{input_key}" not in st.session_state:
|
||||||
st.session_state[f"prev_{input_key}"] = group["name"]
|
st.session_state[f"prev_{input_key}"] = group.name
|
||||||
|
|
||||||
new_name = st.text_input(
|
new_name = st.text_input(
|
||||||
f"Group Name",
|
f"Group Name",
|
||||||
value=group["name"],
|
value=group.name,
|
||||||
key=input_key,
|
key=input_key,
|
||||||
on_change=lambda: None # Prevent automatic callbacks
|
on_change=lambda: None # Prevent automatic callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only update if the name has changed and it's different from the previous value
|
# Only update if the name has changed and it's different from the previous value
|
||||||
if new_name != group["name"] and new_name != st.session_state[f"prev_{input_key}"]:
|
if new_name != group.name and new_name != st.session_state[f"prev_{input_key}"]:
|
||||||
try:
|
try:
|
||||||
updated_group = group.copy()
|
updated_group = group.model_copy(deep=True)
|
||||||
updated_group["name"] = new_name
|
updated_group.name = new_name
|
||||||
update_group(group["id"], updated_group)
|
update_group(group.id, updated_group)
|
||||||
# Update the session state with the latest groups
|
# Update the session state with the latest groups
|
||||||
st.session_state.endpoint_groups = get_groups()
|
st.session_state.endpoint_groups = get_groups()
|
||||||
# Update the previous value before rerunning
|
# Update the previous value before rerunning
|
||||||
@@ -255,47 +255,42 @@ with st.sidebar:
|
|||||||
st.error(f"Failed to update group name: {str(e)}")
|
st.error(f"Failed to update group name: {str(e)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
endpoints_dict = get_available_endpoints()
|
endpoints = get_available_endpoints()
|
||||||
# Extract endpoint names from the endpoints dictionary
|
|
||||||
available_endpoint_names = []
|
|
||||||
for endpoint_id, endpoint_data in endpoints_dict.items():
|
|
||||||
if isinstance(endpoint_data, dict) and "name" in endpoint_data:
|
|
||||||
available_endpoint_names.append(endpoint_data["name"])
|
|
||||||
|
|
||||||
# Use a unique key for the endpoints multiselect
|
# Use a unique key for the endpoints multiselect
|
||||||
endpoints_key = f"endpoints_select_{i}"
|
endpoints_key = f"endpoints_select_{i}"
|
||||||
|
|
||||||
# Initialize the previous value in session state if not present
|
# Initialize the previous value in session state if not present
|
||||||
|
current_endpoints = [ep.id for ep in group.endpoints]
|
||||||
if f"prev_{endpoints_key}" not in st.session_state:
|
if f"prev_{endpoints_key}" not in st.session_state:
|
||||||
st.session_state[f"prev_{endpoints_key}"] = group.get("endpoints", [])
|
st.session_state[f"prev_{endpoints_key}"] = current_endpoints
|
||||||
|
|
||||||
# Extract endpoint names from group.endpoints if they are dictionaries
|
# Create mapping of endpoint ID to name for the multiselect
|
||||||
current_endpoints = []
|
endpoint_name_to_endpoint = {ep.name: ep for ep in endpoints}
|
||||||
for ep in group.get("endpoints", []):
|
|
||||||
if isinstance(ep, dict) and "name" in ep:
|
|
||||||
current_endpoints.append(ep["name"])
|
|
||||||
elif isinstance(ep, str):
|
|
||||||
current_endpoints.append(ep)
|
|
||||||
|
|
||||||
selected_endpoints = st.multiselect(
|
selected_endpoint_names = st.multiselect(
|
||||||
f"Endpoints",
|
f"Endpoints",
|
||||||
options=available_endpoint_names,
|
options=list(endpoint_name_to_endpoint.keys()),
|
||||||
default=current_endpoints,
|
default=[ep.name for ep in group.endpoints],
|
||||||
key=endpoints_key
|
key=endpoints_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convert selected names back to Endpoint objects
|
||||||
|
selected_endpoints = [endpoint_name_to_endpoint[name] for name in selected_endpoint_names]
|
||||||
|
selected_endpoint_ids = [ep.id for ep in selected_endpoints]
|
||||||
|
|
||||||
# Only update if endpoints have changed and they're different from previous value
|
# Only update if endpoints have changed and they're different from previous value
|
||||||
endpoints_changed = selected_endpoints != current_endpoints
|
endpoints_changed = selected_endpoint_ids != current_endpoints
|
||||||
endpoints_diff_from_prev = selected_endpoints != st.session_state[f"prev_{endpoints_key}"]
|
endpoints_diff_from_prev = selected_endpoint_ids != st.session_state[f"prev_{endpoints_key}"]
|
||||||
|
|
||||||
if endpoints_changed and endpoints_diff_from_prev:
|
if endpoints_changed and endpoints_diff_from_prev:
|
||||||
updated_group = group.copy()
|
updated_group = group.model_copy(deep=True)
|
||||||
updated_group["endpoints"] = selected_endpoints
|
updated_group.endpoints = selected_endpoints
|
||||||
update_group(group["id"], updated_group)
|
update_group(group.id, updated_group)
|
||||||
# Update the session state with the latest groups
|
# Update the session state with the latest groups
|
||||||
st.session_state.endpoint_groups = get_groups()
|
st.session_state.endpoint_groups = get_groups()
|
||||||
# Update the previous value before rerunning
|
# Update the previous value before rerunning
|
||||||
st.session_state[f"prev_{endpoints_key}"] = selected_endpoints
|
st.session_state[f"prev_{endpoints_key}"] = selected_endpoint_ids
|
||||||
st.rerun()
|
st.rerun()
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
st.error(f"Failed to load endpoints: {str(e)}")
|
st.error(f"Failed to load endpoints: {str(e)}")
|
||||||
@@ -307,31 +302,23 @@ with st.sidebar:
|
|||||||
|
|
||||||
# Initialize the previous value in session state if not present
|
# Initialize the previous value in session state if not present
|
||||||
if f"prev_{languages_key}" not in st.session_state:
|
if f"prev_{languages_key}" not in st.session_state:
|
||||||
st.session_state[f"prev_{languages_key}"] = group.get("languages", [])
|
st.session_state[f"prev_{languages_key}"] = group.languages
|
||||||
|
|
||||||
# Extract language codes from group.languages if they are dictionaries
|
|
||||||
current_languages = []
|
|
||||||
for lang in group.get("languages", []):
|
|
||||||
if isinstance(lang, dict) and "code" in lang:
|
|
||||||
current_languages.append(lang["code"])
|
|
||||||
elif isinstance(lang, str):
|
|
||||||
current_languages.append(lang)
|
|
||||||
|
|
||||||
selected_languages = st.multiselect(
|
selected_languages = st.multiselect(
|
||||||
f"Languages",
|
f"Languages",
|
||||||
options=st.session_state.available_languages,
|
options=st.session_state.available_languages,
|
||||||
default=current_languages,
|
default=group.languages,
|
||||||
key=languages_key
|
key=languages_key
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only update if languages have changed and they're different from previous value
|
# Only update if languages have changed and they're different from previous value
|
||||||
languages_changed = selected_languages != current_languages
|
languages_changed = selected_languages != group.languages
|
||||||
languages_diff_from_prev = selected_languages != st.session_state[f"prev_{languages_key}"]
|
languages_diff_from_prev = selected_languages != st.session_state[f"prev_{languages_key}"]
|
||||||
|
|
||||||
if languages_changed and languages_diff_from_prev:
|
if languages_changed and languages_diff_from_prev:
|
||||||
updated_group = group.copy()
|
updated_group = group.model_copy(deep=True)
|
||||||
updated_group["languages"] = selected_languages
|
updated_group.languages = selected_languages
|
||||||
update_group(group["id"], updated_group)
|
update_group(group.id, updated_group)
|
||||||
# Update the session state with the latest groups
|
# Update the session state with the latest groups
|
||||||
st.session_state.endpoint_groups = get_groups()
|
st.session_state.endpoint_groups = get_groups()
|
||||||
# Update the previous value before rerunning
|
# Update the previous value before rerunning
|
||||||
|
|||||||
Reference in New Issue
Block a user