diff --git a/src/python/piper_train/__main__.py b/src/python/piper_train/__main__.py index a2ee0d3..0597c4c 100644 --- a/src/python/piper_train/__main__.py +++ b/src/python/piper_train/__main__.py @@ -30,6 +30,10 @@ def main(): choices=("x-low", "medium", "high"), help="Quality/size of model (default: medium)", ) + parser.add_argument( + "--resume_from_single_speaker_checkpoint", + help="For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training", + ) Trainer.add_argparse_args(parser) VitsModel.add_model_specific_args(parser) parser.add_argument("--seed", type=int, default=1234) @@ -82,12 +86,60 @@ def main(): num_speakers=num_speakers, sample_rate=sample_rate, dataset=[dataset_path], - **dict_args + **dict_args, ) + if args.resume_from_single_speaker_checkpoint: + assert ( + num_speakers > 1 + ), "--resume_from_single_speaker_checkpoint is only for multi-speaker models. Use --resume_from_checkpoint for single-speaker models." + + # Load single-speaker checkpoint + _LOGGER.debug( + "Resuming from single-speaker checkpoint: %s", + args.resume_from_single_speaker_checkpoint, + ) + model_single = VitsModel.load_from_checkpoint( + args.resume_from_single_speaker_checkpoint, + dataset=None, + ) + g_dict = model_single.model_g.state_dict() + for key in list(g_dict.keys()): + # Remove keys that can't be copied over due to missing speaker embedding + if ( + key.startswith("dec.cond") + or key.startswith("dp.cond") + or ("enc.cond_layer" in key) + ): + g_dict.pop(key, None) + + # Copy over the multi-speaker model, excluding keys related to the + # speaker embedding (which is missing from the single-speaker model). + load_state_dict(model.model_g, g_dict) + load_state_dict(model.model_d, model_single.model_d.state_dict()) + _LOGGER.info( + "Successfully converted single-speaker checkpoint to multi-speaker" + ) + trainer.fit(model) +def load_state_dict(model, saved_state_dict): + state_dict = model.state_dict() + new_state_dict = {} + + for k, v in state_dict.items(): + if k in saved_state_dict: + # Use saved value + new_state_dict[k] = saved_state_dict[k] + else: + # Use initialized value + _LOGGER.debug("%s is not in the checkpoint for %s", k, key) + new_state_dict[k] = v + + model.load_state_dict(new_state_dict) + + # -----------------------------------------------------------------------------