PulseFocusPlatform/build/lib/ppdet/utils/cli.py

152 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import yaml
import re
from ppdet.core.workspace import get_registered_modules, dump_value
__all__ = ['ColorTTY', 'ArgsParser']
class ColorTTY(object):
def __init__(self):
super(ColorTTY, self).__init__()
self.colors = ['red', 'green', 'yellow', 'blue', 'magenta', 'cyan']
def __getattr__(self, attr):
if attr in self.colors:
color = self.colors.index(attr) + 31
def color_message(message):
return "[{}m{}".format(color, message)
setattr(self, attr, color_message)
return color_message
def bold(self, message):
return self.with_code('01', message)
def with_code(self, code, message):
return "[{}m{}".format(code, message)
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-o", "--opt", nargs='*', help="set configuration options")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split('=', 1)
if '.' not in k:
config[k] = yaml.load(v, Loader=yaml.Loader)
else:
keys = k.split('.')
if keys[0] not in config:
config[keys[0]] = {}
cur = config[keys[0]]
for idx, key in enumerate(keys[1:]):
if idx == len(keys) - 2:
cur[key] = yaml.load(v, Loader=yaml.Loader)
else:
cur[key] = {}
cur = cur[key]
return config
def print_total_cfg(config):
modules = get_registered_modules()
color_tty = ColorTTY()
green = '___{}___'.format(color_tty.colors.index('green') + 31)
styled = {}
for key in config.keys():
if not config[key]: # empty schema
continue
if key not in modules and not hasattr(config[key], '__dict__'):
styled[key] = config[key]
continue
elif key in modules:
module = modules[key]
else:
type_name = type(config[key]).__name__
if type_name in modules:
module = modules[type_name].copy()
module.update({
k: v
for k, v in config[key].__dict__.items()
if k in module.schema
})
key += " ({})".format(type_name)
default = module.find_default_keys()
missing = module.find_missing_keys()
mismatch = module.find_mismatch_keys()
extra = module.find_extra_keys()
dep_missing = []
for dep in module.inject:
if isinstance(module[dep], str) and module[dep] != '<value>':
if module[dep] not in modules: # not a valid module
dep_missing.append(dep)
else:
dep_mod = modules[module[dep]]
# empty dict but mandatory
if not dep_mod and dep_mod.mandatory():
dep_missing.append(dep)
override = list(
set(module.keys()) - set(default) - set(extra) - set(dep_missing))
replacement = {}
for name in set(override + default + extra + mismatch + missing):
new_name = name
if name in missing:
value = "<missing>"
else:
value = module[name]
if name in extra:
value = dump_value(value) + " <extraneous>"
elif name in mismatch:
value = dump_value(value) + " <type mismatch>"
elif name in dep_missing:
value = dump_value(value) + " <module config missing>"
elif name in override and value != '<missing>':
mark = green
new_name = mark + name
replacement[new_name] = value
styled[key] = replacement
buffer = yaml.dump(styled, default_flow_style=False, default_style='')
buffer = (re.sub(r"<missing>", r"[31m<missing>[0m", buffer))
buffer = (re.sub(r"<extraneous>", r"[33m<extraneous>[0m", buffer))
buffer = (re.sub(r"<type mismatch>", r"[31m<type mismatch>[0m", buffer))
buffer = (re.sub(r"<module config missing>",
r"[31m<module config missing>[0m", buffer))
buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2[0m:", buffer)
print(buffer)