Include discriminator loss in validation

This commit is contained in:
Michael Hansen
2023-02-25 20:08:50 -06:00
parent b8e3058d7a
commit 93d3744614

View File

@@ -280,7 +280,7 @@ class VitsModel(pl.LightningModule):
return loss_disc_all
def validation_step(self, batch: Batch, batch_idx: int):
val_loss = self.training_step_g(batch)
val_loss = self.training_step_g(batch) + self.training_step_d(batch)
self.log("val_loss", val_loss)
# Generate audio examples