forked from PulseFocusPlatform/PulseFocusPlatform
192 lines
6.4 KiB
Python
192 lines
6.4 KiB
Python
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import os
|
||
|
import numpy as np
|
||
|
import copy
|
||
|
import motmetrics as mm
|
||
|
mm.lap.default_solver = 'lap'
|
||
|
|
||
|
__all__ = [
|
||
|
'read_mot_results',
|
||
|
'unzip_objs',
|
||
|
'MOTEvaluator',
|
||
|
]
|
||
|
|
||
|
|
||
|
def read_mot_results(filename, is_gt=False, is_ignore=False):
|
||
|
valid_labels = {1}
|
||
|
ignore_labels = {2, 7, 8, 12}
|
||
|
results_dict = dict()
|
||
|
if os.path.isfile(filename):
|
||
|
with open(filename, 'r') as f:
|
||
|
for line in f.readlines():
|
||
|
linelist = line.split(',')
|
||
|
if len(linelist) < 7:
|
||
|
continue
|
||
|
fid = int(linelist[0])
|
||
|
if fid < 1:
|
||
|
continue
|
||
|
results_dict.setdefault(fid, list())
|
||
|
|
||
|
box_size = float(linelist[4]) * float(linelist[5])
|
||
|
|
||
|
if is_gt:
|
||
|
if 'MOT16-' in filename or 'MOT17-' in filename or 'MOT15-' in filename or 'MOT20-' in filename:
|
||
|
label = int(float(linelist[7]))
|
||
|
mark = int(float(linelist[6]))
|
||
|
if mark == 0 or label not in valid_labels:
|
||
|
continue
|
||
|
score = 1
|
||
|
elif is_ignore:
|
||
|
if 'MOT16-' in filename or 'MOT17-' in filename or 'MOT15-' in filename or 'MOT20-' in filename:
|
||
|
label = int(float(linelist[7]))
|
||
|
vis_ratio = float(linelist[8])
|
||
|
if label not in ignore_labels and vis_ratio >= 0:
|
||
|
continue
|
||
|
else:
|
||
|
continue
|
||
|
score = 1
|
||
|
else:
|
||
|
score = float(linelist[6])
|
||
|
|
||
|
tlwh = tuple(map(float, linelist[2:6]))
|
||
|
target_id = int(linelist[1])
|
||
|
|
||
|
results_dict[fid].append((tlwh, target_id, score))
|
||
|
return results_dict
|
||
|
|
||
|
|
||
|
"""
|
||
|
labels={'ped', ... % 1
|
||
|
'person_on_vhcl', ... % 2
|
||
|
'car', ... % 3
|
||
|
'bicycle', ... % 4
|
||
|
'mbike', ... % 5
|
||
|
'non_mot_vhcl', ... % 6
|
||
|
'static_person', ... % 7
|
||
|
'distractor', ... % 8
|
||
|
'occluder', ... % 9
|
||
|
'occluder_on_grnd', ... % 10
|
||
|
'occluder_full', ... % 11
|
||
|
'reflection', ... % 12
|
||
|
'crowd' ... % 13
|
||
|
};
|
||
|
"""
|
||
|
|
||
|
|
||
|
def unzip_objs(objs):
|
||
|
if len(objs) > 0:
|
||
|
tlwhs, ids, scores = zip(*objs)
|
||
|
else:
|
||
|
tlwhs, ids, scores = [], [], []
|
||
|
tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
|
||
|
return tlwhs, ids, scores
|
||
|
|
||
|
|
||
|
class MOTEvaluator(object):
|
||
|
def __init__(self, data_root, seq_name, data_type):
|
||
|
self.data_root = data_root
|
||
|
self.seq_name = seq_name
|
||
|
self.data_type = data_type
|
||
|
|
||
|
self.load_annotations()
|
||
|
self.reset_accumulator()
|
||
|
|
||
|
def load_annotations(self):
|
||
|
assert self.data_type == 'mot'
|
||
|
gt_filename = os.path.join(self.data_root, self.seq_name, 'gt',
|
||
|
'gt.txt')
|
||
|
self.gt_frame_dict = read_mot_results(gt_filename, is_gt=True)
|
||
|
self.gt_ignore_frame_dict = read_mot_results(
|
||
|
gt_filename, is_ignore=True)
|
||
|
|
||
|
def reset_accumulator(self):
|
||
|
self.acc = mm.MOTAccumulator(auto_id=True)
|
||
|
|
||
|
def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
|
||
|
# results
|
||
|
trk_tlwhs = np.copy(trk_tlwhs)
|
||
|
trk_ids = np.copy(trk_ids)
|
||
|
|
||
|
# gts
|
||
|
gt_objs = self.gt_frame_dict.get(frame_id, [])
|
||
|
gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
|
||
|
|
||
|
# ignore boxes
|
||
|
ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
|
||
|
ignore_tlwhs = unzip_objs(ignore_objs)[0]
|
||
|
|
||
|
# remove ignored results
|
||
|
keep = np.ones(len(trk_tlwhs), dtype=bool)
|
||
|
iou_distance = mm.distances.iou_matrix(
|
||
|
ignore_tlwhs, trk_tlwhs, max_iou=0.5)
|
||
|
if len(iou_distance) > 0:
|
||
|
match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
|
||
|
match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
|
||
|
match_ious = iou_distance[match_is, match_js]
|
||
|
|
||
|
match_js = np.asarray(match_js, dtype=int)
|
||
|
match_js = match_js[np.logical_not(np.isnan(match_ious))]
|
||
|
keep[match_js] = False
|
||
|
trk_tlwhs = trk_tlwhs[keep]
|
||
|
trk_ids = trk_ids[keep]
|
||
|
|
||
|
# get distance matrix
|
||
|
iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
|
||
|
|
||
|
# acc
|
||
|
self.acc.update(gt_ids, trk_ids, iou_distance)
|
||
|
|
||
|
if rtn_events and iou_distance.size > 0 and hasattr(self.acc,
|
||
|
'last_mot_events'):
|
||
|
events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
|
||
|
else:
|
||
|
events = None
|
||
|
return events
|
||
|
|
||
|
def eval_file(self, filename):
|
||
|
self.reset_accumulator()
|
||
|
|
||
|
result_frame_dict = read_mot_results(filename, is_gt=False)
|
||
|
frames = sorted(list(set(result_frame_dict.keys())))
|
||
|
for frame_id in frames:
|
||
|
trk_objs = result_frame_dict.get(frame_id, [])
|
||
|
trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
|
||
|
self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
|
||
|
|
||
|
return self.acc
|
||
|
|
||
|
@staticmethod
|
||
|
def get_summary(accs,
|
||
|
names,
|
||
|
metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1',
|
||
|
'precision', 'recall')):
|
||
|
names = copy.deepcopy(names)
|
||
|
if metrics is None:
|
||
|
metrics = mm.metrics.motchallenge_metrics
|
||
|
metrics = copy.deepcopy(metrics)
|
||
|
|
||
|
mh = mm.metrics.create()
|
||
|
summary = mh.compute_many(
|
||
|
accs, metrics=metrics, names=names, generate_overall=True)
|
||
|
return summary
|
||
|
|
||
|
@staticmethod
|
||
|
def save_summary(summary, filename):
|
||
|
import pandas as pd
|
||
|
writer = pd.ExcelWriter(filename)
|
||
|
summary.to_excel(writer)
|
||
|
writer.save()
|