1 line
31 KiB
Plaintext
1 line
31 KiB
Plaintext
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Tacotron2+HifiGAN派蒙语音合成.ipynb","provenance":[{"file_id":"1VAuIqEAnrmCig3Edt5zFgQdckY9TDi3N","timestamp":1659173090731},{"file_id":"1U5ZQJCBF6omYG_doSUjF48GAVKp5t7au","timestamp":1647423462684},{"file_id":"19PDvhxBXXF3ZPdqRYwu-fBc7BT9VG1UL","timestamp":1639549819501},{"file_id":"1uP578uonO0GGjKsQ1hVQO-TCeRl4IdrZ","timestamp":1639450346986}],"collapsed_sections":[],"private_outputs":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"gpuClass":"standard","accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["**2022/03/14 and the unpickling error is solved. The training part works as of 2022/03/14**\n","\n","**2022/03/15 Speech synsthesis with HiFi-GAN works** \n","\n","**2022/03/16 Speech synsthesis with Waveglow should work again now (tested)** \n","\n","**2022/08/16 本笔记在原作者基础上修改更新,并且不再使用Waveglow。在tarcotron2项目基础上适配tensorflow 2.x,感谢NO.zero的帮助。\n","语音合成部分现支持中文输入。**\n","\n","**Tacotron 2 Training and Synthesis Notebook**\n","originally based on the following notebooks\n","https://github.com/NVIDIA/tacotron2,\n","https://bit.ly/3F4DkH2\n","and those presented in Adam is cool and stuff (https://youtu.be/LQAOCXdU8p8 and https://youtu.be/XLt_K_692Mc)"],"metadata":{"id":"oxl2bna8ztl2"}},{"cell_type":"markdown","source":["#Data Preparation"],"metadata":{"id":"r2D2Mt80bamF"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"nmTLE46V4erB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Download Tacotron 2\n","%tensorflow_version 2.x\n","import os\n","\n","os.chdir('/content/drive/MyDrive/tacotron2')\n","!git submodule init\n","!git submodule update\n","!pip install -q unidecode tensorboardX\n","!pip install pypinyin"],"metadata":{"id":"67u1nnaJcyPt"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#Model Preparation"],"metadata":{"id":"IOETvfJdbfYO"}},{"cell_type":"code","source":["#@title { display-mode: \"code\" }\n","\n","%matplotlib inline\n","import os\n","if os.getcwd() != '/content/drive/MyDrive/tacotron2':\n"," os.chdir('/content/drive/MyDrive/tacotron2')\n","import time\n","import argparse\n","import math\n","from numpy import finfo\n","\n","import torch\n","from distributed import apply_gradient_allreduce\n","import torch.distributed as dist\n","from torch.utils.data.distributed import DistributedSampler\n","from torch.utils.data import DataLoader\n","\n","from model import Tacotron2\n","from data_utils import TextMelLoader, TextMelCollate\n","from loss_function import Tacotron2Loss\n","from logger import Tacotron2Logger\n","from hparams import create_hparams\n"," \n","import random\n","import numpy as np\n","\n","import layers\n","from utils import load_wav_to_torch, load_filepaths_and_text\n","from text import text_to_sequence\n","from math import e\n","#from tqdm import tqdm # Terminal\n","#from tqdm import tqdm_notebook as tqdm # Legacy Notebook TQDM\n","from tqdm.notebook import tqdm # Modern Notebook TQDM\n","from distutils.dir_util import copy_tree\n","import matplotlib.pylab as plt\n","\n","def download_from_google_drive(file_id, file_name):\n"," # download a file from the Google Drive link\n"," !rm -f ./cookie\n"," !curl -c ./cookie -s -L \"https://drive.google.com/uc?export=download&id={file_id}\" > /dev/null\n"," confirm_text = !awk '/download/ {print $NF}' ./cookie\n"," confirm_text = confirm_text[0]\n"," !curl -Lb ./cookie \"https://drive.google.com/uc?export=download&confirm={confirm_text}&id={file_id}\" -o {file_name}\n","\n","def create_mels():\n"," print(\"Generating Mels\")\n"," stft = layers.TacotronSTFT(\n"," hparams.filter_length, hparams.hop_length, hparams.win_length,\n"," hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,\n"," hparams.mel_fmax)\n"," def save_mel(filename):\n"," audio, sampling_rate = load_wav_to_torch(filename)\n"," if sampling_rate != stft.sampling_rate:\n"," raise ValueError(\"{} {} SR doesn't match target {} SR\".format(filename, \n"," sampling_rate, stft.sampling_rate))\n"," audio_norm = audio / hparams.max_wav_value\n"," audio_norm = audio_norm.unsqueeze(0)\n"," audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)\n"," melspec = stft.mel_spectrogram(audio_norm)\n"," melspec = torch.squeeze(melspec, 0).cpu().numpy()\n"," np.save(filename.replace('.wav', ''), melspec)\n","\n"," import glob\n"," wavs = glob.glob('Paimon/*/*.wav')\n"," for i in tqdm(wavs):\n"," save_mel(i)\n","\n","\n","def reduce_tensor(tensor, n_gpus):\n"," rt = tensor.clone()\n"," dist.all_reduce(rt, op=dist.reduce_op.SUM)\n"," rt /= n_gpus\n"," return rt\n","\n","\n","def init_distributed(hparams, n_gpus, rank, group_name):\n"," assert torch.cuda.is_available(), \"Distributed mode requires CUDA.\"\n"," print(\"Initializing Distributed\")\n","\n"," # Set cuda device so everything is done on the right GPU.\n"," torch.cuda.set_device(rank % torch.cuda.device_count())\n","\n"," # Initialize distributed communication\n"," dist.init_process_group(\n"," backend=hparams.dist_backend, init_method=hparams.dist_url,\n"," world_size=n_gpus, rank=rank, group_name=group_name)\n","\n"," print(\"Done initializing distributed\")\n","\n","\n","def prepare_dataloaders(hparams):\n"," # Get data, data loaders and collate function ready\n"," trainset = TextMelLoader(hparams.training_files, hparams)\n"," valset = TextMelLoader(hparams.validation_files, hparams)\n"," collate_fn = TextMelCollate(hparams.n_frames_per_step)\n","\n"," if hparams.distributed_run:\n"," train_sampler = DistributedSampler(trainset)\n"," shuffle = False\n"," else:\n"," train_sampler = None\n"," shuffle = True\n","\n"," train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,\n"," sampler=train_sampler,\n"," batch_size=hparams.batch_size, pin_memory=False,\n"," drop_last=True, collate_fn=collate_fn)\n"," return train_loader, valset, collate_fn\n","\n","\n","def prepare_directories_and_logger(output_directory, log_directory, rank):\n"," if rank == 0:\n"," if not os.path.isdir(output_directory):\n"," os.makedirs(output_directory)\n"," os.chmod(output_directory, 0o775)\n"," logger = Tacotron2Logger(os.path.join(output_directory, log_directory))\n"," else:\n"," logger = None\n"," return logger\n","\n","\n","def load_model(hparams):\n"," model = Tacotron2(hparams).cuda()\n"," if hparams.fp16_run:\n"," model.decoder.attention_layer.score_mask_value = finfo('float16').min\n","\n"," if hparams.distributed_run:\n"," model = apply_gradient_allreduce(model)\n","\n"," return model\n","\n","\n","def warm_start_model(checkpoint_path, model, ignore_layers):\n"," assert os.path.isfile(checkpoint_path)\n"," print(\"Warm starting model from checkpoint '{}'\".format(checkpoint_path))\n"," checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')\n"," model_dict = checkpoint_dict['state_dict']\n"," if len(ignore_layers) > 0:\n"," model_dict = {k: v for k, v in model_dict.items()\n"," if k not in ignore_layers}\n"," dummy_dict = model.state_dict()\n"," dummy_dict.update(model_dict)\n"," model_dict = dummy_dict\n"," model.load_state_dict(model_dict)\n"," return model\n","\n","\n","def load_checkpoint(checkpoint_path, model, optimizer):\n"," assert os.path.isfile(checkpoint_path)\n"," print(\"Loading checkpoint '{}'\".format(checkpoint_path))\n"," checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')\n"," model.load_state_dict(checkpoint_dict['state_dict'])\n"," optimizer.load_state_dict(checkpoint_dict['optimizer'])\n"," learning_rate = checkpoint_dict['learning_rate']\n"," iteration = checkpoint_dict['iteration']\n"," print(\"Loaded checkpoint '{}' from iteration {}\" .format(\n"," checkpoint_path, iteration))\n"," return model, optimizer, learning_rate, iteration\n","\n","\n","def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):\n"," print(\"Saving model and optimizer state at iteration {} to {}\".format(\n"," iteration, filepath))\n"," try:\n"," torch.save({'iteration': iteration,\n"," 'state_dict': model.state_dict(),\n"," 'optimizer': optimizer.state_dict(),\n"," 'learning_rate': learning_rate}, filepath)\n"," except KeyboardInterrupt:\n"," print(\"interrupt received while saving, waiting for save to complete.\")\n"," torch.save({'iteration': iteration,'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(),'learning_rate': learning_rate}, filepath)\n"," print(\"Model Saved\")\n","\n","def plot_alignment(alignment, info=None):\n"," %matplotlib inline\n"," fig, ax = plt.subplots(figsize=(int(alignment_graph_width/100), int(alignment_graph_height/100)))\n"," im = ax.imshow(alignment, cmap='inferno', aspect='auto', origin='lower',\n"," interpolation='none')\n"," ax.autoscale(enable=True, axis=\"y\", tight=True)\n"," fig.colorbar(im, ax=ax)\n"," xlabel = 'Decoder timestep'\n"," if info is not None:\n"," xlabel += '\\n\\n' + info\n"," plt.xlabel(xlabel)\n"," plt.ylabel('Encoder timestep')\n"," plt.tight_layout()\n"," fig.canvas.draw()\n"," plt.show()\n","\n","def validate(model, criterion, valset, iteration, batch_size, n_gpus,\n"," collate_fn, logger, distributed_run, rank, epoch, start_eposh, learning_rate):\n"," \"\"\"Handles all the validation scoring and printing\"\"\"\n"," model.eval()\n"," with torch.no_grad():\n"," val_sampler = DistributedSampler(valset) if distributed_run else None\n"," val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,\n"," shuffle=False, batch_size=batch_size,\n"," pin_memory=False, collate_fn=collate_fn)\n","\n"," val_loss = 0.0\n"," for i, batch in enumerate(val_loader):\n"," x, y = model.parse_batch(batch)\n"," y_pred = model(x)\n"," loss = criterion(y_pred, y)\n"," if distributed_run:\n"," reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()\n"," else:\n"," reduced_val_loss = loss.item()\n"," val_loss += reduced_val_loss\n"," val_loss = val_loss / (i + 1)\n","\n"," model.train()\n"," if rank == 0:\n"," print(\"Epoch: {} Validation loss {}: {:9f} Time: {:.1f}m LR: {:.6f}\".format(epoch, iteration, val_loss,(time.perf_counter()-start_eposh)/60, learning_rate))\n"," logger.log_validation(val_loss, model, y, y_pred, iteration)\n"," if hparams.show_alignments:\n"," %matplotlib inline\n"," _, mel_outputs, gate_outputs, alignments = y_pred\n"," idx = random.randint(0, alignments.size(0) - 1)\n"," plot_alignment(alignments[idx].data.cpu().numpy().T)\n","\n","def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,\n"," rank, group_name, hparams, log_directory2):\n"," \"\"\"Training and validation logging results to tensorboard and stdout\n","\n"," Params\n"," ------\n"," output_directory (string): directory to save checkpoints\n"," log_directory (string) directory to save tensorboard logs\n"," checkpoint_path(string): checkpoint path\n"," n_gpus (int): number of gpus\n"," rank (int): rank of current gpu\n"," hparams (object): comma separated list of \"name=value\" pairs.\n"," \"\"\"\n"," if hparams.distributed_run:\n"," init_distributed(hparams, n_gpus, rank, group_name)\n","\n"," torch.manual_seed(hparams.seed)\n"," torch.cuda.manual_seed(hparams.seed)\n","\n"," model = load_model(hparams)\n"," learning_rate = hparams.learning_rate\n"," optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,\n"," weight_decay=hparams.weight_decay)\n","\n"," if hparams.fp16_run:\n"," from apex import amp\n"," model, optimizer = amp.initialize(\n"," model, optimizer, opt_level='O2')\n","\n"," if hparams.distributed_run:\n"," model = apply_gradient_allreduce(model)\n","\n"," criterion = Tacotron2Loss()\n","\n"," logger = prepare_directories_and_logger(\n"," output_directory, log_directory, rank)\n","\n"," train_loader, valset, collate_fn = prepare_dataloaders(hparams)\n","\n"," # Load checkpoint if one exists\n"," iteration = 0\n"," epoch_offset = 0\n"," if checkpoint_path is not None and os.path.isfile(checkpoint_path):\n"," if warm_start:\n"," model = warm_start_model(\n"," checkpoint_path, model, hparams.ignore_layers)\n"," else:\n"," model, optimizer, _learning_rate, iteration = load_checkpoint(\n"," checkpoint_path, model, optimizer)\n"," if hparams.use_saved_learning_rate:\n"," learning_rate = _learning_rate\n"," iteration += 1 # next iteration is iteration + 1\n"," epoch_offset = max(0, int(iteration / len(train_loader)))\n"," else:\n"," os.path.isfile(\"tacotron2_statedict.pt\")\n"," model = warm_start_model(\"tacotron2_statedict.pt\", model, hparams.ignore_layers)\n"," # download LJSpeech pretrained model if no checkpoint already exists\n"," \n"," start_eposh = time.perf_counter()\n"," learning_rate = 0.0\n"," model.train()\n"," is_overflow = False\n"," # ================ MAIN TRAINNIG LOOP! ===================\n"," for epoch in tqdm(range(epoch_offset, hparams.epochs)):\n"," print(\"\\nStarting Epoch: {} Iteration: {}\".format(epoch, iteration))\n"," start_eposh = time.perf_counter() # eposh is russian, not a typo\n"," for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)):\n"," start = time.perf_counter()\n"," if iteration < hparams.decay_start: learning_rate = hparams.A_\n"," else: iteration_adjusted = iteration - hparams.decay_start; learning_rate = (hparams.A_*(e**(-iteration_adjusted/hparams.B_))) + hparams.C_\n"," learning_rate = max(hparams.min_learning_rate, learning_rate) # output the largest number\n"," for param_group in optimizer.param_groups:\n"," param_group['lr'] = learning_rate\n","\n"," model.zero_grad()\n"," x, y = model.parse_batch(batch)\n"," y_pred = model(x)\n","\n"," loss = criterion(y_pred, y)\n"," if hparams.distributed_run:\n"," reduced_loss = reduce_tensor(loss.data, n_gpus).item()\n"," else:\n"," reduced_loss = loss.item()\n"," if hparams.fp16_run:\n"," with amp.scale_loss(loss, optimizer) as scaled_loss:\n"," scaled_loss.backward()\n"," else:\n"," loss.backward()\n","\n"," if hparams.fp16_run:\n"," grad_norm = torch.nn.utils.clip_grad_norm_(\n"," amp.master_params(optimizer), hparams.grad_clip_thresh)\n"," is_overflow = math.isnan(grad_norm)\n"," else:\n"," grad_norm = torch.nn.utils.clip_grad_norm_(\n"," model.parameters(), hparams.grad_clip_thresh)\n","\n"," optimizer.step()\n","\n"," if not is_overflow and rank == 0:\n"," duration = time.perf_counter() - start\n"," logger.log_training(\n"," reduced_loss, grad_norm, learning_rate, duration, iteration)\n"," #print(\"Batch {} loss {:.6f} Grad Norm {:.6f} Time {:.6f}\".format(iteration, reduced_loss, grad_norm, duration), end='\\r', flush=True)\n","\n"," iteration += 1\n"," validate(model, criterion, valset, iteration,\n"," hparams.batch_size, n_gpus, collate_fn, logger,\n"," hparams.distributed_run, rank, epoch, start_eposh, learning_rate)\n"," save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)\n"," if log_directory2 != None:\n"," copy_tree(log_directory, log_directory2)\n","def check_dataset(hparams):\n"," from utils import load_wav_to_torch, load_filepaths_and_text\n"," import os\n"," import numpy as np\n"," def check_arr(filelist_arr):\n"," for i, file in enumerate(filelist_arr):\n"," if len(file) > 2:\n"," print(\"|\".join(file), \"\\nhas multiple '|', this may not be an error.\")\n"," if hparams.load_mel_from_disk and '.wav' in file[0]:\n"," print(\"[WARNING]\", file[0], \" in filelist while expecting .npy .\")\n"," else:\n"," if not hparams.load_mel_from_disk and '.npy' in file[0]:\n"," print(\"[WARNING]\", file[0], \" in filelist while expecting .wav .\")\n"," if (not os.path.exists(file[0])):\n"," print(\"|\".join(file), \"\\n[WARNING] does not exist.\")\n"," if len(file[1]) < 3:\n"," print(\"|\".join(file), \"\\n[info] has no/very little text.\")\n"," if not ((file[1].strip())[-1] in r\"!?,.;:\"):\n"," print(\"|\".join(file), \"\\n[info] has no ending punctuation.\")\n"," mel_length = 1\n"," if hparams.load_mel_from_disk and '.npy' in file[0]:\n"," melspec = torch.from_numpy(np.load(file[0], allow_pickle=True))\n"," mel_length = melspec.shape[1]\n"," if mel_length == 0:\n"," print(\"|\".join(file), \"\\n[WARNING] has 0 duration.\")\n"," print(\"Checking Training Files\")\n"," audiopaths_and_text = load_filepaths_and_text(hparams.training_files) # get split lines from training_files text file.\n"," check_arr(audiopaths_and_text)\n"," print(\"Checking Validation Files\")\n"," audiopaths_and_text = load_filepaths_and_text(hparams.validation_files) # get split lines from validation_files text file.\n"," check_arr(audiopaths_and_text)\n"," print(\"Finished Checking\")\n","\n","# ---- Replace .wav with .npy in filelists ----\n","!sed -i -- 's,.wav|,.npy|,g' filelists/*.txt\n","# ---- Replace .wav with .npy in filelists ----\n","\n","warm_start=False#sorry bout that\n","n_gpus=1\n","rank=0\n","group_name=None\n","\n"],"metadata":{"id":"SDDLqhOChx2U"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Model Name\n","model_filename = \"Paimon_test\" #@param {type:\"string\"}"],"metadata":{"id":"GKRvQ1EWiVhn"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["hparams = create_hparams()\n","#@title Lists\n","hparams.training_files = \"/content/drive/MyDrive/tacotron2/filelists/new_merge_training.txt\"\n","hparams.validation_files = \"/content/drive/MyDrive/tacotron2/filelists/new_merge_testing.txt\""],"metadata":{"id":"yehA2fOliyUI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Parameters\n","\n","#These two are the most important\n","hparams.batch_size = 30 # Controls how fast the model trains. Don't set this too high, or else it will GPU will OOM (out of memory). 30-ish is usually a good number if you have a bigger dataset. If the number of audio files is more than/about the same as this number, it won't train properly, and you won't be able to use it.\n","hparams.epochs = 600 # Maxmimum epochs (number of times the AI looks through the dataset) to train\n","\n","#The rest aren't that important\n","hparams.p_attention_dropout=0.1\n","hparams.p_decoder_dropout=0.1\n","hparams.decay_start = 15000 # wait till decay_start to start decaying learning rate\n","hparams.A_ = 5e-4 # Start/Max Learning Rate\n","hparams.B_ = 8000 # Decay Rate\n","hparams.C_ = 0 # Shift learning rate equation by this value\n","hparams.min_learning_rate = 1e-5 # Min Learning Rate\n","generate_mels = True # Don't change\n","hparams.show_alignments = True\n","alignment_graph_height = 600\n","alignment_graph_width = 1000\n","hparams.load_mel_from_disk = True\n","hparams.ignore_layers = [] # Layers to reset (None by default, other than foreign languages this param can be ignored)\n","\n","torch.backends.cudnn.enabled = hparams.cudnn_enabled\n","torch.backends.cudnn.benchmark = hparams.cudnn_benchmark\n","output_directory = '/content/drive/My Drive/colab/outdir' # Location to save Checkpoints\n","log_directory = '/content/tacotron2/logs' # Location to save Log files locally\n","log_directory2 = '/content/drive/My Drive/colab/logs' # Location to copy log files (done at the end of each epoch to cut down on I/O)\n","checkpoint_path = output_directory+(r'/')+model_filename"],"metadata":{"id":"JzevuoJnkIsi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title MEL Spectrogram\n","if generate_mels:\n"," create_mels()"],"metadata":{"id":"b_xMcYMfkc9L"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Check the Data Set\n","check_dataset(hparams)"],"metadata":{"id":"oJXxqs6kkgLw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#Training"],"metadata":{"id":"62-cfyIubje_"}},{"cell_type":"code","source":["#@title Training\n","print('FP16 Run:', hparams.fp16_run)\n","print('Dynamic Loss Scaling:', hparams.dynamic_loss_scaling)\n","print('Distributed Run:', hparams.distributed_run)\n","print('cuDNN Enabled:', hparams.cudnn_enabled)\n","print('cuDNN Benchmark:', hparams.cudnn_benchmark)\n","train(output_directory, log_directory, checkpoint_path,\n"," warm_start, n_gpus, rank, group_name, hparams, log_directory2)"],"metadata":{"id":"qJTrZhShk8ZR"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#Speech Synthesis"],"metadata":{"id":"jDGVcS77b25R"}},{"cell_type":"markdown","source":["##With HiFi-GAN##"],"metadata":{"id":"V6pX7t0cVlj9"}},{"cell_type":"code","source":["#@markdown Config:\n","\n","#@markdown Restart the code to apply any changes.\n","\n","#Add new characters here.\n","#Universal HiFi-GAN (has some robotic noise): 1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW\n","Tacotron2_Model = '/content/drive/MyDrive/colab/outdir/Paimon_test'#@param {type:\"string\"}\n","TACOTRON2_ID = Tacotron2_Model\n","HIFIGAN_ID = \"1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW\"\n","from pypinyin import lazy_pinyin,Style\n","\n","# Check if Initilized\n","try:\n"," initilized\n","except NameError:\n"," print(\"Setting up, please wait.\\n\")\n"," !pip install tqdm -q\n"," from tqdm.notebook import tqdm\n"," with tqdm(total=5, leave=False) as pbar:\n"," %tensorflow_version 2.x\n"," import os\n"," from os.path import exists, join, basename, splitext\n"," !pip install gdown\n"," git_repo_url = 'https://github.com/NVIDIA/tacotron2.git'\n"," project_name = splitext(basename(git_repo_url))[0]\n"," if not exists(project_name):\n"," # clone and install\n"," !git clone -q --recursive {git_repo_url}\n"," !git clone -q --recursive https://github.com/SortAnon/hifi-gan\n"," !pip install -q librosa unidecode\n"," pbar.update(1) # downloaded TT2 and HiFi-GAN\n"," import sys\n"," sys.path.append('hifi-gan')\n"," sys.path.append(project_name)\n"," import time\n"," import matplotlib\n"," import matplotlib.pylab as plt\n"," import gdown\n"," d = 'https://drive.google.com/uc?id='\n","\n"," %matplotlib inline\n"," import IPython.display as ipd\n"," import numpy as np\n"," import torch\n"," import json\n"," from hparams import create_hparams\n"," from model import Tacotron2\n"," from layers import TacotronSTFT\n"," from audio_processing import griffin_lim\n"," from text import text_to_sequence\n"," from env import AttrDict\n"," from meldataset import MAX_WAV_VALUE\n"," from models import Generator\n","\n"," pbar.update(1) # initialized Dependancies\n","\n"," graph_width = 900\n"," graph_height = 360\n"," def plot_data(data, figsize=(int(graph_width/100), int(graph_height/100))):\n"," %matplotlib inline\n"," fig, axes = plt.subplots(1, len(data), figsize=figsize)\n"," for i in range(len(data)):\n"," axes[i].imshow(data[i], aspect='auto', origin='bottom', \n"," interpolation='none', cmap='inferno')\n"," fig.canvas.draw()\n"," plt.show()\n","\n"," # Setup Pronounciation Dictionary\n"," !gdown --id '1E12g_sREdcH5vuZb44EZYX8JjGWQ9rRp'\n"," thisdict = {}\n"," for line in reversed((open('merged.dict.txt', \"r\").read()).splitlines()):\n"," thisdict[(line.split(\" \",1))[0]] = (line.split(\" \",1))[1].strip()\n","\n"," pbar.update(1) # Downloaded and Set up Pronounciation Dictionary\n","\n"," def ARPA(text, punctuation=r\"!?,.;\", EOS_Token=True):\n"," out = ''\n"," for word_ in text.split(\" \"):\n"," word=word_; end_chars = ''\n"," while any(elem in word for elem in punctuation) and len(word) > 1:\n"," if word[-1] in punctuation: end_chars = word[-1] + end_chars; word = word[:-1]\n"," else: break\n"," try:\n"," word_arpa = thisdict[word.upper()]\n"," word = \"{\" + str(word_arpa) + \"}\"\n"," except KeyError: pass\n"," out = (out + \" \" + word + end_chars).strip()\n"," if EOS_Token and out[-1] != \";\": out += \";\"\n"," return out\n","\n"," def get_hifigan(MODEL_ID):\n"," # Download HiFi-GAN\n"," hifigan_pretrained_model = 'hifimodel'\n"," gdown.download(d+MODEL_ID, hifigan_pretrained_model, quiet=False)\n"," if not exists(hifigan_pretrained_model):\n"," raise Exception(\"HiFI-GAN model failed to download!\")\n","\n"," # Load HiFi-GAN\n"," conf = os.path.join(\"hifi-gan\", \"config_v1.json\")\n"," with open(conf) as f:\n"," json_config = json.loads(f.read())\n"," h = AttrDict(json_config)\n"," torch.manual_seed(h.seed)\n"," hifigan = Generator(h).to(torch.device(\"cuda\"))\n"," state_dict_g = torch.load(hifigan_pretrained_model, map_location=torch.device(\"cuda\"))\n"," hifigan.load_state_dict(state_dict_g[\"generator\"])\n"," hifigan.eval()\n"," hifigan.remove_weight_norm()\n"," return hifigan, h\n","\n"," hifigan, h = get_hifigan(HIFIGAN_ID)\n"," pbar.update(1) # Downloaded and Set up HiFi-GAN\n","\n"," def has_MMI(STATE_DICT):\n"," return any(True for x in STATE_DICT.keys() if \"mi.\" in x)\n","\n"," def get_Tactron2(MODEL_ID):\n"," # Download Tacotron2\n"," tacotron2_pretrained_model = TACOTRON2_ID\n"," if not exists(tacotron2_pretrained_model):\n"," raise Exception(\"Tacotron2 model failed to download!\")\n"," # Load Tacotron2 and Config\n"," hparams = create_hparams()\n"," hparams.sampling_rate = 22050\n"," hparams.max_decoder_steps = 3000 # Max Duration\n"," hparams.gate_threshold = 0.25 # Model must be 25% sure the clip is over before ending generation\n"," model = Tacotron2(hparams)\n"," state_dict = torch.load(tacotron2_pretrained_model)['state_dict']\n"," if has_MMI(state_dict):\n"," raise Exception(\"ERROR: This notebook does not currently support MMI models.\")\n"," model.load_state_dict(state_dict)\n"," _ = model.cuda().eval().half()\n"," return model, hparams\n","\n"," model, hparams = get_Tactron2(TACOTRON2_ID)\n"," previous_tt2_id = TACOTRON2_ID\n","\n"," pbar.update(1) # Downloaded and Set up Tacotron2\n","\n"," # Extra Info\n"," def end_to_end_infer(text, pronounciation_dictionary, show_graphs):\n"," for i in [x for x in text.split(\"\\n\") if len(x)]:\n"," if not pronounciation_dictionary:\n"," if i[-1] != \";\": i=i+\";\" \n"," else: i = ARPA(i)\n"," with torch.no_grad(): # save VRAM by not including gradients\n"," sequence = np.array(text_to_sequence(i, ['english_cleaners']))[None, :]\n"," sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()\n"," mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)\n"," if show_graphs:\n"," plot_data((mel_outputs_postnet.float().data.cpu().numpy()[0],\n"," alignments.float().data.cpu().numpy()[0].T))\n"," y_g_hat = hifigan(mel_outputs_postnet.float())\n"," audio = y_g_hat.squeeze()\n"," audio = audio * MAX_WAV_VALUE\n"," print(\"\")\n"," ipd.display(ipd.Audio(audio.cpu().numpy().astype(\"int16\"), rate=hparams.sampling_rate))\n"," from IPython.display import clear_output\n"," clear_output()\n"," initilized = \"Ready\"\n","\n","if previous_tt2_id != TACOTRON2_ID:\n"," print(\"Updating Models\")\n"," model, hparams = get_Tactron2(TACOTRON2_ID)\n"," hifigan, h = get_hifigan(HIFIGAN_ID)\n"," previous_tt2_id = TACOTRON2_ID\n","\n","pronounciation_dictionary = False #@param {type:\"boolean\"}\n","# disables automatic ARPAbet conversion, useful for inputting your own ARPAbet pronounciations or just for testing\n","show_graphs = True #@param {type:\"boolean\"}\n","max_duration = 25 #this does nothing\n","model.decoder.max_decoder_steps = 1000 #@param {type:\"integer\"}\n","stop_threshold = 0.3 #@param {type:\"number\"}\n","model.decoder.gate_threshold = stop_threshold\n","\n","#@markdown ---\n","\n","print(f\"Current Config:\\npronounciation_dictionary: {pronounciation_dictionary}\\nshow_graphs: {show_graphs}\\nmax_duration (in seconds): {max_duration}\\nstop_threshold: {stop_threshold}\\n\\n\")\n","\n","time.sleep(1)\n","print(\"Enter/Paste your text.输入拼音+数字表示声调,支持直接中文输入\")\n","contents = []\n","while True:\n"," try:\n"," print(\"-\"*50)\n"," line = input()\n"," if line != \"\":\n"," line = \" \".join(lazy_pinyin(line, style=Style.TONE3))\n"," print(line)\n"," end_to_end_infer(line, pronounciation_dictionary, show_graphs)\n"," except EOFError:\n"," break\n"," except KeyboardInterrupt:\n"," print(\"Stopping...\")\n"," break"],"metadata":{"id":"mwsEA9fP4qfZ"},"execution_count":null,"outputs":[]}]} |