Use multiprocess in preprocess script

This commit is contained in:
Michael Hansen
2022-11-11 16:42:48 -05:00
parent eb60d8529b
commit 7c22049330
5 changed files with 140 additions and 194 deletions

View File

@@ -1,5 +1,4 @@
"""Shared access to package resources"""
import json
import os
import typing
from pathlib import Path

View File

@@ -148,7 +148,7 @@ def phonemize(text: str, phonemizer: Phonemizer) -> List[str]:
def phonemes_to_ids(
phonemes: Iterable[str],
phoneme_id_map: Optional[Mapping[str, Iterable[int]]] = None,
missing_phonemes: Optional[Counter[str]] = None,
missing_phonemes: "Optional[Counter[str]]" = None,
) -> List[int]:
if phoneme_id_map is None:
phoneme_id_map = DEFAULT_PHONEME_ID_MAP

View File

@@ -1,10 +1,13 @@
#!/usr/bin/env python3
import argparse
import csv
import dataclasses
import itertools
import json
import logging
import os
from dataclasses import dataclass
from multiprocessing import JoinableQueue, Process, Queue
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set
@@ -33,10 +36,21 @@ def main():
required=True,
help="Target sample rate for voice (hertz)",
)
parser.add_argument(
"--dataset-format", choices=("ljspeech", "mycroft"), required=True
)
parser.add_argument("--cache-dir", help="Directory to cache processed audio files")
parser.add_argument("--max-workers", type=int)
parser.add_argument(
"--debug", action="store_true", help="Print DEBUG messages to the console"
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(level=level)
logging.getLogger().setLevel(level)
# Prevent log spam
logging.getLogger("numba").setLevel(logging.WARNING)
# Convert to paths and create output directories
@@ -51,17 +65,24 @@ def main():
)
args.cache_dir.mkdir(parents=True, exist_ok=True)
if args.dataset_format == "mycroft":
make_dataset = mycroft_dataset
else:
make_dataset = ljspeech_dataset
# Count speakers
_LOGGER.info("Counting number of speakers in the dataset")
_LOGGER.debug("Counting number of speakers in the dataset")
speakers: Set[str] = set()
for utt in mycroft_dataset(args.input_dir):
num_utterances = 0
for utt in make_dataset(args.input_dir):
speakers.add(utt.speaker or "")
num_utterances += 1
is_multispeaker = len(speakers) > 1
speaker_ids: Dict[str, int] = {}
if is_multispeaker:
_LOGGER.info("%s speaker(s) detected", len(speakers))
_LOGGER.info("%s speakers detected", len(speakers))
# Assign speaker ids in sorted order
for speaker_id, speaker in enumerate(sorted(speakers)):
@@ -94,22 +115,32 @@ def main():
)
_LOGGER.info("Wrote dataset config")
# Used to trim silence
silence_detector = make_silence_detector()
if (args.max_workers is None) or (args.max_workers < 1):
args.max_workers = os.cpu_count()
batch_size = int(num_utterances / (args.max_workers * 2))
queue_in = JoinableQueue()
queue_out = Queue()
# Start workers
processes = [
Process(target=process_batch, args=(args, queue_in, queue_out))
for _ in range(args.max_workers)
]
for proc in processes:
proc.start()
_LOGGER.info(
"Processing %s utterance(s) with %s worker(s)", num_utterances, args.max_workers
)
with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file:
phonemizer = Phonemizer(default_voice=args.language)
for utt in mycroft_dataset(args.input_dir):
try:
utt.audio_path = utt.audio_path.absolute()
_LOGGER.debug(utt)
utt.phonemes = phonemize(utt.text, phonemizer)
utt.phoneme_ids = phonemes_to_ids(utt.phonemes)
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
utt.audio_path, args.cache_dir, silence_detector, args.sample_rate
)
for utt_batch in batched(make_dataset(args.input_dir), batch_size):
queue_in.put(utt_batch)
_LOGGER.debug("Waiting for jobs to finish")
for _ in range(num_utterances):
utt = queue_out.get()
if utt is not None:
# JSONL
json.dump(
dataclasses.asdict(utt),
@@ -118,8 +149,48 @@ def main():
cls=PathEncoder,
)
print("", file=dataset_file)
except Exception:
_LOGGER.exception("Failed to process utterance: %s", utt)
# Signal workers to stop
for proc in processes:
queue_in.put(None)
# Wait for workers to stop
for proc in processes:
proc.join(timeout=1)
# -----------------------------------------------------------------------------
def process_batch(args: argparse.Namespace, queue_in: JoinableQueue, queue_out: Queue):
try:
silence_detector = make_silence_detector()
phonemizer = Phonemizer(default_voice=args.language)
while True:
utt_batch = queue_in.get()
if utt_batch is None:
break
for utt in utt_batch:
try:
_LOGGER.debug(utt)
utt.phonemes = phonemize(utt.text, phonemizer)
utt.phoneme_ids = phonemes_to_ids(utt.phonemes)
utt.audio_norm_path, utt.audio_spec_path = cache_norm_audio(
utt.audio_path,
args.cache_dir,
silence_detector,
args.sample_rate,
)
queue_out.put(utt)
except Exception:
_LOGGER.exception("Failed to process utterance: %s", utt)
queue_out.put(None)
queue_in.task_done()
except Exception:
_LOGGER.exception("process_batch")
# -----------------------------------------------------------------------------
@@ -143,6 +214,40 @@ class PathEncoder(json.JSONEncoder):
return super().default(o)
def ljspeech_dataset(dataset_dir: Path) -> Iterable[Utterance]:
# filename|speaker|text
# speaker is optional
metadata_path = dataset_dir / "metadata.csv"
assert metadata_path.exists(), f"Missing {metadata_path}"
wav_dir = dataset_dir / "wav"
if not wav_dir.is_dir():
wav_dir = dataset_dir / "wavs"
assert wav_dir.is_dir(), f"Missing {wav_dir}"
with open(metadata_path, "r", encoding="utf-8") as csv_file:
reader = csv.reader(csv_file, delimiter="|")
for row in reader:
assert len(row) >= 2, "Not enough colums"
speaker: Optional[str] = None
if len(row) == 2:
filename, text = row[0], row[-1]
else:
filename, speaker, text = row[0], row[1], row[-1]
wav_path = wav_dir / filename
if not wav_path.exists():
wav_path = wav_dir / f"{filename}.wav"
if not wav_path.exists():
_LOGGER.warning("Missing %s", wav_path)
continue
yield Utterance(text=text, audio_path=wav_path, speaker=speaker)
def mycroft_dataset(dataset_dir: Path) -> Iterable[Utterance]:
for info_path in dataset_dir.glob("*.info"):
wav_path = info_path.with_suffix(".wav")
@@ -151,6 +256,21 @@ def mycroft_dataset(dataset_dir: Path) -> Iterable[Utterance]:
yield Utterance(text=text, audio_path=wav_path)
# -----------------------------------------------------------------------------
def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
batch = list(itertools.islice(it, n))
while batch:
yield batch
batch = list(itertools.islice(it, n))
# -----------------------------------------------------------------------------
if __name__ == "__main__":

View File

@@ -8,7 +8,6 @@ import torch
from torch import FloatTensor, LongTensor
from torch.utils.data import Dataset
_LOGGER = logging.getLogger("vits.dataset")

View File

@@ -1,172 +0,0 @@
import json
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from typing import List, Optional
import librosa
import torch
from dataclasses_json import DataClassJsonMixin
from torch import Tensor
@dataclass
class DatasetUtterance(DataClassJsonMixin):
id: str
text: Optional[str] = None
phonemes: Optional[List[str]] = None
phoneme_ids: Optional[List[int]] = None
audio_path: Optional[Path] = None
audio_norm_path: Optional[Path] = None
mel_spec_path: Optional[Path] = None
speaker: Optional[str] = None
speaker_id: Optional[int] = None
def __post_init__(self):
self._original_json: Optional[str] = None
@property
def original_json(self) -> str:
if self._original_json is None:
self._original_json = self.to_json(ensure_ascii=False)
return self._original_json
@dataclass
class TrainingUtterance:
id: str
phoneme_ids: Tensor
audio_norm: Tensor
mel_spec: Tensor
speaker_id: Optional[Tensor] = None
# -----------------------------------------------------------------------------
@dataclass
class UtteranceLoadingContext:
cache_dir: Path
is_multispeaker: bool
phonemize: Optional[Popen] = None
phonemes2ids: Optional[Popen] = None
speaker2id: Optional[Popen] = None
audio2norm: Optional[Popen] = None
audio2spec: Optional[Popen] = None
def load_utterance(
utterance_json: str, context: UtteranceLoadingContext
) -> TrainingUtterance:
data_utterance = DatasetUtterance.from_json(utterance_json)
# pylint: disable=protected-access
data_utterance._original_json = utterance_json
# Requirements:
# 1. phoneme ids
# 2. audio norm
# 3. mel spec
# 4. speaker id (if multispeaker)
# 1. phoneme ids
if data_utterance.phoneme_ids is None:
_load_phoneme_ids(data_utterance, context)
# 2. audio norm
if (data_utterance.audio_norm_path is None) or (
not data_utterance.audio_norm_path.exists()
):
_load_audio_norm(data_utterance, context)
# 3. mel spec
if (data_utterance.mel_spec_path is None) or (
not data_utterance.mel_spec_path.exists()
):
_load_mel_spec(data_utterance, context)
# 4. speaker id
if context.is_multispeaker:
if data_utterance.speaker_id is None:
_load_speaker_id(data_utterance, context)
# Convert to training utterance
assert data_utterance.phoneme_ids is not None
assert data_utterance.audio_norm_path is not None
assert data_utterance.mel_spec_path is not None
if context.is_multispeaker:
assert data_utterance.speaker_id is not None
train_utterance = TrainingUtterance(
id=data_utterance.id,
phoneme_ids=torch.LongTensor(data_utterance.phoneme_ids),
audio_norm=torch.load(data_utterance.audio_norm_path),
mel_spec=torch.load(data_utterance.mel_spec_path),
speaker_id=None
if data_utterance.speaker_id is None
else torch.LongTensor(data_utterance.speaker_id),
)
return train_utterance
def _load_phoneme_ids(
data_utterance: DatasetUtterance, context: UtteranceLoadingContext
):
if data_utterance.phonemes is None:
# Need phonemes first
_load_phonemes(data_utterance, context)
assert (
data_utterance.phonemes is not None
), f"phonemes is required for phoneme ids: {data_utterance}"
assert (
context.phonemes2ids is not None
), f"phonemes2ids program is required for phoneme ids: {data_utterance}"
assert context.phonemes2ids.stdin is not None
assert context.phonemes2ids.stdout is not None
# JSON in, JSON out
print(data_utterance.original_json, file=context.phonemes2ids.stdin, flush=True)
result_json = context.phonemes2ids.stdout.readline()
result_dict = json.loads(result_json)
# Update utterance
data_utterance.phoneme_ids = result_dict["phoneme_ids"]
data_utterance._original_json = result_json
def _load_phonemes(data_utterance: DatasetUtterance, context: UtteranceLoadingContext):
assert (
data_utterance.text is not None
), f"text is required for phonemes: {data_utterance}"
assert (
context.phonemize is not None
), f"phonemize program is required for phonemes: {data_utterance}"
assert context.phonemize.stdin is not None
assert context.phonemize.stdout is not None
# JSON in, JSON out
print(data_utterance.original_json, file=context.phonemize.stdin, flush=True)
result_json = context.phonemize.stdout.readline()
result_dict = json.loads(result_json)
# Update utterance
data_utterance.phonemes = result_dict["phoneme"]
data_utterance._original_json = result_json
def _load_audio_norm(
data_utterance: DatasetUtterance, context: UtteranceLoadingContext
):
pass
def _load_mel_spec(data_utterance: DatasetUtterance, context: UtteranceLoadingContext):
pass