diff --git a/src/python/larynx_train/__main__.py b/src/python/larynx_train/__main__.py index f6153e4..e27e5e9 100644 --- a/src/python/larynx_train/__main__.py +++ b/src/python/larynx_train/__main__.py @@ -5,6 +5,7 @@ from pathlib import Path import torch from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from .vits.lightning import VitsModel @@ -18,6 +19,11 @@ def main(): parser.add_argument( "--dataset-dir", required=True, help="Path to pre-processed dataset directory" ) + parser.add_argument( + "--checkpoint-epochs", + type=int, + help="Save checkpoint every N epochs (default: 1)", + ) Trainer.add_argparse_args(parser) VitsModel.add_model_specific_args(parser) parser.add_argument("--seed", type=int, default=1234) @@ -42,6 +48,12 @@ def main(): sample_rate = int(config["audio"]["sample_rate"]) trainer = Trainer.from_argparse_args(args) + if args.checkpoint_epochs is not None: + trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)] + _LOGGER.debug( + "Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs + ) + dict_args = vars(args) model = VitsModel( num_symbols=num_symbols,