diff --git a/src/python/larynx_train/export_generator.py b/src/python/larynx_train/export_generator.py index 7421bdd..33c4668 100644 --- a/src/python/larynx_train/export_generator.py +++ b/src/python/larynx_train/export_generator.py @@ -35,7 +35,7 @@ def main(): args.output = Path(args.output) args.output.parent.mkdir(parents=True, exist_ok=True) - model = VitsModel.load_from_checkpoint(args.checkpoint) + model = VitsModel.load_from_checkpoint(args.checkpoint, dataset=None) model_g = model.model_g # Inference only diff --git a/src/python/larynx_train/export_onnx.py b/src/python/larynx_train/export_onnx.py index 37bccc1..a556354 100644 --- a/src/python/larynx_train/export_onnx.py +++ b/src/python/larynx_train/export_onnx.py @@ -39,7 +39,7 @@ def main(): args.output = Path(args.output) args.output.parent.mkdir(parents=True, exist_ok=True) - model = VitsModel.load_from_checkpoint(args.checkpoint) + model = VitsModel.load_from_checkpoint(args.checkpoint, dataset=None) model_g = model.model_g num_symbols = model_g.n_vocab diff --git a/src/python/larynx_train/export_torchscript.py b/src/python/larynx_train/export_torchscript.py index 36a59fc..10718af 100644 --- a/src/python/larynx_train/export_torchscript.py +++ b/src/python/larynx_train/export_torchscript.py @@ -37,7 +37,7 @@ def main(): args.output = Path(args.output) args.output.parent.mkdir(parents=True, exist_ok=True) - model = VitsModel.load_from_checkpoint(args.checkpoint) + model = VitsModel.load_from_checkpoint(args.checkpoint, dataset=None) model_g = model.model_g num_symbols = model_g.n_vocab diff --git a/src/python/larynx_train/vits/lightning.py b/src/python/larynx_train/vits/lightning.py index 75f0ed2..4836b06 100644 --- a/src/python/larynx_train/vits/lightning.py +++ b/src/python/larynx_train/vits/lightning.py @@ -124,6 +124,10 @@ class VitsModel(pl.LightningModule): num_test_examples: int, max_phoneme_ids: Optional[int] = None, ): + if self.hparams.dataset is None: + _LOGGER.debug("No dataset to load") + return + full_dataset = LarynxDataset( self.hparams.dataset, max_phoneme_ids=max_phoneme_ids )