diff --git a/src/python/larynx_train/vits/dataset.py b/src/python/larynx_train/vits/dataset.py index f5b6839..c257e28 100644 --- a/src/python/larynx_train/vits/dataset.py +++ b/src/python/larynx_train/vits/dataset.py @@ -67,15 +67,19 @@ class LarynxDataset(Dataset): def __init__( self, - dataset_paths: List[Union[str, Path]], # settings: LarynxDatasetSettings + dataset_paths: List[Union[str, Path]], + max_phoneme_ids: Optional[int] = None, ): - # self.settings = settings self.utterances: List[Utterance] = [] for dataset_path in dataset_paths: dataset_path = Path(dataset_path) _LOGGER.debug("Loading dataset: %s", dataset_path) - self.utterances.extend(LarynxDataset.load_dataset(dataset_path)) + self.utterances.extend( + LarynxDataset.load_dataset( + dataset_path, max_phoneme_ids=max_phoneme_ids + ) + ) def __len__(self): return len(self.utterances) @@ -93,7 +97,11 @@ class LarynxDataset(Dataset): ) @staticmethod - def load_dataset(dataset_path: Path) -> Iterable[Utterance]: + def load_dataset( + dataset_path: Path, max_phoneme_ids: Optional[int] = None, + ) -> Iterable[Utterance]: + num_skipped = 0 + with open(dataset_path, "r", encoding="utf-8") as dataset_file: for line_idx, line in enumerate(dataset_file): line = line.strip() @@ -101,15 +109,21 @@ class LarynxDataset(Dataset): continue try: - yield LarynxDataset.load_utterance(line) + utt = LarynxDataset.load_utterance(line) + if (max_phoneme_ids is None) or ( + len(utt.phoneme_ids) <= max_phoneme_ids + ): + yield utt + else: + num_skipped += 1 except Exception: _LOGGER.exception( - "Error on line %s of %s: %s", - line_idx + 1, - dataset_path, - line, + "Error on line %s of %s: %s", line_idx + 1, dataset_path, line, ) + if num_skipped > 0: + _LOGGER.warning("Skipped %s utterance(s)", num_skipped) + @staticmethod def load_utterance(line: str) -> Utterance: utt_dict = json.loads(line) diff --git a/src/python/larynx_train/vits/lightning.py b/src/python/larynx_train/vits/lightning.py index cb5b549..f497e7e 100644 --- a/src/python/larynx_train/vits/lightning.py +++ b/src/python/larynx_train/vits/lightning.py @@ -25,11 +25,7 @@ class VitsModel(pl.LightningModule): # audio resblock="2", resblock_kernel_sizes=(3, 5, 7), - resblock_dilation_sizes=( - (1, 2), - (2, 6), - (3, 12), - ), + resblock_dilation_sizes=((1, 2), (2, 6), (3, 12),), upsample_rates=(8, 8, 4), upsample_initial_channel=256, upsample_kernel_sizes=(16, 16, 8), @@ -72,7 +68,8 @@ class VitsModel(pl.LightningModule): seed: int = 1234, num_test_examples: int = 5, validation_split: float = 0.1, - **kwargs + max_phoneme_ids: Optional[int] = None, + **kwargs, ): super().__init__() self.save_hyperparameters() @@ -111,14 +108,21 @@ class VitsModel(pl.LightningModule): self._train_dataset: Optional[Dataset] = None self._val_dataset: Optional[Dataset] = None self._test_dataset: Optional[Dataset] = None - self._load_datasets(validation_split, num_test_examples) + self._load_datasets(validation_split, num_test_examples, max_phoneme_ids) # State kept between training optimizers self._y = None self._y_hat = None - def _load_datasets(self, validation_split: float, num_test_examples: int): - full_dataset = LarynxDataset(self.hparams.dataset) + def _load_datasets( + self, + validation_split: float, + num_test_examples: int, + max_phoneme_ids: Optional[int] = None, + ): + full_dataset = LarynxDataset( + self.hparams.dataset, max_phoneme_ids=max_phoneme_ids + ) valid_set_size = int(len(full_dataset) * validation_split) train_set_size = len(full_dataset) - valid_set_size - num_test_examples @@ -211,9 +215,7 @@ class VitsModel(pl.LightningModule): self.hparams.mel_fmax, ) y_mel = slice_segments( - mel, - ids_slice, - self.hparams.segment_size // self.hparams.hop_length, + mel, ids_slice, self.hparams.segment_size // self.hparams.hop_length, ) y_hat_mel = mel_spectrogram_torch( y_hat.squeeze(1), @@ -226,9 +228,7 @@ class VitsModel(pl.LightningModule): self.hparams.mel_fmax, ) y = slice_segments( - y, - ids_slice * self.hparams.hop_length, - self.hparams.segment_size, + y, ids_slice * self.hparams.hop_length, self.hparams.segment_size, ) # slice # Save for training_step_d @@ -320,6 +320,11 @@ class VitsModel(pl.LightningModule): parser.add_argument("--batch-size", type=int, required=True) parser.add_argument("--validation-split", type=float, default=0.1) parser.add_argument("--num-test-examples", type=int, default=5) + parser.add_argument( + "--max-phoneme-ids", + type=int, + help="Exclude utterances with phoneme id lists longer than this", + ) # parser.add_argument("--hidden-channels", type=int, default=192) parser.add_argument("--inter-channels", type=int, default=192)