{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "AutoVC training .ipynb", "provenance": [], "collapsed_sections": [], "mount_file_id": "1zL5O_UHp9FMhn4AqNTf5c-_46aCB78g5", "authorship_tag": "ABX9TyN936oOrQasCqhzr517rjM0", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KlD_l2VGnlMq", "outputId": "5fa8ae30-e794-4491-fbb8-49402db93255" }, "source": [ "!git clone https://github.com/gkv856/end2end_auto_voice_conversion.git" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Cloning into 'end2end_auto_voice_conversion'...\n", "remote: Enumerating objects: 720, done.\u001b[K\n", "remote: Counting objects: 100% (720/720), done.\u001b[K\n", "remote: Compressing objects: 100% (521/521), done.\u001b[K\n", "remote: Total 720 (delta 257), reused 616 (delta 153), pack-reused 0\u001b[K\n", "Receiving objects: 100% (720/720), 266.45 MiB | 31.29 MiB/s, done.\n", "Resolving deltas: 100% (257/257), done.\n", "Checking out files: 100% (213/213), done.\n" ] } ] }, { "cell_type": "code", "metadata": { "id": "9F-Jn5-brXJP" }, "source": [ "!mv end2end_auto_voice_conversion/ AVC/" ], "execution_count": 2, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "iIVOtzwUqLkp", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "03d25bd9-6ddc-4566-9301-355b416a0b56" }, "source": [ "pip install webrtcvad" ], "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting webrtcvad\n", " Downloading webrtcvad-2.0.10.tar.gz (66 kB)\n", "\u001b[?25l\r\u001b[K |█████ | 10 kB 23.1 MB/s eta 0:00:01\r\u001b[K |██████████ | 20 kB 22.2 MB/s eta 0:00:01\r\u001b[K |██████████████▉ | 30 kB 11.2 MB/s eta 0:00:01\r\u001b[K |███████████████████▉ | 40 kB 9.0 MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 51 kB 5.3 MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▊ | 61 kB 5.8 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 66 kB 2.8 MB/s \n", "\u001b[?25hBuilding wheels for collected packages: webrtcvad\n", " Building wheel for webrtcvad (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for webrtcvad: filename=webrtcvad-2.0.10-cp37-cp37m-linux_x86_64.whl size=72381 sha256=d85a7b82aeb1d30339feabc54147db687457fadaf29ff9f172277b7d7f152858\n", " Stored in directory: /root/.cache/pip/wheels/11/f9/67/a3158d131f57e1c0a7d8d966a707d4a2fb27567a4fe47723ad\n", "Successfully built webrtcvad\n", "Installing collected packages: webrtcvad\n", "Successfully installed webrtcvad-2.0.10\n" ] } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6IjU5f1KrngW", "outputId": "f86fe1f2-6513-48c5-9a84-1a83d4f93ed7" }, "source": [ "from AVC.strings.constants import hp" ], "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device type available = 'cuda:0'\n" ] } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 402 }, "id": "Ikq17qcOn1W_", "outputId": "2aebfae8-cc44-44e7-8176-741561ac093b" }, "source": [ "import torch\n", "\n", "from AVC.s3_auto_voice_cloner.s1_create_emb_per_speaker import create_embbedings_per_speaker\n", "\n", "from sklearn.manifold import TSNE\n", "import matplotlib.pyplot as plt\n", "\n", "# hp.m_ge2e.best_model_path = \"static/model_chk_pts/ge2e/final_epoch_1000_L_0.0390.pth\"\n", "\n", "utterances = create_embbedings_per_speaker(hp)\n", "\n", "\n", "labels = []\n", "embs = []\n", "for k, v in utterances.items():\n", " embs.append(v)\n", " labels.append(k)\n", "\n", "\n", "embeddings = torch.tensor(embs)\n", "\n", "scatters = TSNE(n_components=2, random_state=0).fit_transform(embeddings.cpu().detach().numpy())\n", "fig = plt.figure(figsize=(5, 5))\n", "\n", "current_Label = labels[0]\n", "current_Index = 0\n", "for index, label in enumerate(labels[1:], 1):\n", " if label != current_Label:\n", " plt.scatter(scatters[current_Index:index, 0], scatters[current_Index:index, 1],\n", " label='{}'.format(current_Label))\n", " current_Label = label\n", " current_Index = index\n", "\n", "plt.scatter(scatters[current_Index:, 0], scatters[current_Index:, 1], label='{}'.format(current_Label))\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.show()\n" ], "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Pre-trained model loaded /content/AVC/static/model_chk_pts/ge2e/embedding_model_GE2E_loss_epoch_1000_L_0.0003.pth\n", "File saved!!\n" ] }, { "output_type": "display_data", "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "id": "xtIewt8kpdR6" }, "source": [ "from AVC.s3_auto_voice_cloner.s5_auto_vc_train import TrainAutoVCNetwork" ], "execution_count": 6, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nr0mhvWUq7_e", "outputId": "17235df3-d6d5-434d-e555-ffe19884f8f6" }, "source": [ "\n", "hp.m_avc.tpm.lambda_cd = 1\n", "hp.m_avc.tpm.num_iters = 1000\n", "hp.m_avc.tpm.log_step = 100\n", "hp.m_avc.tpm.dot_print = 10\n", "hp.m_avc.tpm.checkpoint_interval = 200\n", "hp.m_avc.tpm.lr = 0.001\n", "hp.m_avc.tpm.reduce_lr_interval = 100\n", "hp.m_avc.tpm.data_batch_size = 2\n", "\n", "solver = TrainAutoVCNetwork(hp)\n", "solver.optimizer = torch.optim.Adam(solver.auto_vc_net.parameters(), \n", " lr=solver.lr,\n", " betas=(0.9, 0.999),\n", " eps=1e-7,\n", " weight_decay=0)\n", "# start the training\n", "auto_vc_model, lst_loss_tuple = solver.start_training()\n" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Started Batched Training...\n", "Epoch:[100/1000] .." ] } ] }, { "cell_type": "code", "metadata": { "id": "T9d3hSD23gvi" }, "source": [ "hp.m_avc.gen.best_model_path = \"/content/AVC/static/model_chk_pts/autovc/final_1000.pth\"" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "dXgEIKxlB17a" }, "source": [ "!pip install wavenet_vocoder\n", "!pip install webrtcvad" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Ig_jn3Z2q9is" }, "source": [ "import torch\n", "\n", "from AVC.s3_auto_voice_cloner.s1_create_emb_per_speaker import create_embbedings_per_speaker\n", "\n", "from sklearn.manifold import TSNE\n", "import matplotlib.pyplot as plt\n", "\n", "# hp.m_ge2e.best_model_path = \"static/model_chk_pts/ge2e/final_epoch_1000_L_0.0390.pth\"\n", "\n", "utterances = create_embbedings_per_speaker(hp)\n", "\n", "\n", "labels = []\n", "embs = []\n", "for k, v in utterances.items():\n", " embs.append(v)\n", " labels.append(k)\n", "\n", "\n", "embeddings = torch.tensor(embs)\n", "\n", "scatters = TSNE(n_components=2, random_state=0).fit_transform(embeddings.cpu().detach().numpy())\n", "fig = plt.figure(figsize=(5, 5))\n", "\n", "current_Label = labels[0]\n", "current_Index = 0\n", "for index, label in enumerate(labels[1:], 1):\n", " if label != current_Label:\n", " plt.scatter(scatters[current_Index:index, 0], scatters[current_Index:index, 1],\n", " label='{}'.format(current_Label))\n", " current_Label = label\n", " current_Index = index\n", "\n", "plt.scatter(scatters[current_Index:, 0], scatters[current_Index:, 1], label='{}'.format(current_Label))\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "6EcBFWkg3Wie" }, "source": [ "import tqdm\n", "\n", "from AVC.s3_auto_voice_cloner.s6_create_cross_speaker_mel_specs import VoiceCloner\n", "import soundfile as sf\n", "import os" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "oee-8QuUB_bD" }, "source": [ "hp.m_avc.gen.best_model_path = \"/content/AVC/static/model_chk_pts/autovc/final_1000.pth\"\n", "hp.m_wave_net.gen.best_model_path = \"/content/drive/MyDrive/AI_ML_DL/model_chk_pts/wavenet_model/wavenet_pretrained_step001000000_ema.pth\"\n", "vcs_obj = VoiceCloner(hp, tqdm, absolute_path=True)\n", "\n", "path_audio = \"static/raw_data/wavs/p225/p225_003.wav\"\n", "path_audio = os.path.join(hp.general.project_root, path_audio)\n", "spkr_p225_mel_spec = vcs_obj.au.get_mel_spects_from_audio(path_audio, partial_slices=False)\n", "\n", "path_audio = \"static/raw_data/wavs/p226/p226_003.wav\"\n", "path_audio = os.path.join(hp.general.project_root, path_audio)\n", "spkr_p226_mel_spec = vcs_obj.au.get_mel_spects_from_audio(path_audio, partial_slices=False)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "B8UzGtndCLeV" }, "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import math\n", "import librosa\n", "import librosa.display" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "W4sgHQ2bCXCs" }, "source": [ "# Plot mel spectrograms\n", "fig, ax = plt.subplots(1,2, figsize = (20,10))\n", "\n", "ax[0].set(title = 'Mel Spectrogram of Guitar')\n", "i = librosa.display.specshow(spkr_p225_mel_spec, ax=ax[0])\n", "\n", "ax[1].set(title = 'Mel Spectrogram of Kick')\n", "librosa.display.specshow(spkr_p226_mel_spec, ax=ax[1])\n", "\n", "plt.colorbar(i)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "XlNXkRO0Cfsz" }, "source": [ "avc_mel_specs = vcs_obj.create_cross_spkr_mel_spects(\"p225\", \"p226\", spkr_p225_mel_spec[:320, :])\n", "avc_mel_specs.shape" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "C63S2fAiEMzQ" }, "source": [ "# Plot mel spectrograms\n", "fig, ax = plt.subplots(1,2, figsize = (20,10))\n", "\n", "ax[0].set(title = 'Mel Spectrogram of P225')\n", "i = librosa.display.specshow(spkr_p225_mel_spec[:320, :], ax=ax[0])\n", "\n", "ax[1].set(title = 'AVC mel spect')\n", "librosa.display.specshow(avc_mel_specs, ax=ax[1])\n", "\n", "plt.colorbar(i)" ], "execution_count": null, "outputs": [] } ] }