Unsupported value kind: Tensor

This commit is contained in:
Michael Hansen
2023-01-07 11:20:19 -06:00
parent bebc36014a
commit f1cc4e58bd
4 changed files with 50 additions and 35 deletions

View File

@@ -60,14 +60,16 @@ class Encoder(nn.Module):
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
y = self.attn_layers[i](x, x, attn_mask)
for attn_layer, norm_layer_1, ffn_layer, norm_layer_2 in zip(
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
):
y = attn_layer(x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
x = norm_layer_1(x + y)
y = self.ffn_layers[i](x, x_mask)
y = ffn_layer(x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = norm_layer_2(x + y)
x = x * x_mask
return x
@@ -181,7 +183,7 @@ class MultiHeadAttention(nn.Module):
self.block_length = block_length
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
self.attn = None
self.attn = torch.zeros(1)
self.k_channels = channels // n_heads
self.conv_q = nn.Conv1d(channels, channels, 1)
@@ -222,7 +224,7 @@ class MultiHeadAttention(nn.Module):
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
b, d, t_s, t_t = (key.size(0), key.size(1), key.size(2), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
@@ -287,7 +289,7 @@ class MultiHeadAttention(nn.Module):
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
def _get_relative_embeddings(self, relative_embeddings, length: int):
# max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.window_size + 1), 0)
@@ -345,7 +347,7 @@ class MultiHeadAttention(nn.Module):
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
return x_final
def _attention_bias_proximal(self, length):
def _attention_bias_proximal(self, length: int):
"""Bias for self-attention to encourage attention to close positions.
Args:
length: an integer scalar.
@@ -365,7 +367,7 @@ class FFN(nn.Module):
filter_channels: int,
kernel_size: int,
p_dropout: float = 0.0,
activation: typing.Optional[str] = None,
activation: str = "",
causal: bool = False,
):
super().__init__()
@@ -377,23 +379,31 @@ class FFN(nn.Module):
self.activation = activation
self.causal = causal
if causal:
self.padding = self._causal_padding
else:
self.padding = self._same_padding
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
self.drop = nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(self.padding(x * x_mask))
if self.causal:
padding1 = self._causal_padding(x * x_mask)
else:
padding1 = self._same_padding(x * x_mask)
x = self.conv_1(padding1)
if self.activation == "gelu":
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(self.padding(x * x_mask))
if self.causal:
padding2 = self._causal_padding(x * x_mask)
else:
padding2 = self._same_padding(x * x_mask)
x = self.conv_2(padding2)
return x * x_mask
def _causal_padding(self, x):

View File

@@ -1,5 +1,6 @@
import logging
import math
from typing import Optional
import torch
from torch.nn import functional as F
@@ -90,7 +91,7 @@ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
def subsequent_mask(length):
def subsequent_mask(length: int):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
@@ -105,7 +106,7 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
return acts
def sequence_mask(length, max_length=None):
def sequence_mask(length, max_length: Optional[int] = None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)

View File

@@ -309,6 +309,7 @@ class Generator(torch.nn.Module):
gin_channels: int = 0,
):
super(Generator, self).__init__()
self.LRELU_SLOPE = 0.1
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
@@ -349,15 +350,16 @@ class Generator(torch.nn.Module):
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
for i, up in enumerate(self.ups):
x = F.leaky_relu(x, self.LRELU_SLOPE)
x = up(x)
xs = torch.zeros(1)
for j, resblock in enumerate(self.resblocks):
index = j - (i * self.num_kernels)
if index == 0:
xs = resblock(x)
elif (index > 0) and (index < self.num_kernels):
xs += resblock(x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
@@ -382,6 +384,7 @@ class DiscriminatorP(torch.nn.Module):
use_spectral_norm: bool = False,
):
super(DiscriminatorP, self).__init__()
self.LRELU_SLOPE = 0.1
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if not use_spectral_norm else spectral_norm
@@ -449,7 +452,7 @@ class DiscriminatorP(torch.nn.Module):
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = F.leaky_relu(x, self.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
@@ -461,6 +464,7 @@ class DiscriminatorP(torch.nn.Module):
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
self.LRELU_SLOPE = 0.1
norm_f = spectral_norm if use_spectral_norm else weight_norm
self.convs = nn.ModuleList(
[
@@ -479,7 +483,7 @@ class DiscriminatorS(torch.nn.Module):
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = F.leaky_relu(x, self.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)

View File

@@ -10,8 +10,6 @@ from torch.nn.utils import remove_weight_norm, weight_norm
from .commons import fused_add_tanh_sigmoid_multiply, get_padding, init_weights
from .transforms import piecewise_rational_quadratic_transform
LRELU_SLOPE = 0.1
class LayerNorm(nn.Module):
def __init__(self, channels: int, eps: float = 1e-5):
@@ -227,6 +225,7 @@ class ResBlock1(torch.nn.Module):
dilation: typing.Tuple[int] = (1, 3, 5),
):
super(ResBlock1, self).__init__()
self.LRELU_SLOPE = 0.1
self.convs1 = nn.ModuleList(
[
weight_norm(
@@ -301,11 +300,11 @@ class ResBlock1(torch.nn.Module):
def forward(self, x, x_mask=None):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = F.leaky_relu(x, self.LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = F.leaky_relu(xt, self.LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c2(xt)
@@ -326,6 +325,7 @@ class ResBlock2(torch.nn.Module):
self, channels: int, kernel_size: int = 3, dilation: typing.Tuple[int] = (1, 3)
):
super(ResBlock2, self).__init__()
self.LRELU_SLOPE = 0.1
self.convs = nn.ModuleList(
[
weight_norm(
@@ -354,7 +354,7 @@ class ResBlock2(torch.nn.Module):
def forward(self, x, x_mask=None):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = F.leaky_relu(x, self.LRELU_SLOPE)
if x_mask is not None:
xt = xt * x_mask
xt = c(xt)