mirror of
https://github.com/pstrueb/piper.git
synced 2026-04-18 06:15:30 +00:00
Add --resume_from_single_speaker_checkpoint
This commit is contained in:
@@ -30,6 +30,10 @@ def main():
|
||||
choices=("x-low", "medium", "high"),
|
||||
help="Quality/size of model (default: medium)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_single_speaker_checkpoint",
|
||||
help="For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training",
|
||||
)
|
||||
Trainer.add_argparse_args(parser)
|
||||
VitsModel.add_model_specific_args(parser)
|
||||
parser.add_argument("--seed", type=int, default=1234)
|
||||
@@ -82,12 +86,60 @@ def main():
|
||||
num_speakers=num_speakers,
|
||||
sample_rate=sample_rate,
|
||||
dataset=[dataset_path],
|
||||
**dict_args
|
||||
**dict_args,
|
||||
)
|
||||
|
||||
if args.resume_from_single_speaker_checkpoint:
|
||||
assert (
|
||||
num_speakers > 1
|
||||
), "--resume_from_single_speaker_checkpoint is only for multi-speaker models. Use --resume_from_checkpoint for single-speaker models."
|
||||
|
||||
# Load single-speaker checkpoint
|
||||
_LOGGER.debug(
|
||||
"Resuming from single-speaker checkpoint: %s",
|
||||
args.resume_from_single_speaker_checkpoint,
|
||||
)
|
||||
model_single = VitsModel.load_from_checkpoint(
|
||||
args.resume_from_single_speaker_checkpoint,
|
||||
dataset=None,
|
||||
)
|
||||
g_dict = model_single.model_g.state_dict()
|
||||
for key in list(g_dict.keys()):
|
||||
# Remove keys that can't be copied over due to missing speaker embedding
|
||||
if (
|
||||
key.startswith("dec.cond")
|
||||
or key.startswith("dp.cond")
|
||||
or ("enc.cond_layer" in key)
|
||||
):
|
||||
g_dict.pop(key, None)
|
||||
|
||||
# Copy over the multi-speaker model, excluding keys related to the
|
||||
# speaker embedding (which is missing from the single-speaker model).
|
||||
load_state_dict(model.model_g, g_dict)
|
||||
load_state_dict(model.model_d, model_single.model_d.state_dict())
|
||||
_LOGGER.info(
|
||||
"Successfully converted single-speaker checkpoint to multi-speaker"
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def load_state_dict(model, saved_state_dict):
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k in saved_state_dict:
|
||||
# Use saved value
|
||||
new_state_dict[k] = saved_state_dict[k]
|
||||
else:
|
||||
# Use initialized value
|
||||
_LOGGER.debug("%s is not in the checkpoint for %s", k, key)
|
||||
new_state_dict[k] = v
|
||||
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user