diff --git a/src/cpp/main.cpp b/src/cpp/main.cpp index 9022cd3..bd75006 100644 --- a/src/cpp/main.cpp +++ b/src/cpp/main.cpp @@ -19,8 +19,8 @@ #endif #ifdef _WIN32 - #include - #include +#include +#include #endif #ifdef __APPLE__ @@ -85,6 +85,9 @@ struct RunConfig { // Seconds of extra silence to insert after a single phoneme optional> phonemeSilenceSeconds; + + // true to use CUDA execution provider + bool useCuda = false; }; void parseArgs(int argc, char *argv[], RunConfig &runConfig); @@ -114,7 +117,8 @@ int main(int argc, char *argv[]) { auto startTime = chrono::steady_clock::now(); loadVoice(piperConfig, runConfig.modelPath.string(), - runConfig.modelConfigPath.string(), voice, runConfig.speakerId); + runConfig.modelConfigPath.string(), voice, runConfig.speakerId, + runConfig.useCuda); auto endTime = chrono::steady_clock::now(); spdlog::info("Loaded voice in {} second(s)", chrono::duration(endTime - startTime).count()); @@ -314,8 +318,8 @@ int main(int argc, char *argv[]) { #ifdef _WIN32 // Needed on Windows to avoid terminal conversions - setmode(fileno(stdout),O_BINARY); - setmode(fileno(stdin),O_BINARY); + setmode(fileno(stdout), O_BINARY); + setmode(fileno(stdin), O_BINARY); #endif thread rawOutputThread(rawOutputProc, ref(sharedAudioBuffer), @@ -434,10 +438,11 @@ void printUsage(char *argv[]) { cerr << " --json-input stdin input is lines of JSON " "instead of plain text" << endl; + cerr << " --use-cuda use CUDA execution provider" + << endl; cerr << " --debug print DEBUG messages to the console" << endl; - cerr << " -q --quiet disable logging" - << endl; + cerr << " -q --quiet disable logging" << endl; cerr << endl; } @@ -518,6 +523,8 @@ void parseArgs(int argc, char *argv[], RunConfig &runConfig) { runConfig.tashkeelModelPath = filesystem::path(argv[++i]); } else if (arg == "--json_input" || arg == "--json-input") { runConfig.jsonInput = true; + } else if (arg == "--use_cuda" || arg == "--use-cuda") { + runConfig.useCuda = true; } else if (arg == "--version") { std::cout << piper::getVersion() << std::endl; exit(0); diff --git a/src/cpp/piper.cpp b/src/cpp/piper.cpp index 5310037..71f3fef 100644 --- a/src/cpp/piper.cpp +++ b/src/cpp/piper.cpp @@ -259,12 +259,18 @@ void terminate(PiperConfig &config) { spdlog::info("Terminated piper"); } -void loadModel(std::string modelPath, ModelSession &session) { +void loadModel(std::string modelPath, ModelSession &session, bool useCuda) { spdlog::debug("Loading onnx model from {}", modelPath); session.env = Ort::Env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, instanceName.c_str()); session.env.DisableTelemetryEvents(); + if (useCuda) { + // Use CUDA provider + OrtCUDAProviderOptions cuda_options{}; + session.options.AppendExecutionProvider_CUDA(cuda_options); + } + // Slows down performance by ~2x // session.options.SetIntraOpNumThreads(1); @@ -301,7 +307,7 @@ void loadModel(std::string modelPath, ModelSession &session) { // Load Onnx model and JSON config file void loadVoice(PiperConfig &config, std::string modelPath, std::string modelConfigPath, Voice &voice, - std::optional &speakerId) { + std::optional &speakerId, bool useCuda) { spdlog::debug("Parsing voice config at {}", modelConfigPath); std::ifstream modelConfigFile(modelConfigPath); voice.configRoot = json::parse(modelConfigFile); @@ -322,7 +328,7 @@ void loadVoice(PiperConfig &config, std::string modelPath, spdlog::debug("Voice contains {} speaker(s)", voice.modelConfig.numSpeakers); - loadModel(modelPath, voice.session); + loadModel(modelPath, voice.session, useCuda); } /* loadVoice */ diff --git a/src/cpp/piper.hpp b/src/cpp/piper.hpp index 21f0ece..7b956f7 100644 --- a/src/cpp/piper.hpp +++ b/src/cpp/piper.hpp @@ -116,7 +116,7 @@ void terminate(PiperConfig &config); // Load Onnx model and JSON config file void loadVoice(PiperConfig &config, std::string modelPath, std::string modelConfigPath, Voice &voice, - std::optional &speakerId); + std::optional &speakerId, bool useCuda); // Phonemize text and synthesize audio void textToAudio(PiperConfig &config, Voice &voice, std::string text, diff --git a/src/cpp/test.cpp b/src/cpp/test.cpp index 1b0782e..d2d5a52 100644 --- a/src/cpp/test.cpp +++ b/src/cpp/test.cpp @@ -36,14 +36,16 @@ int main(int argc, char *argv[]) { auto outputPath = std::string(argv[3]); optional speakerId; - loadVoice(piperConfig, modelPath, modelPath + ".json", voice, speakerId); + loadVoice(piperConfig, modelPath, modelPath + ".json", voice, speakerId, + false); piper::initialize(piperConfig); // Output audio to WAV file ofstream audioFile(outputPath, ios::binary); piper::SynthesisResult result; - piper::textToWavFile(piperConfig, voice, "This is a test.", audioFile, result); + piper::textToWavFile(piperConfig, voice, "This is a test.", audioFile, + result); piper::terminate(piperConfig); // Verify that file has some data