diff --git a/src/python/larynx_train/infer.py b/src/python/larynx_train/infer.py index 0c10402..aed620f 100644 --- a/src/python/larynx_train/infer.py +++ b/src/python/larynx_train/infer.py @@ -29,7 +29,7 @@ def main(): args.output_dir = Path(args.output_dir) args.output_dir.mkdir(parents=True, exist_ok=True) - model = VitsModel.load_from_checkpoint(args.checkpoint) + model = VitsModel.load_from_checkpoint(args.checkpoint, dataset=None) # Inference only model.eval()