diff --git a/detection b/detection new file mode 100644 index 0000000..cf6d9d6 --- /dev/null +++ b/detection @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +import torchvision +from torch.utils.data import DataLoader +import torch.optim as optim +from torch.optim import lr_scheduler +import argparse +import os +import cv2 +from network.models import model_selection +from dataset.transform import xception_default_data_transforms +from dataset.mydataset import MyDataset +def main(): + args = parse.parse_args() + test_list = args.test_list + batch_size = args.batch_size + model_path = args.model_path + torch.backends.cudnn.benchmark=True + test_dataset = MyDataset(txt_path=test_list, transform=xception_default_data_transforms['test']) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=8) + test_dataset_size = len(test_dataset) + corrects = 0 + acc = 0 + #model = torchvision.models.densenet121(num_classes=2) + model = model_selection(modelname='xception', num_out_classes=2, dropout=0.5) + model.load_state_dict(torch.load(model_path)) + if isinstance(model, torch.nn.DataParallel): + model = model.module + model = model.cuda() + model.eval() + with torch.no_grad(): + for (image, labels) in test_loader: + image = image.cuda() + labels = labels.cuda() + outputs = model(image) + _, preds = torch.max(outputs.data, 1) + corrects += torch.sum(preds == labels.data).to(torch.float32) + print('Iteration Acc {:.4f}'.format(torch.sum(preds == labels.data).to(torch.float32)/batch_size)) + acc = corrects / test_dataset_size + print('Test Acc: {:.4f}'.format(acc)) + + + +if __name__ == '__main__': + parse = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parse.add_argument('--batch_size', '-bz', type=int, default=32) + parse.add_argument('--test_list', '-tl', type=str, default='./data_list/Deepfakes_c0_test.txt') + parse.add_argument('--model_path', '-mp', type=str, default='./pretrained_model/df_c0_best.pkl') + main() + print('Hello world!!!') \ No newline at end of file