mirror of
https://github.com/pstrueb/piper.git
synced 2026-04-18 06:15:30 +00:00
Add --checkpoint-epochs
This commit is contained in:
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
from .vits.lightning import VitsModel
|
||||
|
||||
@@ -18,6 +19,11 @@ def main():
|
||||
parser.add_argument(
|
||||
"--dataset-dir", required=True, help="Path to pre-processed dataset directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-epochs",
|
||||
type=int,
|
||||
help="Save checkpoint every N epochs (default: 1)",
|
||||
)
|
||||
Trainer.add_argparse_args(parser)
|
||||
VitsModel.add_model_specific_args(parser)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
@@ -42,6 +48,12 @@ def main():
|
||||
sample_rate = int(config["audio"]["sample_rate"])
|
||||
|
||||
trainer = Trainer.from_argparse_args(args)
|
||||
if args.checkpoint_epochs is not None:
|
||||
trainer.callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
|
||||
_LOGGER.debug(
|
||||
"Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
|
||||
)
|
||||
|
||||
dict_args = vars(args)
|
||||
model = VitsModel(
|
||||
num_symbols=num_symbols,
|
||||
|
||||
Reference in New Issue
Block a user