diff --git a/src/python/larynx_train/__main__.py b/src/python/larynx_train/__main__.py index e27e5e9..a2ee0d3 100644 --- a/src/python/larynx_train/__main__.py +++ b/src/python/larynx_train/__main__.py @@ -24,6 +24,12 @@ def main(): type=int, help="Save checkpoint every N epochs (default: 1)", ) + parser.add_argument( + "--quality", + default="medium", + choices=("x-low", "medium", "high"), + help="Quality/size of model (default: medium)", + ) Trainer.add_argparse_args(parser) VitsModel.add_model_specific_args(parser) parser.add_argument("--seed", type=int, default=1234) @@ -55,6 +61,22 @@ def main(): ) dict_args = vars(args) + if args.quality == "x-low": + dict_args["hidden_channels"] = 96 + dict_args["inter_channels"] = 96 + dict_args["filter_channels"] = 384 + elif args.quality == "high": + dict_args["resblock"] = "1" + dict_args["resblock_kernel_sizes"] = (3, 7, 11) + dict_args["resblock_dilation_sizes"] = ( + (1, 3, 5), + (1, 3, 5), + (1, 3, 5), + ) + dict_args["upsample_rates"] = (8, 8, 2, 2) + dict_args["upsample_initial_channel"] = 512 + dict_args["upsample_kernel_sizes"] = (16, 16, 4, 4) + model = VitsModel( num_symbols=num_symbols, num_speakers=num_speakers,