ADD
# Add dataloader.Net2DataLoader: # Implement Net2Dataset # Implement get_net2_data_loader # Test Net2DataLoader # Add train.train_net2: # Add net2 train main function # Implement net2 train function # Implement net2 model save and load function # Implement simple debugging function # Update audio_operation: # Turn off debugging function # Update hparams: # Add net2 dataset default setting # Add net2 model default setting # Add net2 train default setting
This commit is contained in:
parent
080136ba6f
commit
1904696bc7
|
@ -158,7 +158,7 @@ def get_mfccs_and_spectrogram(wav_file, trim=True, random_crop=False):
|
|||
length = sr * default_duration
|
||||
wav = librosa.util.fix_length(wav, length)
|
||||
|
||||
debug = True
|
||||
debug = False
|
||||
if debug:
|
||||
print("wav.shape : " + str(wav.shape))
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
import glob
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
from audio_operation import get_mfccs_and_spectrogram
|
||||
|
||||
|
||||
class Net2Dataset(Dataset):
|
||||
|
||||
def __init__(self, data_path):
|
||||
self.wav_files = glob.glob(data_path)
|
||||
|
||||
def __getitem__(self, item):
|
||||
wav = self.wav_files[item]
|
||||
return get_mfccs_and_spectrogram(wav)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.wav_files)
|
||||
|
||||
|
||||
def get_net2_data_loader(data_path, batch_size, num_workers):
|
||||
dataset = Net2Dataset(data_path)
|
||||
|
||||
data_loader = DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
drop_last=True,
|
||||
worker_init_fn=np.random.seed((torch.initial_seed()) % (2 ** 32)))
|
||||
|
||||
return data_loader
|
25
hparams.py
25
hparams.py
|
@ -30,9 +30,30 @@ net1_logits_t = 1.0
|
|||
|
||||
# net1 train
|
||||
net1_train_device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
|
||||
net1_train_steps = 10000
|
||||
net1_train_checkpoint_path = "../checkpoint"
|
||||
net1_train_steps = 100000
|
||||
net1_train_checkpoint_path = "../checkpoint/net1"
|
||||
net1_train_lr = 0.0003
|
||||
net1_train_log_step = 10
|
||||
net1_train_save_step = 1000
|
||||
net1_train_multiple_flag = False
|
||||
|
||||
# Net2
|
||||
# net2 dataset
|
||||
net2_dataset = "../data/dataset/arctic/slt/*.wav"
|
||||
net2_batch_size = 16
|
||||
net2_num_workers = 5
|
||||
|
||||
# net2 model
|
||||
net2_in_dims = phns_len
|
||||
net2_hidden_units = 256
|
||||
net2_dropout_rate = 0
|
||||
net2_num_conv1d_banks = 8
|
||||
net2_num_highway_blocks = 8
|
||||
|
||||
# net2 train
|
||||
net2_train_device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
|
||||
net2_train_steps = 100000
|
||||
net2_train_checkpoint_path = "../checkpoint/net2"
|
||||
net2_train_lr = 0.0003
|
||||
net2_train_log_step = 10
|
||||
net2_train_save_step = 10000
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import time
|
||||
import datetime
|
||||
|
||||
import hparams
|
||||
|
||||
from model.Net1 import Net1
|
||||
from model.Net2 import Net2
|
||||
from dataloader.Net2DataLoader import get_net2_data_loader
|
||||
|
||||
|
||||
def train(arg):
|
||||
device = torch.device(arg.device)
|
||||
|
||||
# Build Net1 model
|
||||
net1 = Net1(in_dims=hparams.net1_in_dims,
|
||||
hidden_units=hparams.net1_hidden_units,
|
||||
dropout_rate=hparams.net1_dropout_rate,
|
||||
num_conv1d_banks=hparams.net1_num_conv1d_banks,
|
||||
num_highway_blocks=hparams.net1_num_highway_blocks)
|
||||
|
||||
# Move net1 model into the computing device
|
||||
net1.to(device)
|
||||
|
||||
# Build Net2 model
|
||||
net2 = Net2(in_dims=arg.in_dims,
|
||||
hidden_units=arg.hidden_units,
|
||||
dropout_rate=arg.dropout_rate,
|
||||
num_conv1d_banks=arg.num_conv1d_banks,
|
||||
num_highway_blocks=arg.num_highway_blocks)
|
||||
|
||||
# Create optimizer
|
||||
net2_optimizer = torch.optim.Adam(net2.parameters(), lr=arg.learning_rate)
|
||||
|
||||
# Move net2 model into the computing device
|
||||
net2.to(device)
|
||||
|
||||
# Set data loader
|
||||
data_loader = get_net2_data_loader(data_path=arg.data_path,
|
||||
batch_size=arg.batch_size,
|
||||
num_workers=arg.num_workers)
|
||||
|
||||
start_step = 1
|
||||
|
||||
# Resume net1 model
|
||||
if arg.resume_net1_model is None:
|
||||
raise Exception(print("Need net1 pre-trained model!"))
|
||||
|
||||
resume_net1_model_path = os.path.join(hparams.net1_train_checkpoint_path, arg.resume_net1_model)
|
||||
resume_log = "Resume net1 model from : " + resume_net1_model_path
|
||||
print(resume_log)
|
||||
|
||||
checkpoint_net1 = torch.load(resume_net1_model_path)
|
||||
print("Load net1 model successfully!")
|
||||
|
||||
net1.load_state_dict(checkpoint_net1["net"])
|
||||
|
||||
# Fixed parameters of the net1 model
|
||||
for p in net1.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# Resume net2 model
|
||||
if arg.resume_net2_model is not None:
|
||||
resume_net2_model_path = os.path.join(arg.checkpoint_path, arg.resume_net2_model)
|
||||
resume_log = "Resume net2 model from : " + resume_net2_model_path
|
||||
print(resume_log)
|
||||
|
||||
checkpoint_net2 = torch.load(resume_net2_model_path)
|
||||
print("Load net2 model successfully!")
|
||||
|
||||
net2.load_state_dict(checkpoint_net2["net"])
|
||||
net2_optimizer.load_state_dict(checkpoint_net2["optimizer"])
|
||||
start_step = checkpoint_net2["step"]
|
||||
|
||||
if start_step >= arg.train_steps:
|
||||
raise Exception(print(" Training completed !"))
|
||||
|
||||
# Start training
|
||||
print("Start training ... ")
|
||||
start_time = time.time()
|
||||
|
||||
data_iter = iter(data_loader)
|
||||
|
||||
for step in range(start_step, arg.train_steps + 1):
|
||||
|
||||
# Get input data
|
||||
try:
|
||||
mfccs, spec, mel = next(data_iter)
|
||||
except:
|
||||
data_iter = iter(data_loader)
|
||||
mfccs, spec, mel = next(data_iter)
|
||||
|
||||
# Moving input data into the computing device
|
||||
mfccs = mfccs.to(device)
|
||||
spec = spec.to(device)
|
||||
mel = mel.to(device)
|
||||
|
||||
# Set net1 and net2 model
|
||||
net1 = net1.eval()
|
||||
net2 = net2.train()
|
||||
|
||||
# Compute net1
|
||||
net1_outputs_ppgs, _, _ = net1(mfccs)
|
||||
|
||||
net2_inputs_ppgs = net1_outputs_ppgs.detach()
|
||||
|
||||
# Compute net2
|
||||
pred_spec, pred_mel = net2(net2_inputs_ppgs)
|
||||
|
||||
# Compute the loss
|
||||
criterion = torch.nn.MSELoss(reduction='mean')
|
||||
loss_spec = criterion(pred_spec, spec)
|
||||
loss_mel = criterion(pred_mel, mel)
|
||||
loss = loss_spec + loss_mel
|
||||
|
||||
# Backward and optimize
|
||||
net2_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
net2_optimizer.step()
|
||||
|
||||
# Print out training info
|
||||
if step % arg.log_step == 0:
|
||||
et = time.time() - start_time
|
||||
et = str(datetime.timedelta(seconds=et))[:-7]
|
||||
log = "Elapsed [{}], Iteration [{}/{}], Loss : [{:.6f}], Loss_spec : [{:.6f}], Loss_mel : [{:.6f}]" \
|
||||
.format(et, step, arg.train_steps, loss, loss_spec, loss_mel)
|
||||
print(log)
|
||||
|
||||
# Save model
|
||||
if step % arg.save_step == 0:
|
||||
checkpoint = {
|
||||
"net": net2.state_dict(),
|
||||
"optimizer": net2_optimizer.state_dict(),
|
||||
"step": step
|
||||
}
|
||||
|
||||
if not os.path.isdir(arg.checkpoint_path):
|
||||
os.mkdir(arg.checkpoint_path)
|
||||
|
||||
torch.save(checkpoint, os.path.join(arg.checkpoint_path, 'ckpt_%s.pth' % str(step)))
|
||||
|
||||
log = "Net2 training result has been saved to pth : ckpt_%s.pth ." % str(step)
|
||||
print(log)
|
||||
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Set Net1
|
||||
parser.add_argument('-in_dims', default=hparams.net2_in_dims, type=int,
|
||||
help='Number of Net2 input dimensions.')
|
||||
parser.add_argument('-hidden_units', default=hparams.net2_hidden_units, type=int,
|
||||
help='Number of Net2 hidden units.')
|
||||
parser.add_argument('-dropout_rate', default=hparams.net2_dropout_rate, type=float,
|
||||
help='Rate of net2 Dropout layers.')
|
||||
parser.add_argument('-num_conv1d_banks', default=hparams.net2_num_conv1d_banks, type=int,
|
||||
help='Number of Net2 conv1d banks.')
|
||||
parser.add_argument('-num_highway_blocks', default=hparams.net2_num_highway_blocks, type=int,
|
||||
help='Number of Net2 Highway blocks.')
|
||||
|
||||
# Set DataLoader
|
||||
parser.add_argument('-data_path', default=hparams.net2_dataset, type=str,
|
||||
help='Path of Net2 dataset.')
|
||||
parser.add_argument('-batch_size', default=hparams.net2_batch_size, type=int,
|
||||
help='Batch size.')
|
||||
parser.add_argument('-num_workers', default=hparams.net2_num_workers, type=int,
|
||||
help='Number of workers.')
|
||||
|
||||
# Set Train config
|
||||
parser.add_argument('-device', default=hparams.net2_train_device, type=str,
|
||||
help='Net2 training device.')
|
||||
parser.add_argument('-checkpoint_path', default=hparams.net2_train_checkpoint_path, type=str,
|
||||
help='Net2 model checkpoint path.')
|
||||
parser.add_argument('-resume_net1_model', default=None, type=str,
|
||||
help='Net1 resume model checkpoint.')
|
||||
parser.add_argument('-resume_net2_model', default=None, type=str,
|
||||
help='Net2 resume model checkpoint.')
|
||||
parser.add_argument('-train_steps', default=hparams.net2_train_steps, type=int,
|
||||
help='Net2 training steps.')
|
||||
parser.add_argument('-learning_rate', default=hparams.net2_train_lr, type=float,
|
||||
help='Net2 learning rate.')
|
||||
parser.add_argument('-log_step', default=hparams.net2_train_log_step, type=int,
|
||||
help='Net2 training log steps.')
|
||||
parser.add_argument('-save_step', default=hparams.net2_train_save_step, type=int,
|
||||
help='Net2 training save steps.')
|
||||
|
||||
arguments = parser.parse_args()
|
||||
return arguments
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_arguments()
|
||||
|
||||
print("Train Net2 parameters : \n " + str(args))
|
||||
|
||||
train(args)
|
Loading…
Reference in New Issue