PulseFocusPlatform/static/ppdet/core/workspace.py

274 lines
8.5 KiB
Python
Raw Permalink Normal View History

2022-06-01 11:18:00 +08:00
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import importlib
import os
import sys
import yaml
import copy
import collections
try:
collectionsAbc = collections.abc
except AttributeError:
collectionsAbc = collections
from .config.schema import SchemaDict, SharedConfig, extract_schema
from .config.yaml_helpers import serializable
__all__ = [
'global_config',
'load_config',
'merge_config',
'get_registered_modules',
'create',
'register',
'serializable',
'dump_value',
]
def dump_value(value):
# XXX this is hackish, but collections.abc is not available in python 2
if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)):
value = yaml.dump(value, default_flow_style=True)
value = value.replace('\n', '')
value = value.replace('...', '')
return "'{}'".format(value)
else:
# primitive types
return str(value)
class AttrDict(dict):
"""Single level attribute dict, NOT recursive"""
def __init__(self, **kwargs):
super(AttrDict, self).__init__()
super(AttrDict, self).update(kwargs)
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError("object has no attribute '{}'".format(key))
global_config = AttrDict()
READER_KEY = '_READER_'
def load_config(file_path):
"""
Load config from file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
cfg = AttrDict()
with open(file_path) as f:
cfg = merge_config(yaml.load(f, Loader=yaml.Loader), cfg)
if READER_KEY in cfg:
reader_cfg = cfg[READER_KEY]
if reader_cfg.startswith("~"):
reader_cfg = os.path.expanduser(reader_cfg)
if not reader_cfg.startswith('/'):
reader_cfg = os.path.join(os.path.dirname(file_path), reader_cfg)
with open(reader_cfg) as f:
merge_config(yaml.load(f, Loader=yaml.Loader))
del cfg[READER_KEY]
merge_config(cfg)
return global_config
def dict_merge(dct, merge_dct):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
Args:
dct: dict onto which the merge is executed
merge_dct: dct merged into dct
Returns: dct
"""
for k, v in merge_dct.items():
if (k in dct and isinstance(dct[k], dict) and
isinstance(merge_dct[k], collectionsAbc.Mapping)):
dict_merge(dct[k], merge_dct[k])
else:
dct[k] = merge_dct[k]
return dct
def merge_config(config, another_cfg=None):
"""
Merge config into global config or another_cfg.
Args:
config (dict): Config to be merged.
Returns: global config
"""
global global_config
dct = another_cfg if another_cfg is not None else global_config
dct = dict_merge(dct, config)
# NOTE: training batch size defined only in TrainReader, sychornized
# batch size config to global, models can get batch size config
# from global config when building model.
# batch size in evaluation or inference can also be added here
if 'TrainReader' in dct and 'batch_size' in dct['TrainReader']:
dct['train_batch_size'] = dct['TrainReader']['batch_size']
return dct
def get_registered_modules():
return {k: v for k, v in global_config.items() if isinstance(v, SchemaDict)}
def make_partial(cls):
if isinstance(cls.__op__, str):
sep = cls.__op__.split('.')
op_name = sep[-1]
op_module = importlib.import_module('.'.join(sep[:-1]))
else:
op_name = cls.__op__.__name__
op_module = importlib.import_module(cls.__op__.__module__)
if not hasattr(op_module, op_name):
import logging
logger = logging.getLogger(__name__)
logger.warning('{} OP not found, maybe a newer version of paddle '
'is required.'.format(cls.__op__))
return cls
op = getattr(op_module, op_name)
cls.__category__ = getattr(cls, '__category__', None) or 'op'
def partial_apply(self, *args, **kwargs):
kwargs_ = self.__dict__.copy()
kwargs_.update(kwargs)
return op(*args, **kwargs_)
if getattr(cls, '__append_doc__', True): # XXX should default to True?
if sys.version_info[0] > 2:
cls.__doc__ = "Wrapper for `{}` OP".format(op.__name__)
cls.__init__.__doc__ = op.__doc__
cls.__call__ = partial_apply
cls.__call__.__doc__ = op.__doc__
else:
# XXX work around for python 2
partial_apply.__doc__ = op.__doc__
cls.__call__ = partial_apply
return cls
def register(cls):
"""
Register a given module class.
Args:
cls (type): Module class to be registered.
Returns: cls
"""
if cls.__name__ in global_config:
raise ValueError("Module class already registered: {}".format(
cls.__name__))
if hasattr(cls, '__op__'):
cls = make_partial(cls)
global_config[cls.__name__] = extract_schema(cls)
return cls
def create(cls_or_name, **kwargs):
"""
Create an instance of given module class.
Args:
cls_or_name (type or str): Class of which to create instance.
Returns: instance of type `cls_or_name`
"""
assert type(cls_or_name) in [type, str
], "should be a class or name of a class"
name = type(cls_or_name) == str and cls_or_name or cls_or_name.__name__
assert name in global_config and \
isinstance(global_config[name], SchemaDict), \
"the module {} is not registered".format(name)
config = global_config[name]
config.update(kwargs)
config.validate()
cls = getattr(config.pymodule, name)
kwargs = {}
kwargs.update(global_config[name])
# parse `shared` annoation of registered modules
if getattr(config, 'shared', None):
for k in config.shared:
target_key = config[k]
shared_conf = config.schema[k].default
assert isinstance(shared_conf, SharedConfig)
if target_key is not None and not isinstance(target_key,
SharedConfig):
continue # value is given for the module
elif shared_conf.key in global_config:
# `key` is present in config
kwargs[k] = global_config[shared_conf.key]
else:
kwargs[k] = shared_conf.default_value
# parse `inject` annoation of registered modules
if getattr(config, 'inject', None):
for k in config.inject:
target_key = config[k]
# optional dependency
if target_key is None:
continue
# also accept dictionaries and serialized objects
if isinstance(target_key, dict) or hasattr(target_key, '__dict__'):
continue
elif isinstance(target_key, str):
if target_key not in global_config:
raise ValueError("Missing injection config:", target_key)
target = global_config[target_key]
if isinstance(target, SchemaDict):
kwargs[k] = create(target_key)
elif hasattr(target, '__dict__'): # serialized object
kwargs[k] = target
else:
raise ValueError("Unsupported injection type:", target_key)
# prevent modification of global config values of reference types
# (e.g., list, dict) from within the created module instances
kwargs = copy.deepcopy(kwargs)
return cls(**kwargs)