supervised_layout_benchmark/main.py

66 lines
2.2 KiB
Python
Raw Normal View History

2021-04-12 20:52:20 +08:00
# encoding: utf-8
"""
This function denotes the main function to train/test/plot
Usage:
python main.py [FLAGS]
@author: gongzhiqiang
@contact: gongzhiqiang@alumni.sjtu.edu.cn
@version: 1.0
@file: main.py
@time: 2020-12-22
"""
from pathlib import Path
import configargparse
from src.LayoutDeepRegression import Model
from src import train, test, plot
def main():
# default configuration file
config_path = Path(__file__).absolute().parent / "config/config.yml"
parser = configargparse.ArgParser(default_config_files=[str(config_path)], description="Hyper-parameters.")
# configuration file
parser.add_argument("--config", is_config_file=True, default=False, help="config file path")
# mode
parser.add_argument("-m", "--mode", type=str, default="train", help="model: train or test or plot")
# args for training
parser.add_argument("--gpus", type=int, default=0, help="how many gpus")
parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument("--max_epochs", default=20, type=int)
parser.add_argument("--lr", default="0.01", type=float)
parser.add_argument("--resume_from_checkpoint", type=str, help="resume from checkpoint")
parser.add_argument("--num_workers", default=2, type=int, help="num_workers in DataLoader")
parser.add_argument("--seed", type=int, default=1, help="seed")
parser.add_argument("--use_16bit", type=bool, default=False, help="use 16bit precision")
parser.add_argument("--profiler", action="store_true", help="use profiler")
# args for validation
parser.add_argument("--val_check_interval", type=float, default=1,
help="how often within one training epoch to check the validation set")
# args for testing
parser.add_argument("--test_check_num", default='0', type=str, help="checkpoint for test")
parser.add_argument("--test_args", action="store_true", help="print args")
# args from Model
parser = Model.add_model_specific_args(parser)
hparams = parser.parse_args()
# running
assert hparams.mode in ["train", "test", "plot"]
if hparams.test_args:
print(hparams)
else:
getattr(eval(hparams.mode), "main")(hparams)
if __name__ == '__main__':
main()