diff --git a/src/cpp/main.cpp b/src/cpp/main.cpp index 3204b2f..8accd98 100644 --- a/src/cpp/main.cpp +++ b/src/cpp/main.cpp @@ -188,7 +188,7 @@ void parseArgs(int argc, char *argv[], RunConfig &runConfig) { runConfig.outputPath = filesystem::path(argv[++i]); } else if (arg == "-s" || arg == "--speaker") { ensureArg(argc, argv, i); - runConfig.speakerId = (larynx::SpeakerId)stoi(argv[++i]); + runConfig.speakerId = (larynx::SpeakerId)stoll(argv[++i]); } else if (arg == "-h" || arg == "--help") { printUsage(argv); exit(0); diff --git a/src/cpp/model.hpp b/src/cpp/model.hpp index c1cfcf1..befdcfc 100644 --- a/src/cpp/model.hpp +++ b/src/cpp/model.hpp @@ -16,15 +16,16 @@ struct ModelSession { vector outputNames; Ort::AllocatorWithDefaultOptions allocator; Ort::SessionOptions options; + Ort::Env env; ModelSession() : onnx(nullptr){}; }; void loadModel(string modelPath, ModelSession &session) { - Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, - instanceName.c_str()); - env.DisableTelemetryEvents(); + session.env = Ort::Env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + instanceName.c_str()); + session.env.DisableTelemetryEvents(); // Slows down performance by ~2x // session.options.SetIntraOpNumThreads(1); @@ -44,7 +45,7 @@ void loadModel(string modelPath, ModelSession &session) { session.options.DisableProfiling(); auto startTime = chrono::steady_clock::now(); - session.onnx = Ort::Session(env, modelPath.c_str(), session.options); + session.onnx = Ort::Session(session.env, modelPath.c_str(), session.options); auto endTime = chrono::steady_clock::now(); auto loadDuration = chrono::duration(endTime - startTime); diff --git a/src/cpp/synthesize.hpp b/src/cpp/synthesize.hpp index 71070e2..a028549 100644 --- a/src/cpp/synthesize.hpp +++ b/src/cpp/synthesize.hpp @@ -56,7 +56,7 @@ void synthesize(SynthesisConfig &synthesisConfig, ModelSession &session, if (synthesisConfig.speakerId) { // Add speaker id vector speakerId{(int64_t)synthesisConfig.speakerId.value()}; - vector speakerIdShape{1}; + vector speakerIdShape{(int64_t)speakerId.size()}; inputTensors.push_back(Ort::Value::CreateTensor( memoryInfo, speakerId.data(), speakerId.size(), speakerIdShape.data(), speakerIdShape.size()));