forked from PulseFocusPlatform/PulseFocusPlatform
259 lines
7.7 KiB
Python
259 lines
7.7 KiB
Python
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import os
|
||
|
import re
|
||
|
import glob
|
||
|
|
||
|
import numpy as np
|
||
|
from multiprocessing import Pool
|
||
|
from functools import partial
|
||
|
from shapely.geometry import Polygon
|
||
|
import argparse
|
||
|
|
||
|
nms_thresh = 0.1
|
||
|
|
||
|
class_name_15 = [
|
||
|
'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
|
||
|
'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
|
||
|
'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
|
||
|
'harbor', 'swimming-pool', 'helicopter'
|
||
|
]
|
||
|
|
||
|
class_name_16 = [
|
||
|
'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
|
||
|
'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
|
||
|
'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
|
||
|
'harbor', 'swimming-pool', 'helicopter', 'container-crane'
|
||
|
]
|
||
|
|
||
|
|
||
|
def rbox_iou(g, p):
|
||
|
"""
|
||
|
iou of rbox
|
||
|
"""
|
||
|
g = np.array(g)
|
||
|
p = np.array(p)
|
||
|
g = Polygon(g[:8].reshape((4, 2)))
|
||
|
p = Polygon(p[:8].reshape((4, 2)))
|
||
|
g = g.buffer(0)
|
||
|
p = p.buffer(0)
|
||
|
if not g.is_valid or not p.is_valid:
|
||
|
return 0
|
||
|
inter = Polygon(g).intersection(Polygon(p)).area
|
||
|
union = g.area + p.area - inter
|
||
|
if union == 0:
|
||
|
return 0
|
||
|
else:
|
||
|
return inter / union
|
||
|
|
||
|
|
||
|
def py_cpu_nms_poly_fast(dets, thresh):
|
||
|
"""
|
||
|
Args:
|
||
|
dets: pred results
|
||
|
thresh: nms threshold
|
||
|
|
||
|
Returns: index of keep
|
||
|
"""
|
||
|
obbs = dets[:, 0:-1]
|
||
|
x1 = np.min(obbs[:, 0::2], axis=1)
|
||
|
y1 = np.min(obbs[:, 1::2], axis=1)
|
||
|
x2 = np.max(obbs[:, 0::2], axis=1)
|
||
|
y2 = np.max(obbs[:, 1::2], axis=1)
|
||
|
scores = dets[:, 8]
|
||
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||
|
|
||
|
polys = []
|
||
|
for i in range(len(dets)):
|
||
|
tm_polygon = [
|
||
|
dets[i][0], dets[i][1], dets[i][2], dets[i][3], dets[i][4],
|
||
|
dets[i][5], dets[i][6], dets[i][7]
|
||
|
]
|
||
|
polys.append(tm_polygon)
|
||
|
polys = np.array(polys)
|
||
|
order = scores.argsort()[::-1]
|
||
|
|
||
|
keep = []
|
||
|
while order.size > 0:
|
||
|
ovr = []
|
||
|
i = order[0]
|
||
|
keep.append(i)
|
||
|
|
||
|
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||
|
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||
|
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||
|
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||
|
w = np.maximum(0.0, xx2 - xx1)
|
||
|
h = np.maximum(0.0, yy2 - yy1)
|
||
|
hbb_inter = w * h
|
||
|
hbb_ovr = hbb_inter / (areas[i] + areas[order[1:]] - hbb_inter)
|
||
|
# h_keep_inds = np.where(hbb_ovr == 0)[0]
|
||
|
h_inds = np.where(hbb_ovr > 0)[0]
|
||
|
tmp_order = order[h_inds + 1]
|
||
|
for j in range(tmp_order.size):
|
||
|
iou = rbox_iou(polys[i], polys[tmp_order[j]])
|
||
|
hbb_ovr[h_inds[j]] = iou
|
||
|
# ovr.append(iou)
|
||
|
# ovr_index.append(tmp_order[j])
|
||
|
|
||
|
try:
|
||
|
if math.isnan(ovr[0]):
|
||
|
pdb.set_trace()
|
||
|
except:
|
||
|
pass
|
||
|
inds = np.where(hbb_ovr <= thresh)[0]
|
||
|
|
||
|
order = order[inds + 1]
|
||
|
return keep
|
||
|
|
||
|
|
||
|
def poly2origpoly(poly, x, y, rate):
|
||
|
origpoly = []
|
||
|
for i in range(int(len(poly) / 2)):
|
||
|
tmp_x = float(poly[i * 2] + x) / float(rate)
|
||
|
tmp_y = float(poly[i * 2 + 1] + y) / float(rate)
|
||
|
origpoly.append(tmp_x)
|
||
|
origpoly.append(tmp_y)
|
||
|
return origpoly
|
||
|
|
||
|
|
||
|
def nmsbynamedict(nameboxdict, nms, thresh):
|
||
|
"""
|
||
|
Args:
|
||
|
nameboxdict: nameboxdict
|
||
|
nms: nms
|
||
|
thresh: nms threshold
|
||
|
|
||
|
Returns: nms result as dict
|
||
|
"""
|
||
|
nameboxnmsdict = {x: [] for x in nameboxdict}
|
||
|
for imgname in nameboxdict:
|
||
|
keep = nms(np.array(nameboxdict[imgname]), thresh)
|
||
|
outdets = []
|
||
|
for index in keep:
|
||
|
outdets.append(nameboxdict[imgname][index])
|
||
|
nameboxnmsdict[imgname] = outdets
|
||
|
return nameboxnmsdict
|
||
|
|
||
|
|
||
|
def merge_single(output_dir, nms, pred_class_lst):
|
||
|
"""
|
||
|
Args:
|
||
|
output_dir: output_dir
|
||
|
nms: nms
|
||
|
pred_class_lst: pred_class_lst
|
||
|
class_name: class_name
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
"""
|
||
|
class_name, pred_bbox_list = pred_class_lst
|
||
|
nameboxdict = {}
|
||
|
for line in pred_bbox_list:
|
||
|
splitline = line.split(' ')
|
||
|
subname = splitline[0]
|
||
|
splitname = subname.split('__')
|
||
|
oriname = splitname[0]
|
||
|
pattern1 = re.compile(r'__\d+___\d+')
|
||
|
x_y = re.findall(pattern1, subname)
|
||
|
x_y_2 = re.findall(r'\d+', x_y[0])
|
||
|
x, y = int(x_y_2[0]), int(x_y_2[1])
|
||
|
|
||
|
pattern2 = re.compile(r'__([\d+\.]+)__\d+___')
|
||
|
|
||
|
rate = re.findall(pattern2, subname)[0]
|
||
|
|
||
|
confidence = splitline[1]
|
||
|
poly = list(map(float, splitline[2:]))
|
||
|
origpoly = poly2origpoly(poly, x, y, rate)
|
||
|
det = origpoly
|
||
|
det.append(confidence)
|
||
|
det = list(map(float, det))
|
||
|
if (oriname not in nameboxdict):
|
||
|
nameboxdict[oriname] = []
|
||
|
nameboxdict[oriname].append(det)
|
||
|
nameboxnmsdict = nmsbynamedict(nameboxdict, nms, nms_thresh)
|
||
|
|
||
|
# write result
|
||
|
dstname = os.path.join(output_dir, class_name + '.txt')
|
||
|
with open(dstname, 'w') as f_out:
|
||
|
for imgname in nameboxnmsdict:
|
||
|
for det in nameboxnmsdict[imgname]:
|
||
|
confidence = det[-1]
|
||
|
bbox = det[0:-1]
|
||
|
outline = imgname + ' ' + str(confidence) + ' ' + ' '.join(
|
||
|
map(str, bbox))
|
||
|
f_out.write(outline + '\n')
|
||
|
|
||
|
|
||
|
def dota_generate_test_result(pred_txt_dir,
|
||
|
output_dir='output',
|
||
|
dota_version='v1.0'):
|
||
|
"""
|
||
|
pred_txt_dir: dir of pred txt
|
||
|
output_dir: dir of output
|
||
|
dota_version: dota_version v1.0 or v1.5 or v2.0
|
||
|
"""
|
||
|
pred_txt_list = glob.glob("{}/*.txt".format(pred_txt_dir))
|
||
|
|
||
|
# step1: summary pred bbox
|
||
|
pred_classes = {}
|
||
|
class_lst = class_name_15 if dota_version == 'v1.0' else class_name_16
|
||
|
for class_name in class_lst:
|
||
|
pred_classes[class_name] = []
|
||
|
|
||
|
for current_txt in pred_txt_list:
|
||
|
img_id = os.path.split(current_txt)[1]
|
||
|
img_id = img_id.split('.txt')[0]
|
||
|
with open(current_txt) as f:
|
||
|
res = f.readlines()
|
||
|
for item in res:
|
||
|
item = item.split(' ')
|
||
|
pred_class = item[0]
|
||
|
item[0] = img_id
|
||
|
pred_bbox = ' '.join(item)
|
||
|
pred_classes[pred_class].append(pred_bbox)
|
||
|
|
||
|
pred_classes_lst = []
|
||
|
for class_name in pred_classes.keys():
|
||
|
print('class_name: {}, count: {}'.format(class_name,
|
||
|
len(pred_classes[class_name])))
|
||
|
pred_classes_lst.append((class_name, pred_classes[class_name]))
|
||
|
|
||
|
# step2: merge
|
||
|
pool = Pool(len(class_lst))
|
||
|
nms = py_cpu_nms_poly_fast
|
||
|
mergesingle_fn = partial(merge_single, output_dir, nms)
|
||
|
pool.map(mergesingle_fn, pred_classes_lst)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser(description='dota anno to coco')
|
||
|
parser.add_argument('--pred_txt_dir', help='path of pred txt dir')
|
||
|
parser.add_argument(
|
||
|
'--output_dir', help='path of output dir', default='output')
|
||
|
parser.add_argument(
|
||
|
'--dota_version',
|
||
|
help='dota_version, v1.0 or v1.5 or v2.0',
|
||
|
type=str,
|
||
|
default='v1.0')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# process
|
||
|
dota_generate_test_result(args.pred_txt_dir, args.output_dir,
|
||
|
args.dota_version)
|
||
|
print('done!')
|