add new trainer

This commit is contained in:
fxia22 2017-08-06 16:14:23 -07:00
parent 7a9f987530
commit e7c79657d0
3 changed files with 197 additions and 19 deletions

View File

@ -169,6 +169,7 @@ class ViewDataSet3D(data.Dataset):
normal_imgs = [self.loader(item) for item in normal_img_paths]
normal_target = self.loader(normal_target_path)
org_img = imgs[0].copy()
if not self.transform is None:
imgs = [self.transform(item) for item in imgs]
@ -176,9 +177,14 @@ class ViewDataSet3D(data.Dataset):
target = self.target_transform(target)
if not self.off_3d:
if not self.transform is None:
mist_imgs = [np.expand_dims(np.array(item).astype(np.float32)/65536.0, 2) for item in mist_imgs]
org_mist = mist_imgs[0][:,:,0].copy()
mist_target = np.expand_dims(np.array(mist_target).astype(np.float32)/65536.0,2)
if not self.depth_trans is None:
mist_imgs = [self.depth_trans(item) for item in mist_imgs]
if not self.target_transform is None:
if not self.depth_trans is None:
mist_target = self.depth_trans(mist_target)
if not self.transform is None:
@ -186,16 +192,13 @@ class ViewDataSet3D(data.Dataset):
if not self.target_transform is None:
normal_target = self.target_transform(normal_target)
mist_imgs = [np.array(item).astype(np.float32)/65536.0 for item in mist_imgs]
mist_target = np.array(mist_target).astype(np.float32)/65536.0
if not self.off_pc_render:
h,w = mist_target.shape
img = np.array(org_img)
h,w,_ = img.shape
render=np.zeros((h,w,3),dtype='uint8')
target_depth = np.zeros((h,w)).astype(np.float32)
print(h,w)
img = np.array(imgs[0])
depth = mist_imgs[0]
depth = org_mist
pose = poses_relative[0].numpy()
x = -pose[1]
y = -pose[0]
@ -213,16 +216,14 @@ class ViewDataSet3D(data.Dataset):
target_depth.ctypes.data_as(ct.c_void_p)
)
if not self.transform is None:
render = self.transform(render)
render = self.transform(Image.fromarray(render))
if not self.depth_trans is None:
target_depth = self.depth_trans(target_depth)
target_depth = self.depth_trans(np.expand_dims(target_depth,2))
if self.off_3d and self.off_pc_render:
if self.off_3d:
return imgs, target, poses_relative
elif self.off_pc_render and (not self.off_3d):
elif self.off_pc_render:
return imgs, target, mist_imgs, mist_target, normal_imgs, normal_target, poses_relative
elif (not self.off_pc_render) and self.off_3d:
return imgs, target, poses_relative, render, target_depth
else:
return imgs, target, mist_imgs, mist_target, normal_imgs, normal_target, poses_relative, render, target_depth

177
filler_panorama_pc.py Normal file
View File

@ -0,0 +1,177 @@
import argparse
import os
import re
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision.utils as vutils
from datasets import ViewDataSet3D
from completion import CompletionNet, Discriminator
from tensorboard import SummaryWriter
from datetime import datetime
import vision_utils
import torch.nn.functional as F
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('Linear') != -1:
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def generate_patches_position(input_imgs):
# create patch position
batchsize = input_imgs.size(0)
imgsize = input_imgs.size(2)
center1 = np.zeros((batchsize, 2)).astype(np.int64)
center2 = np.zeros((batchsize, 2)).astype(np.int64)
hole = np.zeros((batchsize, 2)).astype(np.int64)
for i in range(batchsize):
center1[i,0] = np.random.randint(64, imgsize - 64)
center1[i,1] = np.random.randint(64, imgsize * 2 - 64)
center2[i,0] = np.random.randint(64, imgsize - 64)
center2[i,1] = np.random.randint(64, imgsize * 2 - 64)
hole[i,0] = np.random.randint(48,64 + 1)
hole[i,1] = np.random.randint(48,64 + 1)
return hole, center1, center2
def generate_patches(input_imgs, center):
# create patch, generate new Variable from variable
batchsize = input_imgs.size(0)
patches = Variable(torch.zeros(batchsize, 3, 128, 128)).cuda()
for i in range(batchsize):
patches[i] = input_imgs[i, :, center[i,0] - 64 : center[i,0] + 64, center[i,1] - 64 : center[i,1] + 64]
return patches
def prepare_completion_input(input_imgs, center, hole, mean):
# fill in mean value into holes
batchsize = input_imgs.size(0)
img_holed = input_imgs.clone()
mask = torch.zeros(batchsize, 1, input_imgs.size(2), input_imgs.size(3))
for i in range(batchsize):
img_holed[i, :, center[i,0] - hole[i,0] : center[i,0] + hole[i,0], center[i,1] - hole[i,1] : center[i,1] + hole[i,1]] = mean.view(3,1,1).repeat(1,hole[i,0]*2, hole[i,1]*2)
mask[i, :, center[i,0] - hole[i,0] : center[i,0] + hole[i,0], center[i,1] - hole[i,1] : center[i,1] + hole[i,1]] = 1
return img_holed, mask
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--debug' , action='store_true', help='debug mode')
parser.add_argument('--imgsize' ,type=int, default = 256, help='image size')
parser.add_argument('--batchsize' ,type=int, default = 20, help='batchsize')
parser.add_argument('--workers' ,type=int, default = 6, help='number of workers')
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
parser.add_argument('--lr', type=float, default=0.002, help='learning rate, default=0.002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--outf', type=str, default="filler_pano_pc", help='output folder')
parser.add_argument('--model', type=str, default="", help='model path')
parser.add_argument('--cepoch' ,type=int, default = 0, help='current epoch')
mean = torch.from_numpy(np.array([ 0.45725039, 0.44777581, 0.4146058 ]).astype(np.float32))
opt = parser.parse_args()
print(opt)
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
try:
os.makedirs(opt.outf)
except OSError:
pass
tf = transforms.Compose([
transforms.Scale(opt.imgsize, opt.imgsize * 2),
transforms.ToTensor(),
])
mist_tf = transforms.Compose([
transforms.ToTensor(),
])
d = ViewDataSet3D(root = opt.dataroot, transform=tf, mist_transform = mist_tf, seqlen = 2, debug=opt.debug, off_3d = False, off_pc_render = False)
cudnn.benchmark = True
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = True)
img = Variable(torch.zeros(opt.batchsize,3, 256, 512)).cuda()
maskv = Variable(torch.zeros(opt.batchsize,1, 256, 512)).cuda()
img_original = Variable(torch.zeros(opt.batchsize,3, 256, 512)).cuda()
label = Variable(torch.LongTensor(opt.batchsize)).cuda()
comp = CompletionNet()
dis = Discriminator(pano = True)
current_epoch = 0
comp = torch.nn.DataParallel(comp).cuda()
comp.apply(weights_init)
dis = torch.nn.DataParallel(dis).cuda()
dis.apply(weights_init)
if opt.model != '':
comp.load_state_dict(torch.load(opt.model))
dis.load_state_dict(torch.load(opt.model.replace("G", "D")))
current_epoch = opt.cepoch
l2 = nn.MSELoss()
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
optimizerD = torch.optim.Adam(dis.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
curriculum = (30000, 50000) # step to start D training and G training, slightly different from the paper
alpha = 0.0004
errG_data = 0
errD_data = 0
for epoch in range(current_epoch, opt.nepoch):
for i, data in enumerate(dataloader, 0):
optimizerG.zero_grad()
source = data[-2]
source_depth = data[-1]
target = data[1]
print(source.size(),source_depth.size(),target.size())
step = i + epoch * len(dataloader)
img.data.copy_(source)
maskv.data.copy_(source_depth)
img_original.data.copy_(target)
recon = comp(img, maskv)
loss = l2(recon, img_original)
loss.backward(retain_variables = True)
optimizerG.step()
print('[%d/%d][%d/%d] MSEloss: %f' % (epoch, opt.nepoch, i, len(dataloader), loss.data[0]))
if i%500 == 0:
visual = torch.cat([source, target], 3)
visual = vutils.make_grid(visual, normalize=True)
writer.add_image('image', visual, step)
vutils.save_image(visual, '%s/compare%d_%d.png' % (opt.outf, epoch, i), nrow=1)
if i%10 == 0:
writer.add_scalar('MSEloss', loss.data[0], step)
if i%10000 == 0:
torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))
if __name__ == '__main__':
main()

View File

@ -13,10 +13,10 @@ void render(int h,int w,unsigned char * img, float * depth,float * pose, unsigne
float * points3d_after = (float*) malloc(sizeof(float) * h * w * 3);
float * points3d_polar = (float*) malloc(sizeof(float) * h * w * 3);
for (i = 0; i < 5; i++) {
printf("%f ", pose[i]);
}
printf("\n");
//for (i = 0; i < 5; i++) {
// printf("%f ", pose[i]);
//}
//printf("\n");
for (ih = 0; ih < h; ih ++ ) {
for (iw = 0; iw < w; iw ++ ) {
for (ic = 0; ic < 3; ic ++) {