add new trainer
This commit is contained in:
parent
7a9f987530
commit
e7c79657d0
31
datasets.py
31
datasets.py
|
@ -169,6 +169,7 @@ class ViewDataSet3D(data.Dataset):
|
|||
normal_imgs = [self.loader(item) for item in normal_img_paths]
|
||||
normal_target = self.loader(normal_target_path)
|
||||
|
||||
org_img = imgs[0].copy()
|
||||
|
||||
if not self.transform is None:
|
||||
imgs = [self.transform(item) for item in imgs]
|
||||
|
@ -176,9 +177,14 @@ class ViewDataSet3D(data.Dataset):
|
|||
target = self.target_transform(target)
|
||||
|
||||
if not self.off_3d:
|
||||
if not self.transform is None:
|
||||
|
||||
mist_imgs = [np.expand_dims(np.array(item).astype(np.float32)/65536.0, 2) for item in mist_imgs]
|
||||
org_mist = mist_imgs[0][:,:,0].copy()
|
||||
mist_target = np.expand_dims(np.array(mist_target).astype(np.float32)/65536.0,2)
|
||||
|
||||
if not self.depth_trans is None:
|
||||
mist_imgs = [self.depth_trans(item) for item in mist_imgs]
|
||||
if not self.target_transform is None:
|
||||
if not self.depth_trans is None:
|
||||
mist_target = self.depth_trans(mist_target)
|
||||
|
||||
if not self.transform is None:
|
||||
|
@ -186,16 +192,13 @@ class ViewDataSet3D(data.Dataset):
|
|||
if not self.target_transform is None:
|
||||
normal_target = self.target_transform(normal_target)
|
||||
|
||||
mist_imgs = [np.array(item).astype(np.float32)/65536.0 for item in mist_imgs]
|
||||
mist_target = np.array(mist_target).astype(np.float32)/65536.0
|
||||
|
||||
if not self.off_pc_render:
|
||||
h,w = mist_target.shape
|
||||
img = np.array(org_img)
|
||||
h,w,_ = img.shape
|
||||
render=np.zeros((h,w,3),dtype='uint8')
|
||||
target_depth = np.zeros((h,w)).astype(np.float32)
|
||||
print(h,w)
|
||||
img = np.array(imgs[0])
|
||||
depth = mist_imgs[0]
|
||||
|
||||
depth = org_mist
|
||||
pose = poses_relative[0].numpy()
|
||||
x = -pose[1]
|
||||
y = -pose[0]
|
||||
|
@ -213,16 +216,14 @@ class ViewDataSet3D(data.Dataset):
|
|||
target_depth.ctypes.data_as(ct.c_void_p)
|
||||
)
|
||||
if not self.transform is None:
|
||||
render = self.transform(render)
|
||||
render = self.transform(Image.fromarray(render))
|
||||
if not self.depth_trans is None:
|
||||
target_depth = self.depth_trans(target_depth)
|
||||
target_depth = self.depth_trans(np.expand_dims(target_depth,2))
|
||||
|
||||
if self.off_3d and self.off_pc_render:
|
||||
if self.off_3d:
|
||||
return imgs, target, poses_relative
|
||||
elif self.off_pc_render and (not self.off_3d):
|
||||
elif self.off_pc_render:
|
||||
return imgs, target, mist_imgs, mist_target, normal_imgs, normal_target, poses_relative
|
||||
elif (not self.off_pc_render) and self.off_3d:
|
||||
return imgs, target, poses_relative, render, target_depth
|
||||
else:
|
||||
return imgs, target, mist_imgs, mist_target, normal_imgs, normal_target, poses_relative, render, target_depth
|
||||
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
import argparse
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torchvision import datasets, transforms
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.autograd import Variable
|
||||
import torchvision.utils as vutils
|
||||
from datasets import ViewDataSet3D
|
||||
from completion import CompletionNet, Discriminator
|
||||
from tensorboard import SummaryWriter
|
||||
from datetime import datetime
|
||||
import vision_utils
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
elif classname.find('Linear') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
m.weight.data.normal_(1.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
def generate_patches_position(input_imgs):
|
||||
# create patch position
|
||||
batchsize = input_imgs.size(0)
|
||||
imgsize = input_imgs.size(2)
|
||||
|
||||
center1 = np.zeros((batchsize, 2)).astype(np.int64)
|
||||
center2 = np.zeros((batchsize, 2)).astype(np.int64)
|
||||
|
||||
hole = np.zeros((batchsize, 2)).astype(np.int64)
|
||||
|
||||
for i in range(batchsize):
|
||||
center1[i,0] = np.random.randint(64, imgsize - 64)
|
||||
center1[i,1] = np.random.randint(64, imgsize * 2 - 64)
|
||||
center2[i,0] = np.random.randint(64, imgsize - 64)
|
||||
center2[i,1] = np.random.randint(64, imgsize * 2 - 64)
|
||||
hole[i,0] = np.random.randint(48,64 + 1)
|
||||
hole[i,1] = np.random.randint(48,64 + 1)
|
||||
|
||||
return hole, center1, center2
|
||||
|
||||
def generate_patches(input_imgs, center):
|
||||
# create patch, generate new Variable from variable
|
||||
batchsize = input_imgs.size(0)
|
||||
patches = Variable(torch.zeros(batchsize, 3, 128, 128)).cuda()
|
||||
for i in range(batchsize):
|
||||
patches[i] = input_imgs[i, :, center[i,0] - 64 : center[i,0] + 64, center[i,1] - 64 : center[i,1] + 64]
|
||||
return patches
|
||||
|
||||
def prepare_completion_input(input_imgs, center, hole, mean):
|
||||
# fill in mean value into holes
|
||||
batchsize = input_imgs.size(0)
|
||||
img_holed = input_imgs.clone()
|
||||
mask = torch.zeros(batchsize, 1, input_imgs.size(2), input_imgs.size(3))
|
||||
for i in range(batchsize):
|
||||
img_holed[i, :, center[i,0] - hole[i,0] : center[i,0] + hole[i,0], center[i,1] - hole[i,1] : center[i,1] + hole[i,1]] = mean.view(3,1,1).repeat(1,hole[i,0]*2, hole[i,1]*2)
|
||||
mask[i, :, center[i,0] - hole[i,0] : center[i,0] + hole[i,0], center[i,1] - hole[i,1] : center[i,1] + hole[i,1]] = 1
|
||||
|
||||
return img_holed, mask
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataroot', required=True, help='path to dataset')
|
||||
parser.add_argument('--debug' , action='store_true', help='debug mode')
|
||||
parser.add_argument('--imgsize' ,type=int, default = 256, help='image size')
|
||||
parser.add_argument('--batchsize' ,type=int, default = 20, help='batchsize')
|
||||
parser.add_argument('--workers' ,type=int, default = 6, help='number of workers')
|
||||
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
|
||||
parser.add_argument('--lr', type=float, default=0.002, help='learning rate, default=0.002')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
|
||||
parser.add_argument('--outf', type=str, default="filler_pano_pc", help='output folder')
|
||||
parser.add_argument('--model', type=str, default="", help='model path')
|
||||
parser.add_argument('--cepoch' ,type=int, default = 0, help='current epoch')
|
||||
|
||||
mean = torch.from_numpy(np.array([ 0.45725039, 0.44777581, 0.4146058 ]).astype(np.float32))
|
||||
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
|
||||
|
||||
try:
|
||||
os.makedirs(opt.outf)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
tf = transforms.Compose([
|
||||
transforms.Scale(opt.imgsize, opt.imgsize * 2),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
mist_tf = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
|
||||
d = ViewDataSet3D(root = opt.dataroot, transform=tf, mist_transform = mist_tf, seqlen = 2, debug=opt.debug, off_3d = False, off_pc_render = False)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = True)
|
||||
|
||||
img = Variable(torch.zeros(opt.batchsize,3, 256, 512)).cuda()
|
||||
maskv = Variable(torch.zeros(opt.batchsize,1, 256, 512)).cuda()
|
||||
img_original = Variable(torch.zeros(opt.batchsize,3, 256, 512)).cuda()
|
||||
label = Variable(torch.LongTensor(opt.batchsize)).cuda()
|
||||
|
||||
comp = CompletionNet()
|
||||
dis = Discriminator(pano = True)
|
||||
current_epoch = 0
|
||||
|
||||
comp = torch.nn.DataParallel(comp).cuda()
|
||||
comp.apply(weights_init)
|
||||
dis = torch.nn.DataParallel(dis).cuda()
|
||||
dis.apply(weights_init)
|
||||
|
||||
if opt.model != '':
|
||||
comp.load_state_dict(torch.load(opt.model))
|
||||
dis.load_state_dict(torch.load(opt.model.replace("G", "D")))
|
||||
current_epoch = opt.cepoch
|
||||
|
||||
l2 = nn.MSELoss()
|
||||
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
optimizerD = torch.optim.Adam(dis.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
||||
curriculum = (30000, 50000) # step to start D training and G training, slightly different from the paper
|
||||
alpha = 0.0004
|
||||
|
||||
errG_data = 0
|
||||
errD_data = 0
|
||||
|
||||
for epoch in range(current_epoch, opt.nepoch):
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
optimizerG.zero_grad()
|
||||
source = data[-2]
|
||||
source_depth = data[-1]
|
||||
target = data[1]
|
||||
print(source.size(),source_depth.size(),target.size())
|
||||
step = i + epoch * len(dataloader)
|
||||
|
||||
img.data.copy_(source)
|
||||
maskv.data.copy_(source_depth)
|
||||
img_original.data.copy_(target)
|
||||
recon = comp(img, maskv)
|
||||
loss = l2(recon, img_original)
|
||||
loss.backward(retain_variables = True)
|
||||
|
||||
optimizerG.step()
|
||||
|
||||
print('[%d/%d][%d/%d] MSEloss: %f' % (epoch, opt.nepoch, i, len(dataloader), loss.data[0]))
|
||||
|
||||
if i%500 == 0:
|
||||
visual = torch.cat([source, target], 3)
|
||||
visual = vutils.make_grid(visual, normalize=True)
|
||||
writer.add_image('image', visual, step)
|
||||
vutils.save_image(visual, '%s/compare%d_%d.png' % (opt.outf, epoch, i), nrow=1)
|
||||
|
||||
if i%10 == 0:
|
||||
writer.add_scalar('MSEloss', loss.data[0], step)
|
||||
|
||||
|
||||
if i%10000 == 0:
|
||||
torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -13,10 +13,10 @@ void render(int h,int w,unsigned char * img, float * depth,float * pose, unsigne
|
|||
float * points3d_after = (float*) malloc(sizeof(float) * h * w * 3);
|
||||
float * points3d_polar = (float*) malloc(sizeof(float) * h * w * 3);
|
||||
|
||||
for (i = 0; i < 5; i++) {
|
||||
printf("%f ", pose[i]);
|
||||
}
|
||||
printf("\n");
|
||||
//for (i = 0; i < 5; i++) {
|
||||
// printf("%f ", pose[i]);
|
||||
//}
|
||||
//printf("\n");
|
||||
for (ih = 0; ih < h; ih ++ ) {
|
||||
for (iw = 0; iw < w; iw ++ ) {
|
||||
for (ic = 0; ic < 3; ic ++) {
|
||||
|
|
Loading…
Reference in New Issue