add color loss
This commit is contained in:
parent
1473d76b23
commit
17633bc697
|
@ -70,12 +70,9 @@ def main():
|
|||
|
||||
|
||||
mean = torch.from_numpy(np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32))
|
||||
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
|
||||
|
||||
try:
|
||||
os.makedirs(opt.outf)
|
||||
except OSError:
|
||||
|
@ -90,10 +87,12 @@ def main():
|
|||
])
|
||||
|
||||
d = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf)
|
||||
d_test = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf, train = False)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
|
||||
dataloader_test = torch.utils.data.DataLoader(d_test, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
|
||||
|
||||
img = Variable(torch.zeros(opt.batchsize,3 + 4, 1024, 2048)).cuda()
|
||||
maskv = Variable(torch.zeros(opt.batchsize,2, 1024, 2048)).cuda()
|
||||
|
@ -118,16 +117,14 @@ def main():
|
|||
current_epoch = opt.cepoch
|
||||
|
||||
l2 = nn.MSELoss()
|
||||
|
||||
if opt.loss == 'train_init':
|
||||
params = list(comp.parameters())
|
||||
sel = np.random.choice(len(params), len(params)/2, replace=False)
|
||||
params_sel = [params[i] for i in sel]
|
||||
optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
||||
else:
|
||||
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
||||
#if opt.loss == 'train_init':
|
||||
# params = list(comp.parameters())
|
||||
# sel = np.random.choice(len(params), len(params)/2, replace=False)
|
||||
# params_sel = [params[i] for i in sel]
|
||||
# optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
#
|
||||
#else:
|
||||
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
optimizerD = torch.optim.Adam(dis.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
||||
curriculum = (200000, 300000) # step to start D training and G training, slightly different from the paper
|
||||
|
@ -144,6 +141,7 @@ def main():
|
|||
for param in p.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
test_loader_enum = enumerate(dataloader_test)
|
||||
for epoch in range(current_epoch, opt.nepoch):
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
optimizerG.zero_grad()
|
||||
|
@ -153,20 +151,14 @@ def main():
|
|||
step = i + epoch * len(dataloader)
|
||||
|
||||
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
|
||||
|
||||
img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)
|
||||
|
||||
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * img_mean.view(opt.batchsize,3,1,1).repeat(1,1,1024,2048)
|
||||
|
||||
source_depth = source_depth[:,:,:,0].unsqueeze(1)
|
||||
|
||||
#print(source_depth.size(), mask.size())
|
||||
source_depth = torch.cat([source_depth, mask], 1)
|
||||
|
||||
img.data.copy_(source)
|
||||
maskv.data.copy_(source_depth)
|
||||
img_original.data.copy_(target)
|
||||
|
||||
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
|
||||
#from IPython import embed; embed()
|
||||
recon = comp(imgc, maskvc)
|
||||
|
@ -178,9 +170,20 @@ def main():
|
|||
loss = l2(p(recon), p(img_originalc).detach()) + opt.l1 * l2(recon, img_originalc)
|
||||
elif opt.loss == 'color_stable':
|
||||
loss = l2(p(recon.view(recon.size(0) * 3, 1, 256, 256).repeat(1,3,1,1)), p(img_originalc.view(img_originalc.size(0)*3,1,256,256).repeat(1,3,1,1)).detach())
|
||||
elif opt.loss == 'color_correction':
|
||||
#from IPython import embed; embed()
|
||||
img_originalc_patch = img_originalc.view(opt.batchsize * 4,3,8,32,8,32).transpose(4,3).contiguous().view(opt.batchsize * 4,3,8,8,-1)
|
||||
recon_patch = recon.view(opt.batchsize * 4,3,8,32,8,32).transpose(4,3).contiguous().view(opt.batchsize * 4,3,8,8,-1)
|
||||
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
|
||||
img_originalc_patch_std = img_originalc_patch.std(dim=-1)
|
||||
recon_patch_mean = recon_patch.mean(dim = -1)
|
||||
recon_patch_std = recon_patch.std(dim = -1)
|
||||
|
||||
loss.backward(retain_graph = True)
|
||||
color_loss = l2(recon_patch_mean, img_originalc_patch_mean) + l2(recon_patch_std, img_originalc_patch_std)
|
||||
loss = l2(p(recon), p(img_originalc).detach()) + 100 * color_loss
|
||||
print("color loss %f" % color_loss.data[0])
|
||||
|
||||
loss.backward(retain_graph = True)
|
||||
|
||||
if step > curriculum[1]:
|
||||
label.data.fill_(1)
|
||||
|
@ -188,6 +191,17 @@ def main():
|
|||
errG = alpha * F.nll_loss(output, label)
|
||||
errG.backward()
|
||||
errG_data = errG.data[0]
|
||||
|
||||
|
||||
#from IPython import embed; embed()
|
||||
if opt.loss == "train_init":
|
||||
for param in comp.parameters():
|
||||
if len(param.size()) == 4:
|
||||
#print(param.size())
|
||||
nk = param.size()[2]//2
|
||||
if nk > 5:
|
||||
param.grad[:nk, :,:,:] = 0
|
||||
|
||||
optimizerG.step()
|
||||
|
||||
# Train D:
|
||||
|
@ -211,6 +225,26 @@ def main():
|
|||
print('[%d/%d][%d/%d] %d MSEloss: %f G_loss %f D_loss %f' % (epoch, opt.nepoch, i, len(dataloader), step, loss.data[0], errG_data, errD_data))
|
||||
|
||||
if i%200 == 0:
|
||||
|
||||
test_i, test_data = test_loader_enum.next()
|
||||
if test_i > len(dataloader_test) - 5:
|
||||
test_loader_enum = enumerate(dataloader_test)
|
||||
|
||||
source = test_data[0]
|
||||
source_depth = test_data[1]
|
||||
target = test_data[2]
|
||||
|
||||
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
|
||||
img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)
|
||||
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * img_mean.view(opt.batchsize,3,1,1).repeat(1,1,1024,2048)
|
||||
source_depth = source_depth[:,:,:,0].unsqueeze(1)
|
||||
source_depth = torch.cat([source_depth, mask], 1)
|
||||
img.data.copy_(source)
|
||||
maskv.data.copy_(source_depth)
|
||||
img_original.data.copy_(target)
|
||||
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
|
||||
recon = comp(imgc, maskvc)
|
||||
|
||||
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, img_originalc.data], 3)
|
||||
visual = vutils.make_grid(visual, normalize=True)
|
||||
writer.add_image('image', visual, step)
|
||||
|
@ -221,7 +255,6 @@ def main():
|
|||
writer.add_scalar('G_loss', errG_data, step)
|
||||
writer.add_scalar('D_loss', errD_data, step)
|
||||
|
||||
|
||||
if i%10000 == 0:
|
||||
torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))
|
||||
torch.save(dis.state_dict(), '%s/compD_epoch%d_%d.pth' % (opt.outf, epoch, i))
|
||||
|
|
|
@ -175,6 +175,9 @@ conv_mask.weight.data[0,0,:,:] = torch.from_numpy(gkern())
|
|||
|
||||
|
||||
for i in range(len(d)):
|
||||
generate_data([i, d, opt.outf, convs, convs2, conv_mask])
|
||||
|
||||
filename = "%s/data_%d.npz" % (opt.outf, i)
|
||||
if not os.path.isfile(filename):
|
||||
generate_data([i, d, opt.outf, convs, convs2, conv_mask])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue