From bebc36014ac24c8a6d29b5a695e9e05eed7882e2 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Sat, 7 Jan 2023 10:57:34 -0600 Subject: [PATCH] Add torchscript export --- src/python/larynx_train/export_torchscript.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/python/larynx_train/export_torchscript.py diff --git a/src/python/larynx_train/export_torchscript.py b/src/python/larynx_train/export_torchscript.py new file mode 100644 index 0000000..2669147 --- /dev/null +++ b/src/python/larynx_train/export_torchscript.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +import argparse +import logging +from pathlib import Path +from typing import Optional + +import torch + +from .vits.lightning import VitsModel + +_LOGGER = logging.getLogger("larynx_train.export_torchscript") + + +def main(): + """Main entry point""" + torch.manual_seed(1234) + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint", help="Path to model checkpoint (.ckpt)") + parser.add_argument("output", help="Path to output model (.onnx)") + + parser.add_argument( + "--debug", action="store_true", help="Print DEBUG messages to the console" + ) + args = parser.parse_args() + + if args.debug: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + _LOGGER.debug(args) + + # ------------------------------------------------------------------------- + + args.checkpoint = Path(args.checkpoint) + args.output = Path(args.output) + args.output.parent.mkdir(parents=True, exist_ok=True) + + model = VitsModel.load_from_checkpoint(args.checkpoint) + model_g = model.model_g + + num_symbols = model_g.n_vocab + num_speakers = model_g.n_speakers + + # Inference only + model_g.eval() + + with torch.no_grad(): + model_g.dec.remove_weight_norm() + + model_g.forward = model_g.infer + + jitted_model = torch.jit.script(model_g) + torch.jit.save(jitted_model, str(args.output)) + + _LOGGER.info("Saved TorchScript model to %s", args.output) + + +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + main()