290 lines
10 KiB
Python
290 lines
10 KiB
Python
from __future__ import print_function
|
|
import torch.utils.data as data
|
|
from PIL import Image
|
|
import os
|
|
import os.path
|
|
import errno
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch.autograd import Variable
|
|
import json
|
|
import codecs
|
|
import numpy as np
|
|
import progressbar
|
|
import sys
|
|
import torchvision.transforms as transforms
|
|
import utils
|
|
import argparse
|
|
import ctypes as ct
|
|
import json
|
|
from numpy.linalg import inv
|
|
import pickle
|
|
from datasets import Places365Dataset, ViewDataSet3D
|
|
import matplotlib.pyplot as plt
|
|
import torch.backends.cudnn as cudnn
|
|
from torch import sin, cos
|
|
import torchvision.utils as vutils
|
|
from tensorboardX import SummaryWriter
|
|
from datetime import datetime
|
|
|
|
|
|
class Depth3DGridGen(nn.Module):
|
|
def __init__(self, height, width):
|
|
super(Depth3DGridGen, self).__init__()
|
|
self.height, self.width = height, width
|
|
self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
|
|
self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.height), 0), repeats = self.width, axis = 0).T, 0)
|
|
self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.width), 0), repeats = self.height, axis = 0), 0)
|
|
self.grid[:,:,2] = np.ones([self.height, width])
|
|
self.grid = torch.from_numpy(self.grid.astype(np.float32))
|
|
|
|
self.theta = self.grid[:,:,0] * np.pi/2 + np.pi/2
|
|
self.phi = self.grid[:,:,1] * np.pi
|
|
|
|
self.x = torch.sin(self.theta) * torch.cos(self.phi)
|
|
self.y = torch.sin(self.theta) * torch.sin(self.phi)
|
|
self.z = torch.cos(self.theta)
|
|
|
|
self.grid3d = torch.from_numpy(np.zeros( [self.height, self.width, 4], dtype=np.float32))
|
|
|
|
self.grid3d[:,:,0] = self.x
|
|
self.grid3d[:,:,1] = self.y
|
|
self.grid3d[:,:,2] = self.z
|
|
self.grid3d[:,:,3] = self.grid[:,:,2]
|
|
|
|
|
|
def forward(self, depth, transformation):
|
|
|
|
batchsize = torch.Size([depth.size(0)])
|
|
bs = depth.size(0)
|
|
|
|
self.batchgrid3d = torch.zeros(batchsize + self.grid3d.size())
|
|
|
|
for i in range(depth.size(0)):
|
|
self.batchgrid3d[i] = self.grid3d
|
|
|
|
self.batchgrid3d = Variable(self.batchgrid3d)
|
|
self.batchgrid = torch.zeros(batchsize + self.grid.size())
|
|
|
|
for i in range(depth.size(0)):
|
|
self.batchgrid[i] = self.grid
|
|
|
|
self.batchgrid = Variable(self.batchgrid)
|
|
|
|
if depth.is_cuda:
|
|
self.batchgrid3d = self.batchgrid3d.cuda()
|
|
|
|
x = self.batchgrid3d[:,:,:,0:1] * depth
|
|
y = self.batchgrid3d[:,:,:,1:2] * depth
|
|
z = self.batchgrid3d[:,:,:,2:3] * depth
|
|
|
|
points = torch.cat([x,y,z,self.batchgrid3d[:,:,:,3:]], -1)
|
|
points = points.view(bs, 1024 * 2048, 4)
|
|
points = torch.bmm(points, transformation)
|
|
points = points.view(bs, 1024, 2048, 4)
|
|
|
|
x = points[:,:,:,0:1]
|
|
y = points[:,:,:,1:2]
|
|
z = points[:,:,:,2:3]
|
|
|
|
r = torch.sqrt(x**2 + y**2 + z**2) + 1e-4
|
|
|
|
theta = torch.acos(z/r)/(np.pi/2) - 1
|
|
phi = torch.atan2(y,x)
|
|
phi = phi/np.pi
|
|
output = torch.cat([phi,theta], 3)
|
|
return output
|
|
|
|
class PoseNet(nn.Module):
|
|
def __init__(self):
|
|
super(PoseNet, self).__init__()
|
|
nf = 16
|
|
self.nf = nf
|
|
alpha = 0.05
|
|
self.convs = nn.Sequential(
|
|
nn.Conv2d(8, nf, kernel_size = 5, stride = 2, padding = 1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(nf, nf , kernel_size = 5, stride = 2, padding = 1),
|
|
nn.BatchNorm2d(nf, momentum=alpha),
|
|
nn.ReLU(),
|
|
nn.Conv2d(nf, nf, kernel_size = 5, stride = 2, padding = 1),
|
|
nn.BatchNorm2d(nf, momentum=alpha),
|
|
nn.ReLU(),
|
|
nn.Conv2d(nf, nf * 2, kernel_size = 5, stride = 2, padding = 1),
|
|
nn.BatchNorm2d(nf * 2, momentum=alpha),
|
|
nn.ReLU(),
|
|
nn.Conv2d(nf * 2, nf * 2, kernel_size = 5, stride = 2, padding = 1),
|
|
nn.BatchNorm2d(nf * 2, momentum=alpha),
|
|
nn.ReLU(),
|
|
nn.Conv2d(nf * 2, nf * 2, kernel_size = 5, stride = 2, padding = 1),
|
|
nn.BatchNorm2d(nf * 2, momentum=alpha),
|
|
nn.ReLU(),
|
|
nn.Conv2d(nf * 2, nf * 2, kernel_size = 5, stride = 2, padding = 1),
|
|
nn.ReLU()
|
|
)
|
|
self.lin = nn.Linear(210 * nf, 6)
|
|
|
|
def forward(self, x):
|
|
return self.lin(self.convs(x).view(-1, 210 * self.nf))
|
|
|
|
|
|
|
|
class PoseMat(nn.Module):
|
|
def __init__(self):
|
|
super(PoseMat, self).__init__()
|
|
def forward(self, pose):
|
|
bs = pose.size(0)
|
|
alpha = pose[:,0]
|
|
beta = pose[:,1]
|
|
gamma = pose[:,2]
|
|
|
|
x = pose[:,3]
|
|
y = pose[:,4]
|
|
z = pose[:,5]
|
|
|
|
|
|
mat0 = torch.stack([cos(alpha) * cos(beta), cos(alpha) * sin(beta) * sin(gamma) - sin(alpha) * cos(gamma),
|
|
cos(alpha) * sin(beta) * cos(gamma) + sin(alpha) * sin(gamma),x], 1)
|
|
|
|
mat1 = torch.stack([sin(alpha) * cos(beta), sin(alpha) * sin(beta) * sin(gamma) + cos(alpha) * cos(gamma),
|
|
sin(alpha) * sin(beta) * cos(gamma) - cos(alpha) * sin(gamma),y], 1)
|
|
|
|
mat2 = torch.stack([-sin(beta), cos(beta) * sin(gamma), cos(beta) * cos(gamma), z], 1)
|
|
mat3 = Variable(torch.zeros(bs, 4).cuda())
|
|
mat = torch.stack([mat0, mat1, mat2, mat3], -1)
|
|
|
|
return mat
|
|
|
|
|
|
#net = PoseNet().cuda()
|
|
#input = Variable(torch.rand(1,6,1024,2048)).cuda()
|
|
#print(net(input).size())
|
|
#
|
|
#mat = PoseMat().cuda()
|
|
#input = Variable(torch.zeros(1,6).cuda())
|
|
#print(mat(input))
|
|
|
|
def weights_init(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Conv') != -1:
|
|
m.weight.data.normal_(0.0, 0.02)
|
|
elif classname.find('Linear') != -1:
|
|
m.weight.data.normal_(0.0, 0.02)
|
|
m.bias.data.fill_(0)
|
|
elif classname.find('BatchNorm') != -1:
|
|
m.weight.data.normal_(1.0, 0.02)
|
|
m.bias.data.fill_(0)
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--debug' , action='store_true', help='debug mode')
|
|
parser.add_argument('--dataroot', required=True, help='path to dataset')
|
|
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
|
|
parser.add_argument('--lr', type=float, default=0.002, help='learning rate, default=0.002')
|
|
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
|
|
parser.add_argument('--outf', type=str, default="pose_adj", help='output folder')
|
|
parser.add_argument('--model', type=str, default="", help='model path')
|
|
|
|
opt = parser.parse_args(['--dataroot', '/home/fei/Downloads/highres_tiny/'])
|
|
|
|
|
|
|
|
try:
|
|
os.makedirs(opt.outf)
|
|
except OSError:
|
|
pass
|
|
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
|
|
|
|
|
|
tf = transforms.Compose([
|
|
transforms.Scale(1024, 1024 * 2),
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
mist_tf = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
d = ViewDataSet3D(root=opt.dataroot, debug=opt.debug, transform=tf, mist_transform=mist_tf, seqlen = 2, off_3d = False, off_pc_render = True)
|
|
print(len(d))
|
|
|
|
cudnn.benchmark = True
|
|
|
|
dataloader = torch.utils.data.DataLoader(d, batch_size=1, shuffle=True, num_workers=1, drop_last = True, pin_memory = False)
|
|
gridgen = Depth3DGridGen(1024, 2048)
|
|
net = PoseNet().cuda()
|
|
mat = PoseMat().cuda()
|
|
net.apply(weights_init)
|
|
|
|
|
|
dll2=np.ctypeslib.load_library('occinf','.')
|
|
l1 = nn.L1Loss()
|
|
optimizer = torch.optim.Adam(net.parameters(), lr = 0.002, betas = (0.5, 0.999))
|
|
|
|
for epoch in range(0, 10):
|
|
for i, data in enumerate(dataloader, 0):
|
|
optimizer.zero_grad()
|
|
step = i + epoch * len(dataloader)
|
|
source = Variable(data[0][0]).cuda()
|
|
target = Variable(data[1]).cuda()
|
|
depth = Variable(data[2][0]).cuda().transpose(2,1).transpose(3,2).float() * 255 * 128
|
|
pose = Variable(data[-1][0]).cuda().float().transpose(2,1)
|
|
|
|
depth_np = (depth.cpu().data[0].numpy()[:,:,0]/128.0).astype(np.float32)
|
|
pose_after_np = pose.transpose(2,1).cpu().data[0].numpy().astype(np.float32)
|
|
occ_np = np.zeros((1024,1024 * 2)).astype(np.bool)
|
|
target_depth_np = np.zeros((1024,1024 * 2)).astype(np.uint32)
|
|
target_depth_np[:] = 65535
|
|
|
|
dll2.occinf(ct.c_int(1024),
|
|
ct.c_int(2048),
|
|
depth_np.ctypes.data_as(ct.c_void_p),
|
|
pose_after_np.ctypes.data_as(ct.c_void_p),
|
|
occ_np.ctypes.data_as(ct.c_void_p),
|
|
target_depth_np.ctypes.data_as(ct.c_void_p)
|
|
)
|
|
|
|
mask = torch.from_numpy(occ_np.astype(np.float32))
|
|
mask_v = Variable(mask.view(1,1,1024,2048)).cuda()
|
|
|
|
grid = gridgen(depth, pose)
|
|
sample = F.grid_sample(target, grid).detach()
|
|
|
|
mask1 = 1 - (F.grid_sample(mask_v, grid) > 0).float()
|
|
masked_source0 = source * mask1.repeat(1,3,1,1)
|
|
masked_source_pred0 = sample * mask1.repeat(1,3,1,1)
|
|
before = l1(masked_source_pred0, masked_source0)
|
|
input_deck = torch.cat([sample, source, depth.view(1,1,1024,2048)/128.0, mask_v], 1)
|
|
|
|
adj_pose = net(input_deck)
|
|
pose_mat2 = mat(adj_pose)
|
|
final_mat = torch.bmm(pose, pose_mat2)
|
|
|
|
grid2 = gridgen(depth, final_mat)
|
|
final_source = F.grid_sample(target, grid2)
|
|
mask2 = 1 - (F.grid_sample(mask_v, grid2) > 0).float()
|
|
|
|
masked_source_pred = final_source * mask2.repeat(1,3,1,1)
|
|
masked_source = source * mask2.repeat(1,3,1,1)
|
|
|
|
loss = l1(masked_source_pred, masked_source) + torch.sum(adj_pose ** 2)
|
|
|
|
after = l1(masked_source_pred, masked_source)
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
print(loss.data[0], before.data[0], after.data[0])
|
|
if i%100 == 0:
|
|
visual = torch.cat([sample.data, masked_source_pred.data, masked_source.data, masked_source_pred.data * 0.5 + masked_source.data * 0.5, masked_source_pred0.data * 0.5 + masked_source0.data * 0.5], 3)
|
|
visual = vutils.make_grid(visual, normalize=True)
|
|
vutils.save_image(visual, '%s/compare%d_%d.png' % (opt.outf, epoch, i), nrow=1)
|
|
|
|
if i%10 == 0:
|
|
writer.add_scalar('loss', loss.data[0], step)
|
|
writer.add_scalar('before', before.data[0], step)
|
|
writer.add_scalar('after', after.data[0], step)
|
|
|
|
if i%10000 == 0:
|
|
torch.save(net.state_dict(), '%s/net_epoch%d_%d.pth' % (opt.outf, epoch, i)) |