forked from PulseFocusPlatform/PulseFocusPlatform
200 lines
7.2 KiB
Python
200 lines
7.2 KiB
Python
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import os
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
import paddle
|
||
|
|
||
|
from topdown_unite_utils import argsparser
|
||
|
from preprocess import decode_image
|
||
|
from infer import Detector, PredictConfig, print_arguments, get_test_images
|
||
|
from keypoint_infer import KeyPoint_Detector, PredictConfig_KeyPoint
|
||
|
from keypoint_visualize import draw_pose
|
||
|
|
||
|
|
||
|
def expand_crop(images, rect, expand_ratio=0.3):
|
||
|
imgh, imgw, c = images.shape
|
||
|
label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()]
|
||
|
if label != 0:
|
||
|
return None, None, None
|
||
|
org_rect = [xmin, ymin, xmax, ymax]
|
||
|
h_half = (ymax - ymin) * (1 + expand_ratio) / 2.
|
||
|
w_half = (xmax - xmin) * (1 + expand_ratio) / 2.
|
||
|
if h_half > w_half * 4 / 3:
|
||
|
w_half = h_half * 0.75
|
||
|
center = [(ymin + ymax) / 2., (xmin + xmax) / 2.]
|
||
|
ymin = max(0, int(center[0] - h_half))
|
||
|
ymax = min(imgh - 1, int(center[0] + h_half))
|
||
|
xmin = max(0, int(center[1] - w_half))
|
||
|
xmax = min(imgw - 1, int(center[1] + w_half))
|
||
|
return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax], org_rect
|
||
|
|
||
|
|
||
|
def get_person_from_rect(images, results):
|
||
|
det_results = results['boxes']
|
||
|
mask = det_results[:, 1] > FLAGS.det_threshold
|
||
|
valid_rects = det_results[mask]
|
||
|
image_buff = []
|
||
|
org_rects = []
|
||
|
for rect in valid_rects:
|
||
|
rect_image, new_rect, org_rect = expand_crop(images, rect)
|
||
|
if rect_image is None or rect_image.size == 0:
|
||
|
continue
|
||
|
image_buff.append([rect_image, new_rect])
|
||
|
org_rects.append(org_rect)
|
||
|
return image_buff, org_rects
|
||
|
|
||
|
|
||
|
def affine_backto_orgimages(keypoint_result, batch_records):
|
||
|
kpts, scores = keypoint_result['keypoint']
|
||
|
kpts[..., 0] += batch_records[0]
|
||
|
kpts[..., 1] += batch_records[1]
|
||
|
return kpts, scores
|
||
|
|
||
|
|
||
|
def topdown_unite_predict(detector, topdown_keypoint_detector, image_list):
|
||
|
for i, img_file in enumerate(image_list):
|
||
|
image, _ = decode_image(img_file, {})
|
||
|
results = detector.predict([image], FLAGS.det_threshold)
|
||
|
batchs_images, det_rects = get_person_from_rect(image, results)
|
||
|
keypoint_vector = []
|
||
|
score_vector = []
|
||
|
rect_vecotr = det_rects
|
||
|
for batch_images, batch_records in batchs_images:
|
||
|
keypoint_result = topdown_keypoint_detector.predict(
|
||
|
batch_images, FLAGS.keypoint_threshold)
|
||
|
orgkeypoints, scores = affine_backto_orgimages(keypoint_result,
|
||
|
batch_records)
|
||
|
keypoint_vector.append(orgkeypoints)
|
||
|
score_vector.append(scores)
|
||
|
keypoint_res = {}
|
||
|
keypoint_res['keypoint'] = [
|
||
|
np.vstack(keypoint_vector), np.vstack(score_vector)
|
||
|
]
|
||
|
keypoint_res['bbox'] = rect_vecotr
|
||
|
if not os.path.exists(FLAGS.output_dir):
|
||
|
os.makedirs(FLAGS.output_dir)
|
||
|
draw_pose(
|
||
|
img_file,
|
||
|
keypoint_res,
|
||
|
visual_thread=FLAGS.keypoint_threshold,
|
||
|
save_dir=FLAGS.output_dir)
|
||
|
|
||
|
|
||
|
def topdown_unite_predict_video(detector, topdown_keypoint_detector, camera_id):
|
||
|
if camera_id != -1:
|
||
|
capture = cv2.VideoCapture(camera_id)
|
||
|
video_name = 'output.mp4'
|
||
|
else:
|
||
|
capture = cv2.VideoCapture(FLAGS.video_file)
|
||
|
video_name = os.path.splitext(os.path.basename(FLAGS.video_file))[
|
||
|
0] + '.mp4'
|
||
|
fps = 30
|
||
|
width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
|
height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
# yapf: disable
|
||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||
|
# yapf: enable
|
||
|
if not os.path.exists(FLAGS.output_dir):
|
||
|
os.makedirs(FLAGS.output_dir)
|
||
|
out_path = os.path.join(FLAGS.output_dir, video_name)
|
||
|
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
|
||
|
index = 0
|
||
|
while (1):
|
||
|
ret, frame = capture.read()
|
||
|
if not ret:
|
||
|
break
|
||
|
index += 1
|
||
|
print('detect frame:%d' % (index))
|
||
|
|
||
|
frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
|
results = detector.predict([frame2], FLAGS.det_threshold)
|
||
|
batchs_images, rect_vecotr = get_person_from_rect(frame2, results)
|
||
|
keypoint_vector = []
|
||
|
score_vector = []
|
||
|
for batch_images, batch_records in batchs_images:
|
||
|
keypoint_result = topdown_keypoint_detector.predict(
|
||
|
batch_images, FLAGS.keypoint_threshold)
|
||
|
orgkeypoints, scores = affine_backto_orgimages(keypoint_result,
|
||
|
batch_records)
|
||
|
keypoint_vector.append(orgkeypoints)
|
||
|
score_vector.append(scores)
|
||
|
keypoint_res = {}
|
||
|
keypoint_res['keypoint'] = [
|
||
|
np.vstack(keypoint_vector), np.vstack(score_vector)
|
||
|
] if len(keypoint_vector) > 0 else [[], []]
|
||
|
keypoint_res['bbox'] = rect_vecotr
|
||
|
im = draw_pose(
|
||
|
frame,
|
||
|
keypoint_res,
|
||
|
visual_thread=FLAGS.keypoint_threshold,
|
||
|
returnimg=True)
|
||
|
|
||
|
writer.write(im)
|
||
|
if camera_id != -1:
|
||
|
cv2.imshow('Mask Detection', im)
|
||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
|
break
|
||
|
writer.release()
|
||
|
|
||
|
|
||
|
def main():
|
||
|
pred_config = PredictConfig(FLAGS.det_model_dir)
|
||
|
detector = Detector(
|
||
|
pred_config,
|
||
|
FLAGS.det_model_dir,
|
||
|
use_gpu=FLAGS.use_gpu,
|
||
|
run_mode=FLAGS.run_mode,
|
||
|
use_dynamic_shape=FLAGS.use_dynamic_shape,
|
||
|
trt_min_shape=FLAGS.trt_min_shape,
|
||
|
trt_max_shape=FLAGS.trt_max_shape,
|
||
|
trt_opt_shape=FLAGS.trt_opt_shape,
|
||
|
trt_calib_mode=FLAGS.trt_calib_mode,
|
||
|
cpu_threads=FLAGS.cpu_threads,
|
||
|
enable_mkldnn=FLAGS.enable_mkldnn)
|
||
|
|
||
|
pred_config = PredictConfig_KeyPoint(FLAGS.keypoint_model_dir)
|
||
|
topdown_keypoint_detector = KeyPoint_Detector(
|
||
|
pred_config,
|
||
|
FLAGS.keypoint_model_dir,
|
||
|
use_gpu=FLAGS.use_gpu,
|
||
|
run_mode=FLAGS.run_mode,
|
||
|
use_dynamic_shape=FLAGS.use_dynamic_shape,
|
||
|
trt_min_shape=FLAGS.trt_min_shape,
|
||
|
trt_max_shape=FLAGS.trt_max_shape,
|
||
|
trt_opt_shape=FLAGS.trt_opt_shape,
|
||
|
trt_calib_mode=FLAGS.trt_calib_mode,
|
||
|
cpu_threads=FLAGS.cpu_threads,
|
||
|
enable_mkldnn=FLAGS.enable_mkldnn)
|
||
|
|
||
|
# predict from video file or camera video stream
|
||
|
if FLAGS.video_file is not None or FLAGS.camera_id != -1:
|
||
|
topdown_unite_predict_video(detector, topdown_keypoint_detector,
|
||
|
FLAGS.camera_id)
|
||
|
else:
|
||
|
# predict from image
|
||
|
img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
|
||
|
topdown_unite_predict(detector, topdown_keypoint_detector, img_list)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
paddle.enable_static()
|
||
|
parser = argsparser()
|
||
|
FLAGS = parser.parse_args()
|
||
|
print_arguments(FLAGS)
|
||
|
|
||
|
main()
|