Add --resume_from_single_speaker_checkpoint

This commit is contained in:
Michael Hansen
2023-06-13 16:46:35 -05:00
parent b9c42da613
commit 32410e36d9

View File

@@ -30,6 +30,10 @@ def main():
choices=("x-low", "medium", "high"),
help="Quality/size of model (default: medium)",
)
parser.add_argument(
"--resume_from_single_speaker_checkpoint",
help="For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training",
)
Trainer.add_argparse_args(parser)
VitsModel.add_model_specific_args(parser)
parser.add_argument("--seed", type=int, default=1234)
@@ -82,12 +86,60 @@ def main():
num_speakers=num_speakers,
sample_rate=sample_rate,
dataset=[dataset_path],
**dict_args
**dict_args,
)
if args.resume_from_single_speaker_checkpoint:
assert (
num_speakers > 1
), "--resume_from_single_speaker_checkpoint is only for multi-speaker models. Use --resume_from_checkpoint for single-speaker models."
# Load single-speaker checkpoint
_LOGGER.debug(
"Resuming from single-speaker checkpoint: %s",
args.resume_from_single_speaker_checkpoint,
)
model_single = VitsModel.load_from_checkpoint(
args.resume_from_single_speaker_checkpoint,
dataset=None,
)
g_dict = model_single.model_g.state_dict()
for key in list(g_dict.keys()):
# Remove keys that can't be copied over due to missing speaker embedding
if (
key.startswith("dec.cond")
or key.startswith("dp.cond")
or ("enc.cond_layer" in key)
):
g_dict.pop(key, None)
# Copy over the multi-speaker model, excluding keys related to the
# speaker embedding (which is missing from the single-speaker model).
load_state_dict(model.model_g, g_dict)
load_state_dict(model.model_d, model_single.model_d.state_dict())
_LOGGER.info(
"Successfully converted single-speaker checkpoint to multi-speaker"
)
trainer.fit(model)
def load_state_dict(model, saved_state_dict):
state_dict = model.state_dict()
new_state_dict = {}
for k, v in state_dict.items():
if k in saved_state_dict:
# Use saved value
new_state_dict[k] = saved_state_dict[k]
else:
# Use initialized value
_LOGGER.debug("%s is not in the checkpoint for %s", k, key)
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
# -----------------------------------------------------------------------------