mirror of
https://github.com/pstrueb/piper.git
synced 2026-04-18 06:15:30 +00:00
Pass speaker id during verification
This commit is contained in:
@@ -98,7 +98,8 @@ class LarynxDataset(Dataset):
|
||||
|
||||
@staticmethod
|
||||
def load_dataset(
|
||||
dataset_path: Path, max_phoneme_ids: Optional[int] = None,
|
||||
dataset_path: Path,
|
||||
max_phoneme_ids: Optional[int] = None,
|
||||
) -> Iterable[Utterance]:
|
||||
num_skipped = 0
|
||||
|
||||
@@ -118,7 +119,10 @@ class LarynxDataset(Dataset):
|
||||
num_skipped += 1
|
||||
except Exception:
|
||||
_LOGGER.exception(
|
||||
"Error on line %s of %s: %s", line_idx + 1, dataset_path, line,
|
||||
"Error on line %s of %s: %s",
|
||||
line_idx + 1,
|
||||
dataset_path,
|
||||
line,
|
||||
)
|
||||
|
||||
if num_skipped > 0:
|
||||
|
||||
@@ -25,7 +25,11 @@ class VitsModel(pl.LightningModule):
|
||||
# audio
|
||||
resblock="2",
|
||||
resblock_kernel_sizes=(3, 5, 7),
|
||||
resblock_dilation_sizes=((1, 2), (2, 6), (3, 12),),
|
||||
resblock_dilation_sizes=(
|
||||
(1, 2),
|
||||
(2, 6),
|
||||
(3, 12),
|
||||
),
|
||||
upsample_rates=(8, 8, 4),
|
||||
upsample_initial_channel=256,
|
||||
upsample_kernel_sizes=(16, 16, 8),
|
||||
@@ -215,7 +219,9 @@ class VitsModel(pl.LightningModule):
|
||||
self.hparams.mel_fmax,
|
||||
)
|
||||
y_mel = slice_segments(
|
||||
mel, ids_slice, self.hparams.segment_size // self.hparams.hop_length,
|
||||
mel,
|
||||
ids_slice,
|
||||
self.hparams.segment_size // self.hparams.hop_length,
|
||||
)
|
||||
y_hat_mel = mel_spectrogram_torch(
|
||||
y_hat.squeeze(1),
|
||||
@@ -228,7 +234,9 @@ class VitsModel(pl.LightningModule):
|
||||
self.hparams.mel_fmax,
|
||||
)
|
||||
y = slice_segments(
|
||||
y, ids_slice * self.hparams.hop_length, self.hparams.segment_size,
|
||||
y,
|
||||
ids_slice * self.hparams.hop_length,
|
||||
self.hparams.segment_size,
|
||||
) # slice
|
||||
|
||||
# Save for training_step_d
|
||||
@@ -276,7 +284,12 @@ class VitsModel(pl.LightningModule):
|
||||
text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
|
||||
text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
|
||||
scales = [0.667, 1.0, 0.8]
|
||||
test_audio = self(text, text_lengths, scales).detach()
|
||||
sid = (
|
||||
test_utt.speaker_id.to(self.device)
|
||||
if test_utt.speaker_id is not None
|
||||
else None
|
||||
)
|
||||
test_audio = self(text, text_lengths, scales, sid=sid).detach()
|
||||
|
||||
# Scale to make louder in [-1, 1]
|
||||
test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))
|
||||
|
||||
@@ -686,6 +686,7 @@ class SynthesizerTrn(nn.Module):
|
||||
):
|
||||
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
||||
if self.n_speakers > 1:
|
||||
assert sid is not None, "Missing speaker id"
|
||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||
else:
|
||||
g = None
|
||||
|
||||
Reference in New Issue
Block a user