# Update train.train_net1:
#     Add checkpoint and save function to training model incoherent
#     Add a simple logging feature
#     Detail training log content, add time calculation
#     Add a function to calculate the accuracy of the model
#     Encapsulated model loss calculation and accuracy calculation

# Update hparams:
#     Add Net1 training parameters: checkpoint path, model save steps

# Update model.Net1:
#     The encapsulated functions (loss calculation and accuracy calculation) are moved into this module
This commit is contained in:
miaomiaomiao-LJY 2021-04-14 20:25:48 +08:00
parent 16f5eb68ac
commit 4664efbd99
3 changed files with 97 additions and 12 deletions

View File

@ -31,6 +31,8 @@ net1_logits_t = 1.0
# net1 train
net1_train_device = 'cuda:3' if torch.cuda.is_available() else 'cpu'
net1_train_steps = 10000
net1_train_checkpoint_path = "../checkpoint"
net1_train_lr = 0.0003
net1_train_log_step = 1
net1_train_multiple_flag = True
net1_train_log_step = 10
net1_train_save_step = 1000
net1_train_multiple_flag = False

View File

@ -1,5 +1,5 @@
import torch
from torch.nn import Module, Linear, Softmax
from torch.nn import Module, Linear, Softmax, CrossEntropyLoss
from .modules import PreNet, CBHG
import hparams
@ -66,3 +66,28 @@ class Net1(Module):
# preds : (N, L_in)
# logits_outputs : (N, L_in, phns_len)
return ppgs, preds, logits_outputs
def get_net1_loss(logits, phones, mfccs):
is_target = torch.sign(torch.abs(torch.sum(mfccs, -1)))
compute_loss = CrossEntropyLoss()
loss = compute_loss(logits.transpose(1, 2) / hparams.net1_logits_t, phones)
loss = loss * is_target
loss = torch.mean(loss)
return loss
def get_net1_acc(preds, phones, mfccs):
is_target = torch.sign(torch.abs(torch.sum(mfccs, -1)))
hits = torch.eq(preds, phones.int()).float()
num_hits = torch.sum(hits * is_target)
num_targets = torch.sum(is_target)
acc = num_hits / num_targets
return acc

View File

@ -1,7 +1,12 @@
import argparse
import hparams
import torch
from model.Net1 import Net1
import time
import datetime
import logging
import os
from model.Net1 import Net1, get_net1_loss, get_net1_acc
from dataloader.Net1DataLoader import get_net1_data_loader
@ -31,13 +36,34 @@ def train(arg):
batch_size=arg.batch_size,
num_workers=arg.num_workers)
start_step = 1
# Resume checkpoint
if arg.resume_model is not None:
resume_model_path = os.path.join(arg.checkpoint_path, arg.resume_model)
resume_log = "Resume model from : " + resume_model_path
print(resume_log)
logger.info(resume_log)
checkpoint = torch.load(resume_model_path)
print("Load model successfully!")
logger.info("Load model successfully!")
net1.load_state_dict(checkpoint["net"])
net1_optimizer.load_state_dict(checkpoint["optimizer"])
start_step = checkpoint["step"]
if start_step >= arg.train_steps:
logger.error(" Training completed !")
raise Exception(print(" Training completed !"))
# Start training
print("Start training ... ")
start_time = time.time()
data_iter = iter(data_loader)
start_step = 1
for step in range(start_step, arg.train_steps):
for step in range(start_step, arg.train_steps + 1):
# Get input data
try:
@ -55,12 +81,10 @@ def train(arg):
ppgs, preds, logits = net1(mfccs)
# Compute the loss
compute_loss = torch.nn.CrossEntropyLoss()
loss = compute_loss(logits.transpose(1, 2) / hparams.net1_logits_t, phones)
loss = get_net1_loss(logits, phones, mfccs)
is_target = torch.sign(torch.abs(torch.sum(mfccs, -1)))
loss = loss * is_target
loss = torch.mean(loss)
# Compute the accuracy
acc = get_net1_acc(preds, phones, mfccs)
# Backward and optimize
net1_optimizer.zero_grad()
@ -81,8 +105,30 @@ def train(arg):
# Print out training info
if step % arg.log_step == 0:
log = "Iteration [{}/{}], [loss : {:.6f}]".format(step, arg.train_steps, loss)
et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}], Loss : [{:.6f}], Accuracy : [{:.6f}]".format(et, step,
arg.train_steps, loss,
acc)
print(log)
logger.info(log)
# Save model
if step % arg.save_step == 0:
checkpoint = {
"net": net1.state_dict(),
"optimizer": net1_optimizer.state_dict(),
"step": step
}
if not os.path.isdir(arg.checkpoint_path):
os.mkdir(arg.checkpoint_path)
torch.save(checkpoint, os.path.join(arg.checkpoint_path, 'ckpt_%s.pth' % str(step)))
log = "Net1 training result has been saved to pth : ckpt_%s.pth ." % str(step)
print(log)
logger.info(log)
def get_arguments():
@ -111,12 +157,18 @@ def get_arguments():
# Set Train config
parser.add_argument('-device', default=hparams.net1_train_device, type=str,
help='Net1 training device.')
parser.add_argument('-checkpoint_path', default=hparams.net1_train_checkpoint_path, type=str,
help='Net1 model checkpoint path.')
parser.add_argument('-resume_model', default=None, type=str,
help='Net1 resume model checkpoint.')
parser.add_argument('-train_steps', default=hparams.net1_train_steps, type=int,
help='Net1 training steps.')
parser.add_argument('-learning_rate', default=hparams.net1_train_lr, type=float,
help='Net1 learning rate.')
parser.add_argument('-log_step', default=hparams.net1_train_log_step, type=int,
help='Net1 training log steps.')
parser.add_argument('-save_step', default=hparams.net1_train_save_step, type=int,
help='Net1 training save steps.')
parser.add_argument('-multiple_train', default=hparams.net1_train_multiple_flag, type=bool,
help='Net1 training log steps.')
@ -125,8 +177,14 @@ def get_arguments():
if __name__ == '__main__':
# Set log
logging.basicConfig(level=logging.INFO, filename="../log.txt", filemode="w")
logger = logging.getLogger("log_test")
args = get_arguments()
print("Train Net1 parameters : \n " + str(args))
logger.info(args)
if args.multiple_train and args.device is not 'cuda':
raise Exception("Multi-GPU training mode enabled, but the default computing device does not support it. "