add color loss

This commit is contained in:
fxia22 2017-10-04 14:41:23 -07:00
parent 1473d76b23
commit 17633bc697
2 changed files with 58 additions and 22 deletions

View File

@ -70,12 +70,9 @@ def main():
mean = torch.from_numpy(np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32))
opt = parser.parse_args()
print(opt)
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
try:
os.makedirs(opt.outf)
except OSError:
@ -90,10 +87,12 @@ def main():
])
d = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf)
d_test = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf, train = False)
cudnn.benchmark = True
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
dataloader_test = torch.utils.data.DataLoader(d_test, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
img = Variable(torch.zeros(opt.batchsize,3 + 4, 1024, 2048)).cuda()
maskv = Variable(torch.zeros(opt.batchsize,2, 1024, 2048)).cuda()
@ -118,16 +117,14 @@ def main():
current_epoch = opt.cepoch
l2 = nn.MSELoss()
if opt.loss == 'train_init':
params = list(comp.parameters())
sel = np.random.choice(len(params), len(params)/2, replace=False)
params_sel = [params[i] for i in sel]
optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
else:
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
#if opt.loss == 'train_init':
# params = list(comp.parameters())
# sel = np.random.choice(len(params), len(params)/2, replace=False)
# params_sel = [params[i] for i in sel]
# optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
#
#else:
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
optimizerD = torch.optim.Adam(dis.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
curriculum = (200000, 300000) # step to start D training and G training, slightly different from the paper
@ -144,6 +141,7 @@ def main():
for param in p.parameters():
param.requires_grad = False
test_loader_enum = enumerate(dataloader_test)
for epoch in range(current_epoch, opt.nepoch):
for i, data in enumerate(dataloader, 0):
optimizerG.zero_grad()
@ -153,20 +151,14 @@ def main():
step = i + epoch * len(dataloader)
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * img_mean.view(opt.batchsize,3,1,1).repeat(1,1,1024,2048)
source_depth = source_depth[:,:,:,0].unsqueeze(1)
#print(source_depth.size(), mask.size())
source_depth = torch.cat([source_depth, mask], 1)
img.data.copy_(source)
maskv.data.copy_(source_depth)
img_original.data.copy_(target)
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
#from IPython import embed; embed()
recon = comp(imgc, maskvc)
@ -178,9 +170,20 @@ def main():
loss = l2(p(recon), p(img_originalc).detach()) + opt.l1 * l2(recon, img_originalc)
elif opt.loss == 'color_stable':
loss = l2(p(recon.view(recon.size(0) * 3, 1, 256, 256).repeat(1,3,1,1)), p(img_originalc.view(img_originalc.size(0)*3,1,256,256).repeat(1,3,1,1)).detach())
elif opt.loss == 'color_correction':
#from IPython import embed; embed()
img_originalc_patch = img_originalc.view(opt.batchsize * 4,3,8,32,8,32).transpose(4,3).contiguous().view(opt.batchsize * 4,3,8,8,-1)
recon_patch = recon.view(opt.batchsize * 4,3,8,32,8,32).transpose(4,3).contiguous().view(opt.batchsize * 4,3,8,8,-1)
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
img_originalc_patch_std = img_originalc_patch.std(dim=-1)
recon_patch_mean = recon_patch.mean(dim = -1)
recon_patch_std = recon_patch.std(dim = -1)
loss.backward(retain_graph = True)
color_loss = l2(recon_patch_mean, img_originalc_patch_mean) + l2(recon_patch_std, img_originalc_patch_std)
loss = l2(p(recon), p(img_originalc).detach()) + 100 * color_loss
print("color loss %f" % color_loss.data[0])
loss.backward(retain_graph = True)
if step > curriculum[1]:
label.data.fill_(1)
@ -188,6 +191,17 @@ def main():
errG = alpha * F.nll_loss(output, label)
errG.backward()
errG_data = errG.data[0]
#from IPython import embed; embed()
if opt.loss == "train_init":
for param in comp.parameters():
if len(param.size()) == 4:
#print(param.size())
nk = param.size()[2]//2
if nk > 5:
param.grad[:nk, :,:,:] = 0
optimizerG.step()
# Train D:
@ -211,6 +225,26 @@ def main():
print('[%d/%d][%d/%d] %d MSEloss: %f G_loss %f D_loss %f' % (epoch, opt.nepoch, i, len(dataloader), step, loss.data[0], errG_data, errD_data))
if i%200 == 0:
test_i, test_data = test_loader_enum.next()
if test_i > len(dataloader_test) - 5:
test_loader_enum = enumerate(dataloader_test)
source = test_data[0]
source_depth = test_data[1]
target = test_data[2]
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * img_mean.view(opt.batchsize,3,1,1).repeat(1,1,1024,2048)
source_depth = source_depth[:,:,:,0].unsqueeze(1)
source_depth = torch.cat([source_depth, mask], 1)
img.data.copy_(source)
maskv.data.copy_(source_depth)
img_original.data.copy_(target)
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
recon = comp(imgc, maskvc)
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, img_originalc.data], 3)
visual = vutils.make_grid(visual, normalize=True)
writer.add_image('image', visual, step)
@ -221,7 +255,6 @@ def main():
writer.add_scalar('G_loss', errG_data, step)
writer.add_scalar('D_loss', errD_data, step)
if i%10000 == 0:
torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))
torch.save(dis.state_dict(), '%s/compD_epoch%d_%d.pth' % (opt.outf, epoch, i))

View File

@ -175,6 +175,9 @@ conv_mask.weight.data[0,0,:,:] = torch.from_numpy(gkern())
for i in range(len(d)):
generate_data([i, d, opt.outf, convs, convs2, conv_mask])
filename = "%s/data_%d.npz" % (opt.outf, i)
if not os.path.isfile(filename):
generate_data([i, d, opt.outf, convs, convs2, conv_mask])