From aac1335055d949537fb8decc6540d8f095a488e8 Mon Sep 17 00:00:00 2001
From: Mateo Cedillo <54605382+rmcpantoja@users.noreply.github.com>
Date: Tue, 11 Jul 2023 08:55:48 -0500
Subject: [PATCH] Added inference notebook for test ckpt models.
---
notebooks/piper_inference_(ckpt).ipynb | 518 +++++++++++++++++++++++++
1 file changed, 518 insertions(+)
create mode 100644 notebooks/piper_inference_(ckpt).ipynb
diff --git a/notebooks/piper_inference_(ckpt).ipynb b/notebooks/piper_inference_(ckpt).ipynb
new file mode 100644
index 0000000..e2e8fc6
--- /dev/null
+++ b/notebooks/piper_inference_(ckpt).ipynb
@@ -0,0 +1,518 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4",
+ "authorship_tag": "ABX9TyNYaSm2J1hcbQ4WsDh9OTnr",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# **[Piper](https://github.com/rhasspy/piper) inferencing notebook.**\n",
+ "## \n",
+ "\n",
+ "---\n",
+ "\n",
+ "- Notebook made by [rmcpantoja](http://github.com/rmcpantoja)\n",
+ "- Collaborator: [Xx_Nessu_xX](https://fakeyou.com/profile/Xx_Nessu_xX)"
+ ],
+ "metadata": {
+ "id": "eK3nmYDB6C1a"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# First steps"
+ ],
+ "metadata": {
+ "id": "9wIvcSmOby84"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Install software and settings\n",
+ "#@markdown The speech synthesizer and other important dependencies will be installed in this cell. But first, some settings:\n",
+ "\n",
+ "#@markdown #### Enhable Enhanced Accessibility?\n",
+ "#@markdown This Enhanced Accessibility functionality is designed for the visually impaired, in which most of the interface can be used by voice guides.\n",
+ "enhanced_accessibility = True #@param {type:\"boolean\"}\n",
+ "#@markdown ---\n",
+ "\n",
+ "#@markdown #### Please select your language:\n",
+ "lang_select = \"English\" #@param [\"English\", \"Spanish\"]\n",
+ "if lang_select == \"English\":\n",
+ " lang = \"en\"\n",
+ "elif lang_select == \"Spanish\":\n",
+ " lang = \"es\"\n",
+ "else:\n",
+ " raise Exception(\"Language not supported.\")\n",
+ "#@markdown ---\n",
+ "#@markdown #### Do you want to use the GPU for inference?\n",
+ "\n",
+ "#@markdown The GPU can be enabled in the edit/notebook settings menu, and this step must be done before connecting to a runtime. The GPU can lead to a higher response speed in inference, but you can use the CPU, for example, if your colab runtime to use GPU's has been ended.\n",
+ "use_gpu = False #@param {type:\"boolean\"}\n",
+ "\n",
+ "if enhanced_accessibility:\n",
+ " from google.colab import output\n",
+ " guideurl = f\"https://github.com/rmcpantoja/piper/blob/master/notebooks/wav/{lang}\"\n",
+ " def playaudio(filename, extension = \"wav\"):\n",
+ " return output.eval_js(f'new Audio(\"{guideurl}/{filename}.{extension}?raw=true\").play()')\n",
+ "\n",
+ "%cd /content\n",
+ "print(\"Installing...\")\n",
+ "if enhanced_accessibility:\n",
+ " playaudio(\"installing\")\n",
+ "!git clone -q https://github.com/rmcpantoja/piper\n",
+ "%cd /content/piper/src/python\n",
+ "!pip install -q -r requirements.txt\n",
+ "!pip install -q torchtext==0.12.0 torchvision==0.12.0\n",
+ "# fixing recent compativility isswes:\n",
+ "!pip install -q torchaudio==0.11.0 torchmetrics==0.11.4\n",
+ "!bash build_monotonic_align.sh\n",
+ "!apt-get install -q espeak-ng\n",
+ "import os\n",
+ "if not os.path.exists(\"/content/piper/src/python/lng\"):\n",
+ " !cp -r \"/content/piper/notebooks/lng\" /content/piper/src/python/lng\n",
+ "import sys\n",
+ "sys.path.append('/content/piper/notebooks')\n",
+ "from translator import *\n",
+ "lan = Translator()\n",
+ "print(\"Checking GPU...\")\n",
+ "gpu_info = !nvidia-smi\n",
+ "if use_gpu and any('not found' in info for info in gpu_info[0].split(':')):\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"nogpu\")\n",
+ " raise Exception(lan.translate(lang, \"The Use GPU checkbox is checked, but you don't have a GPU runtime.\"))\n",
+ "elif not use_gpu and not any('not found' in info for info in gpu_info[0].split(':')):\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"gpuavailable\")\n",
+ " raise Exception(lan.translate(lang, \"The Use GPU checkbox is unchecked, however you are using a GPU runtime environment. We recommend you check the checkbox to use GPU to take advantage of it.\"))\n",
+ "\n",
+ "if enhanced_accessibility:\n",
+ " playaudio(\"installed\")\n",
+ "print(\"Success!\")"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "v8b_PEtXb8co"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Download model and config\n",
+ "%cd /content/piper/src/python\n",
+ "import os\n",
+ "#@markdown #### Model ID or link (ckpt format):\n",
+ "model_url_or_id = \"\" #@param {type:\"string\"}\n",
+ "if model_url_or_id == \"\" or model_url_or_id == \"http\" or model_url_or_id == \"1\":\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"noid\")\n",
+ " raise Exception(lan.translate(lang, \"Invalid link or ID!\"))\n",
+ "print(\"Downloading model...\")\n",
+ "if enhanced_accessibility:\n",
+ " playaudio(\"downloading\")\n",
+ "if model_url_or_id.startswith(\"1\"):\n",
+ " !gdown -q \"{model_url_or_id}\"\n",
+ "elif model_url_or_id.startswith(\"https://drive.google.com/file/d/\"):\n",
+ " !gdown -q \"{model_url_or_id}\" --fuzzy\n",
+ "else:\n",
+ " !wget -q \"{model_url_or_id}\"\n",
+ "#@markdown ---\n",
+ "#@markdown #### ID or URL of the config.json file:\n",
+ "config_url_or_id = \"\" #@param {type:\"string\"}\n",
+ "if config_url_or_id.startswith(\"1\"):\n",
+ " !gdown -q \"{config_url_or_id}\"\n",
+ "elif config_url_or_id.startswith(\"https://drive.google.com/file/d/\"):\n",
+ " !gdown -q \"{config_url_or_id}\" --fuzzy\n",
+ "else:\n",
+ " !wget -q \"{config_url_or_id}\"\n",
+ "#@markdown ---\n",
+ "if enhanced_accessibility:\n",
+ " playaudio(\"downloaded\")\n",
+ "print(\"Voice package downloaded!\")"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "ykIYmVXccg6s"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# inferencing"
+ ],
+ "metadata": {
+ "id": "MRvkYJF6g5FT"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title run inference\n",
+ "#@markdown #### before you enjoy... Some notes!\n",
+ "#@markdown * You can run the cell to download voice packs and download voices you want at any time, even if you run this cell!\n",
+ "#@markdown * When you download a new voice, run this cell again and you will now be able to toggle between all the ones you download. Incredible, right?\n",
+ "\n",
+ "#@markdown Enjoy!!\n",
+ "\n",
+ "%cd /content/piper/src/python\n",
+ "# original: infer.py\n",
+ "import json\n",
+ "import logging\n",
+ "import sys\n",
+ "from pathlib import Path\n",
+ "import torch\n",
+ "from piper_train.vits.lightning import VitsModel\n",
+ "from piper_train.vits.utils import audio_float_to_int16\n",
+ "from piper_train.vits.wavfile import write as write_wav\n",
+ "import numpy as np\n",
+ "import glob\n",
+ "import ipywidgets as widgets\n",
+ "from IPython.display import display, Audio, Markdown, clear_output\n",
+ "from espeak_phonemizer import Phonemizer\n",
+ "from piper_train import phonemize\n",
+ "\n",
+ "_LOGGER = logging.getLogger(\"piper_train.infer_onnx\")\n",
+ "\n",
+ "def detect_ckpt_models(path):\n",
+ " ckpt_models = glob.glob(path + '/*.ckpt')\n",
+ " if len(ckpt_models) > 1:\n",
+ " return ckpt_models\n",
+ " elif len(ckpt_models) == 1:\n",
+ " return ckpt_models[0]\n",
+ " else:\n",
+ " return None\n",
+ "\n",
+ "\n",
+ "def main():\n",
+ " \"\"\"Main entry point\"\"\"\n",
+ " models_path = \"/content/piper/src/python\"\n",
+ " logging.basicConfig(level=logging.DEBUG)\n",
+ " model = None\n",
+ " ckpt_models = detect_ckpt_models(models_path)\n",
+ " speaker_selection = widgets.Dropdown(\n",
+ " options=[],\n",
+ " description=f'{lan.translate(lang, \"Select speaker\")}:',\n",
+ " layout={'visibility': 'hidden'}\n",
+ " )\n",
+ " if ckpt_models is None:\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"novoices\")\n",
+ " raise Exception(lan.translate(lang, \"No downloaded voice packages!\"))\n",
+ " elif isinstance(ckpt_models, str):\n",
+ " ckpt_model = ckpt_models\n",
+ " model, config = load_ckpt(ckpt_model)\n",
+ " if config[\"num_speakers\"] > 1:\n",
+ " speaker_selection.options = config[\"speaker_id_map\"].values()\n",
+ " speaker_selection.layout.visibility = 'visible'\n",
+ " preview_sid = 0\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"multispeaker\")\n",
+ " else:\n",
+ " speaker_selection.layout.visibility = 'hidden'\n",
+ " preview_sid = None\n",
+ "\n",
+ " if enhanced_accessibility:\n",
+ " inferencing(\n",
+ " model,\n",
+ " config,\n",
+ " preview_sid,\n",
+ " lan.translate(\n",
+ " config[\"espeak\"][\"voice\"][:2],\n",
+ " \"Interface openned. Write your texts, configure the different synthesis options or download all the voices you want. Enjoy!\"\n",
+ " )\n",
+ " )\n",
+ " else:\n",
+ " voice_model_names = []\n",
+ " for current in ckpt_models:\n",
+ " voice_struct = current.split(\"/\")[5]\n",
+ " voice_model_names.append(voice_struct)\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"selectmodel\")\n",
+ " selection = widgets.Dropdown(\n",
+ " options=voice_model_names,\n",
+ " description=f'{lan.translate(lang, \"Select voice package\")}:',\n",
+ " )\n",
+ " load_btn = widgets.Button(\n",
+ " description=lan.translate(lang, \"Load it!\")\n",
+ " )\n",
+ " config = None\n",
+ " def load_model(button):\n",
+ " nonlocal config\n",
+ " global ckpt_model\n",
+ " nonlocal model\n",
+ " nonlocal models_path\n",
+ " selected_voice = selection.value\n",
+ " ckpt_model = f\"{models_path}/{selected_voice}\"\n",
+ " model, config = load_ckpt(onnx_model, sess_options, providers)\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"loaded\")\n",
+ " if config[\"num_speakers\"] > 1:\n",
+ " speaker_selection.options = config[\"speaker_id_map\"].values()\n",
+ " speaker_selection.layout.visibility = 'visible'\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"multispeaker\")\n",
+ " else:\n",
+ " speaker_selection.layout.visibility = 'hidden'\n",
+ "\n",
+ " load_btn.on_click(load_model)\n",
+ " display(selection, load_btn)\n",
+ " display(speaker_selection)\n",
+ " speed_slider = widgets.FloatSlider(\n",
+ " value=1,\n",
+ " min=0.25,\n",
+ " max=4,\n",
+ " step=0.1,\n",
+ " description=lan.translate(lang, \"Rate scale\"),\n",
+ " orientation='horizontal',\n",
+ " )\n",
+ " noise_scale_slider = widgets.FloatSlider(\n",
+ " value=0.667,\n",
+ " min=0.25,\n",
+ " max=4,\n",
+ " step=0.1,\n",
+ " description=lan.translate(lang, \"Phoneme noise scale\"),\n",
+ " orientation='horizontal',\n",
+ " )\n",
+ " noise_scale_w_slider = widgets.FloatSlider(\n",
+ " value=1,\n",
+ " min=0.25,\n",
+ " max=4,\n",
+ " step=0.1,\n",
+ " description=lan.translate(lang, \"Phoneme stressing scale\"),\n",
+ " orientation='horizontal',\n",
+ " )\n",
+ " play = widgets.Checkbox(\n",
+ " value=True,\n",
+ " description=lan.translate(lang, \"Auto-play\"),\n",
+ " disabled=False\n",
+ " )\n",
+ " text_input = widgets.Text(\n",
+ " value='',\n",
+ " placeholder=f'{lan.translate(lang, \"Enter your text here\")}:',\n",
+ " description=lan.translate(lang, \"Text to synthesize\"),\n",
+ " layout=widgets.Layout(width='80%')\n",
+ " )\n",
+ " synthesize_button = widgets.Button(\n",
+ " description=lan.translate(lang, \"Synthesize\"),\n",
+ " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n",
+ " tooltip=lan.translate(lang, \"Click here to synthesize the text.\"),\n",
+ " icon='check'\n",
+ " )\n",
+ " close_button = widgets.Button(\n",
+ " description=lan.translate(lang, \"Exit\"),\n",
+ " tooltip=lan.translate(lang, \"Closes this GUI.\"),\n",
+ " icon='check'\n",
+ " )\n",
+ "\n",
+ " def on_synthesize_button_clicked(b):\n",
+ " if model is None:\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"nomodel\")\n",
+ " raise Exception(lan.translate(lang, \"You have not loaded any model from the list!\"))\n",
+ " text = text_input.value\n",
+ " if config[\"num_speakers\"] > 1:\n",
+ " sid = speaker_selection.value\n",
+ " else:\n",
+ " sid = None\n",
+ " rate = speed_slider.value\n",
+ " noise_scale = noise_scale_slider.value\n",
+ " noise_scale_w = noise_scale_w_slider.value\n",
+ " auto_play = play.value\n",
+ " inferencing(model, config, sid, text, rate, noise_scale, noise_scale_w, auto_play)\n",
+ "\n",
+ " def on_close_button_clicked(b):\n",
+ " clear_output()\n",
+ " if enhanced_accessibility:\n",
+ " playaudio(\"exit\")\n",
+ "\n",
+ " synthesize_button.on_click(on_synthesize_button_clicked)\n",
+ " close_button.on_click(on_close_button_clicked)\n",
+ " display(text_input)\n",
+ " display(speed_slider)\n",
+ " display(noise_scale_slider)\n",
+ " display(noise_scale_w_slider)\n",
+ " display(play)\n",
+ " display(synthesize_button)\n",
+ " display(close_button)\n",
+ "\n",
+ "def load_ckpt(model):\n",
+ " _LOGGER.debug(\"Loading model from %s\", model)\n",
+ " config = load_config(model)\n",
+ " model = VitsModel.load_from_checkpoint(str(model), dataset=None)\n",
+ " # Inference only\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " model.model_g.dec.remove_weight_norm()\n",
+ "\n",
+ " _LOGGER.info(\"Loaded model from %s\", model)\n",
+ " return model, config\n",
+ "\n",
+ "def load_config(model):\n",
+ " with open(\"config.json\", \"r\") as file:\n",
+ " config = json.load(file)\n",
+ " return config\n",
+ "\n",
+ "def inferencing(model, config, sid, line, length_scale = 1, noise_scale = 0.667, noise_scale_w = 0.8, auto_play=True):\n",
+ " espeak_voice = config[\"espeak\"][\"voice\"]\n",
+ " phonemizer = Phonemizer(default_voice=espeak_voice)\n",
+ " phonemes = phonemize.phonemize(line, phonemizer)\n",
+ " ids = phonemize.phonemes_to_ids(phonemes)\n",
+ " phoneme_ids = ids\n",
+ " num_speakers = config[\"num_speakers\"]\n",
+ " if num_speakers == 1:\n",
+ " speaker_id = None # for now\n",
+ " else:\n",
+ " speaker_id = sid\n",
+ " text = torch.LongTensor(phoneme_ids).unsqueeze(0)\n",
+ " text_lengths = torch.LongTensor([len(phoneme_ids)])\n",
+ " scales = [\n",
+ " noise_scale,\n",
+ " length_scale,\n",
+ " noise_scale_w\n",
+ " ]\n",
+ " sid = torch.LongTensor([speaker_id]) if speaker_id is not None else None\n",
+ " audio = model(\n",
+ " text,\n",
+ " text_lengths,\n",
+ " scales,\n",
+ " sid=sid\n",
+ " ).detach().numpy()\n",
+ " audio = audio_float_to_int16(audio.squeeze())\n",
+ " sample_rate = config[\"audio\"][\"sample_rate\"]\n",
+ " display(Markdown(f\"{line}\"))\n",
+ " display(Audio(audio, rate=sample_rate, autoplay=auto_play))\n",
+ "\n",
+ "def denoise(\n",
+ " audio: np.ndarray, bias_spec: np.ndarray, denoiser_strength: float\n",
+ ") -> np.ndarray:\n",
+ " audio_spec, audio_angles = transform(audio)\n",
+ "\n",
+ " a = bias_spec.shape[-1]\n",
+ " b = audio_spec.shape[-1]\n",
+ " repeats = max(1, math.ceil(b / a))\n",
+ " bias_spec_repeat = np.repeat(bias_spec, repeats, axis=-1)[..., :b]\n",
+ "\n",
+ " audio_spec_denoised = audio_spec - (bias_spec_repeat * denoiser_strength)\n",
+ " audio_spec_denoised = np.clip(audio_spec_denoised, a_min=0.0, a_max=None)\n",
+ " audio_denoised = inverse(audio_spec_denoised, audio_angles)\n",
+ "\n",
+ " return audio_denoised\n",
+ "\n",
+ "\n",
+ "def stft(x, fft_size, hopsamp):\n",
+ " \"\"\"Compute and return the STFT of the supplied time domain signal x.\n",
+ " Args:\n",
+ " x (1-dim Numpy array): A time domain signal.\n",
+ " fft_size (int): FFT size. Should be a power of 2, otherwise DFT will be used.\n",
+ " hopsamp (int):\n",
+ " Returns:\n",
+ " The STFT. The rows are the time slices and columns are the frequency bins.\n",
+ " \"\"\"\n",
+ " window = np.hanning(fft_size)\n",
+ " fft_size = int(fft_size)\n",
+ " hopsamp = int(hopsamp)\n",
+ " return np.array(\n",
+ " [\n",
+ " np.fft.rfft(window * x[i : i + fft_size])\n",
+ " for i in range(0, len(x) - fft_size, hopsamp)\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def istft(X, fft_size, hopsamp):\n",
+ " \"\"\"Invert a STFT into a time domain signal.\n",
+ " Args:\n",
+ " X (2-dim Numpy array): Input spectrogram. The rows are the time slices and columns are the frequency bins.\n",
+ " fft_size (int):\n",
+ " hopsamp (int): The hop size, in samples.\n",
+ " Returns:\n",
+ " The inverse STFT.\n",
+ " \"\"\"\n",
+ " fft_size = int(fft_size)\n",
+ " hopsamp = int(hopsamp)\n",
+ " window = np.hanning(fft_size)\n",
+ " time_slices = X.shape[0]\n",
+ " len_samples = int(time_slices * hopsamp + fft_size)\n",
+ " x = np.zeros(len_samples)\n",
+ " for n, i in enumerate(range(0, len(x) - fft_size, hopsamp)):\n",
+ " x[i : i + fft_size] += window * np.real(np.fft.irfft(X[n]))\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "def inverse(magnitude, phase):\n",
+ " recombine_magnitude_phase = np.concatenate(\n",
+ " [magnitude * np.cos(phase), magnitude * np.sin(phase)], axis=1\n",
+ " )\n",
+ "\n",
+ " x_org = recombine_magnitude_phase\n",
+ " n_b, n_f, n_t = x_org.shape # pylint: disable=unpacking-non-sequence\n",
+ " x = np.empty([n_b, n_f // 2, n_t], dtype=np.complex64)\n",
+ " x.real = x_org[:, : n_f // 2]\n",
+ " x.imag = x_org[:, n_f // 2 :]\n",
+ " inverse_transform = []\n",
+ " for y in x:\n",
+ " y_ = istft(y.T, fft_size=1024, hopsamp=256)\n",
+ " inverse_transform.append(y_[None, :])\n",
+ "\n",
+ " inverse_transform = np.concatenate(inverse_transform, 0)\n",
+ "\n",
+ " return inverse_transform\n",
+ "\n",
+ "\n",
+ "def transform(input_data):\n",
+ " x = input_data\n",
+ " real_part = []\n",
+ " imag_part = []\n",
+ " for y in x:\n",
+ " y_ = stft(y, fft_size=1024, hopsamp=256).T\n",
+ " real_part.append(y_.real[None, :, :]) # pylint: disable=unsubscriptable-object\n",
+ " imag_part.append(y_.imag[None, :, :]) # pylint: disable=unsubscriptable-object\n",
+ " real_part = np.concatenate(real_part, 0)\n",
+ " imag_part = np.concatenate(imag_part, 0)\n",
+ "\n",
+ " magnitude = np.sqrt(real_part**2 + imag_part**2)\n",
+ " phase = np.arctan2(imag_part.data, real_part.data)\n",
+ "\n",
+ " return magnitude, phase\n",
+ "\n",
+ "main()"
+ ],
+ "metadata": {
+ "id": "hcKk8M2ug8kM",
+ "cellView": "form"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file