66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
|
# encoding: utf-8
|
||
|
"""
|
||
|
This function denotes the main function to train/test/plot
|
||
|
Usage:
|
||
|
python main.py [FLAGS]
|
||
|
|
||
|
@author: gongzhiqiang
|
||
|
@contact: gongzhiqiang@alumni.sjtu.edu.cn
|
||
|
|
||
|
@version: 1.0
|
||
|
@file: main.py
|
||
|
@time: 2020-12-22
|
||
|
|
||
|
"""
|
||
|
from pathlib import Path
|
||
|
import configargparse
|
||
|
|
||
|
from src.LayoutDeepRegression import Model
|
||
|
from src import train, test, plot
|
||
|
|
||
|
|
||
|
def main():
|
||
|
# default configuration file
|
||
|
config_path = Path(__file__).absolute().parent / "config/config.yml"
|
||
|
parser = configargparse.ArgParser(default_config_files=[str(config_path)], description="Hyper-parameters.")
|
||
|
|
||
|
# configuration file
|
||
|
parser.add_argument("--config", is_config_file=True, default=False, help="config file path")
|
||
|
|
||
|
# mode
|
||
|
parser.add_argument("-m", "--mode", type=str, default="train", help="model: train or test or plot")
|
||
|
|
||
|
# args for training
|
||
|
parser.add_argument("--gpus", type=int, default=0, help="how many gpus")
|
||
|
parser.add_argument("--batch_size", default=16, type=int)
|
||
|
parser.add_argument("--max_epochs", default=20, type=int)
|
||
|
parser.add_argument("--lr", default="0.01", type=float)
|
||
|
parser.add_argument("--resume_from_checkpoint", type=str, help="resume from checkpoint")
|
||
|
parser.add_argument("--num_workers", default=2, type=int, help="num_workers in DataLoader")
|
||
|
parser.add_argument("--seed", type=int, default=1, help="seed")
|
||
|
parser.add_argument("--use_16bit", type=bool, default=False, help="use 16bit precision")
|
||
|
parser.add_argument("--profiler", action="store_true", help="use profiler")
|
||
|
|
||
|
# args for validation
|
||
|
parser.add_argument("--val_check_interval", type=float, default=1,
|
||
|
help="how often within one training epoch to check the validation set")
|
||
|
|
||
|
# args for testing
|
||
|
parser.add_argument("--test_check_num", default='0', type=str, help="checkpoint for test")
|
||
|
parser.add_argument("--test_args", action="store_true", help="print args")
|
||
|
|
||
|
# args from Model
|
||
|
parser = Model.add_model_specific_args(parser)
|
||
|
hparams = parser.parse_args()
|
||
|
|
||
|
# running
|
||
|
assert hparams.mode in ["train", "test", "plot"]
|
||
|
if hparams.test_args:
|
||
|
print(hparams)
|
||
|
else:
|
||
|
getattr(eval(hparams.mode), "main")(hparams)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|