mirror of
https://github.com/pstrueb/piper.git
synced 2026-04-20 15:14:48 +00:00
Use multiprocess in preprocess script
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import csv
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import JoinableQueue, Process, Queue
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Set
|
||||
|
||||
@@ -33,10 +36,21 @@ def main():
|
||||
required=True,
|
||||
help="Target sample rate for voice (hertz)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-format", choices=("ljspeech", "mycroft"), required=True
|
||||
)
|
||||
parser.add_argument("--cache-dir", help="Directory to cache processed audio files")
|
||||
parser.add_argument("--max-workers", type=int)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="Print DEBUG messages to the console"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
level = logging.DEBUG if args.debug else logging.INFO
|
||||
logging.basicConfig(level=level)
|
||||
logging.getLogger().setLevel(level)
|
||||
|
||||
# Prevent log spam
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
# Convert to paths and create output directories
|
||||
@@ -51,17 +65,24 @@ def main():
|
||||
)
|
||||
args.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.dataset_format == "mycroft":
|
||||
make_dataset = mycroft_dataset
|
||||
else:
|
||||
make_dataset = ljspeech_dataset
|
||||
|
||||
# Count speakers
|
||||
_LOGGER.info("Counting number of speakers in the dataset")
|
||||
_LOGGER.debug("Counting number of speakers in the dataset")
|
||||
speakers: Set[str] = set()
|
||||
for utt in mycroft_dataset(args.input_dir):
|
||||
num_utterances = 0
|
||||
for utt in make_dataset(args.input_dir):
|
||||
speakers.add(utt.speaker or "")
|
||||
num_utterances += 1
|
||||
|
||||
is_multispeaker = len(speakers) > 1
|
||||
speaker_ids: Dict[str, int] = {}
|
||||
|
||||
if is_multispeaker:
|
||||
_LOGGER.info("%s speaker(s) detected", len(speakers))
|
||||
_LOGGER.info("%s speakers detected", len(speakers))
|
||||
|
||||
# Assign speaker ids in sorted order
|
||||
for speaker_id, speaker in enumerate(sorted(speakers)):
|
||||
@@ -94,22 +115,32 @@ def main():
|
||||
)
|
||||
_LOGGER.info("Wrote dataset config")
|
||||
|
||||
# Used to trim silence
|
||||
silence_detector = make_silence_detector()
|
||||
if (args.max_workers is None) or (args.max_workers < 1):
|
||||
args.max_workers = os.cpu_count()
|
||||
|
||||
batch_size = int(num_utterances / (args.max_workers * 2))
|
||||
queue_in = JoinableQueue()
|
||||
queue_out = Queue()
|
||||
|
||||
# Start workers
|
||||
processes = [
|
||||
Process(target=process_batch, args=(args, queue_in, queue_out))
|
||||
for _ in range(args.max_workers)
|
||||
]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
|
||||
_LOGGER.info(
|
||||
"Processing %s utterance(s) with %s worker(s)", num_utterances, args.max_workers
|
||||
)
|
||||
with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
|
||||
phonemizer = Phonemizer(default_voice=args.language)
|
||||
for utt in mycroft_dataset(args.input_dir):
|
||||
try:
|
||||
utt.audio_path = utt.audio_path.absolute()
|
||||
_LOGGER.debug(utt)
|
||||
|
||||
utt.phonemes = phonemize(utt.text, phonemizer)
|
||||
utt.phoneme_ids = phonemes_to_ids(utt.phonemes)
|
||||
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
|
||||
utt.audio_path, args.cache_dir, silence_detector, args.sample_rate
|
||||
)
|
||||
for utt_batch in batched(make_dataset(args.input_dir), batch_size):
|
||||
queue_in.put(utt_batch)
|
||||
|
||||
_LOGGER.debug("Waiting for jobs to finish")
|
||||
for _ in range(num_utterances):
|
||||
utt = queue_out.get()
|
||||
if utt is not None:
|
||||
# JSONL
|
||||
json.dump(
|
||||
dataclasses.asdict(utt),
|
||||
@@ -118,8 +149,48 @@ def main():
|
||||
cls=PathEncoder,
|
||||
)
|
||||
print("", file=dataset_file)
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to process utterance: %s", utt)
|
||||
|
||||
# Signal workers to stop
|
||||
for proc in processes:
|
||||
queue_in.put(None)
|
||||
|
||||
# Wait for workers to stop
|
||||
for proc in processes:
|
||||
proc.join(timeout=1)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def process_batch(args: argparse.Namespace, queue_in: JoinableQueue, queue_out: Queue):
|
||||
try:
|
||||
silence_detector = make_silence_detector()
|
||||
phonemizer = Phonemizer(default_voice=args.language)
|
||||
|
||||
while True:
|
||||
utt_batch = queue_in.get()
|
||||
if utt_batch is None:
|
||||
break
|
||||
|
||||
for utt in utt_batch:
|
||||
try:
|
||||
_LOGGER.debug(utt)
|
||||
utt.phonemes = phonemize(utt.text, phonemizer)
|
||||
utt.phoneme_ids = phonemes_to_ids(utt.phonemes)
|
||||
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
|
||||
utt.audio_path,
|
||||
args.cache_dir,
|
||||
silence_detector,
|
||||
args.sample_rate,
|
||||
)
|
||||
queue_out.put(utt)
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to process utterance: %s", utt)
|
||||
queue_out.put(None)
|
||||
|
||||
queue_in.task_done()
|
||||
except Exception:
|
||||
_LOGGER.exception("process_batch")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -143,6 +214,40 @@ class PathEncoder(json.JSONEncoder):
|
||||
return super().default(o)
|
||||
|
||||
|
||||
def ljspeech_dataset(dataset_dir: Path) -> Iterable[Utterance]:
|
||||
# filename|speaker|text
|
||||
# speaker is optional
|
||||
metadata_path = dataset_dir / "metadata.csv"
|
||||
assert metadata_path.exists(), f"Missing {metadata_path}"
|
||||
|
||||
wav_dir = dataset_dir / "wav"
|
||||
if not wav_dir.is_dir():
|
||||
wav_dir = dataset_dir / "wavs"
|
||||
|
||||
assert wav_dir.is_dir(), f"Missing {wav_dir}"
|
||||
|
||||
with open(metadata_path, "r", encoding="utf-8") as csv_file:
|
||||
reader = csv.reader(csv_file, delimiter="|")
|
||||
for row in reader:
|
||||
assert len(row) >= 2, "Not enough colums"
|
||||
|
||||
speaker: Optional[str] = None
|
||||
if len(row) == 2:
|
||||
filename, text = row[0], row[-1]
|
||||
else:
|
||||
filename, speaker, text = row[0], row[1], row[-1]
|
||||
|
||||
wav_path = wav_dir / filename
|
||||
if not wav_path.exists():
|
||||
wav_path = wav_dir / f"{filename}.wav"
|
||||
|
||||
if not wav_path.exists():
|
||||
_LOGGER.warning("Missing %s", wav_path)
|
||||
continue
|
||||
|
||||
yield Utterance(text=text, audio_path=wav_path, speaker=speaker)
|
||||
|
||||
|
||||
def mycroft_dataset(dataset_dir: Path) -> Iterable[Utterance]:
|
||||
for info_path in dataset_dir.glob("*.info"):
|
||||
wav_path = info_path.with_suffix(".wav")
|
||||
@@ -151,6 +256,21 @@ def mycroft_dataset(dataset_dir: Path) -> Iterable[Utterance]:
|
||||
yield Utterance(text=text, audio_path=wav_path)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def batched(iterable, n):
|
||||
"Batch data into lists of length n. The last batch may be shorter."
|
||||
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||
if n < 1:
|
||||
raise ValueError("n must be at least one")
|
||||
it = iter(iterable)
|
||||
batch = list(itertools.islice(it, n))
|
||||
while batch:
|
||||
yield batch
|
||||
batch = list(itertools.islice(it, n))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user