diff --git a/src/python/larynx_train/vits/dataset.py b/src/python/larynx_train/vits/dataset.py index c257e28..fc6841d 100644 --- a/src/python/larynx_train/vits/dataset.py +++ b/src/python/larynx_train/vits/dataset.py @@ -98,7 +98,8 @@ class LarynxDataset(Dataset): @staticmethod def load_dataset( - dataset_path: Path, max_phoneme_ids: Optional[int] = None, + dataset_path: Path, + max_phoneme_ids: Optional[int] = None, ) -> Iterable[Utterance]: num_skipped = 0 @@ -118,7 +119,10 @@ class LarynxDataset(Dataset): num_skipped += 1 except Exception: _LOGGER.exception( - "Error on line %s of %s: %s", line_idx + 1, dataset_path, line, + "Error on line %s of %s: %s", + line_idx + 1, + dataset_path, + line, ) if num_skipped > 0: diff --git a/src/python/larynx_train/vits/lightning.py b/src/python/larynx_train/vits/lightning.py index f497e7e..75f0ed2 100644 --- a/src/python/larynx_train/vits/lightning.py +++ b/src/python/larynx_train/vits/lightning.py @@ -25,7 +25,11 @@ class VitsModel(pl.LightningModule): # audio resblock="2", resblock_kernel_sizes=(3, 5, 7), - resblock_dilation_sizes=((1, 2), (2, 6), (3, 12),), + resblock_dilation_sizes=( + (1, 2), + (2, 6), + (3, 12), + ), upsample_rates=(8, 8, 4), upsample_initial_channel=256, upsample_kernel_sizes=(16, 16, 8), @@ -215,7 +219,9 @@ class VitsModel(pl.LightningModule): self.hparams.mel_fmax, ) y_mel = slice_segments( - mel, ids_slice, self.hparams.segment_size // self.hparams.hop_length, + mel, + ids_slice, + self.hparams.segment_size // self.hparams.hop_length, ) y_hat_mel = mel_spectrogram_torch( y_hat.squeeze(1), @@ -228,7 +234,9 @@ class VitsModel(pl.LightningModule): self.hparams.mel_fmax, ) y = slice_segments( - y, ids_slice * self.hparams.hop_length, self.hparams.segment_size, + y, + ids_slice * self.hparams.hop_length, + self.hparams.segment_size, ) # slice # Save for training_step_d @@ -276,7 +284,12 @@ class VitsModel(pl.LightningModule): text = test_utt.phoneme_ids.unsqueeze(0).to(self.device) text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device) scales = [0.667, 1.0, 0.8] - test_audio = self(text, text_lengths, scales).detach() + sid = ( + test_utt.speaker_id.to(self.device) + if test_utt.speaker_id is not None + else None + ) + test_audio = self(text, text_lengths, scales, sid=sid).detach() # Scale to make louder in [-1, 1] test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max()))) diff --git a/src/python/larynx_train/vits/models.py b/src/python/larynx_train/vits/models.py index 123bd06..fb100bf 100644 --- a/src/python/larynx_train/vits/models.py +++ b/src/python/larynx_train/vits/models.py @@ -686,6 +686,7 @@ class SynthesizerTrn(nn.Module): ): x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) if self.n_speakers > 1: + assert sid is not None, "Missing speaker id" g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] else: g = None