UPDATE
# Update train.train_net1: # Add checkpoint and save function to training model incoherent # Add a simple logging feature # Detail training log content, add time calculation # Add a function to calculate the accuracy of the model # Encapsulated model loss calculation and accuracy calculation # Update hparams: # Add Net1 training parameters: checkpoint path, model save steps # Update model.Net1: # The encapsulated functions (loss calculation and accuracy calculation) are moved into this module
This commit is contained in:
parent
16f5eb68ac
commit
4664efbd99
|
@ -31,6 +31,8 @@ net1_logits_t = 1.0
|
|||
# net1 train
|
||||
net1_train_device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
|
||||
net1_train_steps = 10000
|
||||
net1_train_checkpoint_path = "../checkpoint"
|
||||
net1_train_lr = 0.0003
|
||||
net1_train_log_step = 1
|
||||
net1_train_multiple_flag = True
|
||||
net1_train_log_step = 10
|
||||
net1_train_save_step = 1000
|
||||
net1_train_multiple_flag = False
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from torch.nn import Module, Linear, Softmax
|
||||
from torch.nn import Module, Linear, Softmax, CrossEntropyLoss
|
||||
from .modules import PreNet, CBHG
|
||||
|
||||
import hparams
|
||||
|
@ -66,3 +66,28 @@ class Net1(Module):
|
|||
# preds : (N, L_in)
|
||||
# logits_outputs : (N, L_in, phns_len)
|
||||
return ppgs, preds, logits_outputs
|
||||
|
||||
|
||||
def get_net1_loss(logits, phones, mfccs):
|
||||
is_target = torch.sign(torch.abs(torch.sum(mfccs, -1)))
|
||||
|
||||
compute_loss = CrossEntropyLoss()
|
||||
loss = compute_loss(logits.transpose(1, 2) / hparams.net1_logits_t, phones)
|
||||
|
||||
loss = loss * is_target
|
||||
loss = torch.mean(loss)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def get_net1_acc(preds, phones, mfccs):
|
||||
is_target = torch.sign(torch.abs(torch.sum(mfccs, -1)))
|
||||
|
||||
hits = torch.eq(preds, phones.int()).float()
|
||||
|
||||
num_hits = torch.sum(hits * is_target)
|
||||
num_targets = torch.sum(is_target)
|
||||
|
||||
acc = num_hits / num_targets
|
||||
|
||||
return acc
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
import argparse
|
||||
import hparams
|
||||
import torch
|
||||
from model.Net1 import Net1
|
||||
import time
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
|
||||
from model.Net1 import Net1, get_net1_loss, get_net1_acc
|
||||
from dataloader.Net1DataLoader import get_net1_data_loader
|
||||
|
||||
|
||||
|
@ -31,13 +36,34 @@ def train(arg):
|
|||
batch_size=arg.batch_size,
|
||||
num_workers=arg.num_workers)
|
||||
|
||||
start_step = 1
|
||||
|
||||
# Resume checkpoint
|
||||
if arg.resume_model is not None:
|
||||
resume_model_path = os.path.join(arg.checkpoint_path, arg.resume_model)
|
||||
resume_log = "Resume model from : " + resume_model_path
|
||||
print(resume_log)
|
||||
logger.info(resume_log)
|
||||
|
||||
checkpoint = torch.load(resume_model_path)
|
||||
print("Load model successfully!")
|
||||
logger.info("Load model successfully!")
|
||||
|
||||
net1.load_state_dict(checkpoint["net"])
|
||||
net1_optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
start_step = checkpoint["step"]
|
||||
|
||||
if start_step >= arg.train_steps:
|
||||
logger.error(" Training completed !")
|
||||
raise Exception(print(" Training completed !"))
|
||||
|
||||
# Start training
|
||||
print("Start training ... ")
|
||||
start_time = time.time()
|
||||
|
||||
data_iter = iter(data_loader)
|
||||
|
||||
start_step = 1
|
||||
for step in range(start_step, arg.train_steps):
|
||||
for step in range(start_step, arg.train_steps + 1):
|
||||
|
||||
# Get input data
|
||||
try:
|
||||
|
@ -55,12 +81,10 @@ def train(arg):
|
|||
ppgs, preds, logits = net1(mfccs)
|
||||
|
||||
# Compute the loss
|
||||
compute_loss = torch.nn.CrossEntropyLoss()
|
||||
loss = compute_loss(logits.transpose(1, 2) / hparams.net1_logits_t, phones)
|
||||
loss = get_net1_loss(logits, phones, mfccs)
|
||||
|
||||
is_target = torch.sign(torch.abs(torch.sum(mfccs, -1)))
|
||||
loss = loss * is_target
|
||||
loss = torch.mean(loss)
|
||||
# Compute the accuracy
|
||||
acc = get_net1_acc(preds, phones, mfccs)
|
||||
|
||||
# Backward and optimize
|
||||
net1_optimizer.zero_grad()
|
||||
|
@ -81,8 +105,30 @@ def train(arg):
|
|||
|
||||
# Print out training info
|
||||
if step % arg.log_step == 0:
|
||||
log = "Iteration [{}/{}], [loss : {:.6f}]".format(step, arg.train_steps, loss)
|
||||
et = time.time() - start_time
|
||||
et = str(datetime.timedelta(seconds=et))[:-7]
|
||||
log = "Elapsed [{}], Iteration [{}/{}], Loss : [{:.6f}], Accuracy : [{:.6f}]".format(et, step,
|
||||
arg.train_steps, loss,
|
||||
acc)
|
||||
print(log)
|
||||
logger.info(log)
|
||||
|
||||
# Save model
|
||||
if step % arg.save_step == 0:
|
||||
checkpoint = {
|
||||
"net": net1.state_dict(),
|
||||
"optimizer": net1_optimizer.state_dict(),
|
||||
"step": step
|
||||
}
|
||||
|
||||
if not os.path.isdir(arg.checkpoint_path):
|
||||
os.mkdir(arg.checkpoint_path)
|
||||
|
||||
torch.save(checkpoint, os.path.join(arg.checkpoint_path, 'ckpt_%s.pth' % str(step)))
|
||||
|
||||
log = "Net1 training result has been saved to pth : ckpt_%s.pth ." % str(step)
|
||||
print(log)
|
||||
logger.info(log)
|
||||
|
||||
|
||||
def get_arguments():
|
||||
|
@ -111,12 +157,18 @@ def get_arguments():
|
|||
# Set Train config
|
||||
parser.add_argument('-device', default=hparams.net1_train_device, type=str,
|
||||
help='Net1 training device.')
|
||||
parser.add_argument('-checkpoint_path', default=hparams.net1_train_checkpoint_path, type=str,
|
||||
help='Net1 model checkpoint path.')
|
||||
parser.add_argument('-resume_model', default=None, type=str,
|
||||
help='Net1 resume model checkpoint.')
|
||||
parser.add_argument('-train_steps', default=hparams.net1_train_steps, type=int,
|
||||
help='Net1 training steps.')
|
||||
parser.add_argument('-learning_rate', default=hparams.net1_train_lr, type=float,
|
||||
help='Net1 learning rate.')
|
||||
parser.add_argument('-log_step', default=hparams.net1_train_log_step, type=int,
|
||||
help='Net1 training log steps.')
|
||||
parser.add_argument('-save_step', default=hparams.net1_train_save_step, type=int,
|
||||
help='Net1 training save steps.')
|
||||
parser.add_argument('-multiple_train', default=hparams.net1_train_multiple_flag, type=bool,
|
||||
help='Net1 training log steps.')
|
||||
|
||||
|
@ -125,8 +177,14 @@ def get_arguments():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Set log
|
||||
logging.basicConfig(level=logging.INFO, filename="../log.txt", filemode="w")
|
||||
logger = logging.getLogger("log_test")
|
||||
|
||||
args = get_arguments()
|
||||
print("Train Net1 parameters : \n " + str(args))
|
||||
logger.info(args)
|
||||
|
||||
if args.multiple_train and args.device is not 'cuda':
|
||||
raise Exception("Multi-GPU training mode enabled, but the default computing device does not support it. "
|
||||
|
|
Loading…
Reference in New Issue