diff --git a/src/benchmark/benchmark_generator.py b/src/benchmark/benchmark_generator.py index bfb70be..c972877 100644 --- a/src/benchmark/benchmark_generator.py +++ b/src/benchmark/benchmark_generator.py @@ -6,13 +6,13 @@ import sys import torch -_SPEAKER_ID = 0 - def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("-m", "--model", required=True, help="Path to Onnx model file") - parser.add_argument("-c", "--config", help="Path to model config file") + parser.add_argument( + "-m", "--model", required=True, help="Path to generator file (.pt)" + ) + parser.add_argument("-c", "--config", help="Path to model config file (.json)") args = parser.parse_args() if not args.config: diff --git a/src/benchmark/benchmark_onnx.py b/src/benchmark/benchmark_onnx.py index 22426cd..553ac14 100644 --- a/src/benchmark/benchmark_onnx.py +++ b/src/benchmark/benchmark_onnx.py @@ -2,6 +2,7 @@ import argparse import json import time +import statistics import sys import onnxruntime @@ -10,13 +11,14 @@ import numpy as np _NOISE_SCALE = 0.667 _LENGTH_SCALE = 1.0 _NOISE_W = 0.8 -_SPEAKER_ID = 0 def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("-m", "--model", required=True, help="Path to Onnx model file") - parser.add_argument("-c", "--config", help="Path to model config file") + parser.add_argument( + "-m", "--model", required=True, help="Path to Onnx model file (.onnx)" + ) + parser.add_argument("-c", "--config", help="Path to model config file (.json)") args = parser.parse_args() if not args.config: @@ -29,7 +31,25 @@ def main() -> None: utterances = [json.loads(line) for line in sys.stdin] start_time = time.monotonic_ns() - session = onnxruntime.InferenceSession(args.model) + + session_options = onnxruntime.SessionOptions() + session_options.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL + ) + # session_options.enable_cpu_mem_arena = False + # session_options.enable_mem_pattern = False + session_options.enable_mem_reuse = False + # session_options.enable_profiling = False + # session_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL + # session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED + + session = onnxruntime.InferenceSession( + args.model, + sess_options=session_options, + ) + # session.intra_op_num_threads = 1 + # session.inter_op_num_threads = 1 + end_time = time.monotonic_ns() load_sec = (end_time - start_time) / 1e9 @@ -47,7 +67,12 @@ def main() -> None: ) json.dump( - {"load_sec": load_sec, "synthesize_rtf": synthesize_rtf}, + { + "load_sec": load_sec, + "rtf_mean": statistics.mean(synthesize_rtf), + "rtf_stdev": statistics.stdev(synthesize_rtf), + "rtfs": synthesize_rtf, + }, sys.stdout, ) diff --git a/src/benchmark/benchmark_torchscript.py b/src/benchmark/benchmark_torchscript.py new file mode 100644 index 0000000..11cc51c --- /dev/null +++ b/src/benchmark/benchmark_torchscript.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +import argparse +import json +import time +import sys + +import torch + +_NOISE_SCALE = 0.667 +_LENGTH_SCALE = 1.0 +_NOISE_W = 0.8 + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", "--model", required=True, help="Path to Torchscript file (.ts)" + ) + parser.add_argument("-c", "--config", help="Path to model config file (.json)") + args = parser.parse_args() + + if not args.config: + args.config = f"{args.model}.json" + + with open(args.config, "r", encoding="utf-8") as config_file: + config = json.load(config_file) + + sample_rate = config["audio"]["sample_rate"] + utterances = [json.loads(line) for line in sys.stdin] + + start_time = time.monotonic_ns() + model = torch.jit.load(args.model) + end_time = time.monotonic_ns() + + model.eval() + + load_sec = (end_time - start_time) / 1e9 + synthesize_rtf = [] + for utterance in utterances: + phoneme_ids = utterance["phoneme_ids"] + speaker_id = utterance.get("speaker_id") + synthesize_rtf.append( + synthesize( + model, + phoneme_ids, + speaker_id, + sample_rate, + ) + ) + + json.dump( + {"load_sec": load_sec, "synthesize_rtf": synthesize_rtf}, + sys.stdout, + ) + + +def synthesize(model, phoneme_ids, speaker_id, sample_rate) -> float: + text = torch.LongTensor(phoneme_ids).unsqueeze(0) + text_lengths = torch.LongTensor([len(phoneme_ids)]) + sid = torch.LongTensor([speaker_id]) if speaker_id is not None else None + + start_time = time.monotonic_ns() + audio = ( + model( + text, + text_lengths, + sid, + torch.FloatTensor([_NOISE_SCALE]), + torch.FloatTensor([_LENGTH_SCALE]), + torch.FloatTensor([_NOISE_W]), + )[0] + .detach() + .numpy() + .squeeze() + ) + end_time = time.monotonic_ns() + + audio_sec = (len(audio) / 2) / sample_rate + infer_sec = (end_time - start_time) / 1e9 + + return infer_sec / audio_sec + + +if __name__ == "__main__": + main() diff --git a/src/python/piper_train/export_torchscript.py b/src/python/piper_train/export_torchscript.py index 80e413f..3555a20 100644 --- a/src/python/piper_train/export_torchscript.py +++ b/src/python/piper_train/export_torchscript.py @@ -57,9 +57,7 @@ def main(): ) sequence_lengths = torch.LongTensor([sequences.size(1)]) - sid: Optional[int] = None - if num_speakers > 1: - sid = torch.LongTensor([0]) + sid = torch.LongTensor([0]) dummy_input = ( sequences,