pep8
This commit is contained in:
parent
a53df132bf
commit
f4672490d1
67
main.py
67
main.py
|
@ -23,45 +23,80 @@ def main():
|
|||
# default configuration file
|
||||
config_path = Path(__file__).absolute().parent / "config/config.yml"
|
||||
data_path = Path(__file__).absolute().parent / "config/data.yml"
|
||||
parser = configargparse.ArgParser(config_file_parser_class= configargparse.YAMLConfigFileParser, \
|
||||
default_config_files=[str(config_path), str(data_path)], description="Hyper-parameters.")
|
||||
|
||||
parser = configargparse.ArgParser(
|
||||
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
||||
default_config_files=[str(config_path), str(data_path)],
|
||||
description="Hyper-parameters.",
|
||||
)
|
||||
|
||||
# configuration file
|
||||
parser.add_argument("--config", is_config_file=True, default=False, help="config file path")
|
||||
|
||||
parser.add_argument(
|
||||
"--config", is_config_file=True, default=False, help="config file path"
|
||||
)
|
||||
|
||||
# mode
|
||||
parser.add_argument("-m", "--mode", type=str, default="train", help="model: train or test or plot")
|
||||
parser.add_argument(
|
||||
"-m", "--mode", type=str, default="train", help="model: train or test or plot"
|
||||
)
|
||||
|
||||
# problem dimension
|
||||
parser.add_argument("--prob_dim", default=2, type=int, help="dimension of the problem")
|
||||
parser.add_argument(
|
||||
"--prob_dim", default=2, type=int, help="dimension of the problem"
|
||||
)
|
||||
|
||||
# args for plot in point-based methods
|
||||
parser.add_argument("--plot", action="store_true", help="use profiler")
|
||||
|
||||
# args for training
|
||||
parser.add_argument("--gpu", type=int, default=0, help="which gpu: 0 for cpu, 1 for gpu 0, 2 for gpu 1, ...")
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="which gpu: 0 for cpu, 1 for gpu 0, 2 for gpu 1, ...",
|
||||
)
|
||||
parser.add_argument("--batch_size", default=16, type=int)
|
||||
parser.add_argument("--max_epochs", default=20, type=int)
|
||||
parser.add_argument("--lr", default="0.01", type=float)
|
||||
parser.add_argument("--resume_from_checkpoint", type=str, help="resume from checkpoint")
|
||||
parser.add_argument("--num_workers", default=2, type=int, help="num_workers in DataLoader")
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint", type=str, help="resume from checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", default=2, type=int, help="num_workers in DataLoader"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="seed")
|
||||
parser.add_argument("--use_16bit", type=bool, default=False, help="use 16bit precision")
|
||||
parser.add_argument(
|
||||
"--use_16bit", type=bool, default=False, help="use 16bit precision"
|
||||
)
|
||||
parser.add_argument("--profiler", action="store_true", help="use profiler")
|
||||
|
||||
# args for validation
|
||||
parser.add_argument("--val_check_interval", type=float, default=1,
|
||||
help="how often within one training epoch to check the validation set")
|
||||
parser.add_argument(
|
||||
"--val_check_interval",
|
||||
type=float,
|
||||
default=1,
|
||||
help="how often within one training epoch to check the validation set",
|
||||
)
|
||||
|
||||
# args for testing
|
||||
parser.add_argument("-v", "--test_check_num", default='0', type=str, help="checkpoint for test")
|
||||
parser.add_argument(
|
||||
"-v", "--test_check_num", default="0", type=str, help="checkpoint for test"
|
||||
)
|
||||
parser.add_argument("--test_args", action="store_true", help="print args")
|
||||
|
||||
# args from Model
|
||||
parser = Model.add_model_specific_args(parser)
|
||||
hparams = parser.parse_args()
|
||||
|
||||
PointModel = ["RSVR", "RBM", "RandomForest", "Polynomial", "MLPP", "Kriging", "KInterpolation", "GInterpolation"]
|
||||
PointModel = [
|
||||
"RSVR",
|
||||
"RBM",
|
||||
"RandomForest",
|
||||
"Polynomial",
|
||||
"MLPP",
|
||||
"Kriging",
|
||||
"KInterpolation",
|
||||
"GInterpolation",
|
||||
]
|
||||
dpoint = True if hparams.model_name in PointModel else False
|
||||
# running
|
||||
assert hparams.mode in ["train", "test", "plot"]
|
||||
|
@ -74,5 +109,5 @@ def main():
|
|||
getattr(eval(hparams.mode), "main")(hparams)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -18,7 +18,6 @@ import src.models as models
|
|||
|
||||
|
||||
class Model(LightningModule):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
self.hparams = hparams
|
||||
|
@ -33,21 +32,57 @@ class Model(LightningModule):
|
|||
self.default_layout = None
|
||||
|
||||
def _build_model(self):
|
||||
model_list = ["SegNet_AlexNet", "SegNet_VGG", "SegNet_ResNet18", "SegNet_ResNet50",
|
||||
"SegNet_ResNet101", "SegNet_ResNet34", "SegNet_ResNet152",
|
||||
"FPN_ResNet18", "FPN_ResNet50", "FPN_ResNet101", "FPN_ResNet34", "FPN_ResNet152",
|
||||
"FCN_AlexNet", "FCN_VGG", "FCN_ResNet18", "FCN_ResNet50", "FCN_ResNet101",
|
||||
"FCN_ResNet34", "FCN_ResNet152",
|
||||
"UNet_VGG",
|
||||
"MLP", "ConditionalNeuralProcess", "TransformerRecon",
|
||||
"DenseDeepGCN"]
|
||||
layout_model = self.hparams.model_name + '_' + self.hparams.backbone
|
||||
assert (layout_model in model_list or self.hparams.model_name in model_list)
|
||||
self.layout_model = layout_model if layout_model in model_list else self.hparams.model_name
|
||||
model_list = [
|
||||
"SegNet_AlexNet",
|
||||
"SegNet_VGG",
|
||||
"SegNet_ResNet18",
|
||||
"SegNet_ResNet50",
|
||||
"SegNet_ResNet101",
|
||||
"SegNet_ResNet34",
|
||||
"SegNet_ResNet152",
|
||||
"FPN_ResNet18",
|
||||
"FPN_ResNet50",
|
||||
"FPN_ResNet101",
|
||||
"FPN_ResNet34",
|
||||
"FPN_ResNet152",
|
||||
"FCN_AlexNet",
|
||||
"FCN_VGG",
|
||||
"FCN_ResNet18",
|
||||
"FCN_ResNet50",
|
||||
"FCN_ResNet101",
|
||||
"FCN_ResNet34",
|
||||
"FCN_ResNet152",
|
||||
"UNet_VGG",
|
||||
"MLP",
|
||||
"ConditionalNeuralProcess",
|
||||
"TransformerRecon",
|
||||
"DenseDeepGCN",
|
||||
]
|
||||
layout_model = self.hparams.model_name + "_" + self.hparams.backbone
|
||||
assert layout_model in model_list or self.hparams.model_name in model_list
|
||||
self.layout_model = (
|
||||
layout_model if layout_model in model_list else self.hparams.model_name
|
||||
)
|
||||
|
||||
self.vec = True if self.layout_model in ["MLP", "ConditionalNeuralProcess", "TransformerRecon", "DenseDeepGCN"] else False
|
||||
self.vec = (
|
||||
True
|
||||
if self.layout_model
|
||||
in ["MLP", "ConditionalNeuralProcess", "TransformerRecon", "DenseDeepGCN"]
|
||||
else False
|
||||
)
|
||||
|
||||
self.model = nn.ModuleList([getattr(models, self.layout_model)(input_dim=self.input_dim, output_dim=self.output_dim) for i in range(self.hparams.div_num*self.hparams.div_num)]) if self.vec else getattr(models, self.layout_model)(in_channels=1)
|
||||
self.model = (
|
||||
nn.ModuleList(
|
||||
[
|
||||
getattr(models, self.layout_model)(
|
||||
input_dim=self.input_dim, output_dim=self.output_dim
|
||||
)
|
||||
for i in range(self.hparams.div_num * self.hparams.div_num)
|
||||
]
|
||||
)
|
||||
if self.vec
|
||||
else getattr(models, self.layout_model)(in_channels=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
|
@ -57,19 +92,36 @@ class Model(LightningModule):
|
|||
if num == 0:
|
||||
output = submodel(x[1]).unsqueeze(1)
|
||||
else:
|
||||
output = torch.cat((output, submodel(x[1]).unsqueeze(1)), axis=1)
|
||||
elif self.layout_model == "ConditionalNeuralProcess" or self.layout_model == "TransformerRecon":
|
||||
output = torch.cat(
|
||||
(output, submodel(x[1]).unsqueeze(1)), axis=1
|
||||
)
|
||||
elif (
|
||||
self.layout_model == "ConditionalNeuralProcess"
|
||||
or self.layout_model == "TransformerRecon"
|
||||
):
|
||||
for num, submodel in enumerate(self.model):
|
||||
if num == 0:
|
||||
output = submodel(x[0], x[1], (x[2])[:,0,...], (x[3])[:,0,...]).unsqueeze(1)
|
||||
output = submodel(
|
||||
x[0], x[1], (x[2])[:, 0, ...], (x[3])[:, 0, ...]
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
output = torch.cat((output, submodel(x[0], x[1], (x[2])[:,num,...], (x[3])[:,num,...]).unsqueeze(1)), axis=1)
|
||||
elif self.layout_model =="DenseDeepGCN":
|
||||
output = torch.cat(
|
||||
(
|
||||
output,
|
||||
submodel(
|
||||
x[0], x[1], (x[2])[:, num, ...], (x[3])[:, num, ...]
|
||||
).unsqueeze(1),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
elif self.layout_model == "DenseDeepGCN":
|
||||
for num, submodel in enumerate(self.model):
|
||||
if num == 0:
|
||||
output = submodel(x[num, ...]).unsqueeze(1)
|
||||
else:
|
||||
output = torch.cat((output, submodel(x[num, ...]).unsqueeze(1)), axis=1)
|
||||
output = torch.cat(
|
||||
(output, submodel(x[num, ...]).unsqueeze(1)), axis=1
|
||||
)
|
||||
else:
|
||||
output = self.model(x)
|
||||
|
||||
|
@ -85,8 +137,7 @@ class Model(LightningModule):
|
|||
return loader
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(),
|
||||
lr=self.hparams.lr)
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
||||
scheduler = ExponentialLR(optimizer, gamma=0.9)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
|
@ -154,14 +205,18 @@ class Model(LightningModule):
|
|||
return trainval_dataset, test_dataset
|
||||
|
||||
def prepare_data(self):
|
||||
"""Prepare dataset
|
||||
"""
|
||||
trainval_dataset, test_dataset = self.read_vec_data() if self.vec else self.read_image_data()
|
||||
"""Prepare dataset"""
|
||||
trainval_dataset, test_dataset = (
|
||||
self.read_vec_data() if self.vec else self.read_image_data()
|
||||
)
|
||||
|
||||
# split train/val set
|
||||
train_length, val_length = int(len(trainval_dataset) * 0.8), len(trainval_dataset)-int(len(trainval_dataset) * 0.8)
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset,
|
||||
[train_length, val_length])
|
||||
train_length, val_length = int(len(trainval_dataset) * 0.8), len(
|
||||
trainval_dataset
|
||||
) - int(len(trainval_dataset) * 0.8)
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(
|
||||
trainval_dataset, [train_length, val_length]
|
||||
)
|
||||
print(
|
||||
f"Prepared dataset, train:{int(len(train_dataset))},\
|
||||
val:{int(len(val_dataset))}, test:{len(test_dataset)}"
|
||||
|
@ -173,8 +228,10 @@ class Model(LightningModule):
|
|||
self.test_dataset = self.__dataloader(test_dataset, shuffle=False)
|
||||
|
||||
self.default_layout = trainval_dataset._layout()
|
||||
|
||||
self.default_layout = torch.from_numpy(self.default_layout).unsqueeze(0).unsqueeze(0)
|
||||
|
||||
self.default_layout = (
|
||||
torch.from_numpy(self.default_layout).unsqueeze(0).unsqueeze(0)
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.train_dataset
|
||||
|
@ -194,24 +251,57 @@ class Model(LightningModule):
|
|||
heat_obs, heat = batch
|
||||
heat_info = heat_obs
|
||||
|
||||
if self.layout_model=="ConditionalNeuralProcess" or self.layout_model=="TransformerRecon":
|
||||
heat_info[1] = heat_info[1].transpose(1,2)
|
||||
heat_info[3] = heat_info[3].transpose(2,3)
|
||||
heat = heat.transpose(2,3)
|
||||
elif self.layout_model=="DenseDeepGCN":
|
||||
heat_obs=heat_obs.squeeze()
|
||||
pseudo_heat = torch.zeros_like(heat[:,0,:]).squeeze()
|
||||
inputs = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,0,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0)
|
||||
if (
|
||||
self.layout_model == "ConditionalNeuralProcess"
|
||||
or self.layout_model == "TransformerRecon"
|
||||
):
|
||||
heat_info[1] = heat_info[1].transpose(1, 2)
|
||||
heat_info[3] = heat_info[3].transpose(2, 3)
|
||||
heat = heat.transpose(2, 3)
|
||||
elif self.layout_model == "DenseDeepGCN":
|
||||
heat_obs = heat_obs.squeeze()
|
||||
pseudo_heat = torch.zeros_like(heat[:, 0, :]).squeeze()
|
||||
inputs = (
|
||||
torch.cat(
|
||||
(
|
||||
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
|
||||
torch.cat((obs_index, pred_index[:, 0, ...]), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
|
||||
for i in range(self.hparams.div_num*self.hparams.div_num-1):
|
||||
input_single = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,i+1,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0)
|
||||
for i in range(self.hparams.div_num * self.hparams.div_num - 1):
|
||||
input_single = (
|
||||
torch.cat(
|
||||
(
|
||||
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
|
||||
torch.cat((obs_index, pred_index[:, i + 1, ...]), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
inputs = torch.cat((inputs, input_single), 0)
|
||||
|
||||
heat_info = inputs
|
||||
|
||||
labels = torch.cat((heat_obs,heat[:,0,:].squeeze()), 1).unsqueeze(1).unsqueeze(1)
|
||||
for i in range(self.hparams.div_num*self.hparams.div_num-1):
|
||||
label = torch.cat((heat_obs,heat[:,i,:].squeeze()), 1).unsqueeze(1).unsqueeze(1)
|
||||
labels = (
|
||||
torch.cat((heat_obs, heat[:, 0, :].squeeze()), 1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(1)
|
||||
)
|
||||
for i in range(self.hparams.div_num * self.hparams.div_num - 1):
|
||||
label = (
|
||||
torch.cat((heat_obs, heat[:, i, :].squeeze()), 1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(1)
|
||||
)
|
||||
labels = torch.cat((labels, label), 1)
|
||||
|
||||
heat = labels
|
||||
|
@ -233,26 +323,59 @@ class Model(LightningModule):
|
|||
heat_obs, heat = batch
|
||||
heat_info = heat_obs
|
||||
|
||||
if self.layout_model=="ConditionalNeuralProcess" or self.layout_model=="TransformerRecon":
|
||||
heat_info[1] = heat_info[1].transpose(1,2)
|
||||
heat_info[3] = heat_info[3].transpose(2,3)
|
||||
heat = heat.transpose(2,3)
|
||||
elif self.layout_model=="DenseDeepGCN":
|
||||
heat_obs=heat_obs.squeeze()
|
||||
if (
|
||||
self.layout_model == "ConditionalNeuralProcess"
|
||||
or self.layout_model == "TransformerRecon"
|
||||
):
|
||||
heat_info[1] = heat_info[1].transpose(1, 2)
|
||||
heat_info[3] = heat_info[3].transpose(2, 3)
|
||||
heat = heat.transpose(2, 3)
|
||||
elif self.layout_model == "DenseDeepGCN":
|
||||
heat_obs = heat_obs.squeeze()
|
||||
|
||||
pseudo_heat = torch.zeros_like(heat[:,0,:]).squeeze()
|
||||
pseudo_heat = torch.zeros_like(heat[:, 0, :]).squeeze()
|
||||
|
||||
inputs = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,0,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0)
|
||||
inputs = (
|
||||
torch.cat(
|
||||
(
|
||||
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
|
||||
torch.cat((obs_index, pred_index[:, 0, ...]), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
|
||||
for i in range(self.hparams.div_num*self.hparams.div_num-1):
|
||||
input_single = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,i+1,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0)
|
||||
for i in range(self.hparams.div_num * self.hparams.div_num - 1):
|
||||
input_single = (
|
||||
torch.cat(
|
||||
(
|
||||
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
|
||||
torch.cat((obs_index, pred_index[:, i + 1, ...]), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
inputs = torch.cat((inputs, input_single), 0)
|
||||
|
||||
heat_info = inputs
|
||||
|
||||
labels = torch.cat((heat_obs,heat[:,0,:].squeeze()), 1).unsqueeze(1).unsqueeze(1)
|
||||
for i in range(self.hparams.div_num*self.hparams.div_num-1):
|
||||
label = torch.cat((heat_obs,heat[:,i,:].squeeze()), 1).unsqueeze(1).unsqueeze(1)
|
||||
labels = (
|
||||
torch.cat((heat_obs, heat[:, 0, :].squeeze()), 1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(1)
|
||||
)
|
||||
for i in range(self.hparams.div_num * self.hparams.div_num - 1):
|
||||
label = (
|
||||
torch.cat((heat_obs, heat[:, i, :].squeeze()), 1)
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(1)
|
||||
)
|
||||
labels = torch.cat((labels, label), 1)
|
||||
|
||||
heat = labels
|
||||
|
@ -276,17 +399,42 @@ class Model(LightningModule):
|
|||
else:
|
||||
heat_obs, heat = batch
|
||||
heat_info = heat_obs
|
||||
|
||||
if self.layout_model=="ConditionalNeuralProcess" or self.layout_model=="TransformerRecon":
|
||||
heat_info[1] = heat_info[1].transpose(1,2)
|
||||
heat_info[3] = heat_info[3].transpose(2,3)
|
||||
elif self.layout_model=="DenseDeepGCN":
|
||||
heat_obs = heat_obs.squeeze()
|
||||
pseudo_heat = torch.zeros_like(heat0[:,0,:]).squeeze()
|
||||
inputs = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,0,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0)
|
||||
|
||||
for i in range(self.hparams.div_num*self.hparams.div_num-1):
|
||||
input_single = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,i+1,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0)
|
||||
if (
|
||||
self.layout_model == "ConditionalNeuralProcess"
|
||||
or self.layout_model == "TransformerRecon"
|
||||
):
|
||||
heat_info[1] = heat_info[1].transpose(1, 2)
|
||||
heat_info[3] = heat_info[3].transpose(2, 3)
|
||||
elif self.layout_model == "DenseDeepGCN":
|
||||
heat_obs = heat_obs.squeeze()
|
||||
pseudo_heat = torch.zeros_like(heat0[:, 0, :]).squeeze()
|
||||
inputs = (
|
||||
torch.cat(
|
||||
(
|
||||
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
|
||||
torch.cat((obs_index, pred_index[:, 0, ...]), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
|
||||
for i in range(self.hparams.div_num * self.hparams.div_num - 1):
|
||||
input_single = (
|
||||
torch.cat(
|
||||
(
|
||||
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
|
||||
torch.cat((obs_index, pred_index[:, i + 1, ...]), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
inputs = torch.cat((inputs, input_single), 0)
|
||||
|
||||
heat_info = inputs
|
||||
|
@ -294,17 +442,28 @@ class Model(LightningModule):
|
|||
heat_pred0 = self(heat_info)
|
||||
|
||||
if self.vec:
|
||||
if self.layout_model=="DenseDeepGCN":
|
||||
heat_pred0 = heat_pred0[...,-self.output_dim:]
|
||||
if self.layout_model == "DenseDeepGCN":
|
||||
heat_pred0 = heat_pred0[..., -self.output_dim :]
|
||||
else:
|
||||
pass
|
||||
|
||||
heat_pred0 = heat_pred0.reshape((-1, self.hparams.div_num*self.hparams.div_num, int(200 / self.hparams.div_num), int(200 / self.hparams.div_num)))
|
||||
|
||||
heat_pred0 = heat_pred0.reshape(
|
||||
(
|
||||
-1,
|
||||
self.hparams.div_num * self.hparams.div_num,
|
||||
int(200 / self.hparams.div_num),
|
||||
int(200 / self.hparams.div_num),
|
||||
)
|
||||
)
|
||||
heat_pred = torch.zeros_like(heat_pred0).reshape((-1, 1, 200, 200))
|
||||
for i in range(self.hparams.div_num):
|
||||
for j in range(self.hparams.div_num):
|
||||
heat_pred[..., 0+i:200:self.hparams.div_num, 0+j:200:self.hparams.div_num] = heat_pred0[:, self.hparams.div_num*i+j,...].unsqueeze(1)
|
||||
heat_pred = heat_pred.transpose(2,3)
|
||||
heat_pred[
|
||||
...,
|
||||
0 + i : 200 : self.hparams.div_num,
|
||||
0 + j : 200 : self.hparams.div_num,
|
||||
] = heat_pred0[:, self.hparams.div_num * i + j, ...].unsqueeze(1)
|
||||
heat_pred = heat_pred.transpose(2, 3)
|
||||
heat = heat.unsqueeze(1)
|
||||
|
||||
else:
|
||||
|
@ -312,49 +471,110 @@ class Model(LightningModule):
|
|||
|
||||
loss = self.criterion(heat_pred, heat) * self.hparams.std_heat
|
||||
|
||||
default_layout = torch.repeat_interleave(self.default_layout, repeats=heat_pred.size(0), dim=0).float().to(device=heat.device)
|
||||
default_layout = (
|
||||
torch.repeat_interleave(
|
||||
self.default_layout, repeats=heat_pred.size(0), dim=0
|
||||
)
|
||||
.float()
|
||||
.to(device=heat.device)
|
||||
)
|
||||
ones = torch.ones_like(default_layout).to(device=heat.device)
|
||||
zeros = torch.zeros_like(default_layout).to(device=heat.device)
|
||||
layout_ind = torch.where(default_layout<1e-2,zeros,ones)
|
||||
loss_2 = torch.sum(torch.abs(torch.sub(heat, heat_pred)) *layout_ind )* self.hparams.std_heat/ torch.sum(layout_ind)
|
||||
#---------------------------------
|
||||
loss_1 = torch.sum(torch.max(torch.max(torch.max(torch.abs(torch.sub(heat,heat_pred)) * layout_ind, 3).values, 2).values * self.hparams.std_heat,1).values)/heat_pred.size(0)
|
||||
#---------------------------------
|
||||
layout_ind = torch.where(default_layout < 1e-2, zeros, ones)
|
||||
loss_2 = (
|
||||
torch.sum(torch.abs(torch.sub(heat, heat_pred)) * layout_ind)
|
||||
* self.hparams.std_heat
|
||||
/ torch.sum(layout_ind)
|
||||
)
|
||||
# ---------------------------------
|
||||
loss_1 = (
|
||||
torch.sum(
|
||||
torch.max(
|
||||
torch.max(
|
||||
torch.max(
|
||||
torch.abs(torch.sub(heat, heat_pred)) * layout_ind, 3
|
||||
).values,
|
||||
2,
|
||||
).values
|
||||
* self.hparams.std_heat,
|
||||
1,
|
||||
).values
|
||||
)
|
||||
/ heat_pred.size(0)
|
||||
)
|
||||
# ---------------------------------
|
||||
boundary_ones = torch.zeros_like(default_layout).to(device=heat.device)
|
||||
boundary_ones[..., -2:, :] = ones[..., -2:, :]
|
||||
boundary_ones[..., :2, :] = ones[..., :2, :]
|
||||
boundary_ones[..., :, :2] = ones[..., :, :2]
|
||||
boundary_ones[..., :, -2:] = ones[..., :, -2:]
|
||||
loss_3 = torch.sum(torch.abs(torch.sub(heat, heat_pred)) *boundary_ones )* self.hparams.std_heat/ torch.sum(boundary_ones)
|
||||
#----------------------------------
|
||||
loss_4 = torch.sum(torch.max(torch.max(torch.max(torch.abs(torch.sub(heat,heat_pred)), 3).values, 2).values * self.hparams.std_heat,1).values)/heat_pred.size(0)
|
||||
|
||||
return {"test_loss": loss, "test_loss_1": loss_1, "test_loss_2": loss_2, "test_loss_3": loss_3, "test_loss_4": loss_4}
|
||||
loss_3 = (
|
||||
torch.sum(torch.abs(torch.sub(heat, heat_pred)) * boundary_ones)
|
||||
* self.hparams.std_heat
|
||||
/ torch.sum(boundary_ones)
|
||||
)
|
||||
# ----------------------------------
|
||||
loss_4 = (
|
||||
torch.sum(
|
||||
torch.max(
|
||||
torch.max(
|
||||
torch.max(torch.abs(torch.sub(heat, heat_pred)), 3).values, 2
|
||||
).values
|
||||
* self.hparams.std_heat,
|
||||
1,
|
||||
).values
|
||||
)
|
||||
/ heat_pred.size(0)
|
||||
)
|
||||
|
||||
return {
|
||||
"test_loss": loss,
|
||||
"test_loss_1": loss_1,
|
||||
"test_loss_2": loss_2,
|
||||
"test_loss_3": loss_3,
|
||||
"test_loss_4": loss_4,
|
||||
}
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
test_loss_mean = torch.stack([x["test_loss"] for x in outputs]).mean()
|
||||
self.log("test_loss (" + "MAE" +")", test_loss_mean.item())
|
||||
#test_loss_max = torch.max(torch.stack([x["test_loss_1"] for x in outputs]))
|
||||
self.log("test_loss (" + "MAE" + ")", test_loss_mean.item())
|
||||
# test_loss_max = torch.max(torch.stack([x["test_loss_1"] for x in outputs]))
|
||||
test_loss_max = torch.stack([x["test_loss_1"] for x in outputs]).mean()
|
||||
self.log("test_loss_1 (" + "M-CAE" +")", test_loss_max.item())
|
||||
self.log("test_loss_1 (" + "M-CAE" + ")", test_loss_max.item())
|
||||
test_loss_com_mean = torch.stack([x["test_loss_2"] for x in outputs]).mean()
|
||||
self.log("test_loss_2 (" + "CMAE" +")", test_loss_com_mean.item())
|
||||
self.log("test_loss_2 (" + "CMAE" + ")", test_loss_com_mean.item())
|
||||
test_loss_bc_mean = torch.stack([x["test_loss_3"] for x in outputs]).mean()
|
||||
self.log("test_loss_3 (" + "BMAE" +")", test_loss_bc_mean.item())
|
||||
self.log("test_loss_3 (" + "BMAE" + ")", test_loss_bc_mean.item())
|
||||
test_loss_max_1 = torch.stack([x["test_loss_4"] for x in outputs]).mean()
|
||||
self.log("test_loss_4 (" + "MaxAE" + ")", test_loss_max_1.item())
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser): # pragma: no-cover
|
||||
"""Parameters you define here will be available to your model through `self.hparams`.
|
||||
"""
|
||||
"""Parameters you define here will be available to your model through `self.hparams`."""
|
||||
# dataset args
|
||||
parser.add_argument("--data_root", type=str, required=True, help="path of dataset")
|
||||
parser.add_argument("--train_list", type=str, required=True, help="path of train dataset list")
|
||||
parser.add_argument("--train_size", default=0.8, type=float, help="train_size in train_test_split")
|
||||
parser.add_argument("--test_list", type=str, required=True, help="path of test dataset list")
|
||||
#parser.add_argument("--boundary", type=str, default="rm_wall", help="boundary condition")
|
||||
parser.add_argument("--data_format", type=str, default="mat", choices=["mat", "h5"], help="dataset format")
|
||||
parser.add_argument(
|
||||
"--data_root", type=str, required=True, help="path of dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_list", type=str, required=True, help="path of train dataset list"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_size",
|
||||
default=0.8,
|
||||
type=float,
|
||||
help="train_size in train_test_split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_list", type=str, required=True, help="path of test dataset list"
|
||||
)
|
||||
# parser.add_argument("--boundary", type=str, default="rm_wall", help="boundary condition")
|
||||
parser.add_argument(
|
||||
"--data_format",
|
||||
type=str,
|
||||
default="mat",
|
||||
choices=["mat", "h5"],
|
||||
help="dataset format",
|
||||
)
|
||||
|
||||
# Normalization params
|
||||
parser.add_argument("--mean_layout", default=0, type=float)
|
||||
|
@ -364,10 +584,19 @@ class Model(LightningModule):
|
|||
|
||||
# Model params (opt)
|
||||
parser.add_argument("--input_size", default=200, type=int)
|
||||
parser.add_argument("--model_name", type=str, default='FCN', help="the name of chosen model")
|
||||
parser.add_argument("--backbone", type=str, default='AlexNet', help="the used backbone in the regression model")
|
||||
parser.add_argument(
|
||||
"--model_name", type=str, default="FCN", help="the name of chosen model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backbone",
|
||||
type=str,
|
||||
default="AlexNet",
|
||||
help="the used backbone in the regression model",
|
||||
)
|
||||
|
||||
# div_num for vec (opt)
|
||||
parser.add_argument("--div_num", default=4, type=int, help="division of heat source systems")
|
||||
|
||||
parser.add_argument(
|
||||
"--div_num", default=4, type=int, help="division of heat source systems"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
|
|
@ -6,8 +6,7 @@ from .loadresponse import LoadResponse, LoadPointResponse, LoadVecResponse, mat_
|
|||
|
||||
|
||||
class LayoutDataset(LoadResponse):
|
||||
"""Layout dataset (mutiple files) generated by 'layout-generator'.
|
||||
"""
|
||||
"""Layout dataset (mutiple files) generated by 'layout-generator'."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -20,8 +19,9 @@ class LayoutDataset(LoadResponse):
|
|||
resp_name="u",
|
||||
):
|
||||
test_name = os.path.splitext(os.path.basename(list_path))[0]
|
||||
subdir = os.path.join("train", "train") \
|
||||
if train else os.path.join("test", test_name)
|
||||
subdir = (
|
||||
os.path.join("train", "train") if train else os.path.join("test", test_name)
|
||||
)
|
||||
|
||||
# find the path of the list of train/test samples
|
||||
list_path = os.path.join(root, list_path)
|
||||
|
@ -42,7 +42,6 @@ class LayoutDataset(LoadResponse):
|
|||
|
||||
|
||||
class LayoutPointDataset(LoadPointResponse):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
|
@ -53,8 +52,9 @@ class LayoutPointDataset(LoadPointResponse):
|
|||
layout_name="F",
|
||||
):
|
||||
test_name = os.path.splitext(os.path.basename(list_path))[0]
|
||||
subdir = os.path.join("train", "train") \
|
||||
if train else os.path.join("test", test_name)
|
||||
subdir = (
|
||||
os.path.join("train", "train") if train else os.path.join("test", test_name)
|
||||
)
|
||||
|
||||
# find the path of the list of train/test samples
|
||||
list_path = os.path.join(root, list_path)
|
||||
|
@ -74,8 +74,7 @@ class LayoutPointDataset(LoadPointResponse):
|
|||
|
||||
|
||||
class LayoutVecDataset(LoadVecResponse):
|
||||
"""Layout dataset (mutiple files) generated by 'layout-generator'.
|
||||
"""
|
||||
"""Layout dataset (mutiple files) generated by 'layout-generator'."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -89,8 +88,9 @@ class LayoutVecDataset(LoadVecResponse):
|
|||
resp_name="u",
|
||||
):
|
||||
test_name = os.path.splitext(os.path.basename(list_path))[0]
|
||||
subdir = os.path.join("train", "train") \
|
||||
if train else os.path.join("test", test_name)
|
||||
subdir = (
|
||||
os.path.join("train", "train") if train else os.path.join("test", test_name)
|
||||
)
|
||||
|
||||
# find the path of the list of train/test samples
|
||||
list_path = os.path.join(root, list_path)
|
||||
|
@ -108,4 +108,4 @@ class LayoutVecDataset(LoadVecResponse):
|
|||
div_num=div_num,
|
||||
transform=transform,
|
||||
target_transform=target_transform,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -28,33 +28,35 @@ class LoadResponse(VisionDataset):
|
|||
target_transform=None,
|
||||
is_valid_file=None,
|
||||
):
|
||||
super().__init__(
|
||||
root, transform=transform, target_transform=target_transform
|
||||
)
|
||||
super().__init__(root, transform=transform, target_transform=target_transform)
|
||||
self.list_path = list_path
|
||||
self.loader = loader
|
||||
self.load_name = load_name
|
||||
self.resp_name = resp_name
|
||||
self.layout_name = layout_name
|
||||
self.extensions = extensions
|
||||
self.sample_files = make_dataset_list(root, list_path, extensions, is_valid_file)
|
||||
self.sample_files = make_dataset_list(
|
||||
root, list_path, extensions, is_valid_file
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.sample_files[index]
|
||||
load, resp, _ = self.loader(path, self.load_name, self.resp_name)
|
||||
|
||||
load[np.where(load<TOL)] = 298
|
||||
load[np.where(load < TOL)] = 298
|
||||
|
||||
if self.transform is not None:
|
||||
load = self.transform(load)
|
||||
if self.target_transform is not None:
|
||||
resp = self.target_transform(resp)
|
||||
|
||||
|
||||
return load, resp
|
||||
|
||||
def _layout(self):
|
||||
path = self.sample_files[0]
|
||||
_, _, layout = self.loader(path, self.load_name, self.resp_name, self.layout_name)
|
||||
_, _, layout = self.loader(
|
||||
path, self.load_name, self.resp_name, self.layout_name
|
||||
)
|
||||
return layout
|
||||
|
||||
def __len__(self):
|
||||
|
@ -75,21 +77,23 @@ class LoadPointResponse(VisionDataset):
|
|||
extensions=None,
|
||||
is_valid_file=None,
|
||||
):
|
||||
super().__init__(
|
||||
root
|
||||
)
|
||||
super().__init__(root)
|
||||
self.list_path = list_path
|
||||
self.loader = loader
|
||||
self.load_name = load_name
|
||||
self.resp_name = resp_name
|
||||
self.layout_name = layout_name
|
||||
self.extensions = extensions
|
||||
self.sample_files = make_dataset_list(root, list_path, extensions, is_valid_file)
|
||||
self.sample_files = make_dataset_list(
|
||||
root, list_path, extensions, is_valid_file
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.sample_files[index]
|
||||
load, resp, layout = self.loader(path, self.load_name, self.resp_name, self.layout_name)
|
||||
|
||||
load, resp, layout = self.loader(
|
||||
path, self.load_name, self.resp_name, self.layout_name
|
||||
)
|
||||
|
||||
return load, resp, layout
|
||||
|
||||
def __len__(self):
|
||||
|
@ -98,14 +102,14 @@ class LoadPointResponse(VisionDataset):
|
|||
|
||||
class LoadVecResponse(VisionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
self,
|
||||
root,
|
||||
loader,
|
||||
list_path,
|
||||
load_name="u_obs",
|
||||
resp_name="u",
|
||||
layout_name="F",
|
||||
div_num = 4,
|
||||
div_num=4,
|
||||
extensions=None,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
|
@ -121,13 +125,15 @@ class LoadVecResponse(VisionDataset):
|
|||
self.resp_name = resp_name
|
||||
self.layout_name = layout_name
|
||||
self.div_num = div_num
|
||||
self.sample_files = make_dataset_list(root, list_path, extensions, is_valid_file)
|
||||
self.sample_files = make_dataset_list(
|
||||
root, list_path, extensions, is_valid_file
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
path = self.sample_files[index]
|
||||
x_context, y_context, x_target, y_target, resp = self._loader(path)
|
||||
|
||||
|
||||
if self.transform is not None:
|
||||
y_context = (y_context - self.transform[0]) / self.transform[1]
|
||||
else:
|
||||
|
@ -135,11 +141,17 @@ class LoadVecResponse(VisionDataset):
|
|||
|
||||
if self.target_transform is not None:
|
||||
y_target = (y_target - self.target_transform[0]) / self.transform[1]
|
||||
resp = (resp-self.target_transform[0]) / self.transform[1]
|
||||
resp = (resp - self.target_transform[0]) / self.transform[1]
|
||||
else:
|
||||
pass
|
||||
|
||||
return x_context, y_context.type(torch.FloatTensor), x_target, y_target.type(torch.FloatTensor), resp.type(torch.FloatTensor)
|
||||
|
||||
return (
|
||||
x_context,
|
||||
y_context.type(torch.FloatTensor),
|
||||
x_target,
|
||||
y_target.type(torch.FloatTensor),
|
||||
resp.type(torch.FloatTensor),
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_files)
|
||||
|
@ -149,22 +161,34 @@ class LoadVecResponse(VisionDataset):
|
|||
load, resp, _ = self.loader(path, self.load_name, self.resp_name)
|
||||
|
||||
monitor_x, monitor_y = np.where(load > TOL)
|
||||
y_context = torch.from_numpy(load[monitor_x, monitor_y].reshape(1,-1)).float()
|
||||
y_context = torch.from_numpy(load[monitor_x, monitor_y].reshape(1, -1)).float()
|
||||
|
||||
monitor_x, monitor_y = monitor_x / load.shape[0], monitor_y / load.shape[1]
|
||||
x_context = torch.from_numpy(np.concatenate([monitor_x.reshape(-1,1),monitor_y.reshape(-1,1)], axis=1)).float()
|
||||
|
||||
x = np.linspace(0, load.shape[0]-1, load.shape[0]).astype(int)
|
||||
y = np.linspace(1, load.shape[1]-1, load.shape[1]).astype(int)
|
||||
x_context = torch.from_numpy(
|
||||
np.concatenate([monitor_x.reshape(-1, 1), monitor_y.reshape(-1, 1)], axis=1)
|
||||
).float()
|
||||
|
||||
x = np.linspace(0, load.shape[0] - 1, load.shape[0]).astype(int)
|
||||
y = np.linspace(1, load.shape[1] - 1, load.shape[1]).astype(int)
|
||||
|
||||
x_target = None
|
||||
y_target = None
|
||||
for i in range(self.div_num):
|
||||
for j in range(self.div_num):
|
||||
x1, y1 = x[0+i:np.size(x):self.div_num], y[0+j:np.size(y):self.div_num]
|
||||
x1, y1 = (
|
||||
x[0 + i : np.size(x) : self.div_num],
|
||||
y[0 + j : np.size(y) : self.div_num],
|
||||
)
|
||||
x1, y1 = np.meshgrid(x1, y1)
|
||||
x_target0 = torch.from_numpy(np.concatenate([x1.reshape(-1, 1), y1.reshape(-1, 1)], axis=1) / np.max(load.shape)).float().unsqueeze(0)
|
||||
y_target0 = torch.from_numpy(resp[x1, y1].reshape(1,-1)).unsqueeze(0)
|
||||
x_target0 = (
|
||||
torch.from_numpy(
|
||||
np.concatenate([x1.reshape(-1, 1), y1.reshape(-1, 1)], axis=1)
|
||||
/ np.max(load.shape)
|
||||
)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
)
|
||||
y_target0 = torch.from_numpy(resp[x1, y1].reshape(1, -1)).unsqueeze(0)
|
||||
if x_target is not None:
|
||||
x_target = torch.cat((x_target, x_target0), 0)
|
||||
else:
|
||||
|
@ -179,7 +203,9 @@ class LoadVecResponse(VisionDataset):
|
|||
|
||||
def _layout(self):
|
||||
path = self.sample_files[0]
|
||||
_, _, layout = self.loader(path, self.load_name, self.resp_name, self.layout_name)
|
||||
_, _, layout = self.loader(
|
||||
path, self.load_name, self.resp_name, self.layout_name
|
||||
)
|
||||
return layout
|
||||
|
||||
def _inputdim(self):
|
||||
|
@ -188,9 +214,9 @@ class LoadVecResponse(VisionDataset):
|
|||
monitor_x, _ = np.where(load > TOL)
|
||||
return np.size(monitor_x)
|
||||
|
||||
|
||||
def make_dataset(root_dir, extensions=None, is_valid_file=None):
|
||||
"""make_dataset() from torchvision.
|
||||
"""
|
||||
"""make_dataset() from torchvision."""
|
||||
files = []
|
||||
root_dir = os.path.expanduser(root_dir)
|
||||
if not ((extensions is None) ^ (is_valid_file is None)):
|
||||
|
@ -211,8 +237,7 @@ def make_dataset(root_dir, extensions=None, is_valid_file=None):
|
|||
|
||||
|
||||
def make_dataset_list(root_dir, list_path, extensions=None, is_valid_file=None):
|
||||
"""make_dataset() from torchvision.
|
||||
"""
|
||||
"""make_dataset() from torchvision."""
|
||||
files = []
|
||||
root_dir = os.path.expanduser(root_dir)
|
||||
if not ((extensions is None) ^ (is_valid_file is None)):
|
||||
|
@ -224,7 +249,7 @@ def make_dataset_list(root_dir, list_path, extensions=None, is_valid_file=None):
|
|||
is_valid_file = lambda x: has_allowed_extension(x, extensions)
|
||||
|
||||
assert os.path.isdir(root_dir), root_dir
|
||||
with open(list_path, 'r') as rf:
|
||||
with open(list_path, "r") as rf:
|
||||
for line in rf.readlines():
|
||||
data_path = line.strip()
|
||||
path = os.path.join(root_dir, data_path)
|
||||
|
@ -247,9 +272,9 @@ def mat_loader(path, load_name, resp_name=None, layout_name=None):
|
|||
|
||||
if __name__ == "__main__":
|
||||
total_num = 50000
|
||||
with open('train'+str(total_num)+'.txt', 'w') as wf:
|
||||
for idx in range(int(total_num*0.8)):
|
||||
wf.write('Example'+str(idx)+'.mat'+'\n')
|
||||
with open('val'+str(total_num)+'.txt', 'w') as wf:
|
||||
for idx in range(int(total_num*0.8), total_num):
|
||||
wf.write('Example'+str(idx)+'.mat'+'\n')
|
||||
with open("train" + str(total_num) + ".txt", "w") as wf:
|
||||
for idx in range(int(total_num * 0.8)):
|
||||
wf.write("Example" + str(idx) + ".mat" + "\n")
|
||||
with open("val" + str(total_num) + ".txt", "w") as wf:
|
||||
for idx in range(int(total_num * 0.8), total_num):
|
||||
wf.write("Example" + str(idx) + ".mat" + "\n")
|
||||
|
|
|
@ -5,4 +5,4 @@ from .fpn import *
|
|||
from .cnp import *
|
||||
from .mlp import *
|
||||
from .transformer import *
|
||||
from .gnn import *
|
||||
from .gnn import *
|
||||
|
|
|
@ -69,7 +69,13 @@ class DeterministicDecoder(nn.Module):
|
|||
|
||||
|
||||
class ConditionalNeuralProcess(nn.Module):
|
||||
def __init__(self, input_dim=None, output_dim=None, encoder_sizes=[2+1, 128, 128, 128, 256], decoder_sizes=[256+2, 256, 256, 128, 128, 2]):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim=None,
|
||||
output_dim=None,
|
||||
encoder_sizes=[2 + 1, 128, 128, 128, 256],
|
||||
decoder_sizes=[256 + 2, 256, 256, 128, 128, 2],
|
||||
):
|
||||
super(ConditionalNeuralProcess, self).__init__()
|
||||
self._encoder = DeterministicEncoder(encoder_sizes)
|
||||
self._decoder = DeterministicDecoder(decoder_sizes)
|
||||
|
@ -79,15 +85,15 @@ class ConditionalNeuralProcess(nn.Module):
|
|||
dist, mu, sigma = self._decoder(representation, x_target)
|
||||
|
||||
log_p = None if y_target is None else dist.log_prob(y_target)
|
||||
#return log_p, mu, sigma
|
||||
# return log_p, mu, sigma
|
||||
return mu
|
||||
|
||||
|
||||
|
||||
def input_mapping(x, B):
|
||||
if B is None:
|
||||
return x
|
||||
else:
|
||||
x_proj = (2. * np.pi * x) @ B.t()
|
||||
x_proj = (2.0 * np.pi * x) @ B.t()
|
||||
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
||||
|
||||
|
||||
|
@ -106,4 +112,4 @@ class ConditionalNeuralProcessFourier(nn.Module):
|
|||
dist, mu, sigma = self._decoder(representation, x_target)
|
||||
|
||||
log_p = None if y_target is None else dist.log_prob(y_target)
|
||||
return log_p, mu, sigma
|
||||
return log_p, mu, sigma
|
||||
|
|
|
@ -7,18 +7,24 @@ from .util.backbone import *
|
|||
|
||||
|
||||
__all__ = [
|
||||
"FCN_VGG", "FCN_AlexNet", "FCN_ResNet18", "FCN_ResNet34",
|
||||
"FCN_ResNet50", "FCN_ResNet101", "FCN_ResNet152",
|
||||
"FCN_VGG",
|
||||
"FCN_AlexNet",
|
||||
"FCN_ResNet18",
|
||||
"FCN_ResNet34",
|
||||
"FCN_ResNet50",
|
||||
"FCN_ResNet101",
|
||||
"FCN_ResNet152",
|
||||
]
|
||||
|
||||
|
||||
class Conv3x3GNReLU(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, upsample=False):
|
||||
super().__init__()
|
||||
self.upsample = upsample
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False),
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
|
||||
),
|
||||
nn.GroupNorm(32, out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
@ -31,18 +37,19 @@ class Conv3x3GNReLU(nn.Module):
|
|||
|
||||
|
||||
class FCN_VGG(nn.Module):
|
||||
|
||||
def __init__(self, inter_channels=256, in_channels=1, bn=False):
|
||||
super(FCN_VGG, self).__init__()
|
||||
vgg = vgg16()
|
||||
features, classifier = list(vgg.features.children()), list(vgg.classifier.children())
|
||||
features, classifier = list(vgg.features.children()), list(
|
||||
vgg.classifier.children()
|
||||
)
|
||||
|
||||
if in_channels != 3:
|
||||
features[0] = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
|
||||
for f in features:
|
||||
if 'MaxPool' in f.__class__.__name__:
|
||||
if "MaxPool" in f.__class__.__name__:
|
||||
f.ceil_mode = True
|
||||
elif 'ReLU' in f.__class__.__name__:
|
||||
elif "ReLU" in f.__class__.__name__:
|
||||
f.inplace = True
|
||||
|
||||
features_temp = []
|
||||
|
@ -53,7 +60,7 @@ class FCN_VGG(nn.Module):
|
|||
features_temp.append(nn.GroupNorm(32, features[i].out_channels))
|
||||
|
||||
self.features3 = nn.Sequential(*features[:17])
|
||||
self.features4 = nn.Sequential(*features[17: 24])
|
||||
self.features4 = nn.Sequential(*features[17:24])
|
||||
self.features5 = nn.Sequential(*features[24:])
|
||||
|
||||
self.score_pool3 = nn.Conv2d(256, inter_channels, kernel_size=1)
|
||||
|
@ -67,7 +74,9 @@ class FCN_VGG(nn.Module):
|
|||
fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True), score_fr
|
||||
)
|
||||
self.upscore2 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True)
|
||||
self.upscore_pool4 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True)
|
||||
self.upscore_pool4 = Conv3x3GNReLU(
|
||||
inter_channels, inter_channels, upsample=True
|
||||
)
|
||||
self.final_conv = nn.Conv2d(inter_channels, 1, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -82,12 +91,16 @@ class FCN_VGG(nn.Module):
|
|||
upscore_pool4 = self.upscore_pool4(score_pool4 + upscore2, pool3.size()[-2:])
|
||||
|
||||
score_pool3 = self.score_pool3(pool3)
|
||||
upscore8 = F.interpolate(self.final_conv(score_pool3 + upscore_pool4), x.size()[-2:], mode='bilinear', align_corners=True)
|
||||
upscore8 = F.interpolate(
|
||||
self.final_conv(score_pool3 + upscore_pool4),
|
||||
x.size()[-2:],
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
return upscore8
|
||||
|
||||
|
||||
class FCN_AlexNet(nn.Module):
|
||||
|
||||
def __init__(self, inter_channels=256, in_channels=1):
|
||||
super(FCN_AlexNet, self).__init__()
|
||||
self.alexnet = AlexNet(in_channels=in_channels)
|
||||
|
@ -103,7 +116,9 @@ class FCN_AlexNet(nn.Module):
|
|||
fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True), score_fr
|
||||
)
|
||||
self.upscore2 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True)
|
||||
self.upscore_pool4 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True)
|
||||
self.upscore_pool4 = Conv3x3GNReLU(
|
||||
inter_channels, inter_channels, upsample=True
|
||||
)
|
||||
self.final_conv = nn.Conv2d(inter_channels, 1, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -118,31 +133,39 @@ class FCN_AlexNet(nn.Module):
|
|||
upscore_pool4 = self.upscore_pool4(score_pool4 + upscore2, pool3.size()[-2:])
|
||||
|
||||
score_pool3 = self.score_pool3(pool3)
|
||||
upscore8 = F.interpolate(self.final_conv(score_pool3 + upscore_pool4), x.size()[-2:],
|
||||
mode='bilinear', align_corners=True)
|
||||
upscore8 = F.interpolate(
|
||||
self.final_conv(score_pool3 + upscore_pool4),
|
||||
x.size()[-2:],
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
return upscore8
|
||||
|
||||
|
||||
class FCN_ResNet(nn.Module):
|
||||
|
||||
def __init__(self, backbone, inter_channels=256):
|
||||
super(FCN_ResNet, self).__init__()
|
||||
self.backbone = backbone
|
||||
|
||||
self.score_pool3 = nn.Conv2d(backbone.layer2[0].downsample[1].num_features,
|
||||
inter_channels, kernel_size=1)
|
||||
self.score_pool4 = nn.Conv2d(backbone.layer3[0].downsample[1].num_features,
|
||||
inter_channels, kernel_size=1)
|
||||
self.score_pool3 = nn.Conv2d(
|
||||
backbone.layer2[0].downsample[1].num_features, inter_channels, kernel_size=1
|
||||
)
|
||||
self.score_pool4 = nn.Conv2d(
|
||||
backbone.layer3[0].downsample[1].num_features, inter_channels, kernel_size=1
|
||||
)
|
||||
|
||||
fc6 = nn.Conv2d(backbone.layer4[0].downsample[1].num_features,
|
||||
512, kernel_size=3, padding=1)
|
||||
fc6 = nn.Conv2d(
|
||||
backbone.layer4[0].downsample[1].num_features, 512, kernel_size=3, padding=1
|
||||
)
|
||||
fc7 = nn.Conv2d(512, 512, kernel_size=1)
|
||||
score_fr = nn.Conv2d(512, inter_channels, kernel_size=1)
|
||||
self.score_fr = nn.Sequential(
|
||||
fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True), score_fr
|
||||
)
|
||||
self.upscore2 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True)
|
||||
self.upscore_pool4 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True)
|
||||
self.upscore_pool4 = Conv3x3GNReLU(
|
||||
inter_channels, inter_channels, upsample=True
|
||||
)
|
||||
self.final_conv = nn.Conv2d(inter_channels, 1, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -155,7 +178,12 @@ class FCN_ResNet(nn.Module):
|
|||
upscore_pool4 = self.upscore_pool4(score_pool4 + upscore2, pool3.size()[-2:])
|
||||
|
||||
score_pool3 = self.score_pool3(pool3)
|
||||
upscore8 = F.interpolate(self.final_conv(score_pool3 + upscore_pool4), x.size()[-2:], mode='bilinear', align_corners=True)
|
||||
upscore8 = F.interpolate(
|
||||
self.final_conv(score_pool3 + upscore_pool4),
|
||||
x.size()[-2:],
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
return upscore8
|
||||
|
||||
|
||||
|
@ -209,9 +237,9 @@ def FCN_ResNet152(in_channels=1, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
model = FCN_AlexNet(in_channels=1, inter_channels=128)
|
||||
x = torch.randn(1, 1, 200, 200)
|
||||
with torch.no_grad():
|
||||
y = model(x)
|
||||
print(y.shape)
|
||||
print(y.shape)
|
||||
|
|
|
@ -6,7 +6,13 @@ from src.utils.model_init import weights_init
|
|||
from .util.backbone import *
|
||||
|
||||
|
||||
__all__ = ["FPN_ResNet18", "FPN_ResNet34", "FPN_ResNet50", "FPN_ResNet101", "FPN_ResNet152"]
|
||||
__all__ = [
|
||||
"FPN_ResNet18",
|
||||
"FPN_ResNet34",
|
||||
"FPN_ResNet50",
|
||||
"FPN_ResNet101",
|
||||
"FPN_ResNet152",
|
||||
]
|
||||
|
||||
|
||||
class Conv3x3GNReLU(nn.Module):
|
||||
|
@ -14,7 +20,9 @@ class Conv3x3GNReLU(nn.Module):
|
|||
super().__init__()
|
||||
self.upsample = upsample
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False),
|
||||
nn.Conv2d(
|
||||
in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
|
||||
),
|
||||
nn.GroupNorm(32, out_channels),
|
||||
# nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
|
@ -45,11 +53,15 @@ class SegmentationBlock(nn.Module):
|
|||
def __init__(self, in_channels, out_channels, n_upsamples=0):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]
|
||||
self.blocks = [
|
||||
Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))
|
||||
]
|
||||
|
||||
if n_upsamples > 1:
|
||||
for _ in range(1, n_upsamples):
|
||||
self.blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))
|
||||
self.blocks.append(
|
||||
Conv3x3GNReLU(out_channels, out_channels, upsample=True)
|
||||
)
|
||||
|
||||
self.blocks_name = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
|
@ -77,32 +89,31 @@ class FPN_ResNet(nn.Module):
|
|||
self.backbone = backbone
|
||||
self.backbone.apply(weights_init)
|
||||
self.final_upsampling = final_upsampling
|
||||
self.conv1 = nn.Conv2d(encoder_channels[0],
|
||||
pyramid_channels,
|
||||
kernel_size=(1, 1))
|
||||
self.conv1 = nn.Conv2d(
|
||||
encoder_channels[0], pyramid_channels, kernel_size=(1, 1)
|
||||
)
|
||||
|
||||
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
|
||||
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
|
||||
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
|
||||
|
||||
self.s5 = SegmentationBlock(pyramid_channels,
|
||||
segmentation_channels,
|
||||
n_upsamples=3)
|
||||
self.s4 = SegmentationBlock(pyramid_channels,
|
||||
segmentation_channels,
|
||||
n_upsamples=2)
|
||||
self.s3 = SegmentationBlock(pyramid_channels,
|
||||
segmentation_channels,
|
||||
n_upsamples=1)
|
||||
self.s2 = SegmentationBlock(pyramid_channels,
|
||||
segmentation_channels,
|
||||
n_upsamples=0)
|
||||
self.s5 = SegmentationBlock(
|
||||
pyramid_channels, segmentation_channels, n_upsamples=3
|
||||
)
|
||||
self.s4 = SegmentationBlock(
|
||||
pyramid_channels, segmentation_channels, n_upsamples=2
|
||||
)
|
||||
self.s3 = SegmentationBlock(
|
||||
pyramid_channels, segmentation_channels, n_upsamples=1
|
||||
)
|
||||
self.s2 = SegmentationBlock(
|
||||
pyramid_channels, segmentation_channels, n_upsamples=0
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
|
||||
self.final_conv = nn.Conv2d(segmentation_channels,
|
||||
final_channels,
|
||||
kernel_size=1,
|
||||
padding=0)
|
||||
self.final_conv = nn.Conv2d(
|
||||
segmentation_channels, final_channels, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
|
@ -126,45 +137,45 @@ class FPN_ResNet(nn.Module):
|
|||
x = self.final_conv(x)
|
||||
|
||||
if self.final_upsampling is not None and self.final_upsampling > 1:
|
||||
x = F.interpolate(x, scale_factor=self.final_upsampling, mode="bilinear", align_corners=True)
|
||||
x = F.interpolate(
|
||||
x,
|
||||
scale_factor=self.final_upsampling,
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def FPN_ResNet18(in_channels=1, **kwargs):
|
||||
"""FPN with ResNet18 as backbone
|
||||
"""
|
||||
"""FPN with ResNet18 as backbone"""
|
||||
backbone = resnet18(in_channels=in_channels)
|
||||
model = FPN_ResNet(backbone, encoder_channels=[512, 256, 128, 64], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def FPN_ResNet34(in_channels=1, **kwargs):
|
||||
"""FPN with ResNet18 as backbone
|
||||
"""
|
||||
"""FPN with ResNet18 as backbone"""
|
||||
backbone = resnet34(in_channels=in_channels)
|
||||
model = FPN_ResNet(backbone, encoder_channels=[512, 256, 128, 64], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def FPN_ResNet50(in_channels=1, **kwargs):
|
||||
"""FPN with ResNet50 as backbone
|
||||
"""
|
||||
"""FPN with ResNet50 as backbone"""
|
||||
backbone = resnet50(in_channels=in_channels)
|
||||
model = FPN_ResNet(backbone, encoder_channels=[2048, 1024, 512, 256], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def FPN_ResNet101(in_channels=1, **kwargs):
|
||||
"""FPN with ResNet101 as backbone
|
||||
"""
|
||||
"""FPN with ResNet101 as backbone"""
|
||||
backbone = resnet101(in_channels=in_channels)
|
||||
model = FPN_ResNet(backbone, encoder_channels=[2048, 1024, 512, 256], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def FPN_ResNet152(in_channels=1, **kwargs):
|
||||
"""FPN with ResNet101 as backbone
|
||||
"""
|
||||
"""FPN with ResNet101 as backbone"""
|
||||
backbone = resnet152(in_channels=in_channels)
|
||||
model = FPN_ResNet(backbone, encoder_channels=[2048, 1024, 512, 256], **kwargs)
|
||||
return model
|
||||
return model
|
||||
|
|
|
@ -2,8 +2,23 @@ import torch
|
|||
from torch.nn import Sequential as Seq
|
||||
from torch.nn import Linear as Lin
|
||||
import torch_geometric as tg
|
||||
from .util.gcn_lib.dense import BasicConv, GraphConv2d, PlainDynBlock2d, ResDynBlock2d, DenseDynBlock2d, DenseDilatedKnnGraph
|
||||
from .util.gcn_lib.sparse import MultiSeq, MLP, GraphConv, PlainDynBlock, ResDynBlock, DenseDynBlock, DilatedKnnGraph
|
||||
from .util.gcn_lib.dense import (
|
||||
BasicConv,
|
||||
GraphConv2d,
|
||||
PlainDynBlock2d,
|
||||
ResDynBlock2d,
|
||||
DenseDynBlock2d,
|
||||
DenseDilatedKnnGraph,
|
||||
)
|
||||
from .util.gcn_lib.sparse import (
|
||||
MultiSeq,
|
||||
MLP,
|
||||
GraphConv,
|
||||
PlainDynBlock,
|
||||
ResDynBlock,
|
||||
DenseDynBlock,
|
||||
DilatedKnnGraph,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DenseDeepGCN"]
|
||||
|
@ -15,7 +30,7 @@ class DenseDeepGCN(torch.nn.Module):
|
|||
input_dim=None,
|
||||
output_dim=None,
|
||||
dropout=0.8,
|
||||
in_channels=2+1,
|
||||
in_channels=2 + 1,
|
||||
k=5,
|
||||
n_classes=1,
|
||||
block="dense",
|
||||
|
@ -31,57 +46,99 @@ class DenseDeepGCN(torch.nn.Module):
|
|||
):
|
||||
super(DenseDeepGCN, self).__init__()
|
||||
self.dim = dim
|
||||
self.n_classes = n_classes #
|
||||
self.k = k #
|
||||
self.in_channels = in_channels #
|
||||
self.dropout = dropout #
|
||||
self.n_classes = n_classes #
|
||||
self.k = k #
|
||||
self.in_channels = in_channels #
|
||||
self.dropout = dropout #
|
||||
self.block = block
|
||||
self.conv = conv #
|
||||
self.act = act #
|
||||
self.norm = norm #
|
||||
self.bias = bias #
|
||||
self.channels = n_filters #
|
||||
self.n_blocks = n_blocks #
|
||||
self.epsilon = epsilon #
|
||||
self.stochastic = stochastic #
|
||||
self.conv = conv #
|
||||
self.act = act #
|
||||
self.norm = norm #
|
||||
self.bias = bias #
|
||||
self.channels = n_filters #
|
||||
self.n_blocks = n_blocks #
|
||||
self.epsilon = epsilon #
|
||||
self.stochastic = stochastic #
|
||||
|
||||
c_growth = self.channels
|
||||
|
||||
#print(self.dropout)
|
||||
# print(self.dropout)
|
||||
self.knn = DenseDilatedKnnGraph(k, 1, self.stochastic, self.epsilon)
|
||||
self.head = GraphConv2d(self.in_channels, self.channels, conv, act, norm, bias)
|
||||
|
||||
if self.block.lower() == 'res':
|
||||
self.backbone = Seq(*[ResDynBlock2d(self.channels, k, 1+i, conv, act, norm, bias, stochastic, epsilon)
|
||||
for i in range(self.n_blocks-1)])
|
||||
if self.block.lower() == "res":
|
||||
self.backbone = Seq(
|
||||
*[
|
||||
ResDynBlock2d(
|
||||
self.channels,
|
||||
k,
|
||||
1 + i,
|
||||
conv,
|
||||
act,
|
||||
norm,
|
||||
bias,
|
||||
stochastic,
|
||||
epsilon,
|
||||
)
|
||||
for i in range(self.n_blocks - 1)
|
||||
]
|
||||
)
|
||||
fusion_dims = int(self.channels + c_growth * (self.n_blocks - 1))
|
||||
elif self.block.lower() == 'dense':
|
||||
self.backbone = Seq(*[DenseDynBlock2d(self.channels+c_growth*i, c_growth, k, 1+i, conv, act,
|
||||
norm, bias, stochastic, epsilon)
|
||||
for i in range(self.n_blocks-1)])
|
||||
elif self.block.lower() == "dense":
|
||||
self.backbone = Seq(
|
||||
*[
|
||||
DenseDynBlock2d(
|
||||
self.channels + c_growth * i,
|
||||
c_growth,
|
||||
k,
|
||||
1 + i,
|
||||
conv,
|
||||
act,
|
||||
norm,
|
||||
bias,
|
||||
stochastic,
|
||||
epsilon,
|
||||
)
|
||||
for i in range(self.n_blocks - 1)
|
||||
]
|
||||
)
|
||||
fusion_dims = int(
|
||||
(self.channels + self.channels + c_growth * (self.n_blocks - 1)) * self.n_blocks // 2)
|
||||
(self.channels + self.channels + c_growth * (self.n_blocks - 1))
|
||||
* self.n_blocks
|
||||
// 2
|
||||
)
|
||||
else:
|
||||
stochastic = False
|
||||
|
||||
self.backbone = Seq(*[PlainDynBlock2d(self.channels, k, 1, conv, act, norm,
|
||||
bias, stochastic, epsilon)
|
||||
for i in range(self.n_blocks - 1)])
|
||||
self.backbone = Seq(
|
||||
*[
|
||||
PlainDynBlock2d(
|
||||
self.channels, k, 1, conv, act, norm, bias, stochastic, epsilon
|
||||
)
|
||||
for i in range(self.n_blocks - 1)
|
||||
]
|
||||
)
|
||||
fusion_dims = int(self.channels + c_growth * (self.n_blocks - 1))
|
||||
|
||||
self.fusion_block = BasicConv([fusion_dims, 1024], act, norm, bias)
|
||||
self.prediction = Seq(*[BasicConv([fusion_dims+1024, 512], act, norm, bias),
|
||||
BasicConv([512, 256], act, norm, bias),
|
||||
torch.nn.Dropout(p=self.dropout),
|
||||
BasicConv([256, self.n_classes], None, None, bias)])
|
||||
self.prediction = Seq(
|
||||
*[
|
||||
BasicConv([fusion_dims + 1024, 512], act, norm, bias),
|
||||
BasicConv([512, 256], act, norm, bias),
|
||||
torch.nn.Dropout(p=self.dropout),
|
||||
BasicConv([256, self.n_classes], None, None, bias),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
feats = [self.head(inputs, self.knn(inputs[:, 0:self.dim]))]
|
||||
for i in range(self.n_blocks-1):
|
||||
feats = [self.head(inputs, self.knn(inputs[:, 0 : self.dim]))]
|
||||
for i in range(self.n_blocks - 1):
|
||||
feats.append(self.backbone[i](feats[-1]))
|
||||
feats = torch.cat(feats, dim=1)
|
||||
|
||||
fusion = torch.max_pool2d(self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]])
|
||||
fusion = torch.max_pool2d(
|
||||
self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]]
|
||||
)
|
||||
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2)
|
||||
return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1)
|
||||
|
||||
|
@ -104,28 +161,76 @@ class SparseDeepGCN(torch.nn.Module):
|
|||
self.knn = DilatedKnnGraph(k, 1, stochastic, epsilon)
|
||||
self.head = GraphConv(opt.in_channels, channels, conv, act, norm, bias)
|
||||
|
||||
if opt.block.lower() == 'res':
|
||||
self.backbone = MultiSeq(*[ResDynBlock(channels, k, 1+i, conv, act, norm, bias, stochastic=stochastic, epsilon=epsilon)
|
||||
for i in range(self.n_blocks-1)])
|
||||
if opt.block.lower() == "res":
|
||||
self.backbone = MultiSeq(
|
||||
*[
|
||||
ResDynBlock(
|
||||
channels,
|
||||
k,
|
||||
1 + i,
|
||||
conv,
|
||||
act,
|
||||
norm,
|
||||
bias,
|
||||
stochastic=stochastic,
|
||||
epsilon=epsilon,
|
||||
)
|
||||
for i in range(self.n_blocks - 1)
|
||||
]
|
||||
)
|
||||
fusion_dims = int(channels + c_growth * (self.n_blocks - 1))
|
||||
elif opt.block.lower() == 'dense':
|
||||
self.backbone = MultiSeq(*[DenseDynBlock(channels+c_growth*i, c_growth, k, 1+i,
|
||||
conv, act, norm, bias, stochastic=stochastic, epsilon=epsilon)
|
||||
for i in range(self.n_blocks-1)])
|
||||
elif opt.block.lower() == "dense":
|
||||
self.backbone = MultiSeq(
|
||||
*[
|
||||
DenseDynBlock(
|
||||
channels + c_growth * i,
|
||||
c_growth,
|
||||
k,
|
||||
1 + i,
|
||||
conv,
|
||||
act,
|
||||
norm,
|
||||
bias,
|
||||
stochastic=stochastic,
|
||||
epsilon=epsilon,
|
||||
)
|
||||
for i in range(self.n_blocks - 1)
|
||||
]
|
||||
)
|
||||
fusion_dims = int(
|
||||
(channels + channels + c_growth * (self.n_blocks - 1)) * self.n_blocks // 2)
|
||||
(channels + channels + c_growth * (self.n_blocks - 1))
|
||||
* self.n_blocks
|
||||
// 2
|
||||
)
|
||||
else:
|
||||
# Use PlainGCN without skip connection and dilated convolution.
|
||||
stochastic = False
|
||||
self.backbone = MultiSeq(
|
||||
*[PlainDynBlock(channels, k, 1, conv, act, norm, bias, stochastic=stochastic, epsilon=epsilon)
|
||||
for i in range(self.n_blocks - 1)])
|
||||
*[
|
||||
PlainDynBlock(
|
||||
channels,
|
||||
k,
|
||||
1,
|
||||
conv,
|
||||
act,
|
||||
norm,
|
||||
bias,
|
||||
stochastic=stochastic,
|
||||
epsilon=epsilon,
|
||||
)
|
||||
for i in range(self.n_blocks - 1)
|
||||
]
|
||||
)
|
||||
fusion_dims = int(channels + c_growth * (self.n_blocks - 1))
|
||||
|
||||
self.fusion_block = MLP([fusion_dims, 1024], act, norm, bias)
|
||||
self.prediction = MultiSeq(*[MLP([fusion_dims+1024, 512], act, norm, bias),
|
||||
MLP([512, 256], act, norm, bias, drop=opt.dropout),
|
||||
MLP([256, opt.n_classes], None, None, bias)])
|
||||
self.prediction = MultiSeq(
|
||||
*[
|
||||
MLP([fusion_dims + 1024, 512], act, norm, bias),
|
||||
MLP([512, 256], act, norm, bias, drop=opt.dropout),
|
||||
MLP([256, opt.n_classes], None, None, bias),
|
||||
]
|
||||
)
|
||||
self.model_init()
|
||||
|
||||
def model_init(self):
|
||||
|
@ -141,19 +246,21 @@ class SparseDeepGCN(torch.nn.Module):
|
|||
corr, color, batch = data.pos, data.x, data.batch
|
||||
x = torch.cat((corr, color), dim=1)
|
||||
feats = [self.head(x, self.knn(x[:, 0:3], batch))]
|
||||
for i in range(self.n_blocks-1):
|
||||
for i in range(self.n_blocks - 1):
|
||||
feats.append(self.backbone[i](feats[-1], batch)[0])
|
||||
feats = torch.cat(feats, dim=1)
|
||||
|
||||
fusion = tg.utils.scatter_('max', self.fusion_block(feats), batch)
|
||||
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0)
|
||||
fusion = tg.utils.scatter_("max", self.fusion_block(feats), batch)
|
||||
fusion = torch.repeat_interleave(
|
||||
fusion, repeats=feats.shape[0] // fusion.shape[0], dim=0
|
||||
)
|
||||
return self.prediction(torch.cat((fusion, feats), dim=1))
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
import random, numpy as np, argparse
|
||||
|
||||
seed = 0
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
@ -163,22 +270,58 @@ if __name__ == "__main__":
|
|||
|
||||
batch_size = 2
|
||||
N = 1024
|
||||
device = 'cuda'
|
||||
device = "cuda"
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch implementation of Deep GCN For semantic segmentation')
|
||||
parser.add_argument('--in_channels', default=9, type=int, help='input channels (default:9)')
|
||||
parser.add_argument('--n_classes', default=13, type=int, help='num of segmentation classes (default:13)')
|
||||
parser.add_argument('--k', default=4, type=int, help='neighbor num (default:16)')
|
||||
parser.add_argument('--block', default='res', type=str, help='graph backbone block type {plain, res, dense}')
|
||||
parser.add_argument('--conv', default='edge', type=str, help='graph conv layer {edge, mr}')
|
||||
parser.add_argument('--act', default='relu', type=str, help='activation layer {relu, prelu, leakyrelu}')
|
||||
parser.add_argument('--norm', default='batch', type=str, help='{batch, instance} normalization')
|
||||
parser.add_argument('--bias', default=True, type=bool, help='bias of conv layer True or False')
|
||||
parser.add_argument('--n_filters', default=64, type=int, help='number of channels of deep features')
|
||||
parser.add_argument('--n_blocks', default=7, type=int, help='number of basic blocks')
|
||||
parser.add_argument('--dropout', default=0.5, type=float, help='ratio of dropout')
|
||||
parser.add_argument('--epsilon', default=0.2, type=float, help='stochastic epsilon for gcn')
|
||||
parser.add_argument('--stochastic', default=False, type=bool, help='stochastic for gcn, True or False')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PyTorch implementation of Deep GCN For semantic segmentation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--in_channels", default=9, type=int, help="input channels (default:9)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_classes",
|
||||
default=13,
|
||||
type=int,
|
||||
help="num of segmentation classes (default:13)",
|
||||
)
|
||||
parser.add_argument("--k", default=4, type=int, help="neighbor num (default:16)")
|
||||
parser.add_argument(
|
||||
"--block",
|
||||
default="res",
|
||||
type=str,
|
||||
help="graph backbone block type {plain, res, dense}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conv", default="edge", type=str, help="graph conv layer {edge, mr}"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--act",
|
||||
default="relu",
|
||||
type=str,
|
||||
help="activation layer {relu, prelu, leakyrelu}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--norm", default="batch", type=str, help="{batch, instance} normalization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bias", default=True, type=bool, help="bias of conv layer True or False"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_filters", default=64, type=int, help="number of channels of deep features"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_blocks", default=7, type=int, help="number of basic blocks"
|
||||
)
|
||||
parser.add_argument("--dropout", default=0.5, type=float, help="ratio of dropout")
|
||||
parser.add_argument(
|
||||
"--epsilon", default=0.2, type=float, help="stochastic epsilon for gcn"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stochastic",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="stochastic for gcn, True or False",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pos = torch.rand((batch_size, N, 2), dtype=torch.float).to(device)
|
||||
|
@ -188,9 +331,8 @@ if __name__ == "__main__":
|
|||
print(inputs.size())
|
||||
|
||||
# net = DGCNNSegDense().to(device)
|
||||
net = DenseDeepGCN(in_channels=2+6).to(device)
|
||||
net = DenseDeepGCN(in_channels=2 + 6).to(device)
|
||||
# net = SparseDeepGCN(args).to(device)
|
||||
print(net)
|
||||
out = net(inputs)
|
||||
print(out.shape)
|
||||
|
|
@ -2,4 +2,4 @@ import src.models.util.point_lib as pointmodels
|
|||
|
||||
|
||||
def Point(model, u_obs, u):
|
||||
return getattr(pointmodels, model)(u_obs, u)
|
||||
return getattr(pointmodels, model)(u_obs, u)
|
||||
|
|
|
@ -9,38 +9,40 @@ def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
|
|||
# activation layer
|
||||
|
||||
act = act.lower()
|
||||
if act == 'relu':
|
||||
if act == "relu":
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act == 'leakyrelu':
|
||||
elif act == "leakyrelu":
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act == 'prelu':
|
||||
elif act == "prelu":
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act == 'gelu':
|
||||
elif act == "gelu":
|
||||
layer = nn.GELU()
|
||||
elif act == 'sigmoid':
|
||||
elif act == "sigmoid":
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError('activation layer [%s] is not found' % act)
|
||||
raise NotImplementedError("activation layer [%s] is not found" % act)
|
||||
return layer
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, act = "gelu", bias=True, hidden_layer=[512, 512, 512]):
|
||||
def __init__(
|
||||
self, input_dim, output_dim, act="gelu", bias=True, hidden_layer=[512, 512, 512]
|
||||
):
|
||||
super(MLP, self).__init__()
|
||||
self.hidden_layer = hidden_layer
|
||||
net = []
|
||||
net.append(nn.Linear(input_dim, hidden_layer[0]))
|
||||
if act is not None and act.lower() != 'none':
|
||||
if act is not None and act.lower() != "none":
|
||||
net.append(act_layer(act))
|
||||
|
||||
if len(hidden_layer) > 1:
|
||||
|
||||
if len(hidden_layer) > 1:
|
||||
for i in range(1, len(hidden_layer)):
|
||||
net.append(nn.Linear(hidden_layer[i - 1], hidden_layer[i], bias))
|
||||
if act is not None and act.lower() != 'none':
|
||||
if act is not None and act.lower() != "none":
|
||||
net.append(act_layer(act))
|
||||
|
||||
|
||||
net.append(nn.Linear(hidden_layer[-1], output_dim))
|
||||
self.net = nn.Sequential(*net)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
return self.net(x)
|
||||
|
|
|
@ -8,28 +8,37 @@ import torch.nn.functional as F
|
|||
from .util.backbone import *
|
||||
|
||||
|
||||
__all__ = ["SegNet_VGG", "SegNet_VGG_GN", "SegNet_AlexNet", "SegNet_ResNet18",
|
||||
"SegNet_ResNet50", "SegNet_ResNet101", "SegNet_ResNet34", "SegNet_ResNet152"]
|
||||
__all__ = [
|
||||
"SegNet_VGG",
|
||||
"SegNet_VGG_GN",
|
||||
"SegNet_AlexNet",
|
||||
"SegNet_ResNet18",
|
||||
"SegNet_ResNet50",
|
||||
"SegNet_ResNet101",
|
||||
"SegNet_ResNet34",
|
||||
"SegNet_ResNet152",
|
||||
]
|
||||
|
||||
|
||||
# required class for decoder of SegNet_ResNet
|
||||
class DecoderBottleneck(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(DecoderBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 4,
|
||||
kernel_size=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.conv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4,
|
||||
kernel_size=2, stride=2, bias=False)
|
||||
self.conv2 = nn.ConvTranspose2d(
|
||||
in_channels // 4, in_channels // 4, kernel_size=2, stride=2, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 2, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(in_channels // 2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels, in_channels // 2,
|
||||
kernel_size=2, stride=2, bias=False),
|
||||
nn.BatchNorm2d(in_channels // 2))
|
||||
nn.ConvTranspose2d(
|
||||
in_channels, in_channels // 2, kernel_size=2, stride=2, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(in_channels // 2),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
@ -49,21 +58,21 @@ class DecoderBottleneck(nn.Module):
|
|||
|
||||
# required class for decoder of SegNet_ResNet
|
||||
class LastBottleneck(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(LastBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 4,
|
||||
kernel_size=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.conv2 = nn.Conv2d(in_channels // 4, in_channels // 4,
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels // 4, in_channels // 4, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 4, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(in_channels // 4))
|
||||
nn.BatchNorm2d(in_channels // 4),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
@ -83,20 +92,23 @@ class LastBottleneck(nn.Module):
|
|||
|
||||
# required class for decoder of SegNet_ResNet
|
||||
class DecoderBasicBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(DecoderBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 2,
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, in_channels // 2, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels // 2)
|
||||
self.conv2 = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
|
||||
kernel_size=2, stride=2, bias=False)
|
||||
self.conv2 = nn.ConvTranspose2d(
|
||||
in_channels // 2, in_channels // 2, kernel_size=2, stride=2, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(in_channels // 2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels, in_channels // 2,
|
||||
kernel_size=2, stride=2, bias=False),
|
||||
nn.BatchNorm2d(in_channels // 2))
|
||||
nn.ConvTranspose2d(
|
||||
in_channels, in_channels // 2, kernel_size=2, stride=2, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(in_channels // 2),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
@ -112,19 +124,21 @@ class DecoderBasicBlock(nn.Module):
|
|||
|
||||
|
||||
class LastBasicBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(LastBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels,
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.conv2 = nn.Conv2d(in_channels, in_channels,
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(in_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(in_channels))
|
||||
nn.BatchNorm2d(in_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
@ -140,7 +154,6 @@ class LastBasicBlock(nn.Module):
|
|||
|
||||
|
||||
class SegNet_VGG(nn.Module):
|
||||
|
||||
def __init__(self, out_channels=1, in_channels=1, pretrained=False):
|
||||
super(SegNet_VGG, self).__init__()
|
||||
vgg_bn = vgg16_bn(pretrained=pretrained)
|
||||
|
@ -160,34 +173,45 @@ class SegNet_VGG(nn.Module):
|
|||
|
||||
# Decoder, same as the encoder but reversed, maxpool will not be used
|
||||
decoder = encoder
|
||||
decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)]
|
||||
decoder = [
|
||||
i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)
|
||||
]
|
||||
# Replace the last conv layer
|
||||
decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||
# When reversing, we also reversed conv->batchN->relu, correct it
|
||||
decoder = [item for i in range(0, len(decoder), 3)
|
||||
for item in decoder[i:i + 3][::-1]]
|
||||
decoder = [
|
||||
item for i in range(0, len(decoder), 3) for item in decoder[i : i + 3][::-1]
|
||||
]
|
||||
# Replace some conv layers & batchN after them
|
||||
for i, module in enumerate(decoder):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
if module.in_channels != module.out_channels:
|
||||
decoder[i + 1] = nn.BatchNorm2d(module.in_channels)
|
||||
decoder[i] = nn.Conv2d(module.out_channels, module.in_channels,
|
||||
kernel_size=3, stride=1, padding=1)
|
||||
decoder[i] = nn.Conv2d(
|
||||
module.out_channels,
|
||||
module.in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
self.stage1_decoder = nn.Sequential(*decoder[0:9])
|
||||
self.stage2_decoder = nn.Sequential(*decoder[9:18])
|
||||
self.stage3_decoder = nn.Sequential(*decoder[18:27])
|
||||
self.stage4_decoder = nn.Sequential(*decoder[27:33])
|
||||
self.stage5_decoder = nn.Sequential(*decoder[33:],
|
||||
nn.Conv2d(64, out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
)
|
||||
self.stage5_decoder = nn.Sequential(
|
||||
*decoder[33:],
|
||||
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
||||
|
||||
self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder,
|
||||
self.stage4_decoder, self.stage5_decoder)
|
||||
self._initialize_weights(
|
||||
self.stage1_decoder,
|
||||
self.stage2_decoder,
|
||||
self.stage3_decoder,
|
||||
self.stage4_decoder,
|
||||
self.stage5_decoder,
|
||||
)
|
||||
|
||||
def _initialize_weights(self, *stages):
|
||||
for modules in stages:
|
||||
|
@ -242,7 +266,6 @@ class SegNet_VGG(nn.Module):
|
|||
|
||||
|
||||
class SegNet_VGG_GN(nn.Module):
|
||||
|
||||
def __init__(self, out_channels=1, in_channels=3, pretrained=False):
|
||||
super(SegNet_VGG_GN, self).__init__()
|
||||
vgg_bn = vgg16_bn(pretrained=pretrained)
|
||||
|
@ -267,33 +290,45 @@ class SegNet_VGG_GN(nn.Module):
|
|||
|
||||
# Decoder, same as the encoder but reversed, maxpool will not be used
|
||||
decoder = encoder
|
||||
decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)]
|
||||
decoder = [
|
||||
i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)
|
||||
]
|
||||
# Replace the last conv layer
|
||||
decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||
# When reversing, we also reversed conv->batchN->relu, correct it
|
||||
decoder = [item for i in range(0, len(decoder), 3)
|
||||
for item in decoder[i:i + 3][::-1]]
|
||||
decoder = [
|
||||
item for i in range(0, len(decoder), 3) for item in decoder[i : i + 3][::-1]
|
||||
]
|
||||
# Replace some conv layers & batchN after them
|
||||
for i, module in enumerate(decoder):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
if module.in_channels != module.out_channels:
|
||||
decoder[i + 1] = nn.GroupNorm(32, module.in_channels)
|
||||
decoder[i] = nn.Conv2d(module.out_channels, module.in_channels,
|
||||
kernel_size=3, stride=1, padding=1)
|
||||
decoder[i] = nn.Conv2d(
|
||||
module.out_channels,
|
||||
module.in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
self.stage1_decoder = nn.Sequential(*decoder[0:9])
|
||||
self.stage2_decoder = nn.Sequential(*decoder[9:18])
|
||||
self.stage3_decoder = nn.Sequential(*decoder[18:27])
|
||||
self.stage4_decoder = nn.Sequential(*decoder[27:33])
|
||||
self.stage5_decoder = nn.Sequential(*decoder[33:], nn.Conv2d(64,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
self.stage5_decoder = nn.Sequential(
|
||||
*decoder[33:],
|
||||
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
||||
|
||||
self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder,
|
||||
self.stage4_decoder, self.stage5_decoder)
|
||||
self._initialize_weights(
|
||||
self.stage1_decoder,
|
||||
self.stage2_decoder,
|
||||
self.stage3_decoder,
|
||||
self.stage4_decoder,
|
||||
self.stage5_decoder,
|
||||
)
|
||||
|
||||
def _initialize_weights(self, *stages):
|
||||
for modules in stages:
|
||||
|
@ -348,7 +383,6 @@ class SegNet_VGG_GN(nn.Module):
|
|||
|
||||
|
||||
class SegNet_AlexNet(nn.Module):
|
||||
|
||||
def __init__(self, out_channels=1, in_channels=1, bn=False):
|
||||
super(SegNet_AlexNet, self).__init__()
|
||||
self.stage3_encoder = nn.Sequential(
|
||||
|
@ -373,7 +407,9 @@ class SegNet_AlexNet(nn.Module):
|
|||
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False, return_indices=True)
|
||||
self.maxpool = nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, ceil_mode=False, return_indices=True
|
||||
)
|
||||
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
||||
self.stage5_decoder = nn.Sequential(
|
||||
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||
|
@ -419,7 +455,6 @@ class SegNet_AlexNet(nn.Module):
|
|||
|
||||
|
||||
class SegNet_ResNet(nn.Module):
|
||||
|
||||
def __init__(self, backbone, out_channels=1, is_bottleneck=False, in_channels=1):
|
||||
super(SegNet_ResNet, self).__init__()
|
||||
resnet_backbone = backbone
|
||||
|
@ -442,18 +477,25 @@ class SegNet_ResNet(nn.Module):
|
|||
channels = (512, 256, 128)
|
||||
for i, block in enumerate(resnet_r_blocks[:-1]):
|
||||
new_block = list(block.children())[::-1][:-1]
|
||||
decoder.append(nn.Sequential(*new_block,
|
||||
DecoderBottleneck(channels[i])
|
||||
if is_bottleneck else DecoderBasicBlock(channels[i])))
|
||||
decoder.append(
|
||||
nn.Sequential(
|
||||
*new_block,
|
||||
DecoderBottleneck(channels[i])
|
||||
if is_bottleneck
|
||||
else DecoderBasicBlock(channels[i])
|
||||
)
|
||||
)
|
||||
new_block = list(resnet_r_blocks[-1].children())[::-1][:-1]
|
||||
decoder.append(nn.Sequential(*new_block,
|
||||
LastBottleneck(256)
|
||||
if is_bottleneck else LastBasicBlock(64)))
|
||||
decoder.append(
|
||||
nn.Sequential(
|
||||
*new_block, LastBottleneck(256) if is_bottleneck else LastBasicBlock(64)
|
||||
)
|
||||
)
|
||||
|
||||
self.decoder = nn.Sequential(*decoder)
|
||||
self.last_conv = nn.Sequential(
|
||||
nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, bias=False),
|
||||
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -468,10 +510,14 @@ class SegNet_ResNet(nn.Module):
|
|||
h_diff = ceil((x.size()[2] - indices.size()[2]) / 2)
|
||||
w_diff = ceil((x.size()[3] - indices.size()[3]) / 2)
|
||||
if indices.size()[2] % 2 == 1:
|
||||
x = x[:, :, h_diff:x.size()[2] - (h_diff - 1),
|
||||
w_diff: x.size()[3] - (w_diff - 1)]
|
||||
x = x[
|
||||
:,
|
||||
:,
|
||||
h_diff : x.size()[2] - (h_diff - 1),
|
||||
w_diff : x.size()[3] - (w_diff - 1),
|
||||
]
|
||||
else:
|
||||
x = x[:, :, h_diff:x.size()[2] - h_diff, w_diff: x.size()[3] - w_diff]
|
||||
x = x[:, :, h_diff : x.size()[2] - h_diff, w_diff : x.size()[3] - w_diff]
|
||||
|
||||
x = F.max_unpool2d(x, indices, kernel_size=2, stride=2)
|
||||
x = self.last_conv(x)
|
||||
|
@ -479,9 +525,11 @@ class SegNet_ResNet(nn.Module):
|
|||
if inputsize != x.size():
|
||||
h_diff = (x.size()[2] - inputsize[2]) // 2
|
||||
w_diff = (x.size()[3] - inputsize[3]) // 2
|
||||
x = x[:, :, h_diff:x.size()[2] - h_diff, w_diff: x.size()[3] - w_diff]
|
||||
if h_diff % 2 != 0: x = x[:, :, :-1, :]
|
||||
if w_diff % 2 != 0: x = x[:, :, :, :-1]
|
||||
x = x[:, :, h_diff : x.size()[2] - h_diff, w_diff : x.size()[3] - w_diff]
|
||||
if h_diff % 2 != 0:
|
||||
x = x[:, :, :-1, :]
|
||||
if w_diff % 2 != 0:
|
||||
x = x[:, :, :, :-1]
|
||||
|
||||
return x
|
||||
|
||||
|
@ -492,8 +540,13 @@ def SegNet_ResNet18(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet18()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=False,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=False,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -503,8 +556,13 @@ def SegNet_ResNet34(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet34()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=False,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=False,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -514,8 +572,13 @@ def SegNet_ResNet50(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet50()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=True,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -525,8 +588,13 @@ def SegNet_ResNet101(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet101()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=True,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -536,15 +604,20 @@ def SegNet_ResNet152(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet101()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=True,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
model = SegNet_AlexNet(in_channels=1, out_channels=1)
|
||||
print(model)
|
||||
x = torch.randn(1, 1, 200, 200)
|
||||
with torch.no_grad():
|
||||
y = model(x)
|
||||
print(y.shape)
|
||||
print(y.shape)
|
||||
|
|
|
@ -8,28 +8,37 @@ import torch.nn.functional as F
|
|||
from .util.backbone import *
|
||||
|
||||
|
||||
__all__ = ["SegNet_VGG", "SegNet_VGG_GN", "SegNet_AlexNet", "SegNet_ResNet18",
|
||||
"SegNet_ResNet50", "SegNet_ResNet101", "SegNet_ResNet34", "SegNet_ResNet152"]
|
||||
__all__ = [
|
||||
"SegNet_VGG",
|
||||
"SegNet_VGG_GN",
|
||||
"SegNet_AlexNet",
|
||||
"SegNet_ResNet18",
|
||||
"SegNet_ResNet50",
|
||||
"SegNet_ResNet101",
|
||||
"SegNet_ResNet34",
|
||||
"SegNet_ResNet152",
|
||||
]
|
||||
|
||||
|
||||
# required class for decoder of SegNet_ResNet
|
||||
class DecoderBottleneck(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(DecoderBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 4,
|
||||
kernel_size=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.conv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4,
|
||||
kernel_size=2, stride=2, bias=False)
|
||||
self.conv2 = nn.ConvTranspose2d(
|
||||
in_channels // 4, in_channels // 4, kernel_size=2, stride=2, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 2, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(in_channels // 2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels, in_channels // 2,
|
||||
kernel_size=2, stride=2, bias=False),
|
||||
nn.BatchNorm2d(in_channels // 2))
|
||||
nn.ConvTranspose2d(
|
||||
in_channels, in_channels // 2, kernel_size=2, stride=2, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(in_channels // 2),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
@ -49,21 +58,21 @@ class DecoderBottleneck(nn.Module):
|
|||
|
||||
# required class for decoder of SegNet_ResNet
|
||||
class LastBottleneck(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(LastBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 4,
|
||||
kernel_size=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.conv2 = nn.Conv2d(in_channels // 4, in_channels // 4,
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels // 4, in_channels // 4, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 4, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(in_channels // 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(in_channels // 4))
|
||||
nn.BatchNorm2d(in_channels // 4),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
@ -83,20 +92,23 @@ class LastBottleneck(nn.Module):
|
|||
|
||||
# required class for decoder of SegNet_ResNet
|
||||
class DecoderBasicBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(DecoderBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels // 2,
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, in_channels // 2, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels // 2)
|
||||
self.conv2 = nn.ConvTranspose2d(in_channels // 2, in_channels // 2,
|
||||
kernel_size=2, stride=2, bias=False)
|
||||
self.conv2 = nn.ConvTranspose2d(
|
||||
in_channels // 2, in_channels // 2, kernel_size=2, stride=2, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(in_channels // 2)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels, in_channels // 2,
|
||||
kernel_size=2, stride=2, bias=False),
|
||||
nn.BatchNorm2d(in_channels // 2))
|
||||
nn.ConvTranspose2d(
|
||||
in_channels, in_channels // 2, kernel_size=2, stride=2, bias=False
|
||||
),
|
||||
nn.BatchNorm2d(in_channels // 2),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
@ -112,19 +124,21 @@ class DecoderBasicBlock(nn.Module):
|
|||
|
||||
|
||||
class LastBasicBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super(LastBasicBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels,
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.conv2 = nn.Conv2d(in_channels, in_channels,
|
||||
kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(in_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(in_channels))
|
||||
nn.BatchNorm2d(in_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
|
@ -140,7 +154,6 @@ class LastBasicBlock(nn.Module):
|
|||
|
||||
|
||||
class SegNet_VGG(nn.Module):
|
||||
|
||||
def __init__(self, out_channels=1, in_channels=1, pretrained=False):
|
||||
super(SegNet_VGG, self).__init__()
|
||||
vgg_bn = vgg16_bn(pretrained=pretrained)
|
||||
|
@ -160,34 +173,45 @@ class SegNet_VGG(nn.Module):
|
|||
|
||||
# Decoder, same as the encoder but reversed, maxpool will not be used
|
||||
decoder = encoder
|
||||
decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)]
|
||||
decoder = [
|
||||
i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)
|
||||
]
|
||||
# Replace the last conv layer
|
||||
decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||
# When reversing, we also reversed conv->batchN->relu, correct it
|
||||
decoder = [item for i in range(0, len(decoder), 3)
|
||||
for item in decoder[i:i + 3][::-1]]
|
||||
decoder = [
|
||||
item for i in range(0, len(decoder), 3) for item in decoder[i : i + 3][::-1]
|
||||
]
|
||||
# Replace some conv layers & batchN after them
|
||||
for i, module in enumerate(decoder):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
if module.in_channels != module.out_channels:
|
||||
decoder[i + 1] = nn.BatchNorm2d(module.in_channels)
|
||||
decoder[i] = nn.Conv2d(module.out_channels, module.in_channels,
|
||||
kernel_size=3, stride=1, padding=1)
|
||||
decoder[i] = nn.Conv2d(
|
||||
module.out_channels,
|
||||
module.in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
self.stage1_decoder = nn.Sequential(*decoder[0:9])
|
||||
self.stage2_decoder = nn.Sequential(*decoder[9:18])
|
||||
self.stage3_decoder = nn.Sequential(*decoder[18:27])
|
||||
self.stage4_decoder = nn.Sequential(*decoder[27:33])
|
||||
self.stage5_decoder = nn.Sequential(*decoder[33:],
|
||||
nn.Conv2d(64, out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
)
|
||||
self.stage5_decoder = nn.Sequential(
|
||||
*decoder[33:],
|
||||
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
||||
|
||||
self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder,
|
||||
self.stage4_decoder, self.stage5_decoder)
|
||||
self._initialize_weights(
|
||||
self.stage1_decoder,
|
||||
self.stage2_decoder,
|
||||
self.stage3_decoder,
|
||||
self.stage4_decoder,
|
||||
self.stage5_decoder,
|
||||
)
|
||||
|
||||
def _initialize_weights(self, *stages):
|
||||
for modules in stages:
|
||||
|
@ -242,7 +266,6 @@ class SegNet_VGG(nn.Module):
|
|||
|
||||
|
||||
class SegNet_VGG_GN(nn.Module):
|
||||
|
||||
def __init__(self, out_channels=1, in_channels=3, pretrained=False):
|
||||
super(SegNet_VGG_GN, self).__init__()
|
||||
vgg_bn = vgg16_bn(pretrained=pretrained)
|
||||
|
@ -267,33 +290,45 @@ class SegNet_VGG_GN(nn.Module):
|
|||
|
||||
# Decoder, same as the encoder but reversed, maxpool will not be used
|
||||
decoder = encoder
|
||||
decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)]
|
||||
decoder = [
|
||||
i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)
|
||||
]
|
||||
# Replace the last conv layer
|
||||
decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||
# When reversing, we also reversed conv->batchN->relu, correct it
|
||||
decoder = [item for i in range(0, len(decoder), 3)
|
||||
for item in decoder[i:i + 3][::-1]]
|
||||
decoder = [
|
||||
item for i in range(0, len(decoder), 3) for item in decoder[i : i + 3][::-1]
|
||||
]
|
||||
# Replace some conv layers & batchN after them
|
||||
for i, module in enumerate(decoder):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
if module.in_channels != module.out_channels:
|
||||
decoder[i + 1] = nn.GroupNorm(32, module.in_channels)
|
||||
decoder[i] = nn.Conv2d(module.out_channels, module.in_channels,
|
||||
kernel_size=3, stride=1, padding=1)
|
||||
decoder[i] = nn.Conv2d(
|
||||
module.out_channels,
|
||||
module.in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
self.stage1_decoder = nn.Sequential(*decoder[0:9])
|
||||
self.stage2_decoder = nn.Sequential(*decoder[9:18])
|
||||
self.stage3_decoder = nn.Sequential(*decoder[18:27])
|
||||
self.stage4_decoder = nn.Sequential(*decoder[27:33])
|
||||
self.stage5_decoder = nn.Sequential(*decoder[33:], nn.Conv2d(64,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
self.stage5_decoder = nn.Sequential(
|
||||
*decoder[33:],
|
||||
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
||||
|
||||
self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder,
|
||||
self.stage4_decoder, self.stage5_decoder)
|
||||
self._initialize_weights(
|
||||
self.stage1_decoder,
|
||||
self.stage2_decoder,
|
||||
self.stage3_decoder,
|
||||
self.stage4_decoder,
|
||||
self.stage5_decoder,
|
||||
)
|
||||
|
||||
def _initialize_weights(self, *stages):
|
||||
for modules in stages:
|
||||
|
@ -348,7 +383,6 @@ class SegNet_VGG_GN(nn.Module):
|
|||
|
||||
|
||||
class SegNet_AlexNet(nn.Module):
|
||||
|
||||
def __init__(self, out_channels=1, in_channels=1, bn=False):
|
||||
super(SegNet_AlexNet, self).__init__()
|
||||
self.stage3_encoder = nn.Sequential(
|
||||
|
@ -373,7 +407,9 @@ class SegNet_AlexNet(nn.Module):
|
|||
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False, return_indices=True)
|
||||
self.maxpool = nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, ceil_mode=False, return_indices=True
|
||||
)
|
||||
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
|
||||
self.stage5_decoder = nn.Sequential(
|
||||
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
||||
|
@ -419,7 +455,6 @@ class SegNet_AlexNet(nn.Module):
|
|||
|
||||
|
||||
class SegNet_ResNet(nn.Module):
|
||||
|
||||
def __init__(self, backbone, out_channels=1, is_bottleneck=False, in_channels=1):
|
||||
super(SegNet_ResNet, self).__init__()
|
||||
resnet_backbone = backbone
|
||||
|
@ -442,18 +477,25 @@ class SegNet_ResNet(nn.Module):
|
|||
channels = (512, 256, 128)
|
||||
for i, block in enumerate(resnet_r_blocks[:-1]):
|
||||
new_block = list(block.children())[::-1][:-1]
|
||||
decoder.append(nn.Sequential(*new_block,
|
||||
DecoderBottleneck(channels[i])
|
||||
if is_bottleneck else DecoderBasicBlock(channels[i])))
|
||||
decoder.append(
|
||||
nn.Sequential(
|
||||
*new_block,
|
||||
DecoderBottleneck(channels[i])
|
||||
if is_bottleneck
|
||||
else DecoderBasicBlock(channels[i])
|
||||
)
|
||||
)
|
||||
new_block = list(resnet_r_blocks[-1].children())[::-1][:-1]
|
||||
decoder.append(nn.Sequential(*new_block,
|
||||
LastBottleneck(256)
|
||||
if is_bottleneck else LastBasicBlock(64)))
|
||||
decoder.append(
|
||||
nn.Sequential(
|
||||
*new_block, LastBottleneck(256) if is_bottleneck else LastBasicBlock(64)
|
||||
)
|
||||
)
|
||||
|
||||
self.decoder = nn.Sequential(*decoder)
|
||||
self.last_conv = nn.Sequential(
|
||||
nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, bias=False),
|
||||
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -468,10 +510,14 @@ class SegNet_ResNet(nn.Module):
|
|||
h_diff = ceil((x.size()[2] - indices.size()[2]) / 2)
|
||||
w_diff = ceil((x.size()[3] - indices.size()[3]) / 2)
|
||||
if indices.size()[2] % 2 == 1:
|
||||
x = x[:, :, h_diff:x.size()[2] - (h_diff - 1),
|
||||
w_diff: x.size()[3] - (w_diff - 1)]
|
||||
x = x[
|
||||
:,
|
||||
:,
|
||||
h_diff : x.size()[2] - (h_diff - 1),
|
||||
w_diff : x.size()[3] - (w_diff - 1),
|
||||
]
|
||||
else:
|
||||
x = x[:, :, h_diff:x.size()[2] - h_diff, w_diff: x.size()[3] - w_diff]
|
||||
x = x[:, :, h_diff : x.size()[2] - h_diff, w_diff : x.size()[3] - w_diff]
|
||||
|
||||
x = F.max_unpool2d(x, indices, kernel_size=2, stride=2)
|
||||
x = self.last_conv(x)
|
||||
|
@ -479,9 +525,11 @@ class SegNet_ResNet(nn.Module):
|
|||
if inputsize != x.size():
|
||||
h_diff = (x.size()[2] - inputsize[2]) // 2
|
||||
w_diff = (x.size()[3] - inputsize[3]) // 2
|
||||
x = x[:, :, h_diff:x.size()[2] - h_diff, w_diff: x.size()[3] - w_diff]
|
||||
if h_diff % 2 != 0: x = x[:, :, :-1, :]
|
||||
if w_diff % 2 != 0: x = x[:, :, :, :-1]
|
||||
x = x[:, :, h_diff : x.size()[2] - h_diff, w_diff : x.size()[3] - w_diff]
|
||||
if h_diff % 2 != 0:
|
||||
x = x[:, :, :-1, :]
|
||||
if w_diff % 2 != 0:
|
||||
x = x[:, :, :, :-1]
|
||||
|
||||
return x
|
||||
|
||||
|
@ -492,8 +540,13 @@ def SegNet_ResNet18(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet18()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=False,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=False,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -503,8 +556,13 @@ def SegNet_ResNet34(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet34()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=False,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=False,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -514,8 +572,13 @@ def SegNet_ResNet50(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet50()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=True,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -525,8 +588,13 @@ def SegNet_ResNet101(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet101()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=True,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -536,15 +604,20 @@ def SegNet_ResNet152(in_channels=1, out_channels=1, **kwargs):
|
|||
|
||||
"""
|
||||
backbone_net = resnet101()
|
||||
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True,
|
||||
in_channels=in_channels, **kwargs)
|
||||
model = SegNet_ResNet(
|
||||
backbone_net,
|
||||
out_channels=out_channels,
|
||||
is_bottleneck=True,
|
||||
in_channels=in_channels,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
model = SegNet_AlexNet(in_channels=1, out_channels=1)
|
||||
print(model)
|
||||
x = torch.randn(1, 1, 200, 200)
|
||||
with torch.no_grad():
|
||||
y = model(x)
|
||||
print(y.shape)
|
||||
print(y.shape)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
''' Define the Transformer model '''
|
||||
""" Define the Transformer model """
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
@ -13,54 +13,76 @@ def get_pad_mask(seq, pad_idx):
|
|||
|
||||
|
||||
def get_subsequent_mask(seq):
|
||||
''' For masking out the subsequent info. '''
|
||||
"""For masking out the subsequent info."""
|
||||
sz_b, len_s = seq.size()
|
||||
subsequent_mask = (1 - torch.triu(
|
||||
torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool()
|
||||
subsequent_mask = (
|
||||
1 - torch.triu(torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)
|
||||
).bool()
|
||||
return subsequent_mask
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
|
||||
def __init__(self, d_hid, n_position=200):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
|
||||
# Not a parameter
|
||||
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
self.register_buffer(
|
||||
"pos_table", self._get_sinusoid_encoding_table(n_position, d_hid)
|
||||
)
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
''' Sinusoid position encoding table '''
|
||||
"""Sinusoid position encoding table"""
|
||||
|
||||
# TODO: make it with torch instead of numpy
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
||||
)
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.pos_table[:, :x.size(1)].clone().detach()
|
||||
return x + self.pos_table[:, : x.size(1)].clone().detach()
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
''' A encoder model with self attention mechanism. '''
|
||||
"""A encoder model with self attention mechanism."""
|
||||
|
||||
def __init__(
|
||||
self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
|
||||
d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False):
|
||||
self,
|
||||
n_src_vocab,
|
||||
d_word_vec,
|
||||
n_layers,
|
||||
n_head,
|
||||
d_k,
|
||||
d_v,
|
||||
d_model,
|
||||
d_inner,
|
||||
pad_idx,
|
||||
dropout=0.1,
|
||||
n_position=200,
|
||||
scale_emb=False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
|
||||
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.layer_stack = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)])
|
||||
self.layer_stack = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.scale_emb = scale_emb
|
||||
self.d_model = d_model
|
||||
|
@ -82,24 +104,39 @@ class Encoder(nn.Module):
|
|||
|
||||
if return_attns:
|
||||
return enc_output, enc_slf_attn_list
|
||||
return enc_output,
|
||||
return (enc_output,)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
''' A decoder model with self attention mechanism. '''
|
||||
"""A decoder model with self attention mechanism."""
|
||||
|
||||
def __init__(
|
||||
self, n_trg_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
|
||||
d_model, d_inner, pad_idx, n_position=200, dropout=0.1, scale_emb=False):
|
||||
self,
|
||||
n_trg_vocab,
|
||||
d_word_vec,
|
||||
n_layers,
|
||||
n_head,
|
||||
d_k,
|
||||
d_v,
|
||||
d_model,
|
||||
d_inner,
|
||||
pad_idx,
|
||||
n_position=200,
|
||||
dropout=0.1,
|
||||
scale_emb=False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx)
|
||||
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.layer_stack = nn.ModuleList([
|
||||
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)])
|
||||
self.layer_stack = nn.ModuleList(
|
||||
[
|
||||
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.scale_emb = scale_emb
|
||||
self.d_model = d_model
|
||||
|
@ -117,24 +154,41 @@ class Decoder(nn.Module):
|
|||
|
||||
for dec_layer in self.layer_stack:
|
||||
dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
|
||||
dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask)
|
||||
dec_output,
|
||||
enc_output,
|
||||
slf_attn_mask=trg_mask,
|
||||
dec_enc_attn_mask=src_mask,
|
||||
)
|
||||
dec_slf_attn_list += [dec_slf_attn] if return_attns else []
|
||||
dec_enc_attn_list += [dec_enc_attn] if return_attns else []
|
||||
|
||||
if return_attns:
|
||||
return dec_output, dec_slf_attn_list, dec_enc_attn_list
|
||||
return dec_output,
|
||||
return (dec_output,)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
''' A sequence to sequence model with attention mechanism. '''
|
||||
"""A sequence to sequence model with attention mechanism."""
|
||||
|
||||
def __init__(
|
||||
self, n_src_vocab, n_trg_vocab, src_pad_idx, trg_pad_idx,
|
||||
d_word_vec=512, d_model=512, d_inner=2048,
|
||||
n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, n_position=200,
|
||||
trg_emb_prj_weight_sharing=True, emb_src_trg_weight_sharing=True,
|
||||
scale_emb_or_prj='prj'):
|
||||
self,
|
||||
n_src_vocab,
|
||||
n_trg_vocab,
|
||||
src_pad_idx,
|
||||
trg_pad_idx,
|
||||
d_word_vec=512,
|
||||
d_model=512,
|
||||
d_inner=2048,
|
||||
n_layers=6,
|
||||
n_head=8,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
dropout=0.1,
|
||||
n_position=200,
|
||||
trg_emb_prj_weight_sharing=True,
|
||||
emb_src_trg_weight_sharing=True,
|
||||
scale_emb_or_prj="prj",
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
@ -150,22 +204,42 @@ class Transformer(nn.Module):
|
|||
# 'prj': multiply (\sqrt{d_model} ^ -1) to linear projection output
|
||||
# 'none': no multiplication
|
||||
|
||||
assert scale_emb_or_prj in ['emb', 'prj', 'none']
|
||||
scale_emb = (scale_emb_or_prj == 'emb') if trg_emb_prj_weight_sharing else False
|
||||
self.scale_prj = (scale_emb_or_prj == 'prj') if trg_emb_prj_weight_sharing else False
|
||||
assert scale_emb_or_prj in ["emb", "prj", "none"]
|
||||
scale_emb = (scale_emb_or_prj == "emb") if trg_emb_prj_weight_sharing else False
|
||||
self.scale_prj = (
|
||||
(scale_emb_or_prj == "prj") if trg_emb_prj_weight_sharing else False
|
||||
)
|
||||
self.d_model = d_model
|
||||
|
||||
self.encoder = Encoder(
|
||||
n_src_vocab=n_src_vocab, n_position=n_position,
|
||||
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
|
||||
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
|
||||
pad_idx=src_pad_idx, dropout=dropout, scale_emb=scale_emb)
|
||||
n_src_vocab=n_src_vocab,
|
||||
n_position=n_position,
|
||||
d_word_vec=d_word_vec,
|
||||
d_model=d_model,
|
||||
d_inner=d_inner,
|
||||
n_layers=n_layers,
|
||||
n_head=n_head,
|
||||
d_k=d_k,
|
||||
d_v=d_v,
|
||||
pad_idx=src_pad_idx,
|
||||
dropout=dropout,
|
||||
scale_emb=scale_emb,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
n_trg_vocab=n_trg_vocab, n_position=n_position,
|
||||
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
|
||||
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
|
||||
pad_idx=trg_pad_idx, dropout=dropout, scale_emb=scale_emb)
|
||||
n_trg_vocab=n_trg_vocab,
|
||||
n_position=n_position,
|
||||
d_word_vec=d_word_vec,
|
||||
d_model=d_model,
|
||||
d_inner=d_inner,
|
||||
n_layers=n_layers,
|
||||
n_head=n_head,
|
||||
d_k=d_k,
|
||||
d_v=d_v,
|
||||
pad_idx=trg_pad_idx,
|
||||
dropout=dropout,
|
||||
scale_emb=scale_emb,
|
||||
)
|
||||
|
||||
self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias=False)
|
||||
|
||||
|
@ -173,9 +247,10 @@ class Transformer(nn.Module):
|
|||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
assert d_model == d_word_vec, \
|
||||
'To facilitate the residual connections, \
|
||||
the dimensions of all module outputs shall be the same.'
|
||||
assert (
|
||||
d_model == d_word_vec
|
||||
), "To facilitate the residual connections, \
|
||||
the dimensions of all module outputs shall be the same."
|
||||
|
||||
if trg_emb_prj_weight_sharing:
|
||||
# Share the weight between target word embedding & last dense layer
|
||||
|
@ -187,7 +262,9 @@ class Transformer(nn.Module):
|
|||
def forward(self, src_seq, trg_seq):
|
||||
|
||||
src_mask = get_pad_mask(src_seq, self.src_pad_idx)
|
||||
trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq)
|
||||
trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(
|
||||
trg_seq
|
||||
)
|
||||
|
||||
enc_output, *_ = self.encoder(src_seq, src_mask)
|
||||
dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask)
|
||||
|
@ -202,9 +279,12 @@ class Transformer(nn.Module):
|
|||
class EncoderModify(nn.Module):
|
||||
def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):
|
||||
super(EncoderModify, self).__init__()
|
||||
self.layer_stack = nn.ModuleList([
|
||||
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)])
|
||||
self.layer_stack = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, src_seq):
|
||||
enc_output = src_seq
|
||||
|
@ -216,12 +296,14 @@ class EncoderModify(nn.Module):
|
|||
|
||||
# Modified Decoder
|
||||
class DecoderModify(nn.Module):
|
||||
def __init__(
|
||||
self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):
|
||||
def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):
|
||||
super(DecoderModify, self).__init__()
|
||||
self.layer_stack = nn.ModuleList([
|
||||
DecoderLayerModify(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)])
|
||||
self.layer_stack = nn.ModuleList(
|
||||
[
|
||||
DecoderLayerModify(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, trg_seq, src_key, enc_output):
|
||||
dec_output = trg_seq
|
||||
|
@ -233,8 +315,17 @@ class DecoderModify(nn.Module):
|
|||
# Modified Transformer
|
||||
class TransformerRecon(nn.Module):
|
||||
def __init__(
|
||||
self, input_dim=None, output_dim=None, d_model=128, d_inner=512,
|
||||
n_layers=2, n_head=4, d_k=32, d_v=32, dropout=0.1):
|
||||
self,
|
||||
input_dim=None,
|
||||
output_dim=None,
|
||||
d_model=128,
|
||||
d_inner=512,
|
||||
n_layers=2,
|
||||
n_head=4,
|
||||
d_k=32,
|
||||
d_v=32,
|
||||
dropout=0.1,
|
||||
):
|
||||
|
||||
super(TransformerRecon, self).__init__()
|
||||
|
||||
|
@ -253,19 +344,31 @@ class TransformerRecon(nn.Module):
|
|||
self.d_model = d_model
|
||||
|
||||
self.encoder = EncoderModify(
|
||||
d_model=d_model, d_inner=d_inner, n_layers=n_layers,
|
||||
n_head=n_head, d_k=d_k, d_v=d_v, dropout=dropout)
|
||||
d_model=d_model,
|
||||
d_inner=d_inner,
|
||||
n_layers=n_layers,
|
||||
n_head=n_head,
|
||||
d_k=d_k,
|
||||
d_v=d_v,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.decoder = DecoderModify(
|
||||
d_model=d_model, d_inner=d_inner, n_layers=1,
|
||||
n_head=n_head, d_k=d_k, d_v=d_v, dropout=dropout)
|
||||
d_model=d_model,
|
||||
d_inner=d_inner,
|
||||
n_layers=1,
|
||||
n_head=n_head,
|
||||
d_k=d_k,
|
||||
d_v=d_v,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.pre = nn.Sequential(
|
||||
nn.Linear(128 + 2, 128),
|
||||
nn.GELU(),
|
||||
nn.Linear(128, 128),
|
||||
nn.GELU(),
|
||||
nn.Linear(128, 1)
|
||||
nn.Linear(128, 1),
|
||||
)
|
||||
|
||||
def forward(self, src_seq, src_label, trg_seq, trg_label=None):
|
||||
|
@ -278,11 +381,3 @@ class TransformerRecon(nn.Module):
|
|||
dec_output = self.decoder(trg_seq, src_seq, enc_output)
|
||||
return self.pre(torch.cat([dec_output, trg_x], dim=-1))
|
||||
|
||||
|
||||
def test_TransformerModify():
|
||||
transformer = TransformerModify()
|
||||
x_src = torch.randn(8, 10, 3)
|
||||
src_seq = torch.randn(8, 10, 2)
|
||||
x_trg = torch.randn(8, 400, 2)
|
||||
y = transformer(x_src, src_seq, x_trg)
|
||||
print(y.shape)
|
|
@ -10,8 +10,9 @@ __all__ = ["UNet_VGG"]
|
|||
|
||||
|
||||
class _EncoderBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, dropout=False, polling=True, bn=False):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, dropout=False, polling=True, bn=False
|
||||
):
|
||||
super(_EncoderBlock, self).__init__()
|
||||
layers = [
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
|
@ -35,15 +36,18 @@ class _EncoderBlock(nn.Module):
|
|||
|
||||
|
||||
class _DecoderBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, middle_channels, out_channels, bn=False):
|
||||
super(_DecoderBlock, self).__init__()
|
||||
self.decode = nn.Sequential(
|
||||
nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(middle_channels) if bn else nn.GroupNorm(32, middle_channels),
|
||||
nn.BatchNorm2d(middle_channels)
|
||||
if bn
|
||||
else nn.GroupNorm(32, middle_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(middle_channels) if bn else nn.GroupNorm(32, middle_channels),
|
||||
nn.BatchNorm2d(middle_channels)
|
||||
if bn
|
||||
else nn.GroupNorm(32, middle_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2),
|
||||
)
|
||||
|
@ -53,7 +57,6 @@ class _DecoderBlock(nn.Module):
|
|||
|
||||
|
||||
class UNet_VGG(nn.Module):
|
||||
|
||||
def __init__(self, out_channels=1, in_channels=1, bn=False):
|
||||
super(UNet_VGG, self).__init__()
|
||||
self.enc1 = _EncoderBlock(in_channels, 64, polling=False, bn=bn)
|
||||
|
@ -82,8 +85,17 @@ class UNet_VGG(nn.Module):
|
|||
enc3 = self.enc3(enc2)
|
||||
enc4 = self.enc4(enc3)
|
||||
center = self.center(self.polling(enc4))
|
||||
dec4 = self.dec4(torch.cat([F.interpolate(center, enc4.size()[-2:], mode='bilinear',
|
||||
align_corners=True), enc4], 1))
|
||||
dec4 = self.dec4(
|
||||
torch.cat(
|
||||
[
|
||||
F.interpolate(
|
||||
center, enc4.size()[-2:], mode="bilinear", align_corners=True
|
||||
),
|
||||
enc4,
|
||||
],
|
||||
1,
|
||||
)
|
||||
)
|
||||
dec3 = self.dec3(torch.cat([dec4, enc3], 1))
|
||||
dec2 = self.dec2(torch.cat([dec3, enc2], 1))
|
||||
dec1 = self.dec1(torch.cat([dec2, enc1], 1))
|
||||
|
@ -91,10 +103,10 @@ class UNet_VGG(nn.Module):
|
|||
return final
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = UNet(in_channels=1, out_channels=1)
|
||||
if __name__ == "__main__":
|
||||
model = UNet_VGG(in_channels=1, out_channels=1)
|
||||
print(model)
|
||||
x = torch.randn(1, 1, 200, 200)
|
||||
with torch.no_grad():
|
||||
y = model(x)
|
||||
print(y.shape)
|
||||
print(y.shape)
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .alexnet import *
|
||||
from .resnet import *
|
||||
from .vgg import *
|
||||
from .vgg import *
|
||||
|
|
|
@ -16,8 +16,13 @@ class AlexNet(nn.Module):
|
|||
super(AlexNet, self).__init__()
|
||||
self.features3 = nn.Sequential(
|
||||
# kernel(11, 11) -> kernel(7, 7)
|
||||
nn.Conv2d(in_channels=in_channels, out_channels=64,
|
||||
kernel_size=7, stride=4, padding=3),
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=7,
|
||||
stride=4,
|
||||
padding=3,
|
||||
),
|
||||
nn.BatchNorm2d(64) if bn else nn.GroupNorm(32, 64),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
@ -49,7 +54,7 @@ class AlexNet(nn.Module):
|
|||
|
||||
if __name__ == "__main__":
|
||||
x = torch.zeros(8, 1, 200, 200)
|
||||
net = Alexnet()
|
||||
net = AlexNet()
|
||||
print(net)
|
||||
y = net(x)
|
||||
print()
|
||||
|
|
|
@ -14,8 +14,8 @@ __all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152
|
|||
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
||||
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
||||
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
||||
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
||||
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
||||
|
@ -24,7 +24,9 @@ model_urls = {
|
|||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
return nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
|
@ -66,12 +68,9 @@ class Bottleneck(nn.Module):
|
|||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
|
||||
)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
|
@ -106,7 +105,9 @@ class ResNet(nn.Module):
|
|||
def __init__(self, block, layers, in_channels=1):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
|
||||
)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
@ -127,11 +128,13 @@ class ResNet(nn.Module):
|
|||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
nn.Conv2d(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
|
@ -173,7 +176,7 @@ def resnet18(pretrained=False, in_channels=1, **kwargs):
|
|||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, **kwargs)
|
||||
if pretrained:
|
||||
model._load_pretrained_model(model_urls['resnet18'])
|
||||
model._load_pretrained_model(model_urls["resnet18"])
|
||||
return model
|
||||
|
||||
|
||||
|
@ -185,7 +188,7 @@ def resnet34(pretrained=False, in_channels=1, **kwargs):
|
|||
"""
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], in_channels=in_channels, **kwargs)
|
||||
if pretrained:
|
||||
model._load_pretrained_model(model_urls['resnet34'])
|
||||
model._load_pretrained_model(model_urls["resnet34"])
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -11,30 +11,33 @@ from src.utils.vgg_utils import load_state_dict_from_url
|
|||
|
||||
|
||||
__all__ = [
|
||||
"VGG", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn",
|
||||
"vgg19_bn", "vgg19",
|
||||
"VGG",
|
||||
"vgg11",
|
||||
"vgg11_bn",
|
||||
"vgg13",
|
||||
"vgg13_bn",
|
||||
"vgg16",
|
||||
"vgg16_bn",
|
||||
"vgg19_bn",
|
||||
"vgg19",
|
||||
]
|
||||
|
||||
|
||||
model_urls = {
|
||||
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
|
||||
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
|
||||
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
|
||||
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
|
||||
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
|
||||
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
|
||||
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
|
||||
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
|
||||
"vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
|
||||
"vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
|
||||
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
|
||||
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
||||
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
||||
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
||||
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
||||
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
|
||||
}
|
||||
|
||||
|
||||
class VGG(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: nn.Module,
|
||||
num_classes: int = 1000,
|
||||
init_weights: bool = True
|
||||
self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True
|
||||
) -> None:
|
||||
super(VGG, self).__init__()
|
||||
self.features = features
|
||||
|
@ -61,7 +64,7 @@ class VGG(nn.Module):
|
|||
def _initialize_weights(self) -> None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
|
@ -76,7 +79,7 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ
|
|||
layers: List[nn.Module] = []
|
||||
in_channels = 3
|
||||
for v in cfg:
|
||||
if v == 'M':
|
||||
if v == "M":
|
||||
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
v = cast(int, v)
|
||||
|
@ -90,20 +93,67 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ
|
|||
|
||||
|
||||
cfgs: Dict[str, List[Union[str, int]]] = {
|
||||
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
|
||||
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
|
||||
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
|
||||
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
||||
"B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
||||
"D": [
|
||||
64,
|
||||
64,
|
||||
"M",
|
||||
128,
|
||||
128,
|
||||
"M",
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
],
|
||||
"E": [
|
||||
64,
|
||||
64,
|
||||
"M",
|
||||
128,
|
||||
128,
|
||||
"M",
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
"M",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
|
||||
def _vgg(
|
||||
arch: str,
|
||||
cfg: str,
|
||||
batch_norm: bool,
|
||||
pretrained: bool,
|
||||
progress: bool,
|
||||
**kwargs: Any
|
||||
) -> VGG:
|
||||
if pretrained:
|
||||
kwargs['init_weights'] = False
|
||||
kwargs["init_weights"] = False
|
||||
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(model_urls[arch],
|
||||
progress=progress)
|
||||
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
@ -116,7 +166,7 @@ def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
|
||||
return _vgg("vgg11", "A", False, pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
||||
|
@ -127,7 +177,7 @@ def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
|
||||
return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
||||
|
@ -138,7 +188,7 @@ def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
|
||||
return _vgg("vgg13", "B", False, pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
||||
|
@ -149,7 +199,7 @@ def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
|
||||
return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
||||
|
@ -160,7 +210,7 @@ def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
|
||||
return _vgg("vgg16", "D", False, pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
||||
|
@ -171,7 +221,7 @@ def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
|
||||
return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
||||
|
@ -182,7 +232,7 @@ def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
|
||||
return _vgg("vgg19", "E", False, pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
|
||||
|
@ -193,4 +243,4 @@ def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
|
|||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)
|
||||
return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from .torch_nn import *
|
||||
from .torch_edge import *
|
||||
from .torch_vertex import *
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ class DenseDilated(nn.Module):
|
|||
|
||||
edge_index: (2, batch_size, num_points, k)
|
||||
"""
|
||||
|
||||
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
|
||||
super(DenseDilated, self).__init__()
|
||||
self.dilation = dilation
|
||||
|
@ -20,12 +21,12 @@ class DenseDilated(nn.Module):
|
|||
if self.stochastic:
|
||||
if torch.rand(1) < self.epsilon and self.training:
|
||||
num = self.k * self.dilation
|
||||
randnum = torch.randperm(num)[:self.k]
|
||||
randnum = torch.randperm(num)[: self.k]
|
||||
edge_index = edge_index[:, :, :, randnum]
|
||||
else:
|
||||
edge_index = edge_index[:, :, :, ::self.dilation]
|
||||
edge_index = edge_index[:, :, :, :: self.dilation]
|
||||
else:
|
||||
edge_index = edge_index[:, :, :, ::self.dilation]
|
||||
edge_index = edge_index[:, :, :, :: self.dilation]
|
||||
return edge_index
|
||||
|
||||
|
||||
|
@ -37,7 +38,7 @@ def pairwise_distance(x):
|
|||
Returns:
|
||||
pairwise distance: (batch_size, num_points, num_points)
|
||||
"""
|
||||
x_inner = -2*torch.matmul(x, x.transpose(2, 1))
|
||||
x_inner = -2 * torch.matmul(x, x.transpose(2, 1))
|
||||
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
|
||||
return x_square + x_inner + x_square.transpose(2, 1)
|
||||
|
||||
|
@ -54,7 +55,11 @@ def dense_knn_matrix(x, k=16):
|
|||
x = x.transpose(2, 1).squeeze(-1)
|
||||
batch_size, n_points, n_dims = x.shape
|
||||
_, nn_idx = torch.topk(-pairwise_distance(x.detach()), k=k)
|
||||
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
|
||||
center_idx = (
|
||||
torch.arange(0, n_points, device=x.device)
|
||||
.repeat(batch_size, k, 1)
|
||||
.transpose(2, 1)
|
||||
)
|
||||
return torch.stack((nn_idx, center_idx), dim=0)
|
||||
|
||||
|
||||
|
@ -62,6 +67,7 @@ class DenseDilatedKnnGraph(nn.Module):
|
|||
"""
|
||||
Find the neighbors' indices based on dilated knn
|
||||
"""
|
||||
|
||||
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
|
||||
super(DenseDilatedKnnGraph, self).__init__()
|
||||
self.dilation = dilation
|
||||
|
@ -80,6 +86,7 @@ class DilatedKnnGraph(nn.Module):
|
|||
"""
|
||||
Find the neighbors' indices based on dilated knn
|
||||
"""
|
||||
|
||||
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
|
||||
super(DilatedKnnGraph, self).__init__()
|
||||
self.dilation = dilation
|
||||
|
@ -94,7 +101,9 @@ class DilatedKnnGraph(nn.Module):
|
|||
B, C, N = x.shape
|
||||
edge_index = []
|
||||
for i in range(B):
|
||||
edgeindex = self.knn(x[i].contiguous().transpose(1, 0).contiguous(), self.k * self.dilation)
|
||||
edgeindex = self.knn(
|
||||
x[i].contiguous().transpose(1, 0).contiguous(), self.k * self.dilation
|
||||
)
|
||||
edgeindex = edgeindex.view(2, N, self.k * self.dilation)
|
||||
edge_index.append(edgeindex)
|
||||
edge_index = torch.stack(edge_index, dim=1)
|
||||
|
|
|
@ -10,55 +10,55 @@ def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
|
|||
# activation layer
|
||||
|
||||
act = act.lower()
|
||||
if act == 'relu':
|
||||
if act == "relu":
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act == 'leakyrelu':
|
||||
elif act == "leakyrelu":
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act == 'prelu':
|
||||
elif act == "prelu":
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act == 'gelu':
|
||||
elif act == "gelu":
|
||||
layer = nn.GELU()
|
||||
elif act == 'sigmoid':
|
||||
elif act == "sigmoid":
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError('activation layer [%s] is not found' % act)
|
||||
raise NotImplementedError("activation layer [%s] is not found" % act)
|
||||
return layer
|
||||
|
||||
|
||||
def norm_layer(norm, nc):
|
||||
# normalization layer 2d
|
||||
norm = norm.lower()
|
||||
if norm == 'batch':
|
||||
if norm == "batch":
|
||||
layer = nn.BatchNorm2d(nc, affine=False, track_running_stats=False)
|
||||
elif norm == 'instance':
|
||||
elif norm == "instance":
|
||||
layer = nn.InstanceNorm2d(nc, affine=True)
|
||||
elif norm == 'group':
|
||||
elif norm == "group":
|
||||
layer = nn.GroupNorm(32, nc, affine=False)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [%s] is not found' % norm)
|
||||
raise NotImplementedError("normalization layer [%s] is not found" % norm)
|
||||
return layer
|
||||
|
||||
|
||||
class MLP(Seq):
|
||||
def __init__(self, channels, act='relu', norm=None, bias=True):
|
||||
def __init__(self, channels, act="relu", norm=None, bias=True):
|
||||
m = []
|
||||
for i in range(1, len(channels)):
|
||||
m.append(Lin(channels[i - 1], channels[i], bias))
|
||||
if act is not None and act.lower() != 'none':
|
||||
if act is not None and act.lower() != "none":
|
||||
m.append(act_layer(act))
|
||||
if norm is not None and norm.lower() != 'none':
|
||||
if norm is not None and norm.lower() != "none":
|
||||
m.append(norm_layer(norm, channels[-1]))
|
||||
super(MLP, self).__init__(*m)
|
||||
|
||||
|
||||
class BasicConv(Seq):
|
||||
def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.):
|
||||
def __init__(self, channels, act="relu", norm=None, bias=True, drop=0.0):
|
||||
m = []
|
||||
for i in range(1, len(channels)):
|
||||
m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias))
|
||||
if act is not None and act.lower() != 'none':
|
||||
if act is not None and act.lower() != "none":
|
||||
m.append(act_layer(act))
|
||||
if norm is not None and norm.lower() != 'none':
|
||||
if norm is not None and norm.lower() != "none":
|
||||
m.append(norm_layer(norm, channels[-1]))
|
||||
if drop > 0:
|
||||
m.append(nn.Dropout2d(drop))
|
||||
|
@ -95,11 +95,17 @@ def batched_index_select(x, idx):
|
|||
"""
|
||||
batch_size, num_dims, num_vertices = x.shape[:3]
|
||||
k = idx.shape[-1]
|
||||
idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices
|
||||
idx_base = (
|
||||
torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices
|
||||
)
|
||||
idx = idx + idx_base
|
||||
idx = idx.contiguous().view(-1)
|
||||
|
||||
x = x.transpose(2, 1)
|
||||
feature = x.contiguous().view(batch_size * num_vertices, -1)[idx, :]
|
||||
feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous()
|
||||
feature = (
|
||||
feature.view(batch_size, num_vertices, k, num_dims)
|
||||
.permute(0, 3, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
return feature
|
||||
|
|
|
@ -9,9 +9,10 @@ class MRConv2d(nn.Module):
|
|||
"""
|
||||
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
|
||||
|
||||
def __init__(self, in_channels, out_channels, act="relu", norm=None, bias=True):
|
||||
super(MRConv2d, self).__init__()
|
||||
self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias)
|
||||
self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias)
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
x_i = batched_index_select(x, edge_index[1])
|
||||
|
@ -24,14 +25,17 @@ class EdgeConv2d(nn.Module):
|
|||
"""
|
||||
Edge convolution layer (with activation, batch normalization) for dense data type
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
|
||||
|
||||
def __init__(self, in_channels, out_channels, act="relu", norm=None, bias=True):
|
||||
super(EdgeConv2d, self).__init__()
|
||||
self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias)
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
x_i = batched_index_select(x, edge_index[1])
|
||||
x_j = batched_index_select(x, edge_index[0])
|
||||
max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
|
||||
max_value, _ = torch.max(
|
||||
self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True
|
||||
)
|
||||
return max_value
|
||||
|
||||
|
||||
|
@ -39,14 +43,17 @@ class GraphConv2d(nn.Module):
|
|||
"""
|
||||
Static graph convolution layer
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True):
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, conv="edge", act="relu", norm=None, bias=True
|
||||
):
|
||||
super(GraphConv2d, self).__init__()
|
||||
if conv == 'edge':
|
||||
if conv == "edge":
|
||||
self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias)
|
||||
elif conv == 'mr':
|
||||
elif conv == "mr":
|
||||
self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias)
|
||||
else:
|
||||
raise NotImplementedError('conv:{} is not supported'.format(conv))
|
||||
raise NotImplementedError("conv:{} is not supported".format(conv))
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
return self.gconv(x, edge_index)
|
||||
|
@ -56,59 +63,149 @@ class DynConv2d(GraphConv2d):
|
|||
"""
|
||||
Dynamic graph convolution layer
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu',
|
||||
norm=None, bias=True, stochastic=False, epsilon=0.0, knn='matrix'):
|
||||
super(DynConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=9,
|
||||
dilation=1,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
stochastic=False,
|
||||
epsilon=0.0,
|
||||
knn="matrix",
|
||||
):
|
||||
super(DynConv2d, self).__init__(
|
||||
in_channels, out_channels, conv, act, norm, bias
|
||||
)
|
||||
self.k = kernel_size
|
||||
self.d = dilation
|
||||
if knn == 'matrix':
|
||||
self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)
|
||||
if knn == "matrix":
|
||||
self.dilated_knn_graph = DenseDilatedKnnGraph(
|
||||
kernel_size, dilation, stochastic, epsilon
|
||||
)
|
||||
else:
|
||||
self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)
|
||||
self.dilated_knn_graph = DilatedKnnGraph(
|
||||
kernel_size, dilation, stochastic, epsilon
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
edge_index = self.dilated_knn_graph(x)
|
||||
return super(DynConv2d, self).forward(x, edge_index)
|
||||
|
||||
|
||||
|
||||
class PlainDynBlock2d(nn.Module):
|
||||
"""
|
||||
Plain Dynamic graph convolution block
|
||||
"""
|
||||
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
|
||||
bias=True, stochastic=False, epsilon=0.0, knn='matrix'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
kernel_size=9,
|
||||
dilation=1,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
stochastic=False,
|
||||
epsilon=0.0,
|
||||
knn="matrix",
|
||||
):
|
||||
super(PlainDynBlock2d, self).__init__()
|
||||
self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv,
|
||||
act, norm, bias, stochastic, epsilon, knn)
|
||||
self.body = DynConv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
conv,
|
||||
act,
|
||||
norm,
|
||||
bias,
|
||||
stochastic,
|
||||
epsilon,
|
||||
knn,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.body(x)
|
||||
|
||||
|
||||
|
||||
|
||||
class ResDynBlock2d(nn.Module):
|
||||
"""
|
||||
Residual Dynamic graph convolution block
|
||||
"""
|
||||
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
|
||||
bias=True, stochastic=False, epsilon=0.0, knn='matrix', res_scale=1):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
kernel_size=9,
|
||||
dilation=1,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
stochastic=False,
|
||||
epsilon=0.0,
|
||||
knn="matrix",
|
||||
res_scale=1,
|
||||
):
|
||||
super(ResDynBlock2d, self).__init__()
|
||||
self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv,
|
||||
act, norm, bias, stochastic, epsilon, knn)
|
||||
self.body = DynConv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
conv,
|
||||
act,
|
||||
norm,
|
||||
bias,
|
||||
stochastic,
|
||||
epsilon,
|
||||
knn,
|
||||
)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x):
|
||||
return self.body(x) + x*self.res_scale
|
||||
return self.body(x) + x * self.res_scale
|
||||
|
||||
|
||||
class DenseDynBlock2d(nn.Module):
|
||||
"""
|
||||
Dense Dynamic graph convolution block
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge',
|
||||
act='relu', norm=None,bias=True, stochastic=False, epsilon=0.0, knn='matrix'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=9,
|
||||
dilation=1,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
stochastic=False,
|
||||
epsilon=0.0,
|
||||
knn="matrix",
|
||||
):
|
||||
super(DenseDynBlock2d, self).__init__()
|
||||
self.body = DynConv2d(in_channels, out_channels, kernel_size, dilation, conv,
|
||||
act, norm, bias, stochastic, epsilon, knn)
|
||||
self.body = DynConv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
conv,
|
||||
act,
|
||||
norm,
|
||||
bias,
|
||||
stochastic,
|
||||
epsilon,
|
||||
knn,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
dense = self.body(x)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from .torch_nn import *
|
||||
from .torch_edge import *
|
||||
from .torch_vertex import *
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ class Dilated(nn.Module):
|
|||
"""
|
||||
Find dilated neighbor from neighbor list
|
||||
"""
|
||||
|
||||
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
|
||||
super(Dilated, self).__init__()
|
||||
self.dilation = dilation
|
||||
|
@ -18,14 +19,14 @@ class Dilated(nn.Module):
|
|||
if self.stochastic:
|
||||
if torch.rand(1) < self.epsilon and self.training:
|
||||
num = self.k * self.dilation
|
||||
randnum = torch.randperm(num)[:self.k]
|
||||
randnum = torch.randperm(num)[: self.k]
|
||||
edge_index = edge_index.view(2, -1, num)
|
||||
edge_index = edge_index[:, :, randnum]
|
||||
return edge_index.view(2, -1)
|
||||
else:
|
||||
edge_index = edge_index[:, ::self.dilation]
|
||||
edge_index = edge_index[:, :: self.dilation]
|
||||
else:
|
||||
edge_index = edge_index[:, ::self.dilation]
|
||||
edge_index = edge_index[:, :: self.dilation]
|
||||
return edge_index
|
||||
|
||||
|
||||
|
@ -33,14 +34,15 @@ class DilatedKnnGraph(nn.Module):
|
|||
"""
|
||||
Find the neighbors' indices based on dilated knn
|
||||
"""
|
||||
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn='matrix'):
|
||||
|
||||
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn="matrix"):
|
||||
super(DilatedKnnGraph, self).__init__()
|
||||
self.dilation = dilation
|
||||
self.stochastic = stochastic
|
||||
self.epsilon = epsilon
|
||||
self.k = k
|
||||
self._dilated = Dilated(k, dilation, stochastic, epsilon)
|
||||
if knn == 'matrix':
|
||||
if knn == "matrix":
|
||||
self.knn = knn_graph_matrix
|
||||
else:
|
||||
self.knn = knn_graph
|
||||
|
@ -58,7 +60,7 @@ def pairwise_distance(x):
|
|||
Returns:
|
||||
pairwise distance: (batch_size, num_points, num_points)
|
||||
"""
|
||||
x_inner = -2*torch.matmul(x, x.transpose(2, 1))
|
||||
x_inner = -2 * torch.matmul(x, x.transpose(2, 1))
|
||||
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
|
||||
return x_square + x_inner + x_square.transpose(2, 1)
|
||||
|
||||
|
@ -83,7 +85,11 @@ def knn_matrix(x, k=16, batch=None):
|
|||
del neg_adj
|
||||
|
||||
n_points = x.shape[1]
|
||||
start_idx = torch.arange(0, n_points*batch_size, n_points).long().view(batch_size, 1, 1)
|
||||
start_idx = (
|
||||
torch.arange(0, n_points * batch_size, n_points)
|
||||
.long()
|
||||
.view(batch_size, 1, 1)
|
||||
)
|
||||
if x.is_cuda:
|
||||
start_idx = start_idx.cuda()
|
||||
nn_idx += start_idx
|
||||
|
@ -93,7 +99,13 @@ def knn_matrix(x, k=16, batch=None):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
nn_idx = nn_idx.view(1, -1)
|
||||
center_idx = torch.arange(0, n_points*batch_size).repeat(k, 1).transpose(1, 0).contiguous().view(1, -1)
|
||||
center_idx = (
|
||||
torch.arange(0, n_points * batch_size)
|
||||
.repeat(k, 1)
|
||||
.transpose(1, 0)
|
||||
.contiguous()
|
||||
.view(1, -1)
|
||||
)
|
||||
if x.is_cuda:
|
||||
center_idx = center_idx.cuda()
|
||||
return nn_idx, center_idx
|
||||
|
@ -110,4 +122,3 @@ def knn_graph_matrix(x, k=16, batch=None):
|
|||
"""
|
||||
nn_idx, center_idx = knn_matrix(x, k, batch)
|
||||
return torch.cat((nn_idx, center_idx), dim=0)
|
||||
|
||||
|
|
|
@ -6,27 +6,33 @@ from torch_geometric.utils import degree
|
|||
|
||||
|
||||
class GenMessagePassing(MessagePassing):
|
||||
def __init__(self, aggr='softmax',
|
||||
t=1.0, learn_t=False,
|
||||
p=1.0, learn_p=False,
|
||||
y=0.0, learn_y=False):
|
||||
def __init__(
|
||||
self,
|
||||
aggr="softmax",
|
||||
t=1.0,
|
||||
learn_t=False,
|
||||
p=1.0,
|
||||
learn_p=False,
|
||||
y=0.0,
|
||||
learn_y=False,
|
||||
):
|
||||
|
||||
if aggr in ['softmax_sg', 'softmax', 'softmax_sum']:
|
||||
if aggr in ["softmax_sg", "softmax", "softmax_sum"]:
|
||||
|
||||
super(GenMessagePassing, self).__init__(aggr=None)
|
||||
self.aggr = aggr
|
||||
|
||||
if learn_t and (aggr == 'softmax' or aggr == 'softmax_sum'):
|
||||
if learn_t and (aggr == "softmax" or aggr == "softmax_sum"):
|
||||
self.learn_t = True
|
||||
self.t = torch.nn.Parameter(torch.Tensor([t]), requires_grad=True)
|
||||
else:
|
||||
self.learn_t = False
|
||||
self.t = t
|
||||
|
||||
if aggr == 'softmax_sum':
|
||||
if aggr == "softmax_sum":
|
||||
self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y)
|
||||
|
||||
elif aggr in ['power', 'power_sum']:
|
||||
elif aggr in ["power", "power_sum"]:
|
||||
|
||||
super(GenMessagePassing, self).__init__(aggr=None)
|
||||
self.aggr = aggr
|
||||
|
@ -36,45 +42,52 @@ class GenMessagePassing(MessagePassing):
|
|||
else:
|
||||
self.p = p
|
||||
|
||||
if aggr == 'power_sum':
|
||||
if aggr == "power_sum":
|
||||
self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y)
|
||||
else:
|
||||
super(GenMessagePassing, self).__init__(aggr=aggr)
|
||||
|
||||
def aggregate(self, inputs, index, ptr=None, dim_size=None):
|
||||
|
||||
if self.aggr in ['add', 'mean', 'max', None]:
|
||||
return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size)
|
||||
if self.aggr in ["add", "mean", "max", None]:
|
||||
return super(GenMessagePassing, self).aggregate(
|
||||
inputs, index, ptr, dim_size
|
||||
)
|
||||
|
||||
elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']:
|
||||
elif self.aggr in ["softmax_sg", "softmax", "softmax_sum"]:
|
||||
|
||||
if self.learn_t:
|
||||
out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)
|
||||
out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
out = scatter_softmax(inputs*self.t, index, dim=self.node_dim)
|
||||
out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
|
||||
|
||||
out = scatter(inputs*out, index, dim=self.node_dim,
|
||||
dim_size=dim_size, reduce='sum')
|
||||
out = scatter(
|
||||
inputs * out, index, dim=self.node_dim, dim_size=dim_size, reduce="sum"
|
||||
)
|
||||
|
||||
if self.aggr == 'softmax_sum':
|
||||
if self.aggr == "softmax_sum":
|
||||
self.sigmoid_y = torch.sigmoid(self.y)
|
||||
degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
|
||||
out = torch.pow(degrees, self.sigmoid_y) * out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
elif self.aggr in ['power', 'power_sum']:
|
||||
elif self.aggr in ["power", "power_sum"]:
|
||||
min_value, max_value = 1e-7, 1e1
|
||||
torch.clamp_(inputs, min_value, max_value)
|
||||
out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim,
|
||||
dim_size=dim_size, reduce='mean')
|
||||
out = scatter(
|
||||
torch.pow(inputs, self.p),
|
||||
index,
|
||||
dim=self.node_dim,
|
||||
dim_size=dim_size,
|
||||
reduce="mean",
|
||||
)
|
||||
torch.clamp_(out, min_value, max_value)
|
||||
out = torch.pow(out, 1/self.p)
|
||||
out = torch.pow(out, 1 / self.p)
|
||||
# torch.clamp(out, min_value, max_value)
|
||||
|
||||
if self.aggr == 'power_sum':
|
||||
if self.aggr == "power_sum":
|
||||
self.sigmoid_y = torch.sigmoid(self.y)
|
||||
degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
|
||||
out = torch.pow(degrees, self.sigmoid_y) * out
|
||||
|
@ -82,15 +95,16 @@ class GenMessagePassing(MessagePassing):
|
|||
return out
|
||||
|
||||
else:
|
||||
raise NotImplementedError('To be implemented')
|
||||
raise NotImplementedError("To be implemented")
|
||||
|
||||
|
||||
class MsgNorm(torch.nn.Module):
|
||||
def __init__(self, learn_msg_scale=False):
|
||||
super(MsgNorm, self).__init__()
|
||||
|
||||
self.msg_scale = torch.nn.Parameter(torch.Tensor([1.0]),
|
||||
requires_grad=learn_msg_scale)
|
||||
self.msg_scale = torch.nn.Parameter(
|
||||
torch.Tensor([1.0]), requires_grad=learn_msg_scale
|
||||
)
|
||||
|
||||
def forward(self, x, msg, p=2):
|
||||
msg = F.normalize(msg, p=p, dim=1)
|
||||
|
|
|
@ -9,32 +9,32 @@ from .utils.data_util import get_atom_feature_dims, get_bond_feature_dims
|
|||
def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1):
|
||||
# activation layer
|
||||
act = act_type.lower()
|
||||
if act == 'relu':
|
||||
if act == "relu":
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act == 'leakyrelu':
|
||||
elif act == "leakyrelu":
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act == 'prelu':
|
||||
elif act == "prelu":
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act == 'sigmoid':
|
||||
layer= nn.Sigmoid()
|
||||
elif act == "sigmoid":
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError('activation layer [%s] is not found' % act)
|
||||
raise NotImplementedError("activation layer [%s] is not found" % act)
|
||||
return layer
|
||||
|
||||
|
||||
def norm_layer(norm_type, nc):
|
||||
# normalization layer 1d
|
||||
norm = norm_type.lower()
|
||||
if norm == 'batch':
|
||||
if norm == "batch":
|
||||
layer = nn.BatchNorm1d(nc, affine=True, track_running_stats=False)
|
||||
elif norm == 'layer':
|
||||
elif norm == "layer":
|
||||
layer = nn.LayerNorm(nc, elementwise_affine=True)
|
||||
elif norm == 'instance':
|
||||
elif norm == "instance":
|
||||
layer = nn.InstanceNorm1d(nc, affine=False)
|
||||
elif norm == 'group':
|
||||
elif norm == "group":
|
||||
layer = nn.GroupNorm(32, nc, affine=True)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [%s] is not found' % norm)
|
||||
raise NotImplementedError("normalization layer [%s] is not found" % norm)
|
||||
return layer
|
||||
|
||||
|
||||
|
@ -52,9 +52,9 @@ class MultiSeq(Seq):
|
|||
|
||||
|
||||
class MLP(Seq):
|
||||
def __init__(self, channels, act='relu',
|
||||
norm=None, bias=True,
|
||||
drop=0., last_lin=False):
|
||||
def __init__(
|
||||
self, channels, act="relu", norm=None, bias=True, drop=0.0, last_lin=False
|
||||
):
|
||||
m = []
|
||||
|
||||
for i in range(1, len(channels)):
|
||||
|
@ -64,9 +64,9 @@ class MLP(Seq):
|
|||
if (i == len(channels) - 1) and last_lin:
|
||||
pass
|
||||
else:
|
||||
if norm is not None and norm.lower() != 'none':
|
||||
if norm is not None and norm.lower() != "none":
|
||||
m.append(norm_layer(norm, channels[i]))
|
||||
if act is not None and act.lower() != 'none':
|
||||
if act is not None and act.lower() != "none":
|
||||
m.append(act_layer(act))
|
||||
if drop > 0:
|
||||
m.append(nn.Dropout2d(drop))
|
||||
|
@ -76,7 +76,6 @@ class MLP(Seq):
|
|||
|
||||
|
||||
class AtomEncoder(nn.Module):
|
||||
|
||||
def __init__(self, emb_dim):
|
||||
super(AtomEncoder, self).__init__()
|
||||
|
||||
|
@ -97,7 +96,6 @@ class AtomEncoder(nn.Module):
|
|||
|
||||
|
||||
class BondEncoder(nn.Module):
|
||||
|
||||
def __init__(self, emb_dim):
|
||||
super(BondEncoder, self).__init__()
|
||||
|
||||
|
@ -115,5 +113,3 @@ class BondEncoder(nn.Module):
|
|||
bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])
|
||||
|
||||
return bond_embedding
|
||||
|
||||
|
||||
|
|
|
@ -10,35 +10,43 @@ from torch_geometric.utils import remove_self_loops, add_self_loops
|
|||
|
||||
class GENConv(GenMessagePassing):
|
||||
"""
|
||||
GENeralized Graph Convolution (GENConv): https://arxiv.org/pdf/2006.07739.pdf
|
||||
SoftMax & PowerMean Aggregation
|
||||
GENeralized Graph Convolution (GENConv): https://arxiv.org/pdf/2006.07739.pdf
|
||||
SoftMax & PowerMean Aggregation
|
||||
"""
|
||||
def __init__(self, in_dim, emb_dim,
|
||||
aggr='softmax',
|
||||
t=1.0, learn_t=False,
|
||||
p=1.0, learn_p=False,
|
||||
y=0.0, learn_y=False,
|
||||
msg_norm=False, learn_msg_scale=True,
|
||||
encode_edge=False, bond_encoder=False,
|
||||
edge_feat_dim=None,
|
||||
norm='batch', mlp_layers=2,
|
||||
eps=1e-7):
|
||||
|
||||
super(GENConv, self).__init__(aggr=aggr,
|
||||
t=t, learn_t=learn_t,
|
||||
p=p, learn_p=learn_p,
|
||||
y=y, learn_y=learn_y)
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
emb_dim,
|
||||
aggr="softmax",
|
||||
t=1.0,
|
||||
learn_t=False,
|
||||
p=1.0,
|
||||
learn_p=False,
|
||||
y=0.0,
|
||||
learn_y=False,
|
||||
msg_norm=False,
|
||||
learn_msg_scale=True,
|
||||
encode_edge=False,
|
||||
bond_encoder=False,
|
||||
edge_feat_dim=None,
|
||||
norm="batch",
|
||||
mlp_layers=2,
|
||||
eps=1e-7,
|
||||
):
|
||||
|
||||
super(GENConv, self).__init__(
|
||||
aggr=aggr, t=t, learn_t=learn_t, p=p, learn_p=learn_p, y=y, learn_y=learn_y
|
||||
)
|
||||
|
||||
channels_list = [in_dim]
|
||||
|
||||
for i in range(mlp_layers-1):
|
||||
channels_list.append(in_dim*2)
|
||||
for i in range(mlp_layers - 1):
|
||||
channels_list.append(in_dim * 2)
|
||||
|
||||
channels_list.append(emb_dim)
|
||||
|
||||
self.mlp = MLP(channels=channels_list,
|
||||
norm=norm,
|
||||
last_lin=True)
|
||||
self.mlp = MLP(channels=channels_list, norm=norm, last_lin=True)
|
||||
|
||||
self.msg_encoder = torch.nn.ReLU()
|
||||
self.eps = eps
|
||||
|
@ -91,14 +99,23 @@ class MRConv(nn.Module):
|
|||
"""
|
||||
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'):
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, act="relu", norm=None, bias=True, aggr="max"
|
||||
):
|
||||
super(MRConv, self).__init__()
|
||||
self.nn = MLP([in_channels*2, out_channels], act, norm, bias)
|
||||
self.nn = MLP([in_channels * 2, out_channels], act, norm, bias)
|
||||
self.aggr = aggr
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
""""""
|
||||
x_j = tg.utils.scatter_(self.aggr, torch.index_select(x, 0, edge_index[0]) - torch.index_select(x, 0, edge_index[1]), edge_index[1], dim_size=x.shape[0])
|
||||
x_j = tg.utils.scatter_(
|
||||
self.aggr,
|
||||
torch.index_select(x, 0, edge_index[0])
|
||||
- torch.index_select(x, 0, edge_index[1]),
|
||||
edge_index[1],
|
||||
dim_size=x.shape[0],
|
||||
)
|
||||
return self.nn(torch.cat([x, x_j], dim=1))
|
||||
|
||||
|
||||
|
@ -106,8 +123,13 @@ class EdgConv(tg.nn.EdgeConv):
|
|||
"""
|
||||
Edge convolution layer (with activation, batch normalization)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'):
|
||||
super(EdgConv, self).__init__(MLP([in_channels*2, out_channels], act, norm, bias), aggr)
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, act="relu", norm=None, bias=True, aggr="max"
|
||||
):
|
||||
super(EdgConv, self).__init__(
|
||||
MLP([in_channels * 2, out_channels], act, norm, bias), aggr
|
||||
)
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
return super(EdgConv, self).forward(x, edge_index)
|
||||
|
@ -117,10 +139,13 @@ class GATConv(nn.Module):
|
|||
"""
|
||||
Graph Attention Convolution layer (with activation, batch normalization)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, heads=8):
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, act="relu", norm=None, bias=True, heads=8
|
||||
):
|
||||
super(GATConv, self).__init__()
|
||||
self.gconv = tg.nn.GATConv(in_channels, out_channels, heads, bias=bias)
|
||||
m =[]
|
||||
m = []
|
||||
if act:
|
||||
m.append(act_layer(act))
|
||||
if norm:
|
||||
|
@ -154,19 +179,25 @@ class SAGEConv(tg.nn.SAGEConv):
|
|||
:class:`torch_geometric.nn.conv.MessagePassing`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
nn,
|
||||
norm=True,
|
||||
bias=True,
|
||||
relative=False,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
nn,
|
||||
norm=True,
|
||||
bias=True,
|
||||
relative=False,
|
||||
**kwargs
|
||||
):
|
||||
self.relative = relative
|
||||
if norm is not None:
|
||||
super(SAGEConv, self).__init__(in_channels, out_channels, True, bias, **kwargs)
|
||||
super(SAGEConv, self).__init__(
|
||||
in_channels, out_channels, True, bias, **kwargs
|
||||
)
|
||||
else:
|
||||
super(SAGEConv, self).__init__(in_channels, out_channels, False, bias, **kwargs)
|
||||
super(SAGEConv, self).__init__(
|
||||
in_channels, out_channels, False, bias, **kwargs
|
||||
)
|
||||
self.nn = nn
|
||||
|
||||
def forward(self, x, edge_index, size=None):
|
||||
|
@ -199,9 +230,19 @@ class RSAGEConv(SAGEConv):
|
|||
Residual SAGE convolution layer (with activation, batch normalization)
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, relative=False):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
relative=False,
|
||||
):
|
||||
nn = MLP([out_channels + in_channels, out_channels], act, norm, bias)
|
||||
super(RSAGEConv, self).__init__(in_channels, out_channels, nn, norm, bias, relative)
|
||||
super(RSAGEConv, self).__init__(
|
||||
in_channels, out_channels, nn, norm, bias, relative
|
||||
)
|
||||
|
||||
|
||||
class SemiGCNConv(nn.Module):
|
||||
|
@ -209,7 +250,7 @@ class SemiGCNConv(nn.Module):
|
|||
SemiGCN convolution layer (with activation, batch normalization)
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
|
||||
def __init__(self, in_channels, out_channels, act="relu", norm=None, bias=True):
|
||||
super(SemiGCNConv, self).__init__()
|
||||
self.gconv = tg.nn.GCNConv(in_channels, out_channels, bias=bias)
|
||||
m = []
|
||||
|
@ -228,7 +269,10 @@ class GinConv(tg.nn.GINConv):
|
|||
"""
|
||||
GINConv layer (with activation, batch normalization)
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='add'):
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, act="relu", norm=None, bias=True, aggr="add"
|
||||
):
|
||||
super(GinConv, self).__init__(MLP([in_channels, out_channels], act, norm, bias))
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
|
@ -239,25 +283,36 @@ class GraphConv(nn.Module):
|
|||
"""
|
||||
Static graph convolution layer
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, conv='edge',
|
||||
act='relu', norm=None, bias=True, heads=8):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
heads=8,
|
||||
):
|
||||
super(GraphConv, self).__init__()
|
||||
if conv.lower() == 'edge':
|
||||
if conv.lower() == "edge":
|
||||
self.gconv = EdgConv(in_channels, out_channels, act, norm, bias)
|
||||
elif conv.lower() == 'mr':
|
||||
elif conv.lower() == "mr":
|
||||
self.gconv = MRConv(in_channels, out_channels, act, norm, bias)
|
||||
elif conv.lower() == 'gat':
|
||||
self.gconv = GATConv(in_channels, out_channels//heads, act, norm, bias, heads)
|
||||
elif conv.lower() == 'gcn':
|
||||
elif conv.lower() == "gat":
|
||||
self.gconv = GATConv(
|
||||
in_channels, out_channels // heads, act, norm, bias, heads
|
||||
)
|
||||
elif conv.lower() == "gcn":
|
||||
self.gconv = SemiGCNConv(in_channels, out_channels, act, norm, bias)
|
||||
elif conv.lower() == 'gin':
|
||||
elif conv.lower() == "gin":
|
||||
self.gconv = GinConv(in_channels, out_channels, act, norm, bias)
|
||||
elif conv.lower() == 'sage':
|
||||
elif conv.lower() == "sage":
|
||||
self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, False)
|
||||
elif conv.lower() == 'rsage':
|
||||
elif conv.lower() == "rsage":
|
||||
self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, True)
|
||||
else:
|
||||
raise NotImplementedError('conv {} is not implemented'.format(conv))
|
||||
raise NotImplementedError("conv {} is not implemented".format(conv))
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
return self.gconv(x, edge_index)
|
||||
|
@ -267,9 +322,23 @@ class DynConv(GraphConv):
|
|||
"""
|
||||
Dynamic graph convolution layer
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu',
|
||||
norm=None, bias=True, heads=8, **kwargs):
|
||||
super(DynConv, self).__init__(in_channels, out_channels, conv, act, norm, bias, heads)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=9,
|
||||
dilation=1,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
heads=8,
|
||||
**kwargs
|
||||
):
|
||||
super(DynConv, self).__init__(
|
||||
in_channels, out_channels, conv, act, norm, bias, heads
|
||||
)
|
||||
self.k = kernel_size
|
||||
self.d = dilation
|
||||
self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, **kwargs)
|
||||
|
@ -283,11 +352,23 @@ class PlainDynBlock(nn.Module):
|
|||
"""
|
||||
Plain Dynamic graph convolution block
|
||||
"""
|
||||
def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
|
||||
bias=True, res_scale=1, **kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
kernel_size=9,
|
||||
dilation=1,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
res_scale=1,
|
||||
**kwargs
|
||||
):
|
||||
super(PlainDynBlock, self).__init__()
|
||||
self.body = DynConv(channels, channels, kernel_size, dilation, conv,
|
||||
act, norm, bias, **kwargs)
|
||||
self.body = DynConv(
|
||||
channels, channels, kernel_size, dilation, conv, act, norm, bias, **kwargs
|
||||
)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x, batch=None):
|
||||
|
@ -298,25 +379,58 @@ class ResDynBlock(nn.Module):
|
|||
"""
|
||||
Residual Dynamic graph convolution block
|
||||
"""
|
||||
def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
|
||||
bias=True, res_scale=1, **kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
kernel_size=9,
|
||||
dilation=1,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
res_scale=1,
|
||||
**kwargs
|
||||
):
|
||||
super(ResDynBlock, self).__init__()
|
||||
self.body = DynConv(channels, channels, kernel_size, dilation, conv,
|
||||
act, norm, bias, **kwargs)
|
||||
self.body = DynConv(
|
||||
channels, channels, kernel_size, dilation, conv, act, norm, bias, **kwargs
|
||||
)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x, batch=None):
|
||||
return self.body(x, batch) + x*self.res_scale, batch
|
||||
return self.body(x, batch) + x * self.res_scale, batch
|
||||
|
||||
|
||||
class DenseDynBlock(nn.Module):
|
||||
"""
|
||||
Dense Dynamic graph convolution block
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, bias=True, **kwargs):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=9,
|
||||
dilation=1,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
**kwargs
|
||||
):
|
||||
super(DenseDynBlock, self).__init__()
|
||||
self.body = DynConv(in_channels, out_channels, kernel_size, dilation, conv,
|
||||
act, norm, bias, **kwargs)
|
||||
self.body = DynConv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
conv,
|
||||
act,
|
||||
norm,
|
||||
bias,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def forward(self, x, batch=None):
|
||||
dense = self.body(x, batch)
|
||||
|
@ -327,24 +441,43 @@ class ResGraphBlock(nn.Module):
|
|||
"""
|
||||
Residual Static graph convolution block
|
||||
"""
|
||||
def __init__(self, channels, conv='edge', act='relu', norm=None, bias=True, heads=8, res_scale=1):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
heads=8,
|
||||
res_scale=1,
|
||||
):
|
||||
super(ResGraphBlock, self).__init__()
|
||||
self.body = GraphConv(channels, channels, conv, act, norm, bias, heads)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
return self.body(x, edge_index) + x*self.res_scale, edge_index
|
||||
return self.body(x, edge_index) + x * self.res_scale, edge_index
|
||||
|
||||
|
||||
class DenseGraphBlock(nn.Module):
|
||||
"""
|
||||
Dense Static graph convolution block
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True, heads=8):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv="edge",
|
||||
act="relu",
|
||||
norm=None,
|
||||
bias=True,
|
||||
heads=8,
|
||||
):
|
||||
super(DenseGraphBlock, self).__init__()
|
||||
self.body = GraphConv(in_channels, out_channels, conv, act, norm, bias, heads)
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
dense = self.body(x, edge_index)
|
||||
return torch.cat((x, dense), 1), edge_index
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .ckpt_util import *
|
||||
|
||||
# from .data_util import *
|
||||
from .loss import *
|
||||
from .metrics import *
|
||||
from .optim import *
|
||||
# from .tf_logger import *
|
||||
|
||||
# from .tf_logger import *
|
||||
|
|
|
@ -6,25 +6,27 @@ import logging
|
|||
import numpy as np
|
||||
|
||||
|
||||
def save_ckpt(model, optimizer, loss, epoch, save_path, name_pre, name_post='best'):
|
||||
def save_ckpt(model, optimizer, loss, epoch, save_path, name_pre, name_post="best"):
|
||||
model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
|
||||
state = {
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model_cpu,
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss
|
||||
}
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model_cpu,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
print("Directory ", save_path, " is created.")
|
||||
|
||||
filename = '{}/{}_{}.pth'.format(save_path, name_pre, name_post)
|
||||
filename = "{}/{}_{}.pth".format(save_path, name_pre, name_post)
|
||||
torch.save(state, filename)
|
||||
print('model has been saved as {}'.format(filename))
|
||||
print("model has been saved as {}".format(filename))
|
||||
|
||||
|
||||
def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax means max best
|
||||
def load_pretrained_models(
|
||||
model, pretrained_model, phase, ismax=True
|
||||
): # ismax means max best
|
||||
if ismax:
|
||||
best_value = -np.inf
|
||||
else:
|
||||
|
@ -36,7 +38,7 @@ def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax
|
|||
logging.info("===> Loading checkpoint '{}'".format(pretrained_model))
|
||||
checkpoint = torch.load(pretrained_model)
|
||||
try:
|
||||
best_value = checkpoint['best_value']
|
||||
best_value = checkpoint["best_value"]
|
||||
if best_value == -np.inf or best_value == np.inf:
|
||||
show_best_value = False
|
||||
else:
|
||||
|
@ -46,11 +48,13 @@ def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax
|
|||
show_best_value = False
|
||||
|
||||
model_dict = model.state_dict()
|
||||
ckpt_model_state_dict = checkpoint['state_dict']
|
||||
ckpt_model_state_dict = checkpoint["state_dict"]
|
||||
|
||||
# rename ckpt (avoid name is not same because of multi-gpus)
|
||||
is_model_multi_gpus = True if list(model_dict)[0][0][0] == 'm' else False
|
||||
is_ckpt_multi_gpus = True if list(ckpt_model_state_dict)[0][0] == 'm' else False
|
||||
is_model_multi_gpus = True if list(model_dict)[0][0][0] == "m" else False
|
||||
is_ckpt_multi_gpus = (
|
||||
True if list(ckpt_model_state_dict)[0][0] == "m" else False
|
||||
)
|
||||
|
||||
if not (is_model_multi_gpus == is_ckpt_multi_gpus):
|
||||
temp_dict = OrderedDict()
|
||||
|
@ -58,7 +62,7 @@ def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax
|
|||
if is_ckpt_multi_gpus:
|
||||
name = k[7:] # remove 'module.'
|
||||
else:
|
||||
name = 'module.'+k # add 'module'
|
||||
name = "module." + k # add 'module'
|
||||
temp_dict[name] = v
|
||||
# load params
|
||||
ckpt_model_state_dict = temp_dict
|
||||
|
@ -67,34 +71,44 @@ def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax
|
|||
model.load_state_dict(ckpt_model_state_dict)
|
||||
|
||||
if show_best_value:
|
||||
logging.info("The pretrained_model is at checkpoint {}. \t "
|
||||
"Best value: {}".format(checkpoint['epoch'], best_value))
|
||||
logging.info(
|
||||
"The pretrained_model is at checkpoint {}. \t "
|
||||
"Best value: {}".format(checkpoint["epoch"], best_value)
|
||||
)
|
||||
else:
|
||||
logging.info("The pretrained_model is at checkpoint {}.".format(checkpoint['epoch']))
|
||||
logging.info(
|
||||
"The pretrained_model is at checkpoint {}.".format(
|
||||
checkpoint["epoch"]
|
||||
)
|
||||
)
|
||||
|
||||
if phase == 'train':
|
||||
epoch = checkpoint['epoch']
|
||||
if phase == "train":
|
||||
epoch = checkpoint["epoch"]
|
||||
else:
|
||||
epoch = -1
|
||||
else:
|
||||
raise ImportError("===> No checkpoint found at '{}'".format(pretrained_model))
|
||||
raise ImportError(
|
||||
"===> No checkpoint found at '{}'".format(pretrained_model)
|
||||
)
|
||||
else:
|
||||
logging.info('===> No pre-trained model')
|
||||
logging.info("===> No pre-trained model")
|
||||
return model, best_value, epoch
|
||||
|
||||
|
||||
def load_pretrained_optimizer(pretrained_model, optimizer, scheduler, lr, use_ckpt_lr=True):
|
||||
def load_pretrained_optimizer(
|
||||
pretrained_model, optimizer, scheduler, lr, use_ckpt_lr=True
|
||||
):
|
||||
if pretrained_model:
|
||||
if os.path.isfile(pretrained_model):
|
||||
checkpoint = torch.load(pretrained_model)
|
||||
if 'optimizer_state_dict' in checkpoint.keys():
|
||||
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if "optimizer_state_dict" in checkpoint.keys():
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if torch.is_tensor(v):
|
||||
state[k] = v.cuda()
|
||||
if 'scheduler_state_dict' in checkpoint.keys():
|
||||
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
if "scheduler_state_dict" in checkpoint.keys():
|
||||
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
if use_ckpt_lr:
|
||||
try:
|
||||
lr = scheduler.get_lr()[0]
|
||||
|
@ -105,26 +119,30 @@ def load_pretrained_optimizer(pretrained_model, optimizer, scheduler, lr, use_ck
|
|||
|
||||
|
||||
def save_checkpoint(state, is_best, save_path, postname):
|
||||
filename = '{}/{}_{}.pth'.format(save_path, postname, int(state['epoch']))
|
||||
filename = "{}/{}_{}.pth".format(save_path, postname, int(state["epoch"]))
|
||||
torch.save(state, filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, '{}/{}_best.pth'.format(save_path, postname))
|
||||
shutil.copyfile(filename, "{}/{}_best.pth".format(save_path, postname))
|
||||
|
||||
|
||||
def change_ckpt_dict(model, optimizer, scheduler, opt):
|
||||
|
||||
for _ in range(opt.epoch):
|
||||
scheduler.step()
|
||||
is_best = (opt.test_value < opt.best_value)
|
||||
is_best = opt.test_value < opt.best_value
|
||||
opt.best_value = min(opt.test_value, opt.best_value)
|
||||
|
||||
model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
|
||||
# optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()}
|
||||
save_checkpoint({
|
||||
'epoch': opt.epoch,
|
||||
'state_dict': model_cpu,
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'scheduler_state_dict': scheduler.state_dict(),
|
||||
'best_value': opt.best_value,
|
||||
}, is_best, opt.save_path, opt.post)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"epoch": opt.epoch,
|
||||
"state_dict": model_cpu,
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"best_value": opt.best_value,
|
||||
},
|
||||
is_best,
|
||||
opt.save_path,
|
||||
opt.post,
|
||||
)
|
||||
|
|
|
@ -28,17 +28,20 @@ def add_zeros(data):
|
|||
return data
|
||||
|
||||
|
||||
def extract_node_feature(data, reduce='add'):
|
||||
if reduce in ['mean', 'max', 'add']:
|
||||
data.x = scatter(data.edge_attr,
|
||||
data.edge_index[0],
|
||||
dim=0,
|
||||
dim_size=data.num_nodes,
|
||||
reduce=reduce)
|
||||
def extract_node_feature(data, reduce="add"):
|
||||
if reduce in ["mean", "max", "add"]:
|
||||
data.x = scatter(
|
||||
data.edge_attr,
|
||||
data.edge_index[0],
|
||||
dim=0,
|
||||
dim_size=data.num_nodes,
|
||||
reduce=reduce,
|
||||
)
|
||||
else:
|
||||
raise Exception('Unknown Aggregation Type')
|
||||
raise Exception("Unknown Aggregation Type")
|
||||
return data
|
||||
|
||||
|
||||
# random partition graph
|
||||
def random_partition_graph(num_nodes, cluster_number=10):
|
||||
parts = np.random.randint(cluster_number, size=num_nodes)
|
||||
|
@ -47,7 +50,7 @@ def random_partition_graph(num_nodes, cluster_number=10):
|
|||
|
||||
def generate_sub_graphs(adj, parts, cluster_number=10, batch_size=1):
|
||||
# convert sparse tensor to scipy csr
|
||||
adj = adj.to_scipy(layout='csr')
|
||||
adj = adj.to_scipy(layout="csr")
|
||||
|
||||
num_batches = cluster_number // batch_size
|
||||
|
||||
|
@ -56,20 +59,27 @@ def generate_sub_graphs(adj, parts, cluster_number=10, batch_size=1):
|
|||
|
||||
for cluster in range(num_batches):
|
||||
sg_nodes[cluster] = np.where(parts == cluster)[0]
|
||||
sg_edges[cluster] = tg.utils.from_scipy_sparse_matrix(adj[sg_nodes[cluster], :][:, sg_nodes[cluster]])[0]
|
||||
sg_edges[cluster] = tg.utils.from_scipy_sparse_matrix(
|
||||
adj[sg_nodes[cluster], :][:, sg_nodes[cluster]]
|
||||
)[0]
|
||||
|
||||
return sg_nodes, sg_edges
|
||||
|
||||
|
||||
def random_rotate(points):
|
||||
theta = np.random.uniform(0, np.pi * 2)
|
||||
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
||||
rotation_matrix = np.array(
|
||||
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
|
||||
)
|
||||
rotation_matrix = torch.from_numpy(rotation_matrix).float()
|
||||
points[:, 0:2] = torch.matmul(points[:, [0, 1]].transpose(1, 3), rotation_matrix).transpose(1, 3)
|
||||
points[:, 0:2] = torch.matmul(
|
||||
points[:, [0, 1]].transpose(1, 3), rotation_matrix
|
||||
).transpose(1, 3)
|
||||
return points
|
||||
|
||||
|
||||
def random_translate(points, mean=0, std=0.02):
|
||||
points += torch.randn(points.shape)*std + mean
|
||||
points += torch.randn(points.shape) * std + mean
|
||||
return points
|
||||
|
||||
|
||||
|
@ -82,15 +92,17 @@ def random_points_augmentation(points, rotate=False, translate=False, **kwargs):
|
|||
return points
|
||||
|
||||
|
||||
def scale_translate_pointcloud(pointcloud, shift=[-0.2, 0.2], scale=[2. / 3., 3. /2.]):
|
||||
def scale_translate_pointcloud(
|
||||
pointcloud, shift=[-0.2, 0.2], scale=[2.0 / 3.0, 3.0 / 2.0]
|
||||
):
|
||||
"""
|
||||
for scaling and shifting the point cloud
|
||||
:param pointcloud:
|
||||
:return:
|
||||
"""
|
||||
B, C, N = pointcloud.shape[0:3]
|
||||
scale = scale[0] + torch.rand([B, C, 1, 1])*(scale[1]-scale[0])
|
||||
shift = shift[0] + torch.rand([B, C, 1, 1]) * (shift[1]-shift[0])
|
||||
scale = scale[0] + torch.rand([B, C, 1, 1]) * (scale[1] - scale[0])
|
||||
shift = shift[0] + torch.rand([B, C, 1, 1]) * (shift[1] - shift[0])
|
||||
translated_pointcloud = torch.mul(pointcloud, scale) + shift
|
||||
return translated_pointcloud
|
||||
|
||||
|
@ -126,25 +138,29 @@ class PartNet(InMemoryDataset):
|
|||
final dataset. (default: :obj:`None`)
|
||||
"""
|
||||
# the dataset we use for our paper is pre-released version
|
||||
def __init__(self,
|
||||
root,
|
||||
dataset='sem_seg_h5',
|
||||
obj_category='Bed',
|
||||
level=3,
|
||||
phase='train',
|
||||
transform=None,
|
||||
pre_transform=None,
|
||||
pre_filter=None):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
dataset="sem_seg_h5",
|
||||
obj_category="Bed",
|
||||
level=3,
|
||||
phase="train",
|
||||
transform=None,
|
||||
pre_transform=None,
|
||||
pre_filter=None,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.level = level
|
||||
self.obj_category = obj_category
|
||||
self.object = '-'.join([self.obj_category, str(self.level)])
|
||||
self.level_folder = 'level_'+str(self.level)
|
||||
self.processed_file_folder = osp.join(self.dataset, self.level_folder, self.object)
|
||||
self.object = "-".join([self.obj_category, str(self.level)])
|
||||
self.level_folder = "level_" + str(self.level)
|
||||
self.processed_file_folder = osp.join(
|
||||
self.dataset, self.level_folder, self.object
|
||||
)
|
||||
super(PartNet, self).__init__(root, transform, pre_transform, pre_filter)
|
||||
if phase == 'test':
|
||||
if phase == "test":
|
||||
path = self.processed_paths[1]
|
||||
elif phase == 'val':
|
||||
elif phase == "val":
|
||||
path = self.processed_paths[2]
|
||||
else:
|
||||
path = self.processed_paths[0]
|
||||
|
@ -156,19 +172,24 @@ class PartNet(InMemoryDataset):
|
|||
|
||||
@property
|
||||
def processed_file_names(self):
|
||||
return osp.join(self.processed_file_folder, 'train.pt'), osp.join(self.processed_file_folder, 'test.pt'), \
|
||||
osp.join(self.processed_file_folder, 'val.pt')
|
||||
return (
|
||||
osp.join(self.processed_file_folder, "train.pt"),
|
||||
osp.join(self.processed_file_folder, "test.pt"),
|
||||
osp.join(self.processed_file_folder, "val.pt"),
|
||||
)
|
||||
|
||||
def download(self):
|
||||
path = osp.join(self.raw_dir, self.dataset)
|
||||
if not osp.exists(path):
|
||||
raise FileExistsError('PartNet can only downloaded via application. '
|
||||
'See details in https://cs.stanford.edu/~kaichun/partnet/')
|
||||
raise FileExistsError(
|
||||
"PartNet can only downloaded via application. "
|
||||
"See details in https://cs.stanford.edu/~kaichun/partnet/"
|
||||
)
|
||||
# path = download_url(self.url, self.root)
|
||||
extract_zip(path, self.root)
|
||||
os.unlink(path)
|
||||
shutil.rmtree(self.raw_dir)
|
||||
name = self.url.split(os.sep)[-1].split('.')[0]
|
||||
name = self.url.split(os.sep)[-1].split(".")[0]
|
||||
os.rename(osp.join(self.root, name), self.raw_dir)
|
||||
|
||||
def process(self):
|
||||
|
@ -176,31 +197,38 @@ class PartNet(InMemoryDataset):
|
|||
processed_path = osp.join(self.processed_dir, self.processed_file_folder)
|
||||
if not osp.exists(processed_path):
|
||||
os.makedirs(osp.join(processed_path))
|
||||
torch.save(self.process_set('train'), self.processed_paths[0])
|
||||
torch.save(self.process_set('test'), self.processed_paths[1])
|
||||
torch.save(self.process_set('val'), self.processed_paths[2])
|
||||
torch.save(self.process_set("train"), self.processed_paths[0])
|
||||
torch.save(self.process_set("test"), self.processed_paths[1])
|
||||
torch.save(self.process_set("val"), self.processed_paths[2])
|
||||
|
||||
def process_set(self, dataset):
|
||||
if self.dataset == 'ins_seg_h5':
|
||||
raw_path = osp.join(self.raw_dir, 'ins_seg_h5_for_sgpn', self.dataset)
|
||||
categories = glob(osp.join(raw_path, '*'))
|
||||
if self.dataset == "ins_seg_h5":
|
||||
raw_path = osp.join(self.raw_dir, "ins_seg_h5_for_sgpn", self.dataset)
|
||||
categories = glob(osp.join(raw_path, "*"))
|
||||
categories = sorted([x.split(os.sep)[-1] for x in categories])
|
||||
|
||||
data_list = []
|
||||
for target, category in enumerate(tqdm(categories)):
|
||||
folder = osp.join(raw_path, category)
|
||||
paths = glob('{}/{}-*.h5'.format(folder, dataset))
|
||||
paths = glob("{}/{}-*.h5".format(folder, dataset))
|
||||
labels, nors, opacitys, pts, rgbs = [], [], [], [], []
|
||||
for path in paths:
|
||||
f = h5py.File(path)
|
||||
pts += torch.from_numpy(f['pts'][:]).unbind(0)
|
||||
labels += torch.from_numpy(f['label'][:]).to(torch.long).unbind(0)
|
||||
nors += torch.from_numpy(f['nor'][:]).unbind(0)
|
||||
opacitys += torch.from_numpy(f['opacity'][:]).unbind(0)
|
||||
rgbs += torch.from_numpy(f['rgb'][:]).to(torch.float32).unbind(0)
|
||||
pts += torch.from_numpy(f["pts"][:]).unbind(0)
|
||||
labels += torch.from_numpy(f["label"][:]).to(torch.long).unbind(0)
|
||||
nors += torch.from_numpy(f["nor"][:]).unbind(0)
|
||||
opacitys += torch.from_numpy(f["opacity"][:]).unbind(0)
|
||||
rgbs += torch.from_numpy(f["rgb"][:]).to(torch.float32).unbind(0)
|
||||
|
||||
for i, (pt, label, nor, opacity, rgb) in enumerate(zip(pts, labels, nors, opacitys, rgbs)):
|
||||
data = Data(pos=pt[:, :3], y=label, norm=nor[:, :3], x=torch.cat((opacity.unsqueeze(-1), rgb/255.), 1))
|
||||
for i, (pt, label, nor, opacity, rgb) in enumerate(
|
||||
zip(pts, labels, nors, opacitys, rgbs)
|
||||
):
|
||||
data = Data(
|
||||
pos=pt[:, :3],
|
||||
y=label,
|
||||
norm=nor[:, :3],
|
||||
x=torch.cat((opacity.unsqueeze(-1), rgb / 255.0), 1),
|
||||
)
|
||||
|
||||
if self.pre_filter is not None and not self.pre_filter(data):
|
||||
continue
|
||||
|
@ -215,14 +243,18 @@ class PartNet(InMemoryDataset):
|
|||
# class_name = []
|
||||
for target, category in enumerate(tqdm(categories)):
|
||||
folder = osp.join(raw_path, category)
|
||||
paths = glob('{}/{}-*.h5'.format(folder, dataset))
|
||||
paths = glob("{}/{}-*.h5".format(folder, dataset))
|
||||
labels, pts = [], []
|
||||
# clss = category.split('-')[0]
|
||||
|
||||
for path in paths:
|
||||
f = h5py.File(path)
|
||||
pts += torch.from_numpy(f['data'][:].astype(np.float32)).unbind(0)
|
||||
labels += torch.from_numpy(f['label_seg'][:].astype(np.float32)).to(torch.long).unbind(0)
|
||||
pts += torch.from_numpy(f["data"][:].astype(np.float32)).unbind(0)
|
||||
labels += (
|
||||
torch.from_numpy(f["label_seg"][:].astype(np.float32))
|
||||
.to(torch.long)
|
||||
.unbind(0)
|
||||
)
|
||||
for i, (pt, label) in enumerate(zip(pts, labels)):
|
||||
data = Data(pos=pt[:, :3], y=label)
|
||||
# data = PartData(pos=pt[:, :3], y=label, clss=clss)
|
||||
|
@ -235,10 +267,7 @@ class PartNet(InMemoryDataset):
|
|||
|
||||
|
||||
class PartData(Data):
|
||||
def __init__(self,
|
||||
y=None,
|
||||
pos=None,
|
||||
clss=None):
|
||||
def __init__(self, y=None, pos=None, clss=None):
|
||||
super(PartData).__init__(pos=pos, y=y)
|
||||
self.clss = clss
|
||||
|
||||
|
@ -246,38 +275,30 @@ class PartData(Data):
|
|||
# allowable multiple choice node and edge features
|
||||
# code from https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py
|
||||
allowable_features = {
|
||||
'possible_atomic_num_list' : list(range(1, 119)) + ['misc'],
|
||||
'possible_chirality_list' : [
|
||||
'CHI_UNSPECIFIED',
|
||||
'CHI_TETRAHEDRAL_CW',
|
||||
'CHI_TETRAHEDRAL_CCW',
|
||||
'CHI_OTHER'
|
||||
"possible_atomic_num_list": list(range(1, 119)) + ["misc"],
|
||||
"possible_chirality_list": [
|
||||
"CHI_UNSPECIFIED",
|
||||
"CHI_TETRAHEDRAL_CW",
|
||||
"CHI_TETRAHEDRAL_CCW",
|
||||
"CHI_OTHER",
|
||||
],
|
||||
'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
|
||||
'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
|
||||
'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
|
||||
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
|
||||
'possible_hybridization_list' : [
|
||||
'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc'
|
||||
],
|
||||
'possible_is_aromatic_list': [False, True],
|
||||
'possible_is_in_ring_list': [False, True],
|
||||
'possible_bond_type_list' : [
|
||||
'SINGLE',
|
||||
'DOUBLE',
|
||||
'TRIPLE',
|
||||
'AROMATIC',
|
||||
'misc'
|
||||
"possible_degree_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "misc"],
|
||||
"possible_formal_charge_list": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, "misc"],
|
||||
"possible_numH_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, "misc"],
|
||||
"possible_number_radical_e_list": [0, 1, 2, 3, 4, "misc"],
|
||||
"possible_hybridization_list": ["SP", "SP2", "SP3", "SP3D", "SP3D2", "misc"],
|
||||
"possible_is_aromatic_list": [False, True],
|
||||
"possible_is_in_ring_list": [False, True],
|
||||
"possible_bond_type_list": ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC", "misc"],
|
||||
"possible_bond_stereo_list": [
|
||||
"STEREONONE",
|
||||
"STEREOZ",
|
||||
"STEREOE",
|
||||
"STEREOCIS",
|
||||
"STEREOTRANS",
|
||||
"STEREOANY",
|
||||
],
|
||||
'possible_bond_stereo_list': [
|
||||
'STEREONONE',
|
||||
'STEREOZ',
|
||||
'STEREOE',
|
||||
'STEREOCIS',
|
||||
'STEREOTRANS',
|
||||
'STEREOANY',
|
||||
],
|
||||
'possible_is_conjugated_list': [False, True],
|
||||
"possible_is_conjugated_list": [False, True],
|
||||
}
|
||||
|
||||
|
||||
|
@ -298,31 +319,44 @@ def atom_to_feature_vector(atom):
|
|||
:return: list
|
||||
"""
|
||||
atom_feature = [
|
||||
safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()),
|
||||
allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())),
|
||||
safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()),
|
||||
safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()),
|
||||
safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()),
|
||||
safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()),
|
||||
safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())),
|
||||
allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()),
|
||||
allowable_features['possible_is_in_ring_list'].index(atom.IsInRing()),
|
||||
]
|
||||
safe_index(allowable_features["possible_atomic_num_list"], atom.GetAtomicNum()),
|
||||
allowable_features["possible_chirality_list"].index(str(atom.GetChiralTag())),
|
||||
safe_index(allowable_features["possible_degree_list"], atom.GetTotalDegree()),
|
||||
safe_index(
|
||||
allowable_features["possible_formal_charge_list"], atom.GetFormalCharge()
|
||||
),
|
||||
safe_index(allowable_features["possible_numH_list"], atom.GetTotalNumHs()),
|
||||
safe_index(
|
||||
allowable_features["possible_number_radical_e_list"],
|
||||
atom.GetNumRadicalElectrons(),
|
||||
),
|
||||
safe_index(
|
||||
allowable_features["possible_hybridization_list"],
|
||||
str(atom.GetHybridization()),
|
||||
),
|
||||
allowable_features["possible_is_aromatic_list"].index(atom.GetIsAromatic()),
|
||||
allowable_features["possible_is_in_ring_list"].index(atom.IsInRing()),
|
||||
]
|
||||
return atom_feature
|
||||
|
||||
|
||||
def get_atom_feature_dims():
|
||||
return list(map(len, [
|
||||
allowable_features['possible_atomic_num_list'],
|
||||
allowable_features['possible_chirality_list'],
|
||||
allowable_features['possible_degree_list'],
|
||||
allowable_features['possible_formal_charge_list'],
|
||||
allowable_features['possible_numH_list'],
|
||||
allowable_features['possible_number_radical_e_list'],
|
||||
allowable_features['possible_hybridization_list'],
|
||||
allowable_features['possible_is_aromatic_list'],
|
||||
allowable_features['possible_is_in_ring_list']
|
||||
]))
|
||||
return list(
|
||||
map(
|
||||
len,
|
||||
[
|
||||
allowable_features["possible_atomic_num_list"],
|
||||
allowable_features["possible_chirality_list"],
|
||||
allowable_features["possible_degree_list"],
|
||||
allowable_features["possible_formal_charge_list"],
|
||||
allowable_features["possible_numH_list"],
|
||||
allowable_features["possible_number_radical_e_list"],
|
||||
allowable_features["possible_hybridization_list"],
|
||||
allowable_features["possible_is_aromatic_list"],
|
||||
allowable_features["possible_is_in_ring_list"],
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def bond_to_feature_vector(bond):
|
||||
|
@ -332,56 +366,71 @@ def bond_to_feature_vector(bond):
|
|||
:return: list
|
||||
"""
|
||||
bond_feature = [
|
||||
safe_index(allowable_features['possible_bond_type_list'], str(bond.GetBondType())),
|
||||
allowable_features['possible_bond_stereo_list'].index(str(bond.GetStereo())),
|
||||
allowable_features['possible_is_conjugated_list'].index(bond.GetIsConjugated()),
|
||||
]
|
||||
safe_index(
|
||||
allowable_features["possible_bond_type_list"], str(bond.GetBondType())
|
||||
),
|
||||
allowable_features["possible_bond_stereo_list"].index(str(bond.GetStereo())),
|
||||
allowable_features["possible_is_conjugated_list"].index(bond.GetIsConjugated()),
|
||||
]
|
||||
return bond_feature
|
||||
|
||||
|
||||
def get_bond_feature_dims():
|
||||
return list(map(len, [
|
||||
allowable_features['possible_bond_type_list'],
|
||||
allowable_features['possible_bond_stereo_list'],
|
||||
allowable_features['possible_is_conjugated_list']
|
||||
]))
|
||||
return list(
|
||||
map(
|
||||
len,
|
||||
[
|
||||
allowable_features["possible_bond_type_list"],
|
||||
allowable_features["possible_bond_stereo_list"],
|
||||
allowable_features["possible_is_conjugated_list"],
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def atom_feature_vector_to_dict(atom_feature):
|
||||
[atomic_num_idx,
|
||||
chirality_idx,
|
||||
degree_idx,
|
||||
formal_charge_idx,
|
||||
num_h_idx,
|
||||
number_radical_e_idx,
|
||||
hybridization_idx,
|
||||
is_aromatic_idx,
|
||||
is_in_ring_idx] = atom_feature
|
||||
[
|
||||
atomic_num_idx,
|
||||
chirality_idx,
|
||||
degree_idx,
|
||||
formal_charge_idx,
|
||||
num_h_idx,
|
||||
number_radical_e_idx,
|
||||
hybridization_idx,
|
||||
is_aromatic_idx,
|
||||
is_in_ring_idx,
|
||||
] = atom_feature
|
||||
|
||||
feature_dict = {
|
||||
'atomic_num': allowable_features['possible_atomic_num_list'][atomic_num_idx],
|
||||
'chirality': allowable_features['possible_chirality_list'][chirality_idx],
|
||||
'degree': allowable_features['possible_degree_list'][degree_idx],
|
||||
'formal_charge': allowable_features['possible_formal_charge_list'][formal_charge_idx],
|
||||
'num_h': allowable_features['possible_numH_list'][num_h_idx],
|
||||
'num_rad_e': allowable_features['possible_number_radical_e_list'][number_radical_e_idx],
|
||||
'hybridization': allowable_features['possible_hybridization_list'][hybridization_idx],
|
||||
'is_aromatic': allowable_features['possible_is_aromatic_list'][is_aromatic_idx],
|
||||
'is_in_ring': allowable_features['possible_is_in_ring_list'][is_in_ring_idx]
|
||||
"atomic_num": allowable_features["possible_atomic_num_list"][atomic_num_idx],
|
||||
"chirality": allowable_features["possible_chirality_list"][chirality_idx],
|
||||
"degree": allowable_features["possible_degree_list"][degree_idx],
|
||||
"formal_charge": allowable_features["possible_formal_charge_list"][
|
||||
formal_charge_idx
|
||||
],
|
||||
"num_h": allowable_features["possible_numH_list"][num_h_idx],
|
||||
"num_rad_e": allowable_features["possible_number_radical_e_list"][
|
||||
number_radical_e_idx
|
||||
],
|
||||
"hybridization": allowable_features["possible_hybridization_list"][
|
||||
hybridization_idx
|
||||
],
|
||||
"is_aromatic": allowable_features["possible_is_aromatic_list"][is_aromatic_idx],
|
||||
"is_in_ring": allowable_features["possible_is_in_ring_list"][is_in_ring_idx],
|
||||
}
|
||||
|
||||
return feature_dict
|
||||
|
||||
|
||||
def bond_feature_vector_to_dict(bond_feature):
|
||||
[bond_type_idx,
|
||||
bond_stereo_idx,
|
||||
is_conjugated_idx] = bond_feature
|
||||
[bond_type_idx, bond_stereo_idx, is_conjugated_idx] = bond_feature
|
||||
|
||||
feature_dict = {
|
||||
'bond_type': allowable_features['possible_bond_type_list'][bond_type_idx],
|
||||
'bond_stereo': allowable_features['possible_bond_stereo_list'][bond_stereo_idx],
|
||||
'is_conjugated': allowable_features['possible_is_conjugated_list'][is_conjugated_idx]
|
||||
"bond_type": allowable_features["possible_bond_type_list"][bond_type_idx],
|
||||
"bond_stereo": allowable_features["possible_bond_stereo_list"][bond_stereo_idx],
|
||||
"is_conjugated": allowable_features["possible_is_conjugated_list"][
|
||||
is_conjugated_idx
|
||||
],
|
||||
}
|
||||
|
||||
return feature_dict
|
||||
|
|
|
@ -3,12 +3,12 @@ import shutil
|
|||
import csv
|
||||
|
||||
|
||||
def save_best_result(list_of_dict, file_name, dir_path='best_result'):
|
||||
def save_best_result(list_of_dict, file_name, dir_path="best_result"):
|
||||
if not os.path.exists(dir_path):
|
||||
os.mkdir(dir_path)
|
||||
print("Directory ", dir_path, " is created.")
|
||||
csv_file_name = '{}/{}.csv'.format(dir_path, file_name)
|
||||
with open(csv_file_name, 'a+') as csv_file:
|
||||
csv_file_name = "{}/{}.csv".format(dir_path, file_name)
|
||||
with open(csv_file_name, "a+") as csv_file:
|
||||
csv_writer = csv.writer(csv_file)
|
||||
for _ in range(len(list_of_dict)):
|
||||
csv_writer.writerow(list_of_dict[_].values())
|
||||
|
@ -17,10 +17,10 @@ def save_best_result(list_of_dict, file_name, dir_path='best_result'):
|
|||
def create_exp_dir(path, scripts_to_save=None):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
print('Experiment dir : {}'.format(path))
|
||||
print("Experiment dir : {}".format(path))
|
||||
|
||||
if scripts_to_save is not None:
|
||||
os.mkdir(os.path.join(path, 'scripts'))
|
||||
os.mkdir(os.path.join(path, "scripts"))
|
||||
for script in scripts_to_save:
|
||||
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
|
||||
dst_file = os.path.join(path, "scripts", os.path.basename(script))
|
||||
shutil.copyfile(script, dst_file)
|
||||
|
|
|
@ -14,11 +14,13 @@ class SmoothCrossEntropy(torch.nn.Module):
|
|||
if self.smoothing:
|
||||
n_class = pred.size(1)
|
||||
one_hot = torch.zeros_like(pred).scatter(1, gt.view(-1, 1), 1)
|
||||
one_hot = one_hot * (1 - self.eps) + (1 - one_hot) * self.eps / (n_class - 1)
|
||||
one_hot = one_hot * (1 - self.eps) + (1 - one_hot) * self.eps / (
|
||||
n_class - 1
|
||||
)
|
||||
log_prb = F.log_softmax(pred, dim=1)
|
||||
|
||||
loss = -(one_hot * log_prb).sum(dim=1).mean()
|
||||
else:
|
||||
loss = F.cross_entropy(pred, gt, reduction='mean')
|
||||
loss = F.cross_entropy(pred, gt, reduction="mean")
|
||||
|
||||
return loss
|
||||
|
|
|
@ -1,25 +1,24 @@
|
|||
from math import log10
|
||||
|
||||
|
||||
def PSNR(mse, peak=1.):
|
||||
return 10 * log10((peak ** 2) / mse)
|
||||
def PSNR(mse, peak=1.0):
|
||||
return 10 * log10((peak ** 2) / mse)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
|
|
@ -4,7 +4,6 @@ from torch.optim.optimizer import Optimizer, required
|
|||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
self.buffer = [[None, None, None] for ind in range(10)]
|
||||
|
@ -21,55 +20,67 @@ class RAdam(Optimizer):
|
|||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
raise RuntimeError("RAdam does not support sparse gradients")
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
state['step'] += 1
|
||||
buffered = self.buffer[int(state['step'] % 10)]
|
||||
if state['step'] == buffered[0]:
|
||||
state["step"] += 1
|
||||
buffered = self.buffer[int(state["step"] % 10)]
|
||||
if state["step"] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
buffered[0] = state["step"]
|
||||
beta2_t = beta2 ** state["step"]
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
step_size = (
|
||||
group["lr"]
|
||||
* math.sqrt(
|
||||
(1 - beta2_t)
|
||||
* (N_sma - 4)
|
||||
/ (N_sma_max - 4)
|
||||
* (N_sma - 2)
|
||||
/ N_sma
|
||||
* N_sma_max
|
||||
/ (N_sma_max - 2)
|
||||
)
|
||||
/ (1 - beta1 ** state["step"])
|
||||
)
|
||||
else:
|
||||
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
||||
step_size = group["lr"] / (1 - beta1 ** state["step"])
|
||||
buffered[2] = step_size
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
if N_sma >= 5:
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
else:
|
||||
p_data_fp32.add_(-step_size, exp_avg)
|
||||
|
@ -78,8 +89,8 @@ class RAdam(Optimizer):
|
|||
|
||||
return loss
|
||||
|
||||
class PlainRAdam(Optimizer):
|
||||
|
||||
class PlainRAdam(Optimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
|
||||
|
@ -96,46 +107,58 @@ class PlainRAdam(Optimizer):
|
|||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
raise RuntimeError("RAdam does not support sparse gradients")
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
state['step'] += 1
|
||||
beta2_t = beta2 ** state['step']
|
||||
state["step"] += 1
|
||||
beta2_t = beta2 ** state["step"]
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
if N_sma >= 5:
|
||||
step_size = (
|
||||
group["lr"]
|
||||
* math.sqrt(
|
||||
(1 - beta2_t)
|
||||
* (N_sma - 4)
|
||||
/ (N_sma_max - 4)
|
||||
* (N_sma - 2)
|
||||
/ N_sma
|
||||
* N_sma_max
|
||||
/ (N_sma_max - 2)
|
||||
)
|
||||
/ (1 - beta1 ** state["step"])
|
||||
)
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
else:
|
||||
step_size = group['lr'] / (1 - beta1 ** state['step'])
|
||||
step_size = group["lr"] / (1 - beta1 ** state["step"])
|
||||
p_data_fp32.add_(-step_size, exp_avg)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
@ -144,10 +167,18 @@ class PlainRAdam(Optimizer):
|
|||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0):
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, use_variance=True, warmup = warmup)
|
||||
def __init__(
|
||||
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
use_variance=True,
|
||||
warmup=warmup,
|
||||
)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
@ -160,49 +191,51 @@ class AdamW(Optimizer):
|
|||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
raise RuntimeError(
|
||||
"Adam does not support sparse gradients, please consider SparseAdam instead"
|
||||
)
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state['step'] += 1
|
||||
state["step"] += 1
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
if group['warmup'] > state['step']:
|
||||
scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
|
||||
else:
|
||||
scheduled_lr = group['lr']
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
||||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
|
||||
if group["warmup"] > state["step"]:
|
||||
scheduled_lr = 1e-8 + state["step"] * group["lr"] / group["warmup"]
|
||||
else:
|
||||
scheduled_lr = group["lr"]
|
||||
|
||||
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(-group["weight_decay"] * scheduled_lr, p_data_fp32)
|
||||
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
return loss
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import random
|
||||
import os
|
||||
|
||||
print('Using', vtk.vtkVersion.GetVTKSourceVersion())
|
||||
print("Using", vtk.vtkVersion.GetVTKSourceVersion())
|
||||
|
||||
|
||||
class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):
|
||||
|
@ -14,7 +14,7 @@ class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):
|
|||
|
||||
def keyPressEvent(self, obj, event):
|
||||
key = self.parent.GetKeySym()
|
||||
if key == '+':
|
||||
if key == "+":
|
||||
point_size = self.pointcloud.vtkActor.GetProperty().GetPointSize()
|
||||
self.pointcloud.vtkActor.GetProperty().SetPointSize(point_size + 1)
|
||||
print(str(point_size) + " " + key)
|
||||
|
@ -22,7 +22,6 @@ class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):
|
|||
|
||||
|
||||
class VtkPointCloud:
|
||||
|
||||
def __init__(self, point_size=18, maxNumPoints=1e8):
|
||||
self.maxNumPoints = maxNumPoints
|
||||
self.vtkPolyData = vtk.vtkPolyData()
|
||||
|
@ -59,11 +58,11 @@ class VtkPointCloud:
|
|||
self.vtkPoints = vtk.vtkPoints()
|
||||
self.vtkCells = vtk.vtkCellArray()
|
||||
self.vtkDepth = vtk.vtkDoubleArray()
|
||||
self.vtkDepth.SetName('DepthArray')
|
||||
self.vtkDepth.SetName("DepthArray")
|
||||
self.vtkPolyData.SetPoints(self.vtkPoints)
|
||||
self.vtkPolyData.SetVerts(self.vtkCells)
|
||||
self.vtkPolyData.GetPointData().SetScalars(self.vtkDepth)
|
||||
self.vtkPolyData.GetPointData().SetActiveScalars('DepthArray')
|
||||
self.vtkPolyData.GetPointData().SetActiveScalars("DepthArray")
|
||||
|
||||
|
||||
def getActorCircle(radius_inner=100, radius_outer=99, color=(1, 0, 0)):
|
||||
|
@ -95,7 +94,15 @@ def getActorCircle(radius_inner=100, radius_outer=99, color=(1, 0, 0)):
|
|||
return actor
|
||||
|
||||
|
||||
def show_pointclouds(points, colors, text=[], title="Default", png_path="", interactive=True, orientation='horizontal'):
|
||||
def show_pointclouds(
|
||||
points,
|
||||
colors,
|
||||
text=[],
|
||||
title="Default",
|
||||
png_path="",
|
||||
interactive=True,
|
||||
orientation="horizontal",
|
||||
):
|
||||
"""
|
||||
Show multiple point clouds specified as lists. First clouds at the bottom.
|
||||
:param points: list of pointclouds, item: numpy (N x 3) XYZ
|
||||
|
@ -108,16 +115,18 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
|
|||
"""
|
||||
|
||||
# make sure pointclouds is a list
|
||||
assert isinstance(points, type([])), \
|
||||
"Pointclouds argument must be a list"
|
||||
assert isinstance(points, type([])), "Pointclouds argument must be a list"
|
||||
|
||||
# make sure colors is a list
|
||||
assert isinstance(colors, type([])), \
|
||||
"Colors argument must be a list"
|
||||
assert isinstance(colors, type([])), "Colors argument must be a list"
|
||||
|
||||
# make sure number of pointclouds and colors are the same
|
||||
assert len(points) == len(colors), \
|
||||
"Number of pointclouds (%d) is different then number of colors (%d)" % (len(points), len(colors))
|
||||
assert len(points) == len(
|
||||
colors
|
||||
), "Number of pointclouds (%d) is different then number of colors (%d)" % (
|
||||
len(points),
|
||||
len(colors),
|
||||
)
|
||||
|
||||
while len(text) < len(points):
|
||||
text.append("")
|
||||
|
@ -130,15 +139,20 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
|
|||
renderers = [vtk.vtkRenderer() for _ in range(num_pointclouds)]
|
||||
|
||||
height = 1.0 / max(num_pointclouds, 1)
|
||||
viewports = [(i*height, (i+1)*height) for i in range(num_pointclouds)]
|
||||
#print(viewports)
|
||||
viewports = [(i * height, (i + 1) * height) for i in range(num_pointclouds)]
|
||||
# print(viewports)
|
||||
|
||||
# iterate over all point clouds
|
||||
for i, pc in enumerate(points):
|
||||
pc = pc.squeeze()
|
||||
co = colors[i].squeeze()
|
||||
assert pc.shape[0] == co.shape[0], \
|
||||
"expected same number of points (%d) then colors (%d), cloud index = %d" % (pc.shape[0], co.shape[0], i)
|
||||
assert (
|
||||
pc.shape[0] == co.shape[0]
|
||||
), "expected same number of points (%d) then colors (%d), cloud index = %d" % (
|
||||
pc.shape[0],
|
||||
co.shape[0],
|
||||
i,
|
||||
)
|
||||
assert pc.shape[1] == 3, "expected points to be N x 3, got N x %d" % pc.shape[1]
|
||||
assert co.shape[1] == 3, "expected colors to be N x 3, got N x %d" % co.shape[1]
|
||||
|
||||
|
@ -151,13 +165,13 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
|
|||
renderers[i].AddActor(pointclouds[i].vtkActor)
|
||||
# renderers[i].AddActor(vtk.vtkAxesActor())
|
||||
renderers[i].SetBackground(1.0, 1.0, 1.0)
|
||||
if orientation == 'horizontal':
|
||||
if orientation == "horizontal":
|
||||
print(viewports[i][0])
|
||||
renderers[i].SetViewport(viewports[i][0], 0.0, viewports[i][1], 1.0)
|
||||
elif orientation == 'vertical':
|
||||
elif orientation == "vertical":
|
||||
renderers[i].SetViewport(0.0, viewports[i][0], 1.0, viewports[i][1])
|
||||
else:
|
||||
raise Exception('Not a valid orientation!')
|
||||
raise Exception("Not a valid orientation!")
|
||||
renderers[i].ResetCamera()
|
||||
|
||||
# Add circle to first render
|
||||
|
@ -167,12 +181,12 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
|
|||
# Text actors
|
||||
text_actors = [vtk.vtkTextActor() for _ in text]
|
||||
for i, ta in enumerate(text_actors):
|
||||
if orientation == 'horizontal':
|
||||
ta.SetInput(' ' + text[i])
|
||||
elif orientation == 'vertical':
|
||||
ta.SetInput(text[i] + '\n\n\n\n\n\n')
|
||||
if orientation == "horizontal":
|
||||
ta.SetInput(" " + text[i])
|
||||
elif orientation == "vertical":
|
||||
ta.SetInput(text[i] + "\n\n\n\n\n\n")
|
||||
else:
|
||||
raise Exception('Not a valid orientation!')
|
||||
raise Exception("Not a valid orientation!")
|
||||
txtprop = ta.GetTextProperty()
|
||||
txtprop.SetFontFamilyToArial()
|
||||
txtprop.SetFontSize(0)
|
||||
|
@ -201,14 +215,14 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
|
|||
# camera.SetFocalPoint(0, 0, 0)
|
||||
|
||||
camera.SetViewUp(0, 0, 1)
|
||||
if orientation == 'horizontal':
|
||||
if orientation == "horizontal":
|
||||
camera.SetPosition(3, -10, 2)
|
||||
camera.SetFocalPoint(3, 1.5, 1.5)
|
||||
elif orientation == 'vertical':
|
||||
elif orientation == "vertical":
|
||||
camera.SetPosition(1.5, -6, 2)
|
||||
camera.SetFocalPoint(1.5, 1.5, 1.5)
|
||||
else:
|
||||
raise Exception('Not a valid orientation!')
|
||||
raise Exception("Not a valid orientation!")
|
||||
|
||||
camera.SetClippingRange(0.002, 1000)
|
||||
for renderer in renderers:
|
||||
|
@ -217,12 +231,12 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
|
|||
# Begin Interaction
|
||||
render_window.Render()
|
||||
render_window.SetWindowName(title)
|
||||
if orientation == 'horizontal':
|
||||
if orientation == "horizontal":
|
||||
render_window.SetSize(1940, 720)
|
||||
elif orientation == 'vertical':
|
||||
elif orientation == "vertical":
|
||||
render_window.SetSize(600, 1388)
|
||||
else:
|
||||
raise Exception('Not a valid orientation!')
|
||||
raise Exception("Not a valid orientation!")
|
||||
|
||||
if interactive:
|
||||
render_window_interactor.Start()
|
||||
|
@ -253,10 +267,20 @@ def get_points_colors_from_obj(filename, limit=1):
|
|||
return points[idx, :], colors[idx, :]
|
||||
|
||||
|
||||
def visualize_part_seg(file_name_pred, file_name_gt, comparison_folder_list, limit=1, text=[], png_path="",
|
||||
interactive=True, orientation='horizontal'):
|
||||
def visualize_part_seg(
|
||||
file_name_pred,
|
||||
file_name_gt,
|
||||
comparison_folder_list,
|
||||
limit=1,
|
||||
text=[],
|
||||
png_path="",
|
||||
interactive=True,
|
||||
orientation="horizontal",
|
||||
):
|
||||
# load base point cloud
|
||||
gt_points, gt_colors = get_points_colors_from_obj(os.path.join(comparison_folder_list[0], file_name_gt), limit)
|
||||
gt_points, gt_colors = get_points_colors_from_obj(
|
||||
os.path.join(comparison_folder_list[0], file_name_gt), limit
|
||||
)
|
||||
|
||||
idx_gt = gt_points[:, 1] >= limit
|
||||
|
||||
|
@ -264,12 +288,19 @@ def visualize_part_seg(file_name_pred, file_name_gt, comparison_folder_list, lim
|
|||
all_colors = [gt_colors[idx_gt, :3]]
|
||||
|
||||
for folder in comparison_folder_list:
|
||||
pts, col = get_points_colors_from_obj(os.path.join(folder, file_name_pred), limit=limit)
|
||||
pts, col = get_points_colors_from_obj(
|
||||
os.path.join(folder, file_name_pred), limit=limit
|
||||
)
|
||||
|
||||
all_points.append(pts)
|
||||
all_colors.append(col)
|
||||
|
||||
print(np.asarray(all_points).shape)
|
||||
show_pointclouds(all_points, all_colors, text=text, png_path=png_path, interactive=interactive,
|
||||
orientation=orientation)
|
||||
|
||||
show_pointclouds(
|
||||
all_points,
|
||||
all_colors,
|
||||
text=text,
|
||||
png_path=png_path,
|
||||
interactive=interactive,
|
||||
orientation=orientation,
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@ try:
|
|||
import tensorflow as tf
|
||||
import tensorboard.plugins.mesh.summary as meshsummary
|
||||
except ImportError:
|
||||
print('tensorflow is not installed.')
|
||||
print("tensorflow is not installed.")
|
||||
import numpy as np
|
||||
import scipy.misc
|
||||
|
||||
|
@ -11,38 +11,38 @@ import scipy.misc
|
|||
try:
|
||||
from StringIO import StringIO # Python 2.7
|
||||
except ImportError:
|
||||
from io import BytesIO # Python 3.x
|
||||
from io import BytesIO # Python 3.x
|
||||
|
||||
|
||||
class TfLogger(object):
|
||||
|
||||
def __init__(self, log_dir):
|
||||
"""Create a summary writer logging to log_dir."""
|
||||
self.writer = tf.compat.v1.summary.FileWriter(log_dir)
|
||||
|
||||
# Camera and scene configuration.
|
||||
self.config_dict = {
|
||||
'camera': {'cls': 'PerspectiveCamera', 'fov': 75},
|
||||
'lights': [
|
||||
"camera": {"cls": "PerspectiveCamera", "fov": 75},
|
||||
"lights": [
|
||||
{
|
||||
'cls': 'AmbientLight',
|
||||
'color': '#ffffff',
|
||||
'intensity': 0.75,
|
||||
}, {
|
||||
'cls': 'DirectionalLight',
|
||||
'color': '#ffffff',
|
||||
'intensity': 0.75,
|
||||
'position': [0, -1, 2],
|
||||
}],
|
||||
'material': {
|
||||
'cls': 'MeshStandardMaterial',
|
||||
'metalness': 0
|
||||
}
|
||||
"cls": "AmbientLight",
|
||||
"color": "#ffffff",
|
||||
"intensity": 0.75,
|
||||
},
|
||||
{
|
||||
"cls": "DirectionalLight",
|
||||
"color": "#ffffff",
|
||||
"intensity": 0.75,
|
||||
"position": [0, -1, 2],
|
||||
},
|
||||
],
|
||||
"material": {"cls": "MeshStandardMaterial", "metalness": 0},
|
||||
}
|
||||
|
||||
def scalar_summary(self, tag, value, step):
|
||||
"""Log a scalar variable."""
|
||||
summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)])
|
||||
summary = tf.compat.v1.Summary(
|
||||
value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]
|
||||
)
|
||||
self.writer.add_summary(summary, step)
|
||||
|
||||
def image_summary(self, tag, images, step):
|
||||
|
@ -54,10 +54,15 @@ class TfLogger(object):
|
|||
scipy.misc.toimage(img).save(s, format="png")
|
||||
|
||||
# Create an Image object
|
||||
img_sum = tf.compat.v1.Summary.Image(encoded_image_string=s.getvalue(),
|
||||
height=img.shape[0], width=img.shape[1])
|
||||
img_sum = tf.compat.v1.Summary.Image(
|
||||
encoded_image_string=s.getvalue(),
|
||||
height=img.shape[0],
|
||||
width=img.shape[1],
|
||||
)
|
||||
# Create a Summary value
|
||||
img_summaries.append(tf.compat.v1.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
|
||||
img_summaries.append(
|
||||
tf.compat.v1.Summary.Value(tag="%s/%d" % (tag, i), image=img_sum)
|
||||
)
|
||||
|
||||
# Create and write Summary
|
||||
summary = tf.Summary(value=img_summaries)
|
||||
|
@ -71,10 +76,17 @@ class TfLogger(object):
|
|||
vertices = tf.constant(vertices)
|
||||
if faces is not None:
|
||||
faces = tf.constant(faces)
|
||||
meshes_summares=[]
|
||||
meshes_summares = []
|
||||
for i in range(vertices.shape[0]):
|
||||
meshes_summares.append(meshsummary.op(
|
||||
tag, vertices=vertices, faces=faces, colors=colors, config_dict=self.config_dict))
|
||||
meshes_summares.append(
|
||||
meshsummary.op(
|
||||
tag,
|
||||
vertices=vertices,
|
||||
faces=faces,
|
||||
colors=colors,
|
||||
config_dict=self.config_dict,
|
||||
)
|
||||
)
|
||||
|
||||
sess = tf.Session()
|
||||
summaries = sess.run(meshes_summares)
|
||||
|
@ -93,7 +105,7 @@ class TfLogger(object):
|
|||
hist.max = float(np.max(values))
|
||||
hist.num = int(np.prod(values.shape))
|
||||
hist.sum = float(np.sum(values))
|
||||
hist.sum_squares = float(np.sum(values**2))
|
||||
hist.sum_squares = float(np.sum(values ** 2))
|
||||
|
||||
# Drop the start of the first bin
|
||||
bin_edges = bin_edges[1:]
|
||||
|
@ -108,4 +120,3 @@ class TfLogger(object):
|
|||
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
|
|
|
@ -6,4 +6,4 @@ from .mlp import MLPP
|
|||
from .polynomialregression import Polynomial
|
||||
from .randomforest import RandomForest
|
||||
from .rbm import RBM
|
||||
from .svr import RSVR
|
||||
from .svr import RSVR
|
||||
|
|
|
@ -11,22 +11,27 @@ from .util.base import Base
|
|||
|
||||
|
||||
class DBN(Base):
|
||||
|
||||
def __init__(self,u_obs,u,nComponents=[250,30,5],constant=298):
|
||||
super().__init__(u_obs,u)
|
||||
self.n_components=nComponents
|
||||
def __init__(self, u_obs, u, nComponents=[250, 30, 5], constant=298):
|
||||
super().__init__(u_obs, u)
|
||||
self.n_components = nComponents
|
||||
|
||||
def dbn(self):
|
||||
models = []
|
||||
for num, components in zip(range(0,len(self.n_components)),self.n_components):
|
||||
model = rbm(n_components=components, n_iter=300, learning_rate=0.06, random_state=1,verbose=True)
|
||||
model_name = 'rbm' + str(num)
|
||||
for num, components in zip(range(0, len(self.n_components)), self.n_components):
|
||||
model = rbm(
|
||||
n_components=components,
|
||||
n_iter=300,
|
||||
learning_rate=0.06,
|
||||
random_state=1,
|
||||
verbose=True,
|
||||
)
|
||||
model_name = "rbm" + str(num)
|
||||
models.append((model_name, model))
|
||||
models.append(('clf', LinearRegression()))
|
||||
models.append(("clf", LinearRegression()))
|
||||
return models
|
||||
|
||||
def predict(self):
|
||||
|
||||
|
||||
self.pred_init()
|
||||
X, Y = self.train_samples()
|
||||
test_samples = self.test_samples()
|
||||
|
@ -35,58 +40,58 @@ class DBN(Base):
|
|||
regressor = regressor = Pipeline(models)
|
||||
regressor.fit(X, Y)
|
||||
|
||||
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1])
|
||||
self.u_pred = regressor.predict(test_samples).reshape(
|
||||
self.u.shape[0], self.u.shape[1]
|
||||
)
|
||||
return self.u_pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m=sio.loadmat('Example0.mat')
|
||||
u_obs=m['u_obs']
|
||||
u=m['u']
|
||||
sample = DBN(u_obs,u, nComponents=[250,30,5])
|
||||
if __name__ == "__main__":
|
||||
m = sio.loadmat("Example0.mat")
|
||||
u_obs = m["u_obs"]
|
||||
u = m["u"]
|
||||
sample = DBN(u_obs, u, nComponents=[250, 30, 5])
|
||||
u_pred = sample.predict()
|
||||
print('mae:',mae(u_pred,u))
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
u_pred=u_pred*50+298
|
||||
u_pred = u_pred * 50 + 298
|
||||
from sklearn.metrics import mean_absolute_error as mae
|
||||
print('mae:',mae(u_pred,u))
|
||||
|
||||
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.pcolormesh(X,Y,abs(u-u_pred))
|
||||
plt.title("Absolute Error")
|
||||
im = plt.pcolormesh(X, Y, abs(u - u_pred))
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
im = plt.contourf(X,Y,u,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
im = plt.contourf(X, Y, u, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet')
|
||||
plt.title("Absolute Error")
|
||||
im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
#save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
#fig.savefig(save_name, dpi=300)
|
||||
# save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
# fig.savefig(save_name, dpi=300)
|
||||
|
||||
|
||||
#fig = plt.figure(figsize=(5,5))
|
||||
#im = plt.imshow(u,cmap='jet')
|
||||
#plt.colorbar(im)
|
||||
fig.savefig('prediction.png', dpi=300)
|
||||
# fig = plt.figure(figsize=(5,5))
|
||||
# im = plt.imshow(u,cmap='jet')
|
||||
# plt.colorbar(im)
|
||||
fig.savefig("prediction.png", dpi=300)
|
||||
|
|
|
@ -8,13 +8,12 @@ from .util.base import Base
|
|||
|
||||
|
||||
class GInterpolation(Base):
|
||||
|
||||
def __init__(self,u_obs,u,constant=298):
|
||||
super().__init__(u_obs,u)
|
||||
def __init__(self, u_obs, u, constant=298):
|
||||
super().__init__(u_obs, u)
|
||||
|
||||
def predict(self):
|
||||
|
||||
row = np.linspace(0, self.u.shape[0]-1, num=self.u.shape[0])
|
||||
row = np.linspace(0, self.u.shape[0] - 1, num=self.u.shape[0])
|
||||
col = row
|
||||
col, row = np.meshgrid(col, row)
|
||||
|
||||
|
@ -24,61 +23,72 @@ class GInterpolation(Base):
|
|||
col = np.dot(np.ones_like(self.cols).reshape(-1, 1), col)
|
||||
row = np.dot(np.ones_like(self.rows).reshape(-1, 1), row)
|
||||
|
||||
ind = np.dot(np.ones_like(self.rows).reshape(1, -1), (np.exp(-np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2)))))
|
||||
ind = np.dot(
|
||||
np.ones_like(self.rows).reshape(1, -1),
|
||||
(
|
||||
np.exp(
|
||||
-np.sqrt(
|
||||
np.power((self.rows.reshape(-1, 1) - row), 2)
|
||||
+ np.power((self.cols.reshape(-1, 1) - col), 2)
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
param = np.exp(-np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2)))/(np.dot(np.ones_like(self.rows).reshape(-1, 1), ind))
|
||||
param = np.exp(
|
||||
-np.sqrt(
|
||||
np.power((self.rows.reshape(-1, 1) - row), 2)
|
||||
+ np.power((self.cols.reshape(-1, 1) - col), 2)
|
||||
)
|
||||
) / (np.dot(np.ones_like(self.rows).reshape(-1, 1), ind))
|
||||
|
||||
dis = np.dot(self.u_obs[self.rows,self.cols].reshape(1, -1), param)
|
||||
dis = np.dot(self.u_obs[self.rows, self.cols].reshape(1, -1), param)
|
||||
self.u_pred = dis.reshape(self.u.shape[0], self.u.shape[1])
|
||||
|
||||
|
||||
return self.u_pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m=sio.loadmat('Example10001.mat')
|
||||
u_obs=m['u_obs']
|
||||
u=m['u']
|
||||
sample = GInterpolation(u_obs,u)
|
||||
if __name__ == "__main__":
|
||||
m = sio.loadmat("Example10001.mat")
|
||||
u_obs = m["u_obs"]
|
||||
u = m["u"]
|
||||
sample = GInterpolation(u_obs, u)
|
||||
u_pred = sample.predict()
|
||||
print('mae:',mae(u_pred,u))
|
||||
|
||||
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.pcolormesh(X,Y,abs(u-u_pred))
|
||||
plt.title("Absolute Error")
|
||||
im = plt.pcolormesh(X, Y, abs(u - u_pred))
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
im = plt.contourf(X,Y,u,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
im = plt.contourf(X, Y, u, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet')
|
||||
plt.title("Absolute Error")
|
||||
im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
#save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
#fig.savefig(save_name, dpi=300)
|
||||
# save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
# fig.savefig(save_name, dpi=300)
|
||||
|
||||
|
||||
#fig = plt.figure(figsize=(5,5))
|
||||
#im = plt.imshow(u,cmap='jet')
|
||||
#plt.colorbar(im)
|
||||
fig.savefig('prediction.png', dpi=300)
|
||||
# fig = plt.figure(figsize=(5,5))
|
||||
# im = plt.imshow(u,cmap='jet')
|
||||
# plt.colorbar(im)
|
||||
fig.savefig("prediction.png", dpi=300)
|
||||
|
|
|
@ -8,22 +8,23 @@ from .util.base import Base
|
|||
|
||||
|
||||
class KInterpolation(Base):
|
||||
|
||||
def __init__(self,u_obs,u,k=5,constant=298):
|
||||
super().__init__(u_obs,u)
|
||||
def __init__(self, u_obs, u, k=5, constant=298):
|
||||
super().__init__(u_obs, u)
|
||||
self.k = k
|
||||
|
||||
def knearest(self, row, col):
|
||||
d = np.zeros_like(self.rows)
|
||||
for k in range(self.rows.shape[0]):
|
||||
d[k] = math.sqrt(math.pow((self.rows[k]-row),2)+math.pow((self.cols[k]-col),2))
|
||||
kpoint = np.argsort(d)[:self.k]
|
||||
d[k] = math.sqrt(
|
||||
math.pow((self.rows[k] - row), 2) + math.pow((self.cols[k] - col), 2)
|
||||
)
|
||||
kpoint = np.argsort(d)[: self.k]
|
||||
return self.rows[kpoint], self.cols[kpoint]
|
||||
|
||||
def predict(self):
|
||||
|
||||
row = np.linspace(0, self.u.shape[0]-1, num=self.u.shape[0])
|
||||
col = np.linspace(0, self.u.shape[0]-1, num=self.u.shape[0])
|
||||
|
||||
row = np.linspace(0, self.u.shape[0] - 1, num=self.u.shape[0])
|
||||
col = np.linspace(0, self.u.shape[0] - 1, num=self.u.shape[0])
|
||||
col, row = np.meshgrid(col, row)
|
||||
|
||||
col = col.reshape(1, -1)
|
||||
|
@ -32,69 +33,90 @@ class KInterpolation(Base):
|
|||
col = np.dot(np.ones_like(self.cols).reshape(-1, 1), col)
|
||||
row = np.dot(np.ones_like(self.rows).reshape(-1, 1), row)
|
||||
|
||||
ksort = np.argsort(np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2)), axis=0)
|
||||
ksort = np.argsort(
|
||||
np.sqrt(
|
||||
np.power((self.rows.reshape(-1, 1) - row), 2)
|
||||
+ np.power((self.cols.reshape(-1, 1) - col), 2)
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
kind = np.zeros_like(ksort)
|
||||
|
||||
for num in range(ksort.shape[1]):
|
||||
kind[ksort[:self.k, num], num] = 1
|
||||
kind[ksort[: self.k, num], num] = 1
|
||||
|
||||
assert (self.k >= 1)
|
||||
assert self.k >= 1
|
||||
|
||||
ind = np.dot(np.ones_like(self.rows).reshape(1, -1), np.exp(-np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2))) * kind)
|
||||
ind = np.dot(
|
||||
np.ones_like(self.rows).reshape(1, -1),
|
||||
np.exp(
|
||||
-np.sqrt(
|
||||
np.power((self.rows.reshape(-1, 1) - row), 2)
|
||||
+ np.power((self.cols.reshape(-1, 1) - col), 2)
|
||||
)
|
||||
)
|
||||
* kind,
|
||||
)
|
||||
|
||||
param = np.exp(-np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2)))*kind/(np.dot(np.ones_like(self.rows).reshape(-1, 1), ind))
|
||||
param = (
|
||||
np.exp(
|
||||
-np.sqrt(
|
||||
np.power((self.rows.reshape(-1, 1) - row), 2)
|
||||
+ np.power((self.cols.reshape(-1, 1) - col), 2)
|
||||
)
|
||||
)
|
||||
* kind
|
||||
/ (np.dot(np.ones_like(self.rows).reshape(-1, 1), ind))
|
||||
)
|
||||
|
||||
dis = np.dot(self.u_obs[self.rows,self.cols].reshape(1, -1), param)
|
||||
dis = np.dot(self.u_obs[self.rows, self.cols].reshape(1, -1), param)
|
||||
self.u_pred = dis.reshape(self.u.shape[0], self.u.shape[1])
|
||||
|
||||
|
||||
return self.u_pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m=sio.loadmat('Example0.mat')
|
||||
u_obs=m['u_obs']
|
||||
u=m['u']
|
||||
sample = KInterpolation(u_obs,u, k=4)
|
||||
if __name__ == "__main__":
|
||||
m = sio.loadmat("Example0.mat")
|
||||
u_obs = m["u_obs"]
|
||||
u = m["u"]
|
||||
sample = KInterpolation(u_obs, u, k=4)
|
||||
u_pred = sample.predict()
|
||||
print('mae:',mae(u_pred,u))
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.pcolormesh(X,Y,abs(u-u_pred))
|
||||
plt.title("Absolute Error")
|
||||
im = plt.pcolormesh(X, Y, abs(u - u_pred))
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
im = plt.contourf(X,Y,u,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
im = plt.contourf(X, Y, u, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet')
|
||||
plt.title("Absolute Error")
|
||||
im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
#save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
#fig.savefig(save_name, dpi=300)
|
||||
# save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
# fig.savefig(save_name, dpi=300)
|
||||
|
||||
|
||||
#fig = plt.figure(figsize=(5,5))
|
||||
#im = plt.imshow(u,cmap='jet')
|
||||
#plt.colorbar(im)
|
||||
fig.savefig('prediction.png', dpi=300)
|
||||
# fig = plt.figure(figsize=(5,5))
|
||||
# im = plt.imshow(u,cmap='jet')
|
||||
# plt.colorbar(im)
|
||||
fig.savefig("prediction.png", dpi=300)
|
||||
|
|
|
@ -8,78 +8,76 @@ from .util.base import Base
|
|||
|
||||
|
||||
class Kriging(Base):
|
||||
|
||||
def __init__(self,u_obs,u,constant=298):
|
||||
super().__init__(u_obs,u)
|
||||
def __init__(self, u_obs, u, constant=298):
|
||||
super().__init__(u_obs, u)
|
||||
|
||||
def predict(self):
|
||||
|
||||
|
||||
self.pred_init()
|
||||
X, Y = self.train_samples()
|
||||
test_samples = self.test_samples()
|
||||
|
||||
kernel = 1.0 * gp.kernels.RBF(1.0) + gp.kernels.WhiteKernel() # + gp.kernels.DotProduct()
|
||||
regressor = gp.GaussianProcessRegressor(kernel=kernel,n_restarts_optimizer=10, alpha=0.01)
|
||||
kernel = (
|
||||
1.0 * gp.kernels.RBF(1.0) + gp.kernels.WhiteKernel()
|
||||
) # + gp.kernels.DotProduct()
|
||||
regressor = gp.GaussianProcessRegressor(
|
||||
kernel=kernel, n_restarts_optimizer=10, alpha=0.01
|
||||
)
|
||||
regressor.fit(X, Y)
|
||||
|
||||
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1])
|
||||
self.u_pred = regressor.predict(test_samples).reshape(
|
||||
self.u.shape[0], self.u.shape[1]
|
||||
)
|
||||
return self.u_pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m=sio.loadmat('Example0.mat')
|
||||
u_obs=m['u_obs']
|
||||
u=m['u']
|
||||
if __name__ == "__main__":
|
||||
m = sio.loadmat("Example0.mat")
|
||||
u_obs = m["u_obs"]
|
||||
u = m["u"]
|
||||
|
||||
sample = Kriging(u_obs,u)
|
||||
sample = Kriging(u_obs, u)
|
||||
u_pred = sample.predict()
|
||||
|
||||
|
||||
u_pred=u_pred*50+298
|
||||
u_pred = u_pred * 50 + 298
|
||||
|
||||
from sklearn.metrics import mean_absolute_error as mae
|
||||
print('mae:',mae(u_pred,u))
|
||||
|
||||
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.pcolormesh(X,Y,abs(u-u_pred))
|
||||
plt.title("Absolute Error")
|
||||
im = plt.pcolormesh(X, Y, abs(u - u_pred))
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
im = plt.contourf(X,Y,u,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
im = plt.contourf(X, Y, u, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet')
|
||||
plt.title("Absolute Error")
|
||||
im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
#save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
#fig.savefig(save_name, dpi=300)
|
||||
# save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
# fig.savefig(save_name, dpi=300)
|
||||
|
||||
|
||||
#fig = plt.figure(figsize=(5,5))
|
||||
#im = plt.imshow(u,cmap='jet')
|
||||
#plt.colorbar(im)
|
||||
fig.savefig('prediction.png', dpi=300)
|
||||
|
||||
|
||||
|
||||
# fig = plt.figure(figsize=(5,5))
|
||||
# im = plt.imshow(u,cmap='jet')
|
||||
# plt.colorbar(im)
|
||||
fig.savefig("prediction.png", dpi=300)
|
||||
|
|
|
@ -10,77 +10,78 @@ from .util.base import Base
|
|||
|
||||
|
||||
class MLPP(Base):
|
||||
|
||||
def __init__(self,u_obs,u,layers=[100,50],constant=298):
|
||||
super().__init__(u_obs,u)
|
||||
def __init__(self, u_obs, u, layers=[100, 50], constant=298):
|
||||
super().__init__(u_obs, u)
|
||||
self.layers = layers
|
||||
|
||||
def predict(self):
|
||||
|
||||
|
||||
self.pred_init()
|
||||
X, Y = self.train_samples()
|
||||
test_samples = self.test_samples()
|
||||
|
||||
regressor = MLPRegressor(hidden_layer_sizes=self.layers,alpha=2,solver='adam',random_state=1,max_iter=20000)
|
||||
regressor = MLPRegressor(
|
||||
hidden_layer_sizes=self.layers,
|
||||
alpha=2,
|
||||
solver="adam",
|
||||
random_state=1,
|
||||
max_iter=20000,
|
||||
)
|
||||
regressor.fit(X, Y)
|
||||
|
||||
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1])
|
||||
self.u_pred = regressor.predict(test_samples).reshape(
|
||||
self.u.shape[0], self.u.shape[1]
|
||||
)
|
||||
return self.u_pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m=sio.loadmat('Example0.mat')
|
||||
u_obs=m['u_obs']
|
||||
u=m['u']
|
||||
sample = MLP(u_obs,u)
|
||||
if __name__ == "__main__":
|
||||
m = sio.loadmat("Example0.mat")
|
||||
u_obs = m["u_obs"]
|
||||
u = m["u"]
|
||||
sample = MLPP(u_obs, u)
|
||||
u_pred = sample.predict()
|
||||
print('mae:',mae(u_pred,u))
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
#u_pred=u_pred*50+298
|
||||
# u_pred=u_pred*50+298
|
||||
|
||||
from sklearn.metrics import mean_absolute_error as mae
|
||||
print('mae:',mae(u_pred,u))
|
||||
|
||||
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.pcolormesh(X,Y,abs(u-u_pred))
|
||||
plt.title("Absolute Error")
|
||||
im = plt.pcolormesh(X, Y, abs(u - u_pred))
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
im = plt.contourf(X,Y,u,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
im = plt.contourf(X, Y, u, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet')
|
||||
plt.title("Absolute Error")
|
||||
im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
#save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
#fig.savefig(save_name, dpi=300)
|
||||
# save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
# fig.savefig(save_name, dpi=300)
|
||||
|
||||
|
||||
#fig = plt.figure(figsize=(5,5))
|
||||
#im = plt.imshow(u,cmap='jet')
|
||||
#plt.colorbar(im)
|
||||
fig.savefig('prediction.png', dpi=300)
|
||||
|
||||
|
||||
|
||||
# fig = plt.figure(figsize=(5,5))
|
||||
# im = plt.imshow(u,cmap='jet')
|
||||
# plt.colorbar(im)
|
||||
fig.savefig("prediction.png", dpi=300)
|
||||
|
|
|
@ -10,76 +10,76 @@ from .util.base import Base
|
|||
|
||||
|
||||
class Polynomial(Base):
|
||||
|
||||
def __init__(self,u_obs,u,degree=5, constant=298):
|
||||
super().__init__(u_obs,u)
|
||||
def __init__(self, u_obs, u, degree=5, constant=298):
|
||||
super().__init__(u_obs, u)
|
||||
self.degree = degree
|
||||
|
||||
def predict(self):
|
||||
|
||||
|
||||
self.pred_init()
|
||||
X, Y = self.train_samples()
|
||||
test_samples = self.test_samples()
|
||||
|
||||
regressor = Pipeline([('poly', PolynomialFeatures(degree=self.degree, include_bias=False)), ('clf', LinearRegression())])
|
||||
regressor = Pipeline(
|
||||
[
|
||||
("poly", PolynomialFeatures(degree=self.degree, include_bias=False)),
|
||||
("clf", LinearRegression()),
|
||||
]
|
||||
)
|
||||
regressor.fit(X, Y)
|
||||
|
||||
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1])
|
||||
self.u_pred = regressor.predict(test_samples).reshape(
|
||||
self.u.shape[0], self.u.shape[1]
|
||||
)
|
||||
return self.u_pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m=sio.loadmat('Example0.mat')
|
||||
u_obs=m['u_obs']
|
||||
u=m['u']
|
||||
sample = Polynomial(u_obs,u, degree=5)
|
||||
if __name__ == "__main__":
|
||||
m = sio.loadmat("Example0.mat")
|
||||
u_obs = m["u_obs"]
|
||||
u = m["u"]
|
||||
sample = Polynomial(u_obs, u, degree=5)
|
||||
u_pred = sample.predict()
|
||||
|
||||
u_pred=u_pred*50+298
|
||||
|
||||
u_pred = u_pred * 50 + 298
|
||||
|
||||
from sklearn.metrics import mean_absolute_error as mae
|
||||
print('mae:',mae(u_pred,u))
|
||||
|
||||
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.pcolormesh(X,Y,abs(u-u_pred))
|
||||
plt.title("Absolute Error")
|
||||
im = plt.pcolormesh(X, Y, abs(u - u_pred))
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
im = plt.contourf(X,Y,u,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
im = plt.contourf(X, Y, u, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet')
|
||||
plt.title("Absolute Error")
|
||||
im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
#save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
#fig.savefig(save_name, dpi=300)
|
||||
# save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
# fig.savefig(save_name, dpi=300)
|
||||
|
||||
|
||||
#fig = plt.figure(figsize=(5,5))
|
||||
#im = plt.imshow(u,cmap='jet')
|
||||
#plt.colorbar(im)
|
||||
fig.savefig('prediction.png', dpi=300)
|
||||
|
||||
|
||||
|
||||
# fig = plt.figure(figsize=(5,5))
|
||||
# im = plt.imshow(u,cmap='jet')
|
||||
# plt.colorbar(im)
|
||||
fig.savefig("prediction.png", dpi=300)
|
||||
|
|
|
@ -9,12 +9,11 @@ from sklearn.metrics import mean_absolute_error as mae
|
|||
|
||||
|
||||
class RandomForest(Base):
|
||||
|
||||
def __init__(self,u_obs,u,constant=298):
|
||||
super().__init__(u_obs,u)
|
||||
def __init__(self, u_obs, u, constant=298):
|
||||
super().__init__(u_obs, u)
|
||||
|
||||
def predict(self):
|
||||
|
||||
|
||||
self.pred_init()
|
||||
X, Y = self.train_samples()
|
||||
test_samples = self.test_samples()
|
||||
|
@ -22,63 +21,59 @@ class RandomForest(Base):
|
|||
regressor = RandomForestRegressor(n_estimators=500, random_state=10)
|
||||
regressor.fit(X, Y)
|
||||
|
||||
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1])
|
||||
self.u_pred = regressor.predict(test_samples).reshape(
|
||||
self.u.shape[0], self.u.shape[1]
|
||||
)
|
||||
return self.u_pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m=sio.loadmat('Example0.mat')
|
||||
u_obs=m['u_obs']
|
||||
u=m['u']
|
||||
sample = RandomForest(u_obs,u)
|
||||
if __name__ == "__main__":
|
||||
m = sio.loadmat("Example0.mat")
|
||||
u_obs = m["u_obs"]
|
||||
u = m["u"]
|
||||
sample = RandomForest(u_obs, u)
|
||||
u_pred = sample.predict()
|
||||
print('mae:',mae(u_pred,u))
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
u_pred=u_pred*50+298
|
||||
u_pred = u_pred * 50 + 298
|
||||
|
||||
from sklearn.metrics import mean_absolute_error as mae
|
||||
print('mae:',mae(u_pred,u))
|
||||
|
||||
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.pcolormesh(X,Y,abs(u-u_pred))
|
||||
plt.title("Absolute Error")
|
||||
im = plt.pcolormesh(X, Y, abs(u - u_pred))
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
im = plt.contourf(X,Y,u,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
im = plt.contourf(X, Y, u, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet')
|
||||
plt.title("Absolute Error")
|
||||
im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
#save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
#fig.savefig(save_name, dpi=300)
|
||||
# save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
# fig.savefig(save_name, dpi=300)
|
||||
|
||||
|
||||
#fig = plt.figure(figsize=(5,5))
|
||||
#im = plt.imshow(u,cmap='jet')
|
||||
#plt.colorbar(im)
|
||||
fig.savefig('prediction.png', dpi=300)
|
||||
|
||||
|
||||
|
||||
# fig = plt.figure(figsize=(5,5))
|
||||
# im = plt.imshow(u,cmap='jet')
|
||||
# plt.colorbar(im)
|
||||
fig.savefig("prediction.png", dpi=300)
|
||||
|
|
|
@ -11,77 +11,78 @@ from .util.base import Base
|
|||
|
||||
|
||||
class RBM(Base):
|
||||
|
||||
def __init__(self,u_obs,u,nComponents=8000,constant=298):
|
||||
super().__init__(u_obs,u)
|
||||
self.n_components=nComponents
|
||||
def __init__(self, u_obs, u, nComponents=8000, constant=298):
|
||||
super().__init__(u_obs, u)
|
||||
self.n_components = nComponents
|
||||
|
||||
def predict(self):
|
||||
|
||||
|
||||
self.pred_init()
|
||||
X, Y = self.train_samples()
|
||||
test_samples = self.test_samples()
|
||||
|
||||
rbm1 = rbm(n_components=self.n_components, n_iter=300, learning_rate=0.06, random_state=1,verbose=True)
|
||||
regressor = Pipeline([('rbm', rbm1), ('clf', LinearRegression())])
|
||||
rbm1 = rbm(
|
||||
n_components=self.n_components,
|
||||
n_iter=300,
|
||||
learning_rate=0.06,
|
||||
random_state=1,
|
||||
verbose=True,
|
||||
)
|
||||
regressor = Pipeline([("rbm", rbm1), ("clf", LinearRegression())])
|
||||
regressor.fit(X, Y)
|
||||
|
||||
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1])
|
||||
self.u_pred = regressor.predict(test_samples).reshape(
|
||||
self.u.shape[0], self.u.shape[1]
|
||||
)
|
||||
return self.u_pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m=sio.loadmat('Example0.mat')
|
||||
u_obs=m['u_obs']
|
||||
u=m['u']
|
||||
sample = RBM(u_obs,u, nComponents=8000)
|
||||
if __name__ == "__main__":
|
||||
m = sio.loadmat("Example0.mat")
|
||||
u_obs = m["u_obs"]
|
||||
u = m["u"]
|
||||
sample = RBM(u_obs, u, nComponents=8000)
|
||||
u_pred = sample.predict()
|
||||
print('mae:',mae(u_pred,u))
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
u_pred=u_pred*50+298
|
||||
u_pred = u_pred * 50 + 298
|
||||
from sklearn.metrics import mean_absolute_error as mae
|
||||
print('mae:',mae(u_pred,u))
|
||||
|
||||
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.pcolormesh(X,Y,abs(u-u_pred))
|
||||
plt.title("Absolute Error")
|
||||
im = plt.pcolormesh(X, Y, abs(u - u_pred))
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
im = plt.contourf(X,Y,u,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
im = plt.contourf(X, Y, u, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet')
|
||||
plt.title("Absolute Error")
|
||||
im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
#save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
#fig.savefig(save_name, dpi=300)
|
||||
# save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
# fig.savefig(save_name, dpi=300)
|
||||
|
||||
|
||||
#fig = plt.figure(figsize=(5,5))
|
||||
#im = plt.imshow(u,cmap='jet')
|
||||
#plt.colorbar(im)
|
||||
fig.savefig('prediction.png', dpi=300)
|
||||
|
||||
|
||||
|
||||
# fig = plt.figure(figsize=(5,5))
|
||||
# im = plt.imshow(u,cmap='jet')
|
||||
# plt.colorbar(im)
|
||||
fig.savefig("prediction.png", dpi=300)
|
||||
|
|
|
@ -8,74 +8,70 @@ from .util.base import Base
|
|||
|
||||
|
||||
class RSVR(Base):
|
||||
|
||||
def __init__(self,u_obs,u,constant=298):
|
||||
super().__init__(u_obs,u)
|
||||
def __init__(self, u_obs, u, constant=298):
|
||||
super().__init__(u_obs, u)
|
||||
|
||||
def predict(self):
|
||||
|
||||
|
||||
self.pred_init()
|
||||
X, Y = self.train_samples()
|
||||
test_samples = self.test_samples()
|
||||
|
||||
regressor = SVR(kernel='rbf')
|
||||
regressor = SVR(kernel="rbf")
|
||||
regressor.fit(X, Y)
|
||||
|
||||
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1])
|
||||
self.u_pred = regressor.predict(test_samples).reshape(
|
||||
self.u.shape[0], self.u.shape[1]
|
||||
)
|
||||
return self.u_pred
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m=sio.loadmat('Example0.mat')
|
||||
u_obs=m['u_obs']
|
||||
u=m['u']
|
||||
sample = RSVR(u_obs,u)
|
||||
if __name__ == "__main__":
|
||||
m = sio.loadmat("Example0.mat")
|
||||
u_obs = m["u_obs"]
|
||||
u = m["u"]
|
||||
sample = RSVR(u_obs, u)
|
||||
u_pred = sample.predict()
|
||||
|
||||
|
||||
from sklearn.metrics import mean_absolute_error as mae
|
||||
print('mae:',mae(u_pred,u))
|
||||
|
||||
u_pred=u_pred*50+298
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
print("mae:", mae(u_pred, u))
|
||||
|
||||
u_pred = u_pred * 50 + 298
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.pcolormesh(X,Y,abs(u-u_pred))
|
||||
plt.title("Absolute Error")
|
||||
im = plt.pcolormesh(X, Y, abs(u - u_pred))
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
im = plt.contourf(X,Y,u,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
im = plt.contourf(X, Y, u, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet')
|
||||
plt.title("Absolute Error")
|
||||
im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
#save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
#fig.savefig(save_name, dpi=300)
|
||||
# save_name = os.path.join('outputs/predict_plot', '1.png')
|
||||
# fig.savefig(save_name, dpi=300)
|
||||
|
||||
|
||||
#fig = plt.figure(figsize=(5,5))
|
||||
#im = plt.imshow(u,cmap='jet')
|
||||
#plt.colorbar(im)
|
||||
fig.savefig('prediction.png', dpi=300)
|
||||
|
||||
|
||||
|
||||
# fig = plt.figure(figsize=(5,5))
|
||||
# im = plt.imshow(u,cmap='jet')
|
||||
# plt.colorbar(im)
|
||||
fig.savefig("prediction.png", dpi=300)
|
||||
|
|
|
@ -12,7 +12,8 @@ class Base:
|
|||
"""
|
||||
The observations are in matrix format
|
||||
"""
|
||||
def __init__(self,u_obs,u,constant=298):
|
||||
|
||||
def __init__(self, u_obs, u, constant=298):
|
||||
self.u_obs = np.array(u_obs)
|
||||
self.u = np.array(u)
|
||||
self.constant = constant
|
||||
|
@ -22,18 +23,22 @@ class Base:
|
|||
self.obser()
|
||||
|
||||
def obser(self):
|
||||
[self.rows,self.cols] = np.where(self.u_obs>self.constant)
|
||||
[self.rows, self.cols] = np.where(self.u_obs > self.constant)
|
||||
|
||||
def pred_init(self):
|
||||
self.u_pred = np.zeros_like(self.u)
|
||||
|
||||
def train_samples(self):
|
||||
X_train = np.transpose(np.array([self.rows,self.cols])) / max(self.u.shape)
|
||||
y_train = np.transpose(self.u[self.rows,self.cols])
|
||||
X_train = np.transpose(np.array([self.rows, self.cols])) / max(self.u.shape)
|
||||
y_train = np.transpose(self.u[self.rows, self.cols])
|
||||
return X_train, y_train
|
||||
|
||||
def test_samples(self):
|
||||
samples = [[row,col] for row in range(self.u.shape[0]) for col in range(self.u.shape[1])]
|
||||
samples = [
|
||||
[row, col]
|
||||
for row in range(self.u.shape[0])
|
||||
for col in range(self.u.shape[1])
|
||||
]
|
||||
samples = np.array(samples) / max(self.u.shape)
|
||||
return samples
|
||||
|
||||
|
@ -45,6 +50,7 @@ class BaseVec:
|
|||
"""
|
||||
The observations are in matrix format
|
||||
"""
|
||||
|
||||
def __init__(self, root, train_list, constant=298):
|
||||
self.root = root
|
||||
self.train_list = train_list
|
||||
|
@ -55,43 +61,45 @@ class BaseVec:
|
|||
|
||||
self.constant = constant
|
||||
|
||||
def _loader(self, path, mode='train'):
|
||||
|
||||
def _loader(self, path, mode="train"):
|
||||
|
||||
input = []
|
||||
|
||||
|
||||
output = []
|
||||
if mode == 'train':
|
||||
for _ in range(4*4):
|
||||
if mode == "train":
|
||||
for _ in range(4 * 4):
|
||||
output.append([])
|
||||
else:
|
||||
pass
|
||||
|
||||
#print((path[3]))
|
||||
# print((path[3]))
|
||||
num = 0
|
||||
for i in range(len(path)):
|
||||
num = num + 1
|
||||
#print(len(path))
|
||||
#print(i)
|
||||
source = np.array(sio.loadmat(path[i])['u_obs'])
|
||||
target = np.array(sio.loadmat(path[i])['u'])
|
||||
# print(len(path))
|
||||
# print(i)
|
||||
source = np.array(sio.loadmat(path[i])["u_obs"])
|
||||
target = np.array(sio.loadmat(path[i])["u"])
|
||||
if self.layout is None:
|
||||
self.layout = np.array(sio.loadmat(path[i])['F'])
|
||||
self.layout = np.array(sio.loadmat(path[i])["F"])
|
||||
else:
|
||||
pass
|
||||
|
||||
indata = source[np.where(source>TOL)]
|
||||
indata = source[np.where(source > TOL)]
|
||||
input.append(indata)
|
||||
if mode == 'train':
|
||||
if mode == "train":
|
||||
for k in range(4):
|
||||
for kk in range(4):
|
||||
sep = target[0+k:target.shape[0]:4, 0+kk:target.shape[1]:4].flatten()
|
||||
output[k*4+kk].append(sep)
|
||||
elif mode == 'test':
|
||||
sep = target[
|
||||
0 + k : target.shape[0] : 4, 0 + kk : target.shape[1] : 4
|
||||
].flatten()
|
||||
output[k * 4 + kk].append(sep)
|
||||
elif mode == "test":
|
||||
output.append(target)
|
||||
else:
|
||||
pass
|
||||
|
||||
if num % 1000 == 0 :
|
||||
|
||||
if num % 1000 == 0:
|
||||
print("num:", num)
|
||||
|
||||
return input, output
|
||||
|
@ -104,14 +112,17 @@ class BaseVec:
|
|||
base = os.path.dirname(list_path)
|
||||
print(base)
|
||||
test_name = os.path.splitext(os.path.basename(list_path))[0]
|
||||
subdir = os.path.join("train", "train") \
|
||||
if base=='train' else os.path.join("test", test_name)
|
||||
subdir = (
|
||||
os.path.join("train", "train")
|
||||
if base == "train"
|
||||
else os.path.join("test", test_name)
|
||||
)
|
||||
file_dir = os.path.join(root_dir, subdir)
|
||||
list_file = os.path.join(root_dir, list_path)
|
||||
print(file_dir)
|
||||
print(list_file)
|
||||
assert os.path.isdir(file_dir)
|
||||
with open(list_file, 'r') as rf:
|
||||
with open(list_file, "r") as rf:
|
||||
for line in rf.readlines():
|
||||
data_path = line.strip()
|
||||
path = os.path.join(file_dir, data_path)
|
||||
|
@ -119,13 +130,13 @@ class BaseVec:
|
|||
return files
|
||||
|
||||
def train_samples(self):
|
||||
#print(self.train_file)
|
||||
# print(self.train_file)
|
||||
X_train, y_train = self._loader(self.train_file)
|
||||
return X_train, y_train
|
||||
|
||||
def test_samples(self, test_path):
|
||||
test_file = self.make_dataset(self.root, test_path)
|
||||
X_test, y_test = self._loader(test_file, mode='test')
|
||||
X_test, y_test = self._loader(test_file, mode="test")
|
||||
|
||||
return X_test, np.array(y_test)
|
||||
|
||||
|
@ -133,21 +144,18 @@ class BaseVec:
|
|||
pass
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#m=sio.loadmat('Example0.mat')
|
||||
#u_obs=m['u_obs']
|
||||
#u=m['u']
|
||||
#sample = Base(u_obs,u)
|
||||
#sample.vec()
|
||||
root = 'g:/gong/recon_project/TFRD/HSink/'
|
||||
train_list = 'g:/gong/recon_project/TFRD/HSink/train/train_val.txt'
|
||||
test_list = 'g:/gong/recon_project/TFRD/HSink/train/test_0.txt'
|
||||
if __name__ == "__main__":
|
||||
# m=sio.loadmat('Example0.mat')
|
||||
# u_obs=m['u_obs']
|
||||
# u=m['u']
|
||||
# sample = Base(u_obs,u)
|
||||
# sample.vec()
|
||||
root = "g:/gong/recon_project/TFRD/HSink/"
|
||||
train_list = "g:/gong/recon_project/TFRD/HSink/train/train_val.txt"
|
||||
test_list = "g:/gong/recon_project/TFRD/HSink/train/test_0.txt"
|
||||
|
||||
sample = BaseVec(root, train_list)
|
||||
|
||||
a, b = sample.train_samples()
|
||||
print(a)
|
||||
print(b[0])
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
PAD_WORD = '<blank>'
|
||||
UNK_WORD = '<unk>'
|
||||
BOS_WORD = '<s>'
|
||||
EOS_WORD = '</s>'
|
||||
PAD_WORD = "<blank>"
|
||||
UNK_WORD = "<unk>"
|
||||
BOS_WORD = "<s>"
|
||||
EOS_WORD = "</s>"
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
''' Define the Layers '''
|
||||
""" Define the Layers """
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
|
||||
|
@ -8,7 +8,7 @@ __author__ = "Yu-Hsiang Huang"
|
|||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
''' Compose with two layers '''
|
||||
"""Compose with two layers"""
|
||||
|
||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
|
||||
super(EncoderLayer, self).__init__()
|
||||
|
@ -17,13 +17,14 @@ class EncoderLayer(nn.Module):
|
|||
|
||||
def forward(self, enc_input, slf_attn_mask=None):
|
||||
enc_output, enc_slf_attn = self.slf_attn(
|
||||
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
enc_input, enc_input, enc_input, mask=slf_attn_mask
|
||||
)
|
||||
enc_output = self.pos_ffn(enc_output)
|
||||
return enc_output, enc_slf_attn
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
''' Compose with three layers '''
|
||||
"""Compose with three layers"""
|
||||
|
||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
|
||||
super(DecoderLayer, self).__init__()
|
||||
|
@ -32,12 +33,14 @@ class DecoderLayer(nn.Module):
|
|||
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
|
||||
|
||||
def forward(
|
||||
self, dec_input, enc_output,
|
||||
slf_attn_mask=None, dec_enc_attn_mask=None):
|
||||
self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None
|
||||
):
|
||||
dec_output, dec_slf_attn = self.slf_attn(
|
||||
dec_input, dec_input, dec_input, mask=slf_attn_mask)
|
||||
dec_input, dec_input, dec_input, mask=slf_attn_mask
|
||||
)
|
||||
dec_output, dec_enc_attn = self.enc_attn(
|
||||
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
|
||||
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask
|
||||
)
|
||||
dec_output = self.pos_ffn(dec_output)
|
||||
return dec_output, dec_slf_attn, dec_enc_attn
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ __author__ = "Yu-Hsiang Huang"
|
|||
|
||||
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
''' Scaled Dot-Product Attention '''
|
||||
"""Scaled Dot-Product Attention"""
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super().__init__()
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
'''A wrapper class for scheduled optimizer '''
|
||||
"""A wrapper class for scheduled optimizer """
|
||||
import numpy as np
|
||||
|
||||
class ScheduledOptim():
|
||||
'''A simple wrapper class for learning rate scheduling'''
|
||||
|
||||
class ScheduledOptim:
|
||||
"""A simple wrapper class for learning rate scheduling"""
|
||||
|
||||
def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps):
|
||||
self._optimizer = optimizer
|
||||
|
@ -11,30 +12,27 @@ class ScheduledOptim():
|
|||
self.n_warmup_steps = n_warmup_steps
|
||||
self.n_steps = 0
|
||||
|
||||
|
||||
def step_and_update_lr(self):
|
||||
"Step with the inner optimizer"
|
||||
self._update_learning_rate()
|
||||
self._optimizer.step()
|
||||
|
||||
|
||||
def zero_grad(self):
|
||||
"Zero out the gradients with the inner optimizer"
|
||||
self._optimizer.zero_grad()
|
||||
|
||||
|
||||
def _get_lr_scale(self):
|
||||
d_model = self.d_model
|
||||
n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
|
||||
return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))
|
||||
|
||||
return (d_model ** -0.5) * min(
|
||||
n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5)
|
||||
)
|
||||
|
||||
def _update_learning_rate(self):
|
||||
''' Learning rate scheduling per step '''
|
||||
"""Learning rate scheduling per step"""
|
||||
|
||||
self.n_steps += 1
|
||||
lr = self.lr_mul * self._get_lr_scale()
|
||||
|
||||
for param_group in self._optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
param_group["lr"] = lr
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
''' Define the sublayers in encoder/decoder layer '''
|
||||
""" Define the sublayers in encoder/decoder layer """
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
@ -8,7 +8,7 @@ __author__ = "Yu-Hsiang Huang"
|
|||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
''' Multi-Head Attention module '''
|
||||
"""Multi-Head Attention module"""
|
||||
|
||||
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
||||
super().__init__()
|
||||
|
@ -59,7 +59,7 @@ class MultiHeadAttention(nn.Module):
|
|||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
''' A two-feed-forward-layer module '''
|
||||
"""A two-feed-forward-layer module"""
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.1):
|
||||
super().__init__()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
''' This module will handle the text generation with beam search. '''
|
||||
""" This module will handle the text generation with beam search. """
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -7,12 +7,18 @@ from models.transformer.Models import Transformer, get_pad_mask, get_subsequent_
|
|||
|
||||
|
||||
class Translator(nn.Module):
|
||||
''' Load a trained model and translate in beam search fashion. '''
|
||||
"""Load a trained model and translate in beam search fashion."""
|
||||
|
||||
def __init__(
|
||||
self, model, beam_size, max_seq_len,
|
||||
src_pad_idx, trg_pad_idx, trg_bos_idx, trg_eos_idx):
|
||||
|
||||
self,
|
||||
model,
|
||||
beam_size,
|
||||
max_seq_len,
|
||||
src_pad_idx,
|
||||
trg_pad_idx,
|
||||
trg_bos_idx,
|
||||
trg_eos_idx,
|
||||
):
|
||||
|
||||
super(Translator, self).__init__()
|
||||
|
||||
|
@ -26,28 +32,27 @@ class Translator(nn.Module):
|
|||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]]))
|
||||
self.register_buffer("init_seq", torch.LongTensor([[trg_bos_idx]]))
|
||||
self.register_buffer(
|
||||
'blank_seqs',
|
||||
torch.full((beam_size, max_seq_len), trg_pad_idx, dtype=torch.long))
|
||||
"blank_seqs",
|
||||
torch.full((beam_size, max_seq_len), trg_pad_idx, dtype=torch.long),
|
||||
)
|
||||
self.blank_seqs[:, 0] = self.trg_bos_idx
|
||||
self.register_buffer(
|
||||
'len_map',
|
||||
torch.arange(1, max_seq_len + 1, dtype=torch.long).unsqueeze(0))
|
||||
|
||||
"len_map", torch.arange(1, max_seq_len + 1, dtype=torch.long).unsqueeze(0)
|
||||
)
|
||||
|
||||
def _model_decode(self, trg_seq, enc_output, src_mask):
|
||||
trg_mask = get_subsequent_mask(trg_seq)
|
||||
dec_output, *_ = self.model.decoder(trg_seq, trg_mask, enc_output, src_mask)
|
||||
return F.softmax(self.model.trg_word_prj(dec_output), dim=-1)
|
||||
|
||||
|
||||
def _get_init_state(self, src_seq, src_mask):
|
||||
beam_size = self.beam_size
|
||||
|
||||
enc_output, *_ = self.model.encoder(src_seq, src_mask)
|
||||
dec_output = self._model_decode(self.init_seq, enc_output, src_mask)
|
||||
|
||||
|
||||
best_k_probs, best_k_idx = dec_output[:, -1, :].topk(beam_size)
|
||||
|
||||
scores = torch.log(best_k_probs).view(beam_size)
|
||||
|
@ -56,23 +61,27 @@ class Translator(nn.Module):
|
|||
enc_output = enc_output.repeat(beam_size, 1, 1)
|
||||
return enc_output, gen_seq, scores
|
||||
|
||||
|
||||
def _get_the_best_score_and_idx(self, gen_seq, dec_output, scores, step):
|
||||
assert len(scores.size()) == 1
|
||||
|
||||
|
||||
beam_size = self.beam_size
|
||||
|
||||
# Get k candidates for each beam, k^2 candidates in total.
|
||||
best_k2_probs, best_k2_idx = dec_output[:, -1, :].topk(beam_size)
|
||||
|
||||
# Include the previous scores.
|
||||
scores = torch.log(best_k2_probs).view(beam_size, -1) + scores.view(beam_size, 1)
|
||||
scores = torch.log(best_k2_probs).view(beam_size, -1) + scores.view(
|
||||
beam_size, 1
|
||||
)
|
||||
|
||||
# Get the best k candidates from k^2 candidates.
|
||||
scores, best_k_idx_in_k2 = scores.view(-1).topk(beam_size)
|
||||
|
||||
|
||||
# Get the corresponding positions of the best k candidiates.
|
||||
best_k_r_idxs, best_k_c_idxs = best_k_idx_in_k2 // beam_size, best_k_idx_in_k2 % beam_size
|
||||
best_k_r_idxs, best_k_c_idxs = (
|
||||
best_k_idx_in_k2 // beam_size,
|
||||
best_k_idx_in_k2 % beam_size,
|
||||
)
|
||||
best_k_idx = best_k2_idx[best_k_r_idxs, best_k_c_idxs]
|
||||
|
||||
# Copy the corresponding previous tokens.
|
||||
|
@ -82,27 +91,28 @@ class Translator(nn.Module):
|
|||
|
||||
return gen_seq, scores
|
||||
|
||||
|
||||
def translate_sentence(self, src_seq):
|
||||
# Only accept batch size equals to 1 in this function.
|
||||
# TODO: expand to batch operation.
|
||||
assert src_seq.size(0) == 1
|
||||
|
||||
src_pad_idx, trg_eos_idx = self.src_pad_idx, self.trg_eos_idx
|
||||
max_seq_len, beam_size, alpha = self.max_seq_len, self.beam_size, self.alpha
|
||||
src_pad_idx, trg_eos_idx = self.src_pad_idx, self.trg_eos_idx
|
||||
max_seq_len, beam_size, alpha = self.max_seq_len, self.beam_size, self.alpha
|
||||
|
||||
with torch.no_grad():
|
||||
src_mask = get_pad_mask(src_seq, src_pad_idx)
|
||||
enc_output, gen_seq, scores = self._get_init_state(src_seq, src_mask)
|
||||
|
||||
ans_idx = 0 # default
|
||||
for step in range(2, max_seq_len): # decode up to max length
|
||||
ans_idx = 0 # default
|
||||
for step in range(2, max_seq_len): # decode up to max length
|
||||
dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask)
|
||||
gen_seq, scores = self._get_the_best_score_and_idx(gen_seq, dec_output, scores, step)
|
||||
gen_seq, scores = self._get_the_best_score_and_idx(
|
||||
gen_seq, dec_output, scores, step
|
||||
)
|
||||
|
||||
# Check if all path finished
|
||||
# -- locate the eos in the generated sequences
|
||||
eos_locs = gen_seq == trg_eos_idx
|
||||
eos_locs = gen_seq == trg_eos_idx
|
||||
# -- replace the eos with its position for the length penalty use
|
||||
seq_lens, _ = self.len_map.masked_fill(~eos_locs, max_seq_len).min(1)
|
||||
# -- check if all beams contain eos
|
||||
|
@ -111,4 +121,4 @@ class Translator(nn.Module):
|
|||
_, ans_idx = scores.div(seq_lens.float() ** alpha).max(0)
|
||||
ans_idx = ans_idx.item()
|
||||
break
|
||||
return gen_seq[ans_idx][:seq_lens[ans_idx]].tolist()
|
||||
return gen_seq[ans_idx][: seq_lens[ans_idx]].tolist()
|
||||
|
|
168
src/plot.py
168
src/plot.py
|
@ -22,7 +22,7 @@ def main(hparams):
|
|||
if hparams.gpu == 0:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
ngpu = "cuda:"+str(hparams.gpu-1)
|
||||
ngpu = "cuda:" + str(hparams.gpu - 1)
|
||||
print(ngpu)
|
||||
device = torch.device(ngpu)
|
||||
model = Model(hparams).to(device)
|
||||
|
@ -31,8 +31,9 @@ def main(hparams):
|
|||
print()
|
||||
|
||||
# Model loading
|
||||
model_path = os.path.join(f'lightning_logs/version_' +
|
||||
hparams.test_check_num, 'checkpoints/')
|
||||
model_path = os.path.join(
|
||||
f"lightning_logs/version_" + hparams.test_check_num, "checkpoints/"
|
||||
)
|
||||
ckpt = list(Path(model_path).glob("*.ckpt"))[0]
|
||||
print(ckpt)
|
||||
|
||||
|
@ -47,9 +48,9 @@ def main(hparams):
|
|||
test_list = hparams.test_list
|
||||
file_path = os.path.join(root, test_list)
|
||||
test_name = os.path.splitext(os.path.basename(test_list))[0]
|
||||
root_dir = os.path.join(root, 'test', test_name)
|
||||
root_dir = os.path.join(root, "test", test_name)
|
||||
|
||||
with open(file_path, 'r') as fp:
|
||||
with open(file_path, "r") as fp:
|
||||
for line in fp.readlines():
|
||||
# Data Reading
|
||||
data_path = line.strip()
|
||||
|
@ -61,104 +62,165 @@ def main(hparams):
|
|||
u_true = heat.squeeze().squeeze().numpy()
|
||||
heat_obs = (heat_obs - hparams.mean_layout) / hparams.std_layout
|
||||
heat0 = (heat0 - hparams.mean_heat) / hparams.std_heat
|
||||
heat = (heat-hparams.mean_heat) / hparams.std_heat
|
||||
obs_index, heat_obs, pred_index, heat0, heat = obs_index.to(device), heat_obs.to(device), pred_index.to(device), heat0.to(device), heat.to(device)
|
||||
heat = (heat - hparams.mean_heat) / hparams.std_heat
|
||||
obs_index, heat_obs, pred_index, heat0, heat = (
|
||||
obs_index.to(device),
|
||||
heat_obs.to(device),
|
||||
pred_index.to(device),
|
||||
heat0.to(device),
|
||||
heat.to(device),
|
||||
)
|
||||
heat_info = [obs_index, heat_obs, pred_index, heat0]
|
||||
|
||||
if model.layout_model=="ConditionalNeuralProcess" or model.layout_model=="TransformerRecon":
|
||||
heat_info[1] = heat_info[1].transpose(1,2)
|
||||
heat_info[3] = heat_info[3].transpose(2,3)
|
||||
elif model.layout_model=="DenseDeepGCN":
|
||||
heat_obs=heat_obs.squeeze()
|
||||
pseudo_heat = torch.zeros_like(heat0[:,0,:]).squeeze()
|
||||
inputs = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,0,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0)
|
||||
if (
|
||||
model.layout_model == "ConditionalNeuralProcess"
|
||||
or model.layout_model == "TransformerRecon"
|
||||
):
|
||||
heat_info[1] = heat_info[1].transpose(1, 2)
|
||||
heat_info[3] = heat_info[3].transpose(2, 3)
|
||||
elif model.layout_model == "DenseDeepGCN":
|
||||
heat_obs = heat_obs.squeeze()
|
||||
pseudo_heat = torch.zeros_like(heat0[:, 0, :]).squeeze()
|
||||
inputs = (
|
||||
torch.cat(
|
||||
(
|
||||
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
|
||||
torch.cat((obs_index, pred_index[:, 0, ...]), 1),
|
||||
),
|
||||
2,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
|
||||
for i in range(self.hparams.div_num*self.hparams.div_num-1):
|
||||
input_single = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,i+1,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0)
|
||||
for i in range(hparams.div_num * hparams.div_num - 1):
|
||||
input_single = (
|
||||
torch.cat(
|
||||
(
|
||||
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
|
||||
torch.cat(
|
||||
(obs_index, pred_index[:, i + 1, ...]), 1
|
||||
),
|
||||
),
|
||||
2,
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(0)
|
||||
)
|
||||
inputs = torch.cat((inputs, input_single), 0)
|
||||
|
||||
heat_info = inputs
|
||||
else:
|
||||
data = sio.loadmat(path)
|
||||
u_true, u_obs = data["u"], data["u_obs"]
|
||||
|
||||
u_obs[np.where(u_obs<TOL)]=hparams.mean_layout
|
||||
u_obs = torch.Tensor((u_obs - hparams.mean_layout) / hparams.std_layout).unsqueeze(0).unsqueeze(0).to(device)
|
||||
heat = torch.Tensor((u_true - hparams.mean_heat) / hparams.std_heat).unsqueeze(0).unsqueeze(0).to(device)
|
||||
|
||||
u_obs[np.where(u_obs < TOL)] = hparams.mean_layout
|
||||
u_obs = (
|
||||
torch.Tensor((u_obs - hparams.mean_layout) / hparams.std_layout)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.to(device)
|
||||
)
|
||||
heat = (
|
||||
torch.Tensor((u_true - hparams.mean_heat) / hparams.std_heat)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.to(device)
|
||||
)
|
||||
heat_info = u_obs
|
||||
|
||||
hs_F = sio.loadmat(path)["F"]
|
||||
|
||||
# Plot u_obs and Real Temperature Field
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Real Time Power')
|
||||
|
||||
im = plt.pcolormesh(X,Y,hs_F)
|
||||
plt.title("Real Time Power")
|
||||
|
||||
im = plt.pcolormesh(X, Y, hs_F)
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
|
||||
heat_pred0 = model(heat_info)
|
||||
if model.vec:
|
||||
if model.layout_model=="DenseDeepGCN":
|
||||
heat_pred0 = heat_pred0[...,-model.output_dim:]
|
||||
if model.layout_model == "DenseDeepGCN":
|
||||
heat_pred0 = heat_pred0[..., -model.output_dim :]
|
||||
else:
|
||||
pass
|
||||
|
||||
heat_pred0 = heat_pred0.reshape((-1, hparams.div_num*hparams.div_num, int(200 / hparams.div_num), int(200 / hparams.div_num)))
|
||||
|
||||
heat_pred0 = heat_pred0.reshape(
|
||||
(
|
||||
-1,
|
||||
hparams.div_num * hparams.div_num,
|
||||
int(200 / hparams.div_num),
|
||||
int(200 / hparams.div_num),
|
||||
)
|
||||
)
|
||||
heat_pre = torch.zeros_like(heat_pred0).reshape((-1, 1, 200, 200))
|
||||
for i in range(hparams.div_num):
|
||||
for j in range(hparams.div_num):
|
||||
heat_pre[..., 0+i:200:hparams.div_num, 0+j:200:hparams.div_num] = heat_pred0[:, hparams.div_num*i+j,...].unsqueeze(1)
|
||||
heat_pre = heat_pre.transpose(2,3)
|
||||
heat_pre[
|
||||
...,
|
||||
0 + i : 200 : hparams.div_num,
|
||||
0 + j : 200 : hparams.div_num,
|
||||
] = heat_pred0[:, hparams.div_num * i + j, ...].unsqueeze(1)
|
||||
heat_pre = heat_pre.transpose(2, 3)
|
||||
heat = heat.unsqueeze(1)
|
||||
else:
|
||||
heat_pre = heat_pred0
|
||||
heat = heat
|
||||
|
||||
mae = F.l1_loss(heat, heat_pre) * hparams.std_heat
|
||||
print('sample:', data_path)
|
||||
print('MAE:', mae)
|
||||
print("sample:", data_path)
|
||||
print("MAE:", mae)
|
||||
mae_test.append(mae.item())
|
||||
heat_pre = heat_pre.squeeze(0).squeeze(0).cpu().numpy() * hparams.std_heat + hparams.mean_heat
|
||||
#heat_pre = np.transpose(heat_pre, (1,0))
|
||||
heat_pre = (
|
||||
heat_pre.squeeze(0).squeeze(0).cpu().numpy() * hparams.std_heat
|
||||
+ hparams.mean_heat
|
||||
)
|
||||
# heat_pre = np.transpose(heat_pre, (1,0))
|
||||
hmax = max(np.max(heat_pre), np.max(u_true))
|
||||
hmin = min(np.min(heat_pre), np.min(u_true))
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
|
||||
im = plt.contourf(X,Y,u_true,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
|
||||
im = plt.contourf(X, Y, u_true, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
|
||||
im = plt.contourf(X, Y, heat_pre,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
|
||||
im = plt.contourf(X, Y, heat_pre, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
|
||||
im = plt.contourf(X, Y, np.abs(heat_pre-u_true),levels=150,cmap='jet')
|
||||
|
||||
plt.title("Absolute Error")
|
||||
|
||||
im = plt.contourf(X, Y, np.abs(heat_pre - u_true), levels=150, cmap="jet")
|
||||
|
||||
plt.colorbar(im)
|
||||
|
||||
save_name = os.path.join('outputs/predict_plot', os.path.splitext(os.path.basename(path))[0]+'.png')
|
||||
mat_name = os.path.join('outputs/predict_plot', os.path.splitext(os.path.basename(path))[0]+'.mat')
|
||||
sio.savemat(mat_name, {'pre': heat_pre, 'u_true': u_true})
|
||||
save_name = os.path.join(
|
||||
"outputs/predict_plot",
|
||||
os.path.splitext(os.path.basename(path))[0] + ".png",
|
||||
)
|
||||
mat_name = os.path.join(
|
||||
"outputs/predict_plot",
|
||||
os.path.splitext(os.path.basename(path))[0] + ".mat",
|
||||
)
|
||||
sio.savemat(mat_name, {"pre": heat_pre, "u_true": u_true})
|
||||
fig.savefig(save_name, dpi=300)
|
||||
plt.close()
|
||||
|
||||
mae_test = np.array(mae_test)
|
||||
print(mae_test.mean())
|
||||
np.savetxt('outputs/mae_test.csv', mae_test, fmt='%f', delimiter=',')
|
||||
|
||||
np.savetxt("outputs/mae_test.csv", mae_test, fmt="%f", delimiter=",")
|
||||
|
|
118
src/point.py
118
src/point.py
|
@ -17,14 +17,12 @@ TOL = 1e-14
|
|||
|
||||
|
||||
def main(hparams):
|
||||
|
||||
|
||||
model = PointModel(hparams)
|
||||
model.testing_step()
|
||||
|
||||
|
||||
|
||||
class PointModel:
|
||||
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
self.hparams = hparams
|
||||
|
@ -44,85 +42,87 @@ class PointModel:
|
|||
|
||||
trange = tqdm.tqdm(self.data)
|
||||
all_mae, all_maxae, all_cmae, all_mcae, all_bmae = 0, 0, 0, 0, 0
|
||||
|
||||
|
||||
mae_test = []
|
||||
for num, data in enumerate(trange):
|
||||
|
||||
path = self.data.sample_files[num]
|
||||
u_obs=data[0]
|
||||
u=data[1]
|
||||
F=data[2]
|
||||
u_obs = data[0]
|
||||
u = data[1]
|
||||
F = data[2]
|
||||
sample = Point(self.hparams.model_name, u_obs, u)
|
||||
|
||||
u_pred = sample.predict()
|
||||
if(self.hparams.plot):
|
||||
self.plot(path, u_pred, u, F)
|
||||
|
||||
all_mae += self.metric.mae(u_pred,u)
|
||||
all_maxae += self.metric.maxae(u_pred,u)
|
||||
all_cmae += self.metric.cmae(u_pred,u,F)
|
||||
all_mcae += self.metric.mcae(u_pred,u,F)
|
||||
all_bmae += self.metric.bmae(u_pred,u)
|
||||
|
||||
mae_test.append(self.metric.mae(u_pred,u))
|
||||
|
||||
u_pred = sample.predict()
|
||||
if self.hparams.plot:
|
||||
self.plot(path, u_pred, u, F)
|
||||
|
||||
all_mae += self.metric.mae(u_pred, u)
|
||||
all_maxae += self.metric.maxae(u_pred, u)
|
||||
all_cmae += self.metric.cmae(u_pred, u, F)
|
||||
all_mcae += self.metric.mcae(u_pred, u, F)
|
||||
all_bmae += self.metric.bmae(u_pred, u)
|
||||
|
||||
mae_test.append(self.metric.mae(u_pred, u))
|
||||
|
||||
trange.set_description("Testing")
|
||||
|
||||
trange.set_postfix(MAE=all_mae/(num+1), MaxAE=all_maxae/(num+1), CMAE=all_cmae/(num+1), MCAE=all_mcae/(num+1), BMAE=all_bmae/(num+1))
|
||||
|
||||
|
||||
trange.set_postfix(
|
||||
MAE=all_mae / (num + 1),
|
||||
MaxAE=all_maxae / (num + 1),
|
||||
CMAE=all_cmae / (num + 1),
|
||||
MCAE=all_mcae / (num + 1),
|
||||
BMAE=all_bmae / (num + 1),
|
||||
)
|
||||
|
||||
mae_test = np.array(mae_test)
|
||||
print(mae_test.mean())
|
||||
np.savetxt('outputs/mae_test.csv', mae_test, fmt='%f', delimiter=',')
|
||||
|
||||
np.savetxt("outputs/mae_test.csv", mae_test, fmt="%f", delimiter=",")
|
||||
|
||||
def plot(self, path, heat_pre, u_true, hs_F):
|
||||
fig = plt.figure(figsize=(22.5,5))
|
||||
|
||||
fig = plt.figure(figsize=(22.5, 5))
|
||||
|
||||
grid_x = np.linspace(0, 0.1, num=200)
|
||||
grid_y = np.linspace(0, 0.1, num=200)
|
||||
X, Y = np.meshgrid(grid_x, grid_y)
|
||||
|
||||
plt.subplot(141)
|
||||
plt.title('Real Time Power')
|
||||
|
||||
im = plt.pcolormesh(X,Y,hs_F)
|
||||
plt.title("Real Time Power")
|
||||
|
||||
im = plt.pcolormesh(X, Y, hs_F)
|
||||
plt.colorbar(im)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0)
|
||||
fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
|
||||
|
||||
plt.subplot(142)
|
||||
plt.title('Real Temperature Field')
|
||||
|
||||
im = plt.contourf(X,Y,u_true,levels=150,cmap='jet')
|
||||
plt.title("Real Temperature Field")
|
||||
|
||||
im = plt.contourf(X, Y, u_true, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(143)
|
||||
plt.title('Reconstructed Temperature Field')
|
||||
|
||||
im = plt.contourf(X, Y, heat_pre,levels=150,cmap='jet')
|
||||
plt.title("Reconstructed Temperature Field")
|
||||
|
||||
im = plt.contourf(X, Y, heat_pre, levels=150, cmap="jet")
|
||||
plt.colorbar(im)
|
||||
|
||||
plt.subplot(144)
|
||||
plt.title('Absolute Error')
|
||||
|
||||
im = plt.contourf(X, Y, np.abs(heat_pre-u_true),levels=150,cmap='jet')
|
||||
|
||||
plt.title("Absolute Error")
|
||||
|
||||
im = plt.contourf(X, Y, np.abs(heat_pre - u_true), levels=150, cmap="jet")
|
||||
|
||||
plt.colorbar(im)
|
||||
|
||||
save_name = os.path.join('outputs/predict_plot', os.path.splitext(os.path.basename(path))[0]+'.png')
|
||||
mat_name = os.path.join('outputs/predict_plot', os.path.splitext(os.path.basename(path))[0]+'.mat')
|
||||
sio.savemat(mat_name, {'pre': heat_pre, 'u_true': u_true})
|
||||
save_name = os.path.join(
|
||||
"outputs/predict_plot", os.path.splitext(os.path.basename(path))[0] + ".png"
|
||||
)
|
||||
mat_name = os.path.join(
|
||||
"outputs/predict_plot", os.path.splitext(os.path.basename(path))[0] + ".mat"
|
||||
)
|
||||
sio.savemat(mat_name, {"pre": heat_pre, "u_true": u_true})
|
||||
fig.savefig(save_name, dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Metric:
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -133,17 +133,17 @@ class Metric:
|
|||
return np.max(np.max(abs(u_pred - u)))
|
||||
|
||||
def mcae(self, u_pred, u, F):
|
||||
F[np.where(F>TOL)] = 1
|
||||
return np.max(np.max(abs(u_pred - u)*F))
|
||||
F[np.where(F > TOL)] = 1
|
||||
return np.max(np.max(abs(u_pred - u) * F))
|
||||
|
||||
def cmae(self, u_pred, u, F):
|
||||
F[np.where(F>TOL)] = 1
|
||||
return np.sum(np.sum(abs(u_pred - u)*F)) / np.sum(F)
|
||||
F[np.where(F > TOL)] = 1
|
||||
return np.sum(np.sum(abs(u_pred - u) * F)) / np.sum(F)
|
||||
|
||||
def bmae(self, u_pred, u):
|
||||
ind = np.zeros_like(u_pred)
|
||||
ind[:2,...]=1
|
||||
ind[-2:,...]=1
|
||||
ind[...,:2]=1
|
||||
ind[...,-2:]=1
|
||||
return np.sum(np.sum(abs(u_pred - u)*ind)) / np.sum(ind)
|
||||
ind[:2, ...] = 1
|
||||
ind[-2:, ...] = 1
|
||||
ind[..., :2] = 1
|
||||
ind[..., -2:] = 1
|
||||
return np.sum(np.sum(abs(u_pred - u) * ind)) / np.sum(ind)
|
||||
|
|
|
@ -33,15 +33,16 @@ def main(hparams):
|
|||
if hparams.gpu == 0:
|
||||
hparams.gpu = 0
|
||||
else:
|
||||
hparams.gpu = [hparams.gpu-1]
|
||||
hparams.gpu = [hparams.gpu - 1]
|
||||
trainer = pl.Trainer(
|
||||
gpus=hparams.gpu,
|
||||
precision=16 if hparams.use_16bit else 32,
|
||||
# limit_test_batches=0.05
|
||||
)
|
||||
|
||||
model_path = os.path.join(f'lightning_logs/version_' +
|
||||
hparams.test_check_num, 'checkpoints/')
|
||||
model_path = os.path.join(
|
||||
f"lightning_logs/version_" + hparams.test_check_num, "checkpoints/"
|
||||
)
|
||||
model_path = list(Path(model_path).glob("*.ckpt"))[0]
|
||||
test_model = model.load_from_checkpoint(checkpoint_path=model_path, hparams=hparams)
|
||||
|
||||
|
@ -52,4 +53,3 @@ def main(hparams):
|
|||
print()
|
||||
|
||||
trainer.test(model=test_model)
|
||||
|
||||
|
|
|
@ -32,8 +32,8 @@ def main(hparams):
|
|||
if hparams.gpu == 0:
|
||||
hparams.gpu = 0
|
||||
else:
|
||||
hparams.gpu = [hparams.gpu-1]
|
||||
#print(hparams.gpus)
|
||||
hparams.gpu = [hparams.gpu - 1]
|
||||
# print(hparams.gpus)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=hparams.max_epochs,
|
||||
gpus=hparams.gpu,
|
||||
|
|
|
@ -10,13 +10,14 @@ def weights_init(m):
|
|||
"""
|
||||
class_name = m.__class__.__name__
|
||||
if class_name.find("Conv") != -1:
|
||||
torch.nn.init.kaiming_normal_(m.weight,
|
||||
mode="fan_out",
|
||||
nonlinearity="relu") # 初始化卷积层权重
|
||||
torch.nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_out", nonlinearity="relu"
|
||||
) # 初始化卷积层权重
|
||||
# torch.nn.init.xavier_normal_(m.weight)
|
||||
elif (class_name.find("BatchNorm") != -1
|
||||
and class_name.find("WithFixedBatchNorm") == -1
|
||||
): # batch norm层不能用kaiming_normal初始化
|
||||
elif (
|
||||
class_name.find("BatchNorm") != -1
|
||||
and class_name.find("WithFixedBatchNorm") == -1
|
||||
): # batch norm层不能用kaiming_normal初始化
|
||||
torch.nn.init.constant_(m.weight, 1)
|
||||
torch.nn.init.constant_(m.bias, 0)
|
||||
# m.weight.data.normal_(1.0, 0.02)
|
||||
|
@ -45,9 +46,10 @@ def weights_init_without_kaiming(m):
|
|||
if class_name.find("Conv") != -1:
|
||||
torch.nn.init.xavier_normal_(m.weight)
|
||||
# torch.nn.init.normal_(m.weight) # 初始化卷积层权重
|
||||
elif (class_name.find("BatchNorm") != -1
|
||||
and class_name.find("WithFixedBatchNorm") == -1
|
||||
): # batch norm层不能用kaiming_normal初始化
|
||||
elif (
|
||||
class_name.find("BatchNorm") != -1
|
||||
and class_name.find("WithFixedBatchNorm") == -1
|
||||
): # batch norm层不能用kaiming_normal初始化
|
||||
torch.nn.init.constant_(m.weight, 1)
|
||||
torch.nn.init.constant_(m.bias, 0)
|
||||
# m.weight.data.normal_(1.0, 0.02)
|
||||
|
|
|
@ -33,7 +33,6 @@ class ToTensor:
|
|||
|
||||
|
||||
class Resize:
|
||||
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
|
@ -43,7 +42,7 @@ class Resize:
|
|||
for _ in range(4 - x_dim):
|
||||
x_tensor = x_tensor.unsqueeze(0)
|
||||
x_resize = interpolate(x_tensor, size=self.size)
|
||||
for _ in range(4-x_dim):
|
||||
for _ in range(4 - x_dim):
|
||||
x_resize = x_resize.squeeze(0)
|
||||
return x_resize.numpy()
|
||||
|
||||
|
|
|
@ -12,7 +12,9 @@ def get_upsampling_weight(in_channels, out_channels, kernel_size):
|
|||
center = factor - 0.5
|
||||
og = np.ogrid[:kernel_size, :kernel_size]
|
||||
filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
|
||||
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64)
|
||||
weight = np.zeros(
|
||||
(in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64
|
||||
)
|
||||
weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt
|
||||
return torch.from_numpy(weight).float()
|
||||
|
||||
|
@ -26,4 +28,4 @@ def initialize_weights(*models):
|
|||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
module.weight.data.fill_(1)
|
||||
module.bias.data.zero_()
|
||||
module.bias.data.zero_()
|
||||
|
|
Loading…
Reference in New Issue