diff --git a/src/python/larynx_train/preprocess.py b/src/python/larynx_train/preprocess.py index 9257367..1f6fa79 100644 --- a/src/python/larynx_train/preprocess.py +++ b/src/python/larynx_train/preprocess.py @@ -45,11 +45,18 @@ def main() -> None: parser.add_argument( "--single-speaker", action="store_true", help="Force single speaker dataset" ) + parser.add_argument( + "--speaker-id", type=int, help="Add speaker id to single speaker dataset" + ) parser.add_argument( "--debug", action="store_true", help="Print DEBUG messages to the console" ) args = parser.parse_args() + if args.single_speaker and (args.speaker_id is not None): + _LOGGER.fatal("--single-speaker and --speaker-id cannot both be provided") + return + level = logging.DEBUG if args.debug else logging.INFO logging.basicConfig(level=level) logging.getLogger().setLevel(level) @@ -78,7 +85,7 @@ def main() -> None: _LOGGER.debug("Counting number of speakers/utterances in the dataset") speaker_counts: Counter[str] = Counter() num_utterances = 0 - for utt in make_dataset(args.input_dir, args.single_speaker): + for utt in make_dataset(args.input_dir, args.single_speaker, args.speaker_id): speaker = utt.speaker or "" speaker_counts[speaker] += 1 num_utterances += 1 @@ -146,7 +153,8 @@ def main() -> None: ) with open(args.output_dir / "dataset.jsonl", "w", encoding="utf-8") as dataset_file: for utt_batch in batched( - make_dataset(args.input_dir, args.single_speaker), batch_size + make_dataset(args.input_dir, args.single_speaker, args.speaker_id), + batch_size, ): queue_in.put(utt_batch) @@ -231,7 +239,9 @@ class PathEncoder(json.JSONEncoder): return super().default(o) -def ljspeech_dataset(dataset_dir: Path, is_single_speaker: bool) -> Iterable[Utterance]: +def ljspeech_dataset( + dataset_dir: Path, is_single_speaker: bool, speaker_id: Optional[int] = None +) -> Iterable[Utterance]: # filename|speaker|text # speaker is optional metadata_path = dataset_dir / "metadata.csv" @@ -271,15 +281,19 @@ def ljspeech_dataset(dataset_dir: Path, is_single_speaker: bool) -> Iterable[Utt _LOGGER.warning("Missing %s", filename) continue - yield Utterance(text=text, audio_path=wav_path, speaker=speaker) + yield Utterance( + text=text, audio_path=wav_path, speaker=speaker, speaker_id=speaker_id + ) -def mycroft_dataset(dataset_dir: Path, is_single_speaker: bool) -> Iterable[Utterance]: +def mycroft_dataset( + dataset_dir: Path, is_single_speaker: bool, speaker_id: Optional[int] = None +) -> Iterable[Utterance]: for info_path in dataset_dir.glob("*.info"): wav_path = info_path.with_suffix(".wav") if wav_path.exists(): text = info_path.read_text(encoding="utf-8").strip() - yield Utterance(text=text, audio_path=wav_path) + yield Utterance(text=text, audio_path=wav_path, speaker_id=speaker_id) # -----------------------------------------------------------------------------