add tensorboard integration
This commit is contained in:
parent
537fff75c1
commit
73752a3b99
|
@ -2,12 +2,15 @@ import argparse
|
|||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torchvision import datasets, transforms
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.autograd import Variable
|
||||
import torchvision.utils as vutils
|
||||
from datasets import Places365Dataset
|
||||
from completion import CompletionNet
|
||||
from tensorboard import SummaryWriter
|
||||
from datetime import datetime
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
|
@ -26,7 +29,7 @@ def main():
|
|||
parser.add_argument('--dataroot', required=True, help='path to dataset')
|
||||
parser.add_argument('--debug' , action='store_true', help='debug mode')
|
||||
parser.add_argument('--imgsize' ,type=int, default = 256, help='image size')
|
||||
parser.add_argument('--batchsize' ,type=int, default = 24, help='batchsize')
|
||||
parser.add_argument('--batchsize' ,type=int, default = 76, help='batchsize')
|
||||
parser.add_argument('--workers' ,type=int, default = 6, help='number of workers')
|
||||
parser.add_argument('--nepoch' ,type=int, default = 50, help='number of epochs')
|
||||
parser.add_argument('--lr', type=float, default=0.002, help='learning rate, default=0.002')
|
||||
|
@ -35,11 +38,11 @@ def main():
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
writer = SummaryWriter(opt.outf + '/runs/'+datetime.now().strftime('%B%d %H:%M:%S'))
|
||||
|
||||
try:
|
||||
os.makedirs(opt.outf)
|
||||
except OSError:
|
||||
|
@ -60,11 +63,12 @@ def main():
|
|||
|
||||
cudnn.benchmark = True
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True)
|
||||
dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers), drop_last = True, pin_memory = True)
|
||||
|
||||
img = Variable(torch.rand(opt.batchsize,3, 256, 256)).cuda()
|
||||
patch = Variable(torch.rand(opt.batchsize,3, 128, 128)).cuda()
|
||||
comp = CompletionNet().cuda()
|
||||
comp = CompletionNet()
|
||||
comp = torch.nn.DataParallel(comp).cuda()
|
||||
comp.apply(weights_init)
|
||||
l2 = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(comp.parameters(), lr = opt.lr, betas = (opt.beta1, 0.999))
|
||||
|
@ -82,7 +86,14 @@ def main():
|
|||
print('[%d/%d][%d/%d] loss: %f' % (epoch, opt.nepoch, i, len(dataloader), loss.data[0]))
|
||||
if i%500 == 0:
|
||||
visual = torch.cat([img.data, recon.data], 3)
|
||||
vutils.save_image(visual, '%s/results%d_%d.png' % (opt.outf, epoch, i), nrow=1)
|
||||
#vutils.save_image(visual, '%s/results%d_%d.png' % (opt.outf, epoch, i), nrow=1)
|
||||
|
||||
visual = vutils.make_grid(visual, normalize=True)
|
||||
writer.add_image('image', visual, i + epoch * len(dataloader))
|
||||
|
||||
if i%10 == 0:
|
||||
writer.add_scalar('loss', loss.data[0], i + epoch * len(dataloader))
|
||||
|
||||
torch.save(comp.state_dict(), '%s/comp_epoch%d.pth' % (opt.outf, epoch))
|
||||
|
||||
|
||||
|
|
13
datasets.py
13
datasets.py
|
@ -27,17 +27,12 @@ def is_image_file(filename):
|
|||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
def default_loader(path):
|
||||
img = Image.open(path)
|
||||
ret = img.copy().convert('RGB')
|
||||
img.close()
|
||||
return ret
|
||||
|
||||
img = Image.open(path).convert('RGB')
|
||||
return img
|
||||
|
||||
def depth_loader(path):
|
||||
img = Image.open(path)
|
||||
ret = img.copy().convert('I')
|
||||
img.close()
|
||||
return ret
|
||||
img = Image.open(path).convert('I')
|
||||
return img
|
||||
|
||||
class ViewDataSet3D(data.Dataset):
|
||||
|
||||
|
|
Loading…
Reference in New Issue