Unsupported value kind: Tensor

This commit is contained in:
Michael Hansen
2023-01-07 11:20:19 -06:00
parent bebc36014a
commit f1cc4e58bd
4 changed files with 50 additions and 35 deletions

View File

@@ -309,6 +309,7 @@ class Generator(torch.nn.Module):
gin_channels: int = 0,
):
super(Generator, self).__init__()
self.LRELU_SLOPE = 0.1
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
@@ -349,15 +350,16 @@ class Generator(torch.nn.Module):
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
for i, up in enumerate(self.ups):
x = F.leaky_relu(x, self.LRELU_SLOPE)
x = up(x)
xs = torch.zeros(1)
for j, resblock in enumerate(self.resblocks):
index = j - (i * self.num_kernels)
if index == 0:
xs = resblock(x)
elif (index > 0) and (index < self.num_kernels):
xs += resblock(x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
@@ -382,6 +384,7 @@ class DiscriminatorP(torch.nn.Module):
use_spectral_norm: bool = False,
):
super(DiscriminatorP, self).__init__()
self.LRELU_SLOPE = 0.1
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if not use_spectral_norm else spectral_norm
@@ -449,7 +452,7 @@ class DiscriminatorP(torch.nn.Module):
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = F.leaky_relu(x, self.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
@@ -461,6 +464,7 @@ class DiscriminatorP(torch.nn.Module):
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
self.LRELU_SLOPE = 0.1
norm_f = spectral_norm if use_spectral_norm else weight_norm
self.convs = nn.ModuleList(
[
@@ -479,7 +483,7 @@ class DiscriminatorS(torch.nn.Module):
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = F.leaky_relu(x, self.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)