diff --git a/src/python/larynx_train/vits/lightning.py b/src/python/larynx_train/vits/lightning.py index 4836b06..441a038 100644 --- a/src/python/larynx_train/vits/lightning.py +++ b/src/python/larynx_train/vits/lightning.py @@ -280,7 +280,7 @@ class VitsModel(pl.LightningModule): return loss_disc_all def validation_step(self, batch: Batch, batch_idx: int): - val_loss = self.training_step_g(batch) + val_loss = self.training_step_g(batch) + self.training_step_d(batch) self.log("val_loss", val_loss) # Generate audio examples