mirror of
https://github.com/pstrueb/piper.git
synced 2026-04-17 22:05:30 +00:00
Move the flow component to the decoder. RTF of encoder is now 0.01
This commit is contained in:
@@ -55,9 +55,7 @@ class VitsEncoder(nn.Module):
|
|||||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||||
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
# Should we move the `flow` to the decoder?
|
return z_p, y_mask, g
|
||||||
z = gen.flow(z_p, y_mask, g=g, reverse=True)
|
|
||||||
return z, y_mask, g
|
|
||||||
|
|
||||||
|
|
||||||
class VitsDecoder(nn.Module):
|
class VitsDecoder(nn.Module):
|
||||||
@@ -66,6 +64,7 @@ class VitsDecoder(nn.Module):
|
|||||||
self.gen = gen
|
self.gen = gen
|
||||||
|
|
||||||
def forward(self, z, y_mask, g=None):
|
def forward(self, z, y_mask, g=None):
|
||||||
|
z = self.gen.flow(z, y_mask, g=g, reverse=True)
|
||||||
output = self.gen.dec((z * y_mask), g=g)
|
output = self.gen.dec((z * y_mask), g=g)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user