TFR-HSS-Benchmark/README.md

162 lines
4.5 KiB
Markdown

# Temperature field reconstruction implementation package
## Introduction
This project provides the implementation of the paper "A Machine Learning Modelling Benchmark for
Temperature Field Reconstruction of Heat-Source Systems". [[paper](https://arxiv.org/abs/2108.08298)] [[data generator](https://github.com/shendu-sw/recon-data-generator)]
## Requirements
* Software
* python >= 3.6
* cuda (only GPU is required)
* pytorch
* Hardware
* GPU with at least 16GB (recommended)
* CPU
## Environment construction
1. Install required packages followed `requirements.txt`.
```python
pip install -r requirements.txt
```
2. Install `torch-cluster`, `torch-scatter`, `torch-sparse` package (matching the version of `torch`, `cuda`)
* Automatic installation [[install instruction](https://github.com/rusty1s/pytorch_geometric#pip-wheels)]
```
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${version}+${CUDA}.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${version}+${CUDA}.html
pip install torch-geometric
pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${version}+${CUDA}.html
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${version}+${CUDA}.html
```
`$version` describes the `torch` version and should be replaced by `1.4.0`,`1.5.0`,`1.6.0`,`1.7.0`, `1.8.0`,`1.9.0`,`1.7.1`,`1.8.1`.
`$CUDA` should be replaced by `cpu`, `cu101`, `cu102`, `cu111`, `cu92`.
- Manual installation [[download](https://pytorch-geometric.com/whl)]
## Running
> All the methods for TFR-HSS task can be accessed by running `main.py` file.
>
> All the parameters are defined in two forms, namely the `yaml` file (default `config\config.yml`) and `command line` parameters. Priority: `yaml` < `command line`
### Image-based and Vector-based methods
> The image-based and vector-based methods are following the same command.
- Training
```
python main.py -m train
```
or
```
python main.py --mode=train
```
- Test
```
python main.py -m test --test_check_num=21
```
or
```
python main.py --mode=test --test_check_num=21
```
or
```
python main.py -m=test -v=21
```
where variable `test_check_num` is the number of the saved model for test.
- Prediction visualization
```
python main.py -m plot --test_check_num=21
```
or
```
python main.py --mode=plot --test_check_num=21
```
or
```
python main.py -m=test -v=21
```
where variable `test_check_num` is the number of the saved model for plotting.
### Point-based methods
> Only testing is permitted for point-based methods.
- Testing
```
python main.py
```
* Testing with prediction visualization
```
python main.py --plot
```
## Project architecture
- `config`: the configuration file
- `config.yml` describes configurations
- `model_name`: model for reconstruction
- `backbone`: backbone network, used only for deep surrogate models
- `data_root`: root path of data
- `train_list`: train samples
- `test_list`: test samples
- others
- `samples`: examples
- `outputs`: the output results by `test` and `plot` module. The test results is saved at `outputs/*.csv` and the plotting figures is saved at `outputs/predict_plot/`.
- `src`: including surrogate model, training and testing files.
- `test.py`: testing files.
- `train.py`: training files.
- `plot.py`: prediction visualization files.
- `point.py`: Model and testing files for point-based methods.
- `DeepRegression.py`: Model configurations for image-based and vector-based methods.
- `data`: data preprocessing and data loading files.
- `models`: interpolation and machine learning models for the TFR-HSS task.
- `utils`: useful tool function files.
* `docker`: start with docker.
* `lightning_logs`: saved models.
## One tiny example
One tiny example for training and testing can be accessed based on the following instruction.
- Some training and testing data are available at `samples/data`.
- Based on the original configuration file, run `python main.py` directly for a quick experience of this tiny example.
## Citing this work
If you find this work helpful for your research, please consider citing:
```
@article{gong2021,
Author = {Xiaoqian Chen and Zhiqiang Gong and Xiaoyu Zhao and Weien Zhou and Wen Yao},
Title = {A Machine Learning Modelling Benchmark for Temperature Field Reconstruction of Heat-Source Systems},
Journal = {arXiv preprint arXiv:2108.08298},
Year = {2021}
}
```