manifold_face_tamper/detection/models.py

157 lines
5.6 KiB
Python

"""
Author: Andreas Rössler
"""
import os
import argparse
import torch
#import pretrainedmodels
import torch.nn as nn
import torch.nn.functional as F
from network.xception import xception, xception_concat
import math
import torchvision
def return_pytorch04_xception(pretrained=False):
# Raises warning "src not broadcastable to dst" but thats fine
model = xception(pretrained=False)
if pretrained:
# Load model in torch 0.4+
model.fc = model.last_linear
del model.last_linear
state_dict = torch.load(
'/public/liuhonggu/.torch/models/xception-b5690688.pth')
for name, weights in state_dict.items():
if 'pointwise' in name:
state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
model.load_state_dict(state_dict)
model.last_linear = model.fc
del model.fc
return model
class TransferModel(nn.Module):
"""
Simple transfer learning model that takes an imagenet pretrained model with
a fc layer as base model and retrains a new fc layer for num_out_classes
"""
def __init__(self, modelchoice, num_out_classes=2, dropout=0.5):
super(TransferModel, self).__init__()
self.modelchoice = modelchoice
if modelchoice == 'xception':
self.model = return_pytorch04_xception(pretrained=False)
# Replace fc
num_ftrs = self.model.last_linear.in_features
if not dropout:
self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
else:
print('Using dropout', dropout)
self.model.last_linear = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(num_ftrs, num_out_classes)
)
elif modelchoice == 'xception_concat':
self.model = xception_concat()
num_ftrs = self.model.last_linear.in_features
if not dropout:
self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
else:
print('Using dropout', dropout)
self.model.last_linear = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(num_ftrs, num_out_classes)
)
elif modelchoice == 'resnet50' or modelchoice == 'resnet18':
if modelchoice == 'resnet50':
self.model = torchvision.models.resnet50(pretrained=True)
if modelchoice == 'resnet18':
self.model = torchvision.models.resnet18(pretrained=True)
# Replace fc
num_ftrs = self.model.fc.in_features
if not dropout:
self.model.fc = nn.Linear(num_ftrs, num_out_classes)
else:
self.model.fc = nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(num_ftrs, num_out_classes)
)
else:
raise Exception('Choose valid model, e.g. resnet50')
def set_trainable_up_to(self, boolean, layername="Conv2d_4a_3x3"):
"""
Freezes all layers below a specific layer and sets the following layers
to true if boolean else only the fully connected final layer
:param boolean:
:param layername: depends on network, for inception e.g. Conv2d_4a_3x3
:return:
"""
# Stage-1: freeze all the layers
if layername is None:
for i, param in self.model.named_parameters():
param.requires_grad = True
return
else:
for i, param in self.model.named_parameters():
param.requires_grad = False
if boolean:
# Make all layers following the layername layer trainable
ct = []
found = False
for name, child in self.model.named_children():
if layername in ct:
found = True
for params in child.parameters():
params.requires_grad = True
ct.append(name)
if not found:
raise Exception('Layer not found, cant finetune!'.format(
layername))
else:
if self.modelchoice == 'xception':
# Make fc trainable
for param in self.model.last_linear.parameters():
param.requires_grad = True
else:
# Make fc trainable
for param in self.model.fc.parameters():
param.requires_grad = True
def forward(self, x):
x = self.model(x)
return x
def model_selection(modelname, num_out_classes,
dropout=None):
"""
:param modelname:
:return: model, image size, pretraining<yes/no>, input_list
"""
if modelname == 'xception':
return TransferModel(modelchoice='xception',
num_out_classes=num_out_classes)
# , 299, \True, ['image'], None
elif modelname == 'resnet18':
return TransferModel(modelchoice='resnet18', dropout=dropout,
num_out_classes=num_out_classes)
# , \224, True, ['image'], None
elif modelname == 'xception_concat':
return TransferModel(modelchoice='xception_concat',
num_out_classes=num_out_classes)
else:
raise NotImplementedError(modelname)
if __name__ == '__main__':
model, image_size, *_ = model_selection('xception', num_out_classes=2)
print(model)
model = model.cuda()
from torchsummary import summary
input_s = (3, image_size, image_size)
print(summary(model, input_s))