add tensorboard integration

This commit is contained in:
fxia22 2017-07-22 18:13:29 -07:00
parent 537fff75c1
commit 73752a3b99
2 changed files with 21 additions and 15 deletions

View File

@ -2,12 +2,15 @@ import argparse
import os
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision.utils as vutils
from datasets import Places365Dataset
from completion import CompletionNet
from tensorboard import SummaryWriter
from datetime import datetime
def weights_init(m):
classname = m.__class__.__name__
@ -26,7 +29,7 @@ def main():
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--debug' , action='store_true', help='debug mode')
parser.add_argument('--imgsize' ,type=int, default = 256, help='image size')
parser.add_argument('--batchsize' ,type=int, default = 24, help='batchsize')
parser.add_argument('--batchsize' ,type=int, default = 76, help='batchsize')
parser.add_argument('--workers' ,type=int, default = 6, help='number of workers')
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
parser.add_argument('--lr', type=float, default=0.002, help='learning rate, default=0.002')
@ -35,11 +38,11 @@ def main():
opt = parser.parse_args()
print(opt)
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
try:
os.makedirs(opt.outf)
except OSError:
@ -60,11 +63,12 @@ def main():
cudnn.benchmark = True
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True)
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = True)
img = Variable(torch.rand(opt.batchsize,3, 256, 256)).cuda()
patch = Variable(torch.rand(opt.batchsize,3, 128, 128)).cuda()
comp = CompletionNet().cuda()
comp = CompletionNet()
comp = torch.nn.DataParallel(comp).cuda()
comp.apply(weights_init)
l2 = nn.MSELoss()
optimizer = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
@ -82,7 +86,14 @@ def main():
print('[%d/%d][%d/%d] loss: %f' % (epoch, opt.nepoch, i, len(dataloader), loss.data[0]))
if i%500 == 0:
visual = torch.cat([img.data, recon.data], 3)
vutils.save_image(visual, '%s/results%d_%d.png' % (opt.outf, epoch, i), nrow=1)
#vutils.save_image(visual, '%s/results%d_%d.png' % (opt.outf, epoch, i), nrow=1)
visual = vutils.make_grid(visual, normalize=True)
writer.add_image('image', visual, i + epoch * len(dataloader))
if i%10 == 0:
writer.add_scalar('loss', loss.data[0], i + epoch * len(dataloader))
torch.save(comp.state_dict(), '%s/comp_epoch%d.pth' % (opt.outf, epoch))

View File

@ -27,17 +27,12 @@ def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def default_loader(path):
img = Image.open(path)
ret = img.copy().convert('RGB')
img.close()
return ret
img = Image.open(path).convert('RGB')
return img
def depth_loader(path):
img = Image.open(path)
ret = img.copy().convert('I')
img.close()
return ret
img = Image.open(path).convert('I')
return img
class ViewDataSet3D(data.Dataset):