Use multiprocess in preprocess script

This commit is contained in:
Michael Hansen
2022-11-11 16:42:48 -05:00
parent eb60d8529b
commit 7c22049330
5 changed files with 140 additions and 194 deletions

View File

@@ -1,10 +1,13 @@
#!/usr/bin/env python3
import argparse
import csv
import dataclasses
import itertools
import json
import logging
import os
from dataclasses import dataclass
from multiprocessing import JoinableQueue, Process, Queue
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set
@@ -33,10 +36,21 @@ def main():
required=True,
help="Target sample rate for voice (hertz)",
)
parser.add_argument(
"--dataset-format", choices=("ljspeech", "mycroft"), required=True
)
parser.add_argument("--cache-dir", help="Directory to cache processed audio files")
parser.add_argument("--max-workers", type=int)
parser.add_argument(
"--debug", action="store_true", help="Print DEBUG messages to the console"
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(level=level)
logging.getLogger().setLevel(level)
# Prevent log spam
logging.getLogger("numba").setLevel(logging.WARNING)
# Convert to paths and create output directories
@@ -51,17 +65,24 @@ def main():
)
args.cache_dir.mkdir(parents=True, exist_ok=True)
if args.dataset_format == "mycroft":
make_dataset = mycroft_dataset
else:
make_dataset = ljspeech_dataset
# Count speakers
_LOGGER.info("Counting number of speakers in the dataset")
_LOGGER.debug("Counting number of speakers in the dataset")
speakers: Set[str] = set()
for utt in mycroft_dataset(args.input_dir):
num_utterances = 0
for utt in make_dataset(args.input_dir):
speakers.add(utt.speaker or "")
num_utterances += 1
is_multispeaker = len(speakers) > 1
speaker_ids: Dict[str, int] = {}
if is_multispeaker:
_LOGGER.info("%s speaker(s) detected", len(speakers))
_LOGGER.info("%s speakers detected", len(speakers))
# Assign speaker ids in sorted order
for speaker_id, speaker in enumerate(sorted(speakers)):
@@ -94,22 +115,32 @@ def main():
)
_LOGGER.info("Wrote dataset config")
# Used to trim silence
silence_detector = make_silence_detector()
if (args.max_workers is None) or (args.max_workers < 1):
args.max_workers = os.cpu_count()
batch_size = int(num_utterances / (args.max_workers * 2))
queue_in = JoinableQueue()
queue_out = Queue()
# Start workers
processes = [
Process(target=process_batch, args=(args, queue_in, queue_out))
for _ in range(args.max_workers)
]
for proc in processes:
proc.start()
_LOGGER.info(
"Processing %s utterance(s) with %s worker(s)", num_utterances, args.max_workers
)
with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
phonemizer = Phonemizer(default_voice=args.language)
for utt in mycroft_dataset(args.input_dir):
try:
utt.audio_path = utt.audio_path.absolute()
_LOGGER.debug(utt)
utt.phonemes = phonemize(utt.text, phonemizer)
utt.phoneme_ids = phonemes_to_ids(utt.phonemes)
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
utt.audio_path, args.cache_dir, silence_detector, args.sample_rate
)
for utt_batch in batched(make_dataset(args.input_dir), batch_size):
queue_in.put(utt_batch)
_LOGGER.debug("Waiting for jobs to finish")
for _ in range(num_utterances):
utt = queue_out.get()
if utt is not None:
# JSONL
json.dump(
dataclasses.asdict(utt),
@@ -118,8 +149,48 @@ def main():
cls=PathEncoder,
)
print("", file=dataset_file)
except Exception:
_LOGGER.exception("Failed to process utterance: %s", utt)
# Signal workers to stop
for proc in processes:
queue_in.put(None)
# Wait for workers to stop
for proc in processes:
proc.join(timeout=1)
# -----------------------------------------------------------------------------
def process_batch(args: argparse.Namespace, queue_in: JoinableQueue, queue_out: Queue):
try:
silence_detector = make_silence_detector()
phonemizer = Phonemizer(default_voice=args.language)
while True:
utt_batch = queue_in.get()
if utt_batch is None:
break
for utt in utt_batch:
try:
_LOGGER.debug(utt)
utt.phonemes = phonemize(utt.text, phonemizer)
utt.phoneme_ids = phonemes_to_ids(utt.phonemes)
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
utt.audio_path,
args.cache_dir,
silence_detector,
args.sample_rate,
)
queue_out.put(utt)
except Exception:
_LOGGER.exception("Failed to process utterance: %s", utt)
queue_out.put(None)
queue_in.task_done()
except Exception:
_LOGGER.exception("process_batch")
# -----------------------------------------------------------------------------
@@ -143,6 +214,40 @@ class PathEncoder(json.JSONEncoder):
return super().default(o)
def ljspeech_dataset(dataset_dir: Path) -> Iterable[Utterance]:
# filename|speaker|text
# speaker is optional
metadata_path = dataset_dir / "metadata.csv"
assert metadata_path.exists(), f"Missing {metadata_path}"
wav_dir = dataset_dir / "wav"
if not wav_dir.is_dir():
wav_dir = dataset_dir / "wavs"
assert wav_dir.is_dir(), f"Missing {wav_dir}"
with open(metadata_path, "r", encoding="utf-8") as csv_file:
reader = csv.reader(csv_file, delimiter="|")
for row in reader:
assert len(row) >= 2, "Not enough colums"
speaker: Optional[str] = None
if len(row) == 2:
filename, text = row[0], row[-1]
else:
filename, speaker, text = row[0], row[1], row[-1]
wav_path = wav_dir / filename
if not wav_path.exists():
wav_path = wav_dir / f"{filename}.wav"
if not wav_path.exists():
_LOGGER.warning("Missing %s", wav_path)
continue
yield Utterance(text=text, audio_path=wav_path, speaker=speaker)
def mycroft_dataset(dataset_dir: Path) -> Iterable[Utterance]:
for info_path in dataset_dir.glob("*.info"):
wav_path = info_path.with_suffix(".wav")
@@ -151,6 +256,21 @@ def mycroft_dataset(dataset_dir: Path) -> Iterable[Utterance]:
yield Utterance(text=text, audio_path=wav_path)
# -----------------------------------------------------------------------------
def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
batch = list(itertools.islice(it, n))
while batch:
yield batch
batch = list(itertools.islice(it, n))
# -----------------------------------------------------------------------------
if __name__ == "__main__":