Merge pull request #36 from fxia22/dev

Dev
This commit is contained in:
hzyjerry 2017-10-22 15:50:06 -07:00 committed by GitHub
commit bb6545ac83
11 changed files with 1175 additions and 51 deletions

5
.gitignore vendored
View File

@ -28,4 +28,7 @@ dev/transfer.c
physics/*.txt
physics/*.json
# Physics assets
physics/models
physics/models
*/events*

View File

@ -94,35 +94,33 @@ __global__ void fill(unsigned char * img)
__global__ void merge(unsigned char * img_all, unsigned char * img, int n, int stride)
__global__ void merge(unsigned char * img_all, unsigned char * img, float * selection, int n, int stride)
{
int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
int width = gridDim.x * TILE_DIM;
int idx = 0;
float sum = 0;
float weight = 0;
for (int j = 0; j < TILE_DIM; j+= BLOCK_ROWS) {
sum = 0;
for (idx = 0; idx < n; idx ++) sum += selection[idx * stride + ((y+j)*width + x)];
//printf("%f\n", sum);
int nz = 0;
for (idx = 0; idx < n; idx ++)
if (img_all[stride * idx + 3*((y+j)*width + x)] + img_all[stride * idx + 3*((y+j)*width + x) + 1] + img_all[stride * idx + 3*((y+j)*width + x) + 2] > 0)
nz +=1 ;
img[3*((y+j)*width + x)] = 0;
img[3*((y+j)*width + x)+1] = 0;
img[3*((y+j)*width + x)+2] = 0;
if (nz > 0)
for (idx = 0; idx < n; idx ++) {
img[3*((y+j)*width + x)] += img_all[idx * stride + 3*((y+j)*width + x)] / nz;
img[3*((y+j)*width + x)+1] += img_all[idx * stride + 3*((y+j)*width + x) + 1] / nz;
img[3*((y+j)*width + x)+2] += img_all[idx * stride + 3*((y+j)*width + x) + 2] / nz;
weight = selection[idx * stride + ((y+j)*width + x)] / (sum + 1e-4);
img[3*((y+j)*width + x)] += (unsigned char) (img_all[idx * stride * 3 + 3*((y+j)*width + x)] * weight);
img[3*((y+j)*width + x)+1] += (unsigned char) (img_all[idx * stride * 3 + 3*((y+j)*width + x) + 1] * weight);
img[3*((y+j)*width + x)+2] += (unsigned char)(img_all[idx * stride * 3 + 3*((y+j)*width + x) + 2] * weight);
}
}
}
@ -211,9 +209,61 @@ __global__ void render_depth(float *points3d_polar, unsigned int * depth_render)
}
}
__global__ void get_average(unsigned char * img, int * nz, int * average, int scale)
{
int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
int width = gridDim.x * TILE_DIM;
int h = width /2;
for (int j = 0; j < TILE_DIM; j+= BLOCK_ROWS)
{
int iw = x;
int ih = y + j;
if (img[3*(ih*width + iw)] + img[3*(ih*width + iw)+1] + img[3*(ih*width + iw)+2] > 0)
{
//nz[ih/3 * width + iw/3] += 1;
//average[3*(ih/3*width + iw/3)] += (int)img[3*(ih*width + iw)];
//average[3*(ih/3*width + iw/3)+1] += (int)img[3*(ih*width + iw)+1];
//average[3*(ih/3*width + iw/3)+2] += (int)img[3*(ih*width + iw)+2];
atomicAdd(&(nz[ih/scale * width + iw/scale]), 1);
atomicAdd(&(average[3*(ih/scale*width + iw/scale)]), (int)img[3*(ih*width + iw)]);
atomicAdd(&(average[3*(ih/scale*width + iw/scale)+1]), (int)img[3*(ih*width + iw)+1]);
atomicAdd(&(average[3*(ih/scale*width + iw/scale)+2]), (int)img[3*(ih*width + iw)+2]);
}
}
}
__global__ void render_final(float *points3d_polar, float * depth_render, int * img, int * render, int s)
__global__ void fill_with_average(unsigned char *img, int * nz, int * average, int scale)
{
int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
int width = gridDim.x * TILE_DIM;
int h = width /2;
for (int j = 0; j < TILE_DIM; j+= BLOCK_ROWS)
{
int iw = x;
int ih = y + j;
if ((img[3*(ih*width + iw)] + img[3*(ih*width + iw)+1] + img[3*(ih*width + iw)+2] == 0) && (nz[ih/scale * width + iw/scale] > 0))
{
img[3*(ih*width + iw)] = (unsigned char)(average[3*(ih/scale*width + iw/scale)] / nz[ih/scale * width + iw/scale]);
img[3*(ih*width + iw) + 1] = (unsigned char)(average[3*(ih/scale*width + iw/scale) + 1] / nz[ih/scale * width + iw/scale]);
img[3*(ih*width + iw) + 2] = (unsigned char)(average[3*(ih/scale*width + iw/scale) + 2] / nz[ih/scale * width + iw/scale]);
}
}
}
__global__ void render_final(float *points3d_polar, float * selection, float * depth_render, int * img, int * render, int s)
{
int x = blockIdx.x * TILE_DIM + threadIdx.x;
int y = blockIdx.y * TILE_DIM + threadIdx.y;
@ -305,8 +355,11 @@ __global__ void render_final(float *points3d_polar, float * depth_render, int *
if (r > 255) r = 255;
if (g > 255) g = 255;
if (b > 255) b = 255;
render[(ity * w * s + itx)] = r * 256 * 256 + g * 256 + b;
if ((ity > 0) && (ity < h * s) && (itx > 0) && (ity < w * s)) {
render[(ity * w * s + itx)] = r * 256 * 256 + g * 256 + b;
selection[(ity * w * s + itx)] = 1.0 / abs(det);
}
//printf("%f\n", selection[(ity * w * s + itx)]);
}
}
@ -339,6 +392,11 @@ void render(int n, int h,int w, int s, unsigned char * img, float * depth,float
float *d_depth_render;
float *d_3dpoint, *d_3dpoint_after;
float * d_selection;
int * nz;
int * average;
int *d_render2, *d_img2;
cudaMalloc((void **)&d_img, frame_mem_size);
@ -351,10 +409,18 @@ void render(int n, int h,int w, int s, unsigned char * img, float * depth,float
cudaMalloc((void **)&d_pose, sizeof(float) * 16);
cudaMalloc((void **)&d_render2, render_mem_size * sizeof(int));
cudaMalloc((void **)&d_img2, render_mem_size * sizeof(int));
cudaMalloc((void **)&d_selection, render_mem_size * sizeof(float) * n);
cudaMalloc((void **)&nz, render_mem_size * sizeof(int));
cudaMalloc((void **)&average, render_mem_size * sizeof(int) * 3);
cudaMemcpy(d_depth_render, depth_render, render_mem_size * sizeof(float), cudaMemcpyHostToDevice);
cudaMemset(d_render_all, 0, render_mem_size * sizeof(unsigned char) * 3 * n);
cudaMemset(d_selection, 0, render_mem_size * sizeof(float) * n);
cudaMemset(nz, 0, render_mem_size * sizeof(int));
cudaMemset(average, 0, render_mem_size * sizeof(int) * 3);
int idx;
for (idx = 0; idx < n; idx ++) {
@ -376,15 +442,27 @@ void render(int n, int h,int w, int s, unsigned char * img, float * depth,float
char_to_int <<< dimGrid, dimBlock >>> (d_img2, d_img);
render_final <<< dimGrid, dimBlock >>> (d_3dpoint_after, d_depth_render, d_img2, d_render2, s);
render_final <<< dimGrid, dimBlock >>> (d_3dpoint_after, &(d_selection[idx * nx * ny * s * s]), d_depth_render, d_img2, d_render2, s);
int_to_char <<< dimGrid2, dimBlock >>> (d_render2, d_render);
//int_to_char <<< dimGrid2, dimBlock >>> (d_render2, d_render);
int_to_char <<< dimGrid2, dimBlock >>> (d_render2, &(d_render_all[idx * nx * ny * s * s * 3]));
fill <<< dimGrid2, dimBlock >>> (&(d_render_all[idx * nx * ny * s * s * 3]));
//fill <<< dimGrid2, dimBlock >>> (&(d_render_all[idx * nx * ny * s * s * 3]));
}
merge <<< dimGrid2, dimBlock >>> (d_render_all, d_render, n, nx * ny * s * s * 3);
merge <<< dimGrid2, dimBlock >>> (d_render_all, d_render, d_selection, n, nx * ny * s * s);
/*int fill_size[8] = {3, 5, 10, 20, 50, 100, 200};
for (int j = 0; j < 8; j++) {
cudaMemset(nz, 0, render_mem_size * sizeof(int));
cudaMemset(average, 0, render_mem_size * sizeof(int) * 3);
get_average <<< dimGrid2, dimBlock >>> (d_render, nz, average, fill_size[j]);
fill_with_average <<< dimGrid2, dimBlock >>> (d_render, nz, average, fill_size[j]);
}*/
cudaMemset(nz, 0, render_mem_size * sizeof(int));
cudaMemset(average, 0, render_mem_size * sizeof(int) * 3);
get_average <<< dimGrid2, dimBlock >>> (d_render, nz, average, 3);
fill_with_average <<< dimGrid2, dimBlock >>> (d_render, nz, average, 3);
cudaMemcpy(render, d_render, render_mem_size * sizeof(unsigned char) * 3 , cudaMemcpyDeviceToHost);
@ -398,6 +476,9 @@ void render(int n, int h,int w, int s, unsigned char * img, float * depth,float
cudaFree(d_3dpoint_after);
cudaFree(d_pose);
cudaFree(d_render_all);
cudaFree(d_selection);
cudaFree(nz);
cudaFree(average);
}

View File

@ -25,7 +25,7 @@ from numpy import cos, sin
from profiler import Profiler
from multiprocessing.dummy import Process
from datasets import ViewDataSet3D
from realenv.data.datasets import ViewDataSet3D
# In[2]:
@ -44,8 +44,8 @@ d = ViewDataSet3D(root='/home/fei/Downloads/highres_tiny/', transform = np.array
# In[ ]:
scene_dict = dict(zip(d.scenes, range(len(d.scenes))))
model_id = scene_dict.keys()[0]
model_id = scene_dict.keys()[1]
scene_id = scene_dict[model_id]
uuids, rts = d.get_scene_info(scene_id)
@ -104,7 +104,7 @@ poses = relative_poses_topk
poses_after = [
pose.dot(np.linalg.inv(poses[i])).astype(np.float32)
for i in range(len(imgs_topk))]
# In[ ]:
@ -113,33 +113,17 @@ showsz = 1024
show = np.zeros((showsz,showsz * 2,3),dtype='uint8')
this_depth = (128 * depths[topk[0]]).astype(np.float32)
for i in range(50):
for i in range(5):
with Profiler("Render pointcloud"):
cuda_pc.render(ct.c_int(len(imgs_topk)),
cuda_pc.render(ct.c_int(len(imgs_topk)),
ct.c_int(imgs_topk[0].shape[0]),
ct.c_int(imgs_topk[0].shape[1]),
ct.c_int(1),
imgs_topk.ctypes.data_as(ct.c_void_p),
depths_topk.ctypes.data_as(ct.c_void_p),
np.asarray(poses_after, dtype = np.float32).ctypes.data_as(ct.c_void_p),
show.ctypes.data_as(ct.c_void_p),
this_depth.ctypes.data_as(ct.c_void_p)
)
Image.fromarray(show).save('imgs/test%04d.png' % i)
# In[ ]:
# In[ ]:
# In[ ]:
# In[ ]:

View File

@ -41,9 +41,9 @@ def get_model_path(idx=0):
class ViewDataSet3D(data.Dataset):
def __init__(self, train=True, transform=None, mist_transform=None, loader=default_loader, seqlen=5, debug=False, dist_filter = None, off_3d = True, off_pc_render = True):
def __init__(self, root, train=True, transform=None, mist_transform=None, loader=default_loader, seqlen=5, debug=False, dist_filter = None, off_3d = True, off_pc_render = True):
print ('Processing the data:')
self.root = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dataset')
self.root = root
self.fofn = self.root + '_fofn'+str(int(train))+'.pkl'
self.train = train
self.loader = loader
@ -152,9 +152,9 @@ class ViewDataSet3D(data.Dataset):
## DEPTH DEBUG
#p = p #np.dot(rotation, p)
#rotation = np.array([[0,-1,0,0],[-1,0,0,0],[0,0,1,0],[0,0,0,1]])
rotation = np.array([[0,1,0,0],[0,0,1,0],[-1,0,0,0],[0,0,0,1]])
p = np.dot(p, rotation)
poses.append(p)
f.close()
@ -168,7 +168,7 @@ class ViewDataSet3D(data.Dataset):
uuids = [item[1] for item in self.select[index]]
#print("selection length", len(self.select), len(self.select[0]))
#print(uuids)
#print(uuids)
#poses = ([self.meta[scene][item][1:] for item in uuids])
#poses = [item[0] + item[1] for item in poses]

View File

@ -0,0 +1,104 @@
import numpy as np
import ctypes as ct
import cv2
import sys
import argparse
from datasets import ViewDataSet3D
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
import time
from numpy import cos, sin
import matplotlib.pyplot as plt
from PIL import Image
import os
import time
from multiprocessing import Pool, cpu_count
from scipy.signal import convolve2d
from scipy.interpolate import griddata
import scipy
import torch.nn.functional as F
from torchvision import transforms
dll=np.ctypeslib.load_library('../core/render/render_cuda_f','.')
# In[6]:
def render(imgs, depths, pose, poses, tdepth):
global fps
t0 = time.time()
showsz = imgs[0].shape[0]
nimgs = len(imgs)
show=np.zeros((showsz,showsz * 2,3),dtype='uint8')
target_depth = (128 * tdepth[:,:,0]).astype(np.float32)
imgs = np.array(imgs)
depths = np.array(depths).flatten()
pose_after = [pose.dot(np.linalg.inv(poses[0])).dot(poses[i]).astype(np.float32) for i in range(len(imgs))]
pose_after = np.array(pose_after)
dll.render(ct.c_int(len(imgs)),
ct.c_int(imgs[i].shape[0]),
ct.c_int(imgs[i].shape[1]),
ct.c_int(1),
imgs.ctypes.data_as(ct.c_void_p),
depths.ctypes.data_as(ct.c_void_p),
pose_after.ctypes.data_as(ct.c_void_p),
show.ctypes.data_as(ct.c_void_p),
target_depth.ctypes.data_as(ct.c_void_p)
)
return show, target_depth
# In[7]:
def generate_data(args):
idx = args[0]
print(idx)
d = args[1]
outf = args[2]
print(idx)
data = d[idx] ## This operation stalls 95% of the time, CPU heavy
sources = data[0]
target = data[1]
source_depths = data[2]
target_depth = data[3]
poses = [item.numpy() for item in data[-1]]
show, _ = render(sources, source_depths, poses[0], poses, target_depth)
print(show.shape)
if idx % 100 == 0:
Image.fromarray(show).save('%s/show%d.png' % (outf, idx))
Image.fromarray(target).save('%s/target%d.png' % (outf, idx))
filename = "%s/data_%d.npz" % (outf, idx)
if not os.path.isfile(filename):
np.savez(file = filename, source = show, depth = target_depth, target = target)
return show, target_depth, target
parser = argparse.ArgumentParser()
parser.add_argument('--debug' , action='store_true', help='debug mode')
parser.add_argument('--dataroot' , required = True, help='dataset path')
parser.add_argument('--outf' , type = str, default = '', help='path of output folder')
opt = parser.parse_args()
d = ViewDataSet3D(root=opt.dataroot, transform = np.array, mist_transform = np.array, seqlen = 5, off_3d = False, train = True)
p = Pool(6)
p.map(generate_data, [(idx, d, opt.outf) for idx in range(len(d))])
#for i in range(len(d)):
# filename = "%s/data_%d.npz" % (opt.outf, i)
# print(filename)
# if not os.path.isfile(filename):
# generate_data([i, d, opt.outf])

View File

View File

@ -0,0 +1,229 @@
from __future__ import print_function
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torch
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.nn.functional as F
import shutil
import time
import matplotlib.pyplot as plt
cudnn.benchmark = True
class AdaptiveNorm2d(nn.Module):
def __init__(self, nchannel, momentum = 0.05):
super(AdaptiveNorm2d, self).__init__()
self.nm = nn.BatchNorm2d(nchannel, momentum = momentum)
self.w0 = nn.Parameter(torch.zeros(1))
self.w1 = nn.Parameter(torch.ones(1))
def forward(self, x):
return self.w0.repeat(x.size()) * self.nm(x) + self.w1.repeat(x.size()) * x
class CompletionNet2(nn.Module):
def __init__(self, norm = AdaptiveNorm2d, nf = 64):
super(CompletionNet2, self).__init__()
self.nf = nf
alpha = 0.05
self.convs = nn.Sequential(
nn.Conv2d(5, nf/4, kernel_size = 5, stride = 1, padding = 2),
nn.LeakyReLU(0.1),
nn.Conv2d(nf/4, nf, kernel_size = 5, stride = 2, padding = 2),
norm(nf, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf, nf, kernel_size = 3, stride = 1, padding = 2),
norm(nf, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf, nf*4, kernel_size = 5, stride = 2, padding = 1),
norm(nf * 4, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf*4, nf * 4, kernel_size = 3, stride = 1, padding = 1),
norm(nf * 4, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, padding = 1),
norm(nf * 4, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 2, padding = 2),
norm(nf * 4, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 4, padding = 4),
norm(nf * 4, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 8, padding = 8),
norm(nf * 4, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 16, padding = 16),
norm(nf * 4, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, dilation = 32, padding = 32),
norm(nf * 4, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, padding = 1),
norm(nf * 4, momentum=alpha),
nn.ReLU(),
nn.Conv2d(nf * 4, nf * 4, kernel_size = 3, stride = 1, padding = 1),
norm(nf * 4, momentum=alpha),
nn.ReLU(),
nn.ConvTranspose2d(nf * 4, nf , kernel_size = 4, stride = 2, padding = 1),
norm(nf , momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf, nf, kernel_size = 3, stride = 1, padding = 1),
norm(nf, momentum=alpha),
nn.LeakyReLU(0.1),
nn.ConvTranspose2d(nf, nf/4, kernel_size = 4, stride = 2, padding = 1),
norm(nf/4, momentum=alpha),
nn.ReLU(),
nn.Conv2d(nf/4, nf/4, kernel_size = 3, stride = 1, padding = 1),
norm(nf/4, momentum=alpha),
nn.LeakyReLU(0.1),
nn.Conv2d(nf/4, 3, kernel_size = 3, stride = 1, padding = 1),
)
def forward(self, x, mask):
return self.convs(torch.cat([x, mask], 1))
def identity_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data.fill_(0)
o, i, k1, k2 = m.weight.data.size()
cx, cy = k1//2, k2//2
nc = min(o,i)
print(nc)
for i in range(nc):
m.weight.data[i,i,cx,cy] = 1
m.bias.data.fill_(0)
if m.stride[0] == 2:
for i in range(nc):
m.weight.data[i+nc,i,cx+1,cy] = 1
m.weight.data[i+nc*2,i,cx,cy+1] = 1
m.weight.data[i+nc*3,i,cx+1,cy+1] = 1
elif classname.find('ConvTranspose2d') != -1:
o, i, k1, k2 = m.weight.data.size()
nc = min(o,i)
cx, cy = k1//2-1, k2//2-1
m.weight.data.fill_(0)
for i in range(nc):
m.weight.data[i,i,cx,cy] = 1
m.weight.data[i+nc,i,cx+1,cy] = 1
m.weight.data[i+nc*2,i,cx,cy+1] = 1
m.weight.data[i+nc*3,i,cx+1,cy+1] = 1
m.bias.data.fill_(0)
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1)
m.bias.data.fill_(0)
class Perceptual(nn.Module):
def __init__(self, features, early = False):
super(Perceptual, self).__init__()
self.features = features
self.early = early
def forward(self, x):
bs = x.size(0)
x = self.features[0](x)
x = self.features[1](x)
x = self.features[2](x)
x0 = x.view(bs,-1,1)
x = F.relu(x)
x = self.features[4](x)
x = self.features[5](x)
x = self.features[6](x)
x = self.features[7](x)
x1 = x.view(bs, -1, 1)
x = F.relu(x)
x = self.features[9](x)
x = self.features[10](x)
x = self.features[11](x)
x = self.features[12](x)
x2 = x.view(bs, -1, 1)
x = F.relu(x)
x = self.features[14](x)
x = self.features[15](x)
x = self.features[16](x)
x = self.features[17](x)
x = self.features[18](x)
x = self.features[19](x)
x3 = x.view(bs, -1, 1)
x = F.relu(x)
x = self.features[21](x)
x = self.features[22](x)
x = self.features[23](x)
x = self.features[24](x)
x = self.features[25](x)
x = self.features[26](x)
x4 = x.view(bs, -1, 1)
if self.early:
perfeat = torch.cat([x0, x1, x2], 1)
else:
perfeat = torch.cat([x0, x1, x2, x3, x4], 1)
return perfeat
class Discriminator2(nn.Module):
def __init__(self, pano = False):
super(Discriminator2, self).__init__()
alpha = 0.05
self.pano = pano
nf = 64
self.nf = nf
self.convs_global = nn.Sequential(
nn.Conv2d(3, nf, kernel_size = 5, stride = 2, padding = 1),
nn.ReLU(),
nn.Conv2d(nf, nf * 2, kernel_size = 5, stride = 2, padding = 1),
nn.BatchNorm2d(nf * 2, momentum=alpha),
nn.ReLU(),
nn.Conv2d(nf * 2, nf * 4, kernel_size = 5, stride = 2, padding = 1),
nn.BatchNorm2d(nf * 4, momentum=alpha),
nn.ReLU(),
nn.Conv2d(nf * 4, nf * 8, kernel_size = 5, stride = 2, padding = 1),
nn.BatchNorm2d(nf * 8, momentum=alpha),
nn.ReLU(),
nn.Conv2d(nf * 8, nf * 8, kernel_size = 5, stride = 2, padding = 1),
nn.BatchNorm2d(nf * 8, momentum=alpha),
nn.ReLU(),
nn.Conv2d(nf * 8, nf * 8, kernel_size = 5, stride = 2, padding = 1),
nn.ReLU()
)
if self.pano:
self.fc_global = nn.Linear(nf * 8 * 3 * 7, 1000)
else:
self.fc_global = nn.Linear(nf * 8 * 3 * 3, 1000)
def forward(self, img):
y = self.convs_global(img)
if self.pano:
y = y.view(y.size(0), self.nf * 8 * 3 * 7)
else:
y = y.view(y.size(0), self.nf * 8 * 3 * 3)
y = F.relu(self.fc_global(y))
x = F.log_softmax(y)
return x

View File

@ -0,0 +1,333 @@
import argparse
import os
import re
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision.utils as vutils
from realenv.data.datasets import PairDataset
from completion2 import CompletionNet2, identity_init, Perceptual, Discriminator2
from tensorboard import SummaryWriter
from datetime import datetime
import vision_utils
import torch.nn.functional as F
import torchvision.models as models
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('Linear') != -1:
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def crop(source, source_depth, target):
bs = source.size(0)
source_cropped = Variable(torch.zeros(4*bs, 3, 256, 256)).cuda()
source_depth_cropped = Variable(torch.zeros(4*bs, 2, 256, 256)).cuda()
target_cropped = Variable(torch.zeros(4*bs, 3, 256, 256)).cuda()
for i in range(bs):
for j in range(4):
idx = i * 4 + j
blurry_margin = 1024 / 8
centerx = np.random.randint(blurry_margin + 128, 1024 - blurry_margin - 128)
centery = np.random.randint(128, 1024 * 2 - 128)
source_cropped[idx] = source[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
source_depth_cropped[idx] = source_depth[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
target_cropped[idx] = target[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
return source_cropped, source_depth_cropped, target_cropped
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--debug' , action='store_true', help='debug mode')
parser.add_argument('--imgsize' ,type=int, default = 256, help='image size')
parser.add_argument('--batchsize' ,type=int, default = 20, help='batchsize')
parser.add_argument('--workers' ,type=int, default = 9, help='number of workers')
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--outf', type=str, default="filler_pano_pc_full", help='output folder')
parser.add_argument('--model', type=str, default="", help='model path')
parser.add_argument('--cepoch', type=int, default = 0, help='current epoch')
parser.add_argument('--loss', type=str, default="perceptual", help='l1 only')
parser.add_argument('--init', type=str, default = "iden", help='init method')
parser.add_argument('--l1', type=float, default = 0, help='add l1 loss')
parser.add_argument('--color_coeff', type=float, default = 0, help='add color match loss')
parser.add_argument('--cascade' , action='store_true', help='debug mode')
parser.add_argument('--unfiller' , action='store_true', help='debug mode')
mean = torch.from_numpy(np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32))
opt = parser.parse_args()
print(opt)
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
try:
os.makedirs(opt.outf)
except OSError:
pass
tf = transforms.Compose([
transforms.ToTensor(),
])
mist_tf = transforms.Compose([
transforms.ToTensor(),
])
d = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf)
d_test = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf, train = False)
cudnn.benchmark = True
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
dataloader_test = torch.utils.data.DataLoader(d_test, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
img = Variable(torch.zeros(opt.batchsize,3 , 1024, 2048)).cuda()
maskv = Variable(torch.zeros(opt.batchsize,2, 1024, 2048)).cuda()
img_original = Variable(torch.zeros(opt.batchsize,3, 1024, 2048)).cuda()
label = Variable(torch.LongTensor(opt.batchsize * 4)).cuda()
comp = CompletionNet2(norm = nn.BatchNorm2d)
dis = Discriminator2(pano = False)
current_epoch = opt.cepoch
comp = torch.nn.DataParallel(comp).cuda()
if opt.init == 'iden':
comp.apply(identity_init)
else:
comp.apply(weights_init)
dis = torch.nn.DataParallel(dis).cuda()
dis.apply(weights_init)
if opt.model != '':
comp.load_state_dict(torch.load(opt.model))
#dis.load_state_dict(torch.load(opt.model.replace("G", "D")))
current_epoch = opt.cepoch
if opt.cascade:
comp2 = CompletionNet2(norm = nn.BatchNorm2d)
comp2 = torch.nn.DataParallel(comp2).cuda()
if opt.model != '':
comp2.load_state_dict(torch.load(opt.model))
optimizerG2 = torch.optim.Adam(comp2.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
l2 = nn.MSELoss()
#if opt.loss == 'train_init':
# params = list(comp.parameters())
# sel = np.random.choice(len(params), len(params)/2, replace=False)
# params_sel = [params[i] for i in sel]
# optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
#
#else:
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
optimizerD = torch.optim.Adam(dis.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
curriculum = (200000, 300000) # step to start D training and G training, slightly different from the paper
alpha = 0.004
errG_data = 0
errD_data = 0
vgg16 = models.vgg16(pretrained = False)
vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
feat = vgg16.features
p = torch.nn.DataParallel(Perceptual(feat, early = (opt.loss == 'early'))).cuda()
for param in p.parameters():
param.requires_grad = False
test_loader_enum = enumerate(dataloader_test)
for epoch in range(current_epoch, opt.nepoch):
for i, data in enumerate(dataloader, 0):
optimizerG.zero_grad()
source = data[0]
source_depth = data[1]
target = data[2]
step = i + epoch * len(dataloader)
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
#img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * mean.view(1,3,1,1).repeat(opt.batchsize,1,1024,2048)
source_depth = source_depth[:,:,:,0].unsqueeze(1)
#print(source_depth.size(), mask.size())
source_depth = torch.cat([source_depth, mask], 1)
img.data.copy_(source)
maskv.data.copy_(source_depth)
img_original.data.copy_(target)
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
#from IPython import embed; embed()
recon = comp(imgc, maskvc)
if opt.loss == "train_init":
loss = l2(recon, imgc[:,:3,:,:])
elif opt.loss == 'l1':
loss = l2(recon, img_originalc)
elif opt.loss == 'perceptual':
loss = l2(p(recon), p(img_originalc).detach()) + opt.l1 * l2(recon, img_originalc)
elif opt.loss == 'color_stable':
loss = l2(p(recon.view(recon.size(0) * 3, 1, 256, 256).repeat(1,3,1,1)), p(img_originalc.view(img_originalc.size(0)*3,1,256,256).repeat(1,3,1,1)).detach())
elif opt.loss == 'color_correction':
loss = l2(p(recon), p(img_originalc).detach())
for scale in [32]:
img_originalc_patch = img_originalc.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
recon_patch = recon.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
recon_patch_mean = recon_patch.mean(dim = -1)
recon_patch_cov = []
img_originalc_patch_cov = []
for j in range(3):
recon_patch_cov.append((recon_patch * recon_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
recon_patch_cov_cat = torch.cat(recon_patch_cov,1)
img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)
color_loss = l2(recon_patch_mean, img_originalc_patch_mean) + l2(recon_patch_cov_cat, img_originalc_patch_cov_cat.detach())
loss += opt.color_coeff * color_loss
print("color loss %f" % color_loss.data[0])
loss.backward(retain_graph = True)
if opt.cascade:
optimizerG2.zero_grad()
recon2 = comp2(torch.cat([recon, imgc[:,3:]], 1), maskvc)
loss2 = l2(p(recon2), p(img_originalc).detach())
for scale in [32]:
img_originalc_patch = img_originalc.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
recon2_patch = recon2.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
recon2_patch_mean = recon2_patch.mean(dim = -1)
recon2_patch_cov = []
img_originalc_patch_cov = []
for j in range(3):
recon2_patch_cov.append((recon2_patch * recon2_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
recon2_patch_cov_cat = torch.cat(recon2_patch_cov,1)
img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)
color_loss = l2(recon2_patch_mean, img_originalc_patch_mean) + l2(recon2_patch_cov_cat, img_originalc_patch_cov_cat.detach())
loss2 += opt.color_coeff * color_loss
print("color loss %f" % color_loss.data[0])
loss2.backward(retain_graph = True)
print("loss2 %f" % loss2.data[0])
optimizerG2.step()
if i%10 == 0:
writer.add_scalar('MSEloss2', loss2.data[0], step)
if step > curriculum[1]:
label.data.fill_(1)
output = dis(recon)
errG = alpha * F.nll_loss(output, label)
errG.backward()
errG_data = errG.data[0]
#from IPython import embed; embed()
if opt.loss == "train_init":
for param in comp.parameters():
if len(param.size()) == 4:
#print(param.size())
nk = param.size()[2]//2
if nk > 5:
param.grad[:nk, :,:,:] = 0
optimizerG.step()
# Train D:
if step > curriculum[0]:
optimizerD.zero_grad()
label.data.fill_(0)
output = dis(recon.detach())
#print(output)
errD_fake = alpha * F.nll_loss(output, label)
errD_fake.backward(retain_graph = True)
output = dis(img_originalc)
#print(output)
label.data.fill_(1)
errD_real = alpha * F.nll_loss(output, label)
errD_real.backward()
optimizerD.step()
errD_data = errD_real.data[0] + errD_fake.data[0]
print('[%d/%d][%d/%d] %d MSEloss: %f G_loss %f D_loss %f' % (epoch, opt.nepoch, i, len(dataloader), step, loss.data[0], errG_data, errD_data))
if i%200 == 0:
test_i, test_data = test_loader_enum.next()
if test_i > len(dataloader_test) - 5:
test_loader_enum = enumerate(dataloader_test)
source = test_data[0]
source_depth = test_data[1]
target = test_data[2]
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * mean.view(1,3,1,1).repeat(opt.batchsize,1,1024,2048)
source_depth = source_depth[:,:,:,0].unsqueeze(1)
source_depth = torch.cat([source_depth, mask], 1)
img.data.copy_(source)
maskv.data.copy_(source_depth)
img_original.data.copy_(target)
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
comp.eval()
recon = comp(imgc, maskvc)
comp.train()
if opt.cascade:
comp2.eval()
recon2 = comp2(torch.cat([recon, imgc[:,3:]], 1), maskvc)
comp2.train()
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, recon2.data, img_originalc.data], 3)
else:
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, img_originalc.data], 3)
visual = vutils.make_grid(visual, normalize=True)
writer.add_image('image', visual, step)
vutils.save_image(visual, '%s/compare%d_%d.png' % (opt.outf, epoch, i), nrow=1)
if i%10 == 0:
writer.add_scalar('MSEloss', loss.data[0], step)
writer.add_scalar('G_loss', errG_data, step)
writer.add_scalar('D_loss', errD_data, step)
if i%10000 == 0:
torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))
torch.save(dis.state_dict(), '%s/compD_epoch%d_%d.pth' % (opt.outf, epoch, i))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,333 @@
import argparse
import os
import re
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision.utils as vutils
from realenv.data.datasets import PairDataset
from completion2 import CompletionNet2, identity_init, Perceptual, Discriminator2
from tensorboard import SummaryWriter
from datetime import datetime
import vision_utils
import torch.nn.functional as F
import torchvision.models as models
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('Linear') != -1:
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def crop(source, source_depth, target):
bs = source.size(0)
source_cropped = Variable(torch.zeros(4*bs, 3, 256, 256)).cuda()
source_depth_cropped = Variable(torch.zeros(4*bs, 2, 256, 256)).cuda()
target_cropped = Variable(torch.zeros(4*bs, 3, 256, 256)).cuda()
for i in range(bs):
for j in range(4):
idx = i * 4 + j
blurry_margin = 1024 / 8
centerx = np.random.randint(blurry_margin + 128, 1024 - blurry_margin - 128)
centery = np.random.randint(128, 1024 * 2 - 128)
source_cropped[idx] = source[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
source_depth_cropped[idx] = source_depth[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
target_cropped[idx] = target[i, :, centerx-128:centerx + 128, centery - 128:centery + 128]
return source_cropped, source_depth_cropped, target_cropped
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--debug' , action='store_true', help='debug mode')
parser.add_argument('--imgsize' ,type=int, default = 256, help='image size')
parser.add_argument('--batchsize' ,type=int, default = 20, help='batchsize')
parser.add_argument('--workers' ,type=int, default = 9, help='number of workers')
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--outf', type=str, default="filler_pano_pc_full", help='output folder')
parser.add_argument('--model', type=str, default="", help='model path')
parser.add_argument('--cepoch', type=int, default = 0, help='current epoch')
parser.add_argument('--loss', type=str, default="perceptual", help='l1 only')
parser.add_argument('--init', type=str, default = "iden", help='init method')
parser.add_argument('--l1', type=float, default = 0, help='add l1 loss')
parser.add_argument('--color_coeff', type=float, default = 0, help='add color match loss')
parser.add_argument('--unfiller' , action='store_true', help='debug mode')
mean = torch.from_numpy(np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32))
opt = parser.parse_args()
print(opt)
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
try:
os.makedirs(opt.outf)
except OSError:
pass
tf = transforms.Compose([
transforms.ToTensor(),
])
mist_tf = transforms.Compose([
transforms.ToTensor(),
])
d = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf)
d_test = PairDataset(root = opt.dataroot, transform=tf, mist_transform = mist_tf, train = False)
cudnn.benchmark = True
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
dataloader_test = torch.utils.data.DataLoader(d_test, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = False)
img = Variable(torch.zeros(opt.batchsize,3 , 1024, 2048)).cuda()
maskv = Variable(torch.zeros(opt.batchsize,2, 1024, 2048)).cuda()
img_original = Variable(torch.zeros(opt.batchsize,3, 1024, 2048)).cuda()
label = Variable(torch.LongTensor(opt.batchsize * 4)).cuda()
comp = CompletionNet2(norm = nn.BatchNorm2d, nf = 24)
dis = Discriminator2(pano = False)
current_epoch = opt.cepoch
comp = torch.nn.DataParallel(comp).cuda()
if opt.init == 'iden':
comp.apply(identity_init)
else:
comp.apply(weights_init)
dis = torch.nn.DataParallel(dis).cuda()
dis.apply(weights_init)
if opt.model != '':
comp.load_state_dict(torch.load(opt.model))
#dis.load_state_dict(torch.load(opt.model.replace("G", "D")))
current_epoch = opt.cepoch
if opt.unfiller:
comp2 = CompletionNet2(norm = nn.BatchNorm2d, nf = 24)
comp2 = torch.nn.DataParallel(comp2).cuda()
if opt.model != '':
comp2.load_state_dict(torch.load(opt.model))
optimizerG2 = torch.optim.Adam(comp2.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
l2 = nn.MSELoss()
#if opt.loss == 'train_init':
# params = list(comp.parameters())
# sel = np.random.choice(len(params), len(params)/2, replace=False)
# params_sel = [params[i] for i in sel]
# optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
#
#else:
optimizerG = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
optimizerD = torch.optim.Adam(dis.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
curriculum = (200000, 300000) # step to start D training and G training, slightly different from the paper
alpha = 0.004
errG_data = 0
errD_data = 0
vgg16 = models.vgg16(pretrained = False)
vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
feat = vgg16.features
p = torch.nn.DataParallel(Perceptual(feat, early = (opt.loss == 'early'))).cuda()
for param in p.parameters():
param.requires_grad = False
test_loader_enum = enumerate(dataloader_test)
for epoch in range(current_epoch, opt.nepoch):
for i, data in enumerate(dataloader, 0):
optimizerG.zero_grad()
source = data[0]
source_depth = data[1]
target = data[2]
step = i + epoch * len(dataloader)
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
#img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * mean.view(1,3,1,1).repeat(opt.batchsize,1,1024,2048)
source_depth = source_depth[:,:,:,0].unsqueeze(1)
#print(source_depth.size(), mask.size())
source_depth = torch.cat([source_depth, mask], 1)
img.data.copy_(source)
maskv.data.copy_(source_depth)
img_original.data.copy_(target)
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
#from IPython import embed; embed()
recon = comp(imgc, maskvc)
if opt.loss == "train_init":
loss = l2(recon, imgc[:,:3,:,:])
elif opt.loss == 'l1':
loss = l2(recon, img_originalc)
elif opt.loss == 'perceptual':
loss = l2(p(recon), p(img_originalc).detach()) + opt.l1 * l2(recon, img_originalc)
elif opt.loss == 'color_stable':
loss = l2(p(recon.view(recon.size(0) * 3, 1, 256, 256).repeat(1,3,1,1)), p(img_originalc.view(img_originalc.size(0)*3,1,256,256).repeat(1,3,1,1)).detach())
elif opt.loss == 'color_correction':
loss = l2(p(recon), p(img_originalc).detach())
for scale in [32]:
img_originalc_patch = img_originalc.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
recon_patch = recon.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
recon_patch_mean = recon_patch.mean(dim = -1)
recon_patch_cov = []
img_originalc_patch_cov = []
for j in range(3):
recon_patch_cov.append((recon_patch * recon_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
recon_patch_cov_cat = torch.cat(recon_patch_cov,1)
img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)
color_loss = l2(recon_patch_mean, img_originalc_patch_mean) + l2(recon_patch_cov_cat, img_originalc_patch_cov_cat.detach())
loss += opt.color_coeff * color_loss
print("color loss %f" % color_loss.data[0])
loss.backward(retain_graph = True)
if opt.unfiller:
optimizerG2.zero_grad()
recon2 = comp2(img_originalc, maskvc)
loss2 = l2(p(recon2), p(recon).detach())
for scale in [32]:
img_originalc_patch = recon.detach().view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
recon2_patch = recon2.view(opt.batchsize * 4,3,256/scale,scale,256/scale,scale).transpose(4,3).contiguous().view(opt.batchsize * 4,3,256/scale,256/scale,-1)
img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
recon2_patch_mean = recon2_patch.mean(dim = -1)
recon2_patch_cov = []
img_originalc_patch_cov = []
for j in range(3):
recon2_patch_cov.append((recon2_patch * recon2_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
recon2_patch_cov_cat = torch.cat(recon2_patch_cov,1)
img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)
color_loss = l2(recon2_patch_mean, img_originalc_patch_mean) + l2(recon2_patch_cov_cat, img_originalc_patch_cov_cat.detach())
loss2 += opt.color_coeff * color_loss
print("color loss %f" % color_loss.data[0])
loss2.backward(retain_graph = True)
print("loss2 %f" % loss2.data[0])
optimizerG2.step()
if i%10 == 0:
writer.add_scalar('MSEloss2', loss2.data[0], step)
if step > curriculum[1]:
label.data.fill_(1)
output = dis(recon)
errG = alpha * F.nll_loss(output, label)
errG.backward()
errG_data = errG.data[0]
#from IPython import embed; embed()
if opt.loss == "train_init":
for param in comp.parameters():
if len(param.size()) == 4:
#print(param.size())
nk = param.size()[2]//2
if nk > 5:
param.grad[:nk, :,:,:] = 0
optimizerG.step()
# Train D:
if step > curriculum[0]:
optimizerD.zero_grad()
label.data.fill_(0)
output = dis(recon.detach())
#print(output)
errD_fake = alpha * F.nll_loss(output, label)
errD_fake.backward(retain_graph = True)
output = dis(img_originalc)
#print(output)
label.data.fill_(1)
errD_real = alpha * F.nll_loss(output, label)
errD_real.backward()
optimizerD.step()
errD_data = errD_real.data[0] + errD_fake.data[0]
print('[%d/%d][%d/%d] %d MSEloss: %f G_loss %f D_loss %f' % (epoch, opt.nepoch, i, len(dataloader), step, loss.data[0], errG_data, errD_data))
if i%200 == 0:
test_i, test_data = test_loader_enum.next()
if test_i > len(dataloader_test) - 5:
test_loader_enum = enumerate(dataloader_test)
source = test_data[0]
source_depth = test_data[1]
target = test_data[2]
mask = (torch.sum(source[:,:3,:,:],1)>0).float().unsqueeze(1)
source[:,:3,:,:] += (1-mask.repeat(1,3,1,1)) * mean.view(1,3,1,1).repeat(opt.batchsize,1,1024,2048)
source_depth = source_depth[:,:,:,0].unsqueeze(1)
source_depth = torch.cat([source_depth, mask], 1)
img.data.copy_(source)
maskv.data.copy_(source_depth)
img_original.data.copy_(target)
imgc, maskvc, img_originalc = crop(img, maskv, img_original)
comp.eval()
recon = comp(imgc, maskvc)
comp.train()
if opt.unfiller:
comp2.eval()
recon2 = comp2(img_originalc, maskvc)
comp2.train()
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, recon2.data, img_originalc.data], 3)
else:
visual = torch.cat([imgc.data[:,:3,:,:], recon.data, img_originalc.data], 3)
visual = vutils.make_grid(visual, normalize=True)
writer.add_image('image', visual, step)
vutils.save_image(visual, '%s/compare%d_%d.png' % (opt.outf, epoch, i), nrow=1)
if i%10 == 0:
writer.add_scalar('MSEloss', loss.data[0], step)
writer.add_scalar('G_loss', errG_data, step)
writer.add_scalar('D_loss', errD_data, step)
if i%10000 == 0:
torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))
torch.save(dis.state_dict(), '%s/compD_epoch%d_%d.pth' % (opt.outf, epoch, i))
if opt.unfiller:
torch.save(comp2.state_dict(), '%s/compG2_epoch%d_%d.pth' % (opt.outf, epoch, i))
if __name__ == '__main__':
main()

View File

@ -0,0 +1,57 @@
import torch
import math
import random
from PIL import Image, ImageOps
try:
import accimage
except ImportError:
accimage = None
import numpy as np
import numbers
import types
import collections
class RandomScale(object):
"""Rescale the input PIL.Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(w, h), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, minsize, maxsize, interpolation=Image.BILINEAR):
assert isinstance(minsize, int)
assert isinstance(maxsize, int)
self.minsize = minsize
self.maxsize = maxsize
self.interpolation = interpolation
def __call__(self, img):
"""
Args:
img (PIL.Image): Image to be scaled.
Returns:
PIL.Image: Rescaled image.
"""
size = random.randint(self.minsize, self.maxsize)
if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), self.interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), self.interpolation)
else:
raise NotImplementedError()

View File

@ -22,7 +22,7 @@ setup(name='realenv',
zip_safe=False,
install_requires=[
'numpy>=1.10.4',
'go-vncdriver>=0.4.19',
#'go-vncdriver>=0.4.19',
'pyglet>=1.2.0',
'gym>=0.9.2',
'Pillow>=3.3.0',