diff --git a/src/python/piper_train/preprocess.py b/src/python/piper_train/preprocess.py index 1f6fa79..6224585 100644 --- a/src/python/piper_train/preprocess.py +++ b/src/python/piper_train/preprocess.py @@ -7,6 +7,7 @@ import json import logging import os from collections import Counter +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from multiprocessing import JoinableQueue, Process, Queue from pathlib import Path @@ -191,28 +192,36 @@ def process_batch(args: argparse.Namespace, queue_in: JoinableQueue, queue_out: silence_detector = make_silence_detector() phonemizer = Phonemizer(default_voice=args.language) - while True: - utt_batch = queue_in.get() - if utt_batch is None: - break + with ThreadPoolExecutor(max_workers=1) as executor: + while True: + utt_batch = queue_in.get() + if utt_batch is None: + break - for utt in utt_batch: - try: - _LOGGER.debug(utt) - utt.phonemes = phonemize(utt.text, phonemizer) - utt.phoneme_ids = phonemes_to_ids(utt.phonemes) - utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio( - utt.audio_path, - args.cache_dir, - silence_detector, - args.sample_rate, - ) - queue_out.put(utt) - except Exception: - _LOGGER.exception("Failed to process utterance: %s", utt) - queue_out.put(None) + for utt in utt_batch: + try: + utt.phonemes = next( + executor.map( + lambda utt: phonemize(utt.text, phonemizer), + [utt], + timeout=1, + ) + ) + utt.phoneme_ids = phonemes_to_ids(utt.phonemes) + utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio( + utt.audio_path, + args.cache_dir, + silence_detector, + args.sample_rate, + ) + queue_out.put(utt) + except TimeoutError: + _LOGGER.error("Skipping utterance due to timeout: %s", utt) + except Exception: + _LOGGER.exception("Failed to process utterance: %s", utt) + queue_out.put(None) - queue_in.task_done() + queue_in.task_done() except Exception: _LOGGER.exception("process_batch")