333 lines
15 KiB
Python
333 lines
15 KiB
Python
|
# ------------------------------------------------------------------------
|
||
|
# Deformable DETR
|
||
|
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
||
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||
|
# ------------------------------------------------------------------------
|
||
|
# Modified from DETR (https://github.com/facebookresearch/detr)
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||
|
# ------------------------------------------------------------------------
|
||
|
|
||
|
|
||
|
import argparse
|
||
|
import datetime
|
||
|
import json
|
||
|
import random
|
||
|
import time
|
||
|
from pathlib import Path
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
import datasets
|
||
|
|
||
|
import datasets.samplers as samplers
|
||
|
from datasets import build_dataset, get_coco_api_from_dataset
|
||
|
from models import build_model
|
||
|
|
||
|
|
||
|
def get_args_parser():
|
||
|
parser = argparse.ArgumentParser('Deformable DETR Detector', add_help=False)
|
||
|
parser.add_argument('--lr', default=2e-4, type=float)
|
||
|
parser.add_argument('--lr_backbone_names', default=["backbone.0"], type=str, nargs='+')
|
||
|
parser.add_argument('--lr_backbone', default=2e-5, type=float)
|
||
|
parser.add_argument('--lr_linear_proj_names', default=['reference_points', 'sampling_offsets'], type=str, nargs='+')
|
||
|
parser.add_argument('--lr_linear_proj_mult', default=0.1, type=float)
|
||
|
parser.add_argument('--batch_size', default=2, type=int)
|
||
|
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
||
|
parser.add_argument('--epochs', default=15, type=int)
|
||
|
parser.add_argument('--lr_drop', default=5, type=int)
|
||
|
parser.add_argument('--lr_drop_epochs', default=None, type=int, nargs='+')
|
||
|
parser.add_argument('--clip_max_norm', default=0.1, type=float,
|
||
|
help='gradient clipping max norm')
|
||
|
|
||
|
parser.add_argument('--num_ref_frames', default=3, type=int, help='number of reference frames')
|
||
|
|
||
|
parser.add_argument('--sgd', action='store_true')
|
||
|
|
||
|
# Variants of Deformable DETR
|
||
|
parser.add_argument('--with_box_refine', default=False, action='store_true')
|
||
|
parser.add_argument('--two_stage', default=False, action='store_true')
|
||
|
|
||
|
# Model parameters
|
||
|
parser.add_argument('--frozen_weights', type=str, default=None,
|
||
|
help="Path to the pretrained model. If set, only the mask head will be trained")
|
||
|
|
||
|
# * Backbone
|
||
|
parser.add_argument('--backbone', default='resnet50', type=str,
|
||
|
help="Name of the convolutional backbone to use")
|
||
|
parser.add_argument('--dilation', action='store_true',
|
||
|
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
|
||
|
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
||
|
help="Type of positional embedding to use on top of the image features")
|
||
|
parser.add_argument('--position_embedding_scale', default=2 * np.pi, type=float,
|
||
|
help="position / size * scale")
|
||
|
parser.add_argument('--num_feature_levels', default=4, type=int, help='number of feature levels')
|
||
|
|
||
|
|
||
|
# * Transformer
|
||
|
parser.add_argument('--enc_layers', default=6, type=int,
|
||
|
help="Number of encoding layers in the transformer")
|
||
|
parser.add_argument('--dec_layers', default=6, type=int,
|
||
|
help="Number of decoding layers in the transformer")
|
||
|
parser.add_argument('--dim_feedforward', default=1024, type=int,
|
||
|
help="Intermediate size of the feedforward layers in the transformer blocks")
|
||
|
parser.add_argument('--hidden_dim', default=256, type=int,
|
||
|
help="Size of the embeddings (dimension of the transformer)")
|
||
|
parser.add_argument('--dropout', default=0.1, type=float,
|
||
|
help="Dropout applied in the transformer")
|
||
|
parser.add_argument('--nheads', default=8, type=int,
|
||
|
help="Number of attention heads inside the transformer's attentions")
|
||
|
parser.add_argument('--num_queries', default=300, type=int,
|
||
|
help="Number of query slots")
|
||
|
parser.add_argument('--dec_n_points', default=4, type=int)
|
||
|
parser.add_argument('--enc_n_points', default=4, type=int)
|
||
|
parser.add_argument('--n_temporal_decoder_layers', default=1, type=int)
|
||
|
parser.add_argument('--interval1', default=20, type=int)
|
||
|
parser.add_argument('--interval2', default=60, type=int)
|
||
|
|
||
|
parser.add_argument("--fixed_pretrained_model", default=False, action='store_true')
|
||
|
|
||
|
# * Segmentation
|
||
|
parser.add_argument('--masks', action='store_true',
|
||
|
help="Train segmentation head if the flag is provided")
|
||
|
|
||
|
# Loss
|
||
|
parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
|
||
|
help="Disables auxiliary decoding losses (loss at each layer)")
|
||
|
|
||
|
# * Matcher
|
||
|
parser.add_argument('--set_cost_class', default=2, type=float,
|
||
|
help="Class coefficient in the matching cost")
|
||
|
parser.add_argument('--set_cost_bbox', default=5, type=float,
|
||
|
help="L1 box coefficient in the matching cost")
|
||
|
parser.add_argument('--set_cost_giou', default=2, type=float,
|
||
|
help="giou box coefficient in the matching cost")
|
||
|
|
||
|
# * Loss coefficients
|
||
|
parser.add_argument('--mask_loss_coef', default=1, type=float)
|
||
|
parser.add_argument('--dice_loss_coef', default=1, type=float)
|
||
|
parser.add_argument('--cls_loss_coef', default=2, type=float)
|
||
|
parser.add_argument('--bbox_loss_coef', default=5, type=float)
|
||
|
parser.add_argument('--giou_loss_coef', default=2, type=float)
|
||
|
parser.add_argument('--focal_alpha', default=0.25, type=float)
|
||
|
|
||
|
# dataset parameters
|
||
|
parser.add_argument('--dataset_file', default='vid_multi')
|
||
|
parser.add_argument('--coco_path', default='./data/coco', type=str)
|
||
|
parser.add_argument('--vid_path', default='./data/vid', type=str)
|
||
|
parser.add_argument('--coco_pretrain', default=False, action='store_true')
|
||
|
parser.add_argument('--coco_panoptic_path', type=str)
|
||
|
parser.add_argument('--remove_difficult', action='store_true')
|
||
|
|
||
|
parser.add_argument('--output_dir', default='',
|
||
|
help='path where to save, empty for no saving')
|
||
|
parser.add_argument('--device', default='cuda',
|
||
|
help='device to use for training / testing')
|
||
|
parser.add_argument('--seed', default=42, type=int)
|
||
|
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
||
|
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
||
|
help='start epoch')
|
||
|
parser.add_argument('--eval', action='store_true')
|
||
|
parser.add_argument('--num_workers', default=0, type=int)
|
||
|
parser.add_argument('--cache_mode', default=False, action='store_true', help='whether to cache images on memory')
|
||
|
|
||
|
return parser
|
||
|
|
||
|
|
||
|
def main(args):
|
||
|
print(args.dataset_file, 11111111)
|
||
|
if args.dataset_file == "vid_single":
|
||
|
from engine_single import evaluate, train_one_epoch
|
||
|
import util.misc as utils
|
||
|
|
||
|
else:
|
||
|
from engine_multi import evaluate, train_one_epoch
|
||
|
import util.misc_multi as utils
|
||
|
|
||
|
print(args.dataset_file)
|
||
|
device = torch.device(args.device)
|
||
|
utils.init_distributed_mode(args)
|
||
|
print("git:\n {}\n".format(utils.get_sha()))
|
||
|
|
||
|
if args.frozen_weights is not None:
|
||
|
assert args.masks, "Frozen training is meant for segmentation only"
|
||
|
print(args)
|
||
|
|
||
|
|
||
|
# fix the seed for reproducibility
|
||
|
seed = args.seed + utils.get_rank()
|
||
|
torch.manual_seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
random.seed(seed)
|
||
|
|
||
|
model, criterion, postprocessors = build_model(args)
|
||
|
model.to(device)
|
||
|
|
||
|
model_without_ddp = model
|
||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
|
print('number of params:', n_parameters)
|
||
|
|
||
|
# dataset_train = build_dataset(image_set='train_joint', args=args)
|
||
|
dataset_train = build_dataset(image_set='train_vid', args=args)
|
||
|
dataset_val = build_dataset(image_set='val', args=args)
|
||
|
|
||
|
if args.distributed:
|
||
|
if args.cache_mode:
|
||
|
sampler_train = samplers.NodeDistributedSampler(dataset_train)
|
||
|
sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
|
||
|
else:
|
||
|
sampler_train = samplers.DistributedSampler(dataset_train)
|
||
|
sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
|
||
|
else:
|
||
|
sampler_train = torch.utils.data.RandomSampler(dataset_train)
|
||
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||
|
|
||
|
batch_sampler_train = torch.utils.data.BatchSampler(
|
||
|
sampler_train, args.batch_size, drop_last=True)
|
||
|
|
||
|
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
|
||
|
collate_fn=utils.collate_fn, num_workers=args.num_workers,
|
||
|
pin_memory=True)
|
||
|
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
|
||
|
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
|
||
|
pin_memory=True)
|
||
|
|
||
|
# lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
|
||
|
def match_name_keywords(n, name_keywords):
|
||
|
out = False
|
||
|
for b in name_keywords:
|
||
|
if b in n:
|
||
|
out = True
|
||
|
break
|
||
|
return out
|
||
|
|
||
|
for n, p in model_without_ddp.named_parameters():
|
||
|
print(n)
|
||
|
|
||
|
param_dicts = [
|
||
|
{
|
||
|
"params":
|
||
|
[p for n, p in model_without_ddp.named_parameters()
|
||
|
if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
|
||
|
"lr": args.lr,
|
||
|
},
|
||
|
{
|
||
|
"params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad],
|
||
|
"lr": args.lr_backbone,
|
||
|
},
|
||
|
{
|
||
|
"params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
|
||
|
"lr": args.lr * args.lr_linear_proj_mult,
|
||
|
}
|
||
|
]
|
||
|
if args.sgd:
|
||
|
optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9,
|
||
|
weight_decay=args.weight_decay)
|
||
|
else:
|
||
|
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
|
||
|
weight_decay=args.weight_decay)
|
||
|
print(args.lr_drop_epochs)
|
||
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_drop_epochs)
|
||
|
|
||
|
if args.distributed:
|
||
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
||
|
model_without_ddp = model.module
|
||
|
|
||
|
if args.dataset_file == "coco_panoptic":
|
||
|
# We also evaluate AP during panoptic training, on original coco DS
|
||
|
coco_val = datasets.coco.build("val", args)
|
||
|
base_ds = get_coco_api_from_dataset(coco_val)
|
||
|
else:
|
||
|
base_ds = get_coco_api_from_dataset(dataset_val)
|
||
|
|
||
|
if args.frozen_weights is not None:
|
||
|
checkpoint = torch.load(args.frozen_weights, map_location='cpu')
|
||
|
model_without_ddp.detr.load_state_dict(checkpoint['model'])
|
||
|
|
||
|
output_dir = Path(args.output_dir)
|
||
|
if args.resume:
|
||
|
if args.resume.startswith('https'):
|
||
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||
|
args.resume, map_location='cpu', check_hash=True)
|
||
|
else:
|
||
|
checkpoint = torch.load(args.resume, map_location='cpu')
|
||
|
|
||
|
if args.eval:
|
||
|
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
||
|
else:
|
||
|
tmp_dict = model_without_ddp.state_dict().copy()
|
||
|
if args.coco_pretrain: # singleBaseline
|
||
|
for k, v in checkpoint['model'].items():
|
||
|
if ('class_embed' not in k) :
|
||
|
tmp_dict[k] = v
|
||
|
else:
|
||
|
print('k', k)
|
||
|
else:
|
||
|
tmp_dict = checkpoint['model']
|
||
|
for name, param in model_without_ddp.named_parameters():
|
||
|
if ('temp' in name):
|
||
|
param.requires_grad = True
|
||
|
else:
|
||
|
param.requires_grad = False
|
||
|
missing_keys, unexpected_keys = model_without_ddp.load_state_dict(tmp_dict, strict=False)
|
||
|
|
||
|
unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
|
||
|
if len(missing_keys) > 0:
|
||
|
print('Missing Keys: {}'.format(missing_keys))
|
||
|
if len(unexpected_keys) > 0:
|
||
|
print('Unexpected Keys: {}'.format(unexpected_keys))
|
||
|
if args.eval:
|
||
|
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
|
||
|
data_loader_val, base_ds, device, args.output_dir)
|
||
|
if args.output_dir:
|
||
|
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
|
||
|
return
|
||
|
|
||
|
print("Start training")
|
||
|
start_time = time.time()
|
||
|
for epoch in range(args.start_epoch, args.epochs):
|
||
|
if args.distributed:
|
||
|
sampler_train.set_epoch(epoch)
|
||
|
train_stats = train_one_epoch(
|
||
|
model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm)
|
||
|
lr_scheduler.step()
|
||
|
print('args.output_dir', args.output_dir)
|
||
|
if args.output_dir:
|
||
|
checkpoint_paths = [output_dir / 'checkpoint.pth']
|
||
|
# extra checkpoint before LR drop and every 5 epochs
|
||
|
# if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 1 == 0:
|
||
|
if (epoch + 1) % 1 == 0:
|
||
|
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
|
||
|
for checkpoint_path in checkpoint_paths:
|
||
|
utils.save_on_master({
|
||
|
'model': model_without_ddp.state_dict(),
|
||
|
'optimizer': optimizer.state_dict(),
|
||
|
'lr_scheduler': lr_scheduler.state_dict(),
|
||
|
'epoch': epoch,
|
||
|
'args': args,
|
||
|
}, checkpoint_path)
|
||
|
|
||
|
#test_stats, coco_evaluator = evaluate(
|
||
|
# model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
|
||
|
#)
|
||
|
|
||
|
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||
|
'epoch': epoch,
|
||
|
'n_parameters': n_parameters}
|
||
|
|
||
|
if args.output_dir and utils.is_main_process():
|
||
|
with (output_dir / "log.txt").open("a") as f:
|
||
|
f.write(json.dumps(log_stats) + "\n")
|
||
|
|
||
|
total_time = time.time() - start_time
|
||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||
|
print('Training time {}'.format(total_time_str))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser('Deformable DETR training and evaluation script', parents=[get_args_parser()])
|
||
|
args = parser.parse_args()
|
||
|
if args.output_dir:
|
||
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||
|
main(args)
|