PulseFocusPlatform/static/ppdet/utils/export_utils.py

202 lines
7.0 KiB
Python

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import yaml
from collections import OrderedDict
import logging
logger = logging.getLogger(__name__)
import paddle.fluid as fluid
__all__ = ['dump_infer_config', 'save_infer_model']
# Global dictionary
TRT_MIN_SUBGRAPH = {
'YOLO': 3,
'SSD': 3,
'RCNN': 40,
'RetinaNet': 40,
'S2ANet': 40,
'EfficientDet': 40,
'Face': 3,
'TTFNet': 3,
'FCOS': 33,
'SOLOv2': 60,
}
RESIZE_SCALE_SET = {
'RCNN',
'RetinaNet',
'S2ANet',
'FCOS',
'SOLOv2',
}
def parse_reader(reader_cfg, metric, arch):
preprocess_list = []
image_shape = reader_cfg['inputs_def'].get('image_shape', [3, None, None])
has_shape_def = not None in image_shape
dataset = reader_cfg['dataset']
anno_file = dataset.get_anno()
with_background = dataset.with_background
use_default_label = dataset.use_default_label
if metric == 'COCO':
from ppdet.utils.coco_eval import get_category_info
elif metric == "VOC":
from ppdet.utils.voc_eval import get_category_info
elif metric == "WIDERFACE":
from ppdet.utils.widerface_eval_utils import get_category_info
elif cfg.metric == 'OID':
from ppdet.utils.oid_eval import get_category_info
else:
raise ValueError(
"metric only supports COCO, VOC, WIDERFACE, but received {}".format(
metric))
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
label_list = [str(cat) for cat in catid2name.values()]
sample_transforms = reader_cfg['sample_transforms']
for st in sample_transforms[1:]:
method = st.__class__.__name__
p = {'type': method.replace('Image', '')}
params = st.__dict__
params.pop('_id')
if p['type'] == 'Resize' and has_shape_def:
params['target_size'] = min(image_shape[
1:]) if arch in RESIZE_SCALE_SET else image_shape[1]
params['max_size'] = max(image_shape[
1:]) if arch in RESIZE_SCALE_SET else 0
params['image_shape'] = image_shape[1:]
if 'target_dim' in params:
params.pop('target_dim')
if p['type'] == 'ResizeAndPad':
assert has_shape_def, "missing input shape"
p['type'] = 'Resize'
p['target_size'] = params['target_dim']
p['max_size'] = params['target_dim']
p['interp'] = params['interp']
p['image_shape'] = image_shape[1:]
preprocess_list.append(p)
continue
p.update(params)
preprocess_list.append(p)
batch_transforms = reader_cfg.get('batch_transforms', None)
if batch_transforms:
methods = [bt.__class__.__name__ for bt in batch_transforms]
for bt in batch_transforms:
method = bt.__class__.__name__
if method == 'PadBatch':
preprocess_list.append({'type': 'PadStride'})
params = bt.__dict__
preprocess_list[-1].update({'stride': params['pad_to_stride']})
break
return with_background, preprocess_list, label_list
def dump_infer_config(FLAGS, config):
arch_state = 0
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
from ppdet.core.config.yaml_helpers import setup_orderdict
setup_orderdict()
infer_cfg = OrderedDict({
'use_python_inference': False,
'mode': 'fluid',
'draw_threshold': 0.5,
'metric': config['metric']
})
infer_arch = config['architecture']
for arch, min_subgraph_size in TRT_MIN_SUBGRAPH.items():
if arch in infer_arch:
infer_cfg['arch'] = arch
infer_cfg['min_subgraph_size'] = min_subgraph_size
arch_state = 1
break
if not arch_state:
logger.error(
'Architecture: {} is not supported for exporting model now'.format(
infer_arch))
os._exit(0)
# support land mark output
if 'with_lmk' in config and config['with_lmk'] == True:
infer_cfg['with_lmk'] = True
if 'Mask' in config['architecture']:
infer_cfg['mask_resolution'] = config['MaskHead']['resolution']
infer_cfg['with_background'], infer_cfg['Preprocess'], infer_cfg[
'label_list'] = parse_reader(config['TestReader'], config['metric'],
infer_cfg['arch'])
yaml.dump(infer_cfg, open(os.path.join(save_dir, 'infer_cfg.yml'), 'w'))
logger.info("Export inference config file to {}".format(
os.path.join(save_dir, 'infer_cfg.yml')))
def prune_feed_vars(feeded_var_names, target_vars, prog):
"""
Filter out feed variables which are not in program,
pruned feed variables are only used in post processing
on model output, which are not used in program, such
as im_id to identify image order, im_shape to clip bbox
in image.
"""
exist_var_names = []
prog = prog.clone()
prog = prog._prune(targets=target_vars)
global_block = prog.global_block()
for name in feeded_var_names:
try:
v = global_block.var(name)
exist_var_names.append(str(v.name))
except Exception:
logger.info('save_inference_model pruned unused feed '
'variables {}'.format(name))
pass
return exist_var_names
def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name)
feed_var_names = [var.name for var in feed_vars.values()]
fetch_list = sorted(test_fetches.items(), key=lambda i: i[0])
target_vars = [var[1] for var in fetch_list]
feed_var_names = prune_feed_vars(feed_var_names, target_vars, infer_prog)
logger.info("Export inference model to {}, input: {}, output: "
"{}...".format(save_dir, feed_var_names,
[str(var.name) for var in target_vars]))
fluid.io.save_inference_model(
save_dir,
feeded_var_names=feed_var_names,
target_vars=target_vars,
executor=exe,
main_program=infer_prog,
params_filename="__params__")