use asyncio for the multicast client

This commit is contained in:
2025-03-19 10:43:29 +01:00
parent c3a74c2a21
commit 8501184de5
7 changed files with 133 additions and 68 deletions

View File

@@ -10,6 +10,7 @@ dependencies = [
"aioconsole==0.8.1",
"fastapi==0.115.11",
"uvicorn==0.34.0",
"aiohttp==3.9.3",
]
[project.optional-dependencies]

View File

@@ -1,5 +1,6 @@
from typing import List
import time
import asyncio
import logging as log
@@ -15,7 +16,7 @@ import voice_client.tts_client
import voice_models.request_models
def announcement_from_german_text(
async def announcement_from_german_text(
config: translator_config.TranslatorConfigGroup,
text_de
):
@@ -50,12 +51,12 @@ def announcement_from_german_text(
log.info('Voice synth took %s', time.time() - start)
audio_data_dict[big.language] = lc3_audio.decode('latin-1') # TODO: should be .hex in the future
multicast_client.send_audio(
await multicast_client.send_audio(
audio_data_dict
)
if __name__ == '__main__':
async def main():
log.basicConfig(
level=log.INFO,
format='%(module)s.py:%(lineno)d %(levelname)s: %(message)s'
@@ -77,10 +78,12 @@ if __name__ == '__main__':
conf.llm_host_url = 'https://ollama.pstruebi.xyz'
conf.llm_host_token = 'sk-17124cb84df14cc6ab2d9e17d0724d13'
multicast_client.request_init(
await multicast_client.init(
config
)
announcement_from_german_text(config, 'Hello')
await announcement_from_german_text(config, 'Hello')
# TODO: make everything async
if __name__ == '__main__':
asyncio.run(main())

View File

@@ -4,6 +4,7 @@ import json
import logging as log
import time
import ollama
import aiohttp
from multilang_translator.translator import syspromts
@@ -12,10 +13,6 @@ from multilang_translator.translator import syspromts
# from_='llama3.2', system="You are Mario from Super Mario Bros."
# )
async def chat():
message = {'role': 'user', 'content': 'Why is the sky blue?'}
response = await ollama.AsyncClient().chat(model='llama3.2', messages=[message])
def query_openwebui(model, system, query, url, token):
url = f'{url}/api/chat/completions'
@@ -50,6 +47,41 @@ def query_ollama(model, system, query, host='http://localhost:11434'):
return response.message.content
async def query_openwebui_async(model, system, query, url, token):
url = f'{url}/api/chat/completions'
headers = {
'Authorization': f'Bearer {token}',
}
payload = {
'model': model,
'messages': [
{'role': 'system', 'content': system},
{'role': 'user', 'content': query}
],
}
start = time.time()
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as response:
response_json = await response.json()
log.info("Translating the text took %s s", round(time.time() - start, 2))
return response_json['choices'][0]['message']['content']
async def query_ollama_async(model, system, query, host='http://localhost:11434'):
client = ollama.AsyncClient(
host=host,
)
response = await client.chat(
model=model,
messages=[
{'role': 'system', 'content': system},
{'role': 'user', 'content': query}
],
)
return response.message.content
def translate_de_to_x( # TODO: use async ollama client later - implenent a translate async function
text:str,
target_language: str,
@@ -70,6 +102,27 @@ def translate_de_to_x( # TODO: use async ollama client later - implenent a trans
log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3))
return response
async def translate_de_to_x_async(
text:str,
target_language: str,
client='ollama',
model='llama3.2:3b-instruct-q4_0', # remember to use instruct models
host = None,
token = None
):
start=time.time()
s = getattr(syspromts, f"TRANSLATOR_DEU_{target_language.upper()}")
if client == 'ollama':
response = await query_ollama_async(model, s, text, host=host)
elif client == 'openwebui':
response = await query_openwebui_async(model, s, text, url=host, token=token)
else: raise NotImplementedError('llm client not implemented')
log.info('Running the translator to %s took %s s', target_language, round(time.time() - start, 3))
return response
if __name__ == "__main__":
import time
from multilang_translator.translator import test_content

View File

@@ -25,7 +25,6 @@ class Endpoint(BaseModel):
url: str
max_broadcasts: int = 1 # Maximum number of simultaneous broadcasts
class TranslatorLangConfig(BaseModel):
translator_llm: str = 'llama3.2:3b-instruct-q4_0'
@@ -53,6 +52,7 @@ class EndpointGroup(BaseModel):
name: str
languages: List[str]
endpoints: List[Endpoint]
sampling_rate_hz: int = 16000
translator_config: TranslatorConfig = TranslatorConfig()
current_state: AnnouncementStates = AnnouncementStates.IDLE
anouncement_start_time: float = 0.0

View File

@@ -55,7 +55,6 @@ def get_all_endpoints() -> List[Endpoint]:
"""Get all active endpoints."""
return ENDPOINTS
def get_endpoint_by_id(endpoint_id: str) -> Optional[Endpoint]:
"""Get an endpoint by its ID."""
return ENDPOINTS[endpoint_id]

View File

@@ -3,7 +3,7 @@ FastAPI implementation of the Multilang Translator API.
This API mimics the mock_api from auracaster-webui to allow integration.
"""
import time
import logging as log
import logging as log
import asyncio
import random
@@ -31,10 +31,10 @@ app.add_middleware(
allow_headers=["*"],
)
# Endpoint configuration cache
# Endpoint configuration cache
CURRENT_ENDPOINT_CONFIG = {}
def init_endpoint(endpoint: Endpoint, languages: list[str]):
async def init_endpoint(endpoint: Endpoint, languages: list[str], sampling_rate_hz: int):
"""Initialize a specific endpoint for multicast."""
current_config = CURRENT_ENDPOINT_CONFIG.get(endpoint.id)
@@ -42,20 +42,24 @@ def init_endpoint(endpoint: Endpoint, languages: list[str]):
if current_config is not None:
current_langs = [big.language for big in current_config.bigs]
# if languages are unchanged and the caster client status is initiailized, skip init
if current_langs == languages and multicast_client.get_status(base_url=endpoint.url)['is_initialized']:
log.info('Endpoint %s was already initialized', endpoint.name)
return
if current_langs == languages:
# Get status asynchronously
status = await multicast_client.get_status(base_url=endpoint.url)
if status['is_initialized']:
log.info('Endpoint %s was already initialized', endpoint.name)
return
log.info(f"Initializing endpoint: {endpoint.name} at {endpoint.url}")
# Load a default config
config = auracast_config.AuracastConfigGroup(
bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")()
bigs=[getattr(auracast_config, f"AuracastBigConfig{lang.capitalize()}")()
for lang in languages]
)
# overwrite some default configs
config.transport = 'auto'
config.auracast_device_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6))
config.auracast_sampling_rate_hz = sampling_rate_hz
# Configure the bigs
for big in config.bigs:
@@ -63,9 +67,10 @@ def init_endpoint(endpoint: Endpoint, languages: list[str]):
big.name = endpoint.name
big.random_address = ':'.join(f"{random.randint(0, 255):02X}" for _ in range(6))
big.id = random.randint(0, 2**16) #TODO: how many bits is this ?
#big.program_info = big.program_info + ' ' + endpoint.name
#big.program_info = big.program_info + ' ' + endpoint.name
ret = multicast_client.init(config, base_url=endpoint.url)
# make async init request
ret = await multicast_client.init(config, base_url=endpoint.url)
# if ret != 200: # TODO: this is not working, should probably be handled async
# log.error('Init of endpoint %s was unsucessfull', endpoint.name)
# raise Exception(f"Init was of endpoint {endpoint.name} was unsucessfull")
@@ -75,7 +80,6 @@ def init_endpoint(endpoint: Endpoint, languages: list[str]):
# log.info('Endpoint %s was already initialized', endpoint.name)
async def make_announcement(text: str, ep_group: EndpointGroup):
"""
Make an announcement to a group of endpoints.
@@ -89,93 +93,89 @@ async def make_announcement(text: str, ep_group: EndpointGroup):
ep_group.anouncement_start_time = time.time()
# update the database with the new state and start time so this can be read by another process
endpoints_db.update_group(ep_group.id, ep_group)
# Initialize all endpoints in the group if they were not initalized before
for endpoint in ep_group.endpoints:
ep_group.current_state = AnnouncementStates.INIT
endpoints_db.update_group(ep_group.id, ep_group)
init_endpoint(endpoint, ep_group.languages)
await init_endpoint(endpoint, ep_group.languages, ep_group.sampling_rate_hz)
# Translate the text for each language
base_lang = "deu" # German is the base language
ep_group.current_state = AnnouncementStates.TRANSLATING
endpoints_db.update_group(ep_group.id, ep_group)
translations = {base_lang: text}
for lang in ep_group.languages:
if lang != base_lang:
# Translate the text
lang_conf = getattr(ep_group.translator_config, lang)
translation = llm_translator.translate_de_to_x(
trans_conf = getattr(ep_group.translator_config, lang)
translation = await llm_translator.translate_de_to_x_async(
text=text,
target_language=lang,
client=lang_conf.llm_client,
model=lang_conf.translator_llm,
host=lang_conf.llm_host_url,
token=lang_conf.llm_host_token
client=trans_conf.llm_client,
model=trans_conf.translator_llm,
host=trans_conf.llm_host_url,
token=trans_conf.llm_host_token
)
translations[lang] = translation
log.info(f"Translated to {lang}: {translation}")
# Generate voices
ep_group.current_state = AnnouncementStates.GENERATING_VOICE
endpoints_db.update_group(ep_group.id, ep_group)
audio = {}
# Convert each translation to audio
for lang, text in translations.items():
for lang, text in translations.items():
# Get the appropriate language configuration
lang_conf = getattr(ep_group.translator_config, lang)
trans_conf = getattr(ep_group.translator_config, lang)
# Convert text to LC3-encoded audio using the configuration's TTS settings
audio[lang] = text_to_speech.synthesize(
audio[lang] = text_to_speech.synthesize(
translations[lang],
16000, # Sample rate from auracast config # TODO: take sampling rate from auracast config
lang_conf.tts_system, # TTS system from config
lang_conf.tts_model, # TTS model from config
ep_group.sampling_rate_hz, # Sample rate from auracast config # TODO: take sampling rate from auracast config
trans_conf.tts_system, # TTS system from config
trans_conf.tts_model, # TTS model from config
return_lc3=True
).decode('latin-1')
# Add to audio data dictionary (decode bytes to string for JSON serialization)
# Start the monitoring coroutine to wait for streaming to complete
# This will set the state to COMPLETED when finished
asyncio.create_task(monitor_streaming_completion(ep_group))
# Broadcast to all endpoints in group
for endpoint in ep_group.endpoints:
# Send the audio data to the server using the existing send_audio function
log.info(f"Broadcastcasting to {endpoint.name} for languages: {', '.join(audio.keys())}")
multicast_client.send_audio(audio, base_url=endpoint.url)
await multicast_client.send_audio(audio, base_url=endpoint.url)
# Return the translations
return {"translations": translations}
return {"translations": translations}
async def monitor_streaming_completion(ep_group: EndpointGroup):
"""
Monitor streaming status after audio is sent and update group state when complete.
Args:
ep_group: The endpoint group being monitored
"""
log.info(f"Starting streaming completion monitoring for endpoint group {ep_group.id}")
# Set a shorter timeout as requested
max_completion_time = 60 # seconds
# First check if we are actually in streaming state
streaming_started = False
initial_check_timeout = 10 # seconds
initial_check_start = time.time()
# Wait for streaming to start (with timeout)
while time.time() - initial_check_start < initial_check_timeout:
# Wait before checking again
@@ -183,7 +183,7 @@ async def monitor_streaming_completion(ep_group: EndpointGroup):
any_streaming = False
for endpoint in ep_group.endpoints:
status = multicast_client.get_status(base_url=endpoint.url)
status = await multicast_client.get_status(base_url=endpoint.url)
if status.get("is_streaming", False):
any_streaming = True
log.info(f"Streaming confirmed started on endpoint {endpoint.name}")
@@ -192,7 +192,6 @@ async def monitor_streaming_completion(ep_group: EndpointGroup):
if any_streaming:
streaming_started = True
break
if not streaming_started:
log.warning(f"No endpoints started streaming for group {ep_group.id} after {initial_check_timeout}s")
@@ -200,20 +199,21 @@ async def monitor_streaming_completion(ep_group: EndpointGroup):
ep_group.current_state = AnnouncementStates.ERROR
endpoints_db.update_group(ep_group.id, ep_group)
return
# Update group progress
ep_group.current_state = AnnouncementStates.BROADCASTING
endpoints_db.update_group(ep_group.id, ep_group)
# Now monitor until streaming completes on all endpoints
check_completion_start_time = time.time()
completed = [False for _ in ep_group.endpoints]
while not all(completed) or time.time() - check_completion_start_time > max_completion_time:
await asyncio.sleep(1)
# Check status of each endpoint
for i, endpoint in enumerate(ep_group.endpoints):
completed[i] = not multicast_client.get_status(base_url=endpoint.url)['is_streaming']
status = await multicast_client.get_status(base_url=endpoint.url)
completed[i] = not status['is_streaming']
if all(completed):
log.info(f"All endpoints completed streaming for group {ep_group.id}")
@@ -247,7 +247,7 @@ async def get_group_state(group_id: int):
ep_group = endpoints_db.get_group_by_id(group_id)
if not ep_group:
raise HTTPException(status_code=404, detail=f"Endpoint {group_id} not found")
return {"name": ep_group.current_state.name, "value": ep_group.current_state.value}
@@ -274,16 +274,16 @@ async def delete_group(group_id: int):
async def start_announcement(text: str, group_id: int):
"""Start a new announcement to the specified endpoint group."""
global announcement_task
# Get the group from active groups or database
group = endpoints_db.get_group_by_id(group_id)
if not group:
raise HTTPException(status_code=400, detail=f"Group {group_id} not found")
# Check if we're already processing an announcement
#if announcement_task and not announcement_task.done():
# raise HTTPException(status_code=400, detail="Already processing an announcement")
# Start the announcement task
announcement_task = asyncio.create_task(make_announcement(text, group))
return {"status": "Announcement started", "group_id": group_id}
@@ -312,10 +312,10 @@ if __name__ == "__main__":
# with reload=True logging of modules does not function as expected
uvicorn.run(
app,
#'translator_server:app',
host="0.0.0.0",
port=7999,
#reload=True,
#'translator_server:app',
host="0.0.0.0",
port=7999,
#reload=True,
#log_config=None,
#log_level="info"
)

9
tests/get_group_state.py Normal file
View File

@@ -0,0 +1,9 @@
import requests
import time
if __name__ == '__main__':
# get the group state every 0.5s
while True:
response = requests.get('http://localhost:7999/groups/0/state')
print(response.json())
time.sleep(0.5)