PulseFocusPlatform/dataset/dota_coco/dota_generate_test_result.py

259 lines
7.7 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import glob
import numpy as np
from multiprocessing import Pool
from functools import partial
from shapely.geometry import Polygon
import argparse
nms_thresh = 0.1
class_name_15 = [
'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
'harbor', 'swimming-pool', 'helicopter'
]
class_name_16 = [
'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
'harbor', 'swimming-pool', 'helicopter', 'container-crane'
]
def rbox_iou(g, p):
"""
iou of rbox
"""
g = np.array(g)
p = np.array(p)
g = Polygon(g[:8].reshape((4, 2)))
p = Polygon(p[:8].reshape((4, 2)))
g = g.buffer(0)
p = p.buffer(0)
if not g.is_valid or not p.is_valid:
return 0
inter = Polygon(g).intersection(Polygon(p)).area
union = g.area + p.area - inter
if union == 0:
return 0
else:
return inter / union
def py_cpu_nms_poly_fast(dets, thresh):
"""
Args:
dets: pred results
thresh: nms threshold
Returns: index of keep
"""
obbs = dets[:, 0:-1]
x1 = np.min(obbs[:, 0::2], axis=1)
y1 = np.min(obbs[:, 1::2], axis=1)
x2 = np.max(obbs[:, 0::2], axis=1)
y2 = np.max(obbs[:, 1::2], axis=1)
scores = dets[:, 8]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
polys = []
for i in range(len(dets)):
tm_polygon = [
dets[i][0], dets[i][1], dets[i][2], dets[i][3], dets[i][4],
dets[i][5], dets[i][6], dets[i][7]
]
polys.append(tm_polygon)
polys = np.array(polys)
order = scores.argsort()[::-1]
keep = []
while order.size > 0:
ovr = []
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1)
h = np.maximum(0.0, yy2 - yy1)
hbb_inter = w * h
hbb_ovr = hbb_inter / (areas[i] + areas[order[1:]] - hbb_inter)
# h_keep_inds = np.where(hbb_ovr == 0)[0]
h_inds = np.where(hbb_ovr > 0)[0]
tmp_order = order[h_inds + 1]
for j in range(tmp_order.size):
iou = rbox_iou(polys[i], polys[tmp_order[j]])
hbb_ovr[h_inds[j]] = iou
# ovr.append(iou)
# ovr_index.append(tmp_order[j])
try:
if math.isnan(ovr[0]):
pdb.set_trace()
except:
pass
inds = np.where(hbb_ovr <= thresh)[0]
order = order[inds + 1]
return keep
def poly2origpoly(poly, x, y, rate):
origpoly = []
for i in range(int(len(poly) / 2)):
tmp_x = float(poly[i * 2] + x) / float(rate)
tmp_y = float(poly[i * 2 + 1] + y) / float(rate)
origpoly.append(tmp_x)
origpoly.append(tmp_y)
return origpoly
def nmsbynamedict(nameboxdict, nms, thresh):
"""
Args:
nameboxdict: nameboxdict
nms: nms
thresh: nms threshold
Returns: nms result as dict
"""
nameboxnmsdict = {x: [] for x in nameboxdict}
for imgname in nameboxdict:
keep = nms(np.array(nameboxdict[imgname]), thresh)
outdets = []
for index in keep:
outdets.append(nameboxdict[imgname][index])
nameboxnmsdict[imgname] = outdets
return nameboxnmsdict
def merge_single(output_dir, nms, pred_class_lst):
"""
Args:
output_dir: output_dir
nms: nms
pred_class_lst: pred_class_lst
class_name: class_name
Returns:
"""
class_name, pred_bbox_list = pred_class_lst
nameboxdict = {}
for line in pred_bbox_list:
splitline = line.split(' ')
subname = splitline[0]
splitname = subname.split('__')
oriname = splitname[0]
pattern1 = re.compile(r'__\d+___\d+')
x_y = re.findall(pattern1, subname)
x_y_2 = re.findall(r'\d+', x_y[0])
x, y = int(x_y_2[0]), int(x_y_2[1])
pattern2 = re.compile(r'__([\d+\.]+)__\d+___')
rate = re.findall(pattern2, subname)[0]
confidence = splitline[1]
poly = list(map(float, splitline[2:]))
origpoly = poly2origpoly(poly, x, y, rate)
det = origpoly
det.append(confidence)
det = list(map(float, det))
if (oriname not in nameboxdict):
nameboxdict[oriname] = []
nameboxdict[oriname].append(det)
nameboxnmsdict = nmsbynamedict(nameboxdict, nms, nms_thresh)
# write result
dstname = os.path.join(output_dir, class_name + '.txt')
with open(dstname, 'w') as f_out:
for imgname in nameboxnmsdict:
for det in nameboxnmsdict[imgname]:
confidence = det[-1]
bbox = det[0:-1]
outline = imgname + ' ' + str(confidence) + ' ' + ' '.join(
map(str, bbox))
f_out.write(outline + '\n')
def dota_generate_test_result(pred_txt_dir,
output_dir='output',
dota_version='v1.0'):
"""
pred_txt_dir: dir of pred txt
output_dir: dir of output
dota_version: dota_version v1.0 or v1.5 or v2.0
"""
pred_txt_list = glob.glob("{}/*.txt".format(pred_txt_dir))
# step1: summary pred bbox
pred_classes = {}
class_lst = class_name_15 if dota_version == 'v1.0' else class_name_16
for class_name in class_lst:
pred_classes[class_name] = []
for current_txt in pred_txt_list:
img_id = os.path.split(current_txt)[1]
img_id = img_id.split('.txt')[0]
with open(current_txt) as f:
res = f.readlines()
for item in res:
item = item.split(' ')
pred_class = item[0]
item[0] = img_id
pred_bbox = ' '.join(item)
pred_classes[pred_class].append(pred_bbox)
pred_classes_lst = []
for class_name in pred_classes.keys():
print('class_name: {}, count: {}'.format(class_name,
len(pred_classes[class_name])))
pred_classes_lst.append((class_name, pred_classes[class_name]))
# step2: merge
pool = Pool(len(class_lst))
nms = py_cpu_nms_poly_fast
mergesingle_fn = partial(merge_single, output_dir, nms)
pool.map(mergesingle_fn, pred_classes_lst)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='dota anno to coco')
parser.add_argument('--pred_txt_dir', help='path of pred txt dir')
parser.add_argument(
'--output_dir', help='path of output dir', default='output')
parser.add_argument(
'--dota_version',
help='dota_version, v1.0 or v1.5 or v2.0',
type=str,
default='v1.0')
args = parser.parse_args()
# process
dota_generate_test_result(args.pred_txt_dir, args.output_dir,
args.dota_version)
print('done!')