diff --git a/src/python/larynx_train/export_onnx.py b/src/python/larynx_train/export_onnx.py index d4e9573..37bccc1 100644 --- a/src/python/larynx_train/export_onnx.py +++ b/src/python/larynx_train/export_onnx.py @@ -10,7 +10,7 @@ from .vits.lightning import VitsModel _LOGGER = logging.getLogger("larynx_train.export_onnx") -OPSET_VERSION = 15 +OPSET_VERSION = 16 def main(): diff --git a/src/python/larynx_train/export_torchscript.py b/src/python/larynx_train/export_torchscript.py index 2669147..36a59fc 100644 --- a/src/python/larynx_train/export_torchscript.py +++ b/src/python/larynx_train/export_torchscript.py @@ -51,7 +51,26 @@ def main(): model_g.forward = model_g.infer - jitted_model = torch.jit.script(model_g) + dummy_input_length = 50 + sequences = torch.randint( + low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long + ) + sequence_lengths = torch.LongTensor([sequences.size(1)]) + + sid: Optional[int] = None + if num_speakers > 1: + sid = torch.LongTensor([0]) + + dummy_input = ( + sequences, + sequence_lengths, + sid, + torch.FloatTensor([0.667]), + torch.FloatTensor([1.0]), + torch.FloatTensor([0.8]), + ) + + jitted_model = torch.jit.trace(model_g, dummy_input) torch.jit.save(jitted_model, str(args.output)) _LOGGER.info("Saved TorchScript model to %s", args.output) diff --git a/src/python/larynx_train/infer.py b/src/python/larynx_train/infer.py index 993a305..0c10402 100644 --- a/src/python/larynx_train/infer.py +++ b/src/python/larynx_train/infer.py @@ -12,13 +12,13 @@ from .vits.lightning import VitsModel from .vits.utils import audio_float_to_int16 from .vits.wavfile import write as write_wav -_LOGGER = logging.getLogger("mimic3_train.infer") +_LOGGER = logging.getLogger("larynx_train.infer") def main(): """Main entry point""" logging.basicConfig(level=logging.DEBUG) - parser = argparse.ArgumentParser(prog="mimic3_train.infer") + parser = argparse.ArgumentParser(prog="larynx_train.infer") parser.add_argument( "--checkpoint", required=True, help="Path to model checkpoint (.ckpt)" ) diff --git a/src/python/larynx_train/infer_onnx.py b/src/python/larynx_train/infer_onnx.py index 6d3b594..6f98a63 100644 --- a/src/python/larynx_train/infer_onnx.py +++ b/src/python/larynx_train/infer_onnx.py @@ -13,13 +13,13 @@ import onnxruntime from .vits.utils import audio_float_to_int16 from .vits.wavfile import write as write_wav -_LOGGER = logging.getLogger("mimic3_train.infer_onnx") +_LOGGER = logging.getLogger("larynx_train.infer_onnx") def main(): """Main entry point""" logging.basicConfig(level=logging.DEBUG) - parser = argparse.ArgumentParser(prog="mimic3_train.infer_onnx") + parser = argparse.ArgumentParser(prog="larynx_train.infer_onnx") parser.add_argument("--model", required=True, help="Path to model (.onnx)") parser.add_argument("--output-dir", required=True, help="Path to write WAV files") parser.add_argument("--sample-rate", type=int, default=22050) diff --git a/src/python/larynx_train/infer_torchscript.py b/src/python/larynx_train/infer_torchscript.py new file mode 100755 index 0000000..84d5afb --- /dev/null +++ b/src/python/larynx_train/infer_torchscript.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +import argparse +import json +import logging +import sys +import time +from pathlib import Path + +import torch + +from .vits.lightning import VitsModel +from .vits.utils import audio_float_to_int16 +from .vits.wavfile import write as write_wav + +_LOGGER = logging.getLogger("larynx_train.infer_torchscript") + + +def main(): + """Main entry point""" + logging.basicConfig(level=logging.DEBUG) + parser = argparse.ArgumentParser(prog="larynx_train.infer_torchscript") + parser.add_argument( + "--model", required=True, help="Path to torchscript checkpoint (.ts)" + ) + parser.add_argument("--output-dir", required=True, help="Path to write WAV files") + parser.add_argument("--sample-rate", type=int, default=22050) + args = parser.parse_args() + + args.output_dir = Path(args.output_dir) + args.output_dir.mkdir(parents=True, exist_ok=True) + + model = torch.jit.load(args.model) + + # Inference only + # model.eval() + + # with torch.no_grad(): + # model.model_g.dec.remove_weight_norm() + + for i, line in enumerate(sys.stdin): + line = line.strip() + if not line: + continue + + utt = json.loads(line) + utt_id = str(i) + phoneme_ids = utt["phoneme_ids"] + speaker_id = utt.get("speaker_id") + + text = torch.LongTensor(phoneme_ids).unsqueeze(0) + text_lengths = torch.LongTensor([len(phoneme_ids)]) + # scales = [0.667, 1.0, 0.8] + sid = torch.LongTensor([speaker_id]) if speaker_id is not None else None + + start_time = time.perf_counter() + audio = ( + model( + text, + text_lengths, + sid, + torch.FloatTensor([0.667]), + torch.FloatTensor([1.0]), + torch.FloatTensor([0.8]), + )[0] + .detach() + .numpy() + ) + audio = audio_float_to_int16(audio) + end_time = time.perf_counter() + + audio_duration_sec = audio.shape[-1] / args.sample_rate + infer_sec = end_time - start_time + real_time_factor = ( + infer_sec / audio_duration_sec if audio_duration_sec > 0 else 0.0 + ) + + _LOGGER.debug( + "Real-time factor for %s: %0.2f (infer=%0.2f sec, audio=%0.2f sec)", + i + 1, + real_time_factor, + infer_sec, + audio_duration_sec, + ) + + output_path = args.output_dir / f"{utt_id}.wav" + write_wav(str(output_path), args.sample_rate, audio) + + +if __name__ == "__main__": + main()