add ae
This commit is contained in:
parent
b534fc6813
commit
24d728704e
|
@ -0,0 +1,178 @@
|
|||
import argparse
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torchvision import datasets, transforms
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.autograd import Variable
|
||||
import torchvision.utils as vutils
|
||||
from datasets import ViewDataSet3D
|
||||
from completion import CompletionNet, Discriminator
|
||||
from tensorboard import SummaryWriter
|
||||
from datetime import datetime
|
||||
import vision_utils
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
elif classname.find('Linear') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
m.weight.data.normal_(1.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
def generate_patches_position(input_imgs):
|
||||
# create patch position
|
||||
batchsize = input_imgs.size(0)
|
||||
imgsize = input_imgs.size(2)
|
||||
|
||||
center1 = np.zeros((batchsize, 2)).astype(np.int64)
|
||||
center2 = np.zeros((batchsize, 2)).astype(np.int64)
|
||||
|
||||
hole = np.zeros((batchsize, 2)).astype(np.int64)
|
||||
|
||||
for i in range(batchsize):
|
||||
center1[i,0] = np.random.randint(64, imgsize - 64)
|
||||
center1[i,1] = np.random.randint(64, imgsize * 2 - 64)
|
||||
center2[i,0] = np.random.randint(64, imgsize - 64)
|
||||
center2[i,1] = np.random.randint(64, imgsize * 2 - 64)
|
||||
hole[i,0] = np.random.randint(48,64 + 1)
|
||||
hole[i,1] = np.random.randint(48,64 + 1)
|
||||
|
||||
return hole, center1, center2
|
||||
|
||||
def generate_patches(input_imgs, center):
|
||||
# create patch, generate new Variable from variable
|
||||
batchsize = input_imgs.size(0)
|
||||
patches = Variable(torch.zeros(batchsize, 3, 128, 128)).cuda()
|
||||
for i in range(batchsize):
|
||||
patches[i] = input_imgs[i, :, center[i,0] - 64 : center[i,0] + 64, center[i,1] - 64 : center[i,1] + 64]
|
||||
return patches
|
||||
|
||||
def prepare_completion_input(input_imgs, center, hole, mean):
|
||||
# fill in mean value into holes
|
||||
batchsize = input_imgs.size(0)
|
||||
img_holed = input_imgs.clone()
|
||||
mask = torch.zeros(batchsize, 1, input_imgs.size(2), input_imgs.size(3))
|
||||
for i in range(batchsize):
|
||||
img_holed[i, :, center[i,0] - hole[i,0] : center[i,0] + hole[i,0], center[i,1] - hole[i,1] : center[i,1] + hole[i,1]] = mean.view(3,1,1).repeat(1,hole[i,0]*2, hole[i,1]*2)
|
||||
mask[i, :, center[i,0] - hole[i,0] : center[i,0] + hole[i,0], center[i,1] - hole[i,1] : center[i,1] + hole[i,1]] = 1
|
||||
|
||||
return img_holed, mask
|
||||
|
||||
|
||||
def prepare_completion_input2(input_imgs, mean):
|
||||
# fill in mean value into holes
|
||||
batchsize = input_imgs.size(0)
|
||||
img_holed = input_imgs.clone()
|
||||
mask = (torch.zeros(batchsize, 1, input_imgs.size(2), input_imgs.size(3)).random_(0,100) > 50).float()
|
||||
img_holed = img_holed * mask.repeat(1,3,1,1) + (1-mask).repeat(1,3,1,1) * mean.view(1,3,1,1).repeat(batchsize, 1, input_imgs.size(2), input_imgs.size(3))
|
||||
return img_holed, mask
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataroot', required=True, help='path to dataset')
|
||||
parser.add_argument('--debug' , action='store_true', help='debug mode')
|
||||
parser.add_argument('--imgsize' ,type=int, default = 256, help='image size')
|
||||
parser.add_argument('--batchsize' ,type=int, default = 20, help='batchsize')
|
||||
parser.add_argument('--workers' ,type=int, default = 6, help='number of workers')
|
||||
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
|
||||
parser.add_argument('--lr', type=float, default=0.002, help='learning rate, default=0.002')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
|
||||
parser.add_argument('--outf', type=str, default="filler_pano_dae", help='output folder')
|
||||
parser.add_argument('--model', type=str, default="", help='model path')
|
||||
parser.add_argument('--cepoch' ,type=int, default = 0, help='current epoch')
|
||||
|
||||
mean = torch.from_numpy(np.array([ 0.45725039, 0.44777581, 0.4146058 ]).astype(np.float32))
|
||||
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
|
||||
|
||||
try:
|
||||
os.makedirs(opt.outf)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
tf = transforms.Compose([
|
||||
vision_utils.RandomScale(opt.imgsize, int(opt.imgsize * 1.5)),
|
||||
transforms.RandomCrop((opt.imgsize, opt.imgsize * 2)),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
d = ViewDataSet3D(root = opt.dataroot, transform=tf, target_transform = tf, seqlen = 2, debug=opt.debug, off_3d = True)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = True)
|
||||
|
||||
img = Variable(torch.zeros(opt.batchsize,3, 256, 512)).cuda()
|
||||
maskv = Variable(torch.zeros(opt.batchsize,1, 256, 512)).cuda()
|
||||
img_original = Variable(torch.zeros(opt.batchsize,3, 256, 512)).cuda()
|
||||
|
||||
comp = CompletionNet()
|
||||
current_epoch = 0
|
||||
|
||||
comp = torch.nn.DataParallel(comp).cuda()
|
||||
comp.apply(weights_init)
|
||||
|
||||
if opt.model != '':
|
||||
comp.load_state_dict(torch.load(opt.model))
|
||||
current_epoch = opt.cepoch
|
||||
|
||||
l2 = nn.MSELoss()
|
||||
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
||||
errG_data = 0
|
||||
|
||||
for epoch in range(current_epoch, opt.nepoch):
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
|
||||
source, _, _ = data
|
||||
source = source[0]
|
||||
|
||||
step = i + epoch * len(dataloader)
|
||||
optimizerG.zero_grad()
|
||||
|
||||
img_holed, mask = prepare_completion_input2(source, mean)
|
||||
|
||||
# Train G
|
||||
# MSE Loss
|
||||
img.data.copy_(img_holed)
|
||||
img_original.data.copy_(source)
|
||||
|
||||
maskv.data.copy_(mask)
|
||||
recon = comp(img, maskv)
|
||||
loss = l2(recon, img_original)
|
||||
loss.backward(retain_variables = True)
|
||||
optimizerG.step()
|
||||
|
||||
print('[%d/%d][%d/%d] MSEloss: %f' % (epoch, opt.nepoch, i, len(dataloader), loss.data[0]))
|
||||
|
||||
if i%500 == 0:
|
||||
visual = torch.cat([img_original.data, img.data, recon.data], 3)
|
||||
visual = vutils.make_grid(visual, normalize=True)
|
||||
writer.add_image('image', visual, step)
|
||||
|
||||
if i%10 == 0:
|
||||
writer.add_scalar('MSEloss', loss.data[0], step)
|
||||
|
||||
|
||||
if i%10000 == 0:
|
||||
torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
Loading…
Reference in New Issue