diff --git a/src/python/larynx_train/export_generator.py b/src/python/larynx_train/export_generator.py new file mode 100644 index 0000000..7421bdd --- /dev/null +++ b/src/python/larynx_train/export_generator.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +import argparse +import logging +from pathlib import Path +from typing import Optional + +import torch + +from .vits.lightning import VitsModel + +_LOGGER = logging.getLogger("larynx_train.export_generator") + + +def main(): + """Main entry point""" + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint", help="Path to model checkpoint (.ckpt)") + parser.add_argument("output", help="Path to output model (.pt)") + + parser.add_argument( + "--debug", action="store_true", help="Print DEBUG messages to the console" + ) + args = parser.parse_args() + + if args.debug: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + _LOGGER.debug(args) + + # ------------------------------------------------------------------------- + + args.checkpoint = Path(args.checkpoint) + args.output = Path(args.output) + args.output.parent.mkdir(parents=True, exist_ok=True) + + model = VitsModel.load_from_checkpoint(args.checkpoint) + model_g = model.model_g + + # Inference only + model_g.eval() + + with torch.no_grad(): + model_g.dec.remove_weight_norm() + + model_g.forward = model_g.infer + + torch.save(model_g, args.output) + + _LOGGER.info("Exported model to %s", args.output) + + +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + main() diff --git a/src/python/larynx_train/infer_generator.py b/src/python/larynx_train/infer_generator.py new file mode 100644 index 0000000..e58f608 --- /dev/null +++ b/src/python/larynx_train/infer_generator.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +import argparse +import json +import logging +import sys +import time +from pathlib import Path + +import torch + +from .vits.lightning import VitsModel +from .vits.utils import audio_float_to_int16 +from .vits.wavfile import write as write_wav + +_LOGGER = logging.getLogger("larynx_train.infer_generator") + + +def main(): + """Main entry point""" + logging.basicConfig(level=logging.DEBUG) + parser = argparse.ArgumentParser(prog="larynx_train.infer_generator") + parser.add_argument( + "--model", required=True, help="Path to generator (.pt)" + ) + parser.add_argument("--output-dir", required=True, help="Path to write WAV files") + parser.add_argument("--sample-rate", type=int, default=22050) + args = parser.parse_args() + + args.output_dir = Path(args.output_dir) + args.output_dir.mkdir(parents=True, exist_ok=True) + + model = torch.load(args.model) + + # Inference only + model.eval() + + for i, line in enumerate(sys.stdin): + line = line.strip() + if not line: + continue + + utt = json.loads(line) + utt_id = str(i) + phoneme_ids = utt["phoneme_ids"] + speaker_id = utt.get("speaker_id") + + text = torch.LongTensor(phoneme_ids).unsqueeze(0) + text_lengths = torch.LongTensor([len(phoneme_ids)]) + sid = torch.LongTensor([speaker_id]) if speaker_id is not None else None + + start_time = time.perf_counter() + audio = ( + model( + text, + text_lengths, + sid, + # torch.FloatTensor([0.667]), + # torch.FloatTensor([1.0]), + # torch.FloatTensor([0.8]), + )[0] + .detach() + .numpy() + ) + audio = audio_float_to_int16(audio) + end_time = time.perf_counter() + + audio_duration_sec = audio.shape[-1] / args.sample_rate + infer_sec = end_time - start_time + real_time_factor = ( + infer_sec / audio_duration_sec if audio_duration_sec > 0 else 0.0 + ) + + _LOGGER.debug( + "Real-time factor for %s: %0.2f (infer=%0.2f sec, audio=%0.2f sec)", + i + 1, + real_time_factor, + infer_sec, + audio_duration_sec, + ) + + output_path = args.output_dir / f"{utt_id}.wav" + write_wav(str(output_path), args.sample_rate, audio) + + +if __name__ == "__main__": + main() diff --git a/src/python/larynx_train/infer_torchscript.py b/src/python/larynx_train/infer_torchscript.py index 84d5afb..b0223ac 100755 --- a/src/python/larynx_train/infer_torchscript.py +++ b/src/python/larynx_train/infer_torchscript.py @@ -32,10 +32,7 @@ def main(): model = torch.jit.load(args.model) # Inference only - # model.eval() - - # with torch.no_grad(): - # model.model_g.dec.remove_weight_norm() + model.eval() for i, line in enumerate(sys.stdin): line = line.strip() @@ -49,7 +46,6 @@ def main(): text = torch.LongTensor(phoneme_ids).unsqueeze(0) text_lengths = torch.LongTensor([len(phoneme_ids)]) - # scales = [0.667, 1.0, 0.8] sid = torch.LongTensor([speaker_id]) if speaker_id is not None else None start_time = time.perf_counter()