diff --git a/pyproject.toml b/pyproject.toml index 7a666f7..8746a35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,6 @@ dependencies = [ "requests==2.32.3", "ollama==0.4.7", "aioconsole==0.8.1", - "piper-phonemize==1.1.0", - "piper-tts==1.2.0", "fastapi==0.115.11", "uvicorn==0.34.0", ] @@ -19,6 +17,11 @@ test = [ "pytest >= 8.2", ] +[tool.poetry.group.tts.dependencies] +piper-phonemize = "==1.1.0" +piper-tts = "==1.2.0" + + [tool.pytest.ini_options] addopts = [ "--import-mode=importlib","--count=1","-s","-v" diff --git a/src/multilang_translator/translator_api/translator_models.py b/src/multilang_translator/translator_api/translator_models.py index d0114eb..390a666 100644 --- a/src/multilang_translator/translator_api/translator_models.py +++ b/src/multilang_translator/translator_api/translator_models.py @@ -54,5 +54,5 @@ class EndpointGroup(BaseModel): languages: List[str] endpoints: List[Endpoint] translator_config: TranslatorConfig = TranslatorConfig() - current_state: str = AnnouncementStates.IDLE + current_state: AnnouncementStates = AnnouncementStates.IDLE anouncement_start_time: float = 0.0 diff --git a/src/multilang_translator/translator_client/__init__.py b/src/multilang_translator/translator_client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/multilang_translator/translator_client/translator_client.py b/src/multilang_translator/translator_client/translator_client.py new file mode 100644 index 0000000..8ea0baf --- /dev/null +++ b/src/multilang_translator/translator_client/translator_client.py @@ -0,0 +1,90 @@ +""" +API client functions for interacting with the Translator API. +""" +import requests +from typing import List, Optional, Dict, Any, Tuple +from enum import Enum + + +from multilang_translator.translator_api.translator_models import AnnouncementStates, Endpoint, EndpointGroup + + +# This can be overridden through environment variables +API_BASE_URL = "http://localhost:7999" + +def get_groups() -> List[EndpointGroup]: + """Get all endpoint groups.""" + response = requests.get(f"{API_BASE_URL}/groups") + response.raise_for_status() + return [EndpointGroup.model_validate(group) for group in response.json()] + +def get_group(group_id: int) -> Optional[EndpointGroup]: + """Get a specific endpoint group by ID.""" + response = requests.get(f"{API_BASE_URL}/groups/{group_id}") + if response.status_code == 404: + return None + response.raise_for_status() + return EndpointGroup.model_validate(response.json()) + +def create_group(group: EndpointGroup) -> EndpointGroup: + """Create a new endpoint group.""" + response = requests.post(f"{API_BASE_URL}/groups", json=group.model_dump()) + response.raise_for_status() + return EndpointGroup.model_validate(response.json()) + +def update_group(group_id: int, updated_group: EndpointGroup) -> EndpointGroup: + """Update an existing endpoint group.""" + response = requests.put(f"{API_BASE_URL}/groups/{group_id}", json=updated_group.model_dump()) + response.raise_for_status() + return EndpointGroup.model_validate(response.json()) + +def delete_group(group_id: int) -> None: + """Delete an endpoint group.""" + response = requests.delete(f"{API_BASE_URL}/groups/{group_id}") + response.raise_for_status() + +def start_announcement(text: str, group_id: int) -> Dict[str, Any]: + """ + Start a new announcement. + + Args: + text: The text content of the announcement + group_id: The ID of the endpoint group to send the announcement to + + Returns: + Dictionary with status information + """ + response = requests.post(f"{API_BASE_URL}/announcement", params={"text": text, "group_id": group_id}) + response.raise_for_status() + return response.json() + +def get_group_state(group_id: int) -> Tuple[str, float]: + """ + Get the status of the current announcement for a specific group. + + Args: + group_id: The ID of the group to check the announcement status for + + Returns: + Tuple containing (state_name, state_value) + """ + response = requests.get(f"{API_BASE_URL}/groups/{group_id}/state") + response.raise_for_status() + state_data = response.json() + return (state_data["name"], state_data["value"]) + + +def get_available_endpoints() -> List[Endpoint]: + """Get all available endpoints.""" + response = requests.get(f"{API_BASE_URL}/endpoints") + response.raise_for_status() + endpoints_dict = response.json() + # API returns a dictionary with endpoint IDs as keys + # Convert this to a list of Endpoint objects + return [Endpoint.model_validate(endpoint_data) for endpoint_id, endpoint_data in endpoints_dict.items()] + +def get_available_languages() -> List[str]: + """Get all available languages for announcements.""" + response = requests.get(f"{API_BASE_URL}/languages") + response.raise_for_status() + return response.json()