ADD
# Add dataloader.Net2DataLoader: # Implement Net2Dataset # Implement get_net2_data_loader # Test Net2DataLoader # Add train.train_net2: # Add net2 train main function # Implement net2 train function # Implement net2 model save and load function # Implement simple debugging function # Update audio_operation: # Turn off debugging function # Update hparams: # Add net2 dataset default setting # Add net2 model default setting # Add net2 train default setting
This commit is contained in:
parent
080136ba6f
commit
1904696bc7
|
@ -158,7 +158,7 @@ def get_mfccs_and_spectrogram(wav_file, trim=True, random_crop=False):
|
||||||
length = sr * default_duration
|
length = sr * default_duration
|
||||||
wav = librosa.util.fix_length(wav, length)
|
wav = librosa.util.fix_length(wav, length)
|
||||||
|
|
||||||
debug = True
|
debug = False
|
||||||
if debug:
|
if debug:
|
||||||
print("wav.shape : " + str(wav.shape))
|
print("wav.shape : " + str(wav.shape))
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
import glob
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
|
from audio_operation import get_mfccs_and_spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
class Net2Dataset(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, data_path):
|
||||||
|
self.wav_files = glob.glob(data_path)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
wav = self.wav_files[item]
|
||||||
|
return get_mfccs_and_spectrogram(wav)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.wav_files)
|
||||||
|
|
||||||
|
|
||||||
|
def get_net2_data_loader(data_path, batch_size, num_workers):
|
||||||
|
dataset = Net2Dataset(data_path)
|
||||||
|
|
||||||
|
data_loader = DataLoader(dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
drop_last=True,
|
||||||
|
worker_init_fn=np.random.seed((torch.initial_seed()) % (2 ** 32)))
|
||||||
|
|
||||||
|
return data_loader
|
25
hparams.py
25
hparams.py
|
@ -30,9 +30,30 @@ net1_logits_t = 1.0
|
||||||
|
|
||||||
# net1 train
|
# net1 train
|
||||||
net1_train_device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
|
net1_train_device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
|
||||||
net1_train_steps = 10000
|
net1_train_steps = 100000
|
||||||
net1_train_checkpoint_path = "../checkpoint"
|
net1_train_checkpoint_path = "../checkpoint/net1"
|
||||||
net1_train_lr = 0.0003
|
net1_train_lr = 0.0003
|
||||||
net1_train_log_step = 10
|
net1_train_log_step = 10
|
||||||
net1_train_save_step = 1000
|
net1_train_save_step = 1000
|
||||||
net1_train_multiple_flag = False
|
net1_train_multiple_flag = False
|
||||||
|
|
||||||
|
# Net2
|
||||||
|
# net2 dataset
|
||||||
|
net2_dataset = "../data/dataset/arctic/slt/*.wav"
|
||||||
|
net2_batch_size = 16
|
||||||
|
net2_num_workers = 5
|
||||||
|
|
||||||
|
# net2 model
|
||||||
|
net2_in_dims = phns_len
|
||||||
|
net2_hidden_units = 256
|
||||||
|
net2_dropout_rate = 0
|
||||||
|
net2_num_conv1d_banks = 8
|
||||||
|
net2_num_highway_blocks = 8
|
||||||
|
|
||||||
|
# net2 train
|
||||||
|
net2_train_device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
|
||||||
|
net2_train_steps = 100000
|
||||||
|
net2_train_checkpoint_path = "../checkpoint/net2"
|
||||||
|
net2_train_lr = 0.0003
|
||||||
|
net2_train_log_step = 10
|
||||||
|
net2_train_save_step = 10000
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import hparams
|
||||||
|
|
||||||
|
from model.Net1 import Net1
|
||||||
|
from model.Net2 import Net2
|
||||||
|
from dataloader.Net2DataLoader import get_net2_data_loader
|
||||||
|
|
||||||
|
|
||||||
|
def train(arg):
|
||||||
|
device = torch.device(arg.device)
|
||||||
|
|
||||||
|
# Build Net1 model
|
||||||
|
net1 = Net1(in_dims=hparams.net1_in_dims,
|
||||||
|
hidden_units=hparams.net1_hidden_units,
|
||||||
|
dropout_rate=hparams.net1_dropout_rate,
|
||||||
|
num_conv1d_banks=hparams.net1_num_conv1d_banks,
|
||||||
|
num_highway_blocks=hparams.net1_num_highway_blocks)
|
||||||
|
|
||||||
|
# Move net1 model into the computing device
|
||||||
|
net1.to(device)
|
||||||
|
|
||||||
|
# Build Net2 model
|
||||||
|
net2 = Net2(in_dims=arg.in_dims,
|
||||||
|
hidden_units=arg.hidden_units,
|
||||||
|
dropout_rate=arg.dropout_rate,
|
||||||
|
num_conv1d_banks=arg.num_conv1d_banks,
|
||||||
|
num_highway_blocks=arg.num_highway_blocks)
|
||||||
|
|
||||||
|
# Create optimizer
|
||||||
|
net2_optimizer = torch.optim.Adam(net2.parameters(), lr=arg.learning_rate)
|
||||||
|
|
||||||
|
# Move net2 model into the computing device
|
||||||
|
net2.to(device)
|
||||||
|
|
||||||
|
# Set data loader
|
||||||
|
data_loader = get_net2_data_loader(data_path=arg.data_path,
|
||||||
|
batch_size=arg.batch_size,
|
||||||
|
num_workers=arg.num_workers)
|
||||||
|
|
||||||
|
start_step = 1
|
||||||
|
|
||||||
|
# Resume net1 model
|
||||||
|
if arg.resume_net1_model is None:
|
||||||
|
raise Exception(print("Need net1 pre-trained model!"))
|
||||||
|
|
||||||
|
resume_net1_model_path = os.path.join(hparams.net1_train_checkpoint_path, arg.resume_net1_model)
|
||||||
|
resume_log = "Resume net1 model from : " + resume_net1_model_path
|
||||||
|
print(resume_log)
|
||||||
|
|
||||||
|
checkpoint_net1 = torch.load(resume_net1_model_path)
|
||||||
|
print("Load net1 model successfully!")
|
||||||
|
|
||||||
|
net1.load_state_dict(checkpoint_net1["net"])
|
||||||
|
|
||||||
|
# Fixed parameters of the net1 model
|
||||||
|
for p in net1.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
# Resume net2 model
|
||||||
|
if arg.resume_net2_model is not None:
|
||||||
|
resume_net2_model_path = os.path.join(arg.checkpoint_path, arg.resume_net2_model)
|
||||||
|
resume_log = "Resume net2 model from : " + resume_net2_model_path
|
||||||
|
print(resume_log)
|
||||||
|
|
||||||
|
checkpoint_net2 = torch.load(resume_net2_model_path)
|
||||||
|
print("Load net2 model successfully!")
|
||||||
|
|
||||||
|
net2.load_state_dict(checkpoint_net2["net"])
|
||||||
|
net2_optimizer.load_state_dict(checkpoint_net2["optimizer"])
|
||||||
|
start_step = checkpoint_net2["step"]
|
||||||
|
|
||||||
|
if start_step >= arg.train_steps:
|
||||||
|
raise Exception(print(" Training completed !"))
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
print("Start training ... ")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
data_iter = iter(data_loader)
|
||||||
|
|
||||||
|
for step in range(start_step, arg.train_steps + 1):
|
||||||
|
|
||||||
|
# Get input data
|
||||||
|
try:
|
||||||
|
mfccs, spec, mel = next(data_iter)
|
||||||
|
except:
|
||||||
|
data_iter = iter(data_loader)
|
||||||
|
mfccs, spec, mel = next(data_iter)
|
||||||
|
|
||||||
|
# Moving input data into the computing device
|
||||||
|
mfccs = mfccs.to(device)
|
||||||
|
spec = spec.to(device)
|
||||||
|
mel = mel.to(device)
|
||||||
|
|
||||||
|
# Set net1 and net2 model
|
||||||
|
net1 = net1.eval()
|
||||||
|
net2 = net2.train()
|
||||||
|
|
||||||
|
# Compute net1
|
||||||
|
net1_outputs_ppgs, _, _ = net1(mfccs)
|
||||||
|
|
||||||
|
net2_inputs_ppgs = net1_outputs_ppgs.detach()
|
||||||
|
|
||||||
|
# Compute net2
|
||||||
|
pred_spec, pred_mel = net2(net2_inputs_ppgs)
|
||||||
|
|
||||||
|
# Compute the loss
|
||||||
|
criterion = torch.nn.MSELoss(reduction='mean')
|
||||||
|
loss_spec = criterion(pred_spec, spec)
|
||||||
|
loss_mel = criterion(pred_mel, mel)
|
||||||
|
loss = loss_spec + loss_mel
|
||||||
|
|
||||||
|
# Backward and optimize
|
||||||
|
net2_optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
net2_optimizer.step()
|
||||||
|
|
||||||
|
# Print out training info
|
||||||
|
if step % arg.log_step == 0:
|
||||||
|
et = time.time() - start_time
|
||||||
|
et = str(datetime.timedelta(seconds=et))[:-7]
|
||||||
|
log = "Elapsed [{}], Iteration [{}/{}], Loss : [{:.6f}], Loss_spec : [{:.6f}], Loss_mel : [{:.6f}]" \
|
||||||
|
.format(et, step, arg.train_steps, loss, loss_spec, loss_mel)
|
||||||
|
print(log)
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
if step % arg.save_step == 0:
|
||||||
|
checkpoint = {
|
||||||
|
"net": net2.state_dict(),
|
||||||
|
"optimizer": net2_optimizer.state_dict(),
|
||||||
|
"step": step
|
||||||
|
}
|
||||||
|
|
||||||
|
if not os.path.isdir(arg.checkpoint_path):
|
||||||
|
os.mkdir(arg.checkpoint_path)
|
||||||
|
|
||||||
|
torch.save(checkpoint, os.path.join(arg.checkpoint_path, 'ckpt_%s.pth' % str(step)))
|
||||||
|
|
||||||
|
log = "Net2 training result has been saved to pth : ckpt_%s.pth ." % str(step)
|
||||||
|
print(log)
|
||||||
|
|
||||||
|
|
||||||
|
def get_arguments():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
# Set Net1
|
||||||
|
parser.add_argument('-in_dims', default=hparams.net2_in_dims, type=int,
|
||||||
|
help='Number of Net2 input dimensions.')
|
||||||
|
parser.add_argument('-hidden_units', default=hparams.net2_hidden_units, type=int,
|
||||||
|
help='Number of Net2 hidden units.')
|
||||||
|
parser.add_argument('-dropout_rate', default=hparams.net2_dropout_rate, type=float,
|
||||||
|
help='Rate of net2 Dropout layers.')
|
||||||
|
parser.add_argument('-num_conv1d_banks', default=hparams.net2_num_conv1d_banks, type=int,
|
||||||
|
help='Number of Net2 conv1d banks.')
|
||||||
|
parser.add_argument('-num_highway_blocks', default=hparams.net2_num_highway_blocks, type=int,
|
||||||
|
help='Number of Net2 Highway blocks.')
|
||||||
|
|
||||||
|
# Set DataLoader
|
||||||
|
parser.add_argument('-data_path', default=hparams.net2_dataset, type=str,
|
||||||
|
help='Path of Net2 dataset.')
|
||||||
|
parser.add_argument('-batch_size', default=hparams.net2_batch_size, type=int,
|
||||||
|
help='Batch size.')
|
||||||
|
parser.add_argument('-num_workers', default=hparams.net2_num_workers, type=int,
|
||||||
|
help='Number of workers.')
|
||||||
|
|
||||||
|
# Set Train config
|
||||||
|
parser.add_argument('-device', default=hparams.net2_train_device, type=str,
|
||||||
|
help='Net2 training device.')
|
||||||
|
parser.add_argument('-checkpoint_path', default=hparams.net2_train_checkpoint_path, type=str,
|
||||||
|
help='Net2 model checkpoint path.')
|
||||||
|
parser.add_argument('-resume_net1_model', default=None, type=str,
|
||||||
|
help='Net1 resume model checkpoint.')
|
||||||
|
parser.add_argument('-resume_net2_model', default=None, type=str,
|
||||||
|
help='Net2 resume model checkpoint.')
|
||||||
|
parser.add_argument('-train_steps', default=hparams.net2_train_steps, type=int,
|
||||||
|
help='Net2 training steps.')
|
||||||
|
parser.add_argument('-learning_rate', default=hparams.net2_train_lr, type=float,
|
||||||
|
help='Net2 learning rate.')
|
||||||
|
parser.add_argument('-log_step', default=hparams.net2_train_log_step, type=int,
|
||||||
|
help='Net2 training log steps.')
|
||||||
|
parser.add_argument('-save_step', default=hparams.net2_train_save_step, type=int,
|
||||||
|
help='Net2 training save steps.')
|
||||||
|
|
||||||
|
arguments = parser.parse_args()
|
||||||
|
return arguments
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = get_arguments()
|
||||||
|
|
||||||
|
print("Train Net2 parameters : \n " + str(args))
|
||||||
|
|
||||||
|
train(args)
|
Loading…
Reference in New Issue