From aac1335055d949537fb8decc6540d8f095a488e8 Mon Sep 17 00:00:00 2001 From: Mateo Cedillo <54605382+rmcpantoja@users.noreply.github.com> Date: Tue, 11 Jul 2023 08:55:48 -0500 Subject: [PATCH] Added inference notebook for test ckpt models. --- notebooks/piper_inference_(ckpt).ipynb | 518 +++++++++++++++++++++++++ 1 file changed, 518 insertions(+) create mode 100644 notebooks/piper_inference_(ckpt).ipynb diff --git a/notebooks/piper_inference_(ckpt).ipynb b/notebooks/piper_inference_(ckpt).ipynb new file mode 100644 index 0000000..e2e8fc6 --- /dev/null +++ b/notebooks/piper_inference_(ckpt).ipynb @@ -0,0 +1,518 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4", + "authorship_tag": "ABX9TyNYaSm2J1hcbQ4WsDh9OTnr", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "source": [ + "# **[Piper](https://github.com/rhasspy/piper) inferencing notebook.**\n", + "## ![Piper logo](https://contribute.rhasspy.org/img/logo.png)\n", + "\n", + "---\n", + "\n", + "- Notebook made by [rmcpantoja](http://github.com/rmcpantoja)\n", + "- Collaborator: [Xx_Nessu_xX](https://fakeyou.com/profile/Xx_Nessu_xX)" + ], + "metadata": { + "id": "eK3nmYDB6C1a" + } + }, + { + "cell_type": "markdown", + "source": [ + "# First steps" + ], + "metadata": { + "id": "9wIvcSmOby84" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Install software and settings\n", + "#@markdown The speech synthesizer and other important dependencies will be installed in this cell. But first, some settings:\n", + "\n", + "#@markdown #### Enhable Enhanced Accessibility?\n", + "#@markdown This Enhanced Accessibility functionality is designed for the visually impaired, in which most of the interface can be used by voice guides.\n", + "enhanced_accessibility = True #@param {type:\"boolean\"}\n", + "#@markdown ---\n", + "\n", + "#@markdown #### Please select your language:\n", + "lang_select = \"English\" #@param [\"English\", \"Spanish\"]\n", + "if lang_select == \"English\":\n", + " lang = \"en\"\n", + "elif lang_select == \"Spanish\":\n", + " lang = \"es\"\n", + "else:\n", + " raise Exception(\"Language not supported.\")\n", + "#@markdown ---\n", + "#@markdown #### Do you want to use the GPU for inference?\n", + "\n", + "#@markdown The GPU can be enabled in the edit/notebook settings menu, and this step must be done before connecting to a runtime. The GPU can lead to a higher response speed in inference, but you can use the CPU, for example, if your colab runtime to use GPU's has been ended.\n", + "use_gpu = False #@param {type:\"boolean\"}\n", + "\n", + "if enhanced_accessibility:\n", + " from google.colab import output\n", + " guideurl = f\"https://github.com/rmcpantoja/piper/blob/master/notebooks/wav/{lang}\"\n", + " def playaudio(filename, extension = \"wav\"):\n", + " return output.eval_js(f'new Audio(\"{guideurl}/{filename}.{extension}?raw=true\").play()')\n", + "\n", + "%cd /content\n", + "print(\"Installing...\")\n", + "if enhanced_accessibility:\n", + " playaudio(\"installing\")\n", + "!git clone -q https://github.com/rmcpantoja/piper\n", + "%cd /content/piper/src/python\n", + "!pip install -q -r requirements.txt\n", + "!pip install -q torchtext==0.12.0 torchvision==0.12.0\n", + "# fixing recent compativility isswes:\n", + "!pip install -q torchaudio==0.11.0 torchmetrics==0.11.4\n", + "!bash build_monotonic_align.sh\n", + "!apt-get install -q espeak-ng\n", + "import os\n", + "if not os.path.exists(\"/content/piper/src/python/lng\"):\n", + " !cp -r \"/content/piper/notebooks/lng\" /content/piper/src/python/lng\n", + "import sys\n", + "sys.path.append('/content/piper/notebooks')\n", + "from translator import *\n", + "lan = Translator()\n", + "print(\"Checking GPU...\")\n", + "gpu_info = !nvidia-smi\n", + "if use_gpu and any('not found' in info for info in gpu_info[0].split(':')):\n", + " if enhanced_accessibility:\n", + " playaudio(\"nogpu\")\n", + " raise Exception(lan.translate(lang, \"The Use GPU checkbox is checked, but you don't have a GPU runtime.\"))\n", + "elif not use_gpu and not any('not found' in info for info in gpu_info[0].split(':')):\n", + " if enhanced_accessibility:\n", + " playaudio(\"gpuavailable\")\n", + " raise Exception(lan.translate(lang, \"The Use GPU checkbox is unchecked, however you are using a GPU runtime environment. We recommend you check the checkbox to use GPU to take advantage of it.\"))\n", + "\n", + "if enhanced_accessibility:\n", + " playaudio(\"installed\")\n", + "print(\"Success!\")" + ], + "metadata": { + "cellView": "form", + "id": "v8b_PEtXb8co" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Download model and config\n", + "%cd /content/piper/src/python\n", + "import os\n", + "#@markdown #### Model ID or link (ckpt format):\n", + "model_url_or_id = \"\" #@param {type:\"string\"}\n", + "if model_url_or_id == \"\" or model_url_or_id == \"http\" or model_url_or_id == \"1\":\n", + " if enhanced_accessibility:\n", + " playaudio(\"noid\")\n", + " raise Exception(lan.translate(lang, \"Invalid link or ID!\"))\n", + "print(\"Downloading model...\")\n", + "if enhanced_accessibility:\n", + " playaudio(\"downloading\")\n", + "if model_url_or_id.startswith(\"1\"):\n", + " !gdown -q \"{model_url_or_id}\"\n", + "elif model_url_or_id.startswith(\"https://drive.google.com/file/d/\"):\n", + " !gdown -q \"{model_url_or_id}\" --fuzzy\n", + "else:\n", + " !wget -q \"{model_url_or_id}\"\n", + "#@markdown ---\n", + "#@markdown #### ID or URL of the config.json file:\n", + "config_url_or_id = \"\" #@param {type:\"string\"}\n", + "if config_url_or_id.startswith(\"1\"):\n", + " !gdown -q \"{config_url_or_id}\"\n", + "elif config_url_or_id.startswith(\"https://drive.google.com/file/d/\"):\n", + " !gdown -q \"{config_url_or_id}\" --fuzzy\n", + "else:\n", + " !wget -q \"{config_url_or_id}\"\n", + "#@markdown ---\n", + "if enhanced_accessibility:\n", + " playaudio(\"downloaded\")\n", + "print(\"Voice package downloaded!\")" + ], + "metadata": { + "cellView": "form", + "id": "ykIYmVXccg6s" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# inferencing" + ], + "metadata": { + "id": "MRvkYJF6g5FT" + } + }, + { + "cell_type": "code", + "source": [ + "#@title run inference\n", + "#@markdown #### before you enjoy... Some notes!\n", + "#@markdown * You can run the cell to download voice packs and download voices you want at any time, even if you run this cell!\n", + "#@markdown * When you download a new voice, run this cell again and you will now be able to toggle between all the ones you download. Incredible, right?\n", + "\n", + "#@markdown Enjoy!!\n", + "\n", + "%cd /content/piper/src/python\n", + "# original: infer.py\n", + "import json\n", + "import logging\n", + "import sys\n", + "from pathlib import Path\n", + "import torch\n", + "from piper_train.vits.lightning import VitsModel\n", + "from piper_train.vits.utils import audio_float_to_int16\n", + "from piper_train.vits.wavfile import write as write_wav\n", + "import numpy as np\n", + "import glob\n", + "import ipywidgets as widgets\n", + "from IPython.display import display, Audio, Markdown, clear_output\n", + "from espeak_phonemizer import Phonemizer\n", + "from piper_train import phonemize\n", + "\n", + "_LOGGER = logging.getLogger(\"piper_train.infer_onnx\")\n", + "\n", + "def detect_ckpt_models(path):\n", + " ckpt_models = glob.glob(path + '/*.ckpt')\n", + " if len(ckpt_models) > 1:\n", + " return ckpt_models\n", + " elif len(ckpt_models) == 1:\n", + " return ckpt_models[0]\n", + " else:\n", + " return None\n", + "\n", + "\n", + "def main():\n", + " \"\"\"Main entry point\"\"\"\n", + " models_path = \"/content/piper/src/python\"\n", + " logging.basicConfig(level=logging.DEBUG)\n", + " model = None\n", + " ckpt_models = detect_ckpt_models(models_path)\n", + " speaker_selection = widgets.Dropdown(\n", + " options=[],\n", + " description=f'{lan.translate(lang, \"Select speaker\")}:',\n", + " layout={'visibility': 'hidden'}\n", + " )\n", + " if ckpt_models is None:\n", + " if enhanced_accessibility:\n", + " playaudio(\"novoices\")\n", + " raise Exception(lan.translate(lang, \"No downloaded voice packages!\"))\n", + " elif isinstance(ckpt_models, str):\n", + " ckpt_model = ckpt_models\n", + " model, config = load_ckpt(ckpt_model)\n", + " if config[\"num_speakers\"] > 1:\n", + " speaker_selection.options = config[\"speaker_id_map\"].values()\n", + " speaker_selection.layout.visibility = 'visible'\n", + " preview_sid = 0\n", + " if enhanced_accessibility:\n", + " playaudio(\"multispeaker\")\n", + " else:\n", + " speaker_selection.layout.visibility = 'hidden'\n", + " preview_sid = None\n", + "\n", + " if enhanced_accessibility:\n", + " inferencing(\n", + " model,\n", + " config,\n", + " preview_sid,\n", + " lan.translate(\n", + " config[\"espeak\"][\"voice\"][:2],\n", + " \"Interface openned. Write your texts, configure the different synthesis options or download all the voices you want. Enjoy!\"\n", + " )\n", + " )\n", + " else:\n", + " voice_model_names = []\n", + " for current in ckpt_models:\n", + " voice_struct = current.split(\"/\")[5]\n", + " voice_model_names.append(voice_struct)\n", + " if enhanced_accessibility:\n", + " playaudio(\"selectmodel\")\n", + " selection = widgets.Dropdown(\n", + " options=voice_model_names,\n", + " description=f'{lan.translate(lang, \"Select voice package\")}:',\n", + " )\n", + " load_btn = widgets.Button(\n", + " description=lan.translate(lang, \"Load it!\")\n", + " )\n", + " config = None\n", + " def load_model(button):\n", + " nonlocal config\n", + " global ckpt_model\n", + " nonlocal model\n", + " nonlocal models_path\n", + " selected_voice = selection.value\n", + " ckpt_model = f\"{models_path}/{selected_voice}\"\n", + " model, config = load_ckpt(onnx_model, sess_options, providers)\n", + " if enhanced_accessibility:\n", + " playaudio(\"loaded\")\n", + " if config[\"num_speakers\"] > 1:\n", + " speaker_selection.options = config[\"speaker_id_map\"].values()\n", + " speaker_selection.layout.visibility = 'visible'\n", + " if enhanced_accessibility:\n", + " playaudio(\"multispeaker\")\n", + " else:\n", + " speaker_selection.layout.visibility = 'hidden'\n", + "\n", + " load_btn.on_click(load_model)\n", + " display(selection, load_btn)\n", + " display(speaker_selection)\n", + " speed_slider = widgets.FloatSlider(\n", + " value=1,\n", + " min=0.25,\n", + " max=4,\n", + " step=0.1,\n", + " description=lan.translate(lang, \"Rate scale\"),\n", + " orientation='horizontal',\n", + " )\n", + " noise_scale_slider = widgets.FloatSlider(\n", + " value=0.667,\n", + " min=0.25,\n", + " max=4,\n", + " step=0.1,\n", + " description=lan.translate(lang, \"Phoneme noise scale\"),\n", + " orientation='horizontal',\n", + " )\n", + " noise_scale_w_slider = widgets.FloatSlider(\n", + " value=1,\n", + " min=0.25,\n", + " max=4,\n", + " step=0.1,\n", + " description=lan.translate(lang, \"Phoneme stressing scale\"),\n", + " orientation='horizontal',\n", + " )\n", + " play = widgets.Checkbox(\n", + " value=True,\n", + " description=lan.translate(lang, \"Auto-play\"),\n", + " disabled=False\n", + " )\n", + " text_input = widgets.Text(\n", + " value='',\n", + " placeholder=f'{lan.translate(lang, \"Enter your text here\")}:',\n", + " description=lan.translate(lang, \"Text to synthesize\"),\n", + " layout=widgets.Layout(width='80%')\n", + " )\n", + " synthesize_button = widgets.Button(\n", + " description=lan.translate(lang, \"Synthesize\"),\n", + " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", + " tooltip=lan.translate(lang, \"Click here to synthesize the text.\"),\n", + " icon='check'\n", + " )\n", + " close_button = widgets.Button(\n", + " description=lan.translate(lang, \"Exit\"),\n", + " tooltip=lan.translate(lang, \"Closes this GUI.\"),\n", + " icon='check'\n", + " )\n", + "\n", + " def on_synthesize_button_clicked(b):\n", + " if model is None:\n", + " if enhanced_accessibility:\n", + " playaudio(\"nomodel\")\n", + " raise Exception(lan.translate(lang, \"You have not loaded any model from the list!\"))\n", + " text = text_input.value\n", + " if config[\"num_speakers\"] > 1:\n", + " sid = speaker_selection.value\n", + " else:\n", + " sid = None\n", + " rate = speed_slider.value\n", + " noise_scale = noise_scale_slider.value\n", + " noise_scale_w = noise_scale_w_slider.value\n", + " auto_play = play.value\n", + " inferencing(model, config, sid, text, rate, noise_scale, noise_scale_w, auto_play)\n", + "\n", + " def on_close_button_clicked(b):\n", + " clear_output()\n", + " if enhanced_accessibility:\n", + " playaudio(\"exit\")\n", + "\n", + " synthesize_button.on_click(on_synthesize_button_clicked)\n", + " close_button.on_click(on_close_button_clicked)\n", + " display(text_input)\n", + " display(speed_slider)\n", + " display(noise_scale_slider)\n", + " display(noise_scale_w_slider)\n", + " display(play)\n", + " display(synthesize_button)\n", + " display(close_button)\n", + "\n", + "def load_ckpt(model):\n", + " _LOGGER.debug(\"Loading model from %s\", model)\n", + " config = load_config(model)\n", + " model = VitsModel.load_from_checkpoint(str(model), dataset=None)\n", + " # Inference only\n", + " model.eval()\n", + " with torch.no_grad():\n", + " model.model_g.dec.remove_weight_norm()\n", + "\n", + " _LOGGER.info(\"Loaded model from %s\", model)\n", + " return model, config\n", + "\n", + "def load_config(model):\n", + " with open(\"config.json\", \"r\") as file:\n", + " config = json.load(file)\n", + " return config\n", + "\n", + "def inferencing(model, config, sid, line, length_scale = 1, noise_scale = 0.667, noise_scale_w = 0.8, auto_play=True):\n", + " espeak_voice = config[\"espeak\"][\"voice\"]\n", + " phonemizer = Phonemizer(default_voice=espeak_voice)\n", + " phonemes = phonemize.phonemize(line, phonemizer)\n", + " ids = phonemize.phonemes_to_ids(phonemes)\n", + " phoneme_ids = ids\n", + " num_speakers = config[\"num_speakers\"]\n", + " if num_speakers == 1:\n", + " speaker_id = None # for now\n", + " else:\n", + " speaker_id = sid\n", + " text = torch.LongTensor(phoneme_ids).unsqueeze(0)\n", + " text_lengths = torch.LongTensor([len(phoneme_ids)])\n", + " scales = [\n", + " noise_scale,\n", + " length_scale,\n", + " noise_scale_w\n", + " ]\n", + " sid = torch.LongTensor([speaker_id]) if speaker_id is not None else None\n", + " audio = model(\n", + " text,\n", + " text_lengths,\n", + " scales,\n", + " sid=sid\n", + " ).detach().numpy()\n", + " audio = audio_float_to_int16(audio.squeeze())\n", + " sample_rate = config[\"audio\"][\"sample_rate\"]\n", + " display(Markdown(f\"{line}\"))\n", + " display(Audio(audio, rate=sample_rate, autoplay=auto_play))\n", + "\n", + "def denoise(\n", + " audio: np.ndarray, bias_spec: np.ndarray, denoiser_strength: float\n", + ") -> np.ndarray:\n", + " audio_spec, audio_angles = transform(audio)\n", + "\n", + " a = bias_spec.shape[-1]\n", + " b = audio_spec.shape[-1]\n", + " repeats = max(1, math.ceil(b / a))\n", + " bias_spec_repeat = np.repeat(bias_spec, repeats, axis=-1)[..., :b]\n", + "\n", + " audio_spec_denoised = audio_spec - (bias_spec_repeat * denoiser_strength)\n", + " audio_spec_denoised = np.clip(audio_spec_denoised, a_min=0.0, a_max=None)\n", + " audio_denoised = inverse(audio_spec_denoised, audio_angles)\n", + "\n", + " return audio_denoised\n", + "\n", + "\n", + "def stft(x, fft_size, hopsamp):\n", + " \"\"\"Compute and return the STFT of the supplied time domain signal x.\n", + " Args:\n", + " x (1-dim Numpy array): A time domain signal.\n", + " fft_size (int): FFT size. Should be a power of 2, otherwise DFT will be used.\n", + " hopsamp (int):\n", + " Returns:\n", + " The STFT. The rows are the time slices and columns are the frequency bins.\n", + " \"\"\"\n", + " window = np.hanning(fft_size)\n", + " fft_size = int(fft_size)\n", + " hopsamp = int(hopsamp)\n", + " return np.array(\n", + " [\n", + " np.fft.rfft(window * x[i : i + fft_size])\n", + " for i in range(0, len(x) - fft_size, hopsamp)\n", + " ]\n", + " )\n", + "\n", + "\n", + "def istft(X, fft_size, hopsamp):\n", + " \"\"\"Invert a STFT into a time domain signal.\n", + " Args:\n", + " X (2-dim Numpy array): Input spectrogram. The rows are the time slices and columns are the frequency bins.\n", + " fft_size (int):\n", + " hopsamp (int): The hop size, in samples.\n", + " Returns:\n", + " The inverse STFT.\n", + " \"\"\"\n", + " fft_size = int(fft_size)\n", + " hopsamp = int(hopsamp)\n", + " window = np.hanning(fft_size)\n", + " time_slices = X.shape[0]\n", + " len_samples = int(time_slices * hopsamp + fft_size)\n", + " x = np.zeros(len_samples)\n", + " for n, i in enumerate(range(0, len(x) - fft_size, hopsamp)):\n", + " x[i : i + fft_size] += window * np.real(np.fft.irfft(X[n]))\n", + " return x\n", + "\n", + "\n", + "def inverse(magnitude, phase):\n", + " recombine_magnitude_phase = np.concatenate(\n", + " [magnitude * np.cos(phase), magnitude * np.sin(phase)], axis=1\n", + " )\n", + "\n", + " x_org = recombine_magnitude_phase\n", + " n_b, n_f, n_t = x_org.shape # pylint: disable=unpacking-non-sequence\n", + " x = np.empty([n_b, n_f // 2, n_t], dtype=np.complex64)\n", + " x.real = x_org[:, : n_f // 2]\n", + " x.imag = x_org[:, n_f // 2 :]\n", + " inverse_transform = []\n", + " for y in x:\n", + " y_ = istft(y.T, fft_size=1024, hopsamp=256)\n", + " inverse_transform.append(y_[None, :])\n", + "\n", + " inverse_transform = np.concatenate(inverse_transform, 0)\n", + "\n", + " return inverse_transform\n", + "\n", + "\n", + "def transform(input_data):\n", + " x = input_data\n", + " real_part = []\n", + " imag_part = []\n", + " for y in x:\n", + " y_ = stft(y, fft_size=1024, hopsamp=256).T\n", + " real_part.append(y_.real[None, :, :]) # pylint: disable=unsubscriptable-object\n", + " imag_part.append(y_.imag[None, :, :]) # pylint: disable=unsubscriptable-object\n", + " real_part = np.concatenate(real_part, 0)\n", + " imag_part = np.concatenate(imag_part, 0)\n", + "\n", + " magnitude = np.sqrt(real_part**2 + imag_part**2)\n", + " phase = np.arctan2(imag_part.data, real_part.data)\n", + "\n", + " return magnitude, phase\n", + "\n", + "main()" + ], + "metadata": { + "id": "hcKk8M2ug8kM", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file