diff --git a/src/python/larynx_train/preprocess.py b/src/python/larynx_train/preprocess.py index 9fe6a7b..9257367 100644 --- a/src/python/larynx_train/preprocess.py +++ b/src/python/larynx_train/preprocess.py @@ -42,6 +42,9 @@ def main() -> None: ) parser.add_argument("--cache-dir", help="Directory to cache processed audio files") parser.add_argument("--max-workers", type=int) + parser.add_argument( + "--single-speaker", action="store_true", help="Force single speaker dataset" + ) parser.add_argument( "--debug", action="store_true", help="Print DEBUG messages to the console" ) @@ -75,7 +78,7 @@ def main() -> None: _LOGGER.debug("Counting number of speakers/utterances in the dataset") speaker_counts: Counter[str] = Counter() num_utterances = 0 - for utt in make_dataset(args.input_dir): + for utt in make_dataset(args.input_dir, args.single_speaker): speaker = utt.speaker or "" speaker_counts[speaker] += 1 num_utterances += 1 @@ -142,7 +145,9 @@ def main() -> None: "Processing %s utterance(s) with %s worker(s)", num_utterances, args.max_workers ) with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file: - for utt_batch in batched(make_dataset(args.input_dir), batch_size): + for utt_batch in batched( + make_dataset(args.input_dir, args.single_speaker), batch_size + ): queue_in.put(utt_batch) _LOGGER.debug("Waiting for jobs to finish") @@ -226,7 +231,7 @@ class PathEncoder(json.JSONEncoder): return super().default(o) -def ljspeech_dataset(dataset_dir: Path) -> Iterable[Utterance]: +def ljspeech_dataset(dataset_dir: Path, is_single_speaker: bool) -> Iterable[Utterance]: # filename|speaker|text # speaker is optional metadata_path = dataset_dir / "metadata.csv" @@ -242,7 +247,7 @@ def ljspeech_dataset(dataset_dir: Path) -> Iterable[Utterance]: assert len(row) >= 2, "Not enough colums" speaker: Optional[str] = None - if len(row) == 2: + if is_single_speaker or (len(row) == 2): filename, text = row[0], row[-1] else: filename, speaker, text = row[0], row[1], row[-1] @@ -269,7 +274,7 @@ def ljspeech_dataset(dataset_dir: Path) -> Iterable[Utterance]: yield Utterance(text=text, audio_path=wav_path, speaker=speaker) -def mycroft_dataset(dataset_dir: Path) -> Iterable[Utterance]: +def mycroft_dataset(dataset_dir: Path, is_single_speaker: bool) -> Iterable[Utterance]: for info_path in dataset_dir.glob("*.info"): wav_path = info_path.with_suffix(".wav") if wav_path.exists():