forked from BIT_SCST_STIA/manifold_face_tamper
ADD file via upload
This commit is contained in:
parent
5144866ef5
commit
598b1b1fed
|
@ -0,0 +1,156 @@
|
|||
"""
|
||||
|
||||
Author: Andreas Rössler
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
import torch
|
||||
#import pretrainedmodels
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from network.xception import xception, xception_concat
|
||||
import math
|
||||
import torchvision
|
||||
|
||||
|
||||
def return_pytorch04_xception(pretrained=False):
|
||||
# Raises warning "src not broadcastable to dst" but thats fine
|
||||
model = xception(pretrained=False)
|
||||
if pretrained:
|
||||
# Load model in torch 0.4+
|
||||
model.fc = model.last_linear
|
||||
del model.last_linear
|
||||
state_dict = torch.load(
|
||||
'/public/liuhonggu/.torch/models/xception-b5690688.pth')
|
||||
for name, weights in state_dict.items():
|
||||
if 'pointwise' in name:
|
||||
state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
|
||||
model.load_state_dict(state_dict)
|
||||
model.last_linear = model.fc
|
||||
del model.fc
|
||||
return model
|
||||
|
||||
|
||||
class TransferModel(nn.Module):
|
||||
"""
|
||||
Simple transfer learning model that takes an imagenet pretrained model with
|
||||
a fc layer as base model and retrains a new fc layer for num_out_classes
|
||||
"""
|
||||
def __init__(self, modelchoice, num_out_classes=2, dropout=0.5):
|
||||
super(TransferModel, self).__init__()
|
||||
self.modelchoice = modelchoice
|
||||
if modelchoice == 'xception':
|
||||
self.model = return_pytorch04_xception(pretrained=False)
|
||||
# Replace fc
|
||||
num_ftrs = self.model.last_linear.in_features
|
||||
if not dropout:
|
||||
self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
|
||||
else:
|
||||
print('Using dropout', dropout)
|
||||
self.model.last_linear = nn.Sequential(
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Linear(num_ftrs, num_out_classes)
|
||||
)
|
||||
elif modelchoice == 'xception_concat':
|
||||
self.model = xception_concat()
|
||||
num_ftrs = self.model.last_linear.in_features
|
||||
if not dropout:
|
||||
self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
|
||||
else:
|
||||
print('Using dropout', dropout)
|
||||
self.model.last_linear = nn.Sequential(
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Linear(num_ftrs, num_out_classes)
|
||||
)
|
||||
elif modelchoice == 'resnet50' or modelchoice == 'resnet18':
|
||||
if modelchoice == 'resnet50':
|
||||
self.model = torchvision.models.resnet50(pretrained=True)
|
||||
if modelchoice == 'resnet18':
|
||||
self.model = torchvision.models.resnet18(pretrained=True)
|
||||
# Replace fc
|
||||
num_ftrs = self.model.fc.in_features
|
||||
if not dropout:
|
||||
self.model.fc = nn.Linear(num_ftrs, num_out_classes)
|
||||
else:
|
||||
self.model.fc = nn.Sequential(
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Linear(num_ftrs, num_out_classes)
|
||||
)
|
||||
else:
|
||||
raise Exception('Choose valid model, e.g. resnet50')
|
||||
|
||||
def set_trainable_up_to(self, boolean, layername="Conv2d_4a_3x3"):
|
||||
"""
|
||||
Freezes all layers below a specific layer and sets the following layers
|
||||
to true if boolean else only the fully connected final layer
|
||||
:param boolean:
|
||||
:param layername: depends on network, for inception e.g. Conv2d_4a_3x3
|
||||
:return:
|
||||
"""
|
||||
# Stage-1: freeze all the layers
|
||||
if layername is None:
|
||||
for i, param in self.model.named_parameters():
|
||||
param.requires_grad = True
|
||||
return
|
||||
else:
|
||||
for i, param in self.model.named_parameters():
|
||||
param.requires_grad = False
|
||||
if boolean:
|
||||
# Make all layers following the layername layer trainable
|
||||
ct = []
|
||||
found = False
|
||||
for name, child in self.model.named_children():
|
||||
if layername in ct:
|
||||
found = True
|
||||
for params in child.parameters():
|
||||
params.requires_grad = True
|
||||
ct.append(name)
|
||||
if not found:
|
||||
raise Exception('Layer not found, cant finetune!'.format(
|
||||
layername))
|
||||
else:
|
||||
if self.modelchoice == 'xception':
|
||||
# Make fc trainable
|
||||
for param in self.model.last_linear.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
else:
|
||||
# Make fc trainable
|
||||
for param in self.model.fc.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def forward(self, x):
|
||||
x = self.model(x)
|
||||
return x
|
||||
|
||||
|
||||
def model_selection(modelname, num_out_classes,
|
||||
dropout=None):
|
||||
"""
|
||||
:param modelname:
|
||||
:return: model, image size, pretraining<yes/no>, input_list
|
||||
"""
|
||||
if modelname == 'xception':
|
||||
return TransferModel(modelchoice='xception',
|
||||
num_out_classes=num_out_classes)
|
||||
# , 299, \True, ['image'], None
|
||||
elif modelname == 'resnet18':
|
||||
return TransferModel(modelchoice='resnet18', dropout=dropout,
|
||||
num_out_classes=num_out_classes)
|
||||
# , \224, True, ['image'], None
|
||||
elif modelname == 'xception_concat':
|
||||
return TransferModel(modelchoice='xception_concat',
|
||||
num_out_classes=num_out_classes)
|
||||
else:
|
||||
raise NotImplementedError(modelname)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model, image_size, *_ = model_selection('xception', num_out_classes=2)
|
||||
print(model)
|
||||
model = model.cuda()
|
||||
from torchsummary import summary
|
||||
input_s = (3, image_size, image_size)
|
||||
print(summary(model, input_s))
|
Loading…
Reference in New Issue