commit
bb6545ac83
|
@ -29,3 +29,6 @@ physics/*.txt
|
|||
physics/*.json
|
||||
# Physics assets
|
||||
physics/models
|
||||
|
||||
|
||||
*/events*
|
||||
|
|
|
@ -94,36 +94,34 @@ __global__ void fill(unsigned char * img)
|
|||
|
||||
|
||||
|
||||
__global__ void merge(unsigned char * img_all, unsigned char * img, int n, int stride)
|
||||
__global__ void merge(unsigned char * img_all, unsigned char * img, float * selection, int n, int stride)
|
||||
{
|
||||
int x = blockIdx.x * TILE_DIM + threadIdx.x;
|
||||
int y = blockIdx.y * TILE_DIM + threadIdx.y;
|
||||
int width = gridDim.x * TILE_DIM;
|
||||
int idx = 0;
|
||||
|
||||
float sum = 0;
|
||||
float weight = 0;
|
||||
for (int j = 0; j < TILE_DIM; j+= BLOCK_ROWS) {
|
||||
|
||||
int nz = 0;
|
||||
for (idx = 0; idx < n; idx ++)
|
||||
if (img_all[stride * idx + 3*((y+j)*width + x)] + img_all[stride * idx + 3*((y+j)*width + x) + 1] + img_all[stride * idx + 3*((y+j)*width + x) + 2] > 0)
|
||||
nz +=1 ;
|
||||
sum = 0;
|
||||
for (idx = 0; idx < n; idx ++) sum += selection[idx * stride + ((y+j)*width + x)];
|
||||
//printf("%f\n", sum);
|
||||
|
||||
img[3*((y+j)*width + x)] = 0;
|
||||
img[3*((y+j)*width + x)+1] = 0;
|
||||
img[3*((y+j)*width + x)+2] = 0;
|
||||
|
||||
|
||||
if (nz > 0)
|
||||
for (idx = 0; idx < n; idx ++) {
|
||||
|
||||
img[3*((y+j)*width + x)] += img_all[idx * stride + 3*((y+j)*width + x)] / nz;
|
||||
img[3*((y+j)*width + x)+1] += img_all[idx * stride + 3*((y+j)*width + x) + 1] / nz;
|
||||
img[3*((y+j)*width + x)+2] += img_all[idx * stride + 3*((y+j)*width + x) + 2] / nz;
|
||||
weight = selection[idx * stride + ((y+j)*width + x)] / (sum + 1e-4);
|
||||
|
||||
img[3*((y+j)*width + x)] += (unsigned char) (img_all[idx * stride * 3 + 3*((y+j)*width + x)] * weight);
|
||||
img[3*((y+j)*width + x)+1] += (unsigned char) (img_all[idx * stride * 3 + 3*((y+j)*width + x) + 1] * weight);
|
||||
img[3*((y+j)*width + x)+2] += (unsigned char)(img_all[idx * stride * 3 + 3*((y+j)*width + x) + 2] * weight);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,9 +209,61 @@ __global__ void render_depth(float *points3d_polar, unsigned int * depth_render)
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void get_average(unsigned char * img, int * nz, int * average, int scale)
|
||||
{
|
||||
int x = blockIdx.x * TILE_DIM + threadIdx.x;
|
||||
int y = blockIdx.y * TILE_DIM + threadIdx.y;
|
||||
int width = gridDim.x * TILE_DIM;
|
||||
int h = width /2;
|
||||
|
||||
for (int j = 0; j < TILE_DIM; j+= BLOCK_ROWS)
|
||||
{
|
||||
int iw = x;
|
||||
int ih = y + j;
|
||||
|
||||
if (img[3*(ih*width + iw)] + img[3*(ih*width + iw)+1] + img[3*(ih*width + iw)+2] > 0)
|
||||
{
|
||||
//nz[ih/3 * width + iw/3] += 1;
|
||||
//average[3*(ih/3*width + iw/3)] += (int)img[3*(ih*width + iw)];
|
||||
//average[3*(ih/3*width + iw/3)+1] += (int)img[3*(ih*width + iw)+1];
|
||||
//average[3*(ih/3*width + iw/3)+2] += (int)img[3*(ih*width + iw)+2];
|
||||
|
||||
atomicAdd(&(nz[ih/scale * width + iw/scale]), 1);
|
||||
atomicAdd(&(average[3*(ih/scale*width + iw/scale)]), (int)img[3*(ih*width + iw)]);
|
||||
atomicAdd(&(average[3*(ih/scale*width + iw/scale)+1]), (int)img[3*(ih*width + iw)+1]);
|
||||
atomicAdd(&(average[3*(ih/scale*width + iw/scale)+2]), (int)img[3*(ih*width + iw)+2]);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void render_final(float *points3d_polar, float * depth_render, int * img, int * render, int s)
|
||||
__global__ void fill_with_average(unsigned char *img, int * nz, int * average, int scale)
|
||||
{
|
||||
int x = blockIdx.x * TILE_DIM + threadIdx.x;
|
||||
int y = blockIdx.y * TILE_DIM + threadIdx.y;
|
||||
int width = gridDim.x * TILE_DIM;
|
||||
int h = width /2;
|
||||
|
||||
for (int j = 0; j < TILE_DIM; j+= BLOCK_ROWS)
|
||||
{
|
||||
int iw = x;
|
||||
int ih = y + j;
|
||||
|
||||
if ((img[3*(ih*width + iw)] + img[3*(ih*width + iw)+1] + img[3*(ih*width + iw)+2] == 0) && (nz[ih/scale * width + iw/scale] > 0))
|
||||
{
|
||||
img[3*(ih*width + iw)] = (unsigned char)(average[3*(ih/scale*width + iw/scale)] / nz[ih/scale * width + iw/scale]);
|
||||
img[3*(ih*width + iw) + 1] = (unsigned char)(average[3*(ih/scale*width + iw/scale) + 1] / nz[ih/scale * width + iw/scale]);
|
||||
img[3*(ih*width + iw) + 2] = (unsigned char)(average[3*(ih/scale*width + iw/scale) + 2] / nz[ih/scale * width + iw/scale]);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
__global__ void render_final(float *points3d_polar, float * selection, float * depth_render, int * img, int * render, int s)
|
||||
{
|
||||
int x = blockIdx.x * TILE_DIM + threadIdx.x;
|
||||
int y = blockIdx.y * TILE_DIM + threadIdx.y;
|
||||
|
@ -305,8 +355,11 @@ __global__ void render_final(float *points3d_polar, float * depth_render, int *
|
|||
if (r > 255) r = 255;
|
||||
if (g > 255) g = 255;
|
||||
if (b > 255) b = 255;
|
||||
|
||||
if ((ity > 0) && (ity < h * s) && (itx > 0) && (ity < w * s)) {
|
||||
render[(ity * w * s + itx)] = r * 256 * 256 + g * 256 + b;
|
||||
selection[(ity * w * s + itx)] = 1.0 / abs(det);
|
||||
}
|
||||
//printf("%f\n", selection[(ity * w * s + itx)]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -339,6 +392,11 @@ void render(int n, int h,int w, int s, unsigned char * img, float * depth,float
|
|||
float *d_depth_render;
|
||||
float *d_3dpoint, *d_3dpoint_after;
|
||||
|
||||
float * d_selection;
|
||||
|
||||
int * nz;
|
||||
int * average;
|
||||
|
||||
int *d_render2, *d_img2;
|
||||
|
||||
cudaMalloc((void **)&d_img, frame_mem_size);
|
||||
|
@ -351,10 +409,18 @@ void render(int n, int h,int w, int s, unsigned char * img, float * depth,float
|
|||
cudaMalloc((void **)&d_pose, sizeof(float) * 16);
|
||||
cudaMalloc((void **)&d_render2, render_mem_size * sizeof(int));
|
||||
cudaMalloc((void **)&d_img2, render_mem_size * sizeof(int));
|
||||
cudaMalloc((void **)&d_selection, render_mem_size * sizeof(float) * n);
|
||||
|
||||
cudaMalloc((void **)&nz, render_mem_size * sizeof(int));
|
||||
cudaMalloc((void **)&average, render_mem_size * sizeof(int) * 3);
|
||||
|
||||
cudaMemcpy(d_depth_render, depth_render, render_mem_size * sizeof(float), cudaMemcpyHostToDevice);
|
||||
cudaMemset(d_render_all, 0, render_mem_size * sizeof(unsigned char) * 3 * n);
|
||||
cudaMemset(d_selection, 0, render_mem_size * sizeof(float) * n);
|
||||
|
||||
cudaMemset(nz, 0, render_mem_size * sizeof(int));
|
||||
cudaMemset(average, 0, render_mem_size * sizeof(int) * 3);
|
||||
|
||||
|
||||
int idx;
|
||||
for (idx = 0; idx < n; idx ++) {
|
||||
|
@ -376,15 +442,27 @@ void render(int n, int h,int w, int s, unsigned char * img, float * depth,float
|
|||
|
||||
char_to_int <<< dimGrid, dimBlock >>> (d_img2, d_img);
|
||||
|
||||
render_final <<< dimGrid, dimBlock >>> (d_3dpoint_after, d_depth_render, d_img2, d_render2, s);
|
||||
render_final <<< dimGrid, dimBlock >>> (d_3dpoint_after, &(d_selection[idx * nx * ny * s * s]), d_depth_render, d_img2, d_render2, s);
|
||||
|
||||
int_to_char <<< dimGrid2, dimBlock >>> (d_render2, d_render);
|
||||
//int_to_char <<< dimGrid2, dimBlock >>> (d_render2, d_render);
|
||||
int_to_char <<< dimGrid2, dimBlock >>> (d_render2, &(d_render_all[idx * nx * ny * s * s * 3]));
|
||||
|
||||
fill <<< dimGrid2, dimBlock >>> (&(d_render_all[idx * nx * ny * s * s * 3]));
|
||||
//fill <<< dimGrid2, dimBlock >>> (&(d_render_all[idx * nx * ny * s * s * 3]));
|
||||
}
|
||||
|
||||
merge <<< dimGrid2, dimBlock >>> (d_render_all, d_render, n, nx * ny * s * s * 3);
|
||||
merge <<< dimGrid2, dimBlock >>> (d_render_all, d_render, d_selection, n, nx * ny * s * s);
|
||||
|
||||
/*int fill_size[8] = {3, 5, 10, 20, 50, 100, 200};
|
||||
for (int j = 0; j < 8; j++) {
|
||||
cudaMemset(nz, 0, render_mem_size * sizeof(int));
|
||||
cudaMemset(average, 0, render_mem_size * sizeof(int) * 3);
|
||||
get_average <<< dimGrid2, dimBlock >>> (d_render, nz, average, fill_size[j]);
|
||||
fill_with_average <<< dimGrid2, dimBlock >>> (d_render, nz, average, fill_size[j]);
|
||||
}*/
|
||||
cudaMemset(nz, 0, render_mem_size * sizeof(int));
|
||||
cudaMemset(average, 0, render_mem_size * sizeof(int) * 3);
|
||||
get_average <<< dimGrid2, dimBlock >>> (d_render, nz, average, 3);
|
||||
fill_with_average <<< dimGrid2, dimBlock >>> (d_render, nz, average, 3);
|
||||
|
||||
cudaMemcpy(render, d_render, render_mem_size * sizeof(unsigned char) * 3 , cudaMemcpyDeviceToHost);
|
||||
|
||||
|
@ -398,6 +476,9 @@ void render(int n, int h,int w, int s, unsigned char * img, float * depth,float
|
|||
cudaFree(d_3dpoint_after);
|
||||
cudaFree(d_pose);
|
||||
cudaFree(d_render_all);
|
||||
cudaFree(d_selection);
|
||||
cudaFree(nz);
|
||||
cudaFree(average);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from numpy import cos, sin
|
|||
from profiler import Profiler
|
||||
from multiprocessing.dummy import Process
|
||||
|
||||
from datasets import ViewDataSet3D
|
||||
from realenv.data.datasets import ViewDataSet3D
|
||||
|
||||
|
||||
# In[2]:
|
||||
|
@ -45,7 +45,7 @@ d = ViewDataSet3D(root='/home/fei/Downloads/highres_tiny/', transform = np.array
|
|||
|
||||
scene_dict = dict(zip(d.scenes, range(len(d.scenes))))
|
||||
|
||||
model_id = scene_dict.keys()[0]
|
||||
model_id = scene_dict.keys()[1]
|
||||
scene_id = scene_dict[model_id]
|
||||
|
||||
uuids, rts = d.get_scene_info(scene_id)
|
||||
|
@ -113,11 +113,12 @@ showsz = 1024
|
|||
show = np.zeros((showsz,showsz * 2,3),dtype='uint8')
|
||||
|
||||
this_depth = (128 * depths[topk[0]]).astype(np.float32)
|
||||
for i in range(50):
|
||||
for i in range(5):
|
||||
with Profiler("Render pointcloud"):
|
||||
cuda_pc.render(ct.c_int(len(imgs_topk)),
|
||||
ct.c_int(imgs_topk[0].shape[0]),
|
||||
ct.c_int(imgs_topk[0].shape[1]),
|
||||
ct.c_int(1),
|
||||
imgs_topk.ctypes.data_as(ct.c_void_p),
|
||||
depths_topk.ctypes.data_as(ct.c_void_p),
|
||||
np.asarray(poses_after, dtype = np.float32).ctypes.data_as(ct.c_void_p),
|
||||
|
@ -126,20 +127,3 @@ for i in range(50):
|
|||
)
|
||||
|
||||
Image.fromarray(show).save('imgs/test%04d.png' % i)
|
||||
# In[ ]:
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
|
||||
|
||||
# In[ ]:
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -41,9 +41,9 @@ def get_model_path(idx=0):
|
|||
|
||||
|
||||
class ViewDataSet3D(data.Dataset):
|
||||
def __init__(self, train=True, transform=None, mist_transform=None, loader=default_loader, seqlen=5, debug=False, dist_filter = None, off_3d = True, off_pc_render = True):
|
||||
def __init__(self, root, train=True, transform=None, mist_transform=None, loader=default_loader, seqlen=5, debug=False, dist_filter = None, off_3d = True, off_pc_render = True):
|
||||
print ('Processing the data:')
|
||||
self.root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dataset')
|
||||
self.root = root
|
||||
self.fofn = self.root + '_fofn'+str(int(train))+'.pkl'
|
||||
self.train = train
|
||||
self.loader = loader
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
import numpy as np
|
||||
import ctypes as ct
|
||||
import cv2
|
||||
import sys
|
||||
import argparse
|
||||
from datasets import ViewDataSet3D
|
||||
import torch
|
||||
from torchvision import datasets, transforms
|
||||
from torch.autograd import Variable
|
||||
import time
|
||||
from numpy import cos, sin
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from scipy.signal import convolve2d
|
||||
from scipy.interpolate import griddata
|
||||
import scipy
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
dll=np.ctypeslib.load_library('../core/render/render_cuda_f','.')
|
||||
|
||||
# In[6]:
|
||||
|
||||
def render(imgs, depths, pose, poses, tdepth):
|
||||
global fps
|
||||
t0 = time.time()
|
||||
showsz = imgs[0].shape[0]
|
||||
nimgs = len(imgs)
|
||||
show=np.zeros((showsz,showsz * 2,3),dtype='uint8')
|
||||
target_depth = (128 * tdepth[:,:,0]).astype(np.float32)
|
||||
|
||||
imgs = np.array(imgs)
|
||||
depths = np.array(depths).flatten()
|
||||
|
||||
pose_after = [pose.dot(np.linalg.inv(poses[0])).dot(poses[i]).astype(np.float32) for i in range(len(imgs))]
|
||||
pose_after = np.array(pose_after)
|
||||
|
||||
dll.render(ct.c_int(len(imgs)),
|
||||
ct.c_int(imgs[i].shape[0]),
|
||||
ct.c_int(imgs[i].shape[1]),
|
||||
ct.c_int(1),
|
||||
imgs.ctypes.data_as(ct.c_void_p),
|
||||
depths.ctypes.data_as(ct.c_void_p),
|
||||
pose_after.ctypes.data_as(ct.c_void_p),
|
||||
show.ctypes.data_as(ct.c_void_p),
|
||||
target_depth.ctypes.data_as(ct.c_void_p)
|
||||
)
|
||||
|
||||
return show, target_depth
|
||||
|
||||
# In[7]:
|
||||
|
||||
def generate_data(args):
|
||||
|
||||
idx = args[0]
|
||||
print(idx)
|
||||
d = args[1]
|
||||
outf = args[2]
|
||||
|
||||
print(idx)
|
||||
data = d[idx] ## This operation stalls 95% of the time, CPU heavy
|
||||
sources = data[0]
|
||||
target = data[1]
|
||||
source_depths = data[2]
|
||||
target_depth = data[3]
|
||||
poses = [item.numpy() for item in data[-1]]
|
||||
|
||||
show, _ = render(sources, source_depths, poses[0], poses, target_depth)
|
||||
print(show.shape)
|
||||
|
||||
if idx % 100 == 0:
|
||||
Image.fromarray(show).save('%s/show%d.png' % (outf, idx))
|
||||
Image.fromarray(target).save('%s/target%d.png' % (outf, idx))
|
||||
|
||||
filename = "%s/data_%d.npz" % (outf, idx)
|
||||
if not os.path.isfile(filename):
|
||||
np.savez(file = filename, source = show, depth = target_depth, target = target)
|
||||
|
||||
return show, target_depth, target
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--debug' , action='store_true', help='debug mode')
|
||||
parser.add_argument('--dataroot' , required = True, help='dataset path')
|
||||
parser.add_argument('--outf' , type = str, default = '', help='path of output folder')
|
||||
opt = parser.parse_args()
|
||||
|
||||
|
||||
d = ViewDataSet3D(root=opt.dataroot, transform = np.array, mist_transform = np.array, seqlen = 5, off_3d = False, train = True)
|
||||
|
||||
p = Pool(6)
|
||||
p.map(generate_data, [(idx, d, opt.outf) for idx in range(len(d))])
|
||||
|
||||
#for i in range(len(d)):
|
||||
# filename = "%s/data_%d.npz" % (opt.outf, i)
|
||||
# print(filename)
|
||||
# if not os.path.isfile(filename):
|
||||
# generate_data([i, d, opt.outf])
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision import datasets, transforms
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
import shutil
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
class AdaptiveNorm2d(nn.Module):
|
||||
def __init__(self, nchannel, momentum = 0.05):
|
||||
super(AdaptiveNorm2d, self).__init__()
|
||||
self.nm = nn.BatchNorm2d(nchannel, momentum = momentum)
|
||||
self.w0 = nn.Parameter(torch.zeros(1))
|
||||
self.w1 = nn.Parameter(torch.ones(1))
|
||||
def forward(self, x):
|
||||
return self.w0.repeat(x.size()) * self.nm(x) + self.w1.repeat(x.size()) * x
|
||||
|
||||
class CompletionNet2(nn.Module):
|
||||
def __init__(self, norm = AdaptiveNorm2d, nf = 64):
|
||||
super(CompletionNet2, self).__init__()
|
||||
|
||||
self.nf = nf
|
||||
alpha = 0.05
|
||||
self.convs = nn.Sequential(
|
||||
nn.Conv2d(5, nf/4, kernel_size = 5, stride = 1, padding = 2),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf/4, nf, kernel_size = 5, stride = 2, padding = 2),
|
||||
norm(nf, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf, nf, kernel_size = 3, stride = 1, padding = 2),
|
||||
norm(nf, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf, nf*4, kernel_size = 5, stride = 2, padding = 1),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf*4, nf * 4, kernel_size = 3, stride = 1, padding = 1),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, padding = 1),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
|
||||
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 2, padding = 2),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 4, padding = 4),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 8, padding = 8),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 16, padding = 16),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 32, padding = 32),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
|
||||
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, padding = 1),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, padding = 1),
|
||||
norm(nf * 4, momentum=alpha),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.ConvTranspose2d(nf * 4, nf , kernel_size = 4, stride = 2, padding = 1),
|
||||
norm(nf , momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf, nf, kernel_size = 3, stride = 1, padding = 1),
|
||||
norm(nf, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.ConvTranspose2d(nf, nf/4, kernel_size = 4, stride = 2, padding = 1),
|
||||
norm(nf/4, momentum=alpha),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf/4, nf/4, kernel_size = 3, stride = 1, padding = 1),
|
||||
norm(nf/4, momentum=alpha),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Conv2d(nf/4, 3, kernel_size = 3, stride = 1, padding = 1),
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
return self.convs(torch.cat([x, mask], 1))
|
||||
|
||||
|
||||
def identity_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv2d') != -1:
|
||||
m.weight.data.fill_(0)
|
||||
o, i, k1, k2 = m.weight.data.size()
|
||||
cx, cy = k1//2, k2//2
|
||||
nc = min(o,i)
|
||||
print(nc)
|
||||
for i in range(nc):
|
||||
m.weight.data[i,i,cx,cy] = 1
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
if m.stride[0] == 2:
|
||||
for i in range(nc):
|
||||
m.weight.data[i+nc,i,cx+1,cy] = 1
|
||||
m.weight.data[i+nc*2,i,cx,cy+1] = 1
|
||||
m.weight.data[i+nc*3,i,cx+1,cy+1] = 1
|
||||
|
||||
|
||||
elif classname.find('ConvTranspose2d') != -1:
|
||||
o, i, k1, k2 = m.weight.data.size()
|
||||
nc = min(o,i)
|
||||
cx, cy = k1//2-1, k2//2-1
|
||||
m.weight.data.fill_(0)
|
||||
for i in range(nc):
|
||||
m.weight.data[i,i,cx,cy] = 1
|
||||
m.weight.data[i+nc,i,cx+1,cy] = 1
|
||||
m.weight.data[i+nc*2,i,cx,cy+1] = 1
|
||||
m.weight.data[i+nc*3,i,cx+1,cy+1] = 1
|
||||
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
|
||||
class Perceptual(nn.Module):
|
||||
def __init__(self, features, early = False):
|
||||
super(Perceptual, self).__init__()
|
||||
self.features = features
|
||||
self.early = early
|
||||
|
||||
def forward(self, x):
|
||||
bs = x.size(0)
|
||||
x = self.features[0](x)
|
||||
x = self.features[1](x)
|
||||
x = self.features[2](x)
|
||||
x0 = x.view(bs,-1,1)
|
||||
x = F.relu(x)
|
||||
x = self.features[4](x)
|
||||
x = self.features[5](x)
|
||||
x = self.features[6](x)
|
||||
x = self.features[7](x)
|
||||
x1 = x.view(bs, -1, 1)
|
||||
x = F.relu(x)
|
||||
x = self.features[9](x)
|
||||
x = self.features[10](x)
|
||||
x = self.features[11](x)
|
||||
x = self.features[12](x)
|
||||
x2 = x.view(bs, -1, 1)
|
||||
x = F.relu(x)
|
||||
x = self.features[14](x)
|
||||
x = self.features[15](x)
|
||||
x = self.features[16](x)
|
||||
x = self.features[17](x)
|
||||
x = self.features[18](x)
|
||||
x = self.features[19](x)
|
||||
x3 = x.view(bs, -1, 1)
|
||||
x = F.relu(x)
|
||||
x = self.features[21](x)
|
||||
x = self.features[22](x)
|
||||
x = self.features[23](x)
|
||||
x = self.features[24](x)
|
||||
x = self.features[25](x)
|
||||
x = self.features[26](x)
|
||||
x4 = x.view(bs, -1, 1)
|
||||
|
||||
if self.early:
|
||||
perfeat = torch.cat([x0, x1, x2], 1)
|
||||
else:
|
||||
perfeat = torch.cat([x0, x1, x2, x3, x4], 1)
|
||||
|
||||
return perfeat
|
||||
|
||||
|
||||
class Discriminator2(nn.Module):
|
||||
|
||||
def __init__(self, pano = False):
|
||||
super(Discriminator2, self).__init__()
|
||||
alpha = 0.05
|
||||
self.pano = pano
|
||||
nf = 64
|
||||
self.nf = nf
|
||||
|
||||
self.convs_global = nn.Sequential(
|
||||
nn.Conv2d(3, nf, kernel_size = 5, stride = 2, padding = 1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf, nf * 2, kernel_size = 5, stride = 2, padding = 1),
|
||||
nn.BatchNorm2d(nf * 2, momentum=alpha),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf * 2, nf * 4, kernel_size = 5, stride = 2, padding = 1),
|
||||
nn.BatchNorm2d(nf * 4, momentum=alpha),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf * 4, nf * 8, kernel_size = 5, stride = 2, padding = 1),
|
||||
nn.BatchNorm2d(nf * 8, momentum=alpha),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf * 8, nf * 8, kernel_size = 5, stride = 2, padding = 1),
|
||||
nn.BatchNorm2d(nf * 8, momentum=alpha),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(nf * 8, nf * 8, kernel_size = 5, stride = 2, padding = 1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
if self.pano:
|
||||
self.fc_global = nn.Linear(nf * 8 * 3 * 7, 1000)
|
||||
else:
|
||||
self.fc_global = nn.Linear(nf * 8 * 3 * 3, 1000)
|
||||
|
||||
|
||||
def forward(self, img):
|
||||
y = self.convs_global(img)
|
||||
|
||||
if self.pano:
|
||||
y = y.view(y.size(0), self.nf * 8 * 3 * 7)
|
||||
else:
|
||||
y = y.view(y.size(0), self.nf * 8 * 3 * 3)
|
||||
|
||||
y = F.relu(self.fc_global(y))
|
||||
|
||||
x = F.log_softmax(y)
|
||||
|
||||
return x
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
import argparse
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torchvision import datasets, transforms
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.autograd import Variable
|
||||
import torchvision.utils as vutils
|
||||
from realenv.data.datasets import PairDataset
|
||||
from completion2 import CompletionNet2, identity_init, Perceptual, Discriminator2
|
||||
from tensorboard import SummaryWriter
|
||||
from datetime import datetime
|
||||
import vision_utils
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
elif classname.find('Linear') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
m.weight.data.normal_(1.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
def crop(source, source_depth, target):
|
||||
bs = source.size(0)
|
||||
source_cropped = Variable(torch.zeros(4*bs, 3, 256, 256)).cuda()
|
||||
source_depth_cropped = Variable(torch.zeros(4*bs, 2, 256, 256)).cuda()
|
||||
target_cropped = Variable(torch.zeros(4*bs, 3, 256, 256)).cuda()
|
||||
|
||||
for i in range(bs):
|
||||
for j in range(4):
|
||||
idx = i * 4 + j
|
||||
blurry_margin = 1024 / 8
|
||||
centerx = np.random.randint(blurry_margin + 128, 1024 - blurry_margin - 128)
|
||||
centery = np.random.randint(128, 1024 * 2 - 128)
|
||||
source_cropped[idx] = source[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
|
||||
source_depth_cropped[idx] = source_depth[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
|
||||
target_cropped[idx] = target[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
|
||||
|
||||
return source_cropped, source_depth_cropped, target_cropped
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataroot', required=True, help='path to dataset')
|
||||
parser.add_argument('--debug' , action='store_true', help='debug mode')
|
||||
parser.add_argument('--imgsize' ,type=int, default = 256, help='image size')
|
||||
parser.add_argument('--batchsize' ,type=int, default = 20, help='batchsize')
|
||||
parser.add_argument('--workers' ,type=int, default = 9, help='number of workers')
|
||||
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
|
||||
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.002')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
|
||||
parser.add_argument('--outf', type=str, default="filler_pano_pc_full", help='output folder')
|
||||
parser.add_argument('--model', type=str, default="", help='model path')
|
||||
parser.add_argument('--cepoch', type=int, default = 0, help='current epoch')
|
||||
parser.add_argument('--loss', type=str, default="perceptual", help='l1 only')
|
||||
parser.add_argument('--init', type=str, default = "iden", help='init method')
|
||||
parser.add_argument('--l1', type=float, default = 0, help='add l1 loss')
|
||||
parser.add_argument('--color_coeff', type=float, default = 0, help='add color match loss')
|
||||
parser.add_argument('--cascade' , action='store_true', help='debug mode')
|
||||
parser.add_argument('--unfiller' , action='store_true', help='debug mode')
|
||||
|
||||
|
||||
|
||||
mean = torch.from_numpy(np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32))
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
|
||||
try:
|
||||
os.makedirs(opt.outf)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
tf = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
mist_tf = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
d = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf)
|
||||
d_test = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf, train = False)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
|
||||
dataloader_test = torch.utils.data.DataLoader(d_test, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
|
||||
|
||||
img = Variable(torch.zeros(opt.batchsize,3 , 1024, 2048)).cuda()
|
||||
maskv = Variable(torch.zeros(opt.batchsize,2, 1024, 2048)).cuda()
|
||||
img_original = Variable(torch.zeros(opt.batchsize,3, 1024, 2048)).cuda()
|
||||
label = Variable(torch.LongTensor(opt.batchsize * 4)).cuda()
|
||||
|
||||
comp = CompletionNet2(norm = nn.BatchNorm2d)
|
||||
|
||||
dis = Discriminator2(pano = False)
|
||||
current_epoch = opt.cepoch
|
||||
|
||||
comp = torch.nn.DataParallel(comp).cuda()
|
||||
|
||||
|
||||
|
||||
if opt.init == 'iden':
|
||||
comp.apply(identity_init)
|
||||
else:
|
||||
comp.apply(weights_init)
|
||||
dis = torch.nn.DataParallel(dis).cuda()
|
||||
dis.apply(weights_init)
|
||||
|
||||
if opt.model != '':
|
||||
comp.load_state_dict(torch.load(opt.model))
|
||||
#dis.load_state_dict(torch.load(opt.model.replace("G", "D")))
|
||||
current_epoch = opt.cepoch
|
||||
|
||||
if opt.cascade:
|
||||
comp2 = CompletionNet2(norm = nn.BatchNorm2d)
|
||||
comp2 = torch.nn.DataParallel(comp2).cuda()
|
||||
if opt.model != '':
|
||||
comp2.load_state_dict(torch.load(opt.model))
|
||||
optimizerG2 = torch.optim.Adam(comp2.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
||||
l2 = nn.MSELoss()
|
||||
#if opt.loss == 'train_init':
|
||||
# params = list(comp.parameters())
|
||||
# sel = np.random.choice(len(params), len(params)/2, replace=False)
|
||||
# params_sel = [params[i] for i in sel]
|
||||
# optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
#
|
||||
#else:
|
||||
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
optimizerD = torch.optim.Adam(dis.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
||||
curriculum = (200000, 300000) # step to start D training and G training, slightly different from the paper
|
||||
alpha = 0.004
|
||||
|
||||
errG_data = 0
|
||||
errD_data = 0
|
||||
|
||||
vgg16 = models.vgg16(pretrained = False)
|
||||
vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
|
||||
feat = vgg16.features
|
||||
p = torch.nn.DataParallel(Perceptual(feat, early = (opt.loss == 'early'))).cuda()
|
||||
|
||||
for param in p.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
test_loader_enum = enumerate(dataloader_test)
|
||||
for epoch in range(current_epoch, opt.nepoch):
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
optimizerG.zero_grad()
|
||||
source = data[0]
|
||||
source_depth = data[1]
|
||||
target = data[2]
|
||||
step = i + epoch * len(dataloader)
|
||||
|
||||
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
|
||||
#img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)
|
||||
|
||||
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * mean.view(1,3,1,1).repeat(opt.batchsize,1,1024,2048)
|
||||
source_depth = source_depth[:,:,:,0].unsqueeze(1)
|
||||
#print(source_depth.size(), mask.size())
|
||||
source_depth = torch.cat([source_depth, mask], 1)
|
||||
img.data.copy_(source)
|
||||
maskv.data.copy_(source_depth)
|
||||
img_original.data.copy_(target)
|
||||
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
|
||||
#from IPython import embed; embed()
|
||||
recon = comp(imgc, maskvc)
|
||||
|
||||
if opt.loss == "train_init":
|
||||
loss = l2(recon, imgc[:,:3,:,:])
|
||||
elif opt.loss == 'l1':
|
||||
loss = l2(recon, img_originalc)
|
||||
elif opt.loss == 'perceptual':
|
||||
loss = l2(p(recon), p(img_originalc).detach()) + opt.l1 * l2(recon, img_originalc)
|
||||
elif opt.loss == 'color_stable':
|
||||
loss = l2(p(recon.view(recon.size(0) * 3, 1, 256, 256).repeat(1,3,1,1)), p(img_originalc.view(img_originalc.size(0)*3,1,256,256).repeat(1,3,1,1)).detach())
|
||||
elif opt.loss == 'color_correction':
|
||||
loss = l2(p(recon), p(img_originalc).detach())
|
||||
for scale in [32]:
|
||||
img_originalc_patch = img_originalc.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
|
||||
recon_patch = recon.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
|
||||
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
|
||||
recon_patch_mean = recon_patch.mean(dim = -1)
|
||||
recon_patch_cov = []
|
||||
img_originalc_patch_cov = []
|
||||
|
||||
for j in range(3):
|
||||
recon_patch_cov.append((recon_patch * recon_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
|
||||
img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
|
||||
|
||||
recon_patch_cov_cat = torch.cat(recon_patch_cov,1)
|
||||
img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)
|
||||
|
||||
color_loss = l2(recon_patch_mean, img_originalc_patch_mean) + l2(recon_patch_cov_cat, img_originalc_patch_cov_cat.detach())
|
||||
|
||||
loss += opt.color_coeff * color_loss
|
||||
|
||||
print("color loss %f" % color_loss.data[0])
|
||||
|
||||
loss.backward(retain_graph = True)
|
||||
|
||||
if opt.cascade:
|
||||
optimizerG2.zero_grad()
|
||||
|
||||
recon2 = comp2(torch.cat([recon, imgc[:,3:]], 1), maskvc)
|
||||
loss2 = l2(p(recon2), p(img_originalc).detach())
|
||||
for scale in [32]:
|
||||
img_originalc_patch = img_originalc.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
|
||||
recon2_patch = recon2.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
|
||||
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
|
||||
recon2_patch_mean = recon2_patch.mean(dim = -1)
|
||||
recon2_patch_cov = []
|
||||
img_originalc_patch_cov = []
|
||||
|
||||
for j in range(3):
|
||||
recon2_patch_cov.append((recon2_patch * recon2_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
|
||||
img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
|
||||
|
||||
recon2_patch_cov_cat = torch.cat(recon2_patch_cov,1)
|
||||
img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)
|
||||
|
||||
|
||||
color_loss = l2(recon2_patch_mean, img_originalc_patch_mean) + l2(recon2_patch_cov_cat, img_originalc_patch_cov_cat.detach())
|
||||
|
||||
loss2 += opt.color_coeff * color_loss
|
||||
|
||||
print("color loss %f" % color_loss.data[0])
|
||||
|
||||
loss2.backward(retain_graph = True)
|
||||
print("loss2 %f" % loss2.data[0])
|
||||
optimizerG2.step()
|
||||
|
||||
if i%10 == 0:
|
||||
writer.add_scalar('MSEloss2', loss2.data[0], step)
|
||||
|
||||
|
||||
if step > curriculum[1]:
|
||||
label.data.fill_(1)
|
||||
output = dis(recon)
|
||||
errG = alpha * F.nll_loss(output, label)
|
||||
errG.backward()
|
||||
errG_data = errG.data[0]
|
||||
|
||||
|
||||
#from IPython import embed; embed()
|
||||
if opt.loss == "train_init":
|
||||
for param in comp.parameters():
|
||||
if len(param.size()) == 4:
|
||||
#print(param.size())
|
||||
nk = param.size()[2]//2
|
||||
if nk > 5:
|
||||
param.grad[:nk, :,:,:] = 0
|
||||
|
||||
optimizerG.step()
|
||||
|
||||
|
||||
|
||||
# Train D:
|
||||
if step > curriculum[0]:
|
||||
optimizerD.zero_grad()
|
||||
label.data.fill_(0)
|
||||
output = dis(recon.detach())
|
||||
#print(output)
|
||||
errD_fake = alpha * F.nll_loss(output, label)
|
||||
errD_fake.backward(retain_graph = True)
|
||||
|
||||
output = dis(img_originalc)
|
||||
#print(output)
|
||||
label.data.fill_(1)
|
||||
errD_real = alpha * F.nll_loss(output, label)
|
||||
errD_real.backward()
|
||||
optimizerD.step()
|
||||
errD_data = errD_real.data[0] + errD_fake.data[0]
|
||||
|
||||
|
||||
print('[%d/%d][%d/%d] %d MSEloss: %f G_loss %f D_loss %f' % (epoch, opt.nepoch, i, len(dataloader), step, loss.data[0], errG_data, errD_data))
|
||||
|
||||
if i%200 == 0:
|
||||
|
||||
test_i, test_data = test_loader_enum.next()
|
||||
if test_i > len(dataloader_test) - 5:
|
||||
test_loader_enum = enumerate(dataloader_test)
|
||||
|
||||
source = test_data[0]
|
||||
source_depth = test_data[1]
|
||||
target = test_data[2]
|
||||
|
||||
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
|
||||
|
||||
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * mean.view(1,3,1,1).repeat(opt.batchsize,1,1024,2048)
|
||||
source_depth = source_depth[:,:,:,0].unsqueeze(1)
|
||||
source_depth = torch.cat([source_depth, mask], 1)
|
||||
img.data.copy_(source)
|
||||
maskv.data.copy_(source_depth)
|
||||
img_original.data.copy_(target)
|
||||
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
|
||||
comp.eval()
|
||||
recon = comp(imgc, maskvc)
|
||||
comp.train()
|
||||
|
||||
if opt.cascade:
|
||||
comp2.eval()
|
||||
recon2 = comp2(torch.cat([recon, imgc[:,3:]], 1), maskvc)
|
||||
comp2.train()
|
||||
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, recon2.data, img_originalc.data], 3)
|
||||
else:
|
||||
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, img_originalc.data], 3)
|
||||
|
||||
|
||||
visual = vutils.make_grid(visual, normalize=True)
|
||||
writer.add_image('image', visual, step)
|
||||
vutils.save_image(visual, '%s/compare%d_%d.png' % (opt.outf, epoch, i), nrow=1)
|
||||
|
||||
if i%10 == 0:
|
||||
writer.add_scalar('MSEloss', loss.data[0], step)
|
||||
writer.add_scalar('G_loss', errG_data, step)
|
||||
writer.add_scalar('D_loss', errD_data, step)
|
||||
|
||||
if i%10000 == 0:
|
||||
torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))
|
||||
torch.save(dis.state_dict(), '%s/compD_epoch%d_%d.pth' % (opt.outf, epoch, i))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,333 @@
|
|||
import argparse
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torchvision import datasets, transforms
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.autograd import Variable
|
||||
import torchvision.utils as vutils
|
||||
from realenv.data.datasets import PairDataset
|
||||
from completion2 import CompletionNet2, identity_init, Perceptual, Discriminator2
|
||||
from tensorboard import SummaryWriter
|
||||
from datetime import datetime
|
||||
import vision_utils
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
elif classname.find('Linear') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
m.weight.data.normal_(1.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
def crop(source, source_depth, target):
|
||||
bs = source.size(0)
|
||||
source_cropped = Variable(torch.zeros(4*bs, 3, 256, 256)).cuda()
|
||||
source_depth_cropped = Variable(torch.zeros(4*bs, 2, 256, 256)).cuda()
|
||||
target_cropped = Variable(torch.zeros(4*bs, 3, 256, 256)).cuda()
|
||||
|
||||
for i in range(bs):
|
||||
for j in range(4):
|
||||
idx = i * 4 + j
|
||||
blurry_margin = 1024 / 8
|
||||
centerx = np.random.randint(blurry_margin + 128, 1024 - blurry_margin - 128)
|
||||
centery = np.random.randint(128, 1024 * 2 - 128)
|
||||
source_cropped[idx] = source[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
|
||||
source_depth_cropped[idx] = source_depth[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
|
||||
target_cropped[idx] = target[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
|
||||
|
||||
return source_cropped, source_depth_cropped, target_cropped
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataroot', required=True, help='path to dataset')
|
||||
parser.add_argument('--debug' , action='store_true', help='debug mode')
|
||||
parser.add_argument('--imgsize' ,type=int, default = 256, help='image size')
|
||||
parser.add_argument('--batchsize' ,type=int, default = 20, help='batchsize')
|
||||
parser.add_argument('--workers' ,type=int, default = 9, help='number of workers')
|
||||
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
|
||||
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.002')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
|
||||
parser.add_argument('--outf', type=str, default="filler_pano_pc_full", help='output folder')
|
||||
parser.add_argument('--model', type=str, default="", help='model path')
|
||||
parser.add_argument('--cepoch', type=int, default = 0, help='current epoch')
|
||||
parser.add_argument('--loss', type=str, default="perceptual", help='l1 only')
|
||||
parser.add_argument('--init', type=str, default = "iden", help='init method')
|
||||
parser.add_argument('--l1', type=float, default = 0, help='add l1 loss')
|
||||
parser.add_argument('--color_coeff', type=float, default = 0, help='add color match loss')
|
||||
parser.add_argument('--unfiller' , action='store_true', help='debug mode')
|
||||
|
||||
|
||||
|
||||
mean = torch.from_numpy(np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32))
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
|
||||
try:
|
||||
os.makedirs(opt.outf)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
tf = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
mist_tf = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
d = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf)
|
||||
d_test = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf, train = False)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
|
||||
dataloader_test = torch.utils.data.DataLoader(d_test, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
|
||||
|
||||
img = Variable(torch.zeros(opt.batchsize,3 , 1024, 2048)).cuda()
|
||||
maskv = Variable(torch.zeros(opt.batchsize,2, 1024, 2048)).cuda()
|
||||
img_original = Variable(torch.zeros(opt.batchsize,3, 1024, 2048)).cuda()
|
||||
label = Variable(torch.LongTensor(opt.batchsize * 4)).cuda()
|
||||
|
||||
comp = CompletionNet2(norm = nn.BatchNorm2d, nf = 24)
|
||||
|
||||
dis = Discriminator2(pano = False)
|
||||
current_epoch = opt.cepoch
|
||||
|
||||
comp = torch.nn.DataParallel(comp).cuda()
|
||||
|
||||
|
||||
|
||||
if opt.init == 'iden':
|
||||
comp.apply(identity_init)
|
||||
else:
|
||||
comp.apply(weights_init)
|
||||
dis = torch.nn.DataParallel(dis).cuda()
|
||||
dis.apply(weights_init)
|
||||
|
||||
if opt.model != '':
|
||||
comp.load_state_dict(torch.load(opt.model))
|
||||
#dis.load_state_dict(torch.load(opt.model.replace("G", "D")))
|
||||
current_epoch = opt.cepoch
|
||||
|
||||
if opt.unfiller:
|
||||
comp2 = CompletionNet2(norm = nn.BatchNorm2d, nf = 24)
|
||||
comp2 = torch.nn.DataParallel(comp2).cuda()
|
||||
if opt.model != '':
|
||||
comp2.load_state_dict(torch.load(opt.model))
|
||||
optimizerG2 = torch.optim.Adam(comp2.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
||||
l2 = nn.MSELoss()
|
||||
#if opt.loss == 'train_init':
|
||||
# params = list(comp.parameters())
|
||||
# sel = np.random.choice(len(params), len(params)/2, replace=False)
|
||||
# params_sel = [params[i] for i in sel]
|
||||
# optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
#
|
||||
#else:
|
||||
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
optimizerD = torch.optim.Adam(dis.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
||||
curriculum = (200000, 300000) # step to start D training and G training, slightly different from the paper
|
||||
alpha = 0.004
|
||||
|
||||
errG_data = 0
|
||||
errD_data = 0
|
||||
|
||||
vgg16 = models.vgg16(pretrained = False)
|
||||
vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
|
||||
feat = vgg16.features
|
||||
p = torch.nn.DataParallel(Perceptual(feat, early = (opt.loss == 'early'))).cuda()
|
||||
|
||||
for param in p.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
test_loader_enum = enumerate(dataloader_test)
|
||||
for epoch in range(current_epoch, opt.nepoch):
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
optimizerG.zero_grad()
|
||||
source = data[0]
|
||||
source_depth = data[1]
|
||||
target = data[2]
|
||||
step = i + epoch * len(dataloader)
|
||||
|
||||
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
|
||||
#img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)
|
||||
|
||||
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * mean.view(1,3,1,1).repeat(opt.batchsize,1,1024,2048)
|
||||
source_depth = source_depth[:,:,:,0].unsqueeze(1)
|
||||
#print(source_depth.size(), mask.size())
|
||||
source_depth = torch.cat([source_depth, mask], 1)
|
||||
img.data.copy_(source)
|
||||
maskv.data.copy_(source_depth)
|
||||
img_original.data.copy_(target)
|
||||
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
|
||||
#from IPython import embed; embed()
|
||||
recon = comp(imgc, maskvc)
|
||||
|
||||
if opt.loss == "train_init":
|
||||
loss = l2(recon, imgc[:,:3,:,:])
|
||||
elif opt.loss == 'l1':
|
||||
loss = l2(recon, img_originalc)
|
||||
elif opt.loss == 'perceptual':
|
||||
loss = l2(p(recon), p(img_originalc).detach()) + opt.l1 * l2(recon, img_originalc)
|
||||
elif opt.loss == 'color_stable':
|
||||
loss = l2(p(recon.view(recon.size(0) * 3, 1, 256, 256).repeat(1,3,1,1)), p(img_originalc.view(img_originalc.size(0)*3,1,256,256).repeat(1,3,1,1)).detach())
|
||||
elif opt.loss == 'color_correction':
|
||||
loss = l2(p(recon), p(img_originalc).detach())
|
||||
for scale in [32]:
|
||||
img_originalc_patch = img_originalc.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
|
||||
recon_patch = recon.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
|
||||
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
|
||||
recon_patch_mean = recon_patch.mean(dim = -1)
|
||||
recon_patch_cov = []
|
||||
img_originalc_patch_cov = []
|
||||
|
||||
for j in range(3):
|
||||
recon_patch_cov.append((recon_patch * recon_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
|
||||
img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
|
||||
|
||||
recon_patch_cov_cat = torch.cat(recon_patch_cov,1)
|
||||
img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)
|
||||
|
||||
color_loss = l2(recon_patch_mean, img_originalc_patch_mean) + l2(recon_patch_cov_cat, img_originalc_patch_cov_cat.detach())
|
||||
|
||||
loss += opt.color_coeff * color_loss
|
||||
|
||||
print("color loss %f" % color_loss.data[0])
|
||||
|
||||
loss.backward(retain_graph = True)
|
||||
|
||||
if opt.unfiller:
|
||||
optimizerG2.zero_grad()
|
||||
recon2 = comp2(img_originalc, maskvc)
|
||||
loss2 = l2(p(recon2), p(recon).detach())
|
||||
for scale in [32]:
|
||||
img_originalc_patch = recon.detach().view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
|
||||
recon2_patch = recon2.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
|
||||
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
|
||||
recon2_patch_mean = recon2_patch.mean(dim = -1)
|
||||
recon2_patch_cov = []
|
||||
img_originalc_patch_cov = []
|
||||
|
||||
for j in range(3):
|
||||
recon2_patch_cov.append((recon2_patch * recon2_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
|
||||
img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
|
||||
|
||||
recon2_patch_cov_cat = torch.cat(recon2_patch_cov,1)
|
||||
img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)
|
||||
|
||||
|
||||
color_loss = l2(recon2_patch_mean, img_originalc_patch_mean) + l2(recon2_patch_cov_cat, img_originalc_patch_cov_cat.detach())
|
||||
|
||||
loss2 += opt.color_coeff * color_loss
|
||||
|
||||
print("color loss %f" % color_loss.data[0])
|
||||
|
||||
loss2.backward(retain_graph = True)
|
||||
print("loss2 %f" % loss2.data[0])
|
||||
optimizerG2.step()
|
||||
|
||||
if i%10 == 0:
|
||||
writer.add_scalar('MSEloss2', loss2.data[0], step)
|
||||
|
||||
|
||||
if step > curriculum[1]:
|
||||
label.data.fill_(1)
|
||||
output = dis(recon)
|
||||
errG = alpha * F.nll_loss(output, label)
|
||||
errG.backward()
|
||||
errG_data = errG.data[0]
|
||||
|
||||
|
||||
#from IPython import embed; embed()
|
||||
if opt.loss == "train_init":
|
||||
for param in comp.parameters():
|
||||
if len(param.size()) == 4:
|
||||
#print(param.size())
|
||||
nk = param.size()[2]//2
|
||||
if nk > 5:
|
||||
param.grad[:nk, :,:,:] = 0
|
||||
|
||||
optimizerG.step()
|
||||
|
||||
|
||||
|
||||
# Train D:
|
||||
if step > curriculum[0]:
|
||||
optimizerD.zero_grad()
|
||||
label.data.fill_(0)
|
||||
output = dis(recon.detach())
|
||||
#print(output)
|
||||
errD_fake = alpha * F.nll_loss(output, label)
|
||||
errD_fake.backward(retain_graph = True)
|
||||
|
||||
output = dis(img_originalc)
|
||||
#print(output)
|
||||
label.data.fill_(1)
|
||||
errD_real = alpha * F.nll_loss(output, label)
|
||||
errD_real.backward()
|
||||
optimizerD.step()
|
||||
errD_data = errD_real.data[0] + errD_fake.data[0]
|
||||
|
||||
|
||||
print('[%d/%d][%d/%d] %d MSEloss: %f G_loss %f D_loss %f' % (epoch, opt.nepoch, i, len(dataloader), step, loss.data[0], errG_data, errD_data))
|
||||
|
||||
if i%200 == 0:
|
||||
|
||||
test_i, test_data = test_loader_enum.next()
|
||||
if test_i > len(dataloader_test) - 5:
|
||||
test_loader_enum = enumerate(dataloader_test)
|
||||
|
||||
source = test_data[0]
|
||||
source_depth = test_data[1]
|
||||
target = test_data[2]
|
||||
|
||||
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
|
||||
|
||||
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * mean.view(1,3,1,1).repeat(opt.batchsize,1,1024,2048)
|
||||
source_depth = source_depth[:,:,:,0].unsqueeze(1)
|
||||
source_depth = torch.cat([source_depth, mask], 1)
|
||||
img.data.copy_(source)
|
||||
maskv.data.copy_(source_depth)
|
||||
img_original.data.copy_(target)
|
||||
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
|
||||
comp.eval()
|
||||
recon = comp(imgc, maskvc)
|
||||
comp.train()
|
||||
|
||||
if opt.unfiller:
|
||||
comp2.eval()
|
||||
recon2 = comp2(img_originalc, maskvc)
|
||||
comp2.train()
|
||||
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, recon2.data, img_originalc.data], 3)
|
||||
else:
|
||||
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, img_originalc.data], 3)
|
||||
|
||||
|
||||
visual = vutils.make_grid(visual, normalize=True)
|
||||
writer.add_image('image', visual, step)
|
||||
vutils.save_image(visual, '%s/compare%d_%d.png' % (opt.outf, epoch, i), nrow=1)
|
||||
|
||||
if i%10 == 0:
|
||||
writer.add_scalar('MSEloss', loss.data[0], step)
|
||||
writer.add_scalar('G_loss', errG_data, step)
|
||||
writer.add_scalar('D_loss', errD_data, step)
|
||||
|
||||
if i%10000 == 0:
|
||||
torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))
|
||||
torch.save(dis.state_dict(), '%s/compD_epoch%d_%d.pth' % (opt.outf, epoch, i))
|
||||
|
||||
if opt.unfiller:
|
||||
torch.save(comp2.state_dict(), '%s/compG2_epoch%d_%d.pth' % (opt.outf, epoch, i))
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
import math
|
||||
import random
|
||||
from PIL import Image, ImageOps
|
||||
try:
|
||||
import accimage
|
||||
except ImportError:
|
||||
accimage = None
|
||||
import numpy as np
|
||||
import numbers
|
||||
import types
|
||||
import collections
|
||||
|
||||
|
||||
class RandomScale(object):
|
||||
"""Rescale the input PIL.Image to the given size.
|
||||
Args:
|
||||
size (sequence or int): Desired output size. If size is a sequence like
|
||||
(w, h), output size will be matched to this. If size is an int,
|
||||
smaller edge of the image will be matched to this number.
|
||||
i.e, if height > width, then image will be rescaled to
|
||||
(size * height / width, size)
|
||||
interpolation (int, optional): Desired interpolation. Default is
|
||||
``PIL.Image.BILINEAR``
|
||||
"""
|
||||
|
||||
def __init__(self, minsize, maxsize, interpolation=Image.BILINEAR):
|
||||
assert isinstance(minsize, int)
|
||||
assert isinstance(maxsize, int)
|
||||
self.minsize = minsize
|
||||
self.maxsize = maxsize
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (PIL.Image): Image to be scaled.
|
||||
Returns:
|
||||
PIL.Image: Rescaled image.
|
||||
"""
|
||||
|
||||
size = random.randint(self.minsize, self.maxsize)
|
||||
|
||||
if isinstance(size, int):
|
||||
w, h = img.size
|
||||
if (w <= h and w == size) or (h <= w and h == size):
|
||||
return img
|
||||
if w < h:
|
||||
ow = size
|
||||
oh = int(size * h / w)
|
||||
return img.resize((ow, oh), self.interpolation)
|
||||
else:
|
||||
oh = size
|
||||
ow = int(size * w / h)
|
||||
return img.resize((ow, oh), self.interpolation)
|
||||
else:
|
||||
raise NotImplementedError()
|
Loading…
Reference in New Issue