# Add convert:
#     Implement audio convert
#     Load trained net1 and net2 model, make spectrogram to waveform
#     Save converted audio
This commit is contained in:
miaomiaomiao-LJY 2021-04-22 17:38:04 +08:00
parent 68c796fda3
commit ab4497878f
1 changed files with 173 additions and 0 deletions

173
convert.py Normal file
View File

@ -0,0 +1,173 @@
import argparse
import torch
import os
import numpy as np
import hparams
from model.Net1 import Net1
from model.Net2 import Net2
from dataloader.Net2DataLoader import get_net2_data_loader
from audio_operation import db2amp, spec2wav, save_wav, denormalize_0_1, inv_preemphasis
def convert(x_spec, y_spec, x_mel, y_mel):
x_spec = x_spec.cpu().detach().numpy()
y_spec = y_spec.cpu().detach().numpy()
# Denormalizatoin
x_spec = denormalize_0_1(x_spec, hparams.timit_max_db, hparams.timit_min_db)
y_spec = denormalize_0_1(y_spec, hparams.timit_max_db, hparams.timit_min_db)
# Db to amp
x_spec = db2amp(x_spec)
y_spec = db2amp(y_spec)
# Emphasize the magnitude
x_spec = np.power(x_spec, hparams.convert_emphasis_magnitude)
x_spec = np.power(x_spec, hparams.convert_emphasis_magnitude)
print(x_spec.shape)
print(y_spec.shape)
x_spec = np.squeeze(x_spec)
y_spec = np.squeeze(y_spec)
# Spectrogram to waveform
spec_x_audio = np.array(spec2wav(mag=x_spec.T,
n_fft=hparams.timit_n_fft,
win_length=hparams.timit_wim_length,
hop_length=hparams.timit_hop_length,
num_iters=hparams.convert_num_iters))
spec_y_audio = np.array(spec2wav(mag=y_spec.T,
n_fft=hparams.timit_n_fft,
win_length=hparams.timit_wim_length,
hop_length=hparams.timit_hop_length,
num_iters=hparams.convert_num_iters))
print(spec_x_audio.shape)
print(spec_y_audio.shape)
spec_x_audio = inv_preemphasis(spec_x_audio, hparams.timit_preemphasis).astype(np.float32)
spec_y_audio = inv_preemphasis(spec_y_audio, hparams.timit_preemphasis).astype(np.float32)
return spec_x_audio, spec_y_audio
def do_convert(arg):
device = torch.device(arg.device)
# Build model
net1 = Net1(in_dims=hparams.net1_in_dims,
hidden_units=hparams.net1_hidden_units,
dropout_rate=hparams.net1_dropout_rate,
num_conv1d_banks=hparams.net1_num_conv1d_banks,
num_highway_blocks=hparams.net1_num_highway_blocks)
net2 = Net2(in_dims=hparams.net2_in_dims,
hidden_units=hparams.net2_hidden_units,
dropout_rate=hparams.net2_dropout_rate,
num_conv1d_banks=hparams.net2_num_conv1d_banks,
num_highway_blocks=hparams.net2_num_highway_blocks)
# Move model into the computing device
net1.to(device)
net2.to(device)
# Set data loader
data_loader = get_net2_data_loader(data_path=arg.data_path,
batch_size=arg.batch_size,
num_workers=arg.num_workers)
# Resume net1 model
if arg.resume_net1_model is None:
raise Exception(print("Need net1 pre-trained model!"))
resume_net1_model_path = os.path.join(hparams.net1_convert_checkpoint_path, arg.resume_net1_model)
resume_log = "Resume net1 model from : " + resume_net1_model_path
print(resume_log)
checkpoint_net1 = torch.load(resume_net1_model_path)
print("Load net1 model successfully!")
net1.load_state_dict(checkpoint_net1["net"])
# Resume net2 model
if arg.resume_net2_model is None:
raise Exception(print("Need net2 pre-trained model!"))
resume_net2_model_path = os.path.join(hparams.net2_convert_checkpoint_path, arg.resume_net2_model)
resume_log = "Resume net2 model from : " + resume_net2_model_path
print(resume_log)
checkpoint_net2 = torch.load(resume_net2_model_path)
print("Load net1 model successfully!")
net2.load_state_dict(checkpoint_net2["net"])
# Start Converting
print("Start Converting ... ")
data_iter = iter(data_loader)
# Get input data
mfccs, spec, mel = next(data_iter)
# Moving input data into the computing device
mfccs = mfccs.to(device)
spec = spec.to(device)
mel = mel.to(device)
# Set net1 and net2 model
net1 = net1.eval()
net2 = net2.eval()
# Compute net1
net1_outputs_ppgs, _, _ = net1(mfccs)
net2_inputs_ppgs = net1_outputs_ppgs.detach()
# Compute net2
pred_spec, pred_mel = net2(net2_inputs_ppgs)
# pred_spec = net2(net2_inputs_ppgs)
pred_spec_audio, real_spec_audio = convert(pred_spec, spec, pred_mel, mel)
# pred_spec_audio, real_spec_audio = convert(pred_spec, spec, None, mel)
# Write the result
save_wav(hparams.convert_save_path + "-" + arg.resume_net2_model + "-spec_converted.wav",
pred_spec_audio, sr=hparams.timit_sr)
save_wav(hparams.convert_save_path + "-" + arg.resume_net2_model + "-spec_real.wav",
real_spec_audio, sr=hparams.timit_sr)
def get_arguments():
parser = argparse.ArgumentParser()
# Set DataLoader
parser.add_argument('-data_path', default=hparams.convert_dataset, type=str,
help='Path of Net2 dataset.')
parser.add_argument('-batch_size', default=hparams.convert_batch_size, type=int,
help='Batch size.')
parser.add_argument('-num_workers', default=hparams.convert_num_workers, type=int,
help='Number of workers.')
# Set convert config
parser.add_argument('-device', default=hparams.convert_device, type=str,
help='Convert device.')
parser.add_argument('-resume_net1_model', default=None, type=str,
help='Net1 resume model checkpoint.')
parser.add_argument('-resume_net2_model', default=None, type=str,
help='Net2 resume model checkpoint.')
arguments = parser.parse_args()
return arguments
if __name__ == '__main__':
args = get_arguments()
print("Convert parameters : \n " + str(args))
do_convert(args)