# Add test_net.test_net1:
#     Implement net1 test
#     Load trained net1 model, compute loss and accuracy
This commit is contained in:
miaomiaomiao-LJY 2021-04-20 15:32:35 +08:00
parent 1904696bc7
commit a8e872db57
1 changed files with 89 additions and 0 deletions

89
test_net/test_net1.py Normal file
View File

@ -0,0 +1,89 @@
import argparse
import os
import torch
import time
import datetime
import hparams
from model.Net1 import Net1, get_net1_loss, get_net1_acc
from dataloader.Net1DataLoader import get_net1_data_loader
def test(arg):
device = torch.device(arg.device)
# Build model
net1 = Net1(in_dims=hparams.net1_in_dims,
hidden_units=hparams.net1_hidden_units,
dropout_rate=hparams.net1_dropout_rate,
num_conv1d_banks=hparams.net1_num_conv1d_banks,
num_highway_blocks=hparams.net1_num_highway_blocks)
# Move net1 model into the computing device
net1.to(device)
# Set data loader
data_loader = get_net1_data_loader(data_path=hparams.net1_test_dataset,
batch_size=hparams.net1_batch_size,
num_workers=hparams.net1_num_workers)
# Resume net1 model
resume_model_path = os.path.join(hparams.net1_train_checkpoint_path, arg.resume_model)
resume_log = "Resume net1 model from : " + resume_model_path
print(resume_log)
checkpoint = torch.load(resume_model_path)
print("Load model successfully!")
net1.load_state_dict(checkpoint["net"])
# Start testing
print("Start testing ... ")
start_time = time.time()
data_iter = iter(data_loader)
# Get input data
mfccs, phones = next(data_iter)
# Moving input data into the computing device
mfccs = mfccs.to(device)
phones = phones.long().to(device)
# Test net1 model
net1 = net1.eval()
ppgs, preds, logits = net1(mfccs)
# Compute the loss
loss = get_net1_loss(logits, phones, mfccs)
# Compute the accuracy
acc = get_net1_acc(preds, phones, mfccs)
# Print loss and acc
et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Loss : [{:.6f}], Accuracy : [{:.6f}]".format(et, loss, acc)
print(log)
def get_arguments():
parser = argparse.ArgumentParser()
# Set test net1 config
parser.add_argument('-device', default=hparams.net1_train_device, type=str,
help='Net1 training device.')
parser.add_argument('resume_model', type=str,
help='Net1 resume model checkpoint.')
arguments = parser.parse_args()
return arguments
if __name__ == '__main__':
args = get_arguments()
print("Test Net1 parameters : \n " + str(args))
test(args)