diff --git a/src/python/larynx_train/preprocess.py b/src/python/larynx_train/preprocess.py index 54d30f3..be625f9 100644 --- a/src/python/larynx_train/preprocess.py +++ b/src/python/larynx_train/preprocess.py @@ -6,6 +6,7 @@ import itertools import json import logging import os +from collections import Counter from dataclasses import dataclass from multiprocessing import JoinableQueue, Process, Queue from pathlib import Path @@ -71,23 +72,26 @@ def main(): make_dataset = ljspeech_dataset # Count speakers - _LOGGER.debug("Counting number of speakers in the dataset") - speakers: Set[str] = set() + _LOGGER.debug("Counting number of speakers/utterances in the dataset") + speaker_counts: Counter[str] = Counter() num_utterances = 0 for utt in make_dataset(args.input_dir): - speakers.add(utt.speaker or "") + speaker = utt.speaker or "" + speaker_counts[speaker] += 1 num_utterances += 1 assert num_utterances > 0, "No utterances found" - is_multispeaker = len(speakers) > 1 + is_multispeaker = len(speaker_counts) > 1 speaker_ids: Dict[str, int] = {} if is_multispeaker: - _LOGGER.info("%s speakers detected", len(speakers)) + _LOGGER.info("%s speakers detected", len(speaker_counts)) - # Assign speaker ids in sorted order - for speaker_id, speaker in enumerate(sorted(speakers)): + # Assign speaker ids by most number of utterances first + for speaker_id, (speaker, _speaker_count) in enumerate( + speaker_counts.most_common() + ): speaker_ids[speaker] = speaker_id else: _LOGGER.info("Single speaker dataset") @@ -108,7 +112,7 @@ def main(): "num_symbols": len( set(itertools.chain.from_iterable(DEFAULT_PHONEME_ID_MAP.values())) ), - "num_speakers": len(speakers), + "num_speakers": len(speaker_counts), "speaker_id_map": speaker_ids, }, config_file,