TFR-HSS-Benchmark/main.py

114 lines
3.2 KiB
Python
Raw Permalink Normal View History

2021-08-24 17:44:06 +08:00
# encoding: utf-8
"""
This function denotes the main function to train/test/plot
Usage:
python main.py [FLAGS]
@author: gongzhiqiang
@contact: gongzhiqiang@alumni.sjtu.edu.cn
@version: 1.0
@file: main.py
@time: 2021-06-29
"""
from pathlib import Path
import configargparse
from src.DeepRegression import Model
from src import train, test, plot, point
def main():
# default configuration file
config_path = Path(__file__).absolute().parent / "config/config.yml"
data_path = Path(__file__).absolute().parent / "config/data.yml"
2021-08-25 13:33:37 +08:00
parser = configargparse.ArgParser(
config_file_parser_class=configargparse.YAMLConfigFileParser,
default_config_files=[str(config_path), str(data_path)],
description="Hyper-parameters.",
)
2021-08-24 17:44:06 +08:00
# configuration file
2021-08-25 13:33:37 +08:00
parser.add_argument(
"--config", is_config_file=True, default=False, help="config file path"
)
2021-08-24 17:44:06 +08:00
# mode
2021-08-25 13:33:37 +08:00
parser.add_argument(
"-m", "--mode", type=str, default="train", help="model: train or test or plot"
)
2021-08-24 17:44:06 +08:00
# problem dimension
2021-08-25 13:33:37 +08:00
parser.add_argument(
"--prob_dim", default=2, type=int, help="dimension of the problem"
)
2021-08-24 17:44:06 +08:00
# args for plot in point-based methods
parser.add_argument("--plot", action="store_true", help="use profiler")
# args for training
2021-08-25 13:33:37 +08:00
parser.add_argument(
"--gpu",
type=int,
default=0,
help="which gpu: 0 for cpu, 1 for gpu 0, 2 for gpu 1, ...",
)
2021-08-24 17:44:06 +08:00
parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument("--max_epochs", default=20, type=int)
parser.add_argument("--lr", default="0.01", type=float)
2021-08-25 13:33:37 +08:00
parser.add_argument(
"--resume_from_checkpoint", type=str, help="resume from checkpoint"
)
parser.add_argument(
"--num_workers", default=2, type=int, help="num_workers in DataLoader"
)
2021-08-24 17:44:06 +08:00
parser.add_argument("--seed", type=int, default=1, help="seed")
2021-08-25 13:33:37 +08:00
parser.add_argument(
"--use_16bit", type=bool, default=False, help="use 16bit precision"
)
2021-08-24 17:44:06 +08:00
parser.add_argument("--profiler", action="store_true", help="use profiler")
# args for validation
2021-08-25 13:33:37 +08:00
parser.add_argument(
"--val_check_interval",
type=float,
default=1,
help="how often within one training epoch to check the validation set",
)
2021-08-24 17:44:06 +08:00
# args for testing
2021-08-25 13:33:37 +08:00
parser.add_argument(
"-v", "--test_check_num", default="0", type=str, help="checkpoint for test"
)
2021-08-24 17:44:06 +08:00
parser.add_argument("--test_args", action="store_true", help="print args")
# args from Model
parser = Model.add_model_specific_args(parser)
hparams = parser.parse_args()
2021-08-25 13:33:37 +08:00
PointModel = [
"RSVR",
"RBM",
"RandomForest",
"Polynomial",
"MLPP",
"Kriging",
"KInterpolation",
"GInterpolation",
]
2021-08-24 17:44:06 +08:00
dpoint = True if hparams.model_name in PointModel else False
# running
assert hparams.mode in ["train", "test", "plot"]
if hparams.test_args:
print(hparams)
elif dpoint:
print(f"Only testing is permitted for", hparams.model_name)
point.main(hparams)
else:
getattr(eval(hparams.mode), "main")(hparams)
2021-08-25 13:33:37 +08:00
if __name__ == "__main__":
2021-08-24 17:44:06 +08:00
main()