diff --git a/src/python/larynx_train/vits/attentions.py b/src/python/larynx_train/vits/attentions.py index 4f0e4af..20b7a9e 100644 --- a/src/python/larynx_train/vits/attentions.py +++ b/src/python/larynx_train/vits/attentions.py @@ -60,14 +60,16 @@ class Encoder(nn.Module): def forward(self, x, x_mask): attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) x = x * x_mask - for i in range(self.n_layers): - y = self.attn_layers[i](x, x, attn_mask) + for attn_layer, norm_layer_1, ffn_layer, norm_layer_2 in zip( + self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2 + ): + y = attn_layer(x, x, attn_mask) y = self.drop(y) - x = self.norm_layers_1[i](x + y) + x = norm_layer_1(x + y) - y = self.ffn_layers[i](x, x_mask) + y = ffn_layer(x, x_mask) y = self.drop(y) - x = self.norm_layers_2[i](x + y) + x = norm_layer_2(x + y) x = x * x_mask return x @@ -181,7 +183,7 @@ class MultiHeadAttention(nn.Module): self.block_length = block_length self.proximal_bias = proximal_bias self.proximal_init = proximal_init - self.attn = None + self.attn = torch.zeros(1) self.k_channels = channels // n_heads self.conv_q = nn.Conv1d(channels, channels, 1) @@ -222,7 +224,7 @@ class MultiHeadAttention(nn.Module): def attention(self, query, key, value, mask=None): # reshape [b, d, t] -> [b, n_h, t, d_k] - b, d, t_s, t_t = (*key.size(), query.size(2)) + b, d, t_s, t_t = (key.size(0), key.size(1), key.size(2), query.size(2)) query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) @@ -287,7 +289,7 @@ class MultiHeadAttention(nn.Module): ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) return ret - def _get_relative_embeddings(self, relative_embeddings, length): + def _get_relative_embeddings(self, relative_embeddings, length: int): # max_relative_position = 2 * self.window_size + 1 # Pad first before slice to avoid using cond ops. pad_length = max(length - (self.window_size + 1), 0) @@ -345,7 +347,7 @@ class MultiHeadAttention(nn.Module): x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] return x_final - def _attention_bias_proximal(self, length): + def _attention_bias_proximal(self, length: int): """Bias for self-attention to encourage attention to close positions. Args: length: an integer scalar. @@ -365,7 +367,7 @@ class FFN(nn.Module): filter_channels: int, kernel_size: int, p_dropout: float = 0.0, - activation: typing.Optional[str] = None, + activation: str = "", causal: bool = False, ): super().__init__() @@ -377,23 +379,31 @@ class FFN(nn.Module): self.activation = activation self.causal = causal - if causal: - self.padding = self._causal_padding - else: - self.padding = self._same_padding - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) self.drop = nn.Dropout(p_dropout) def forward(self, x, x_mask): - x = self.conv_1(self.padding(x * x_mask)) + if self.causal: + padding1 = self._causal_padding(x * x_mask) + else: + padding1 = self._same_padding(x * x_mask) + + x = self.conv_1(padding1) + if self.activation == "gelu": x = x * torch.sigmoid(1.702 * x) else: x = torch.relu(x) x = self.drop(x) - x = self.conv_2(self.padding(x * x_mask)) + + if self.causal: + padding2 = self._causal_padding(x * x_mask) + else: + padding2 = self._same_padding(x * x_mask) + + x = self.conv_2(padding2) + return x * x_mask def _causal_padding(self, x): diff --git a/src/python/larynx_train/vits/commons.py b/src/python/larynx_train/vits/commons.py index 615af25..3f334c3 100644 --- a/src/python/larynx_train/vits/commons.py +++ b/src/python/larynx_train/vits/commons.py @@ -1,5 +1,6 @@ import logging import math +from typing import Optional import torch from torch.nn import functional as F @@ -90,7 +91,7 @@ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) -def subsequent_mask(length): +def subsequent_mask(length: int): mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) return mask @@ -105,7 +106,7 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): return acts -def sequence_mask(length, max_length=None): +def sequence_mask(length, max_length: Optional[int] = None): if max_length is None: max_length = length.max() x = torch.arange(max_length, dtype=length.dtype, device=length.device) diff --git a/src/python/larynx_train/vits/models.py b/src/python/larynx_train/vits/models.py index fb100bf..68ef7ba 100644 --- a/src/python/larynx_train/vits/models.py +++ b/src/python/larynx_train/vits/models.py @@ -309,6 +309,7 @@ class Generator(torch.nn.Module): gin_channels: int = 0, ): super(Generator, self).__init__() + self.LRELU_SLOPE = 0.1 self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) self.conv_pre = Conv1d( @@ -349,15 +350,16 @@ class Generator(torch.nn.Module): if g is not None: x = x + self.cond(g) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, modules.LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) + for i, up in enumerate(self.ups): + x = F.leaky_relu(x, self.LRELU_SLOPE) + x = up(x) + xs = torch.zeros(1) + for j, resblock in enumerate(self.resblocks): + index = j - (i * self.num_kernels) + if index == 0: + xs = resblock(x) + elif (index > 0) and (index < self.num_kernels): + xs += resblock(x) x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) @@ -382,6 +384,7 @@ class DiscriminatorP(torch.nn.Module): use_spectral_norm: bool = False, ): super(DiscriminatorP, self).__init__() + self.LRELU_SLOPE = 0.1 self.period = period self.use_spectral_norm = use_spectral_norm norm_f = weight_norm if not use_spectral_norm else spectral_norm @@ -449,7 +452,7 @@ class DiscriminatorP(torch.nn.Module): for l in self.convs: x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = F.leaky_relu(x, self.LRELU_SLOPE) fmap.append(x) x = self.conv_post(x) fmap.append(x) @@ -461,6 +464,7 @@ class DiscriminatorP(torch.nn.Module): class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(DiscriminatorS, self).__init__() + self.LRELU_SLOPE = 0.1 norm_f = spectral_norm if use_spectral_norm else weight_norm self.convs = nn.ModuleList( [ @@ -479,7 +483,7 @@ class DiscriminatorS(torch.nn.Module): for l in self.convs: x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = F.leaky_relu(x, self.LRELU_SLOPE) fmap.append(x) x = self.conv_post(x) fmap.append(x) diff --git a/src/python/larynx_train/vits/modules.py b/src/python/larynx_train/vits/modules.py index 1ca701f..656e497 100644 --- a/src/python/larynx_train/vits/modules.py +++ b/src/python/larynx_train/vits/modules.py @@ -10,8 +10,6 @@ from torch.nn.utils import remove_weight_norm, weight_norm from .commons import fused_add_tanh_sigmoid_multiply, get_padding, init_weights from .transforms import piecewise_rational_quadratic_transform -LRELU_SLOPE = 0.1 - class LayerNorm(nn.Module): def __init__(self, channels: int, eps: float = 1e-5): @@ -227,6 +225,7 @@ class ResBlock1(torch.nn.Module): dilation: typing.Tuple[int] = (1, 3, 5), ): super(ResBlock1, self).__init__() + self.LRELU_SLOPE = 0.1 self.convs1 = nn.ModuleList( [ weight_norm( @@ -301,11 +300,11 @@ class ResBlock1(torch.nn.Module): def forward(self, x, x_mask=None): for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) + xt = F.leaky_relu(x, self.LRELU_SLOPE) if x_mask is not None: xt = xt * x_mask xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = F.leaky_relu(xt, self.LRELU_SLOPE) if x_mask is not None: xt = xt * x_mask xt = c2(xt) @@ -326,6 +325,7 @@ class ResBlock2(torch.nn.Module): self, channels: int, kernel_size: int = 3, dilation: typing.Tuple[int] = (1, 3) ): super(ResBlock2, self).__init__() + self.LRELU_SLOPE = 0.1 self.convs = nn.ModuleList( [ weight_norm( @@ -354,7 +354,7 @@ class ResBlock2(torch.nn.Module): def forward(self, x, x_mask=None): for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) + xt = F.leaky_relu(x, self.LRELU_SLOPE) if x_mask is not None: xt = xt * x_mask xt = c(xt)