diff --git a/src/python/piper_train/export_onnx_streaming.py b/src/python/piper_train/export_onnx_streaming.py index fd67631..719b59f 100644 --- a/src/python/piper_train/export_onnx_streaming.py +++ b/src/python/piper_train/export_onnx_streaming.py @@ -55,9 +55,7 @@ class VitsEncoder(nn.Module): ) # [b, t', t], [b, t, d] -> [b, d, t'] z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - # Should we move the `flow` to the decoder? - z = gen.flow(z_p, y_mask, g=g, reverse=True) - return z, y_mask, g + return z_p, y_mask, g class VitsDecoder(nn.Module): @@ -66,6 +64,7 @@ class VitsDecoder(nn.Module): self.gen = gen def forward(self, z, y_mask, g=None): + z = self.gen.flow(z, y_mask, g=g, reverse=True) output = self.gen.dec((z * y_mask), g=g) return output