forked from PulseFocusPlatform/PulseFocusPlatform
422 lines
16 KiB
Python
422 lines
16 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import cv2
|
|
import glob
|
|
import paddle
|
|
import numpy as np
|
|
|
|
from ppdet.core.workspace import create
|
|
from ppdet.utils.checkpoint import load_weight
|
|
|
|
from ppdet.modeling.mot.utils import Timer, load_det_results
|
|
from ppdet.modeling.mot import visualization as mot_vis
|
|
|
|
from ppdet.metrics import Metric, MOTMetric
|
|
|
|
from .callbacks import Callback, ComposeCallback
|
|
|
|
from ppdet.utils.logger import setup_logger
|
|
logger = setup_logger(__name__)
|
|
|
|
__all__ = ['Tracker']
|
|
|
|
|
|
class Tracker(object):
|
|
def __init__(self, cfg, mode='eval'):
|
|
self.cfg = cfg
|
|
assert mode.lower() in ['test', 'eval'], \
|
|
"mode should be 'test' or 'eval'"
|
|
self.mode = mode.lower()
|
|
self.optimizer = None
|
|
|
|
# build MOT data loader
|
|
self.dataset = cfg['{}MOTDataset'.format(self.mode.capitalize())]
|
|
|
|
# build model
|
|
self.model = create(cfg.architecture)
|
|
|
|
self.status = {}
|
|
self.start_epoch = 0
|
|
|
|
# initial default callbacks
|
|
self._init_callbacks()
|
|
|
|
# initial default metrics
|
|
self._init_metrics()
|
|
self._reset_metrics()
|
|
|
|
def _init_callbacks(self):
|
|
self._callbacks = []
|
|
self._compose_callback = None
|
|
|
|
def _init_metrics(self):
|
|
if self.mode in ['test']:
|
|
self._metrics = []
|
|
return
|
|
|
|
if self.cfg.metric == 'MOT':
|
|
self._metrics = [MOTMetric(), ]
|
|
else:
|
|
logger.warning("Metric not support for metric type {}".format(
|
|
self.cfg.metric))
|
|
self._metrics = []
|
|
|
|
def _reset_metrics(self):
|
|
for metric in self._metrics:
|
|
metric.reset()
|
|
|
|
def register_callbacks(self, callbacks):
|
|
callbacks = [h for h in list(callbacks) if h is not None]
|
|
for c in callbacks:
|
|
assert isinstance(c, Callback), \
|
|
"metrics shoule be instances of subclass of Metric"
|
|
self._callbacks.extend(callbacks)
|
|
self._compose_callback = ComposeCallback(self._callbacks)
|
|
|
|
def register_metrics(self, metrics):
|
|
metrics = [m for m in list(metrics) if m is not None]
|
|
for m in metrics:
|
|
assert isinstance(m, Metric), \
|
|
"metrics shoule be instances of subclass of Metric"
|
|
self._metrics.extend(metrics)
|
|
|
|
def load_weights_jde(self, weights):
|
|
load_weight(self.model, weights, self.optimizer)
|
|
|
|
def load_weights_sde(self, det_weights, reid_weights):
|
|
if self.model.detector:
|
|
load_weight(self.model.detector, det_weights, self.optimizer)
|
|
load_weight(self.model.reid, reid_weights, self.optimizer)
|
|
|
|
def _eval_seq_jde(self,
|
|
dataloader,
|
|
save_dir=None,
|
|
show_image=False,
|
|
frame_rate=30):
|
|
if save_dir:
|
|
if not os.path.exists(save_dir): os.makedirs(save_dir)
|
|
tracker = self.model.tracker
|
|
tracker.max_time_lost = int(frame_rate / 30.0 * tracker.track_buffer)
|
|
|
|
timer = Timer()
|
|
results = []
|
|
frame_id = 0
|
|
self.status['mode'] = 'track'
|
|
self.model.eval()
|
|
for step_id, data in enumerate(dataloader):
|
|
self.status['step_id'] = step_id
|
|
if frame_id % 40 == 0:
|
|
logger.info('Processing frame {} ({:.2f} fps)'.format(
|
|
frame_id, 1. / max(1e-5, timer.average_time)))
|
|
|
|
# forward
|
|
timer.tic()
|
|
online_targets = self.model(data)
|
|
|
|
online_tlwhs, online_ids = [], []
|
|
for t in online_targets:
|
|
tlwh = t.tlwh
|
|
tid = t.track_id
|
|
vertical = tlwh[2] / tlwh[3] > 1.6
|
|
if tlwh[2] * tlwh[3] > tracker.min_box_area and not vertical:
|
|
online_tlwhs.append(tlwh)
|
|
online_ids.append(tid)
|
|
timer.toc()
|
|
|
|
# save results
|
|
results.append((frame_id + 1, online_tlwhs, online_ids))
|
|
self.save_results(data, frame_id, online_ids, online_tlwhs,
|
|
timer.average_time, show_image, save_dir)
|
|
frame_id += 1
|
|
|
|
return results, frame_id, timer.average_time, timer.calls
|
|
|
|
def _eval_seq_sde(self,
|
|
dataloader,
|
|
save_dir=None,
|
|
show_image=False,
|
|
frame_rate=30,
|
|
det_file=''):
|
|
if save_dir:
|
|
if not os.path.exists(save_dir): os.makedirs(save_dir)
|
|
tracker = self.model.tracker
|
|
use_detector = False if not self.model.detector else True
|
|
|
|
timer = Timer()
|
|
results = []
|
|
frame_id = 0
|
|
self.status['mode'] = 'track'
|
|
self.model.eval()
|
|
self.model.reid.eval()
|
|
if not use_detector:
|
|
dets_list = load_det_results(det_file, len(dataloader))
|
|
logger.info('Finish loading detection results file {}.'.format(
|
|
det_file))
|
|
|
|
for step_id, data in enumerate(dataloader):
|
|
self.status['step_id'] = step_id
|
|
if frame_id % 40 == 0:
|
|
logger.info('Processing frame {} ({:.2f} fps)'.format(
|
|
frame_id, 1. / max(1e-5, timer.average_time)))
|
|
|
|
timer.tic()
|
|
if not use_detector:
|
|
timer.tic()
|
|
dets = dets_list[frame_id]
|
|
bbox_tlwh = paddle.to_tensor(dets['bbox'], dtype='float32')
|
|
pred_scores = paddle.to_tensor(dets['score'], dtype='float32')
|
|
if bbox_tlwh.shape[0] > 0:
|
|
pred_bboxes = paddle.concat(
|
|
(bbox_tlwh[:, 0:2],
|
|
bbox_tlwh[:, 2:4] + bbox_tlwh[:, 0:2]),
|
|
axis=1)
|
|
else:
|
|
pred_bboxes = []
|
|
pred_scores = []
|
|
data.update({
|
|
'pred_bboxes': pred_bboxes,
|
|
'pred_scores': pred_scores
|
|
})
|
|
|
|
# forward
|
|
timer.tic()
|
|
online_targets = self.model(data)
|
|
|
|
online_tlwhs = []
|
|
online_ids = []
|
|
for track in online_targets:
|
|
if not track.is_confirmed() or track.time_since_update > 1:
|
|
continue
|
|
tlwh = track.to_tlwh()
|
|
track_id = track.track_id
|
|
online_tlwhs.append(tlwh)
|
|
online_ids.append(track_id)
|
|
timer.toc()
|
|
|
|
# save results
|
|
results.append((frame_id + 1, online_tlwhs, online_ids))
|
|
self.save_results(data, frame_id, online_ids, online_tlwhs,
|
|
timer.average_time, show_image, save_dir)
|
|
frame_id += 1
|
|
|
|
return results, frame_id, timer.average_time, timer.calls
|
|
|
|
def mot_evaluate(self,
|
|
data_root,
|
|
seqs,
|
|
output_dir,
|
|
data_type='mot',
|
|
model_type='JDE',
|
|
save_images=False,
|
|
save_videos=False,
|
|
show_image=False,
|
|
det_results_dir=''):
|
|
if not os.path.exists(output_dir): os.makedirs(output_dir)
|
|
result_root = os.path.join(output_dir, 'mot_results')
|
|
if not os.path.exists(result_root): os.makedirs(result_root)
|
|
assert data_type in ['mot', 'kitti'], \
|
|
"data_type should be 'mot' or 'kitti'"
|
|
assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \
|
|
"model_type should be 'JDE', 'DeepSORT' or 'FairMOT'"
|
|
|
|
# run tracking
|
|
n_frame = 0
|
|
timer_avgs, timer_calls = [], []
|
|
for seq in seqs:
|
|
save_dir = os.path.join(output_dir, 'mot_outputs',
|
|
seq) if save_images or save_videos else None
|
|
logger.info('start seq: {}'.format(seq))
|
|
|
|
infer_dir = os.path.join(data_root, seq, 'img1')
|
|
images = self.get_infer_images(infer_dir)
|
|
self.dataset.set_images(images)
|
|
|
|
dataloader = create('EvalMOTReader')(self.dataset, 0)
|
|
|
|
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
|
|
meta_info = open(os.path.join(data_root, seq, 'seqinfo.ini')).read()
|
|
frame_rate = int(meta_info[meta_info.find('frameRate') + 10:
|
|
meta_info.find('\nseqLength')])
|
|
|
|
if model_type in ['JDE', 'FairMOT']:
|
|
results, nf, ta, tc = self._eval_seq_jde(
|
|
dataloader,
|
|
save_dir=save_dir,
|
|
show_image=show_image,
|
|
frame_rate=frame_rate)
|
|
elif model_type in ['DeepSORT']:
|
|
results, nf, ta, tc = self._eval_seq_sde(
|
|
dataloader,
|
|
save_dir=save_dir,
|
|
show_image=show_image,
|
|
frame_rate=frame_rate,
|
|
det_file=os.path.join(det_results_dir,
|
|
'{}.txt'.format(seq)))
|
|
else:
|
|
raise ValueError(model_type)
|
|
|
|
self.write_mot_results(result_filename, results, data_type)
|
|
n_frame += nf
|
|
timer_avgs.append(ta)
|
|
timer_calls.append(tc)
|
|
|
|
if save_videos:
|
|
output_video_path = os.path.join(save_dir, '..',
|
|
'{}_vis.mp4'.format(seq))
|
|
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
|
|
save_dir, output_video_path)
|
|
os.system(cmd_str)
|
|
logger.info('Save video in {}.'.format(output_video_path))
|
|
|
|
logger.info('Evaluate seq: {}'.format(seq))
|
|
# update metrics
|
|
for metric in self._metrics:
|
|
metric.update(data_root, seq, data_type, result_root,
|
|
result_filename)
|
|
|
|
timer_avgs = np.asarray(timer_avgs)
|
|
timer_calls = np.asarray(timer_calls)
|
|
all_time = np.dot(timer_avgs, timer_calls)
|
|
avg_time = all_time / np.sum(timer_calls)
|
|
logger.info('Time elapsed: {:.2f} seconds, FPS: {:.2f}'.format(
|
|
all_time, 1.0 / avg_time))
|
|
|
|
# accumulate metric to log out
|
|
for metric in self._metrics:
|
|
metric.accumulate()
|
|
metric.log()
|
|
# reset metric states for metric may performed multiple times
|
|
self._reset_metrics()
|
|
|
|
def get_infer_images(self, infer_dir):
|
|
assert infer_dir is None or os.path.isdir(infer_dir), \
|
|
"{} is not a directory".format(infer_dir)
|
|
images = set()
|
|
assert os.path.isdir(infer_dir), \
|
|
"infer_dir {} is not a directory".format(infer_dir)
|
|
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
|
exts += [ext.upper() for ext in exts]
|
|
for ext in exts:
|
|
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
|
|
images = list(images)
|
|
images.sort()
|
|
assert len(images) > 0, "no image found in {}".format(infer_dir)
|
|
logger.info("Found {} inference images in total.".format(len(images)))
|
|
return images
|
|
|
|
def mot_predict(self,
|
|
video_file,
|
|
output_dir,
|
|
data_type='mot',
|
|
model_type='JDE',
|
|
save_images=False,
|
|
save_videos=True,
|
|
show_image=False,
|
|
det_results_dir=''):
|
|
if not os.path.exists(output_dir): os.makedirs(output_dir)
|
|
result_root = os.path.join(output_dir, 'mot_results')
|
|
if not os.path.exists(result_root): os.makedirs(result_root)
|
|
assert data_type in ['mot', 'kitti'], \
|
|
"data_type should be 'mot' or 'kitti'"
|
|
assert model_type in ['JDE', 'DeepSORT', 'FairMOT'], \
|
|
"model_type should be 'JDE', 'DeepSORT' or 'FairMOT'"
|
|
|
|
# run tracking
|
|
seq = video_file.split('/')[-1].split('.')[0]
|
|
save_dir = os.path.join(output_dir, 'mot_outputs',
|
|
seq) if save_images or save_videos else None
|
|
logger.info('Starting tracking {}'.format(video_file))
|
|
|
|
self.dataset.set_video(video_file)
|
|
dataloader = create('TestMOTReader')(self.dataset, 0)
|
|
result_filename = os.path.join(result_root, '{}.txt'.format(seq))
|
|
frame_rate = self.dataset.frame_rate
|
|
|
|
if model_type in ['JDE', 'FairMOT']:
|
|
results, nf, ta, tc = self._eval_seq_jde(
|
|
dataloader,
|
|
save_dir=save_dir,
|
|
show_image=show_image,
|
|
frame_rate=frame_rate)
|
|
elif model_type in ['DeepSORT']:
|
|
results, nf, ta, tc = self._eval_seq_sde(
|
|
dataloader,
|
|
save_dir=save_dir,
|
|
show_image=show_image,
|
|
frame_rate=frame_rate,
|
|
det_file=os.path.join(det_results_dir, '{}.txt'.format(seq)))
|
|
else:
|
|
raise ValueError(model_type)
|
|
|
|
if save_videos:
|
|
output_video_path = os.path.join(save_dir, '..',
|
|
'{}_vis.mp4'.format(seq))
|
|
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
|
|
save_dir, output_video_path)
|
|
os.system(cmd_str)
|
|
logger.info('Save video in {}'.format(output_video_path))
|
|
|
|
def write_mot_results(self, filename, results, data_type='mot'):
|
|
if data_type in ['mot', 'mcmot', 'lab']:
|
|
save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
|
|
elif data_type == 'kitti':
|
|
save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
|
|
else:
|
|
raise ValueError(data_type)
|
|
|
|
with open(filename, 'w') as f:
|
|
for frame_id, tlwhs, track_ids in results:
|
|
if data_type == 'kitti':
|
|
frame_id -= 1
|
|
for tlwh, track_id in zip(tlwhs, track_ids):
|
|
if track_id < 0:
|
|
continue
|
|
x1, y1, w, h = tlwh
|
|
x2, y2 = x1 + w, y1 + h
|
|
line = save_format.format(
|
|
frame=frame_id,
|
|
id=track_id,
|
|
x1=x1,
|
|
y1=y1,
|
|
x2=x2,
|
|
y2=y2,
|
|
w=w,
|
|
h=h)
|
|
f.write(line)
|
|
logger.info('MOT results save in {}'.format(filename))
|
|
|
|
def save_results(self, data, frame_id, online_ids, online_tlwhs,
|
|
average_time, show_image, save_dir):
|
|
if show_image or save_dir is not None:
|
|
assert 'ori_image' in data
|
|
img0 = data['ori_image'].numpy()[0]
|
|
online_im = mot_vis.plot_tracking(
|
|
img0,
|
|
online_tlwhs,
|
|
online_ids,
|
|
frame_id=frame_id,
|
|
fps=1. / average_time)
|
|
if show_image:
|
|
cv2.imshow('online_im', online_im)
|
|
if save_dir is not None:
|
|
cv2.imwrite(
|
|
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
|
|
online_im)
|