Add --checkpoint-epochs

This commit is contained in:
Michael Hansen
2023-02-24 15:03:02 -06:00
parent 657a1fae74
commit b8e3058d7a

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from .vits.lightning import VitsModel
@@ -18,6 +19,11 @@ def main():
parser.add_argument(
"--dataset-dir", required=True, help="Path to pre-processed dataset directory"
)
parser.add_argument(
"--checkpoint-epochs",
type=int,
help="Save checkpoint every N epochs (default: 1)",
)
Trainer.add_argparse_args(parser)
VitsModel.add_model_specific_args(parser)
parser.add_argument("--seed", type=int, default=1234)
@@ -42,6 +48,12 @@ def main():
sample_rate = int(config["audio"]["sample_rate"])
trainer = Trainer.from_argparse_args(args)
if args.checkpoint_epochs is not None:
trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
_LOGGER.debug(
"Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
)
dict_args = vars(args)
model = VitsModel(
num_symbols=num_symbols,