This commit is contained in:
zhiqiang gong 2021-08-25 13:33:37 +08:00
parent a53df132bf
commit f4672490d1
61 changed files with 3012 additions and 1703 deletions

63
main.py
View File

@ -23,45 +23,80 @@ def main():
# default configuration file # default configuration file
config_path = Path(__file__).absolute().parent / "config/config.yml" config_path = Path(__file__).absolute().parent / "config/config.yml"
data_path = Path(__file__).absolute().parent / "config/data.yml" data_path = Path(__file__).absolute().parent / "config/data.yml"
parser = configargparse.ArgParser(config_file_parser_class= configargparse.YAMLConfigFileParser, \ parser = configargparse.ArgParser(
default_config_files=[str(config_path), str(data_path)], description="Hyper-parameters.") config_file_parser_class=configargparse.YAMLConfigFileParser,
default_config_files=[str(config_path), str(data_path)],
description="Hyper-parameters.",
)
# configuration file # configuration file
parser.add_argument("--config", is_config_file=True, default=False, help="config file path") parser.add_argument(
"--config", is_config_file=True, default=False, help="config file path"
)
# mode # mode
parser.add_argument("-m", "--mode", type=str, default="train", help="model: train or test or plot") parser.add_argument(
"-m", "--mode", type=str, default="train", help="model: train or test or plot"
)
# problem dimension # problem dimension
parser.add_argument("--prob_dim", default=2, type=int, help="dimension of the problem") parser.add_argument(
"--prob_dim", default=2, type=int, help="dimension of the problem"
)
# args for plot in point-based methods # args for plot in point-based methods
parser.add_argument("--plot", action="store_true", help="use profiler") parser.add_argument("--plot", action="store_true", help="use profiler")
# args for training # args for training
parser.add_argument("--gpu", type=int, default=0, help="which gpu: 0 for cpu, 1 for gpu 0, 2 for gpu 1, ...") parser.add_argument(
"--gpu",
type=int,
default=0,
help="which gpu: 0 for cpu, 1 for gpu 0, 2 for gpu 1, ...",
)
parser.add_argument("--batch_size", default=16, type=int) parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument("--max_epochs", default=20, type=int) parser.add_argument("--max_epochs", default=20, type=int)
parser.add_argument("--lr", default="0.01", type=float) parser.add_argument("--lr", default="0.01", type=float)
parser.add_argument("--resume_from_checkpoint", type=str, help="resume from checkpoint") parser.add_argument(
parser.add_argument("--num_workers", default=2, type=int, help="num_workers in DataLoader") "--resume_from_checkpoint", type=str, help="resume from checkpoint"
)
parser.add_argument(
"--num_workers", default=2, type=int, help="num_workers in DataLoader"
)
parser.add_argument("--seed", type=int, default=1, help="seed") parser.add_argument("--seed", type=int, default=1, help="seed")
parser.add_argument("--use_16bit", type=bool, default=False, help="use 16bit precision") parser.add_argument(
"--use_16bit", type=bool, default=False, help="use 16bit precision"
)
parser.add_argument("--profiler", action="store_true", help="use profiler") parser.add_argument("--profiler", action="store_true", help="use profiler")
# args for validation # args for validation
parser.add_argument("--val_check_interval", type=float, default=1, parser.add_argument(
help="how often within one training epoch to check the validation set") "--val_check_interval",
type=float,
default=1,
help="how often within one training epoch to check the validation set",
)
# args for testing # args for testing
parser.add_argument("-v", "--test_check_num", default='0', type=str, help="checkpoint for test") parser.add_argument(
"-v", "--test_check_num", default="0", type=str, help="checkpoint for test"
)
parser.add_argument("--test_args", action="store_true", help="print args") parser.add_argument("--test_args", action="store_true", help="print args")
# args from Model # args from Model
parser = Model.add_model_specific_args(parser) parser = Model.add_model_specific_args(parser)
hparams = parser.parse_args() hparams = parser.parse_args()
PointModel = ["RSVR", "RBM", "RandomForest", "Polynomial", "MLPP", "Kriging", "KInterpolation", "GInterpolation"] PointModel = [
"RSVR",
"RBM",
"RandomForest",
"Polynomial",
"MLPP",
"Kriging",
"KInterpolation",
"GInterpolation",
]
dpoint = True if hparams.model_name in PointModel else False dpoint = True if hparams.model_name in PointModel else False
# running # running
assert hparams.mode in ["train", "test", "plot"] assert hparams.mode in ["train", "test", "plot"]
@ -74,5 +109,5 @@ def main():
getattr(eval(hparams.mode), "main")(hparams) getattr(eval(hparams.mode), "main")(hparams)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -18,7 +18,6 @@ import src.models as models
class Model(LightningModule): class Model(LightningModule):
def __init__(self, hparams): def __init__(self, hparams):
super().__init__() super().__init__()
self.hparams = hparams self.hparams = hparams
@ -33,21 +32,57 @@ class Model(LightningModule):
self.default_layout = None self.default_layout = None
def _build_model(self): def _build_model(self):
model_list = ["SegNet_AlexNet", "SegNet_VGG", "SegNet_ResNet18", "SegNet_ResNet50", model_list = [
"SegNet_ResNet101", "SegNet_ResNet34", "SegNet_ResNet152", "SegNet_AlexNet",
"FPN_ResNet18", "FPN_ResNet50", "FPN_ResNet101", "FPN_ResNet34", "FPN_ResNet152", "SegNet_VGG",
"FCN_AlexNet", "FCN_VGG", "FCN_ResNet18", "FCN_ResNet50", "FCN_ResNet101", "SegNet_ResNet18",
"FCN_ResNet34", "FCN_ResNet152", "SegNet_ResNet50",
"UNet_VGG", "SegNet_ResNet101",
"MLP", "ConditionalNeuralProcess", "TransformerRecon", "SegNet_ResNet34",
"DenseDeepGCN"] "SegNet_ResNet152",
layout_model = self.hparams.model_name + '_' + self.hparams.backbone "FPN_ResNet18",
assert (layout_model in model_list or self.hparams.model_name in model_list) "FPN_ResNet50",
self.layout_model = layout_model if layout_model in model_list else self.hparams.model_name "FPN_ResNet101",
"FPN_ResNet34",
"FPN_ResNet152",
"FCN_AlexNet",
"FCN_VGG",
"FCN_ResNet18",
"FCN_ResNet50",
"FCN_ResNet101",
"FCN_ResNet34",
"FCN_ResNet152",
"UNet_VGG",
"MLP",
"ConditionalNeuralProcess",
"TransformerRecon",
"DenseDeepGCN",
]
layout_model = self.hparams.model_name + "_" + self.hparams.backbone
assert layout_model in model_list or self.hparams.model_name in model_list
self.layout_model = (
layout_model if layout_model in model_list else self.hparams.model_name
)
self.vec = True if self.layout_model in ["MLP", "ConditionalNeuralProcess", "TransformerRecon", "DenseDeepGCN"] else False self.vec = (
True
if self.layout_model
in ["MLP", "ConditionalNeuralProcess", "TransformerRecon", "DenseDeepGCN"]
else False
)
self.model = nn.ModuleList([getattr(models, self.layout_model)(input_dim=self.input_dim, output_dim=self.output_dim) for i in range(self.hparams.div_num*self.hparams.div_num)]) if self.vec else getattr(models, self.layout_model)(in_channels=1) self.model = (
nn.ModuleList(
[
getattr(models, self.layout_model)(
input_dim=self.input_dim, output_dim=self.output_dim
)
for i in range(self.hparams.div_num * self.hparams.div_num)
]
)
if self.vec
else getattr(models, self.layout_model)(in_channels=1)
)
def forward(self, x): def forward(self, x):
@ -57,19 +92,36 @@ class Model(LightningModule):
if num == 0: if num == 0:
output = submodel(x[1]).unsqueeze(1) output = submodel(x[1]).unsqueeze(1)
else: else:
output = torch.cat((output, submodel(x[1]).unsqueeze(1)), axis=1) output = torch.cat(
elif self.layout_model == "ConditionalNeuralProcess" or self.layout_model == "TransformerRecon": (output, submodel(x[1]).unsqueeze(1)), axis=1
)
elif (
self.layout_model == "ConditionalNeuralProcess"
or self.layout_model == "TransformerRecon"
):
for num, submodel in enumerate(self.model): for num, submodel in enumerate(self.model):
if num == 0: if num == 0:
output = submodel(x[0], x[1], (x[2])[:,0,...], (x[3])[:,0,...]).unsqueeze(1) output = submodel(
x[0], x[1], (x[2])[:, 0, ...], (x[3])[:, 0, ...]
).unsqueeze(1)
else: else:
output = torch.cat((output, submodel(x[0], x[1], (x[2])[:,num,...], (x[3])[:,num,...]).unsqueeze(1)), axis=1) output = torch.cat(
elif self.layout_model =="DenseDeepGCN": (
output,
submodel(
x[0], x[1], (x[2])[:, num, ...], (x[3])[:, num, ...]
).unsqueeze(1),
),
axis=1,
)
elif self.layout_model == "DenseDeepGCN":
for num, submodel in enumerate(self.model): for num, submodel in enumerate(self.model):
if num == 0: if num == 0:
output = submodel(x[num, ...]).unsqueeze(1) output = submodel(x[num, ...]).unsqueeze(1)
else: else:
output = torch.cat((output, submodel(x[num, ...]).unsqueeze(1)), axis=1) output = torch.cat(
(output, submodel(x[num, ...]).unsqueeze(1)), axis=1
)
else: else:
output = self.model(x) output = self.model(x)
@ -85,8 +137,7 @@ class Model(LightningModule):
return loader return loader
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer, gamma=0.9) scheduler = ExponentialLR(optimizer, gamma=0.9)
return [optimizer], [scheduler] return [optimizer], [scheduler]
@ -154,14 +205,18 @@ class Model(LightningModule):
return trainval_dataset, test_dataset return trainval_dataset, test_dataset
def prepare_data(self): def prepare_data(self):
"""Prepare dataset """Prepare dataset"""
""" trainval_dataset, test_dataset = (
trainval_dataset, test_dataset = self.read_vec_data() if self.vec else self.read_image_data() self.read_vec_data() if self.vec else self.read_image_data()
)
# split train/val set # split train/val set
train_length, val_length = int(len(trainval_dataset) * 0.8), len(trainval_dataset)-int(len(trainval_dataset) * 0.8) train_length, val_length = int(len(trainval_dataset) * 0.8), len(
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, trainval_dataset
[train_length, val_length]) ) - int(len(trainval_dataset) * 0.8)
train_dataset, val_dataset = torch.utils.data.random_split(
trainval_dataset, [train_length, val_length]
)
print( print(
f"Prepared dataset, train:{int(len(train_dataset))},\ f"Prepared dataset, train:{int(len(train_dataset))},\
val:{int(len(val_dataset))}, test:{len(test_dataset)}" val:{int(len(val_dataset))}, test:{len(test_dataset)}"
@ -174,7 +229,9 @@ class Model(LightningModule):
self.default_layout = trainval_dataset._layout() self.default_layout = trainval_dataset._layout()
self.default_layout = torch.from_numpy(self.default_layout).unsqueeze(0).unsqueeze(0) self.default_layout = (
torch.from_numpy(self.default_layout).unsqueeze(0).unsqueeze(0)
)
def train_dataloader(self): def train_dataloader(self):
return self.train_dataset return self.train_dataset
@ -194,24 +251,57 @@ class Model(LightningModule):
heat_obs, heat = batch heat_obs, heat = batch
heat_info = heat_obs heat_info = heat_obs
if self.layout_model=="ConditionalNeuralProcess" or self.layout_model=="TransformerRecon": if (
heat_info[1] = heat_info[1].transpose(1,2) self.layout_model == "ConditionalNeuralProcess"
heat_info[3] = heat_info[3].transpose(2,3) or self.layout_model == "TransformerRecon"
heat = heat.transpose(2,3) ):
elif self.layout_model=="DenseDeepGCN": heat_info[1] = heat_info[1].transpose(1, 2)
heat_obs=heat_obs.squeeze() heat_info[3] = heat_info[3].transpose(2, 3)
pseudo_heat = torch.zeros_like(heat[:,0,:]).squeeze() heat = heat.transpose(2, 3)
inputs = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,0,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0) elif self.layout_model == "DenseDeepGCN":
heat_obs = heat_obs.squeeze()
pseudo_heat = torch.zeros_like(heat[:, 0, :]).squeeze()
inputs = (
torch.cat(
(
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
torch.cat((obs_index, pred_index[:, 0, ...]), 1),
),
2,
)
.transpose(1, 2)
.unsqueeze(-1)
.unsqueeze(0)
)
for i in range(self.hparams.div_num*self.hparams.div_num-1): for i in range(self.hparams.div_num * self.hparams.div_num - 1):
input_single = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,i+1,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0) input_single = (
torch.cat(
(
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
torch.cat((obs_index, pred_index[:, i + 1, ...]), 1),
),
2,
)
.transpose(1, 2)
.unsqueeze(-1)
.unsqueeze(0)
)
inputs = torch.cat((inputs, input_single), 0) inputs = torch.cat((inputs, input_single), 0)
heat_info = inputs heat_info = inputs
labels = torch.cat((heat_obs,heat[:,0,:].squeeze()), 1).unsqueeze(1).unsqueeze(1) labels = (
for i in range(self.hparams.div_num*self.hparams.div_num-1): torch.cat((heat_obs, heat[:, 0, :].squeeze()), 1)
label = torch.cat((heat_obs,heat[:,i,:].squeeze()), 1).unsqueeze(1).unsqueeze(1) .unsqueeze(1)
.unsqueeze(1)
)
for i in range(self.hparams.div_num * self.hparams.div_num - 1):
label = (
torch.cat((heat_obs, heat[:, i, :].squeeze()), 1)
.unsqueeze(1)
.unsqueeze(1)
)
labels = torch.cat((labels, label), 1) labels = torch.cat((labels, label), 1)
heat = labels heat = labels
@ -233,26 +323,59 @@ class Model(LightningModule):
heat_obs, heat = batch heat_obs, heat = batch
heat_info = heat_obs heat_info = heat_obs
if self.layout_model=="ConditionalNeuralProcess" or self.layout_model=="TransformerRecon": if (
heat_info[1] = heat_info[1].transpose(1,2) self.layout_model == "ConditionalNeuralProcess"
heat_info[3] = heat_info[3].transpose(2,3) or self.layout_model == "TransformerRecon"
heat = heat.transpose(2,3) ):
elif self.layout_model=="DenseDeepGCN": heat_info[1] = heat_info[1].transpose(1, 2)
heat_obs=heat_obs.squeeze() heat_info[3] = heat_info[3].transpose(2, 3)
heat = heat.transpose(2, 3)
elif self.layout_model == "DenseDeepGCN":
heat_obs = heat_obs.squeeze()
pseudo_heat = torch.zeros_like(heat[:,0,:]).squeeze() pseudo_heat = torch.zeros_like(heat[:, 0, :]).squeeze()
inputs = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,0,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0) inputs = (
torch.cat(
(
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
torch.cat((obs_index, pred_index[:, 0, ...]), 1),
),
2,
)
.transpose(1, 2)
.unsqueeze(-1)
.unsqueeze(0)
)
for i in range(self.hparams.div_num*self.hparams.div_num-1): for i in range(self.hparams.div_num * self.hparams.div_num - 1):
input_single = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,i+1,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0) input_single = (
torch.cat(
(
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
torch.cat((obs_index, pred_index[:, i + 1, ...]), 1),
),
2,
)
.transpose(1, 2)
.unsqueeze(-1)
.unsqueeze(0)
)
inputs = torch.cat((inputs, input_single), 0) inputs = torch.cat((inputs, input_single), 0)
heat_info = inputs heat_info = inputs
labels = torch.cat((heat_obs,heat[:,0,:].squeeze()), 1).unsqueeze(1).unsqueeze(1) labels = (
for i in range(self.hparams.div_num*self.hparams.div_num-1): torch.cat((heat_obs, heat[:, 0, :].squeeze()), 1)
label = torch.cat((heat_obs,heat[:,i,:].squeeze()), 1).unsqueeze(1).unsqueeze(1) .unsqueeze(1)
.unsqueeze(1)
)
for i in range(self.hparams.div_num * self.hparams.div_num - 1):
label = (
torch.cat((heat_obs, heat[:, i, :].squeeze()), 1)
.unsqueeze(1)
.unsqueeze(1)
)
labels = torch.cat((labels, label), 1) labels = torch.cat((labels, label), 1)
heat = labels heat = labels
@ -277,16 +400,41 @@ class Model(LightningModule):
heat_obs, heat = batch heat_obs, heat = batch
heat_info = heat_obs heat_info = heat_obs
if self.layout_model=="ConditionalNeuralProcess" or self.layout_model=="TransformerRecon": if (
heat_info[1] = heat_info[1].transpose(1,2) self.layout_model == "ConditionalNeuralProcess"
heat_info[3] = heat_info[3].transpose(2,3) or self.layout_model == "TransformerRecon"
elif self.layout_model=="DenseDeepGCN": ):
heat_info[1] = heat_info[1].transpose(1, 2)
heat_info[3] = heat_info[3].transpose(2, 3)
elif self.layout_model == "DenseDeepGCN":
heat_obs = heat_obs.squeeze() heat_obs = heat_obs.squeeze()
pseudo_heat = torch.zeros_like(heat0[:,0,:]).squeeze() pseudo_heat = torch.zeros_like(heat0[:, 0, :]).squeeze()
inputs = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,0,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0) inputs = (
torch.cat(
(
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
torch.cat((obs_index, pred_index[:, 0, ...]), 1),
),
2,
)
.transpose(1, 2)
.unsqueeze(-1)
.unsqueeze(0)
)
for i in range(self.hparams.div_num*self.hparams.div_num-1): for i in range(self.hparams.div_num * self.hparams.div_num - 1):
input_single = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,i+1,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0) input_single = (
torch.cat(
(
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
torch.cat((obs_index, pred_index[:, i + 1, ...]), 1),
),
2,
)
.transpose(1, 2)
.unsqueeze(-1)
.unsqueeze(0)
)
inputs = torch.cat((inputs, input_single), 0) inputs = torch.cat((inputs, input_single), 0)
heat_info = inputs heat_info = inputs
@ -294,17 +442,28 @@ class Model(LightningModule):
heat_pred0 = self(heat_info) heat_pred0 = self(heat_info)
if self.vec: if self.vec:
if self.layout_model=="DenseDeepGCN": if self.layout_model == "DenseDeepGCN":
heat_pred0 = heat_pred0[...,-self.output_dim:] heat_pred0 = heat_pred0[..., -self.output_dim :]
else: else:
pass pass
heat_pred0 = heat_pred0.reshape((-1, self.hparams.div_num*self.hparams.div_num, int(200 / self.hparams.div_num), int(200 / self.hparams.div_num))) heat_pred0 = heat_pred0.reshape(
(
-1,
self.hparams.div_num * self.hparams.div_num,
int(200 / self.hparams.div_num),
int(200 / self.hparams.div_num),
)
)
heat_pred = torch.zeros_like(heat_pred0).reshape((-1, 1, 200, 200)) heat_pred = torch.zeros_like(heat_pred0).reshape((-1, 1, 200, 200))
for i in range(self.hparams.div_num): for i in range(self.hparams.div_num):
for j in range(self.hparams.div_num): for j in range(self.hparams.div_num):
heat_pred[..., 0+i:200:self.hparams.div_num, 0+j:200:self.hparams.div_num] = heat_pred0[:, self.hparams.div_num*i+j,...].unsqueeze(1) heat_pred[
heat_pred = heat_pred.transpose(2,3) ...,
0 + i : 200 : self.hparams.div_num,
0 + j : 200 : self.hparams.div_num,
] = heat_pred0[:, self.hparams.div_num * i + j, ...].unsqueeze(1)
heat_pred = heat_pred.transpose(2, 3)
heat = heat.unsqueeze(1) heat = heat.unsqueeze(1)
else: else:
@ -312,49 +471,110 @@ class Model(LightningModule):
loss = self.criterion(heat_pred, heat) * self.hparams.std_heat loss = self.criterion(heat_pred, heat) * self.hparams.std_heat
default_layout = torch.repeat_interleave(self.default_layout, repeats=heat_pred.size(0), dim=0).float().to(device=heat.device) default_layout = (
torch.repeat_interleave(
self.default_layout, repeats=heat_pred.size(0), dim=0
)
.float()
.to(device=heat.device)
)
ones = torch.ones_like(default_layout).to(device=heat.device) ones = torch.ones_like(default_layout).to(device=heat.device)
zeros = torch.zeros_like(default_layout).to(device=heat.device) zeros = torch.zeros_like(default_layout).to(device=heat.device)
layout_ind = torch.where(default_layout<1e-2,zeros,ones) layout_ind = torch.where(default_layout < 1e-2, zeros, ones)
loss_2 = torch.sum(torch.abs(torch.sub(heat, heat_pred)) *layout_ind )* self.hparams.std_heat/ torch.sum(layout_ind) loss_2 = (
#--------------------------------- torch.sum(torch.abs(torch.sub(heat, heat_pred)) * layout_ind)
loss_1 = torch.sum(torch.max(torch.max(torch.max(torch.abs(torch.sub(heat,heat_pred)) * layout_ind, 3).values, 2).values * self.hparams.std_heat,1).values)/heat_pred.size(0) * self.hparams.std_heat
#--------------------------------- / torch.sum(layout_ind)
)
# ---------------------------------
loss_1 = (
torch.sum(
torch.max(
torch.max(
torch.max(
torch.abs(torch.sub(heat, heat_pred)) * layout_ind, 3
).values,
2,
).values
* self.hparams.std_heat,
1,
).values
)
/ heat_pred.size(0)
)
# ---------------------------------
boundary_ones = torch.zeros_like(default_layout).to(device=heat.device) boundary_ones = torch.zeros_like(default_layout).to(device=heat.device)
boundary_ones[..., -2:, :] = ones[..., -2:, :] boundary_ones[..., -2:, :] = ones[..., -2:, :]
boundary_ones[..., :2, :] = ones[..., :2, :] boundary_ones[..., :2, :] = ones[..., :2, :]
boundary_ones[..., :, :2] = ones[..., :, :2] boundary_ones[..., :, :2] = ones[..., :, :2]
boundary_ones[..., :, -2:] = ones[..., :, -2:] boundary_ones[..., :, -2:] = ones[..., :, -2:]
loss_3 = torch.sum(torch.abs(torch.sub(heat, heat_pred)) *boundary_ones )* self.hparams.std_heat/ torch.sum(boundary_ones) loss_3 = (
#---------------------------------- torch.sum(torch.abs(torch.sub(heat, heat_pred)) * boundary_ones)
loss_4 = torch.sum(torch.max(torch.max(torch.max(torch.abs(torch.sub(heat,heat_pred)), 3).values, 2).values * self.hparams.std_heat,1).values)/heat_pred.size(0) * self.hparams.std_heat
/ torch.sum(boundary_ones)
)
# ----------------------------------
loss_4 = (
torch.sum(
torch.max(
torch.max(
torch.max(torch.abs(torch.sub(heat, heat_pred)), 3).values, 2
).values
* self.hparams.std_heat,
1,
).values
)
/ heat_pred.size(0)
)
return {"test_loss": loss, "test_loss_1": loss_1, "test_loss_2": loss_2, "test_loss_3": loss_3, "test_loss_4": loss_4} return {
"test_loss": loss,
"test_loss_1": loss_1,
"test_loss_2": loss_2,
"test_loss_3": loss_3,
"test_loss_4": loss_4,
}
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
test_loss_mean = torch.stack([x["test_loss"] for x in outputs]).mean() test_loss_mean = torch.stack([x["test_loss"] for x in outputs]).mean()
self.log("test_loss (" + "MAE" +")", test_loss_mean.item()) self.log("test_loss (" + "MAE" + ")", test_loss_mean.item())
#test_loss_max = torch.max(torch.stack([x["test_loss_1"] for x in outputs])) # test_loss_max = torch.max(torch.stack([x["test_loss_1"] for x in outputs]))
test_loss_max = torch.stack([x["test_loss_1"] for x in outputs]).mean() test_loss_max = torch.stack([x["test_loss_1"] for x in outputs]).mean()
self.log("test_loss_1 (" + "M-CAE" +")", test_loss_max.item()) self.log("test_loss_1 (" + "M-CAE" + ")", test_loss_max.item())
test_loss_com_mean = torch.stack([x["test_loss_2"] for x in outputs]).mean() test_loss_com_mean = torch.stack([x["test_loss_2"] for x in outputs]).mean()
self.log("test_loss_2 (" + "CMAE" +")", test_loss_com_mean.item()) self.log("test_loss_2 (" + "CMAE" + ")", test_loss_com_mean.item())
test_loss_bc_mean = torch.stack([x["test_loss_3"] for x in outputs]).mean() test_loss_bc_mean = torch.stack([x["test_loss_3"] for x in outputs]).mean()
self.log("test_loss_3 (" + "BMAE" +")", test_loss_bc_mean.item()) self.log("test_loss_3 (" + "BMAE" + ")", test_loss_bc_mean.item())
test_loss_max_1 = torch.stack([x["test_loss_4"] for x in outputs]).mean() test_loss_max_1 = torch.stack([x["test_loss_4"] for x in outputs]).mean()
self.log("test_loss_4 (" + "MaxAE" + ")", test_loss_max_1.item()) self.log("test_loss_4 (" + "MaxAE" + ")", test_loss_max_1.item())
@staticmethod @staticmethod
def add_model_specific_args(parser): # pragma: no-cover def add_model_specific_args(parser): # pragma: no-cover
"""Parameters you define here will be available to your model through `self.hparams`. """Parameters you define here will be available to your model through `self.hparams`."""
"""
# dataset args # dataset args
parser.add_argument("--data_root", type=str, required=True, help="path of dataset") parser.add_argument(
parser.add_argument("--train_list", type=str, required=True, help="path of train dataset list") "--data_root", type=str, required=True, help="path of dataset"
parser.add_argument("--train_size", default=0.8, type=float, help="train_size in train_test_split") )
parser.add_argument("--test_list", type=str, required=True, help="path of test dataset list") parser.add_argument(
#parser.add_argument("--boundary", type=str, default="rm_wall", help="boundary condition") "--train_list", type=str, required=True, help="path of train dataset list"
parser.add_argument("--data_format", type=str, default="mat", choices=["mat", "h5"], help="dataset format") )
parser.add_argument(
"--train_size",
default=0.8,
type=float,
help="train_size in train_test_split",
)
parser.add_argument(
"--test_list", type=str, required=True, help="path of test dataset list"
)
# parser.add_argument("--boundary", type=str, default="rm_wall", help="boundary condition")
parser.add_argument(
"--data_format",
type=str,
default="mat",
choices=["mat", "h5"],
help="dataset format",
)
# Normalization params # Normalization params
parser.add_argument("--mean_layout", default=0, type=float) parser.add_argument("--mean_layout", default=0, type=float)
@ -364,10 +584,19 @@ class Model(LightningModule):
# Model params (opt) # Model params (opt)
parser.add_argument("--input_size", default=200, type=int) parser.add_argument("--input_size", default=200, type=int)
parser.add_argument("--model_name", type=str, default='FCN', help="the name of chosen model") parser.add_argument(
parser.add_argument("--backbone", type=str, default='AlexNet', help="the used backbone in the regression model") "--model_name", type=str, default="FCN", help="the name of chosen model"
)
parser.add_argument(
"--backbone",
type=str,
default="AlexNet",
help="the used backbone in the regression model",
)
# div_num for vec (opt) # div_num for vec (opt)
parser.add_argument("--div_num", default=4, type=int, help="division of heat source systems") parser.add_argument(
"--div_num", default=4, type=int, help="division of heat source systems"
)
return parser return parser

View File

@ -6,8 +6,7 @@ from .loadresponse import LoadResponse, LoadPointResponse, LoadVecResponse, mat_
class LayoutDataset(LoadResponse): class LayoutDataset(LoadResponse):
"""Layout dataset (mutiple files) generated by 'layout-generator'. """Layout dataset (mutiple files) generated by 'layout-generator'."""
"""
def __init__( def __init__(
self, self,
@ -20,8 +19,9 @@ class LayoutDataset(LoadResponse):
resp_name="u", resp_name="u",
): ):
test_name = os.path.splitext(os.path.basename(list_path))[0] test_name = os.path.splitext(os.path.basename(list_path))[0]
subdir = os.path.join("train", "train") \ subdir = (
if train else os.path.join("test", test_name) os.path.join("train", "train") if train else os.path.join("test", test_name)
)
# find the path of the list of train/test samples # find the path of the list of train/test samples
list_path = os.path.join(root, list_path) list_path = os.path.join(root, list_path)
@ -42,7 +42,6 @@ class LayoutDataset(LoadResponse):
class LayoutPointDataset(LoadPointResponse): class LayoutPointDataset(LoadPointResponse):
def __init__( def __init__(
self, self,
root, root,
@ -53,8 +52,9 @@ class LayoutPointDataset(LoadPointResponse):
layout_name="F", layout_name="F",
): ):
test_name = os.path.splitext(os.path.basename(list_path))[0] test_name = os.path.splitext(os.path.basename(list_path))[0]
subdir = os.path.join("train", "train") \ subdir = (
if train else os.path.join("test", test_name) os.path.join("train", "train") if train else os.path.join("test", test_name)
)
# find the path of the list of train/test samples # find the path of the list of train/test samples
list_path = os.path.join(root, list_path) list_path = os.path.join(root, list_path)
@ -74,8 +74,7 @@ class LayoutPointDataset(LoadPointResponse):
class LayoutVecDataset(LoadVecResponse): class LayoutVecDataset(LoadVecResponse):
"""Layout dataset (mutiple files) generated by 'layout-generator'. """Layout dataset (mutiple files) generated by 'layout-generator'."""
"""
def __init__( def __init__(
self, self,
@ -89,8 +88,9 @@ class LayoutVecDataset(LoadVecResponse):
resp_name="u", resp_name="u",
): ):
test_name = os.path.splitext(os.path.basename(list_path))[0] test_name = os.path.splitext(os.path.basename(list_path))[0]
subdir = os.path.join("train", "train") \ subdir = (
if train else os.path.join("test", test_name) os.path.join("train", "train") if train else os.path.join("test", test_name)
)
# find the path of the list of train/test samples # find the path of the list of train/test samples
list_path = os.path.join(root, list_path) list_path = os.path.join(root, list_path)

View File

@ -28,22 +28,22 @@ class LoadResponse(VisionDataset):
target_transform=None, target_transform=None,
is_valid_file=None, is_valid_file=None,
): ):
super().__init__( super().__init__(root, transform=transform, target_transform=target_transform)
root, transform=transform, target_transform=target_transform
)
self.list_path = list_path self.list_path = list_path
self.loader = loader self.loader = loader
self.load_name = load_name self.load_name = load_name
self.resp_name = resp_name self.resp_name = resp_name
self.layout_name = layout_name self.layout_name = layout_name
self.extensions = extensions self.extensions = extensions
self.sample_files = make_dataset_list(root, list_path, extensions, is_valid_file) self.sample_files = make_dataset_list(
root, list_path, extensions, is_valid_file
)
def __getitem__(self, index): def __getitem__(self, index):
path = self.sample_files[index] path = self.sample_files[index]
load, resp, _ = self.loader(path, self.load_name, self.resp_name) load, resp, _ = self.loader(path, self.load_name, self.resp_name)
load[np.where(load<TOL)] = 298 load[np.where(load < TOL)] = 298
if self.transform is not None: if self.transform is not None:
load = self.transform(load) load = self.transform(load)
@ -54,7 +54,9 @@ class LoadResponse(VisionDataset):
def _layout(self): def _layout(self):
path = self.sample_files[0] path = self.sample_files[0]
_, _, layout = self.loader(path, self.load_name, self.resp_name, self.layout_name) _, _, layout = self.loader(
path, self.load_name, self.resp_name, self.layout_name
)
return layout return layout
def __len__(self): def __len__(self):
@ -75,20 +77,22 @@ class LoadPointResponse(VisionDataset):
extensions=None, extensions=None,
is_valid_file=None, is_valid_file=None,
): ):
super().__init__( super().__init__(root)
root
)
self.list_path = list_path self.list_path = list_path
self.loader = loader self.loader = loader
self.load_name = load_name self.load_name = load_name
self.resp_name = resp_name self.resp_name = resp_name
self.layout_name = layout_name self.layout_name = layout_name
self.extensions = extensions self.extensions = extensions
self.sample_files = make_dataset_list(root, list_path, extensions, is_valid_file) self.sample_files = make_dataset_list(
root, list_path, extensions, is_valid_file
)
def __getitem__(self, index): def __getitem__(self, index):
path = self.sample_files[index] path = self.sample_files[index]
load, resp, layout = self.loader(path, self.load_name, self.resp_name, self.layout_name) load, resp, layout = self.loader(
path, self.load_name, self.resp_name, self.layout_name
)
return load, resp, layout return load, resp, layout
@ -105,7 +109,7 @@ class LoadVecResponse(VisionDataset):
load_name="u_obs", load_name="u_obs",
resp_name="u", resp_name="u",
layout_name="F", layout_name="F",
div_num = 4, div_num=4,
extensions=None, extensions=None,
transform=None, transform=None,
target_transform=None, target_transform=None,
@ -121,7 +125,9 @@ class LoadVecResponse(VisionDataset):
self.resp_name = resp_name self.resp_name = resp_name
self.layout_name = layout_name self.layout_name = layout_name
self.div_num = div_num self.div_num = div_num
self.sample_files = make_dataset_list(root, list_path, extensions, is_valid_file) self.sample_files = make_dataset_list(
root, list_path, extensions, is_valid_file
)
def __getitem__(self, index): def __getitem__(self, index):
@ -135,11 +141,17 @@ class LoadVecResponse(VisionDataset):
if self.target_transform is not None: if self.target_transform is not None:
y_target = (y_target - self.target_transform[0]) / self.transform[1] y_target = (y_target - self.target_transform[0]) / self.transform[1]
resp = (resp-self.target_transform[0]) / self.transform[1] resp = (resp - self.target_transform[0]) / self.transform[1]
else: else:
pass pass
return x_context, y_context.type(torch.FloatTensor), x_target, y_target.type(torch.FloatTensor), resp.type(torch.FloatTensor) return (
x_context,
y_context.type(torch.FloatTensor),
x_target,
y_target.type(torch.FloatTensor),
resp.type(torch.FloatTensor),
)
def __len__(self): def __len__(self):
return len(self.sample_files) return len(self.sample_files)
@ -149,22 +161,34 @@ class LoadVecResponse(VisionDataset):
load, resp, _ = self.loader(path, self.load_name, self.resp_name) load, resp, _ = self.loader(path, self.load_name, self.resp_name)
monitor_x, monitor_y = np.where(load > TOL) monitor_x, monitor_y = np.where(load > TOL)
y_context = torch.from_numpy(load[monitor_x, monitor_y].reshape(1,-1)).float() y_context = torch.from_numpy(load[monitor_x, monitor_y].reshape(1, -1)).float()
monitor_x, monitor_y = monitor_x / load.shape[0], monitor_y / load.shape[1] monitor_x, monitor_y = monitor_x / load.shape[0], monitor_y / load.shape[1]
x_context = torch.from_numpy(np.concatenate([monitor_x.reshape(-1,1),monitor_y.reshape(-1,1)], axis=1)).float() x_context = torch.from_numpy(
np.concatenate([monitor_x.reshape(-1, 1), monitor_y.reshape(-1, 1)], axis=1)
).float()
x = np.linspace(0, load.shape[0]-1, load.shape[0]).astype(int) x = np.linspace(0, load.shape[0] - 1, load.shape[0]).astype(int)
y = np.linspace(1, load.shape[1]-1, load.shape[1]).astype(int) y = np.linspace(1, load.shape[1] - 1, load.shape[1]).astype(int)
x_target = None x_target = None
y_target = None y_target = None
for i in range(self.div_num): for i in range(self.div_num):
for j in range(self.div_num): for j in range(self.div_num):
x1, y1 = x[0+i:np.size(x):self.div_num], y[0+j:np.size(y):self.div_num] x1, y1 = (
x[0 + i : np.size(x) : self.div_num],
y[0 + j : np.size(y) : self.div_num],
)
x1, y1 = np.meshgrid(x1, y1) x1, y1 = np.meshgrid(x1, y1)
x_target0 = torch.from_numpy(np.concatenate([x1.reshape(-1, 1), y1.reshape(-1, 1)], axis=1) / np.max(load.shape)).float().unsqueeze(0) x_target0 = (
y_target0 = torch.from_numpy(resp[x1, y1].reshape(1,-1)).unsqueeze(0) torch.from_numpy(
np.concatenate([x1.reshape(-1, 1), y1.reshape(-1, 1)], axis=1)
/ np.max(load.shape)
)
.float()
.unsqueeze(0)
)
y_target0 = torch.from_numpy(resp[x1, y1].reshape(1, -1)).unsqueeze(0)
if x_target is not None: if x_target is not None:
x_target = torch.cat((x_target, x_target0), 0) x_target = torch.cat((x_target, x_target0), 0)
else: else:
@ -179,7 +203,9 @@ class LoadVecResponse(VisionDataset):
def _layout(self): def _layout(self):
path = self.sample_files[0] path = self.sample_files[0]
_, _, layout = self.loader(path, self.load_name, self.resp_name, self.layout_name) _, _, layout = self.loader(
path, self.load_name, self.resp_name, self.layout_name
)
return layout return layout
def _inputdim(self): def _inputdim(self):
@ -188,9 +214,9 @@ class LoadVecResponse(VisionDataset):
monitor_x, _ = np.where(load > TOL) monitor_x, _ = np.where(load > TOL)
return np.size(monitor_x) return np.size(monitor_x)
def make_dataset(root_dir, extensions=None, is_valid_file=None): def make_dataset(root_dir, extensions=None, is_valid_file=None):
"""make_dataset() from torchvision. """make_dataset() from torchvision."""
"""
files = [] files = []
root_dir = os.path.expanduser(root_dir) root_dir = os.path.expanduser(root_dir)
if not ((extensions is None) ^ (is_valid_file is None)): if not ((extensions is None) ^ (is_valid_file is None)):
@ -211,8 +237,7 @@ def make_dataset(root_dir, extensions=None, is_valid_file=None):
def make_dataset_list(root_dir, list_path, extensions=None, is_valid_file=None): def make_dataset_list(root_dir, list_path, extensions=None, is_valid_file=None):
"""make_dataset() from torchvision. """make_dataset() from torchvision."""
"""
files = [] files = []
root_dir = os.path.expanduser(root_dir) root_dir = os.path.expanduser(root_dir)
if not ((extensions is None) ^ (is_valid_file is None)): if not ((extensions is None) ^ (is_valid_file is None)):
@ -224,7 +249,7 @@ def make_dataset_list(root_dir, list_path, extensions=None, is_valid_file=None):
is_valid_file = lambda x: has_allowed_extension(x, extensions) is_valid_file = lambda x: has_allowed_extension(x, extensions)
assert os.path.isdir(root_dir), root_dir assert os.path.isdir(root_dir), root_dir
with open(list_path, 'r') as rf: with open(list_path, "r") as rf:
for line in rf.readlines(): for line in rf.readlines():
data_path = line.strip() data_path = line.strip()
path = os.path.join(root_dir, data_path) path = os.path.join(root_dir, data_path)
@ -247,9 +272,9 @@ def mat_loader(path, load_name, resp_name=None, layout_name=None):
if __name__ == "__main__": if __name__ == "__main__":
total_num = 50000 total_num = 50000
with open('train'+str(total_num)+'.txt', 'w') as wf: with open("train" + str(total_num) + ".txt", "w") as wf:
for idx in range(int(total_num*0.8)): for idx in range(int(total_num * 0.8)):
wf.write('Example'+str(idx)+'.mat'+'\n') wf.write("Example" + str(idx) + ".mat" + "\n")
with open('val'+str(total_num)+'.txt', 'w') as wf: with open("val" + str(total_num) + ".txt", "w") as wf:
for idx in range(int(total_num*0.8), total_num): for idx in range(int(total_num * 0.8), total_num):
wf.write('Example'+str(idx)+'.mat'+'\n') wf.write("Example" + str(idx) + ".mat" + "\n")

View File

@ -69,7 +69,13 @@ class DeterministicDecoder(nn.Module):
class ConditionalNeuralProcess(nn.Module): class ConditionalNeuralProcess(nn.Module):
def __init__(self, input_dim=None, output_dim=None, encoder_sizes=[2+1, 128, 128, 128, 256], decoder_sizes=[256+2, 256, 256, 128, 128, 2]): def __init__(
self,
input_dim=None,
output_dim=None,
encoder_sizes=[2 + 1, 128, 128, 128, 256],
decoder_sizes=[256 + 2, 256, 256, 128, 128, 2],
):
super(ConditionalNeuralProcess, self).__init__() super(ConditionalNeuralProcess, self).__init__()
self._encoder = DeterministicEncoder(encoder_sizes) self._encoder = DeterministicEncoder(encoder_sizes)
self._decoder = DeterministicDecoder(decoder_sizes) self._decoder = DeterministicDecoder(decoder_sizes)
@ -79,7 +85,7 @@ class ConditionalNeuralProcess(nn.Module):
dist, mu, sigma = self._decoder(representation, x_target) dist, mu, sigma = self._decoder(representation, x_target)
log_p = None if y_target is None else dist.log_prob(y_target) log_p = None if y_target is None else dist.log_prob(y_target)
#return log_p, mu, sigma # return log_p, mu, sigma
return mu return mu
@ -87,7 +93,7 @@ def input_mapping(x, B):
if B is None: if B is None:
return x return x
else: else:
x_proj = (2. * np.pi * x) @ B.t() x_proj = (2.0 * np.pi * x) @ B.t()
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

View File

@ -7,18 +7,24 @@ from .util.backbone import *
__all__ = [ __all__ = [
"FCN_VGG", "FCN_AlexNet", "FCN_ResNet18", "FCN_ResNet34", "FCN_VGG",
"FCN_ResNet50", "FCN_ResNet101", "FCN_ResNet152", "FCN_AlexNet",
"FCN_ResNet18",
"FCN_ResNet34",
"FCN_ResNet50",
"FCN_ResNet101",
"FCN_ResNet152",
] ]
class Conv3x3GNReLU(nn.Module): class Conv3x3GNReLU(nn.Module):
def __init__(self, in_channels, out_channels, upsample=False): def __init__(self, in_channels, out_channels, upsample=False):
super().__init__() super().__init__()
self.upsample = upsample self.upsample = upsample
self.block = nn.Sequential( self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False), nn.Conv2d(
in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
),
nn.GroupNorm(32, out_channels), nn.GroupNorm(32, out_channels),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
@ -31,18 +37,19 @@ class Conv3x3GNReLU(nn.Module):
class FCN_VGG(nn.Module): class FCN_VGG(nn.Module):
def __init__(self, inter_channels=256, in_channels=1, bn=False): def __init__(self, inter_channels=256, in_channels=1, bn=False):
super(FCN_VGG, self).__init__() super(FCN_VGG, self).__init__()
vgg = vgg16() vgg = vgg16()
features, classifier = list(vgg.features.children()), list(vgg.classifier.children()) features, classifier = list(vgg.features.children()), list(
vgg.classifier.children()
)
if in_channels != 3: if in_channels != 3:
features[0] = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1) features[0] = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
for f in features: for f in features:
if 'MaxPool' in f.__class__.__name__: if "MaxPool" in f.__class__.__name__:
f.ceil_mode = True f.ceil_mode = True
elif 'ReLU' in f.__class__.__name__: elif "ReLU" in f.__class__.__name__:
f.inplace = True f.inplace = True
features_temp = [] features_temp = []
@ -53,7 +60,7 @@ class FCN_VGG(nn.Module):
features_temp.append(nn.GroupNorm(32, features[i].out_channels)) features_temp.append(nn.GroupNorm(32, features[i].out_channels))
self.features3 = nn.Sequential(*features[:17]) self.features3 = nn.Sequential(*features[:17])
self.features4 = nn.Sequential(*features[17: 24]) self.features4 = nn.Sequential(*features[17:24])
self.features5 = nn.Sequential(*features[24:]) self.features5 = nn.Sequential(*features[24:])
self.score_pool3 = nn.Conv2d(256, inter_channels, kernel_size=1) self.score_pool3 = nn.Conv2d(256, inter_channels, kernel_size=1)
@ -67,7 +74,9 @@ class FCN_VGG(nn.Module):
fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True), score_fr fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True), score_fr
) )
self.upscore2 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True) self.upscore2 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True)
self.upscore_pool4 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True) self.upscore_pool4 = Conv3x3GNReLU(
inter_channels, inter_channels, upsample=True
)
self.final_conv = nn.Conv2d(inter_channels, 1, kernel_size=1) self.final_conv = nn.Conv2d(inter_channels, 1, kernel_size=1)
def forward(self, x): def forward(self, x):
@ -82,12 +91,16 @@ class FCN_VGG(nn.Module):
upscore_pool4 = self.upscore_pool4(score_pool4 + upscore2, pool3.size()[-2:]) upscore_pool4 = self.upscore_pool4(score_pool4 + upscore2, pool3.size()[-2:])
score_pool3 = self.score_pool3(pool3) score_pool3 = self.score_pool3(pool3)
upscore8 = F.interpolate(self.final_conv(score_pool3 + upscore_pool4), x.size()[-2:], mode='bilinear', align_corners=True) upscore8 = F.interpolate(
self.final_conv(score_pool3 + upscore_pool4),
x.size()[-2:],
mode="bilinear",
align_corners=True,
)
return upscore8 return upscore8
class FCN_AlexNet(nn.Module): class FCN_AlexNet(nn.Module):
def __init__(self, inter_channels=256, in_channels=1): def __init__(self, inter_channels=256, in_channels=1):
super(FCN_AlexNet, self).__init__() super(FCN_AlexNet, self).__init__()
self.alexnet = AlexNet(in_channels=in_channels) self.alexnet = AlexNet(in_channels=in_channels)
@ -103,7 +116,9 @@ class FCN_AlexNet(nn.Module):
fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True), score_fr fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True), score_fr
) )
self.upscore2 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True) self.upscore2 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True)
self.upscore_pool4 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True) self.upscore_pool4 = Conv3x3GNReLU(
inter_channels, inter_channels, upsample=True
)
self.final_conv = nn.Conv2d(inter_channels, 1, kernel_size=1) self.final_conv = nn.Conv2d(inter_channels, 1, kernel_size=1)
def forward(self, x): def forward(self, x):
@ -118,31 +133,39 @@ class FCN_AlexNet(nn.Module):
upscore_pool4 = self.upscore_pool4(score_pool4 + upscore2, pool3.size()[-2:]) upscore_pool4 = self.upscore_pool4(score_pool4 + upscore2, pool3.size()[-2:])
score_pool3 = self.score_pool3(pool3) score_pool3 = self.score_pool3(pool3)
upscore8 = F.interpolate(self.final_conv(score_pool3 + upscore_pool4), x.size()[-2:], upscore8 = F.interpolate(
mode='bilinear', align_corners=True) self.final_conv(score_pool3 + upscore_pool4),
x.size()[-2:],
mode="bilinear",
align_corners=True,
)
return upscore8 return upscore8
class FCN_ResNet(nn.Module): class FCN_ResNet(nn.Module):
def __init__(self, backbone, inter_channels=256): def __init__(self, backbone, inter_channels=256):
super(FCN_ResNet, self).__init__() super(FCN_ResNet, self).__init__()
self.backbone = backbone self.backbone = backbone
self.score_pool3 = nn.Conv2d(backbone.layer2[0].downsample[1].num_features, self.score_pool3 = nn.Conv2d(
inter_channels, kernel_size=1) backbone.layer2[0].downsample[1].num_features, inter_channels, kernel_size=1
self.score_pool4 = nn.Conv2d(backbone.layer3[0].downsample[1].num_features, )
inter_channels, kernel_size=1) self.score_pool4 = nn.Conv2d(
backbone.layer3[0].downsample[1].num_features, inter_channels, kernel_size=1
)
fc6 = nn.Conv2d(backbone.layer4[0].downsample[1].num_features, fc6 = nn.Conv2d(
512, kernel_size=3, padding=1) backbone.layer4[0].downsample[1].num_features, 512, kernel_size=3, padding=1
)
fc7 = nn.Conv2d(512, 512, kernel_size=1) fc7 = nn.Conv2d(512, 512, kernel_size=1)
score_fr = nn.Conv2d(512, inter_channels, kernel_size=1) score_fr = nn.Conv2d(512, inter_channels, kernel_size=1)
self.score_fr = nn.Sequential( self.score_fr = nn.Sequential(
fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True), score_fr fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True), score_fr
) )
self.upscore2 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True) self.upscore2 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True)
self.upscore_pool4 = Conv3x3GNReLU(inter_channels, inter_channels, upsample=True) self.upscore_pool4 = Conv3x3GNReLU(
inter_channels, inter_channels, upsample=True
)
self.final_conv = nn.Conv2d(inter_channels, 1, kernel_size=1) self.final_conv = nn.Conv2d(inter_channels, 1, kernel_size=1)
def forward(self, x): def forward(self, x):
@ -155,7 +178,12 @@ class FCN_ResNet(nn.Module):
upscore_pool4 = self.upscore_pool4(score_pool4 + upscore2, pool3.size()[-2:]) upscore_pool4 = self.upscore_pool4(score_pool4 + upscore2, pool3.size()[-2:])
score_pool3 = self.score_pool3(pool3) score_pool3 = self.score_pool3(pool3)
upscore8 = F.interpolate(self.final_conv(score_pool3 + upscore_pool4), x.size()[-2:], mode='bilinear', align_corners=True) upscore8 = F.interpolate(
self.final_conv(score_pool3 + upscore_pool4),
x.size()[-2:],
mode="bilinear",
align_corners=True,
)
return upscore8 return upscore8
@ -209,7 +237,7 @@ def FCN_ResNet152(in_channels=1, **kwargs):
return model return model
if __name__ == '__main__': if __name__ == "__main__":
model = FCN_AlexNet(in_channels=1, inter_channels=128) model = FCN_AlexNet(in_channels=1, inter_channels=128)
x = torch.randn(1, 1, 200, 200) x = torch.randn(1, 1, 200, 200)
with torch.no_grad(): with torch.no_grad():

View File

@ -6,7 +6,13 @@ from src.utils.model_init import weights_init
from .util.backbone import * from .util.backbone import *
__all__ = ["FPN_ResNet18", "FPN_ResNet34", "FPN_ResNet50", "FPN_ResNet101", "FPN_ResNet152"] __all__ = [
"FPN_ResNet18",
"FPN_ResNet34",
"FPN_ResNet50",
"FPN_ResNet101",
"FPN_ResNet152",
]
class Conv3x3GNReLU(nn.Module): class Conv3x3GNReLU(nn.Module):
@ -14,7 +20,9 @@ class Conv3x3GNReLU(nn.Module):
super().__init__() super().__init__()
self.upsample = upsample self.upsample = upsample
self.block = nn.Sequential( self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False), nn.Conv2d(
in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
),
nn.GroupNorm(32, out_channels), nn.GroupNorm(32, out_channels),
# nn.BatchNorm2d(out_channels), # nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
@ -45,11 +53,15 @@ class SegmentationBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_upsamples=0): def __init__(self, in_channels, out_channels, n_upsamples=0):
super().__init__() super().__init__()
self.blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] self.blocks = [
Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))
]
if n_upsamples > 1: if n_upsamples > 1:
for _ in range(1, n_upsamples): for _ in range(1, n_upsamples):
self.blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) self.blocks.append(
Conv3x3GNReLU(out_channels, out_channels, upsample=True)
)
self.blocks_name = [] self.blocks_name = []
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
@ -77,32 +89,31 @@ class FPN_ResNet(nn.Module):
self.backbone = backbone self.backbone = backbone
self.backbone.apply(weights_init) self.backbone.apply(weights_init)
self.final_upsampling = final_upsampling self.final_upsampling = final_upsampling
self.conv1 = nn.Conv2d(encoder_channels[0], self.conv1 = nn.Conv2d(
pyramid_channels, encoder_channels[0], pyramid_channels, kernel_size=(1, 1)
kernel_size=(1, 1)) )
self.p4 = FPNBlock(pyramid_channels, encoder_channels[1]) self.p4 = FPNBlock(pyramid_channels, encoder_channels[1])
self.p3 = FPNBlock(pyramid_channels, encoder_channels[2]) self.p3 = FPNBlock(pyramid_channels, encoder_channels[2])
self.p2 = FPNBlock(pyramid_channels, encoder_channels[3]) self.p2 = FPNBlock(pyramid_channels, encoder_channels[3])
self.s5 = SegmentationBlock(pyramid_channels, self.s5 = SegmentationBlock(
segmentation_channels, pyramid_channels, segmentation_channels, n_upsamples=3
n_upsamples=3) )
self.s4 = SegmentationBlock(pyramid_channels, self.s4 = SegmentationBlock(
segmentation_channels, pyramid_channels, segmentation_channels, n_upsamples=2
n_upsamples=2) )
self.s3 = SegmentationBlock(pyramid_channels, self.s3 = SegmentationBlock(
segmentation_channels, pyramid_channels, segmentation_channels, n_upsamples=1
n_upsamples=1) )
self.s2 = SegmentationBlock(pyramid_channels, self.s2 = SegmentationBlock(
segmentation_channels, pyramid_channels, segmentation_channels, n_upsamples=0
n_upsamples=0) )
self.dropout = nn.Dropout2d(p=dropout, inplace=True) self.dropout = nn.Dropout2d(p=dropout, inplace=True)
self.final_conv = nn.Conv2d(segmentation_channels, self.final_conv = nn.Conv2d(
final_channels, segmentation_channels, final_channels, kernel_size=1, padding=0
kernel_size=1, )
padding=0)
def forward(self, x): def forward(self, x):
x = self.backbone(x) x = self.backbone(x)
@ -126,45 +137,45 @@ class FPN_ResNet(nn.Module):
x = self.final_conv(x) x = self.final_conv(x)
if self.final_upsampling is not None and self.final_upsampling > 1: if self.final_upsampling is not None and self.final_upsampling > 1:
x = F.interpolate(x, scale_factor=self.final_upsampling, mode="bilinear", align_corners=True) x = F.interpolate(
x,
scale_factor=self.final_upsampling,
mode="bilinear",
align_corners=True,
)
return x return x
def FPN_ResNet18(in_channels=1, **kwargs): def FPN_ResNet18(in_channels=1, **kwargs):
"""FPN with ResNet18 as backbone """FPN with ResNet18 as backbone"""
"""
backbone = resnet18(in_channels=in_channels) backbone = resnet18(in_channels=in_channels)
model = FPN_ResNet(backbone, encoder_channels=[512, 256, 128, 64], **kwargs) model = FPN_ResNet(backbone, encoder_channels=[512, 256, 128, 64], **kwargs)
return model return model
def FPN_ResNet34(in_channels=1, **kwargs): def FPN_ResNet34(in_channels=1, **kwargs):
"""FPN with ResNet18 as backbone """FPN with ResNet18 as backbone"""
"""
backbone = resnet34(in_channels=in_channels) backbone = resnet34(in_channels=in_channels)
model = FPN_ResNet(backbone, encoder_channels=[512, 256, 128, 64], **kwargs) model = FPN_ResNet(backbone, encoder_channels=[512, 256, 128, 64], **kwargs)
return model return model
def FPN_ResNet50(in_channels=1, **kwargs): def FPN_ResNet50(in_channels=1, **kwargs):
"""FPN with ResNet50 as backbone """FPN with ResNet50 as backbone"""
"""
backbone = resnet50(in_channels=in_channels) backbone = resnet50(in_channels=in_channels)
model = FPN_ResNet(backbone, encoder_channels=[2048, 1024, 512, 256], **kwargs) model = FPN_ResNet(backbone, encoder_channels=[2048, 1024, 512, 256], **kwargs)
return model return model
def FPN_ResNet101(in_channels=1, **kwargs): def FPN_ResNet101(in_channels=1, **kwargs):
"""FPN with ResNet101 as backbone """FPN with ResNet101 as backbone"""
"""
backbone = resnet101(in_channels=in_channels) backbone = resnet101(in_channels=in_channels)
model = FPN_ResNet(backbone, encoder_channels=[2048, 1024, 512, 256], **kwargs) model = FPN_ResNet(backbone, encoder_channels=[2048, 1024, 512, 256], **kwargs)
return model return model
def FPN_ResNet152(in_channels=1, **kwargs): def FPN_ResNet152(in_channels=1, **kwargs):
"""FPN with ResNet101 as backbone """FPN with ResNet101 as backbone"""
"""
backbone = resnet152(in_channels=in_channels) backbone = resnet152(in_channels=in_channels)
model = FPN_ResNet(backbone, encoder_channels=[2048, 1024, 512, 256], **kwargs) model = FPN_ResNet(backbone, encoder_channels=[2048, 1024, 512, 256], **kwargs)
return model return model

View File

@ -2,8 +2,23 @@ import torch
from torch.nn import Sequential as Seq from torch.nn import Sequential as Seq
from torch.nn import Linear as Lin from torch.nn import Linear as Lin
import torch_geometric as tg import torch_geometric as tg
from .util.gcn_lib.dense import BasicConv, GraphConv2d, PlainDynBlock2d, ResDynBlock2d, DenseDynBlock2d, DenseDilatedKnnGraph from .util.gcn_lib.dense import (
from .util.gcn_lib.sparse import MultiSeq, MLP, GraphConv, PlainDynBlock, ResDynBlock, DenseDynBlock, DilatedKnnGraph BasicConv,
GraphConv2d,
PlainDynBlock2d,
ResDynBlock2d,
DenseDynBlock2d,
DenseDilatedKnnGraph,
)
from .util.gcn_lib.sparse import (
MultiSeq,
MLP,
GraphConv,
PlainDynBlock,
ResDynBlock,
DenseDynBlock,
DilatedKnnGraph,
)
__all__ = ["DenseDeepGCN"] __all__ = ["DenseDeepGCN"]
@ -15,7 +30,7 @@ class DenseDeepGCN(torch.nn.Module):
input_dim=None, input_dim=None,
output_dim=None, output_dim=None,
dropout=0.8, dropout=0.8,
in_channels=2+1, in_channels=2 + 1,
k=5, k=5,
n_classes=1, n_classes=1,
block="dense", block="dense",
@ -31,57 +46,99 @@ class DenseDeepGCN(torch.nn.Module):
): ):
super(DenseDeepGCN, self).__init__() super(DenseDeepGCN, self).__init__()
self.dim = dim self.dim = dim
self.n_classes = n_classes # self.n_classes = n_classes #
self.k = k # self.k = k #
self.in_channels = in_channels # self.in_channels = in_channels #
self.dropout = dropout # self.dropout = dropout #
self.block = block self.block = block
self.conv = conv # self.conv = conv #
self.act = act # self.act = act #
self.norm = norm # self.norm = norm #
self.bias = bias # self.bias = bias #
self.channels = n_filters # self.channels = n_filters #
self.n_blocks = n_blocks # self.n_blocks = n_blocks #
self.epsilon = epsilon # self.epsilon = epsilon #
self.stochastic = stochastic # self.stochastic = stochastic #
c_growth = self.channels c_growth = self.channels
#print(self.dropout) # print(self.dropout)
self.knn = DenseDilatedKnnGraph(k, 1, self.stochastic, self.epsilon) self.knn = DenseDilatedKnnGraph(k, 1, self.stochastic, self.epsilon)
self.head = GraphConv2d(self.in_channels, self.channels, conv, act, norm, bias) self.head = GraphConv2d(self.in_channels, self.channels, conv, act, norm, bias)
if self.block.lower() == 'res': if self.block.lower() == "res":
self.backbone = Seq(*[ResDynBlock2d(self.channels, k, 1+i, conv, act, norm, bias, stochastic, epsilon) self.backbone = Seq(
for i in range(self.n_blocks-1)]) *[
ResDynBlock2d(
self.channels,
k,
1 + i,
conv,
act,
norm,
bias,
stochastic,
epsilon,
)
for i in range(self.n_blocks - 1)
]
)
fusion_dims = int(self.channels + c_growth * (self.n_blocks - 1)) fusion_dims = int(self.channels + c_growth * (self.n_blocks - 1))
elif self.block.lower() == 'dense': elif self.block.lower() == "dense":
self.backbone = Seq(*[DenseDynBlock2d(self.channels+c_growth*i, c_growth, k, 1+i, conv, act, self.backbone = Seq(
norm, bias, stochastic, epsilon) *[
for i in range(self.n_blocks-1)]) DenseDynBlock2d(
self.channels + c_growth * i,
c_growth,
k,
1 + i,
conv,
act,
norm,
bias,
stochastic,
epsilon,
)
for i in range(self.n_blocks - 1)
]
)
fusion_dims = int( fusion_dims = int(
(self.channels + self.channels + c_growth * (self.n_blocks - 1)) * self.n_blocks // 2) (self.channels + self.channels + c_growth * (self.n_blocks - 1))
* self.n_blocks
// 2
)
else: else:
stochastic = False stochastic = False
self.backbone = Seq(*[PlainDynBlock2d(self.channels, k, 1, conv, act, norm, self.backbone = Seq(
bias, stochastic, epsilon) *[
for i in range(self.n_blocks - 1)]) PlainDynBlock2d(
self.channels, k, 1, conv, act, norm, bias, stochastic, epsilon
)
for i in range(self.n_blocks - 1)
]
)
fusion_dims = int(self.channels + c_growth * (self.n_blocks - 1)) fusion_dims = int(self.channels + c_growth * (self.n_blocks - 1))
self.fusion_block = BasicConv([fusion_dims, 1024], act, norm, bias) self.fusion_block = BasicConv([fusion_dims, 1024], act, norm, bias)
self.prediction = Seq(*[BasicConv([fusion_dims+1024, 512], act, norm, bias), self.prediction = Seq(
BasicConv([512, 256], act, norm, bias), *[
torch.nn.Dropout(p=self.dropout), BasicConv([fusion_dims + 1024, 512], act, norm, bias),
BasicConv([256, self.n_classes], None, None, bias)]) BasicConv([512, 256], act, norm, bias),
torch.nn.Dropout(p=self.dropout),
BasicConv([256, self.n_classes], None, None, bias),
]
)
def forward(self, inputs): def forward(self, inputs):
feats = [self.head(inputs, self.knn(inputs[:, 0:self.dim]))] feats = [self.head(inputs, self.knn(inputs[:, 0 : self.dim]))]
for i in range(self.n_blocks-1): for i in range(self.n_blocks - 1):
feats.append(self.backbone[i](feats[-1])) feats.append(self.backbone[i](feats[-1]))
feats = torch.cat(feats, dim=1) feats = torch.cat(feats, dim=1)
fusion = torch.max_pool2d(self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]]) fusion = torch.max_pool2d(
self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]]
)
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2) fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2)
return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1) return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1)
@ -104,28 +161,76 @@ class SparseDeepGCN(torch.nn.Module):
self.knn = DilatedKnnGraph(k, 1, stochastic, epsilon) self.knn = DilatedKnnGraph(k, 1, stochastic, epsilon)
self.head = GraphConv(opt.in_channels, channels, conv, act, norm, bias) self.head = GraphConv(opt.in_channels, channels, conv, act, norm, bias)
if opt.block.lower() == 'res': if opt.block.lower() == "res":
self.backbone = MultiSeq(*[ResDynBlock(channels, k, 1+i, conv, act, norm, bias, stochastic=stochastic, epsilon=epsilon) self.backbone = MultiSeq(
for i in range(self.n_blocks-1)]) *[
ResDynBlock(
channels,
k,
1 + i,
conv,
act,
norm,
bias,
stochastic=stochastic,
epsilon=epsilon,
)
for i in range(self.n_blocks - 1)
]
)
fusion_dims = int(channels + c_growth * (self.n_blocks - 1)) fusion_dims = int(channels + c_growth * (self.n_blocks - 1))
elif opt.block.lower() == 'dense': elif opt.block.lower() == "dense":
self.backbone = MultiSeq(*[DenseDynBlock(channels+c_growth*i, c_growth, k, 1+i, self.backbone = MultiSeq(
conv, act, norm, bias, stochastic=stochastic, epsilon=epsilon) *[
for i in range(self.n_blocks-1)]) DenseDynBlock(
channels + c_growth * i,
c_growth,
k,
1 + i,
conv,
act,
norm,
bias,
stochastic=stochastic,
epsilon=epsilon,
)
for i in range(self.n_blocks - 1)
]
)
fusion_dims = int( fusion_dims = int(
(channels + channels + c_growth * (self.n_blocks - 1)) * self.n_blocks // 2) (channels + channels + c_growth * (self.n_blocks - 1))
* self.n_blocks
// 2
)
else: else:
# Use PlainGCN without skip connection and dilated convolution. # Use PlainGCN without skip connection and dilated convolution.
stochastic = False stochastic = False
self.backbone = MultiSeq( self.backbone = MultiSeq(
*[PlainDynBlock(channels, k, 1, conv, act, norm, bias, stochastic=stochastic, epsilon=epsilon) *[
for i in range(self.n_blocks - 1)]) PlainDynBlock(
channels,
k,
1,
conv,
act,
norm,
bias,
stochastic=stochastic,
epsilon=epsilon,
)
for i in range(self.n_blocks - 1)
]
)
fusion_dims = int(channels + c_growth * (self.n_blocks - 1)) fusion_dims = int(channels + c_growth * (self.n_blocks - 1))
self.fusion_block = MLP([fusion_dims, 1024], act, norm, bias) self.fusion_block = MLP([fusion_dims, 1024], act, norm, bias)
self.prediction = MultiSeq(*[MLP([fusion_dims+1024, 512], act, norm, bias), self.prediction = MultiSeq(
MLP([512, 256], act, norm, bias, drop=opt.dropout), *[
MLP([256, opt.n_classes], None, None, bias)]) MLP([fusion_dims + 1024, 512], act, norm, bias),
MLP([512, 256], act, norm, bias, drop=opt.dropout),
MLP([256, opt.n_classes], None, None, bias),
]
)
self.model_init() self.model_init()
def model_init(self): def model_init(self):
@ -141,19 +246,21 @@ class SparseDeepGCN(torch.nn.Module):
corr, color, batch = data.pos, data.x, data.batch corr, color, batch = data.pos, data.x, data.batch
x = torch.cat((corr, color), dim=1) x = torch.cat((corr, color), dim=1)
feats = [self.head(x, self.knn(x[:, 0:3], batch))] feats = [self.head(x, self.knn(x[:, 0:3], batch))]
for i in range(self.n_blocks-1): for i in range(self.n_blocks - 1):
feats.append(self.backbone[i](feats[-1], batch)[0]) feats.append(self.backbone[i](feats[-1], batch)[0])
feats = torch.cat(feats, dim=1) feats = torch.cat(feats, dim=1)
fusion = tg.utils.scatter_('max', self.fusion_block(feats), batch) fusion = tg.utils.scatter_("max", self.fusion_block(feats), batch)
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[0]//fusion.shape[0], dim=0) fusion = torch.repeat_interleave(
fusion, repeats=feats.shape[0] // fusion.shape[0], dim=0
)
return self.prediction(torch.cat((fusion, feats), dim=1)) return self.prediction(torch.cat((fusion, feats), dim=1))
if __name__ == "__main__": if __name__ == "__main__":
import random, numpy as np, argparse import random, numpy as np, argparse
seed = 0 seed = 0
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
@ -163,22 +270,58 @@ if __name__ == "__main__":
batch_size = 2 batch_size = 2
N = 1024 N = 1024
device = 'cuda' device = "cuda"
parser = argparse.ArgumentParser(description='PyTorch implementation of Deep GCN For semantic segmentation') parser = argparse.ArgumentParser(
parser.add_argument('--in_channels', default=9, type=int, help='input channels (default:9)') description="PyTorch implementation of Deep GCN For semantic segmentation"
parser.add_argument('--n_classes', default=13, type=int, help='num of segmentation classes (default:13)') )
parser.add_argument('--k', default=4, type=int, help='neighbor num (default:16)') parser.add_argument(
parser.add_argument('--block', default='res', type=str, help='graph backbone block type {plain, res, dense}') "--in_channels", default=9, type=int, help="input channels (default:9)"
parser.add_argument('--conv', default='edge', type=str, help='graph conv layer {edge, mr}') )
parser.add_argument('--act', default='relu', type=str, help='activation layer {relu, prelu, leakyrelu}') parser.add_argument(
parser.add_argument('--norm', default='batch', type=str, help='{batch, instance} normalization') "--n_classes",
parser.add_argument('--bias', default=True, type=bool, help='bias of conv layer True or False') default=13,
parser.add_argument('--n_filters', default=64, type=int, help='number of channels of deep features') type=int,
parser.add_argument('--n_blocks', default=7, type=int, help='number of basic blocks') help="num of segmentation classes (default:13)",
parser.add_argument('--dropout', default=0.5, type=float, help='ratio of dropout') )
parser.add_argument('--epsilon', default=0.2, type=float, help='stochastic epsilon for gcn') parser.add_argument("--k", default=4, type=int, help="neighbor num (default:16)")
parser.add_argument('--stochastic', default=False, type=bool, help='stochastic for gcn, True or False') parser.add_argument(
"--block",
default="res",
type=str,
help="graph backbone block type {plain, res, dense}",
)
parser.add_argument(
"--conv", default="edge", type=str, help="graph conv layer {edge, mr}"
)
parser.add_argument(
"--act",
default="relu",
type=str,
help="activation layer {relu, prelu, leakyrelu}",
)
parser.add_argument(
"--norm", default="batch", type=str, help="{batch, instance} normalization"
)
parser.add_argument(
"--bias", default=True, type=bool, help="bias of conv layer True or False"
)
parser.add_argument(
"--n_filters", default=64, type=int, help="number of channels of deep features"
)
parser.add_argument(
"--n_blocks", default=7, type=int, help="number of basic blocks"
)
parser.add_argument("--dropout", default=0.5, type=float, help="ratio of dropout")
parser.add_argument(
"--epsilon", default=0.2, type=float, help="stochastic epsilon for gcn"
)
parser.add_argument(
"--stochastic",
default=False,
type=bool,
help="stochastic for gcn, True or False",
)
args = parser.parse_args() args = parser.parse_args()
pos = torch.rand((batch_size, N, 2), dtype=torch.float).to(device) pos = torch.rand((batch_size, N, 2), dtype=torch.float).to(device)
@ -188,9 +331,8 @@ if __name__ == "__main__":
print(inputs.size()) print(inputs.size())
# net = DGCNNSegDense().to(device) # net = DGCNNSegDense().to(device)
net = DenseDeepGCN(in_channels=2+6).to(device) net = DenseDeepGCN(in_channels=2 + 6).to(device)
# net = SparseDeepGCN(args).to(device) # net = SparseDeepGCN(args).to(device)
print(net) print(net)
out = net(inputs) out = net(inputs)
print(out.shape) print(out.shape)

View File

@ -9,34 +9,36 @@ def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
# activation layer # activation layer
act = act.lower() act = act.lower()
if act == 'relu': if act == "relu":
layer = nn.ReLU(inplace) layer = nn.ReLU(inplace)
elif act == 'leakyrelu': elif act == "leakyrelu":
layer = nn.LeakyReLU(neg_slope, inplace) layer = nn.LeakyReLU(neg_slope, inplace)
elif act == 'prelu': elif act == "prelu":
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act == 'gelu': elif act == "gelu":
layer = nn.GELU() layer = nn.GELU()
elif act == 'sigmoid': elif act == "sigmoid":
layer = nn.Sigmoid() layer = nn.Sigmoid()
else: else:
raise NotImplementedError('activation layer [%s] is not found' % act) raise NotImplementedError("activation layer [%s] is not found" % act)
return layer return layer
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, input_dim, output_dim, act = "gelu", bias=True, hidden_layer=[512, 512, 512]): def __init__(
self, input_dim, output_dim, act="gelu", bias=True, hidden_layer=[512, 512, 512]
):
super(MLP, self).__init__() super(MLP, self).__init__()
self.hidden_layer = hidden_layer self.hidden_layer = hidden_layer
net = [] net = []
net.append(nn.Linear(input_dim, hidden_layer[0])) net.append(nn.Linear(input_dim, hidden_layer[0]))
if act is not None and act.lower() != 'none': if act is not None and act.lower() != "none":
net.append(act_layer(act)) net.append(act_layer(act))
if len(hidden_layer) > 1: if len(hidden_layer) > 1:
for i in range(1, len(hidden_layer)): for i in range(1, len(hidden_layer)):
net.append(nn.Linear(hidden_layer[i - 1], hidden_layer[i], bias)) net.append(nn.Linear(hidden_layer[i - 1], hidden_layer[i], bias))
if act is not None and act.lower() != 'none': if act is not None and act.lower() != "none":
net.append(act_layer(act)) net.append(act_layer(act))
net.append(nn.Linear(hidden_layer[-1], output_dim)) net.append(nn.Linear(hidden_layer[-1], output_dim))

View File

@ -8,28 +8,37 @@ import torch.nn.functional as F
from .util.backbone import * from .util.backbone import *
__all__ = ["SegNet_VGG", "SegNet_VGG_GN", "SegNet_AlexNet", "SegNet_ResNet18", __all__ = [
"SegNet_ResNet50", "SegNet_ResNet101", "SegNet_ResNet34", "SegNet_ResNet152"] "SegNet_VGG",
"SegNet_VGG_GN",
"SegNet_AlexNet",
"SegNet_ResNet18",
"SegNet_ResNet50",
"SegNet_ResNet101",
"SegNet_ResNet34",
"SegNet_ResNet152",
]
# required class for decoder of SegNet_ResNet # required class for decoder of SegNet_ResNet
class DecoderBottleneck(nn.Module): class DecoderBottleneck(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(DecoderBottleneck, self).__init__() super(DecoderBottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False)
kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(in_channels // 4) self.bn1 = nn.BatchNorm2d(in_channels // 4)
self.conv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, self.conv2 = nn.ConvTranspose2d(
kernel_size=2, stride=2, bias=False) in_channels // 4, in_channels // 4, kernel_size=2, stride=2, bias=False
)
self.bn2 = nn.BatchNorm2d(in_channels // 4) self.bn2 = nn.BatchNorm2d(in_channels // 4)
self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 2, 1, bias=False) self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 2, 1, bias=False)
self.bn3 = nn.BatchNorm2d(in_channels // 2) self.bn3 = nn.BatchNorm2d(in_channels // 2)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.ConvTranspose2d(in_channels, in_channels // 2, nn.ConvTranspose2d(
kernel_size=2, stride=2, bias=False), in_channels, in_channels // 2, kernel_size=2, stride=2, bias=False
nn.BatchNorm2d(in_channels // 2)) ),
nn.BatchNorm2d(in_channels // 2),
)
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
@ -49,21 +58,21 @@ class DecoderBottleneck(nn.Module):
# required class for decoder of SegNet_ResNet # required class for decoder of SegNet_ResNet
class LastBottleneck(nn.Module): class LastBottleneck(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(LastBottleneck, self).__init__() super(LastBottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False)
kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(in_channels // 4) self.bn1 = nn.BatchNorm2d(in_channels // 4)
self.conv2 = nn.Conv2d(in_channels // 4, in_channels // 4, self.conv2 = nn.Conv2d(
kernel_size=3, padding=1, bias=False) in_channels // 4, in_channels // 4, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(in_channels // 4) self.bn2 = nn.BatchNorm2d(in_channels // 4)
self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 4, 1, bias=False) self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 4, 1, bias=False)
self.bn3 = nn.BatchNorm2d(in_channels // 4) self.bn3 = nn.BatchNorm2d(in_channels // 4)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False), nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channels // 4)) nn.BatchNorm2d(in_channels // 4),
)
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
@ -83,20 +92,23 @@ class LastBottleneck(nn.Module):
# required class for decoder of SegNet_ResNet # required class for decoder of SegNet_ResNet
class DecoderBasicBlock(nn.Module): class DecoderBasicBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(DecoderBasicBlock, self).__init__() super(DecoderBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 2, self.conv1 = nn.Conv2d(
kernel_size=3, padding=1, bias=False) in_channels, in_channels // 2, kernel_size=3, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(in_channels // 2) self.bn1 = nn.BatchNorm2d(in_channels // 2)
self.conv2 = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, self.conv2 = nn.ConvTranspose2d(
kernel_size=2, stride=2, bias=False) in_channels // 2, in_channels // 2, kernel_size=2, stride=2, bias=False
)
self.bn2 = nn.BatchNorm2d(in_channels // 2) self.bn2 = nn.BatchNorm2d(in_channels // 2)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.ConvTranspose2d(in_channels, in_channels // 2, nn.ConvTranspose2d(
kernel_size=2, stride=2, bias=False), in_channels, in_channels // 2, kernel_size=2, stride=2, bias=False
nn.BatchNorm2d(in_channels // 2)) ),
nn.BatchNorm2d(in_channels // 2),
)
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
@ -112,19 +124,21 @@ class DecoderBasicBlock(nn.Module):
class LastBasicBlock(nn.Module): class LastBasicBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(LastBasicBlock, self).__init__() super(LastBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, self.conv1 = nn.Conv2d(
kernel_size=3, padding=1, bias=False) in_channels, in_channels, kernel_size=3, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(in_channels) self.bn1 = nn.BatchNorm2d(in_channels)
self.conv2 = nn.Conv2d(in_channels, in_channels, self.conv2 = nn.Conv2d(
kernel_size=3, padding=1, bias=False) in_channels, in_channels, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(in_channels) self.bn2 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False), nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channels)) nn.BatchNorm2d(in_channels),
)
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
@ -140,7 +154,6 @@ class LastBasicBlock(nn.Module):
class SegNet_VGG(nn.Module): class SegNet_VGG(nn.Module):
def __init__(self, out_channels=1, in_channels=1, pretrained=False): def __init__(self, out_channels=1, in_channels=1, pretrained=False):
super(SegNet_VGG, self).__init__() super(SegNet_VGG, self).__init__()
vgg_bn = vgg16_bn(pretrained=pretrained) vgg_bn = vgg16_bn(pretrained=pretrained)
@ -160,34 +173,45 @@ class SegNet_VGG(nn.Module):
# Decoder, same as the encoder but reversed, maxpool will not be used # Decoder, same as the encoder but reversed, maxpool will not be used
decoder = encoder decoder = encoder
decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)] decoder = [
i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)
]
# Replace the last conv layer # Replace the last conv layer
decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
# When reversing, we also reversed conv->batchN->relu, correct it # When reversing, we also reversed conv->batchN->relu, correct it
decoder = [item for i in range(0, len(decoder), 3) decoder = [
for item in decoder[i:i + 3][::-1]] item for i in range(0, len(decoder), 3) for item in decoder[i : i + 3][::-1]
]
# Replace some conv layers & batchN after them # Replace some conv layers & batchN after them
for i, module in enumerate(decoder): for i, module in enumerate(decoder):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
if module.in_channels != module.out_channels: if module.in_channels != module.out_channels:
decoder[i + 1] = nn.BatchNorm2d(module.in_channels) decoder[i + 1] = nn.BatchNorm2d(module.in_channels)
decoder[i] = nn.Conv2d(module.out_channels, module.in_channels, decoder[i] = nn.Conv2d(
kernel_size=3, stride=1, padding=1) module.out_channels,
module.in_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.stage1_decoder = nn.Sequential(*decoder[0:9]) self.stage1_decoder = nn.Sequential(*decoder[0:9])
self.stage2_decoder = nn.Sequential(*decoder[9:18]) self.stage2_decoder = nn.Sequential(*decoder[9:18])
self.stage3_decoder = nn.Sequential(*decoder[18:27]) self.stage3_decoder = nn.Sequential(*decoder[18:27])
self.stage4_decoder = nn.Sequential(*decoder[27:33]) self.stage4_decoder = nn.Sequential(*decoder[27:33])
self.stage5_decoder = nn.Sequential(*decoder[33:], self.stage5_decoder = nn.Sequential(
nn.Conv2d(64, out_channels, *decoder[33:],
kernel_size=3, nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
stride=1, )
padding=1)
)
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder, self._initialize_weights(
self.stage4_decoder, self.stage5_decoder) self.stage1_decoder,
self.stage2_decoder,
self.stage3_decoder,
self.stage4_decoder,
self.stage5_decoder,
)
def _initialize_weights(self, *stages): def _initialize_weights(self, *stages):
for modules in stages: for modules in stages:
@ -242,7 +266,6 @@ class SegNet_VGG(nn.Module):
class SegNet_VGG_GN(nn.Module): class SegNet_VGG_GN(nn.Module):
def __init__(self, out_channels=1, in_channels=3, pretrained=False): def __init__(self, out_channels=1, in_channels=3, pretrained=False):
super(SegNet_VGG_GN, self).__init__() super(SegNet_VGG_GN, self).__init__()
vgg_bn = vgg16_bn(pretrained=pretrained) vgg_bn = vgg16_bn(pretrained=pretrained)
@ -267,33 +290,45 @@ class SegNet_VGG_GN(nn.Module):
# Decoder, same as the encoder but reversed, maxpool will not be used # Decoder, same as the encoder but reversed, maxpool will not be used
decoder = encoder decoder = encoder
decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)] decoder = [
i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)
]
# Replace the last conv layer # Replace the last conv layer
decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
# When reversing, we also reversed conv->batchN->relu, correct it # When reversing, we also reversed conv->batchN->relu, correct it
decoder = [item for i in range(0, len(decoder), 3) decoder = [
for item in decoder[i:i + 3][::-1]] item for i in range(0, len(decoder), 3) for item in decoder[i : i + 3][::-1]
]
# Replace some conv layers & batchN after them # Replace some conv layers & batchN after them
for i, module in enumerate(decoder): for i, module in enumerate(decoder):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
if module.in_channels != module.out_channels: if module.in_channels != module.out_channels:
decoder[i + 1] = nn.GroupNorm(32, module.in_channels) decoder[i + 1] = nn.GroupNorm(32, module.in_channels)
decoder[i] = nn.Conv2d(module.out_channels, module.in_channels, decoder[i] = nn.Conv2d(
kernel_size=3, stride=1, padding=1) module.out_channels,
module.in_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.stage1_decoder = nn.Sequential(*decoder[0:9]) self.stage1_decoder = nn.Sequential(*decoder[0:9])
self.stage2_decoder = nn.Sequential(*decoder[9:18]) self.stage2_decoder = nn.Sequential(*decoder[9:18])
self.stage3_decoder = nn.Sequential(*decoder[18:27]) self.stage3_decoder = nn.Sequential(*decoder[18:27])
self.stage4_decoder = nn.Sequential(*decoder[27:33]) self.stage4_decoder = nn.Sequential(*decoder[27:33])
self.stage5_decoder = nn.Sequential(*decoder[33:], nn.Conv2d(64, self.stage5_decoder = nn.Sequential(
out_channels, *decoder[33:],
kernel_size=3, nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
stride=1, )
padding=1))
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder, self._initialize_weights(
self.stage4_decoder, self.stage5_decoder) self.stage1_decoder,
self.stage2_decoder,
self.stage3_decoder,
self.stage4_decoder,
self.stage5_decoder,
)
def _initialize_weights(self, *stages): def _initialize_weights(self, *stages):
for modules in stages: for modules in stages:
@ -348,7 +383,6 @@ class SegNet_VGG_GN(nn.Module):
class SegNet_AlexNet(nn.Module): class SegNet_AlexNet(nn.Module):
def __init__(self, out_channels=1, in_channels=1, bn=False): def __init__(self, out_channels=1, in_channels=1, bn=False):
super(SegNet_AlexNet, self).__init__() super(SegNet_AlexNet, self).__init__()
self.stage3_encoder = nn.Sequential( self.stage3_encoder = nn.Sequential(
@ -373,7 +407,9 @@ class SegNet_AlexNet(nn.Module):
nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False, return_indices=True) self.maxpool = nn.MaxPool2d(
kernel_size=2, stride=2, ceil_mode=False, return_indices=True
)
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
self.stage5_decoder = nn.Sequential( self.stage5_decoder = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.Conv2d(256, 256, kernel_size=3, padding=1),
@ -419,7 +455,6 @@ class SegNet_AlexNet(nn.Module):
class SegNet_ResNet(nn.Module): class SegNet_ResNet(nn.Module):
def __init__(self, backbone, out_channels=1, is_bottleneck=False, in_channels=1): def __init__(self, backbone, out_channels=1, is_bottleneck=False, in_channels=1):
super(SegNet_ResNet, self).__init__() super(SegNet_ResNet, self).__init__()
resnet_backbone = backbone resnet_backbone = backbone
@ -442,18 +477,25 @@ class SegNet_ResNet(nn.Module):
channels = (512, 256, 128) channels = (512, 256, 128)
for i, block in enumerate(resnet_r_blocks[:-1]): for i, block in enumerate(resnet_r_blocks[:-1]):
new_block = list(block.children())[::-1][:-1] new_block = list(block.children())[::-1][:-1]
decoder.append(nn.Sequential(*new_block, decoder.append(
DecoderBottleneck(channels[i]) nn.Sequential(
if is_bottleneck else DecoderBasicBlock(channels[i]))) *new_block,
DecoderBottleneck(channels[i])
if is_bottleneck
else DecoderBasicBlock(channels[i])
)
)
new_block = list(resnet_r_blocks[-1].children())[::-1][:-1] new_block = list(resnet_r_blocks[-1].children())[::-1][:-1]
decoder.append(nn.Sequential(*new_block, decoder.append(
LastBottleneck(256) nn.Sequential(
if is_bottleneck else LastBasicBlock(64))) *new_block, LastBottleneck(256) if is_bottleneck else LastBasicBlock(64)
)
)
self.decoder = nn.Sequential(*decoder) self.decoder = nn.Sequential(*decoder)
self.last_conv = nn.Sequential( self.last_conv = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, bias=False), nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, bias=False),
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1) nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1),
) )
def forward(self, x): def forward(self, x):
@ -468,10 +510,14 @@ class SegNet_ResNet(nn.Module):
h_diff = ceil((x.size()[2] - indices.size()[2]) / 2) h_diff = ceil((x.size()[2] - indices.size()[2]) / 2)
w_diff = ceil((x.size()[3] - indices.size()[3]) / 2) w_diff = ceil((x.size()[3] - indices.size()[3]) / 2)
if indices.size()[2] % 2 == 1: if indices.size()[2] % 2 == 1:
x = x[:, :, h_diff:x.size()[2] - (h_diff - 1), x = x[
w_diff: x.size()[3] - (w_diff - 1)] :,
:,
h_diff : x.size()[2] - (h_diff - 1),
w_diff : x.size()[3] - (w_diff - 1),
]
else: else:
x = x[:, :, h_diff:x.size()[2] - h_diff, w_diff: x.size()[3] - w_diff] x = x[:, :, h_diff : x.size()[2] - h_diff, w_diff : x.size()[3] - w_diff]
x = F.max_unpool2d(x, indices, kernel_size=2, stride=2) x = F.max_unpool2d(x, indices, kernel_size=2, stride=2)
x = self.last_conv(x) x = self.last_conv(x)
@ -479,9 +525,11 @@ class SegNet_ResNet(nn.Module):
if inputsize != x.size(): if inputsize != x.size():
h_diff = (x.size()[2] - inputsize[2]) // 2 h_diff = (x.size()[2] - inputsize[2]) // 2
w_diff = (x.size()[3] - inputsize[3]) // 2 w_diff = (x.size()[3] - inputsize[3]) // 2
x = x[:, :, h_diff:x.size()[2] - h_diff, w_diff: x.size()[3] - w_diff] x = x[:, :, h_diff : x.size()[2] - h_diff, w_diff : x.size()[3] - w_diff]
if h_diff % 2 != 0: x = x[:, :, :-1, :] if h_diff % 2 != 0:
if w_diff % 2 != 0: x = x[:, :, :, :-1] x = x[:, :, :-1, :]
if w_diff % 2 != 0:
x = x[:, :, :, :-1]
return x return x
@ -492,8 +540,13 @@ def SegNet_ResNet18(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet18() backbone_net = resnet18()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=False, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=False,
in_channels=in_channels,
**kwargs
)
return model return model
@ -503,8 +556,13 @@ def SegNet_ResNet34(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet34() backbone_net = resnet34()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=False, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=False,
in_channels=in_channels,
**kwargs
)
return model return model
@ -514,8 +572,13 @@ def SegNet_ResNet50(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet50() backbone_net = resnet50()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=True,
in_channels=in_channels,
**kwargs
)
return model return model
@ -525,8 +588,13 @@ def SegNet_ResNet101(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet101() backbone_net = resnet101()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=True,
in_channels=in_channels,
**kwargs
)
return model return model
@ -536,12 +604,17 @@ def SegNet_ResNet152(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet101() backbone_net = resnet101()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=True,
in_channels=in_channels,
**kwargs
)
return model return model
if __name__ == '__main__': if __name__ == "__main__":
model = SegNet_AlexNet(in_channels=1, out_channels=1) model = SegNet_AlexNet(in_channels=1, out_channels=1)
print(model) print(model)
x = torch.randn(1, 1, 200, 200) x = torch.randn(1, 1, 200, 200)

View File

@ -8,28 +8,37 @@ import torch.nn.functional as F
from .util.backbone import * from .util.backbone import *
__all__ = ["SegNet_VGG", "SegNet_VGG_GN", "SegNet_AlexNet", "SegNet_ResNet18", __all__ = [
"SegNet_ResNet50", "SegNet_ResNet101", "SegNet_ResNet34", "SegNet_ResNet152"] "SegNet_VGG",
"SegNet_VGG_GN",
"SegNet_AlexNet",
"SegNet_ResNet18",
"SegNet_ResNet50",
"SegNet_ResNet101",
"SegNet_ResNet34",
"SegNet_ResNet152",
]
# required class for decoder of SegNet_ResNet # required class for decoder of SegNet_ResNet
class DecoderBottleneck(nn.Module): class DecoderBottleneck(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(DecoderBottleneck, self).__init__() super(DecoderBottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False)
kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(in_channels // 4) self.bn1 = nn.BatchNorm2d(in_channels // 4)
self.conv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, self.conv2 = nn.ConvTranspose2d(
kernel_size=2, stride=2, bias=False) in_channels // 4, in_channels // 4, kernel_size=2, stride=2, bias=False
)
self.bn2 = nn.BatchNorm2d(in_channels // 4) self.bn2 = nn.BatchNorm2d(in_channels // 4)
self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 2, 1, bias=False) self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 2, 1, bias=False)
self.bn3 = nn.BatchNorm2d(in_channels // 2) self.bn3 = nn.BatchNorm2d(in_channels // 2)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.ConvTranspose2d(in_channels, in_channels // 2, nn.ConvTranspose2d(
kernel_size=2, stride=2, bias=False), in_channels, in_channels // 2, kernel_size=2, stride=2, bias=False
nn.BatchNorm2d(in_channels // 2)) ),
nn.BatchNorm2d(in_channels // 2),
)
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
@ -49,21 +58,21 @@ class DecoderBottleneck(nn.Module):
# required class for decoder of SegNet_ResNet # required class for decoder of SegNet_ResNet
class LastBottleneck(nn.Module): class LastBottleneck(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(LastBottleneck, self).__init__() super(LastBottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, self.conv1 = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False)
kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(in_channels // 4) self.bn1 = nn.BatchNorm2d(in_channels // 4)
self.conv2 = nn.Conv2d(in_channels // 4, in_channels // 4, self.conv2 = nn.Conv2d(
kernel_size=3, padding=1, bias=False) in_channels // 4, in_channels // 4, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(in_channels // 4) self.bn2 = nn.BatchNorm2d(in_channels // 4)
self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 4, 1, bias=False) self.conv3 = nn.Conv2d(in_channels // 4, in_channels // 4, 1, bias=False)
self.bn3 = nn.BatchNorm2d(in_channels // 4) self.bn3 = nn.BatchNorm2d(in_channels // 4)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False), nn.Conv2d(in_channels, in_channels // 4, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channels // 4)) nn.BatchNorm2d(in_channels // 4),
)
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
@ -83,20 +92,23 @@ class LastBottleneck(nn.Module):
# required class for decoder of SegNet_ResNet # required class for decoder of SegNet_ResNet
class DecoderBasicBlock(nn.Module): class DecoderBasicBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(DecoderBasicBlock, self).__init__() super(DecoderBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 2, self.conv1 = nn.Conv2d(
kernel_size=3, padding=1, bias=False) in_channels, in_channels // 2, kernel_size=3, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(in_channels // 2) self.bn1 = nn.BatchNorm2d(in_channels // 2)
self.conv2 = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, self.conv2 = nn.ConvTranspose2d(
kernel_size=2, stride=2, bias=False) in_channels // 2, in_channels // 2, kernel_size=2, stride=2, bias=False
)
self.bn2 = nn.BatchNorm2d(in_channels // 2) self.bn2 = nn.BatchNorm2d(in_channels // 2)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.ConvTranspose2d(in_channels, in_channels // 2, nn.ConvTranspose2d(
kernel_size=2, stride=2, bias=False), in_channels, in_channels // 2, kernel_size=2, stride=2, bias=False
nn.BatchNorm2d(in_channels // 2)) ),
nn.BatchNorm2d(in_channels // 2),
)
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
@ -112,19 +124,21 @@ class DecoderBasicBlock(nn.Module):
class LastBasicBlock(nn.Module): class LastBasicBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super(LastBasicBlock, self).__init__() super(LastBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, self.conv1 = nn.Conv2d(
kernel_size=3, padding=1, bias=False) in_channels, in_channels, kernel_size=3, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(in_channels) self.bn1 = nn.BatchNorm2d(in_channels)
self.conv2 = nn.Conv2d(in_channels, in_channels, self.conv2 = nn.Conv2d(
kernel_size=3, padding=1, bias=False) in_channels, in_channels, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(in_channels) self.bn2 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False), nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channels)) nn.BatchNorm2d(in_channels),
)
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
@ -140,7 +154,6 @@ class LastBasicBlock(nn.Module):
class SegNet_VGG(nn.Module): class SegNet_VGG(nn.Module):
def __init__(self, out_channels=1, in_channels=1, pretrained=False): def __init__(self, out_channels=1, in_channels=1, pretrained=False):
super(SegNet_VGG, self).__init__() super(SegNet_VGG, self).__init__()
vgg_bn = vgg16_bn(pretrained=pretrained) vgg_bn = vgg16_bn(pretrained=pretrained)
@ -160,34 +173,45 @@ class SegNet_VGG(nn.Module):
# Decoder, same as the encoder but reversed, maxpool will not be used # Decoder, same as the encoder but reversed, maxpool will not be used
decoder = encoder decoder = encoder
decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)] decoder = [
i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)
]
# Replace the last conv layer # Replace the last conv layer
decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
# When reversing, we also reversed conv->batchN->relu, correct it # When reversing, we also reversed conv->batchN->relu, correct it
decoder = [item for i in range(0, len(decoder), 3) decoder = [
for item in decoder[i:i + 3][::-1]] item for i in range(0, len(decoder), 3) for item in decoder[i : i + 3][::-1]
]
# Replace some conv layers & batchN after them # Replace some conv layers & batchN after them
for i, module in enumerate(decoder): for i, module in enumerate(decoder):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
if module.in_channels != module.out_channels: if module.in_channels != module.out_channels:
decoder[i + 1] = nn.BatchNorm2d(module.in_channels) decoder[i + 1] = nn.BatchNorm2d(module.in_channels)
decoder[i] = nn.Conv2d(module.out_channels, module.in_channels, decoder[i] = nn.Conv2d(
kernel_size=3, stride=1, padding=1) module.out_channels,
module.in_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.stage1_decoder = nn.Sequential(*decoder[0:9]) self.stage1_decoder = nn.Sequential(*decoder[0:9])
self.stage2_decoder = nn.Sequential(*decoder[9:18]) self.stage2_decoder = nn.Sequential(*decoder[9:18])
self.stage3_decoder = nn.Sequential(*decoder[18:27]) self.stage3_decoder = nn.Sequential(*decoder[18:27])
self.stage4_decoder = nn.Sequential(*decoder[27:33]) self.stage4_decoder = nn.Sequential(*decoder[27:33])
self.stage5_decoder = nn.Sequential(*decoder[33:], self.stage5_decoder = nn.Sequential(
nn.Conv2d(64, out_channels, *decoder[33:],
kernel_size=3, nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
stride=1, )
padding=1)
)
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder, self._initialize_weights(
self.stage4_decoder, self.stage5_decoder) self.stage1_decoder,
self.stage2_decoder,
self.stage3_decoder,
self.stage4_decoder,
self.stage5_decoder,
)
def _initialize_weights(self, *stages): def _initialize_weights(self, *stages):
for modules in stages: for modules in stages:
@ -242,7 +266,6 @@ class SegNet_VGG(nn.Module):
class SegNet_VGG_GN(nn.Module): class SegNet_VGG_GN(nn.Module):
def __init__(self, out_channels=1, in_channels=3, pretrained=False): def __init__(self, out_channels=1, in_channels=3, pretrained=False):
super(SegNet_VGG_GN, self).__init__() super(SegNet_VGG_GN, self).__init__()
vgg_bn = vgg16_bn(pretrained=pretrained) vgg_bn = vgg16_bn(pretrained=pretrained)
@ -267,33 +290,45 @@ class SegNet_VGG_GN(nn.Module):
# Decoder, same as the encoder but reversed, maxpool will not be used # Decoder, same as the encoder but reversed, maxpool will not be used
decoder = encoder decoder = encoder
decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)] decoder = [
i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)
]
# Replace the last conv layer # Replace the last conv layer
decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
# When reversing, we also reversed conv->batchN->relu, correct it # When reversing, we also reversed conv->batchN->relu, correct it
decoder = [item for i in range(0, len(decoder), 3) decoder = [
for item in decoder[i:i + 3][::-1]] item for i in range(0, len(decoder), 3) for item in decoder[i : i + 3][::-1]
]
# Replace some conv layers & batchN after them # Replace some conv layers & batchN after them
for i, module in enumerate(decoder): for i, module in enumerate(decoder):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
if module.in_channels != module.out_channels: if module.in_channels != module.out_channels:
decoder[i + 1] = nn.GroupNorm(32, module.in_channels) decoder[i + 1] = nn.GroupNorm(32, module.in_channels)
decoder[i] = nn.Conv2d(module.out_channels, module.in_channels, decoder[i] = nn.Conv2d(
kernel_size=3, stride=1, padding=1) module.out_channels,
module.in_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.stage1_decoder = nn.Sequential(*decoder[0:9]) self.stage1_decoder = nn.Sequential(*decoder[0:9])
self.stage2_decoder = nn.Sequential(*decoder[9:18]) self.stage2_decoder = nn.Sequential(*decoder[9:18])
self.stage3_decoder = nn.Sequential(*decoder[18:27]) self.stage3_decoder = nn.Sequential(*decoder[18:27])
self.stage4_decoder = nn.Sequential(*decoder[27:33]) self.stage4_decoder = nn.Sequential(*decoder[27:33])
self.stage5_decoder = nn.Sequential(*decoder[33:], nn.Conv2d(64, self.stage5_decoder = nn.Sequential(
out_channels, *decoder[33:],
kernel_size=3, nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
stride=1, )
padding=1))
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder, self._initialize_weights(
self.stage4_decoder, self.stage5_decoder) self.stage1_decoder,
self.stage2_decoder,
self.stage3_decoder,
self.stage4_decoder,
self.stage5_decoder,
)
def _initialize_weights(self, *stages): def _initialize_weights(self, *stages):
for modules in stages: for modules in stages:
@ -348,7 +383,6 @@ class SegNet_VGG_GN(nn.Module):
class SegNet_AlexNet(nn.Module): class SegNet_AlexNet(nn.Module):
def __init__(self, out_channels=1, in_channels=1, bn=False): def __init__(self, out_channels=1, in_channels=1, bn=False):
super(SegNet_AlexNet, self).__init__() super(SegNet_AlexNet, self).__init__()
self.stage3_encoder = nn.Sequential( self.stage3_encoder = nn.Sequential(
@ -373,7 +407,9 @@ class SegNet_AlexNet(nn.Module):
nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=False, return_indices=True) self.maxpool = nn.MaxPool2d(
kernel_size=2, stride=2, ceil_mode=False, return_indices=True
)
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
self.stage5_decoder = nn.Sequential( self.stage5_decoder = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.Conv2d(256, 256, kernel_size=3, padding=1),
@ -419,7 +455,6 @@ class SegNet_AlexNet(nn.Module):
class SegNet_ResNet(nn.Module): class SegNet_ResNet(nn.Module):
def __init__(self, backbone, out_channels=1, is_bottleneck=False, in_channels=1): def __init__(self, backbone, out_channels=1, is_bottleneck=False, in_channels=1):
super(SegNet_ResNet, self).__init__() super(SegNet_ResNet, self).__init__()
resnet_backbone = backbone resnet_backbone = backbone
@ -442,18 +477,25 @@ class SegNet_ResNet(nn.Module):
channels = (512, 256, 128) channels = (512, 256, 128)
for i, block in enumerate(resnet_r_blocks[:-1]): for i, block in enumerate(resnet_r_blocks[:-1]):
new_block = list(block.children())[::-1][:-1] new_block = list(block.children())[::-1][:-1]
decoder.append(nn.Sequential(*new_block, decoder.append(
DecoderBottleneck(channels[i]) nn.Sequential(
if is_bottleneck else DecoderBasicBlock(channels[i]))) *new_block,
DecoderBottleneck(channels[i])
if is_bottleneck
else DecoderBasicBlock(channels[i])
)
)
new_block = list(resnet_r_blocks[-1].children())[::-1][:-1] new_block = list(resnet_r_blocks[-1].children())[::-1][:-1]
decoder.append(nn.Sequential(*new_block, decoder.append(
LastBottleneck(256) nn.Sequential(
if is_bottleneck else LastBasicBlock(64))) *new_block, LastBottleneck(256) if is_bottleneck else LastBasicBlock(64)
)
)
self.decoder = nn.Sequential(*decoder) self.decoder = nn.Sequential(*decoder)
self.last_conv = nn.Sequential( self.last_conv = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, bias=False), nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, bias=False),
nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1) nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1),
) )
def forward(self, x): def forward(self, x):
@ -468,10 +510,14 @@ class SegNet_ResNet(nn.Module):
h_diff = ceil((x.size()[2] - indices.size()[2]) / 2) h_diff = ceil((x.size()[2] - indices.size()[2]) / 2)
w_diff = ceil((x.size()[3] - indices.size()[3]) / 2) w_diff = ceil((x.size()[3] - indices.size()[3]) / 2)
if indices.size()[2] % 2 == 1: if indices.size()[2] % 2 == 1:
x = x[:, :, h_diff:x.size()[2] - (h_diff - 1), x = x[
w_diff: x.size()[3] - (w_diff - 1)] :,
:,
h_diff : x.size()[2] - (h_diff - 1),
w_diff : x.size()[3] - (w_diff - 1),
]
else: else:
x = x[:, :, h_diff:x.size()[2] - h_diff, w_diff: x.size()[3] - w_diff] x = x[:, :, h_diff : x.size()[2] - h_diff, w_diff : x.size()[3] - w_diff]
x = F.max_unpool2d(x, indices, kernel_size=2, stride=2) x = F.max_unpool2d(x, indices, kernel_size=2, stride=2)
x = self.last_conv(x) x = self.last_conv(x)
@ -479,9 +525,11 @@ class SegNet_ResNet(nn.Module):
if inputsize != x.size(): if inputsize != x.size():
h_diff = (x.size()[2] - inputsize[2]) // 2 h_diff = (x.size()[2] - inputsize[2]) // 2
w_diff = (x.size()[3] - inputsize[3]) // 2 w_diff = (x.size()[3] - inputsize[3]) // 2
x = x[:, :, h_diff:x.size()[2] - h_diff, w_diff: x.size()[3] - w_diff] x = x[:, :, h_diff : x.size()[2] - h_diff, w_diff : x.size()[3] - w_diff]
if h_diff % 2 != 0: x = x[:, :, :-1, :] if h_diff % 2 != 0:
if w_diff % 2 != 0: x = x[:, :, :, :-1] x = x[:, :, :-1, :]
if w_diff % 2 != 0:
x = x[:, :, :, :-1]
return x return x
@ -492,8 +540,13 @@ def SegNet_ResNet18(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet18() backbone_net = resnet18()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=False, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=False,
in_channels=in_channels,
**kwargs
)
return model return model
@ -503,8 +556,13 @@ def SegNet_ResNet34(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet34() backbone_net = resnet34()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=False, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=False,
in_channels=in_channels,
**kwargs
)
return model return model
@ -514,8 +572,13 @@ def SegNet_ResNet50(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet50() backbone_net = resnet50()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=True,
in_channels=in_channels,
**kwargs
)
return model return model
@ -525,8 +588,13 @@ def SegNet_ResNet101(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet101() backbone_net = resnet101()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=True,
in_channels=in_channels,
**kwargs
)
return model return model
@ -536,12 +604,17 @@ def SegNet_ResNet152(in_channels=1, out_channels=1, **kwargs):
""" """
backbone_net = resnet101() backbone_net = resnet101()
model = SegNet_ResNet(backbone_net, out_channels=out_channels, is_bottleneck=True, model = SegNet_ResNet(
in_channels=in_channels, **kwargs) backbone_net,
out_channels=out_channels,
is_bottleneck=True,
in_channels=in_channels,
**kwargs
)
return model return model
if __name__ == '__main__': if __name__ == "__main__":
model = SegNet_AlexNet(in_channels=1, out_channels=1) model = SegNet_AlexNet(in_channels=1, out_channels=1)
print(model) print(model)
x = torch.randn(1, 1, 200, 200) x = torch.randn(1, 1, 200, 200)

View File

@ -1,4 +1,4 @@
''' Define the Transformer model ''' """ Define the Transformer model """
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
@ -13,54 +13,76 @@ def get_pad_mask(seq, pad_idx):
def get_subsequent_mask(seq): def get_subsequent_mask(seq):
''' For masking out the subsequent info. ''' """For masking out the subsequent info."""
sz_b, len_s = seq.size() sz_b, len_s = seq.size()
subsequent_mask = (1 - torch.triu( subsequent_mask = (
torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool() 1 - torch.triu(torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)
).bool()
return subsequent_mask return subsequent_mask
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200): def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__() super(PositionalEncoding, self).__init__()
# Not a parameter # Not a parameter
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) self.register_buffer(
"pos_table", self._get_sinusoid_encoding_table(n_position, d_hid)
)
def _get_sinusoid_encoding_table(self, n_position, d_hid): def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table ''' """Sinusoid position encoding table"""
# TODO: make it with torch instead of numpy # TODO: make it with torch instead of numpy
def get_position_angle_vec(position): def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0) return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x): def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach() return x + self.pos_table[:, : x.size(1)].clone().detach()
class Encoder(nn.Module): class Encoder(nn.Module):
''' A encoder model with self attention mechanism. ''' """A encoder model with self attention mechanism."""
def __init__( def __init__(
self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v, self,
d_model, d_inner, pad_idx, dropout=0.1, n_position=200, scale_emb=False): n_src_vocab,
d_word_vec,
n_layers,
n_head,
d_k,
d_v,
d_model,
d_inner,
pad_idx,
dropout=0.1,
n_position=200,
scale_emb=False,
):
super().__init__() super().__init__()
self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx) self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.ModuleList([ self.layer_stack = nn.ModuleList(
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) [
for _ in range(n_layers)]) EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
]
)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.scale_emb = scale_emb self.scale_emb = scale_emb
self.d_model = d_model self.d_model = d_model
@ -82,24 +104,39 @@ class Encoder(nn.Module):
if return_attns: if return_attns:
return enc_output, enc_slf_attn_list return enc_output, enc_slf_attn_list
return enc_output, return (enc_output,)
class Decoder(nn.Module): class Decoder(nn.Module):
''' A decoder model with self attention mechanism. ''' """A decoder model with self attention mechanism."""
def __init__( def __init__(
self, n_trg_vocab, d_word_vec, n_layers, n_head, d_k, d_v, self,
d_model, d_inner, pad_idx, n_position=200, dropout=0.1, scale_emb=False): n_trg_vocab,
d_word_vec,
n_layers,
n_head,
d_k,
d_v,
d_model,
d_inner,
pad_idx,
n_position=200,
dropout=0.1,
scale_emb=False,
):
super().__init__() super().__init__()
self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx) self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx)
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.ModuleList([ self.layer_stack = nn.ModuleList(
DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) [
for _ in range(n_layers)]) DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
]
)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.scale_emb = scale_emb self.scale_emb = scale_emb
self.d_model = d_model self.d_model = d_model
@ -117,24 +154,41 @@ class Decoder(nn.Module):
for dec_layer in self.layer_stack: for dec_layer in self.layer_stack:
dec_output, dec_slf_attn, dec_enc_attn = dec_layer( dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask) dec_output,
enc_output,
slf_attn_mask=trg_mask,
dec_enc_attn_mask=src_mask,
)
dec_slf_attn_list += [dec_slf_attn] if return_attns else [] dec_slf_attn_list += [dec_slf_attn] if return_attns else []
dec_enc_attn_list += [dec_enc_attn] if return_attns else [] dec_enc_attn_list += [dec_enc_attn] if return_attns else []
if return_attns: if return_attns:
return dec_output, dec_slf_attn_list, dec_enc_attn_list return dec_output, dec_slf_attn_list, dec_enc_attn_list
return dec_output, return (dec_output,)
class Transformer(nn.Module): class Transformer(nn.Module):
''' A sequence to sequence model with attention mechanism. ''' """A sequence to sequence model with attention mechanism."""
def __init__( def __init__(
self, n_src_vocab, n_trg_vocab, src_pad_idx, trg_pad_idx, self,
d_word_vec=512, d_model=512, d_inner=2048, n_src_vocab,
n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, n_position=200, n_trg_vocab,
trg_emb_prj_weight_sharing=True, emb_src_trg_weight_sharing=True, src_pad_idx,
scale_emb_or_prj='prj'): trg_pad_idx,
d_word_vec=512,
d_model=512,
d_inner=2048,
n_layers=6,
n_head=8,
d_k=64,
d_v=64,
dropout=0.1,
n_position=200,
trg_emb_prj_weight_sharing=True,
emb_src_trg_weight_sharing=True,
scale_emb_or_prj="prj",
):
super().__init__() super().__init__()
@ -150,22 +204,42 @@ class Transformer(nn.Module):
# 'prj': multiply (\sqrt{d_model} ^ -1) to linear projection output # 'prj': multiply (\sqrt{d_model} ^ -1) to linear projection output
# 'none': no multiplication # 'none': no multiplication
assert scale_emb_or_prj in ['emb', 'prj', 'none'] assert scale_emb_or_prj in ["emb", "prj", "none"]
scale_emb = (scale_emb_or_prj == 'emb') if trg_emb_prj_weight_sharing else False scale_emb = (scale_emb_or_prj == "emb") if trg_emb_prj_weight_sharing else False
self.scale_prj = (scale_emb_or_prj == 'prj') if trg_emb_prj_weight_sharing else False self.scale_prj = (
(scale_emb_or_prj == "prj") if trg_emb_prj_weight_sharing else False
)
self.d_model = d_model self.d_model = d_model
self.encoder = Encoder( self.encoder = Encoder(
n_src_vocab=n_src_vocab, n_position=n_position, n_src_vocab=n_src_vocab,
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, n_position=n_position,
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, d_word_vec=d_word_vec,
pad_idx=src_pad_idx, dropout=dropout, scale_emb=scale_emb) d_model=d_model,
d_inner=d_inner,
n_layers=n_layers,
n_head=n_head,
d_k=d_k,
d_v=d_v,
pad_idx=src_pad_idx,
dropout=dropout,
scale_emb=scale_emb,
)
self.decoder = Decoder( self.decoder = Decoder(
n_trg_vocab=n_trg_vocab, n_position=n_position, n_trg_vocab=n_trg_vocab,
d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, n_position=n_position,
n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, d_word_vec=d_word_vec,
pad_idx=trg_pad_idx, dropout=dropout, scale_emb=scale_emb) d_model=d_model,
d_inner=d_inner,
n_layers=n_layers,
n_head=n_head,
d_k=d_k,
d_v=d_v,
pad_idx=trg_pad_idx,
dropout=dropout,
scale_emb=scale_emb,
)
self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias=False) self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias=False)
@ -173,9 +247,10 @@ class Transformer(nn.Module):
if p.dim() > 1: if p.dim() > 1:
nn.init.xavier_uniform_(p) nn.init.xavier_uniform_(p)
assert d_model == d_word_vec, \ assert (
'To facilitate the residual connections, \ d_model == d_word_vec
the dimensions of all module outputs shall be the same.' ), "To facilitate the residual connections, \
the dimensions of all module outputs shall be the same."
if trg_emb_prj_weight_sharing: if trg_emb_prj_weight_sharing:
# Share the weight between target word embedding & last dense layer # Share the weight between target word embedding & last dense layer
@ -187,7 +262,9 @@ class Transformer(nn.Module):
def forward(self, src_seq, trg_seq): def forward(self, src_seq, trg_seq):
src_mask = get_pad_mask(src_seq, self.src_pad_idx) src_mask = get_pad_mask(src_seq, self.src_pad_idx)
trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq) trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(
trg_seq
)
enc_output, *_ = self.encoder(src_seq, src_mask) enc_output, *_ = self.encoder(src_seq, src_mask)
dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask) dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask)
@ -202,9 +279,12 @@ class Transformer(nn.Module):
class EncoderModify(nn.Module): class EncoderModify(nn.Module):
def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):
super(EncoderModify, self).__init__() super(EncoderModify, self).__init__()
self.layer_stack = nn.ModuleList([ self.layer_stack = nn.ModuleList(
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) [
for _ in range(n_layers)]) EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
]
)
def forward(self, src_seq): def forward(self, src_seq):
enc_output = src_seq enc_output = src_seq
@ -216,12 +296,14 @@ class EncoderModify(nn.Module):
# Modified Decoder # Modified Decoder
class DecoderModify(nn.Module): class DecoderModify(nn.Module):
def __init__( def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):
self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):
super(DecoderModify, self).__init__() super(DecoderModify, self).__init__()
self.layer_stack = nn.ModuleList([ self.layer_stack = nn.ModuleList(
DecoderLayerModify(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) [
for _ in range(n_layers)]) DecoderLayerModify(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)
]
)
def forward(self, trg_seq, src_key, enc_output): def forward(self, trg_seq, src_key, enc_output):
dec_output = trg_seq dec_output = trg_seq
@ -233,8 +315,17 @@ class DecoderModify(nn.Module):
# Modified Transformer # Modified Transformer
class TransformerRecon(nn.Module): class TransformerRecon(nn.Module):
def __init__( def __init__(
self, input_dim=None, output_dim=None, d_model=128, d_inner=512, self,
n_layers=2, n_head=4, d_k=32, d_v=32, dropout=0.1): input_dim=None,
output_dim=None,
d_model=128,
d_inner=512,
n_layers=2,
n_head=4,
d_k=32,
d_v=32,
dropout=0.1,
):
super(TransformerRecon, self).__init__() super(TransformerRecon, self).__init__()
@ -253,19 +344,31 @@ class TransformerRecon(nn.Module):
self.d_model = d_model self.d_model = d_model
self.encoder = EncoderModify( self.encoder = EncoderModify(
d_model=d_model, d_inner=d_inner, n_layers=n_layers, d_model=d_model,
n_head=n_head, d_k=d_k, d_v=d_v, dropout=dropout) d_inner=d_inner,
n_layers=n_layers,
n_head=n_head,
d_k=d_k,
d_v=d_v,
dropout=dropout,
)
self.decoder = DecoderModify( self.decoder = DecoderModify(
d_model=d_model, d_inner=d_inner, n_layers=1, d_model=d_model,
n_head=n_head, d_k=d_k, d_v=d_v, dropout=dropout) d_inner=d_inner,
n_layers=1,
n_head=n_head,
d_k=d_k,
d_v=d_v,
dropout=dropout,
)
self.pre = nn.Sequential( self.pre = nn.Sequential(
nn.Linear(128 + 2, 128), nn.Linear(128 + 2, 128),
nn.GELU(), nn.GELU(),
nn.Linear(128, 128), nn.Linear(128, 128),
nn.GELU(), nn.GELU(),
nn.Linear(128, 1) nn.Linear(128, 1),
) )
def forward(self, src_seq, src_label, trg_seq, trg_label=None): def forward(self, src_seq, src_label, trg_seq, trg_label=None):
@ -278,11 +381,3 @@ class TransformerRecon(nn.Module):
dec_output = self.decoder(trg_seq, src_seq, enc_output) dec_output = self.decoder(trg_seq, src_seq, enc_output)
return self.pre(torch.cat([dec_output, trg_x], dim=-1)) return self.pre(torch.cat([dec_output, trg_x], dim=-1))
def test_TransformerModify():
transformer = TransformerModify()
x_src = torch.randn(8, 10, 3)
src_seq = torch.randn(8, 10, 2)
x_trg = torch.randn(8, 400, 2)
y = transformer(x_src, src_seq, x_trg)
print(y.shape)

View File

@ -10,8 +10,9 @@ __all__ = ["UNet_VGG"]
class _EncoderBlock(nn.Module): class _EncoderBlock(nn.Module):
def __init__(
def __init__(self, in_channels, out_channels, dropout=False, polling=True, bn=False): self, in_channels, out_channels, dropout=False, polling=True, bn=False
):
super(_EncoderBlock, self).__init__() super(_EncoderBlock, self).__init__()
layers = [ layers = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
@ -35,15 +36,18 @@ class _EncoderBlock(nn.Module):
class _DecoderBlock(nn.Module): class _DecoderBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, bn=False): def __init__(self, in_channels, middle_channels, out_channels, bn=False):
super(_DecoderBlock, self).__init__() super(_DecoderBlock, self).__init__()
self.decode = nn.Sequential( self.decode = nn.Sequential(
nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1), nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(middle_channels) if bn else nn.GroupNorm(32, middle_channels), nn.BatchNorm2d(middle_channels)
if bn
else nn.GroupNorm(32, middle_channels),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1), nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(middle_channels) if bn else nn.GroupNorm(32, middle_channels), nn.BatchNorm2d(middle_channels)
if bn
else nn.GroupNorm(32, middle_channels),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2), nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2),
) )
@ -53,7 +57,6 @@ class _DecoderBlock(nn.Module):
class UNet_VGG(nn.Module): class UNet_VGG(nn.Module):
def __init__(self, out_channels=1, in_channels=1, bn=False): def __init__(self, out_channels=1, in_channels=1, bn=False):
super(UNet_VGG, self).__init__() super(UNet_VGG, self).__init__()
self.enc1 = _EncoderBlock(in_channels, 64, polling=False, bn=bn) self.enc1 = _EncoderBlock(in_channels, 64, polling=False, bn=bn)
@ -82,8 +85,17 @@ class UNet_VGG(nn.Module):
enc3 = self.enc3(enc2) enc3 = self.enc3(enc2)
enc4 = self.enc4(enc3) enc4 = self.enc4(enc3)
center = self.center(self.polling(enc4)) center = self.center(self.polling(enc4))
dec4 = self.dec4(torch.cat([F.interpolate(center, enc4.size()[-2:], mode='bilinear', dec4 = self.dec4(
align_corners=True), enc4], 1)) torch.cat(
[
F.interpolate(
center, enc4.size()[-2:], mode="bilinear", align_corners=True
),
enc4,
],
1,
)
)
dec3 = self.dec3(torch.cat([dec4, enc3], 1)) dec3 = self.dec3(torch.cat([dec4, enc3], 1))
dec2 = self.dec2(torch.cat([dec3, enc2], 1)) dec2 = self.dec2(torch.cat([dec3, enc2], 1))
dec1 = self.dec1(torch.cat([dec2, enc1], 1)) dec1 = self.dec1(torch.cat([dec2, enc1], 1))
@ -91,8 +103,8 @@ class UNet_VGG(nn.Module):
return final return final
if __name__ == '__main__': if __name__ == "__main__":
model = UNet(in_channels=1, out_channels=1) model = UNet_VGG(in_channels=1, out_channels=1)
print(model) print(model)
x = torch.randn(1, 1, 200, 200) x = torch.randn(1, 1, 200, 200)
with torch.no_grad(): with torch.no_grad():

View File

@ -16,8 +16,13 @@ class AlexNet(nn.Module):
super(AlexNet, self).__init__() super(AlexNet, self).__init__()
self.features3 = nn.Sequential( self.features3 = nn.Sequential(
# kernel(11, 11) -> kernel(7, 7) # kernel(11, 11) -> kernel(7, 7)
nn.Conv2d(in_channels=in_channels, out_channels=64, nn.Conv2d(
kernel_size=7, stride=4, padding=3), in_channels=in_channels,
out_channels=64,
kernel_size=7,
stride=4,
padding=3,
),
nn.BatchNorm2d(64) if bn else nn.GroupNorm(32, 64), nn.BatchNorm2d(64) if bn else nn.GroupNorm(32, 64),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
@ -49,7 +54,7 @@ class AlexNet(nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
x = torch.zeros(8, 1, 200, 200) x = torch.zeros(8, 1, 200, 200)
net = Alexnet() net = AlexNet()
print(net) print(net)
y = net(x) y = net(x)
print() print()

View File

@ -14,8 +14,8 @@ __all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152
model_urls = { model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
@ -24,7 +24,9 @@ model_urls = {
def conv3x3(in_planes, out_planes, stride=1): def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding""" """3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
)
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
@ -66,12 +68,9 @@ class Bottleneck(nn.Module):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, self.conv2 = nn.Conv2d(
planes, planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
kernel_size=3, )
stride=stride,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(planes) self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4) self.bn3 = nn.BatchNorm2d(planes * 4)
@ -106,7 +105,9 @@ class ResNet(nn.Module):
def __init__(self, block, layers, in_channels=1): def __init__(self, block, layers, in_channels=1):
self.inplanes = 64 self.inplanes = 64
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) self.conv1 = nn.Conv2d(
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -127,11 +128,13 @@ class ResNet(nn.Module):
downsample = None downsample = None
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
nn.Conv2d(self.inplanes, nn.Conv2d(
planes * block.expansion, self.inplanes,
kernel_size=1, planes * block.expansion,
stride=stride, kernel_size=1,
bias=False), stride=stride,
bias=False,
),
nn.BatchNorm2d(planes * block.expansion), nn.BatchNorm2d(planes * block.expansion),
) )
@ -173,7 +176,7 @@ def resnet18(pretrained=False, in_channels=1, **kwargs):
""" """
model = ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, **kwargs) model = ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, **kwargs)
if pretrained: if pretrained:
model._load_pretrained_model(model_urls['resnet18']) model._load_pretrained_model(model_urls["resnet18"])
return model return model
@ -185,7 +188,7 @@ def resnet34(pretrained=False, in_channels=1, **kwargs):
""" """
model = ResNet(BasicBlock, [3, 4, 6, 3], in_channels=in_channels, **kwargs) model = ResNet(BasicBlock, [3, 4, 6, 3], in_channels=in_channels, **kwargs)
if pretrained: if pretrained:
model._load_pretrained_model(model_urls['resnet34']) model._load_pretrained_model(model_urls["resnet34"])
return model return model

View File

@ -11,30 +11,33 @@ from src.utils.vgg_utils import load_state_dict_from_url
__all__ = [ __all__ = [
"VGG", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "VGG",
"vgg19_bn", "vgg19", "vgg11",
"vgg11_bn",
"vgg13",
"vgg13_bn",
"vgg16",
"vgg16_bn",
"vgg19_bn",
"vgg19",
] ]
model_urls = { model_urls = {
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', "vgg11": "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', "vgg13": "https://download.pytorch.org/models/vgg13-c768596a.pth",
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
} }
class VGG(nn.Module): class VGG(nn.Module):
def __init__( def __init__(
self, self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True
features: nn.Module,
num_classes: int = 1000,
init_weights: bool = True
) -> None: ) -> None:
super(VGG, self).__init__() super(VGG, self).__init__()
self.features = features self.features = features
@ -61,7 +64,7 @@ class VGG(nn.Module):
def _initialize_weights(self) -> None: def _initialize_weights(self) -> None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
@ -76,7 +79,7 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ
layers: List[nn.Module] = [] layers: List[nn.Module] = []
in_channels = 3 in_channels = 3
for v in cfg: for v in cfg:
if v == 'M': if v == "M":
layers += [nn.MaxPool2d(kernel_size=2, stride=2)] layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else: else:
v = cast(int, v) v = cast(int, v)
@ -90,20 +93,67 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ
cfgs: Dict[str, List[Union[str, int]]] = { cfgs: Dict[str, List[Union[str, int]]] = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], "D": [
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 64,
64,
"M",
128,
128,
"M",
256,
256,
256,
"M",
512,
512,
512,
"M",
512,
512,
512,
"M",
],
"E": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
256,
"M",
512,
512,
512,
512,
"M",
512,
512,
512,
512,
"M",
],
} }
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: def _vgg(
arch: str,
cfg: str,
batch_norm: bool,
pretrained: bool,
progress: bool,
**kwargs: Any
) -> VGG:
if pretrained: if pretrained:
kwargs['init_weights'] = False kwargs["init_weights"] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
@ -116,7 +166,7 @@ def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) return _vgg("vgg11", "A", False, pretrained, progress, **kwargs)
def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@ -127,7 +177,7 @@ def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs)
def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@ -138,7 +188,7 @@ def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) return _vgg("vgg13", "B", False, pretrained, progress, **kwargs)
def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@ -149,7 +199,7 @@ def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs)
def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@ -160,7 +210,7 @@ def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) return _vgg("vgg16", "D", False, pretrained, progress, **kwargs)
def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@ -171,7 +221,7 @@ def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs)
def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@ -182,7 +232,7 @@ def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) return _vgg("vgg19", "E", False, pretrained, progress, **kwargs)
def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG:
@ -193,4 +243,4 @@ def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs)

View File

@ -1,4 +1,3 @@
from .torch_nn import * from .torch_nn import *
from .torch_edge import * from .torch_edge import *
from .torch_vertex import * from .torch_vertex import *

View File

@ -9,6 +9,7 @@ class DenseDilated(nn.Module):
edge_index: (2, batch_size, num_points, k) edge_index: (2, batch_size, num_points, k)
""" """
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(DenseDilated, self).__init__() super(DenseDilated, self).__init__()
self.dilation = dilation self.dilation = dilation
@ -20,12 +21,12 @@ class DenseDilated(nn.Module):
if self.stochastic: if self.stochastic:
if torch.rand(1) < self.epsilon and self.training: if torch.rand(1) < self.epsilon and self.training:
num = self.k * self.dilation num = self.k * self.dilation
randnum = torch.randperm(num)[:self.k] randnum = torch.randperm(num)[: self.k]
edge_index = edge_index[:, :, :, randnum] edge_index = edge_index[:, :, :, randnum]
else: else:
edge_index = edge_index[:, :, :, ::self.dilation] edge_index = edge_index[:, :, :, :: self.dilation]
else: else:
edge_index = edge_index[:, :, :, ::self.dilation] edge_index = edge_index[:, :, :, :: self.dilation]
return edge_index return edge_index
@ -37,7 +38,7 @@ def pairwise_distance(x):
Returns: Returns:
pairwise distance: (batch_size, num_points, num_points) pairwise distance: (batch_size, num_points, num_points)
""" """
x_inner = -2*torch.matmul(x, x.transpose(2, 1)) x_inner = -2 * torch.matmul(x, x.transpose(2, 1))
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
return x_square + x_inner + x_square.transpose(2, 1) return x_square + x_inner + x_square.transpose(2, 1)
@ -54,7 +55,11 @@ def dense_knn_matrix(x, k=16):
x = x.transpose(2, 1).squeeze(-1) x = x.transpose(2, 1).squeeze(-1)
batch_size, n_points, n_dims = x.shape batch_size, n_points, n_dims = x.shape
_, nn_idx = torch.topk(-pairwise_distance(x.detach()), k=k) _, nn_idx = torch.topk(-pairwise_distance(x.detach()), k=k)
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1) center_idx = (
torch.arange(0, n_points, device=x.device)
.repeat(batch_size, k, 1)
.transpose(2, 1)
)
return torch.stack((nn_idx, center_idx), dim=0) return torch.stack((nn_idx, center_idx), dim=0)
@ -62,6 +67,7 @@ class DenseDilatedKnnGraph(nn.Module):
""" """
Find the neighbors' indices based on dilated knn Find the neighbors' indices based on dilated knn
""" """
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(DenseDilatedKnnGraph, self).__init__() super(DenseDilatedKnnGraph, self).__init__()
self.dilation = dilation self.dilation = dilation
@ -80,6 +86,7 @@ class DilatedKnnGraph(nn.Module):
""" """
Find the neighbors' indices based on dilated knn Find the neighbors' indices based on dilated knn
""" """
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(DilatedKnnGraph, self).__init__() super(DilatedKnnGraph, self).__init__()
self.dilation = dilation self.dilation = dilation
@ -94,7 +101,9 @@ class DilatedKnnGraph(nn.Module):
B, C, N = x.shape B, C, N = x.shape
edge_index = [] edge_index = []
for i in range(B): for i in range(B):
edgeindex = self.knn(x[i].contiguous().transpose(1, 0).contiguous(), self.k * self.dilation) edgeindex = self.knn(
x[i].contiguous().transpose(1, 0).contiguous(), self.k * self.dilation
)
edgeindex = edgeindex.view(2, N, self.k * self.dilation) edgeindex = edgeindex.view(2, N, self.k * self.dilation)
edge_index.append(edgeindex) edge_index.append(edgeindex)
edge_index = torch.stack(edge_index, dim=1) edge_index = torch.stack(edge_index, dim=1)

View File

@ -10,55 +10,55 @@ def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
# activation layer # activation layer
act = act.lower() act = act.lower()
if act == 'relu': if act == "relu":
layer = nn.ReLU(inplace) layer = nn.ReLU(inplace)
elif act == 'leakyrelu': elif act == "leakyrelu":
layer = nn.LeakyReLU(neg_slope, inplace) layer = nn.LeakyReLU(neg_slope, inplace)
elif act == 'prelu': elif act == "prelu":
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act == 'gelu': elif act == "gelu":
layer = nn.GELU() layer = nn.GELU()
elif act == 'sigmoid': elif act == "sigmoid":
layer = nn.Sigmoid() layer = nn.Sigmoid()
else: else:
raise NotImplementedError('activation layer [%s] is not found' % act) raise NotImplementedError("activation layer [%s] is not found" % act)
return layer return layer
def norm_layer(norm, nc): def norm_layer(norm, nc):
# normalization layer 2d # normalization layer 2d
norm = norm.lower() norm = norm.lower()
if norm == 'batch': if norm == "batch":
layer = nn.BatchNorm2d(nc, affine=False, track_running_stats=False) layer = nn.BatchNorm2d(nc, affine=False, track_running_stats=False)
elif norm == 'instance': elif norm == "instance":
layer = nn.InstanceNorm2d(nc, affine=True) layer = nn.InstanceNorm2d(nc, affine=True)
elif norm == 'group': elif norm == "group":
layer = nn.GroupNorm(32, nc, affine=False) layer = nn.GroupNorm(32, nc, affine=False)
else: else:
raise NotImplementedError('normalization layer [%s] is not found' % norm) raise NotImplementedError("normalization layer [%s] is not found" % norm)
return layer return layer
class MLP(Seq): class MLP(Seq):
def __init__(self, channels, act='relu', norm=None, bias=True): def __init__(self, channels, act="relu", norm=None, bias=True):
m = [] m = []
for i in range(1, len(channels)): for i in range(1, len(channels)):
m.append(Lin(channels[i - 1], channels[i], bias)) m.append(Lin(channels[i - 1], channels[i], bias))
if act is not None and act.lower() != 'none': if act is not None and act.lower() != "none":
m.append(act_layer(act)) m.append(act_layer(act))
if norm is not None and norm.lower() != 'none': if norm is not None and norm.lower() != "none":
m.append(norm_layer(norm, channels[-1])) m.append(norm_layer(norm, channels[-1]))
super(MLP, self).__init__(*m) super(MLP, self).__init__(*m)
class BasicConv(Seq): class BasicConv(Seq):
def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.): def __init__(self, channels, act="relu", norm=None, bias=True, drop=0.0):
m = [] m = []
for i in range(1, len(channels)): for i in range(1, len(channels)):
m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias)) m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias))
if act is not None and act.lower() != 'none': if act is not None and act.lower() != "none":
m.append(act_layer(act)) m.append(act_layer(act))
if norm is not None and norm.lower() != 'none': if norm is not None and norm.lower() != "none":
m.append(norm_layer(norm, channels[-1])) m.append(norm_layer(norm, channels[-1]))
if drop > 0: if drop > 0:
m.append(nn.Dropout2d(drop)) m.append(nn.Dropout2d(drop))
@ -95,11 +95,17 @@ def batched_index_select(x, idx):
""" """
batch_size, num_dims, num_vertices = x.shape[:3] batch_size, num_dims, num_vertices = x.shape[:3]
k = idx.shape[-1] k = idx.shape[-1]
idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices idx_base = (
torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices
)
idx = idx + idx_base idx = idx + idx_base
idx = idx.contiguous().view(-1) idx = idx.contiguous().view(-1)
x = x.transpose(2, 1) x = x.transpose(2, 1)
feature = x.contiguous().view(batch_size * num_vertices, -1)[idx, :] feature = x.contiguous().view(batch_size * num_vertices, -1)[idx, :]
feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous() feature = (
feature.view(batch_size, num_vertices, k, num_dims)
.permute(0, 3, 1, 2)
.contiguous()
)
return feature return feature

View File

@ -9,9 +9,10 @@ class MRConv2d(nn.Module):
""" """
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
""" """
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
def __init__(self, in_channels, out_channels, act="relu", norm=None, bias=True):
super(MRConv2d, self).__init__() super(MRConv2d, self).__init__()
self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias) self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias)
def forward(self, x, edge_index): def forward(self, x, edge_index):
x_i = batched_index_select(x, edge_index[1]) x_i = batched_index_select(x, edge_index[1])
@ -24,14 +25,17 @@ class EdgeConv2d(nn.Module):
""" """
Edge convolution layer (with activation, batch normalization) for dense data type Edge convolution layer (with activation, batch normalization) for dense data type
""" """
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True):
def __init__(self, in_channels, out_channels, act="relu", norm=None, bias=True):
super(EdgeConv2d, self).__init__() super(EdgeConv2d, self).__init__()
self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias) self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias)
def forward(self, x, edge_index): def forward(self, x, edge_index):
x_i = batched_index_select(x, edge_index[1]) x_i = batched_index_select(x, edge_index[1])
x_j = batched_index_select(x, edge_index[0]) x_j = batched_index_select(x, edge_index[0])
max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True) max_value, _ = torch.max(
self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True
)
return max_value return max_value
@ -39,14 +43,17 @@ class GraphConv2d(nn.Module):
""" """
Static graph convolution layer Static graph convolution layer
""" """
def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True):
def __init__(
self, in_channels, out_channels, conv="edge", act="relu", norm=None, bias=True
):
super(GraphConv2d, self).__init__() super(GraphConv2d, self).__init__()
if conv == 'edge': if conv == "edge":
self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias) self.gconv = EdgeConv2d(in_channels, out_channels, act, norm, bias)
elif conv == 'mr': elif conv == "mr":
self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias) self.gconv = MRConv2d(in_channels, out_channels, act, norm, bias)
else: else:
raise NotImplementedError('conv:{} is not supported'.format(conv)) raise NotImplementedError("conv:{} is not supported".format(conv))
def forward(self, x, edge_index): def forward(self, x, edge_index):
return self.gconv(x, edge_index) return self.gconv(x, edge_index)
@ -56,15 +63,34 @@ class DynConv2d(GraphConv2d):
""" """
Dynamic graph convolution layer Dynamic graph convolution layer
""" """
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu',
norm=None, bias=True, stochastic=False, epsilon=0.0, knn='matrix'): def __init__(
super(DynConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias) self,
in_channels,
out_channels,
kernel_size=9,
dilation=1,
conv="edge",
act="relu",
norm=None,
bias=True,
stochastic=False,
epsilon=0.0,
knn="matrix",
):
super(DynConv2d, self).__init__(
in_channels, out_channels, conv, act, norm, bias
)
self.k = kernel_size self.k = kernel_size
self.d = dilation self.d = dilation
if knn == 'matrix': if knn == "matrix":
self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) self.dilated_knn_graph = DenseDilatedKnnGraph(
kernel_size, dilation, stochastic, epsilon
)
else: else:
self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon) self.dilated_knn_graph = DilatedKnnGraph(
kernel_size, dilation, stochastic, epsilon
)
def forward(self, x): def forward(self, x):
edge_index = self.dilated_knn_graph(x) edge_index = self.dilated_knn_graph(x)
@ -75,11 +101,34 @@ class PlainDynBlock2d(nn.Module):
""" """
Plain Dynamic graph convolution block Plain Dynamic graph convolution block
""" """
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
bias=True, stochastic=False, epsilon=0.0, knn='matrix'): def __init__(
self,
in_channels,
kernel_size=9,
dilation=1,
conv="edge",
act="relu",
norm=None,
bias=True,
stochastic=False,
epsilon=0.0,
knn="matrix",
):
super(PlainDynBlock2d, self).__init__() super(PlainDynBlock2d, self).__init__()
self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, self.body = DynConv2d(
act, norm, bias, stochastic, epsilon, knn) in_channels,
in_channels,
kernel_size,
dilation,
conv,
act,
norm,
bias,
stochastic,
epsilon,
knn,
)
def forward(self, x): def forward(self, x):
return self.body(x) return self.body(x)
@ -89,26 +138,74 @@ class ResDynBlock2d(nn.Module):
""" """
Residual Dynamic graph convolution block Residual Dynamic graph convolution block
""" """
def __init__(self, in_channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
bias=True, stochastic=False, epsilon=0.0, knn='matrix', res_scale=1): def __init__(
self,
in_channels,
kernel_size=9,
dilation=1,
conv="edge",
act="relu",
norm=None,
bias=True,
stochastic=False,
epsilon=0.0,
knn="matrix",
res_scale=1,
):
super(ResDynBlock2d, self).__init__() super(ResDynBlock2d, self).__init__()
self.body = DynConv2d(in_channels, in_channels, kernel_size, dilation, conv, self.body = DynConv2d(
act, norm, bias, stochastic, epsilon, knn) in_channels,
in_channels,
kernel_size,
dilation,
conv,
act,
norm,
bias,
stochastic,
epsilon,
knn,
)
self.res_scale = res_scale self.res_scale = res_scale
def forward(self, x): def forward(self, x):
return self.body(x) + x*self.res_scale return self.body(x) + x * self.res_scale
class DenseDynBlock2d(nn.Module): class DenseDynBlock2d(nn.Module):
""" """
Dense Dynamic graph convolution block Dense Dynamic graph convolution block
""" """
def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge',
act='relu', norm=None,bias=True, stochastic=False, epsilon=0.0, knn='matrix'): def __init__(
self,
in_channels,
out_channels=64,
kernel_size=9,
dilation=1,
conv="edge",
act="relu",
norm=None,
bias=True,
stochastic=False,
epsilon=0.0,
knn="matrix",
):
super(DenseDynBlock2d, self).__init__() super(DenseDynBlock2d, self).__init__()
self.body = DynConv2d(in_channels, out_channels, kernel_size, dilation, conv, self.body = DynConv2d(
act, norm, bias, stochastic, epsilon, knn) in_channels,
out_channels,
kernel_size,
dilation,
conv,
act,
norm,
bias,
stochastic,
epsilon,
knn,
)
def forward(self, x): def forward(self, x):
dense = self.body(x) dense = self.body(x)

View File

@ -1,4 +1,3 @@
from .torch_nn import * from .torch_nn import *
from .torch_edge import * from .torch_edge import *
from .torch_vertex import * from .torch_vertex import *

View File

@ -7,6 +7,7 @@ class Dilated(nn.Module):
""" """
Find dilated neighbor from neighbor list Find dilated neighbor from neighbor list
""" """
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0): def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(Dilated, self).__init__() super(Dilated, self).__init__()
self.dilation = dilation self.dilation = dilation
@ -18,14 +19,14 @@ class Dilated(nn.Module):
if self.stochastic: if self.stochastic:
if torch.rand(1) < self.epsilon and self.training: if torch.rand(1) < self.epsilon and self.training:
num = self.k * self.dilation num = self.k * self.dilation
randnum = torch.randperm(num)[:self.k] randnum = torch.randperm(num)[: self.k]
edge_index = edge_index.view(2, -1, num) edge_index = edge_index.view(2, -1, num)
edge_index = edge_index[:, :, randnum] edge_index = edge_index[:, :, randnum]
return edge_index.view(2, -1) return edge_index.view(2, -1)
else: else:
edge_index = edge_index[:, ::self.dilation] edge_index = edge_index[:, :: self.dilation]
else: else:
edge_index = edge_index[:, ::self.dilation] edge_index = edge_index[:, :: self.dilation]
return edge_index return edge_index
@ -33,14 +34,15 @@ class DilatedKnnGraph(nn.Module):
""" """
Find the neighbors' indices based on dilated knn Find the neighbors' indices based on dilated knn
""" """
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn='matrix'):
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn="matrix"):
super(DilatedKnnGraph, self).__init__() super(DilatedKnnGraph, self).__init__()
self.dilation = dilation self.dilation = dilation
self.stochastic = stochastic self.stochastic = stochastic
self.epsilon = epsilon self.epsilon = epsilon
self.k = k self.k = k
self._dilated = Dilated(k, dilation, stochastic, epsilon) self._dilated = Dilated(k, dilation, stochastic, epsilon)
if knn == 'matrix': if knn == "matrix":
self.knn = knn_graph_matrix self.knn = knn_graph_matrix
else: else:
self.knn = knn_graph self.knn = knn_graph
@ -58,7 +60,7 @@ def pairwise_distance(x):
Returns: Returns:
pairwise distance: (batch_size, num_points, num_points) pairwise distance: (batch_size, num_points, num_points)
""" """
x_inner = -2*torch.matmul(x, x.transpose(2, 1)) x_inner = -2 * torch.matmul(x, x.transpose(2, 1))
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True) x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
return x_square + x_inner + x_square.transpose(2, 1) return x_square + x_inner + x_square.transpose(2, 1)
@ -83,7 +85,11 @@ def knn_matrix(x, k=16, batch=None):
del neg_adj del neg_adj
n_points = x.shape[1] n_points = x.shape[1]
start_idx = torch.arange(0, n_points*batch_size, n_points).long().view(batch_size, 1, 1) start_idx = (
torch.arange(0, n_points * batch_size, n_points)
.long()
.view(batch_size, 1, 1)
)
if x.is_cuda: if x.is_cuda:
start_idx = start_idx.cuda() start_idx = start_idx.cuda()
nn_idx += start_idx nn_idx += start_idx
@ -93,7 +99,13 @@ def knn_matrix(x, k=16, batch=None):
torch.cuda.empty_cache() torch.cuda.empty_cache()
nn_idx = nn_idx.view(1, -1) nn_idx = nn_idx.view(1, -1)
center_idx = torch.arange(0, n_points*batch_size).repeat(k, 1).transpose(1, 0).contiguous().view(1, -1) center_idx = (
torch.arange(0, n_points * batch_size)
.repeat(k, 1)
.transpose(1, 0)
.contiguous()
.view(1, -1)
)
if x.is_cuda: if x.is_cuda:
center_idx = center_idx.cuda() center_idx = center_idx.cuda()
return nn_idx, center_idx return nn_idx, center_idx
@ -110,4 +122,3 @@ def knn_graph_matrix(x, k=16, batch=None):
""" """
nn_idx, center_idx = knn_matrix(x, k, batch) nn_idx, center_idx = knn_matrix(x, k, batch)
return torch.cat((nn_idx, center_idx), dim=0) return torch.cat((nn_idx, center_idx), dim=0)

View File

@ -6,27 +6,33 @@ from torch_geometric.utils import degree
class GenMessagePassing(MessagePassing): class GenMessagePassing(MessagePassing):
def __init__(self, aggr='softmax', def __init__(
t=1.0, learn_t=False, self,
p=1.0, learn_p=False, aggr="softmax",
y=0.0, learn_y=False): t=1.0,
learn_t=False,
p=1.0,
learn_p=False,
y=0.0,
learn_y=False,
):
if aggr in ['softmax_sg', 'softmax', 'softmax_sum']: if aggr in ["softmax_sg", "softmax", "softmax_sum"]:
super(GenMessagePassing, self).__init__(aggr=None) super(GenMessagePassing, self).__init__(aggr=None)
self.aggr = aggr self.aggr = aggr
if learn_t and (aggr == 'softmax' or aggr == 'softmax_sum'): if learn_t and (aggr == "softmax" or aggr == "softmax_sum"):
self.learn_t = True self.learn_t = True
self.t = torch.nn.Parameter(torch.Tensor([t]), requires_grad=True) self.t = torch.nn.Parameter(torch.Tensor([t]), requires_grad=True)
else: else:
self.learn_t = False self.learn_t = False
self.t = t self.t = t
if aggr == 'softmax_sum': if aggr == "softmax_sum":
self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y)
elif aggr in ['power', 'power_sum']: elif aggr in ["power", "power_sum"]:
super(GenMessagePassing, self).__init__(aggr=None) super(GenMessagePassing, self).__init__(aggr=None)
self.aggr = aggr self.aggr = aggr
@ -36,45 +42,52 @@ class GenMessagePassing(MessagePassing):
else: else:
self.p = p self.p = p
if aggr == 'power_sum': if aggr == "power_sum":
self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y) self.y = torch.nn.Parameter(torch.Tensor([y]), requires_grad=learn_y)
else: else:
super(GenMessagePassing, self).__init__(aggr=aggr) super(GenMessagePassing, self).__init__(aggr=aggr)
def aggregate(self, inputs, index, ptr=None, dim_size=None): def aggregate(self, inputs, index, ptr=None, dim_size=None):
if self.aggr in ['add', 'mean', 'max', None]: if self.aggr in ["add", "mean", "max", None]:
return super(GenMessagePassing, self).aggregate(inputs, index, ptr, dim_size) return super(GenMessagePassing, self).aggregate(
inputs, index, ptr, dim_size
)
elif self.aggr in ['softmax_sg', 'softmax', 'softmax_sum']: elif self.aggr in ["softmax_sg", "softmax", "softmax_sum"]:
if self.learn_t: if self.learn_t:
out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
else: else:
with torch.no_grad(): with torch.no_grad():
out = scatter_softmax(inputs*self.t, index, dim=self.node_dim) out = scatter_softmax(inputs * self.t, index, dim=self.node_dim)
out = scatter(inputs*out, index, dim=self.node_dim, out = scatter(
dim_size=dim_size, reduce='sum') inputs * out, index, dim=self.node_dim, dim_size=dim_size, reduce="sum"
)
if self.aggr == 'softmax_sum': if self.aggr == "softmax_sum":
self.sigmoid_y = torch.sigmoid(self.y) self.sigmoid_y = torch.sigmoid(self.y)
degrees = degree(index, num_nodes=dim_size).unsqueeze(1) degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
out = torch.pow(degrees, self.sigmoid_y) * out out = torch.pow(degrees, self.sigmoid_y) * out
return out return out
elif self.aggr in ["power", "power_sum"]:
elif self.aggr in ['power', 'power_sum']:
min_value, max_value = 1e-7, 1e1 min_value, max_value = 1e-7, 1e1
torch.clamp_(inputs, min_value, max_value) torch.clamp_(inputs, min_value, max_value)
out = scatter(torch.pow(inputs, self.p), index, dim=self.node_dim, out = scatter(
dim_size=dim_size, reduce='mean') torch.pow(inputs, self.p),
index,
dim=self.node_dim,
dim_size=dim_size,
reduce="mean",
)
torch.clamp_(out, min_value, max_value) torch.clamp_(out, min_value, max_value)
out = torch.pow(out, 1/self.p) out = torch.pow(out, 1 / self.p)
# torch.clamp(out, min_value, max_value) # torch.clamp(out, min_value, max_value)
if self.aggr == 'power_sum': if self.aggr == "power_sum":
self.sigmoid_y = torch.sigmoid(self.y) self.sigmoid_y = torch.sigmoid(self.y)
degrees = degree(index, num_nodes=dim_size).unsqueeze(1) degrees = degree(index, num_nodes=dim_size).unsqueeze(1)
out = torch.pow(degrees, self.sigmoid_y) * out out = torch.pow(degrees, self.sigmoid_y) * out
@ -82,15 +95,16 @@ class GenMessagePassing(MessagePassing):
return out return out
else: else:
raise NotImplementedError('To be implemented') raise NotImplementedError("To be implemented")
class MsgNorm(torch.nn.Module): class MsgNorm(torch.nn.Module):
def __init__(self, learn_msg_scale=False): def __init__(self, learn_msg_scale=False):
super(MsgNorm, self).__init__() super(MsgNorm, self).__init__()
self.msg_scale = torch.nn.Parameter(torch.Tensor([1.0]), self.msg_scale = torch.nn.Parameter(
requires_grad=learn_msg_scale) torch.Tensor([1.0]), requires_grad=learn_msg_scale
)
def forward(self, x, msg, p=2): def forward(self, x, msg, p=2):
msg = F.normalize(msg, p=p, dim=1) msg = F.normalize(msg, p=p, dim=1)

View File

@ -9,32 +9,32 @@ from .utils.data_util import get_atom_feature_dims, get_bond_feature_dims
def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1):
# activation layer # activation layer
act = act_type.lower() act = act_type.lower()
if act == 'relu': if act == "relu":
layer = nn.ReLU(inplace) layer = nn.ReLU(inplace)
elif act == 'leakyrelu': elif act == "leakyrelu":
layer = nn.LeakyReLU(neg_slope, inplace) layer = nn.LeakyReLU(neg_slope, inplace)
elif act == 'prelu': elif act == "prelu":
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act == 'sigmoid': elif act == "sigmoid":
layer= nn.Sigmoid() layer = nn.Sigmoid()
else: else:
raise NotImplementedError('activation layer [%s] is not found' % act) raise NotImplementedError("activation layer [%s] is not found" % act)
return layer return layer
def norm_layer(norm_type, nc): def norm_layer(norm_type, nc):
# normalization layer 1d # normalization layer 1d
norm = norm_type.lower() norm = norm_type.lower()
if norm == 'batch': if norm == "batch":
layer = nn.BatchNorm1d(nc, affine=True, track_running_stats=False) layer = nn.BatchNorm1d(nc, affine=True, track_running_stats=False)
elif norm == 'layer': elif norm == "layer":
layer = nn.LayerNorm(nc, elementwise_affine=True) layer = nn.LayerNorm(nc, elementwise_affine=True)
elif norm == 'instance': elif norm == "instance":
layer = nn.InstanceNorm1d(nc, affine=False) layer = nn.InstanceNorm1d(nc, affine=False)
elif norm == 'group': elif norm == "group":
layer = nn.GroupNorm(32, nc, affine=True) layer = nn.GroupNorm(32, nc, affine=True)
else: else:
raise NotImplementedError('normalization layer [%s] is not found' % norm) raise NotImplementedError("normalization layer [%s] is not found" % norm)
return layer return layer
@ -52,9 +52,9 @@ class MultiSeq(Seq):
class MLP(Seq): class MLP(Seq):
def __init__(self, channels, act='relu', def __init__(
norm=None, bias=True, self, channels, act="relu", norm=None, bias=True, drop=0.0, last_lin=False
drop=0., last_lin=False): ):
m = [] m = []
for i in range(1, len(channels)): for i in range(1, len(channels)):
@ -64,9 +64,9 @@ class MLP(Seq):
if (i == len(channels) - 1) and last_lin: if (i == len(channels) - 1) and last_lin:
pass pass
else: else:
if norm is not None and norm.lower() != 'none': if norm is not None and norm.lower() != "none":
m.append(norm_layer(norm, channels[i])) m.append(norm_layer(norm, channels[i]))
if act is not None and act.lower() != 'none': if act is not None and act.lower() != "none":
m.append(act_layer(act)) m.append(act_layer(act))
if drop > 0: if drop > 0:
m.append(nn.Dropout2d(drop)) m.append(nn.Dropout2d(drop))
@ -76,7 +76,6 @@ class MLP(Seq):
class AtomEncoder(nn.Module): class AtomEncoder(nn.Module):
def __init__(self, emb_dim): def __init__(self, emb_dim):
super(AtomEncoder, self).__init__() super(AtomEncoder, self).__init__()
@ -97,7 +96,6 @@ class AtomEncoder(nn.Module):
class BondEncoder(nn.Module): class BondEncoder(nn.Module):
def __init__(self, emb_dim): def __init__(self, emb_dim):
super(BondEncoder, self).__init__() super(BondEncoder, self).__init__()
@ -115,5 +113,3 @@ class BondEncoder(nn.Module):
bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])
return bond_embedding return bond_embedding

View File

@ -10,35 +10,43 @@ from torch_geometric.utils import remove_self_loops, add_self_loops
class GENConv(GenMessagePassing): class GENConv(GenMessagePassing):
""" """
GENeralized Graph Convolution (GENConv): https://arxiv.org/pdf/2006.07739.pdf GENeralized Graph Convolution (GENConv): https://arxiv.org/pdf/2006.07739.pdf
SoftMax & PowerMean Aggregation SoftMax & PowerMean Aggregation
""" """
def __init__(self, in_dim, emb_dim,
aggr='softmax',
t=1.0, learn_t=False,
p=1.0, learn_p=False,
y=0.0, learn_y=False,
msg_norm=False, learn_msg_scale=True,
encode_edge=False, bond_encoder=False,
edge_feat_dim=None,
norm='batch', mlp_layers=2,
eps=1e-7):
super(GENConv, self).__init__(aggr=aggr, def __init__(
t=t, learn_t=learn_t, self,
p=p, learn_p=learn_p, in_dim,
y=y, learn_y=learn_y) emb_dim,
aggr="softmax",
t=1.0,
learn_t=False,
p=1.0,
learn_p=False,
y=0.0,
learn_y=False,
msg_norm=False,
learn_msg_scale=True,
encode_edge=False,
bond_encoder=False,
edge_feat_dim=None,
norm="batch",
mlp_layers=2,
eps=1e-7,
):
super(GENConv, self).__init__(
aggr=aggr, t=t, learn_t=learn_t, p=p, learn_p=learn_p, y=y, learn_y=learn_y
)
channels_list = [in_dim] channels_list = [in_dim]
for i in range(mlp_layers-1): for i in range(mlp_layers - 1):
channels_list.append(in_dim*2) channels_list.append(in_dim * 2)
channels_list.append(emb_dim) channels_list.append(emb_dim)
self.mlp = MLP(channels=channels_list, self.mlp = MLP(channels=channels_list, norm=norm, last_lin=True)
norm=norm,
last_lin=True)
self.msg_encoder = torch.nn.ReLU() self.msg_encoder = torch.nn.ReLU()
self.eps = eps self.eps = eps
@ -91,14 +99,23 @@ class MRConv(nn.Module):
""" """
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751)
""" """
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'):
def __init__(
self, in_channels, out_channels, act="relu", norm=None, bias=True, aggr="max"
):
super(MRConv, self).__init__() super(MRConv, self).__init__()
self.nn = MLP([in_channels*2, out_channels], act, norm, bias) self.nn = MLP([in_channels * 2, out_channels], act, norm, bias)
self.aggr = aggr self.aggr = aggr
def forward(self, x, edge_index): def forward(self, x, edge_index):
"""""" """"""
x_j = tg.utils.scatter_(self.aggr, torch.index_select(x, 0, edge_index[0]) - torch.index_select(x, 0, edge_index[1]), edge_index[1], dim_size=x.shape[0]) x_j = tg.utils.scatter_(
self.aggr,
torch.index_select(x, 0, edge_index[0])
- torch.index_select(x, 0, edge_index[1]),
edge_index[1],
dim_size=x.shape[0],
)
return self.nn(torch.cat([x, x_j], dim=1)) return self.nn(torch.cat([x, x_j], dim=1))
@ -106,8 +123,13 @@ class EdgConv(tg.nn.EdgeConv):
""" """
Edge convolution layer (with activation, batch normalization) Edge convolution layer (with activation, batch normalization)
""" """
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='max'):
super(EdgConv, self).__init__(MLP([in_channels*2, out_channels], act, norm, bias), aggr) def __init__(
self, in_channels, out_channels, act="relu", norm=None, bias=True, aggr="max"
):
super(EdgConv, self).__init__(
MLP([in_channels * 2, out_channels], act, norm, bias), aggr
)
def forward(self, x, edge_index): def forward(self, x, edge_index):
return super(EdgConv, self).forward(x, edge_index) return super(EdgConv, self).forward(x, edge_index)
@ -117,10 +139,13 @@ class GATConv(nn.Module):
""" """
Graph Attention Convolution layer (with activation, batch normalization) Graph Attention Convolution layer (with activation, batch normalization)
""" """
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, heads=8):
def __init__(
self, in_channels, out_channels, act="relu", norm=None, bias=True, heads=8
):
super(GATConv, self).__init__() super(GATConv, self).__init__()
self.gconv = tg.nn.GATConv(in_channels, out_channels, heads, bias=bias) self.gconv = tg.nn.GATConv(in_channels, out_channels, heads, bias=bias)
m =[] m = []
if act: if act:
m.append(act_layer(act)) m.append(act_layer(act))
if norm: if norm:
@ -154,19 +179,25 @@ class SAGEConv(tg.nn.SAGEConv):
:class:`torch_geometric.nn.conv.MessagePassing`. :class:`torch_geometric.nn.conv.MessagePassing`.
""" """
def __init__(self, def __init__(
in_channels, self,
out_channels, in_channels,
nn, out_channels,
norm=True, nn,
bias=True, norm=True,
relative=False, bias=True,
**kwargs): relative=False,
**kwargs
):
self.relative = relative self.relative = relative
if norm is not None: if norm is not None:
super(SAGEConv, self).__init__(in_channels, out_channels, True, bias, **kwargs) super(SAGEConv, self).__init__(
in_channels, out_channels, True, bias, **kwargs
)
else: else:
super(SAGEConv, self).__init__(in_channels, out_channels, False, bias, **kwargs) super(SAGEConv, self).__init__(
in_channels, out_channels, False, bias, **kwargs
)
self.nn = nn self.nn = nn
def forward(self, x, edge_index, size=None): def forward(self, x, edge_index, size=None):
@ -199,9 +230,19 @@ class RSAGEConv(SAGEConv):
Residual SAGE convolution layer (with activation, batch normalization) Residual SAGE convolution layer (with activation, batch normalization)
""" """
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, relative=False): def __init__(
self,
in_channels,
out_channels,
act="relu",
norm=None,
bias=True,
relative=False,
):
nn = MLP([out_channels + in_channels, out_channels], act, norm, bias) nn = MLP([out_channels + in_channels, out_channels], act, norm, bias)
super(RSAGEConv, self).__init__(in_channels, out_channels, nn, norm, bias, relative) super(RSAGEConv, self).__init__(
in_channels, out_channels, nn, norm, bias, relative
)
class SemiGCNConv(nn.Module): class SemiGCNConv(nn.Module):
@ -209,7 +250,7 @@ class SemiGCNConv(nn.Module):
SemiGCN convolution layer (with activation, batch normalization) SemiGCN convolution layer (with activation, batch normalization)
""" """
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True): def __init__(self, in_channels, out_channels, act="relu", norm=None, bias=True):
super(SemiGCNConv, self).__init__() super(SemiGCNConv, self).__init__()
self.gconv = tg.nn.GCNConv(in_channels, out_channels, bias=bias) self.gconv = tg.nn.GCNConv(in_channels, out_channels, bias=bias)
m = [] m = []
@ -228,7 +269,10 @@ class GinConv(tg.nn.GINConv):
""" """
GINConv layer (with activation, batch normalization) GINConv layer (with activation, batch normalization)
""" """
def __init__(self, in_channels, out_channels, act='relu', norm=None, bias=True, aggr='add'):
def __init__(
self, in_channels, out_channels, act="relu", norm=None, bias=True, aggr="add"
):
super(GinConv, self).__init__(MLP([in_channels, out_channels], act, norm, bias)) super(GinConv, self).__init__(MLP([in_channels, out_channels], act, norm, bias))
def forward(self, x, edge_index): def forward(self, x, edge_index):
@ -239,25 +283,36 @@ class GraphConv(nn.Module):
""" """
Static graph convolution layer Static graph convolution layer
""" """
def __init__(self, in_channels, out_channels, conv='edge',
act='relu', norm=None, bias=True, heads=8): def __init__(
self,
in_channels,
out_channels,
conv="edge",
act="relu",
norm=None,
bias=True,
heads=8,
):
super(GraphConv, self).__init__() super(GraphConv, self).__init__()
if conv.lower() == 'edge': if conv.lower() == "edge":
self.gconv = EdgConv(in_channels, out_channels, act, norm, bias) self.gconv = EdgConv(in_channels, out_channels, act, norm, bias)
elif conv.lower() == 'mr': elif conv.lower() == "mr":
self.gconv = MRConv(in_channels, out_channels, act, norm, bias) self.gconv = MRConv(in_channels, out_channels, act, norm, bias)
elif conv.lower() == 'gat': elif conv.lower() == "gat":
self.gconv = GATConv(in_channels, out_channels//heads, act, norm, bias, heads) self.gconv = GATConv(
elif conv.lower() == 'gcn': in_channels, out_channels // heads, act, norm, bias, heads
)
elif conv.lower() == "gcn":
self.gconv = SemiGCNConv(in_channels, out_channels, act, norm, bias) self.gconv = SemiGCNConv(in_channels, out_channels, act, norm, bias)
elif conv.lower() == 'gin': elif conv.lower() == "gin":
self.gconv = GinConv(in_channels, out_channels, act, norm, bias) self.gconv = GinConv(in_channels, out_channels, act, norm, bias)
elif conv.lower() == 'sage': elif conv.lower() == "sage":
self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, False) self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, False)
elif conv.lower() == 'rsage': elif conv.lower() == "rsage":
self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, True) self.gconv = RSAGEConv(in_channels, out_channels, act, norm, bias, True)
else: else:
raise NotImplementedError('conv {} is not implemented'.format(conv)) raise NotImplementedError("conv {} is not implemented".format(conv))
def forward(self, x, edge_index): def forward(self, x, edge_index):
return self.gconv(x, edge_index) return self.gconv(x, edge_index)
@ -267,9 +322,23 @@ class DynConv(GraphConv):
""" """
Dynamic graph convolution layer Dynamic graph convolution layer
""" """
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='edge', act='relu',
norm=None, bias=True, heads=8, **kwargs): def __init__(
super(DynConv, self).__init__(in_channels, out_channels, conv, act, norm, bias, heads) self,
in_channels,
out_channels,
kernel_size=9,
dilation=1,
conv="edge",
act="relu",
norm=None,
bias=True,
heads=8,
**kwargs
):
super(DynConv, self).__init__(
in_channels, out_channels, conv, act, norm, bias, heads
)
self.k = kernel_size self.k = kernel_size
self.d = dilation self.d = dilation
self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, **kwargs) self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, **kwargs)
@ -283,11 +352,23 @@ class PlainDynBlock(nn.Module):
""" """
Plain Dynamic graph convolution block Plain Dynamic graph convolution block
""" """
def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
bias=True, res_scale=1, **kwargs): def __init__(
self,
channels,
kernel_size=9,
dilation=1,
conv="edge",
act="relu",
norm=None,
bias=True,
res_scale=1,
**kwargs
):
super(PlainDynBlock, self).__init__() super(PlainDynBlock, self).__init__()
self.body = DynConv(channels, channels, kernel_size, dilation, conv, self.body = DynConv(
act, norm, bias, **kwargs) channels, channels, kernel_size, dilation, conv, act, norm, bias, **kwargs
)
self.res_scale = res_scale self.res_scale = res_scale
def forward(self, x, batch=None): def forward(self, x, batch=None):
@ -298,25 +379,58 @@ class ResDynBlock(nn.Module):
""" """
Residual Dynamic graph convolution block Residual Dynamic graph convolution block
""" """
def __init__(self, channels, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None,
bias=True, res_scale=1, **kwargs): def __init__(
self,
channels,
kernel_size=9,
dilation=1,
conv="edge",
act="relu",
norm=None,
bias=True,
res_scale=1,
**kwargs
):
super(ResDynBlock, self).__init__() super(ResDynBlock, self).__init__()
self.body = DynConv(channels, channels, kernel_size, dilation, conv, self.body = DynConv(
act, norm, bias, **kwargs) channels, channels, kernel_size, dilation, conv, act, norm, bias, **kwargs
)
self.res_scale = res_scale self.res_scale = res_scale
def forward(self, x, batch=None): def forward(self, x, batch=None):
return self.body(x, batch) + x*self.res_scale, batch return self.body(x, batch) + x * self.res_scale, batch
class DenseDynBlock(nn.Module): class DenseDynBlock(nn.Module):
""" """
Dense Dynamic graph convolution block Dense Dynamic graph convolution block
""" """
def __init__(self, in_channels, out_channels=64, kernel_size=9, dilation=1, conv='edge', act='relu', norm=None, bias=True, **kwargs):
def __init__(
self,
in_channels,
out_channels=64,
kernel_size=9,
dilation=1,
conv="edge",
act="relu",
norm=None,
bias=True,
**kwargs
):
super(DenseDynBlock, self).__init__() super(DenseDynBlock, self).__init__()
self.body = DynConv(in_channels, out_channels, kernel_size, dilation, conv, self.body = DynConv(
act, norm, bias, **kwargs) in_channels,
out_channels,
kernel_size,
dilation,
conv,
act,
norm,
bias,
**kwargs
)
def forward(self, x, batch=None): def forward(self, x, batch=None):
dense = self.body(x, batch) dense = self.body(x, batch)
@ -327,24 +441,43 @@ class ResGraphBlock(nn.Module):
""" """
Residual Static graph convolution block Residual Static graph convolution block
""" """
def __init__(self, channels, conv='edge', act='relu', norm=None, bias=True, heads=8, res_scale=1):
def __init__(
self,
channels,
conv="edge",
act="relu",
norm=None,
bias=True,
heads=8,
res_scale=1,
):
super(ResGraphBlock, self).__init__() super(ResGraphBlock, self).__init__()
self.body = GraphConv(channels, channels, conv, act, norm, bias, heads) self.body = GraphConv(channels, channels, conv, act, norm, bias, heads)
self.res_scale = res_scale self.res_scale = res_scale
def forward(self, x, edge_index): def forward(self, x, edge_index):
return self.body(x, edge_index) + x*self.res_scale, edge_index return self.body(x, edge_index) + x * self.res_scale, edge_index
class DenseGraphBlock(nn.Module): class DenseGraphBlock(nn.Module):
""" """
Dense Static graph convolution block Dense Static graph convolution block
""" """
def __init__(self, in_channels, out_channels, conv='edge', act='relu', norm=None, bias=True, heads=8):
def __init__(
self,
in_channels,
out_channels,
conv="edge",
act="relu",
norm=None,
bias=True,
heads=8,
):
super(DenseGraphBlock, self).__init__() super(DenseGraphBlock, self).__init__()
self.body = GraphConv(in_channels, out_channels, conv, act, norm, bias, heads) self.body = GraphConv(in_channels, out_channels, conv, act, norm, bias, heads)
def forward(self, x, edge_index): def forward(self, x, edge_index):
dense = self.body(x, edge_index) dense = self.body(x, edge_index)
return torch.cat((x, dense), 1), edge_index return torch.cat((x, dense), 1), edge_index

View File

@ -1,7 +1,8 @@
from .ckpt_util import * from .ckpt_util import *
# from .data_util import * # from .data_util import *
from .loss import * from .loss import *
from .metrics import * from .metrics import *
from .optim import * from .optim import *
# from .tf_logger import *
# from .tf_logger import *

View File

@ -6,25 +6,27 @@ import logging
import numpy as np import numpy as np
def save_ckpt(model, optimizer, loss, epoch, save_path, name_pre, name_post='best'): def save_ckpt(model, optimizer, loss, epoch, save_path, name_pre, name_post="best"):
model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
state = { state = {
'epoch': epoch, "epoch": epoch,
'model_state_dict': model_cpu, "model_state_dict": model_cpu,
'optimizer_state_dict': optimizer.state_dict(), "optimizer_state_dict": optimizer.state_dict(),
'loss': loss "loss": loss,
} }
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.mkdir(save_path) os.mkdir(save_path)
print("Directory ", save_path, " is created.") print("Directory ", save_path, " is created.")
filename = '{}/{}_{}.pth'.format(save_path, name_pre, name_post) filename = "{}/{}_{}.pth".format(save_path, name_pre, name_post)
torch.save(state, filename) torch.save(state, filename)
print('model has been saved as {}'.format(filename)) print("model has been saved as {}".format(filename))
def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax means max best def load_pretrained_models(
model, pretrained_model, phase, ismax=True
): # ismax means max best
if ismax: if ismax:
best_value = -np.inf best_value = -np.inf
else: else:
@ -36,7 +38,7 @@ def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax
logging.info("===> Loading checkpoint '{}'".format(pretrained_model)) logging.info("===> Loading checkpoint '{}'".format(pretrained_model))
checkpoint = torch.load(pretrained_model) checkpoint = torch.load(pretrained_model)
try: try:
best_value = checkpoint['best_value'] best_value = checkpoint["best_value"]
if best_value == -np.inf or best_value == np.inf: if best_value == -np.inf or best_value == np.inf:
show_best_value = False show_best_value = False
else: else:
@ -46,11 +48,13 @@ def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax
show_best_value = False show_best_value = False
model_dict = model.state_dict() model_dict = model.state_dict()
ckpt_model_state_dict = checkpoint['state_dict'] ckpt_model_state_dict = checkpoint["state_dict"]
# rename ckpt (avoid name is not same because of multi-gpus) # rename ckpt (avoid name is not same because of multi-gpus)
is_model_multi_gpus = True if list(model_dict)[0][0][0] == 'm' else False is_model_multi_gpus = True if list(model_dict)[0][0][0] == "m" else False
is_ckpt_multi_gpus = True if list(ckpt_model_state_dict)[0][0] == 'm' else False is_ckpt_multi_gpus = (
True if list(ckpt_model_state_dict)[0][0] == "m" else False
)
if not (is_model_multi_gpus == is_ckpt_multi_gpus): if not (is_model_multi_gpus == is_ckpt_multi_gpus):
temp_dict = OrderedDict() temp_dict = OrderedDict()
@ -58,7 +62,7 @@ def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax
if is_ckpt_multi_gpus: if is_ckpt_multi_gpus:
name = k[7:] # remove 'module.' name = k[7:] # remove 'module.'
else: else:
name = 'module.'+k # add 'module' name = "module." + k # add 'module'
temp_dict[name] = v temp_dict[name] = v
# load params # load params
ckpt_model_state_dict = temp_dict ckpt_model_state_dict = temp_dict
@ -67,34 +71,44 @@ def load_pretrained_models(model, pretrained_model, phase, ismax=True): # ismax
model.load_state_dict(ckpt_model_state_dict) model.load_state_dict(ckpt_model_state_dict)
if show_best_value: if show_best_value:
logging.info("The pretrained_model is at checkpoint {}. \t " logging.info(
"Best value: {}".format(checkpoint['epoch'], best_value)) "The pretrained_model is at checkpoint {}. \t "
"Best value: {}".format(checkpoint["epoch"], best_value)
)
else: else:
logging.info("The pretrained_model is at checkpoint {}.".format(checkpoint['epoch'])) logging.info(
"The pretrained_model is at checkpoint {}.".format(
checkpoint["epoch"]
)
)
if phase == 'train': if phase == "train":
epoch = checkpoint['epoch'] epoch = checkpoint["epoch"]
else: else:
epoch = -1 epoch = -1
else: else:
raise ImportError("===> No checkpoint found at '{}'".format(pretrained_model)) raise ImportError(
"===> No checkpoint found at '{}'".format(pretrained_model)
)
else: else:
logging.info('===> No pre-trained model') logging.info("===> No pre-trained model")
return model, best_value, epoch return model, best_value, epoch
def load_pretrained_optimizer(pretrained_model, optimizer, scheduler, lr, use_ckpt_lr=True): def load_pretrained_optimizer(
pretrained_model, optimizer, scheduler, lr, use_ckpt_lr=True
):
if pretrained_model: if pretrained_model:
if os.path.isfile(pretrained_model): if os.path.isfile(pretrained_model):
checkpoint = torch.load(pretrained_model) checkpoint = torch.load(pretrained_model)
if 'optimizer_state_dict' in checkpoint.keys(): if "optimizer_state_dict" in checkpoint.keys():
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
for state in optimizer.state.values(): for state in optimizer.state.values():
for k, v in state.items(): for k, v in state.items():
if torch.is_tensor(v): if torch.is_tensor(v):
state[k] = v.cuda() state[k] = v.cuda()
if 'scheduler_state_dict' in checkpoint.keys(): if "scheduler_state_dict" in checkpoint.keys():
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
if use_ckpt_lr: if use_ckpt_lr:
try: try:
lr = scheduler.get_lr()[0] lr = scheduler.get_lr()[0]
@ -105,26 +119,30 @@ def load_pretrained_optimizer(pretrained_model, optimizer, scheduler, lr, use_ck
def save_checkpoint(state, is_best, save_path, postname): def save_checkpoint(state, is_best, save_path, postname):
filename = '{}/{}_{}.pth'.format(save_path, postname, int(state['epoch'])) filename = "{}/{}_{}.pth".format(save_path, postname, int(state["epoch"]))
torch.save(state, filename) torch.save(state, filename)
if is_best: if is_best:
shutil.copyfile(filename, '{}/{}_best.pth'.format(save_path, postname)) shutil.copyfile(filename, "{}/{}_best.pth".format(save_path, postname))
def change_ckpt_dict(model, optimizer, scheduler, opt): def change_ckpt_dict(model, optimizer, scheduler, opt):
for _ in range(opt.epoch): for _ in range(opt.epoch):
scheduler.step() scheduler.step()
is_best = (opt.test_value < opt.best_value) is_best = opt.test_value < opt.best_value
opt.best_value = min(opt.test_value, opt.best_value) opt.best_value = min(opt.test_value, opt.best_value)
model_cpu = {k: v.cpu() for k, v in model.state_dict().items()} model_cpu = {k: v.cpu() for k, v in model.state_dict().items()}
# optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()} # optim_cpu = {k: v.cpu() for k, v in optimizer.state_dict().items()}
save_checkpoint({ save_checkpoint(
'epoch': opt.epoch, {
'state_dict': model_cpu, "epoch": opt.epoch,
'optimizer_state_dict': optimizer.state_dict(), "state_dict": model_cpu,
'scheduler_state_dict': scheduler.state_dict(), "optimizer_state_dict": optimizer.state_dict(),
'best_value': opt.best_value, "scheduler_state_dict": scheduler.state_dict(),
}, is_best, opt.save_path, opt.post) "best_value": opt.best_value,
},
is_best,
opt.save_path,
opt.post,
)

View File

@ -28,17 +28,20 @@ def add_zeros(data):
return data return data
def extract_node_feature(data, reduce='add'): def extract_node_feature(data, reduce="add"):
if reduce in ['mean', 'max', 'add']: if reduce in ["mean", "max", "add"]:
data.x = scatter(data.edge_attr, data.x = scatter(
data.edge_index[0], data.edge_attr,
dim=0, data.edge_index[0],
dim_size=data.num_nodes, dim=0,
reduce=reduce) dim_size=data.num_nodes,
reduce=reduce,
)
else: else:
raise Exception('Unknown Aggregation Type') raise Exception("Unknown Aggregation Type")
return data return data
# random partition graph # random partition graph
def random_partition_graph(num_nodes, cluster_number=10): def random_partition_graph(num_nodes, cluster_number=10):
parts = np.random.randint(cluster_number, size=num_nodes) parts = np.random.randint(cluster_number, size=num_nodes)
@ -47,7 +50,7 @@ def random_partition_graph(num_nodes, cluster_number=10):
def generate_sub_graphs(adj, parts, cluster_number=10, batch_size=1): def generate_sub_graphs(adj, parts, cluster_number=10, batch_size=1):
# convert sparse tensor to scipy csr # convert sparse tensor to scipy csr
adj = adj.to_scipy(layout='csr') adj = adj.to_scipy(layout="csr")
num_batches = cluster_number // batch_size num_batches = cluster_number // batch_size
@ -56,20 +59,27 @@ def generate_sub_graphs(adj, parts, cluster_number=10, batch_size=1):
for cluster in range(num_batches): for cluster in range(num_batches):
sg_nodes[cluster] = np.where(parts == cluster)[0] sg_nodes[cluster] = np.where(parts == cluster)[0]
sg_edges[cluster] = tg.utils.from_scipy_sparse_matrix(adj[sg_nodes[cluster], :][:, sg_nodes[cluster]])[0] sg_edges[cluster] = tg.utils.from_scipy_sparse_matrix(
adj[sg_nodes[cluster], :][:, sg_nodes[cluster]]
)[0]
return sg_nodes, sg_edges return sg_nodes, sg_edges
def random_rotate(points): def random_rotate(points):
theta = np.random.uniform(0, np.pi * 2) theta = np.random.uniform(0, np.pi * 2)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) rotation_matrix = np.array(
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
)
rotation_matrix = torch.from_numpy(rotation_matrix).float() rotation_matrix = torch.from_numpy(rotation_matrix).float()
points[:, 0:2] = torch.matmul(points[:, [0, 1]].transpose(1, 3), rotation_matrix).transpose(1, 3) points[:, 0:2] = torch.matmul(
points[:, [0, 1]].transpose(1, 3), rotation_matrix
).transpose(1, 3)
return points return points
def random_translate(points, mean=0, std=0.02): def random_translate(points, mean=0, std=0.02):
points += torch.randn(points.shape)*std + mean points += torch.randn(points.shape) * std + mean
return points return points
@ -82,15 +92,17 @@ def random_points_augmentation(points, rotate=False, translate=False, **kwargs):
return points return points
def scale_translate_pointcloud(pointcloud, shift=[-0.2, 0.2], scale=[2. / 3., 3. /2.]): def scale_translate_pointcloud(
pointcloud, shift=[-0.2, 0.2], scale=[2.0 / 3.0, 3.0 / 2.0]
):
""" """
for scaling and shifting the point cloud for scaling and shifting the point cloud
:param pointcloud: :param pointcloud:
:return: :return:
""" """
B, C, N = pointcloud.shape[0:3] B, C, N = pointcloud.shape[0:3]
scale = scale[0] + torch.rand([B, C, 1, 1])*(scale[1]-scale[0]) scale = scale[0] + torch.rand([B, C, 1, 1]) * (scale[1] - scale[0])
shift = shift[0] + torch.rand([B, C, 1, 1]) * (shift[1]-shift[0]) shift = shift[0] + torch.rand([B, C, 1, 1]) * (shift[1] - shift[0])
translated_pointcloud = torch.mul(pointcloud, scale) + shift translated_pointcloud = torch.mul(pointcloud, scale) + shift
return translated_pointcloud return translated_pointcloud
@ -126,25 +138,29 @@ class PartNet(InMemoryDataset):
final dataset. (default: :obj:`None`) final dataset. (default: :obj:`None`)
""" """
# the dataset we use for our paper is pre-released version # the dataset we use for our paper is pre-released version
def __init__(self, def __init__(
root, self,
dataset='sem_seg_h5', root,
obj_category='Bed', dataset="sem_seg_h5",
level=3, obj_category="Bed",
phase='train', level=3,
transform=None, phase="train",
pre_transform=None, transform=None,
pre_filter=None): pre_transform=None,
pre_filter=None,
):
self.dataset = dataset self.dataset = dataset
self.level = level self.level = level
self.obj_category = obj_category self.obj_category = obj_category
self.object = '-'.join([self.obj_category, str(self.level)]) self.object = "-".join([self.obj_category, str(self.level)])
self.level_folder = 'level_'+str(self.level) self.level_folder = "level_" + str(self.level)
self.processed_file_folder = osp.join(self.dataset, self.level_folder, self.object) self.processed_file_folder = osp.join(
self.dataset, self.level_folder, self.object
)
super(PartNet, self).__init__(root, transform, pre_transform, pre_filter) super(PartNet, self).__init__(root, transform, pre_transform, pre_filter)
if phase == 'test': if phase == "test":
path = self.processed_paths[1] path = self.processed_paths[1]
elif phase == 'val': elif phase == "val":
path = self.processed_paths[2] path = self.processed_paths[2]
else: else:
path = self.processed_paths[0] path = self.processed_paths[0]
@ -156,19 +172,24 @@ class PartNet(InMemoryDataset):
@property @property
def processed_file_names(self): def processed_file_names(self):
return osp.join(self.processed_file_folder, 'train.pt'), osp.join(self.processed_file_folder, 'test.pt'), \ return (
osp.join(self.processed_file_folder, 'val.pt') osp.join(self.processed_file_folder, "train.pt"),
osp.join(self.processed_file_folder, "test.pt"),
osp.join(self.processed_file_folder, "val.pt"),
)
def download(self): def download(self):
path = osp.join(self.raw_dir, self.dataset) path = osp.join(self.raw_dir, self.dataset)
if not osp.exists(path): if not osp.exists(path):
raise FileExistsError('PartNet can only downloaded via application. ' raise FileExistsError(
'See details in https://cs.stanford.edu/~kaichun/partnet/') "PartNet can only downloaded via application. "
"See details in https://cs.stanford.edu/~kaichun/partnet/"
)
# path = download_url(self.url, self.root) # path = download_url(self.url, self.root)
extract_zip(path, self.root) extract_zip(path, self.root)
os.unlink(path) os.unlink(path)
shutil.rmtree(self.raw_dir) shutil.rmtree(self.raw_dir)
name = self.url.split(os.sep)[-1].split('.')[0] name = self.url.split(os.sep)[-1].split(".")[0]
os.rename(osp.join(self.root, name), self.raw_dir) os.rename(osp.join(self.root, name), self.raw_dir)
def process(self): def process(self):
@ -176,31 +197,38 @@ class PartNet(InMemoryDataset):
processed_path = osp.join(self.processed_dir, self.processed_file_folder) processed_path = osp.join(self.processed_dir, self.processed_file_folder)
if not osp.exists(processed_path): if not osp.exists(processed_path):
os.makedirs(osp.join(processed_path)) os.makedirs(osp.join(processed_path))
torch.save(self.process_set('train'), self.processed_paths[0]) torch.save(self.process_set("train"), self.processed_paths[0])
torch.save(self.process_set('test'), self.processed_paths[1]) torch.save(self.process_set("test"), self.processed_paths[1])
torch.save(self.process_set('val'), self.processed_paths[2]) torch.save(self.process_set("val"), self.processed_paths[2])
def process_set(self, dataset): def process_set(self, dataset):
if self.dataset == 'ins_seg_h5': if self.dataset == "ins_seg_h5":
raw_path = osp.join(self.raw_dir, 'ins_seg_h5_for_sgpn', self.dataset) raw_path = osp.join(self.raw_dir, "ins_seg_h5_for_sgpn", self.dataset)
categories = glob(osp.join(raw_path, '*')) categories = glob(osp.join(raw_path, "*"))
categories = sorted([x.split(os.sep)[-1] for x in categories]) categories = sorted([x.split(os.sep)[-1] for x in categories])
data_list = [] data_list = []
for target, category in enumerate(tqdm(categories)): for target, category in enumerate(tqdm(categories)):
folder = osp.join(raw_path, category) folder = osp.join(raw_path, category)
paths = glob('{}/{}-*.h5'.format(folder, dataset)) paths = glob("{}/{}-*.h5".format(folder, dataset))
labels, nors, opacitys, pts, rgbs = [], [], [], [], [] labels, nors, opacitys, pts, rgbs = [], [], [], [], []
for path in paths: for path in paths:
f = h5py.File(path) f = h5py.File(path)
pts += torch.from_numpy(f['pts'][:]).unbind(0) pts += torch.from_numpy(f["pts"][:]).unbind(0)
labels += torch.from_numpy(f['label'][:]).to(torch.long).unbind(0) labels += torch.from_numpy(f["label"][:]).to(torch.long).unbind(0)
nors += torch.from_numpy(f['nor'][:]).unbind(0) nors += torch.from_numpy(f["nor"][:]).unbind(0)
opacitys += torch.from_numpy(f['opacity'][:]).unbind(0) opacitys += torch.from_numpy(f["opacity"][:]).unbind(0)
rgbs += torch.from_numpy(f['rgb'][:]).to(torch.float32).unbind(0) rgbs += torch.from_numpy(f["rgb"][:]).to(torch.float32).unbind(0)
for i, (pt, label, nor, opacity, rgb) in enumerate(zip(pts, labels, nors, opacitys, rgbs)): for i, (pt, label, nor, opacity, rgb) in enumerate(
data = Data(pos=pt[:, :3], y=label, norm=nor[:, :3], x=torch.cat((opacity.unsqueeze(-1), rgb/255.), 1)) zip(pts, labels, nors, opacitys, rgbs)
):
data = Data(
pos=pt[:, :3],
y=label,
norm=nor[:, :3],
x=torch.cat((opacity.unsqueeze(-1), rgb / 255.0), 1),
)
if self.pre_filter is not None and not self.pre_filter(data): if self.pre_filter is not None and not self.pre_filter(data):
continue continue
@ -215,14 +243,18 @@ class PartNet(InMemoryDataset):
# class_name = [] # class_name = []
for target, category in enumerate(tqdm(categories)): for target, category in enumerate(tqdm(categories)):
folder = osp.join(raw_path, category) folder = osp.join(raw_path, category)
paths = glob('{}/{}-*.h5'.format(folder, dataset)) paths = glob("{}/{}-*.h5".format(folder, dataset))
labels, pts = [], [] labels, pts = [], []
# clss = category.split('-')[0] # clss = category.split('-')[0]
for path in paths: for path in paths:
f = h5py.File(path) f = h5py.File(path)
pts += torch.from_numpy(f['data'][:].astype(np.float32)).unbind(0) pts += torch.from_numpy(f["data"][:].astype(np.float32)).unbind(0)
labels += torch.from_numpy(f['label_seg'][:].astype(np.float32)).to(torch.long).unbind(0) labels += (
torch.from_numpy(f["label_seg"][:].astype(np.float32))
.to(torch.long)
.unbind(0)
)
for i, (pt, label) in enumerate(zip(pts, labels)): for i, (pt, label) in enumerate(zip(pts, labels)):
data = Data(pos=pt[:, :3], y=label) data = Data(pos=pt[:, :3], y=label)
# data = PartData(pos=pt[:, :3], y=label, clss=clss) # data = PartData(pos=pt[:, :3], y=label, clss=clss)
@ -235,10 +267,7 @@ class PartNet(InMemoryDataset):
class PartData(Data): class PartData(Data):
def __init__(self, def __init__(self, y=None, pos=None, clss=None):
y=None,
pos=None,
clss=None):
super(PartData).__init__(pos=pos, y=y) super(PartData).__init__(pos=pos, y=y)
self.clss = clss self.clss = clss
@ -246,38 +275,30 @@ class PartData(Data):
# allowable multiple choice node and edge features # allowable multiple choice node and edge features
# code from https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py # code from https://github.com/snap-stanford/ogb/blob/master/ogb/utils/features.py
allowable_features = { allowable_features = {
'possible_atomic_num_list' : list(range(1, 119)) + ['misc'], "possible_atomic_num_list": list(range(1, 119)) + ["misc"],
'possible_chirality_list' : [ "possible_chirality_list": [
'CHI_UNSPECIFIED', "CHI_UNSPECIFIED",
'CHI_TETRAHEDRAL_CW', "CHI_TETRAHEDRAL_CW",
'CHI_TETRAHEDRAL_CCW', "CHI_TETRAHEDRAL_CCW",
'CHI_OTHER' "CHI_OTHER",
], ],
'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], "possible_degree_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "misc"],
'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], "possible_formal_charge_list": [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, "misc"],
'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], "possible_numH_list": [0, 1, 2, 3, 4, 5, 6, 7, 8, "misc"],
'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], "possible_number_radical_e_list": [0, 1, 2, 3, 4, "misc"],
'possible_hybridization_list' : [ "possible_hybridization_list": ["SP", "SP2", "SP3", "SP3D", "SP3D2", "misc"],
'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc' "possible_is_aromatic_list": [False, True],
], "possible_is_in_ring_list": [False, True],
'possible_is_aromatic_list': [False, True], "possible_bond_type_list": ["SINGLE", "DOUBLE", "TRIPLE", "AROMATIC", "misc"],
'possible_is_in_ring_list': [False, True], "possible_bond_stereo_list": [
'possible_bond_type_list' : [ "STEREONONE",
'SINGLE', "STEREOZ",
'DOUBLE', "STEREOE",
'TRIPLE', "STEREOCIS",
'AROMATIC', "STEREOTRANS",
'misc' "STEREOANY",
], ],
'possible_bond_stereo_list': [ "possible_is_conjugated_list": [False, True],
'STEREONONE',
'STEREOZ',
'STEREOE',
'STEREOCIS',
'STEREOTRANS',
'STEREOANY',
],
'possible_is_conjugated_list': [False, True],
} }
@ -298,31 +319,44 @@ def atom_to_feature_vector(atom):
:return: list :return: list
""" """
atom_feature = [ atom_feature = [
safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), safe_index(allowable_features["possible_atomic_num_list"], atom.GetAtomicNum()),
allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())), allowable_features["possible_chirality_list"].index(str(atom.GetChiralTag())),
safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), safe_index(allowable_features["possible_degree_list"], atom.GetTotalDegree()),
safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()), safe_index(
safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()), allowable_features["possible_formal_charge_list"], atom.GetFormalCharge()
safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()), ),
safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())), safe_index(allowable_features["possible_numH_list"], atom.GetTotalNumHs()),
allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()), safe_index(
allowable_features['possible_is_in_ring_list'].index(atom.IsInRing()), allowable_features["possible_number_radical_e_list"],
] atom.GetNumRadicalElectrons(),
),
safe_index(
allowable_features["possible_hybridization_list"],
str(atom.GetHybridization()),
),
allowable_features["possible_is_aromatic_list"].index(atom.GetIsAromatic()),
allowable_features["possible_is_in_ring_list"].index(atom.IsInRing()),
]
return atom_feature return atom_feature
def get_atom_feature_dims(): def get_atom_feature_dims():
return list(map(len, [ return list(
allowable_features['possible_atomic_num_list'], map(
allowable_features['possible_chirality_list'], len,
allowable_features['possible_degree_list'], [
allowable_features['possible_formal_charge_list'], allowable_features["possible_atomic_num_list"],
allowable_features['possible_numH_list'], allowable_features["possible_chirality_list"],
allowable_features['possible_number_radical_e_list'], allowable_features["possible_degree_list"],
allowable_features['possible_hybridization_list'], allowable_features["possible_formal_charge_list"],
allowable_features['possible_is_aromatic_list'], allowable_features["possible_numH_list"],
allowable_features['possible_is_in_ring_list'] allowable_features["possible_number_radical_e_list"],
])) allowable_features["possible_hybridization_list"],
allowable_features["possible_is_aromatic_list"],
allowable_features["possible_is_in_ring_list"],
],
)
)
def bond_to_feature_vector(bond): def bond_to_feature_vector(bond):
@ -332,56 +366,71 @@ def bond_to_feature_vector(bond):
:return: list :return: list
""" """
bond_feature = [ bond_feature = [
safe_index(allowable_features['possible_bond_type_list'], str(bond.GetBondType())), safe_index(
allowable_features['possible_bond_stereo_list'].index(str(bond.GetStereo())), allowable_features["possible_bond_type_list"], str(bond.GetBondType())
allowable_features['possible_is_conjugated_list'].index(bond.GetIsConjugated()), ),
] allowable_features["possible_bond_stereo_list"].index(str(bond.GetStereo())),
allowable_features["possible_is_conjugated_list"].index(bond.GetIsConjugated()),
]
return bond_feature return bond_feature
def get_bond_feature_dims(): def get_bond_feature_dims():
return list(map(len, [ return list(
allowable_features['possible_bond_type_list'], map(
allowable_features['possible_bond_stereo_list'], len,
allowable_features['possible_is_conjugated_list'] [
])) allowable_features["possible_bond_type_list"],
allowable_features["possible_bond_stereo_list"],
allowable_features["possible_is_conjugated_list"],
],
)
)
def atom_feature_vector_to_dict(atom_feature): def atom_feature_vector_to_dict(atom_feature):
[atomic_num_idx, [
chirality_idx, atomic_num_idx,
degree_idx, chirality_idx,
formal_charge_idx, degree_idx,
num_h_idx, formal_charge_idx,
number_radical_e_idx, num_h_idx,
hybridization_idx, number_radical_e_idx,
is_aromatic_idx, hybridization_idx,
is_in_ring_idx] = atom_feature is_aromatic_idx,
is_in_ring_idx,
] = atom_feature
feature_dict = { feature_dict = {
'atomic_num': allowable_features['possible_atomic_num_list'][atomic_num_idx], "atomic_num": allowable_features["possible_atomic_num_list"][atomic_num_idx],
'chirality': allowable_features['possible_chirality_list'][chirality_idx], "chirality": allowable_features["possible_chirality_list"][chirality_idx],
'degree': allowable_features['possible_degree_list'][degree_idx], "degree": allowable_features["possible_degree_list"][degree_idx],
'formal_charge': allowable_features['possible_formal_charge_list'][formal_charge_idx], "formal_charge": allowable_features["possible_formal_charge_list"][
'num_h': allowable_features['possible_numH_list'][num_h_idx], formal_charge_idx
'num_rad_e': allowable_features['possible_number_radical_e_list'][number_radical_e_idx], ],
'hybridization': allowable_features['possible_hybridization_list'][hybridization_idx], "num_h": allowable_features["possible_numH_list"][num_h_idx],
'is_aromatic': allowable_features['possible_is_aromatic_list'][is_aromatic_idx], "num_rad_e": allowable_features["possible_number_radical_e_list"][
'is_in_ring': allowable_features['possible_is_in_ring_list'][is_in_ring_idx] number_radical_e_idx
],
"hybridization": allowable_features["possible_hybridization_list"][
hybridization_idx
],
"is_aromatic": allowable_features["possible_is_aromatic_list"][is_aromatic_idx],
"is_in_ring": allowable_features["possible_is_in_ring_list"][is_in_ring_idx],
} }
return feature_dict return feature_dict
def bond_feature_vector_to_dict(bond_feature): def bond_feature_vector_to_dict(bond_feature):
[bond_type_idx, [bond_type_idx, bond_stereo_idx, is_conjugated_idx] = bond_feature
bond_stereo_idx,
is_conjugated_idx] = bond_feature
feature_dict = { feature_dict = {
'bond_type': allowable_features['possible_bond_type_list'][bond_type_idx], "bond_type": allowable_features["possible_bond_type_list"][bond_type_idx],
'bond_stereo': allowable_features['possible_bond_stereo_list'][bond_stereo_idx], "bond_stereo": allowable_features["possible_bond_stereo_list"][bond_stereo_idx],
'is_conjugated': allowable_features['possible_is_conjugated_list'][is_conjugated_idx] "is_conjugated": allowable_features["possible_is_conjugated_list"][
is_conjugated_idx
],
} }
return feature_dict return feature_dict

View File

@ -3,12 +3,12 @@ import shutil
import csv import csv
def save_best_result(list_of_dict, file_name, dir_path='best_result'): def save_best_result(list_of_dict, file_name, dir_path="best_result"):
if not os.path.exists(dir_path): if not os.path.exists(dir_path):
os.mkdir(dir_path) os.mkdir(dir_path)
print("Directory ", dir_path, " is created.") print("Directory ", dir_path, " is created.")
csv_file_name = '{}/{}.csv'.format(dir_path, file_name) csv_file_name = "{}/{}.csv".format(dir_path, file_name)
with open(csv_file_name, 'a+') as csv_file: with open(csv_file_name, "a+") as csv_file:
csv_writer = csv.writer(csv_file) csv_writer = csv.writer(csv_file)
for _ in range(len(list_of_dict)): for _ in range(len(list_of_dict)):
csv_writer.writerow(list_of_dict[_].values()) csv_writer.writerow(list_of_dict[_].values())
@ -17,10 +17,10 @@ def save_best_result(list_of_dict, file_name, dir_path='best_result'):
def create_exp_dir(path, scripts_to_save=None): def create_exp_dir(path, scripts_to_save=None):
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
print('Experiment dir : {}'.format(path)) print("Experiment dir : {}".format(path))
if scripts_to_save is not None: if scripts_to_save is not None:
os.mkdir(os.path.join(path, 'scripts')) os.mkdir(os.path.join(path, "scripts"))
for script in scripts_to_save: for script in scripts_to_save:
dst_file = os.path.join(path, 'scripts', os.path.basename(script)) dst_file = os.path.join(path, "scripts", os.path.basename(script))
shutil.copyfile(script, dst_file) shutil.copyfile(script, dst_file)

View File

@ -14,11 +14,13 @@ class SmoothCrossEntropy(torch.nn.Module):
if self.smoothing: if self.smoothing:
n_class = pred.size(1) n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gt.view(-1, 1), 1) one_hot = torch.zeros_like(pred).scatter(1, gt.view(-1, 1), 1)
one_hot = one_hot * (1 - self.eps) + (1 - one_hot) * self.eps / (n_class - 1) one_hot = one_hot * (1 - self.eps) + (1 - one_hot) * self.eps / (
n_class - 1
)
log_prb = F.log_softmax(pred, dim=1) log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean() loss = -(one_hot * log_prb).sum(dim=1).mean()
else: else:
loss = F.cross_entropy(pred, gt, reduction='mean') loss = F.cross_entropy(pred, gt, reduction="mean")
return loss return loss

View File

@ -1,25 +1,24 @@
from math import log10 from math import log10
def PSNR(mse, peak=1.): def PSNR(mse, peak=1.0):
return 10 * log10((peak ** 2) / mse) return 10 * log10((peak ** 2) / mse)
class AverageMeter(object): class AverageMeter(object):
"""Computes and stores the average and current value""" """Computes and stores the average and current value"""
def __init__(self): def __init__(self):
self.reset() self.reset()
def reset(self): def reset(self):
self.val = 0 self.val = 0
self.avg = 0 self.avg = 0
self.sum = 0 self.sum = 0
self.count = 0 self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

View File

@ -4,7 +4,6 @@ from torch.optim.optimizer import Optimizer, required
class RAdam(Optimizer): class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)] self.buffer = [[None, None, None] for ind in range(10)]
@ -21,55 +20,67 @@ class RAdam(Optimizer):
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data.float() grad = p.grad.data.float()
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients') raise RuntimeError("RAdam does not support sparse gradients")
p_data_fp32 = p.data.float() p_data_fp32 = p.data.float()
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state["step"] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32) state["exp_avg"] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else: else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group['betas'] beta1, beta2 = group["betas"]
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1 state["step"] += 1
buffered = self.buffer[int(state['step'] % 10)] buffered = self.buffer[int(state["step"] % 10)]
if state['step'] == buffered[0]: if state["step"] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2] N_sma, step_size = buffered[1], buffered[2]
else: else:
buffered[0] = state['step'] buffered[0] = state["step"]
beta2_t = beta2 ** state['step'] beta2_t = beta2 ** state["step"]
N_sma_max = 2 / (1 - beta2) - 1 N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma buffered[1] = N_sma
# more conservative since it's an approximated value # more conservative since it's an approximated value
if N_sma >= 5: if N_sma >= 5:
step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) step_size = (
group["lr"]
* math.sqrt(
(1 - beta2_t)
* (N_sma - 4)
/ (N_sma_max - 4)
* (N_sma - 2)
/ N_sma
* N_sma_max
/ (N_sma_max - 2)
)
/ (1 - beta1 ** state["step"])
)
else: else:
step_size = group['lr'] / (1 - beta1 ** state['step']) step_size = group["lr"] / (1 - beta1 ** state["step"])
buffered[2] = step_size buffered[2] = step_size
if group['weight_decay'] != 0: if group["weight_decay"] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
# more conservative since it's an approximated value # more conservative since it's an approximated value
if N_sma >= 5: if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps']) denom = exp_avg_sq.sqrt().add_(group["eps"])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom) p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
else: else:
p_data_fp32.add_(-step_size, exp_avg) p_data_fp32.add_(-step_size, exp_avg)
@ -78,8 +89,8 @@ class RAdam(Optimizer):
return loss return loss
class PlainRAdam(Optimizer):
class PlainRAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
@ -96,46 +107,58 @@ class PlainRAdam(Optimizer):
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data.float() grad = p.grad.data.float()
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients') raise RuntimeError("RAdam does not support sparse gradients")
p_data_fp32 = p.data.float() p_data_fp32 = p.data.float()
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state["step"] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32) state["exp_avg"] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else: else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group['betas'] beta1, beta2 = group["betas"]
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1 state["step"] += 1
beta2_t = beta2 ** state['step'] beta2_t = beta2 ** state["step"]
N_sma_max = 2 / (1 - beta2) - 1 N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
if group['weight_decay'] != 0: if group["weight_decay"] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) p_data_fp32.add_(-group["weight_decay"] * group["lr"], p_data_fp32)
# more conservative since it's an approximated value # more conservative since it's an approximated value
if N_sma >= 5: if N_sma >= 5:
step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) step_size = (
denom = exp_avg_sq.sqrt().add_(group['eps']) group["lr"]
* math.sqrt(
(1 - beta2_t)
* (N_sma - 4)
/ (N_sma_max - 4)
* (N_sma - 2)
/ N_sma
* N_sma_max
/ (N_sma_max - 2)
)
/ (1 - beta1 ** state["step"])
)
denom = exp_avg_sq.sqrt().add_(group["eps"])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom) p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
else: else:
step_size = group['lr'] / (1 - beta1 ** state['step']) step_size = group["lr"] / (1 - beta1 ** state["step"])
p_data_fp32.add_(-step_size, exp_avg) p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32) p.data.copy_(p_data_fp32)
@ -144,10 +167,18 @@ class PlainRAdam(Optimizer):
class AdamW(Optimizer): class AdamW(Optimizer):
def __init__(
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0
defaults = dict(lr=lr, betas=betas, eps=eps, ):
weight_decay=weight_decay, amsgrad=amsgrad, use_variance=True, warmup = warmup) defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
use_variance=True,
warmup=warmup,
)
super(AdamW, self).__init__(params, defaults) super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
@ -160,46 +191,48 @@ class AdamW(Optimizer):
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data.float() grad = p.grad.data.float()
if grad.is_sparse: if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
p_data_fp32 = p.data.float() p_data_fp32 = p.data.float()
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state["step"] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32) state["exp_avg"] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else: else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group['betas'] beta1, beta2 = group["betas"]
state['step'] += 1 state["step"] += 1
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad)
denom = exp_avg_sq.sqrt().add_(group['eps']) denom = exp_avg_sq.sqrt().add_(group["eps"])
bias_correction1 = 1 - beta1 ** state['step'] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state['step'] bias_correction2 = 1 - beta2 ** state["step"]
if group['warmup'] > state['step']: if group["warmup"] > state["step"]:
scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] scheduled_lr = 1e-8 + state["step"] * group["lr"] / group["warmup"]
else: else:
scheduled_lr = group['lr'] scheduled_lr = group["lr"]
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
if group['weight_decay'] != 0: if group["weight_decay"] != 0:
p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) p_data_fp32.add_(-group["weight_decay"] * scheduled_lr, p_data_fp32)
p_data_fp32.addcdiv_(-step_size, exp_avg, denom) p_data_fp32.addcdiv_(-step_size, exp_avg, denom)

View File

@ -3,7 +3,7 @@ import numpy as np
import random import random
import os import os
print('Using', vtk.vtkVersion.GetVTKSourceVersion()) print("Using", vtk.vtkVersion.GetVTKSourceVersion())
class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera): class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):
@ -14,7 +14,7 @@ class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):
def keyPressEvent(self, obj, event): def keyPressEvent(self, obj, event):
key = self.parent.GetKeySym() key = self.parent.GetKeySym()
if key == '+': if key == "+":
point_size = self.pointcloud.vtkActor.GetProperty().GetPointSize() point_size = self.pointcloud.vtkActor.GetProperty().GetPointSize()
self.pointcloud.vtkActor.GetProperty().SetPointSize(point_size + 1) self.pointcloud.vtkActor.GetProperty().SetPointSize(point_size + 1)
print(str(point_size) + " " + key) print(str(point_size) + " " + key)
@ -22,7 +22,6 @@ class MyInteractorStyle(vtk.vtkInteractorStyleTrackballCamera):
class VtkPointCloud: class VtkPointCloud:
def __init__(self, point_size=18, maxNumPoints=1e8): def __init__(self, point_size=18, maxNumPoints=1e8):
self.maxNumPoints = maxNumPoints self.maxNumPoints = maxNumPoints
self.vtkPolyData = vtk.vtkPolyData() self.vtkPolyData = vtk.vtkPolyData()
@ -59,11 +58,11 @@ class VtkPointCloud:
self.vtkPoints = vtk.vtkPoints() self.vtkPoints = vtk.vtkPoints()
self.vtkCells = vtk.vtkCellArray() self.vtkCells = vtk.vtkCellArray()
self.vtkDepth = vtk.vtkDoubleArray() self.vtkDepth = vtk.vtkDoubleArray()
self.vtkDepth.SetName('DepthArray') self.vtkDepth.SetName("DepthArray")
self.vtkPolyData.SetPoints(self.vtkPoints) self.vtkPolyData.SetPoints(self.vtkPoints)
self.vtkPolyData.SetVerts(self.vtkCells) self.vtkPolyData.SetVerts(self.vtkCells)
self.vtkPolyData.GetPointData().SetScalars(self.vtkDepth) self.vtkPolyData.GetPointData().SetScalars(self.vtkDepth)
self.vtkPolyData.GetPointData().SetActiveScalars('DepthArray') self.vtkPolyData.GetPointData().SetActiveScalars("DepthArray")
def getActorCircle(radius_inner=100, radius_outer=99, color=(1, 0, 0)): def getActorCircle(radius_inner=100, radius_outer=99, color=(1, 0, 0)):
@ -95,7 +94,15 @@ def getActorCircle(radius_inner=100, radius_outer=99, color=(1, 0, 0)):
return actor return actor
def show_pointclouds(points, colors, text=[], title="Default", png_path="", interactive=True, orientation='horizontal'): def show_pointclouds(
points,
colors,
text=[],
title="Default",
png_path="",
interactive=True,
orientation="horizontal",
):
""" """
Show multiple point clouds specified as lists. First clouds at the bottom. Show multiple point clouds specified as lists. First clouds at the bottom.
:param points: list of pointclouds, item: numpy (N x 3) XYZ :param points: list of pointclouds, item: numpy (N x 3) XYZ
@ -108,16 +115,18 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
""" """
# make sure pointclouds is a list # make sure pointclouds is a list
assert isinstance(points, type([])), \ assert isinstance(points, type([])), "Pointclouds argument must be a list"
"Pointclouds argument must be a list"
# make sure colors is a list # make sure colors is a list
assert isinstance(colors, type([])), \ assert isinstance(colors, type([])), "Colors argument must be a list"
"Colors argument must be a list"
# make sure number of pointclouds and colors are the same # make sure number of pointclouds and colors are the same
assert len(points) == len(colors), \ assert len(points) == len(
"Number of pointclouds (%d) is different then number of colors (%d)" % (len(points), len(colors)) colors
), "Number of pointclouds (%d) is different then number of colors (%d)" % (
len(points),
len(colors),
)
while len(text) < len(points): while len(text) < len(points):
text.append("") text.append("")
@ -130,15 +139,20 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
renderers = [vtk.vtkRenderer() for _ in range(num_pointclouds)] renderers = [vtk.vtkRenderer() for _ in range(num_pointclouds)]
height = 1.0 / max(num_pointclouds, 1) height = 1.0 / max(num_pointclouds, 1)
viewports = [(i*height, (i+1)*height) for i in range(num_pointclouds)] viewports = [(i * height, (i + 1) * height) for i in range(num_pointclouds)]
#print(viewports) # print(viewports)
# iterate over all point clouds # iterate over all point clouds
for i, pc in enumerate(points): for i, pc in enumerate(points):
pc = pc.squeeze() pc = pc.squeeze()
co = colors[i].squeeze() co = colors[i].squeeze()
assert pc.shape[0] == co.shape[0], \ assert (
"expected same number of points (%d) then colors (%d), cloud index = %d" % (pc.shape[0], co.shape[0], i) pc.shape[0] == co.shape[0]
), "expected same number of points (%d) then colors (%d), cloud index = %d" % (
pc.shape[0],
co.shape[0],
i,
)
assert pc.shape[1] == 3, "expected points to be N x 3, got N x %d" % pc.shape[1] assert pc.shape[1] == 3, "expected points to be N x 3, got N x %d" % pc.shape[1]
assert co.shape[1] == 3, "expected colors to be N x 3, got N x %d" % co.shape[1] assert co.shape[1] == 3, "expected colors to be N x 3, got N x %d" % co.shape[1]
@ -151,13 +165,13 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
renderers[i].AddActor(pointclouds[i].vtkActor) renderers[i].AddActor(pointclouds[i].vtkActor)
# renderers[i].AddActor(vtk.vtkAxesActor()) # renderers[i].AddActor(vtk.vtkAxesActor())
renderers[i].SetBackground(1.0, 1.0, 1.0) renderers[i].SetBackground(1.0, 1.0, 1.0)
if orientation == 'horizontal': if orientation == "horizontal":
print(viewports[i][0]) print(viewports[i][0])
renderers[i].SetViewport(viewports[i][0], 0.0, viewports[i][1], 1.0) renderers[i].SetViewport(viewports[i][0], 0.0, viewports[i][1], 1.0)
elif orientation == 'vertical': elif orientation == "vertical":
renderers[i].SetViewport(0.0, viewports[i][0], 1.0, viewports[i][1]) renderers[i].SetViewport(0.0, viewports[i][0], 1.0, viewports[i][1])
else: else:
raise Exception('Not a valid orientation!') raise Exception("Not a valid orientation!")
renderers[i].ResetCamera() renderers[i].ResetCamera()
# Add circle to first render # Add circle to first render
@ -167,12 +181,12 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
# Text actors # Text actors
text_actors = [vtk.vtkTextActor() for _ in text] text_actors = [vtk.vtkTextActor() for _ in text]
for i, ta in enumerate(text_actors): for i, ta in enumerate(text_actors):
if orientation == 'horizontal': if orientation == "horizontal":
ta.SetInput(' ' + text[i]) ta.SetInput(" " + text[i])
elif orientation == 'vertical': elif orientation == "vertical":
ta.SetInput(text[i] + '\n\n\n\n\n\n') ta.SetInput(text[i] + "\n\n\n\n\n\n")
else: else:
raise Exception('Not a valid orientation!') raise Exception("Not a valid orientation!")
txtprop = ta.GetTextProperty() txtprop = ta.GetTextProperty()
txtprop.SetFontFamilyToArial() txtprop.SetFontFamilyToArial()
txtprop.SetFontSize(0) txtprop.SetFontSize(0)
@ -201,14 +215,14 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
# camera.SetFocalPoint(0, 0, 0) # camera.SetFocalPoint(0, 0, 0)
camera.SetViewUp(0, 0, 1) camera.SetViewUp(0, 0, 1)
if orientation == 'horizontal': if orientation == "horizontal":
camera.SetPosition(3, -10, 2) camera.SetPosition(3, -10, 2)
camera.SetFocalPoint(3, 1.5, 1.5) camera.SetFocalPoint(3, 1.5, 1.5)
elif orientation == 'vertical': elif orientation == "vertical":
camera.SetPosition(1.5, -6, 2) camera.SetPosition(1.5, -6, 2)
camera.SetFocalPoint(1.5, 1.5, 1.5) camera.SetFocalPoint(1.5, 1.5, 1.5)
else: else:
raise Exception('Not a valid orientation!') raise Exception("Not a valid orientation!")
camera.SetClippingRange(0.002, 1000) camera.SetClippingRange(0.002, 1000)
for renderer in renderers: for renderer in renderers:
@ -217,12 +231,12 @@ def show_pointclouds(points, colors, text=[], title="Default", png_path="", inte
# Begin Interaction # Begin Interaction
render_window.Render() render_window.Render()
render_window.SetWindowName(title) render_window.SetWindowName(title)
if orientation == 'horizontal': if orientation == "horizontal":
render_window.SetSize(1940, 720) render_window.SetSize(1940, 720)
elif orientation == 'vertical': elif orientation == "vertical":
render_window.SetSize(600, 1388) render_window.SetSize(600, 1388)
else: else:
raise Exception('Not a valid orientation!') raise Exception("Not a valid orientation!")
if interactive: if interactive:
render_window_interactor.Start() render_window_interactor.Start()
@ -253,10 +267,20 @@ def get_points_colors_from_obj(filename, limit=1):
return points[idx, :], colors[idx, :] return points[idx, :], colors[idx, :]
def visualize_part_seg(file_name_pred, file_name_gt, comparison_folder_list, limit=1, text=[], png_path="", def visualize_part_seg(
interactive=True, orientation='horizontal'): file_name_pred,
file_name_gt,
comparison_folder_list,
limit=1,
text=[],
png_path="",
interactive=True,
orientation="horizontal",
):
# load base point cloud # load base point cloud
gt_points, gt_colors = get_points_colors_from_obj(os.path.join(comparison_folder_list[0], file_name_gt), limit) gt_points, gt_colors = get_points_colors_from_obj(
os.path.join(comparison_folder_list[0], file_name_gt), limit
)
idx_gt = gt_points[:, 1] >= limit idx_gt = gt_points[:, 1] >= limit
@ -264,12 +288,19 @@ def visualize_part_seg(file_name_pred, file_name_gt, comparison_folder_list, lim
all_colors = [gt_colors[idx_gt, :3]] all_colors = [gt_colors[idx_gt, :3]]
for folder in comparison_folder_list: for folder in comparison_folder_list:
pts, col = get_points_colors_from_obj(os.path.join(folder, file_name_pred), limit=limit) pts, col = get_points_colors_from_obj(
os.path.join(folder, file_name_pred), limit=limit
)
all_points.append(pts) all_points.append(pts)
all_colors.append(col) all_colors.append(col)
print(np.asarray(all_points).shape) print(np.asarray(all_points).shape)
show_pointclouds(all_points, all_colors, text=text, png_path=png_path, interactive=interactive, show_pointclouds(
orientation=orientation) all_points,
all_colors,
text=text,
png_path=png_path,
interactive=interactive,
orientation=orientation,
)

View File

@ -3,7 +3,7 @@ try:
import tensorflow as tf import tensorflow as tf
import tensorboard.plugins.mesh.summary as meshsummary import tensorboard.plugins.mesh.summary as meshsummary
except ImportError: except ImportError:
print('tensorflow is not installed.') print("tensorflow is not installed.")
import numpy as np import numpy as np
import scipy.misc import scipy.misc
@ -11,38 +11,38 @@ import scipy.misc
try: try:
from StringIO import StringIO # Python 2.7 from StringIO import StringIO # Python 2.7
except ImportError: except ImportError:
from io import BytesIO # Python 3.x from io import BytesIO # Python 3.x
class TfLogger(object): class TfLogger(object):
def __init__(self, log_dir): def __init__(self, log_dir):
"""Create a summary writer logging to log_dir.""" """Create a summary writer logging to log_dir."""
self.writer = tf.compat.v1.summary.FileWriter(log_dir) self.writer = tf.compat.v1.summary.FileWriter(log_dir)
# Camera and scene configuration. # Camera and scene configuration.
self.config_dict = { self.config_dict = {
'camera': {'cls': 'PerspectiveCamera', 'fov': 75}, "camera": {"cls": "PerspectiveCamera", "fov": 75},
'lights': [ "lights": [
{ {
'cls': 'AmbientLight', "cls": "AmbientLight",
'color': '#ffffff', "color": "#ffffff",
'intensity': 0.75, "intensity": 0.75,
}, { },
'cls': 'DirectionalLight', {
'color': '#ffffff', "cls": "DirectionalLight",
'intensity': 0.75, "color": "#ffffff",
'position': [0, -1, 2], "intensity": 0.75,
}], "position": [0, -1, 2],
'material': { },
'cls': 'MeshStandardMaterial', ],
'metalness': 0 "material": {"cls": "MeshStandardMaterial", "metalness": 0},
}
} }
def scalar_summary(self, tag, value, step): def scalar_summary(self, tag, value, step):
"""Log a scalar variable.""" """Log a scalar variable."""
summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) summary = tf.compat.v1.Summary(
value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]
)
self.writer.add_summary(summary, step) self.writer.add_summary(summary, step)
def image_summary(self, tag, images, step): def image_summary(self, tag, images, step):
@ -54,10 +54,15 @@ class TfLogger(object):
scipy.misc.toimage(img).save(s, format="png") scipy.misc.toimage(img).save(s, format="png")
# Create an Image object # Create an Image object
img_sum = tf.compat.v1.Summary.Image(encoded_image_string=s.getvalue(), img_sum = tf.compat.v1.Summary.Image(
height=img.shape[0], width=img.shape[1]) encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1],
)
# Create a Summary value # Create a Summary value
img_summaries.append(tf.compat.v1.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) img_summaries.append(
tf.compat.v1.Summary.Value(tag="%s/%d" % (tag, i), image=img_sum)
)
# Create and write Summary # Create and write Summary
summary = tf.Summary(value=img_summaries) summary = tf.Summary(value=img_summaries)
@ -71,10 +76,17 @@ class TfLogger(object):
vertices = tf.constant(vertices) vertices = tf.constant(vertices)
if faces is not None: if faces is not None:
faces = tf.constant(faces) faces = tf.constant(faces)
meshes_summares=[] meshes_summares = []
for i in range(vertices.shape[0]): for i in range(vertices.shape[0]):
meshes_summares.append(meshsummary.op( meshes_summares.append(
tag, vertices=vertices, faces=faces, colors=colors, config_dict=self.config_dict)) meshsummary.op(
tag,
vertices=vertices,
faces=faces,
colors=colors,
config_dict=self.config_dict,
)
)
sess = tf.Session() sess = tf.Session()
summaries = sess.run(meshes_summares) summaries = sess.run(meshes_summares)
@ -93,7 +105,7 @@ class TfLogger(object):
hist.max = float(np.max(values)) hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape)) hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values)) hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2)) hist.sum_squares = float(np.sum(values ** 2))
# Drop the start of the first bin # Drop the start of the first bin
bin_edges = bin_edges[1:] bin_edges = bin_edges[1:]
@ -108,4 +120,3 @@ class TfLogger(object):
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step) self.writer.add_summary(summary, step)
self.writer.flush() self.writer.flush()

View File

@ -11,18 +11,23 @@ from .util.base import Base
class DBN(Base): class DBN(Base):
def __init__(self, u_obs, u, nComponents=[250, 30, 5], constant=298):
def __init__(self,u_obs,u,nComponents=[250,30,5],constant=298): super().__init__(u_obs, u)
super().__init__(u_obs,u) self.n_components = nComponents
self.n_components=nComponents
def dbn(self): def dbn(self):
models = [] models = []
for num, components in zip(range(0,len(self.n_components)),self.n_components): for num, components in zip(range(0, len(self.n_components)), self.n_components):
model = rbm(n_components=components, n_iter=300, learning_rate=0.06, random_state=1,verbose=True) model = rbm(
model_name = 'rbm' + str(num) n_components=components,
n_iter=300,
learning_rate=0.06,
random_state=1,
verbose=True,
)
model_name = "rbm" + str(num)
models.append((model_name, model)) models.append((model_name, model))
models.append(('clf', LinearRegression())) models.append(("clf", LinearRegression()))
return models return models
def predict(self): def predict(self):
@ -35,58 +40,58 @@ class DBN(Base):
regressor = regressor = Pipeline(models) regressor = regressor = Pipeline(models)
regressor.fit(X, Y) regressor.fit(X, Y)
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1]) self.u_pred = regressor.predict(test_samples).reshape(
self.u.shape[0], self.u.shape[1]
)
return self.u_pred return self.u_pred
if __name__ == '__main__': if __name__ == "__main__":
m=sio.loadmat('Example0.mat') m = sio.loadmat("Example0.mat")
u_obs=m['u_obs'] u_obs = m["u_obs"]
u=m['u'] u = m["u"]
sample = DBN(u_obs,u, nComponents=[250,30,5]) sample = DBN(u_obs, u, nComponents=[250, 30, 5])
u_pred = sample.predict() u_pred = sample.predict()
print('mae:',mae(u_pred,u)) print("mae:", mae(u_pred, u))
u_pred=u_pred*50+298 u_pred = u_pred * 50 + 298
from sklearn.metrics import mean_absolute_error as mae from sklearn.metrics import mean_absolute_error as mae
print('mae:',mae(u_pred,u))
print("mae:", mae(u_pred, u))
fig = plt.figure(figsize=(22.5, 5))
fig = plt.figure(figsize=(22.5,5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
plt.subplot(141) plt.subplot(141)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.pcolormesh(X,Y,abs(u-u_pred)) im = plt.pcolormesh(X, Y, abs(u - u_pred))
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u,levels=150,cmap='jet') im = plt.contourf(X, Y, u, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet') im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet') im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
#save_name = os.path.join('outputs/predict_plot', '1.png') # save_name = os.path.join('outputs/predict_plot', '1.png')
#fig.savefig(save_name, dpi=300) # fig.savefig(save_name, dpi=300)
# fig = plt.figure(figsize=(5,5))
#fig = plt.figure(figsize=(5,5)) # im = plt.imshow(u,cmap='jet')
#im = plt.imshow(u,cmap='jet') # plt.colorbar(im)
#plt.colorbar(im) fig.savefig("prediction.png", dpi=300)
fig.savefig('prediction.png', dpi=300)

View File

@ -8,13 +8,12 @@ from .util.base import Base
class GInterpolation(Base): class GInterpolation(Base):
def __init__(self, u_obs, u, constant=298):
def __init__(self,u_obs,u,constant=298): super().__init__(u_obs, u)
super().__init__(u_obs,u)
def predict(self): def predict(self):
row = np.linspace(0, self.u.shape[0]-1, num=self.u.shape[0]) row = np.linspace(0, self.u.shape[0] - 1, num=self.u.shape[0])
col = row col = row
col, row = np.meshgrid(col, row) col, row = np.meshgrid(col, row)
@ -24,61 +23,72 @@ class GInterpolation(Base):
col = np.dot(np.ones_like(self.cols).reshape(-1, 1), col) col = np.dot(np.ones_like(self.cols).reshape(-1, 1), col)
row = np.dot(np.ones_like(self.rows).reshape(-1, 1), row) row = np.dot(np.ones_like(self.rows).reshape(-1, 1), row)
ind = np.dot(np.ones_like(self.rows).reshape(1, -1), (np.exp(-np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2))))) ind = np.dot(
np.ones_like(self.rows).reshape(1, -1),
(
np.exp(
-np.sqrt(
np.power((self.rows.reshape(-1, 1) - row), 2)
+ np.power((self.cols.reshape(-1, 1) - col), 2)
)
)
),
)
param = np.exp(-np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2)))/(np.dot(np.ones_like(self.rows).reshape(-1, 1), ind)) param = np.exp(
-np.sqrt(
np.power((self.rows.reshape(-1, 1) - row), 2)
+ np.power((self.cols.reshape(-1, 1) - col), 2)
)
) / (np.dot(np.ones_like(self.rows).reshape(-1, 1), ind))
dis = np.dot(self.u_obs[self.rows,self.cols].reshape(1, -1), param) dis = np.dot(self.u_obs[self.rows, self.cols].reshape(1, -1), param)
self.u_pred = dis.reshape(self.u.shape[0], self.u.shape[1]) self.u_pred = dis.reshape(self.u.shape[0], self.u.shape[1])
return self.u_pred return self.u_pred
if __name__ == '__main__': if __name__ == "__main__":
m=sio.loadmat('Example10001.mat') m = sio.loadmat("Example10001.mat")
u_obs=m['u_obs'] u_obs = m["u_obs"]
u=m['u'] u = m["u"]
sample = GInterpolation(u_obs,u) sample = GInterpolation(u_obs, u)
u_pred = sample.predict() u_pred = sample.predict()
print('mae:',mae(u_pred,u)) print("mae:", mae(u_pred, u))
fig = plt.figure(figsize=(22.5, 5))
fig = plt.figure(figsize=(22.5,5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
plt.subplot(141) plt.subplot(141)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.pcolormesh(X,Y,abs(u-u_pred)) im = plt.pcolormesh(X, Y, abs(u - u_pred))
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u,levels=150,cmap='jet') im = plt.contourf(X, Y, u, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet') im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet') im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
#save_name = os.path.join('outputs/predict_plot', '1.png') # save_name = os.path.join('outputs/predict_plot', '1.png')
#fig.savefig(save_name, dpi=300) # fig.savefig(save_name, dpi=300)
# fig = plt.figure(figsize=(5,5))
#fig = plt.figure(figsize=(5,5)) # im = plt.imshow(u,cmap='jet')
#im = plt.imshow(u,cmap='jet') # plt.colorbar(im)
#plt.colorbar(im) fig.savefig("prediction.png", dpi=300)
fig.savefig('prediction.png', dpi=300)

View File

@ -8,22 +8,23 @@ from .util.base import Base
class KInterpolation(Base): class KInterpolation(Base):
def __init__(self, u_obs, u, k=5, constant=298):
def __init__(self,u_obs,u,k=5,constant=298): super().__init__(u_obs, u)
super().__init__(u_obs,u)
self.k = k self.k = k
def knearest(self, row, col): def knearest(self, row, col):
d = np.zeros_like(self.rows) d = np.zeros_like(self.rows)
for k in range(self.rows.shape[0]): for k in range(self.rows.shape[0]):
d[k] = math.sqrt(math.pow((self.rows[k]-row),2)+math.pow((self.cols[k]-col),2)) d[k] = math.sqrt(
kpoint = np.argsort(d)[:self.k] math.pow((self.rows[k] - row), 2) + math.pow((self.cols[k] - col), 2)
)
kpoint = np.argsort(d)[: self.k]
return self.rows[kpoint], self.cols[kpoint] return self.rows[kpoint], self.cols[kpoint]
def predict(self): def predict(self):
row = np.linspace(0, self.u.shape[0]-1, num=self.u.shape[0]) row = np.linspace(0, self.u.shape[0] - 1, num=self.u.shape[0])
col = np.linspace(0, self.u.shape[0]-1, num=self.u.shape[0]) col = np.linspace(0, self.u.shape[0] - 1, num=self.u.shape[0])
col, row = np.meshgrid(col, row) col, row = np.meshgrid(col, row)
col = col.reshape(1, -1) col = col.reshape(1, -1)
@ -32,69 +33,90 @@ class KInterpolation(Base):
col = np.dot(np.ones_like(self.cols).reshape(-1, 1), col) col = np.dot(np.ones_like(self.cols).reshape(-1, 1), col)
row = np.dot(np.ones_like(self.rows).reshape(-1, 1), row) row = np.dot(np.ones_like(self.rows).reshape(-1, 1), row)
ksort = np.argsort(np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2)), axis=0) ksort = np.argsort(
np.sqrt(
np.power((self.rows.reshape(-1, 1) - row), 2)
+ np.power((self.cols.reshape(-1, 1) - col), 2)
),
axis=0,
)
kind = np.zeros_like(ksort) kind = np.zeros_like(ksort)
for num in range(ksort.shape[1]): for num in range(ksort.shape[1]):
kind[ksort[:self.k, num], num] = 1 kind[ksort[: self.k, num], num] = 1
assert (self.k >= 1) assert self.k >= 1
ind = np.dot(np.ones_like(self.rows).reshape(1, -1), np.exp(-np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2))) * kind) ind = np.dot(
np.ones_like(self.rows).reshape(1, -1),
np.exp(
-np.sqrt(
np.power((self.rows.reshape(-1, 1) - row), 2)
+ np.power((self.cols.reshape(-1, 1) - col), 2)
)
)
* kind,
)
param = np.exp(-np.sqrt(np.power((self.rows.reshape(-1,1) - row), 2)+np.power((self.cols.reshape(-1,1) - col), 2)))*kind/(np.dot(np.ones_like(self.rows).reshape(-1, 1), ind)) param = (
np.exp(
-np.sqrt(
np.power((self.rows.reshape(-1, 1) - row), 2)
+ np.power((self.cols.reshape(-1, 1) - col), 2)
)
)
* kind
/ (np.dot(np.ones_like(self.rows).reshape(-1, 1), ind))
)
dis = np.dot(self.u_obs[self.rows,self.cols].reshape(1, -1), param) dis = np.dot(self.u_obs[self.rows, self.cols].reshape(1, -1), param)
self.u_pred = dis.reshape(self.u.shape[0], self.u.shape[1]) self.u_pred = dis.reshape(self.u.shape[0], self.u.shape[1])
return self.u_pred return self.u_pred
if __name__ == '__main__': if __name__ == "__main__":
m=sio.loadmat('Example0.mat') m = sio.loadmat("Example0.mat")
u_obs=m['u_obs'] u_obs = m["u_obs"]
u=m['u'] u = m["u"]
sample = KInterpolation(u_obs,u, k=4) sample = KInterpolation(u_obs, u, k=4)
u_pred = sample.predict() u_pred = sample.predict()
print('mae:',mae(u_pred,u)) print("mae:", mae(u_pred, u))
fig = plt.figure(figsize=(22.5, 5))
fig = plt.figure(figsize=(22.5,5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
plt.subplot(141) plt.subplot(141)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.pcolormesh(X,Y,abs(u-u_pred)) im = plt.pcolormesh(X, Y, abs(u - u_pred))
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u,levels=150,cmap='jet') im = plt.contourf(X, Y, u, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet') im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet') im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
#save_name = os.path.join('outputs/predict_plot', '1.png') # save_name = os.path.join('outputs/predict_plot', '1.png')
#fig.savefig(save_name, dpi=300) # fig.savefig(save_name, dpi=300)
# fig = plt.figure(figsize=(5,5))
#fig = plt.figure(figsize=(5,5)) # im = plt.imshow(u,cmap='jet')
#im = plt.imshow(u,cmap='jet') # plt.colorbar(im)
#plt.colorbar(im) fig.savefig("prediction.png", dpi=300)
fig.savefig('prediction.png', dpi=300)

View File

@ -8,9 +8,8 @@ from .util.base import Base
class Kriging(Base): class Kriging(Base):
def __init__(self, u_obs, u, constant=298):
def __init__(self,u_obs,u,constant=298): super().__init__(u_obs, u)
super().__init__(u_obs,u)
def predict(self): def predict(self):
@ -18,68 +17,67 @@ class Kriging(Base):
X, Y = self.train_samples() X, Y = self.train_samples()
test_samples = self.test_samples() test_samples = self.test_samples()
kernel = 1.0 * gp.kernels.RBF(1.0) + gp.kernels.WhiteKernel() # + gp.kernels.DotProduct() kernel = (
regressor = gp.GaussianProcessRegressor(kernel=kernel,n_restarts_optimizer=10, alpha=0.01) 1.0 * gp.kernels.RBF(1.0) + gp.kernels.WhiteKernel()
) # + gp.kernels.DotProduct()
regressor = gp.GaussianProcessRegressor(
kernel=kernel, n_restarts_optimizer=10, alpha=0.01
)
regressor.fit(X, Y) regressor.fit(X, Y)
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1]) self.u_pred = regressor.predict(test_samples).reshape(
self.u.shape[0], self.u.shape[1]
)
return self.u_pred return self.u_pred
if __name__ == '__main__': if __name__ == "__main__":
m=sio.loadmat('Example0.mat') m = sio.loadmat("Example0.mat")
u_obs=m['u_obs'] u_obs = m["u_obs"]
u=m['u'] u = m["u"]
sample = Kriging(u_obs,u) sample = Kriging(u_obs, u)
u_pred = sample.predict() u_pred = sample.predict()
u_pred = u_pred * 50 + 298
u_pred=u_pred*50+298
from sklearn.metrics import mean_absolute_error as mae from sklearn.metrics import mean_absolute_error as mae
print('mae:',mae(u_pred,u))
print("mae:", mae(u_pred, u))
fig = plt.figure(figsize=(22.5, 5))
fig = plt.figure(figsize=(22.5,5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
plt.subplot(141) plt.subplot(141)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.pcolormesh(X,Y,abs(u-u_pred)) im = plt.pcolormesh(X, Y, abs(u - u_pred))
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u,levels=150,cmap='jet') im = plt.contourf(X, Y, u, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet') im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet') im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
#save_name = os.path.join('outputs/predict_plot', '1.png') # save_name = os.path.join('outputs/predict_plot', '1.png')
#fig.savefig(save_name, dpi=300) # fig.savefig(save_name, dpi=300)
#fig = plt.figure(figsize=(5,5))
#im = plt.imshow(u,cmap='jet')
#plt.colorbar(im)
fig.savefig('prediction.png', dpi=300)
# fig = plt.figure(figsize=(5,5))
# im = plt.imshow(u,cmap='jet')
# plt.colorbar(im)
fig.savefig("prediction.png", dpi=300)

View File

@ -10,9 +10,8 @@ from .util.base import Base
class MLPP(Base): class MLPP(Base):
def __init__(self, u_obs, u, layers=[100, 50], constant=298):
def __init__(self,u_obs,u,layers=[100,50],constant=298): super().__init__(u_obs, u)
super().__init__(u_obs,u)
self.layers = layers self.layers = layers
def predict(self): def predict(self):
@ -21,66 +20,68 @@ class MLPP(Base):
X, Y = self.train_samples() X, Y = self.train_samples()
test_samples = self.test_samples() test_samples = self.test_samples()
regressor = MLPRegressor(hidden_layer_sizes=self.layers,alpha=2,solver='adam',random_state=1,max_iter=20000) regressor = MLPRegressor(
hidden_layer_sizes=self.layers,
alpha=2,
solver="adam",
random_state=1,
max_iter=20000,
)
regressor.fit(X, Y) regressor.fit(X, Y)
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1]) self.u_pred = regressor.predict(test_samples).reshape(
self.u.shape[0], self.u.shape[1]
)
return self.u_pred return self.u_pred
if __name__ == '__main__': if __name__ == "__main__":
m=sio.loadmat('Example0.mat') m = sio.loadmat("Example0.mat")
u_obs=m['u_obs'] u_obs = m["u_obs"]
u=m['u'] u = m["u"]
sample = MLP(u_obs,u) sample = MLPP(u_obs, u)
u_pred = sample.predict() u_pred = sample.predict()
print('mae:',mae(u_pred,u)) print("mae:", mae(u_pred, u))
#u_pred=u_pred*50+298 # u_pred=u_pred*50+298
from sklearn.metrics import mean_absolute_error as mae from sklearn.metrics import mean_absolute_error as mae
print('mae:',mae(u_pred,u))
print("mae:", mae(u_pred, u))
fig = plt.figure(figsize=(22.5, 5))
fig = plt.figure(figsize=(22.5,5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
plt.subplot(141) plt.subplot(141)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.pcolormesh(X,Y,abs(u-u_pred)) im = plt.pcolormesh(X, Y, abs(u - u_pred))
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u,levels=150,cmap='jet') im = plt.contourf(X, Y, u, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet') im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet') im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
#save_name = os.path.join('outputs/predict_plot', '1.png') # save_name = os.path.join('outputs/predict_plot', '1.png')
#fig.savefig(save_name, dpi=300) # fig.savefig(save_name, dpi=300)
#fig = plt.figure(figsize=(5,5))
#im = plt.imshow(u,cmap='jet')
#plt.colorbar(im)
fig.savefig('prediction.png', dpi=300)
# fig = plt.figure(figsize=(5,5))
# im = plt.imshow(u,cmap='jet')
# plt.colorbar(im)
fig.savefig("prediction.png", dpi=300)

View File

@ -10,9 +10,8 @@ from .util.base import Base
class Polynomial(Base): class Polynomial(Base):
def __init__(self, u_obs, u, degree=5, constant=298):
def __init__(self,u_obs,u,degree=5, constant=298): super().__init__(u_obs, u)
super().__init__(u_obs,u)
self.degree = degree self.degree = degree
def predict(self): def predict(self):
@ -21,65 +20,66 @@ class Polynomial(Base):
X, Y = self.train_samples() X, Y = self.train_samples()
test_samples = self.test_samples() test_samples = self.test_samples()
regressor = Pipeline([('poly', PolynomialFeatures(degree=self.degree, include_bias=False)), ('clf', LinearRegression())]) regressor = Pipeline(
[
("poly", PolynomialFeatures(degree=self.degree, include_bias=False)),
("clf", LinearRegression()),
]
)
regressor.fit(X, Y) regressor.fit(X, Y)
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1]) self.u_pred = regressor.predict(test_samples).reshape(
self.u.shape[0], self.u.shape[1]
)
return self.u_pred return self.u_pred
if __name__ == '__main__': if __name__ == "__main__":
m=sio.loadmat('Example0.mat') m = sio.loadmat("Example0.mat")
u_obs=m['u_obs'] u_obs = m["u_obs"]
u=m['u'] u = m["u"]
sample = Polynomial(u_obs,u, degree=5) sample = Polynomial(u_obs, u, degree=5)
u_pred = sample.predict() u_pred = sample.predict()
u_pred=u_pred*50+298 u_pred = u_pred * 50 + 298
from sklearn.metrics import mean_absolute_error as mae from sklearn.metrics import mean_absolute_error as mae
print('mae:',mae(u_pred,u))
print("mae:", mae(u_pred, u))
fig = plt.figure(figsize=(22.5, 5))
fig = plt.figure(figsize=(22.5,5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
plt.subplot(141) plt.subplot(141)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.pcolormesh(X,Y,abs(u-u_pred)) im = plt.pcolormesh(X, Y, abs(u - u_pred))
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u,levels=150,cmap='jet') im = plt.contourf(X, Y, u, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet') im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet') im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
#save_name = os.path.join('outputs/predict_plot', '1.png') # save_name = os.path.join('outputs/predict_plot', '1.png')
#fig.savefig(save_name, dpi=300) # fig.savefig(save_name, dpi=300)
#fig = plt.figure(figsize=(5,5))
#im = plt.imshow(u,cmap='jet')
#plt.colorbar(im)
fig.savefig('prediction.png', dpi=300)
# fig = plt.figure(figsize=(5,5))
# im = plt.imshow(u,cmap='jet')
# plt.colorbar(im)
fig.savefig("prediction.png", dpi=300)

View File

@ -9,9 +9,8 @@ from sklearn.metrics import mean_absolute_error as mae
class RandomForest(Base): class RandomForest(Base):
def __init__(self, u_obs, u, constant=298):
def __init__(self,u_obs,u,constant=298): super().__init__(u_obs, u)
super().__init__(u_obs,u)
def predict(self): def predict(self):
@ -22,63 +21,59 @@ class RandomForest(Base):
regressor = RandomForestRegressor(n_estimators=500, random_state=10) regressor = RandomForestRegressor(n_estimators=500, random_state=10)
regressor.fit(X, Y) regressor.fit(X, Y)
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1]) self.u_pred = regressor.predict(test_samples).reshape(
self.u.shape[0], self.u.shape[1]
)
return self.u_pred return self.u_pred
if __name__ == '__main__': if __name__ == "__main__":
m=sio.loadmat('Example0.mat') m = sio.loadmat("Example0.mat")
u_obs=m['u_obs'] u_obs = m["u_obs"]
u=m['u'] u = m["u"]
sample = RandomForest(u_obs,u) sample = RandomForest(u_obs, u)
u_pred = sample.predict() u_pred = sample.predict()
print('mae:',mae(u_pred,u)) print("mae:", mae(u_pred, u))
u_pred=u_pred*50+298 u_pred = u_pred * 50 + 298
from sklearn.metrics import mean_absolute_error as mae from sklearn.metrics import mean_absolute_error as mae
print('mae:',mae(u_pred,u))
print("mae:", mae(u_pred, u))
fig = plt.figure(figsize=(22.5, 5))
fig = plt.figure(figsize=(22.5,5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
plt.subplot(141) plt.subplot(141)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.pcolormesh(X,Y,abs(u-u_pred)) im = plt.pcolormesh(X, Y, abs(u - u_pred))
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u,levels=150,cmap='jet') im = plt.contourf(X, Y, u, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet') im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet') im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
#save_name = os.path.join('outputs/predict_plot', '1.png') # save_name = os.path.join('outputs/predict_plot', '1.png')
#fig.savefig(save_name, dpi=300) # fig.savefig(save_name, dpi=300)
#fig = plt.figure(figsize=(5,5))
#im = plt.imshow(u,cmap='jet')
#plt.colorbar(im)
fig.savefig('prediction.png', dpi=300)
# fig = plt.figure(figsize=(5,5))
# im = plt.imshow(u,cmap='jet')
# plt.colorbar(im)
fig.savefig("prediction.png", dpi=300)

View File

@ -11,10 +11,9 @@ from .util.base import Base
class RBM(Base): class RBM(Base):
def __init__(self, u_obs, u, nComponents=8000, constant=298):
def __init__(self,u_obs,u,nComponents=8000,constant=298): super().__init__(u_obs, u)
super().__init__(u_obs,u) self.n_components = nComponents
self.n_components=nComponents
def predict(self): def predict(self):
@ -22,66 +21,68 @@ class RBM(Base):
X, Y = self.train_samples() X, Y = self.train_samples()
test_samples = self.test_samples() test_samples = self.test_samples()
rbm1 = rbm(n_components=self.n_components, n_iter=300, learning_rate=0.06, random_state=1,verbose=True) rbm1 = rbm(
regressor = Pipeline([('rbm', rbm1), ('clf', LinearRegression())]) n_components=self.n_components,
n_iter=300,
learning_rate=0.06,
random_state=1,
verbose=True,
)
regressor = Pipeline([("rbm", rbm1), ("clf", LinearRegression())])
regressor.fit(X, Y) regressor.fit(X, Y)
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1]) self.u_pred = regressor.predict(test_samples).reshape(
self.u.shape[0], self.u.shape[1]
)
return self.u_pred return self.u_pred
if __name__ == '__main__': if __name__ == "__main__":
m=sio.loadmat('Example0.mat') m = sio.loadmat("Example0.mat")
u_obs=m['u_obs'] u_obs = m["u_obs"]
u=m['u'] u = m["u"]
sample = RBM(u_obs,u, nComponents=8000) sample = RBM(u_obs, u, nComponents=8000)
u_pred = sample.predict() u_pred = sample.predict()
print('mae:',mae(u_pred,u)) print("mae:", mae(u_pred, u))
u_pred=u_pred*50+298 u_pred = u_pred * 50 + 298
from sklearn.metrics import mean_absolute_error as mae from sklearn.metrics import mean_absolute_error as mae
print('mae:',mae(u_pred,u))
print("mae:", mae(u_pred, u))
fig = plt.figure(figsize=(22.5, 5))
fig = plt.figure(figsize=(22.5,5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
plt.subplot(141) plt.subplot(141)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.pcolormesh(X,Y,abs(u-u_pred)) im = plt.pcolormesh(X, Y, abs(u - u_pred))
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u,levels=150,cmap='jet') im = plt.contourf(X, Y, u, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet') im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet') im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
#save_name = os.path.join('outputs/predict_plot', '1.png') # save_name = os.path.join('outputs/predict_plot', '1.png')
#fig.savefig(save_name, dpi=300) # fig.savefig(save_name, dpi=300)
#fig = plt.figure(figsize=(5,5))
#im = plt.imshow(u,cmap='jet')
#plt.colorbar(im)
fig.savefig('prediction.png', dpi=300)
# fig = plt.figure(figsize=(5,5))
# im = plt.imshow(u,cmap='jet')
# plt.colorbar(im)
fig.savefig("prediction.png", dpi=300)

View File

@ -8,9 +8,8 @@ from .util.base import Base
class RSVR(Base): class RSVR(Base):
def __init__(self, u_obs, u, constant=298):
def __init__(self,u_obs,u,constant=298): super().__init__(u_obs, u)
super().__init__(u_obs,u)
def predict(self): def predict(self):
@ -18,64 +17,61 @@ class RSVR(Base):
X, Y = self.train_samples() X, Y = self.train_samples()
test_samples = self.test_samples() test_samples = self.test_samples()
regressor = SVR(kernel='rbf') regressor = SVR(kernel="rbf")
regressor.fit(X, Y) regressor.fit(X, Y)
self.u_pred=regressor.predict(test_samples).reshape(self.u.shape[0],self.u.shape[1]) self.u_pred = regressor.predict(test_samples).reshape(
self.u.shape[0], self.u.shape[1]
)
return self.u_pred return self.u_pred
if __name__ == '__main__': if __name__ == "__main__":
m=sio.loadmat('Example0.mat') m = sio.loadmat("Example0.mat")
u_obs=m['u_obs'] u_obs = m["u_obs"]
u=m['u'] u = m["u"]
sample = RSVR(u_obs,u) sample = RSVR(u_obs, u)
u_pred = sample.predict() u_pred = sample.predict()
from sklearn.metrics import mean_absolute_error as mae from sklearn.metrics import mean_absolute_error as mae
print('mae:',mae(u_pred,u))
u_pred=u_pred*50+298 print("mae:", mae(u_pred, u))
fig = plt.figure(figsize=(22.5,5)) u_pred = u_pred * 50 + 298
fig = plt.figure(figsize=(22.5, 5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
plt.subplot(141) plt.subplot(141)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.pcolormesh(X,Y,abs(u-u_pred)) im = plt.pcolormesh(X, Y, abs(u - u_pred))
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u,levels=150,cmap='jet') im = plt.contourf(X, Y, u, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, u_pred,levels=150,cmap='jet') im = plt.contourf(X, Y, u_pred, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, abs(u-u_pred),levels=150,cmap='jet') im = plt.contourf(X, Y, abs(u - u_pred), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
#save_name = os.path.join('outputs/predict_plot', '1.png') # save_name = os.path.join('outputs/predict_plot', '1.png')
#fig.savefig(save_name, dpi=300) # fig.savefig(save_name, dpi=300)
#fig = plt.figure(figsize=(5,5))
#im = plt.imshow(u,cmap='jet')
#plt.colorbar(im)
fig.savefig('prediction.png', dpi=300)
# fig = plt.figure(figsize=(5,5))
# im = plt.imshow(u,cmap='jet')
# plt.colorbar(im)
fig.savefig("prediction.png", dpi=300)

View File

@ -12,7 +12,8 @@ class Base:
""" """
The observations are in matrix format The observations are in matrix format
""" """
def __init__(self,u_obs,u,constant=298):
def __init__(self, u_obs, u, constant=298):
self.u_obs = np.array(u_obs) self.u_obs = np.array(u_obs)
self.u = np.array(u) self.u = np.array(u)
self.constant = constant self.constant = constant
@ -22,18 +23,22 @@ class Base:
self.obser() self.obser()
def obser(self): def obser(self):
[self.rows,self.cols] = np.where(self.u_obs>self.constant) [self.rows, self.cols] = np.where(self.u_obs > self.constant)
def pred_init(self): def pred_init(self):
self.u_pred = np.zeros_like(self.u) self.u_pred = np.zeros_like(self.u)
def train_samples(self): def train_samples(self):
X_train = np.transpose(np.array([self.rows,self.cols])) / max(self.u.shape) X_train = np.transpose(np.array([self.rows, self.cols])) / max(self.u.shape)
y_train = np.transpose(self.u[self.rows,self.cols]) y_train = np.transpose(self.u[self.rows, self.cols])
return X_train, y_train return X_train, y_train
def test_samples(self): def test_samples(self):
samples = [[row,col] for row in range(self.u.shape[0]) for col in range(self.u.shape[1])] samples = [
[row, col]
for row in range(self.u.shape[0])
for col in range(self.u.shape[1])
]
samples = np.array(samples) / max(self.u.shape) samples = np.array(samples) / max(self.u.shape)
return samples return samples
@ -45,6 +50,7 @@ class BaseVec:
""" """
The observations are in matrix format The observations are in matrix format
""" """
def __init__(self, root, train_list, constant=298): def __init__(self, root, train_list, constant=298):
self.root = root self.root = root
self.train_list = train_list self.train_list = train_list
@ -55,43 +61,45 @@ class BaseVec:
self.constant = constant self.constant = constant
def _loader(self, path, mode='train'): def _loader(self, path, mode="train"):
input = [] input = []
output = [] output = []
if mode == 'train': if mode == "train":
for _ in range(4*4): for _ in range(4 * 4):
output.append([]) output.append([])
else: else:
pass pass
#print((path[3])) # print((path[3]))
num = 0 num = 0
for i in range(len(path)): for i in range(len(path)):
num = num + 1 num = num + 1
#print(len(path)) # print(len(path))
#print(i) # print(i)
source = np.array(sio.loadmat(path[i])['u_obs']) source = np.array(sio.loadmat(path[i])["u_obs"])
target = np.array(sio.loadmat(path[i])['u']) target = np.array(sio.loadmat(path[i])["u"])
if self.layout is None: if self.layout is None:
self.layout = np.array(sio.loadmat(path[i])['F']) self.layout = np.array(sio.loadmat(path[i])["F"])
else: else:
pass pass
indata = source[np.where(source>TOL)] indata = source[np.where(source > TOL)]
input.append(indata) input.append(indata)
if mode == 'train': if mode == "train":
for k in range(4): for k in range(4):
for kk in range(4): for kk in range(4):
sep = target[0+k:target.shape[0]:4, 0+kk:target.shape[1]:4].flatten() sep = target[
output[k*4+kk].append(sep) 0 + k : target.shape[0] : 4, 0 + kk : target.shape[1] : 4
elif mode == 'test': ].flatten()
output[k * 4 + kk].append(sep)
elif mode == "test":
output.append(target) output.append(target)
else: else:
pass pass
if num % 1000 == 0 : if num % 1000 == 0:
print("num:", num) print("num:", num)
return input, output return input, output
@ -104,14 +112,17 @@ class BaseVec:
base = os.path.dirname(list_path) base = os.path.dirname(list_path)
print(base) print(base)
test_name = os.path.splitext(os.path.basename(list_path))[0] test_name = os.path.splitext(os.path.basename(list_path))[0]
subdir = os.path.join("train", "train") \ subdir = (
if base=='train' else os.path.join("test", test_name) os.path.join("train", "train")
if base == "train"
else os.path.join("test", test_name)
)
file_dir = os.path.join(root_dir, subdir) file_dir = os.path.join(root_dir, subdir)
list_file = os.path.join(root_dir, list_path) list_file = os.path.join(root_dir, list_path)
print(file_dir) print(file_dir)
print(list_file) print(list_file)
assert os.path.isdir(file_dir) assert os.path.isdir(file_dir)
with open(list_file, 'r') as rf: with open(list_file, "r") as rf:
for line in rf.readlines(): for line in rf.readlines():
data_path = line.strip() data_path = line.strip()
path = os.path.join(file_dir, data_path) path = os.path.join(file_dir, data_path)
@ -119,13 +130,13 @@ class BaseVec:
return files return files
def train_samples(self): def train_samples(self):
#print(self.train_file) # print(self.train_file)
X_train, y_train = self._loader(self.train_file) X_train, y_train = self._loader(self.train_file)
return X_train, y_train return X_train, y_train
def test_samples(self, test_path): def test_samples(self, test_path):
test_file = self.make_dataset(self.root, test_path) test_file = self.make_dataset(self.root, test_path)
X_test, y_test = self._loader(test_file, mode='test') X_test, y_test = self._loader(test_file, mode="test")
return X_test, np.array(y_test) return X_test, np.array(y_test)
@ -133,21 +144,18 @@ class BaseVec:
pass pass
if __name__ == "__main__":
if __name__ == '__main__': # m=sio.loadmat('Example0.mat')
#m=sio.loadmat('Example0.mat') # u_obs=m['u_obs']
#u_obs=m['u_obs'] # u=m['u']
#u=m['u'] # sample = Base(u_obs,u)
#sample = Base(u_obs,u) # sample.vec()
#sample.vec() root = "g:/gong/recon_project/TFRD/HSink/"
root = 'g:/gong/recon_project/TFRD/HSink/' train_list = "g:/gong/recon_project/TFRD/HSink/train/train_val.txt"
train_list = 'g:/gong/recon_project/TFRD/HSink/train/train_val.txt' test_list = "g:/gong/recon_project/TFRD/HSink/train/test_0.txt"
test_list = 'g:/gong/recon_project/TFRD/HSink/train/test_0.txt'
sample = BaseVec(root, train_list) sample = BaseVec(root, train_list)
a, b = sample.train_samples() a, b = sample.train_samples()
print(a) print(a)
print(b[0]) print(b[0])

View File

@ -1,4 +1,4 @@
PAD_WORD = '<blank>' PAD_WORD = "<blank>"
UNK_WORD = '<unk>' UNK_WORD = "<unk>"
BOS_WORD = '<s>' BOS_WORD = "<s>"
EOS_WORD = '</s>' EOS_WORD = "</s>"

View File

@ -1,4 +1,4 @@
''' Define the Layers ''' """ Define the Layers """
import torch.nn as nn import torch.nn as nn
import torch import torch
from .SubLayers import MultiHeadAttention, PositionwiseFeedForward from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
@ -8,7 +8,7 @@ __author__ = "Yu-Hsiang Huang"
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
''' Compose with two layers ''' """Compose with two layers"""
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__() super(EncoderLayer, self).__init__()
@ -17,13 +17,14 @@ class EncoderLayer(nn.Module):
def forward(self, enc_input, slf_attn_mask=None): def forward(self, enc_input, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn( enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask) enc_input, enc_input, enc_input, mask=slf_attn_mask
)
enc_output = self.pos_ffn(enc_output) enc_output = self.pos_ffn(enc_output)
return enc_output, enc_slf_attn return enc_output, enc_slf_attn
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
''' Compose with three layers ''' """Compose with three layers"""
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(DecoderLayer, self).__init__() super(DecoderLayer, self).__init__()
@ -32,12 +33,14 @@ class DecoderLayer(nn.Module):
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward( def forward(
self, dec_input, enc_output, self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None
slf_attn_mask=None, dec_enc_attn_mask=None): ):
dec_output, dec_slf_attn = self.slf_attn( dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, mask=slf_attn_mask) dec_input, dec_input, dec_input, mask=slf_attn_mask
)
dec_output, dec_enc_attn = self.enc_attn( dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask) dec_output, enc_output, enc_output, mask=dec_enc_attn_mask
)
dec_output = self.pos_ffn(dec_output) dec_output = self.pos_ffn(dec_output)
return dec_output, dec_slf_attn, dec_enc_attn return dec_output, dec_slf_attn, dec_enc_attn

View File

@ -6,7 +6,7 @@ __author__ = "Yu-Hsiang Huang"
class ScaledDotProductAttention(nn.Module): class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention ''' """Scaled Dot-Product Attention"""
def __init__(self, temperature, attn_dropout=0.1): def __init__(self, temperature, attn_dropout=0.1):
super().__init__() super().__init__()

View File

@ -1,8 +1,9 @@
'''A wrapper class for scheduled optimizer ''' """A wrapper class for scheduled optimizer """
import numpy as np import numpy as np
class ScheduledOptim():
'''A simple wrapper class for learning rate scheduling''' class ScheduledOptim:
"""A simple wrapper class for learning rate scheduling"""
def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps): def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps):
self._optimizer = optimizer self._optimizer = optimizer
@ -11,30 +12,27 @@ class ScheduledOptim():
self.n_warmup_steps = n_warmup_steps self.n_warmup_steps = n_warmup_steps
self.n_steps = 0 self.n_steps = 0
def step_and_update_lr(self): def step_and_update_lr(self):
"Step with the inner optimizer" "Step with the inner optimizer"
self._update_learning_rate() self._update_learning_rate()
self._optimizer.step() self._optimizer.step()
def zero_grad(self): def zero_grad(self):
"Zero out the gradients with the inner optimizer" "Zero out the gradients with the inner optimizer"
self._optimizer.zero_grad() self._optimizer.zero_grad()
def _get_lr_scale(self): def _get_lr_scale(self):
d_model = self.d_model d_model = self.d_model
n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5)) return (d_model ** -0.5) * min(
n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5)
)
def _update_learning_rate(self): def _update_learning_rate(self):
''' Learning rate scheduling per step ''' """Learning rate scheduling per step"""
self.n_steps += 1 self.n_steps += 1
lr = self.lr_mul * self._get_lr_scale() lr = self.lr_mul * self._get_lr_scale()
for param_group in self._optimizer.param_groups: for param_group in self._optimizer.param_groups:
param_group['lr'] = lr param_group["lr"] = lr

View File

@ -1,4 +1,4 @@
''' Define the sublayers in encoder/decoder layer ''' """ Define the sublayers in encoder/decoder layer """
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -8,7 +8,7 @@ __author__ = "Yu-Hsiang Huang"
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module ''' """Multi-Head Attention module"""
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__() super().__init__()
@ -59,7 +59,7 @@ class MultiHeadAttention(nn.Module):
class PositionwiseFeedForward(nn.Module): class PositionwiseFeedForward(nn.Module):
''' A two-feed-forward-layer module ''' """A two-feed-forward-layer module"""
def __init__(self, d_in, d_hid, dropout=0.1): def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__() super().__init__()

View File

@ -1,4 +1,4 @@
''' This module will handle the text generation with beam search. ''' """ This module will handle the text generation with beam search. """
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -7,12 +7,18 @@ from models.transformer.Models import Transformer, get_pad_mask, get_subsequent_
class Translator(nn.Module): class Translator(nn.Module):
''' Load a trained model and translate in beam search fashion. ''' """Load a trained model and translate in beam search fashion."""
def __init__( def __init__(
self, model, beam_size, max_seq_len, self,
src_pad_idx, trg_pad_idx, trg_bos_idx, trg_eos_idx): model,
beam_size,
max_seq_len,
src_pad_idx,
trg_pad_idx,
trg_bos_idx,
trg_eos_idx,
):
super(Translator, self).__init__() super(Translator, self).__init__()
@ -26,22 +32,21 @@ class Translator(nn.Module):
self.model = model self.model = model
self.model.eval() self.model.eval()
self.register_buffer('init_seq', torch.LongTensor([[trg_bos_idx]])) self.register_buffer("init_seq", torch.LongTensor([[trg_bos_idx]]))
self.register_buffer( self.register_buffer(
'blank_seqs', "blank_seqs",
torch.full((beam_size, max_seq_len), trg_pad_idx, dtype=torch.long)) torch.full((beam_size, max_seq_len), trg_pad_idx, dtype=torch.long),
)
self.blank_seqs[:, 0] = self.trg_bos_idx self.blank_seqs[:, 0] = self.trg_bos_idx
self.register_buffer( self.register_buffer(
'len_map', "len_map", torch.arange(1, max_seq_len + 1, dtype=torch.long).unsqueeze(0)
torch.arange(1, max_seq_len + 1, dtype=torch.long).unsqueeze(0)) )
def _model_decode(self, trg_seq, enc_output, src_mask): def _model_decode(self, trg_seq, enc_output, src_mask):
trg_mask = get_subsequent_mask(trg_seq) trg_mask = get_subsequent_mask(trg_seq)
dec_output, *_ = self.model.decoder(trg_seq, trg_mask, enc_output, src_mask) dec_output, *_ = self.model.decoder(trg_seq, trg_mask, enc_output, src_mask)
return F.softmax(self.model.trg_word_prj(dec_output), dim=-1) return F.softmax(self.model.trg_word_prj(dec_output), dim=-1)
def _get_init_state(self, src_seq, src_mask): def _get_init_state(self, src_seq, src_mask):
beam_size = self.beam_size beam_size = self.beam_size
@ -56,7 +61,6 @@ class Translator(nn.Module):
enc_output = enc_output.repeat(beam_size, 1, 1) enc_output = enc_output.repeat(beam_size, 1, 1)
return enc_output, gen_seq, scores return enc_output, gen_seq, scores
def _get_the_best_score_and_idx(self, gen_seq, dec_output, scores, step): def _get_the_best_score_and_idx(self, gen_seq, dec_output, scores, step):
assert len(scores.size()) == 1 assert len(scores.size()) == 1
@ -66,13 +70,18 @@ class Translator(nn.Module):
best_k2_probs, best_k2_idx = dec_output[:, -1, :].topk(beam_size) best_k2_probs, best_k2_idx = dec_output[:, -1, :].topk(beam_size)
# Include the previous scores. # Include the previous scores.
scores = torch.log(best_k2_probs).view(beam_size, -1) + scores.view(beam_size, 1) scores = torch.log(best_k2_probs).view(beam_size, -1) + scores.view(
beam_size, 1
)
# Get the best k candidates from k^2 candidates. # Get the best k candidates from k^2 candidates.
scores, best_k_idx_in_k2 = scores.view(-1).topk(beam_size) scores, best_k_idx_in_k2 = scores.view(-1).topk(beam_size)
# Get the corresponding positions of the best k candidiates. # Get the corresponding positions of the best k candidiates.
best_k_r_idxs, best_k_c_idxs = best_k_idx_in_k2 // beam_size, best_k_idx_in_k2 % beam_size best_k_r_idxs, best_k_c_idxs = (
best_k_idx_in_k2 // beam_size,
best_k_idx_in_k2 % beam_size,
)
best_k_idx = best_k2_idx[best_k_r_idxs, best_k_c_idxs] best_k_idx = best_k2_idx[best_k_r_idxs, best_k_c_idxs]
# Copy the corresponding previous tokens. # Copy the corresponding previous tokens.
@ -82,7 +91,6 @@ class Translator(nn.Module):
return gen_seq, scores return gen_seq, scores
def translate_sentence(self, src_seq): def translate_sentence(self, src_seq):
# Only accept batch size equals to 1 in this function. # Only accept batch size equals to 1 in this function.
# TODO: expand to batch operation. # TODO: expand to batch operation.
@ -95,10 +103,12 @@ class Translator(nn.Module):
src_mask = get_pad_mask(src_seq, src_pad_idx) src_mask = get_pad_mask(src_seq, src_pad_idx)
enc_output, gen_seq, scores = self._get_init_state(src_seq, src_mask) enc_output, gen_seq, scores = self._get_init_state(src_seq, src_mask)
ans_idx = 0 # default ans_idx = 0 # default
for step in range(2, max_seq_len): # decode up to max length for step in range(2, max_seq_len): # decode up to max length
dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask) dec_output = self._model_decode(gen_seq[:, :step], enc_output, src_mask)
gen_seq, scores = self._get_the_best_score_and_idx(gen_seq, dec_output, scores, step) gen_seq, scores = self._get_the_best_score_and_idx(
gen_seq, dec_output, scores, step
)
# Check if all path finished # Check if all path finished
# -- locate the eos in the generated sequences # -- locate the eos in the generated sequences
@ -111,4 +121,4 @@ class Translator(nn.Module):
_, ans_idx = scores.div(seq_lens.float() ** alpha).max(0) _, ans_idx = scores.div(seq_lens.float() ** alpha).max(0)
ans_idx = ans_idx.item() ans_idx = ans_idx.item()
break break
return gen_seq[ans_idx][:seq_lens[ans_idx]].tolist() return gen_seq[ans_idx][: seq_lens[ans_idx]].tolist()

View File

@ -22,7 +22,7 @@ def main(hparams):
if hparams.gpu == 0: if hparams.gpu == 0:
device = torch.device("cpu") device = torch.device("cpu")
else: else:
ngpu = "cuda:"+str(hparams.gpu-1) ngpu = "cuda:" + str(hparams.gpu - 1)
print(ngpu) print(ngpu)
device = torch.device(ngpu) device = torch.device(ngpu)
model = Model(hparams).to(device) model = Model(hparams).to(device)
@ -31,8 +31,9 @@ def main(hparams):
print() print()
# Model loading # Model loading
model_path = os.path.join(f'lightning_logs/version_' + model_path = os.path.join(
hparams.test_check_num, 'checkpoints/') f"lightning_logs/version_" + hparams.test_check_num, "checkpoints/"
)
ckpt = list(Path(model_path).glob("*.ckpt"))[0] ckpt = list(Path(model_path).glob("*.ckpt"))[0]
print(ckpt) print(ckpt)
@ -47,9 +48,9 @@ def main(hparams):
test_list = hparams.test_list test_list = hparams.test_list
file_path = os.path.join(root, test_list) file_path = os.path.join(root, test_list)
test_name = os.path.splitext(os.path.basename(test_list))[0] test_name = os.path.splitext(os.path.basename(test_list))[0]
root_dir = os.path.join(root, 'test', test_name) root_dir = os.path.join(root, "test", test_name)
with open(file_path, 'r') as fp: with open(file_path, "r") as fp:
for line in fp.readlines(): for line in fp.readlines():
# Data Reading # Data Reading
data_path = line.strip() data_path = line.strip()
@ -61,20 +62,53 @@ def main(hparams):
u_true = heat.squeeze().squeeze().numpy() u_true = heat.squeeze().squeeze().numpy()
heat_obs = (heat_obs - hparams.mean_layout) / hparams.std_layout heat_obs = (heat_obs - hparams.mean_layout) / hparams.std_layout
heat0 = (heat0 - hparams.mean_heat) / hparams.std_heat heat0 = (heat0 - hparams.mean_heat) / hparams.std_heat
heat = (heat-hparams.mean_heat) / hparams.std_heat heat = (heat - hparams.mean_heat) / hparams.std_heat
obs_index, heat_obs, pred_index, heat0, heat = obs_index.to(device), heat_obs.to(device), pred_index.to(device), heat0.to(device), heat.to(device) obs_index, heat_obs, pred_index, heat0, heat = (
obs_index.to(device),
heat_obs.to(device),
pred_index.to(device),
heat0.to(device),
heat.to(device),
)
heat_info = [obs_index, heat_obs, pred_index, heat0] heat_info = [obs_index, heat_obs, pred_index, heat0]
if model.layout_model=="ConditionalNeuralProcess" or model.layout_model=="TransformerRecon": if (
heat_info[1] = heat_info[1].transpose(1,2) model.layout_model == "ConditionalNeuralProcess"
heat_info[3] = heat_info[3].transpose(2,3) or model.layout_model == "TransformerRecon"
elif model.layout_model=="DenseDeepGCN": ):
heat_obs=heat_obs.squeeze() heat_info[1] = heat_info[1].transpose(1, 2)
pseudo_heat = torch.zeros_like(heat0[:,0,:]).squeeze() heat_info[3] = heat_info[3].transpose(2, 3)
inputs = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,0,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0) elif model.layout_model == "DenseDeepGCN":
heat_obs = heat_obs.squeeze()
pseudo_heat = torch.zeros_like(heat0[:, 0, :]).squeeze()
inputs = (
torch.cat(
(
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
torch.cat((obs_index, pred_index[:, 0, ...]), 1),
),
2,
)
.transpose(1, 2)
.unsqueeze(-1)
.unsqueeze(0)
)
for i in range(self.hparams.div_num*self.hparams.div_num-1): for i in range(hparams.div_num * hparams.div_num - 1):
input_single = torch.cat((torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1), torch.cat((obs_index, pred_index[:,i+1,...]), 1)), 2).transpose(1,2).unsqueeze(-1).unsqueeze(0) input_single = (
torch.cat(
(
torch.cat((heat_obs, pseudo_heat), 1).unsqueeze(-1),
torch.cat(
(obs_index, pred_index[:, i + 1, ...]), 1
),
),
2,
)
.transpose(1, 2)
.unsqueeze(-1)
.unsqueeze(0)
)
inputs = torch.cat((inputs, input_single), 0) inputs = torch.cat((inputs, input_single), 0)
heat_info = inputs heat_info = inputs
@ -82,83 +116,111 @@ def main(hparams):
data = sio.loadmat(path) data = sio.loadmat(path)
u_true, u_obs = data["u"], data["u_obs"] u_true, u_obs = data["u"], data["u_obs"]
u_obs[np.where(u_obs<TOL)]=hparams.mean_layout u_obs[np.where(u_obs < TOL)] = hparams.mean_layout
u_obs = torch.Tensor((u_obs - hparams.mean_layout) / hparams.std_layout).unsqueeze(0).unsqueeze(0).to(device) u_obs = (
heat = torch.Tensor((u_true - hparams.mean_heat) / hparams.std_heat).unsqueeze(0).unsqueeze(0).to(device) torch.Tensor((u_obs - hparams.mean_layout) / hparams.std_layout)
.unsqueeze(0)
.unsqueeze(0)
.to(device)
)
heat = (
torch.Tensor((u_true - hparams.mean_heat) / hparams.std_heat)
.unsqueeze(0)
.unsqueeze(0)
.to(device)
)
heat_info = u_obs heat_info = u_obs
hs_F = sio.loadmat(path)["F"] hs_F = sio.loadmat(path)["F"]
# Plot u_obs and Real Temperature Field # Plot u_obs and Real Temperature Field
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
plt.subplot(141) plt.subplot(141)
plt.title('Real Time Power') plt.title("Real Time Power")
im = plt.pcolormesh(X,Y,hs_F) im = plt.pcolormesh(X, Y, hs_F)
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
with torch.no_grad(): with torch.no_grad():
heat_pred0 = model(heat_info) heat_pred0 = model(heat_info)
if model.vec: if model.vec:
if model.layout_model=="DenseDeepGCN": if model.layout_model == "DenseDeepGCN":
heat_pred0 = heat_pred0[...,-model.output_dim:] heat_pred0 = heat_pred0[..., -model.output_dim :]
else: else:
pass pass
heat_pred0 = heat_pred0.reshape((-1, hparams.div_num*hparams.div_num, int(200 / hparams.div_num), int(200 / hparams.div_num))) heat_pred0 = heat_pred0.reshape(
(
-1,
hparams.div_num * hparams.div_num,
int(200 / hparams.div_num),
int(200 / hparams.div_num),
)
)
heat_pre = torch.zeros_like(heat_pred0).reshape((-1, 1, 200, 200)) heat_pre = torch.zeros_like(heat_pred0).reshape((-1, 1, 200, 200))
for i in range(hparams.div_num): for i in range(hparams.div_num):
for j in range(hparams.div_num): for j in range(hparams.div_num):
heat_pre[..., 0+i:200:hparams.div_num, 0+j:200:hparams.div_num] = heat_pred0[:, hparams.div_num*i+j,...].unsqueeze(1) heat_pre[
heat_pre = heat_pre.transpose(2,3) ...,
0 + i : 200 : hparams.div_num,
0 + j : 200 : hparams.div_num,
] = heat_pred0[:, hparams.div_num * i + j, ...].unsqueeze(1)
heat_pre = heat_pre.transpose(2, 3)
heat = heat.unsqueeze(1) heat = heat.unsqueeze(1)
else: else:
heat_pre = heat_pred0 heat_pre = heat_pred0
heat = heat heat = heat
mae = F.l1_loss(heat, heat_pre) * hparams.std_heat mae = F.l1_loss(heat, heat_pre) * hparams.std_heat
print('sample:', data_path) print("sample:", data_path)
print('MAE:', mae) print("MAE:", mae)
mae_test.append(mae.item()) mae_test.append(mae.item())
heat_pre = heat_pre.squeeze(0).squeeze(0).cpu().numpy() * hparams.std_heat + hparams.mean_heat heat_pre = (
#heat_pre = np.transpose(heat_pre, (1,0)) heat_pre.squeeze(0).squeeze(0).cpu().numpy() * hparams.std_heat
+ hparams.mean_heat
)
# heat_pre = np.transpose(heat_pre, (1,0))
hmax = max(np.max(heat_pre), np.max(u_true)) hmax = max(np.max(heat_pre), np.max(u_true))
hmin = min(np.min(heat_pre), np.min(u_true)) hmin = min(np.min(heat_pre), np.min(u_true))
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u_true,levels=150,cmap='jet') im = plt.contourf(X, Y, u_true, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, heat_pre,levels=150,cmap='jet') im = plt.contourf(X, Y, heat_pre, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, np.abs(heat_pre-u_true),levels=150,cmap='jet') im = plt.contourf(X, Y, np.abs(heat_pre - u_true), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
save_name = os.path.join('outputs/predict_plot', os.path.splitext(os.path.basename(path))[0]+'.png') save_name = os.path.join(
mat_name = os.path.join('outputs/predict_plot', os.path.splitext(os.path.basename(path))[0]+'.mat') "outputs/predict_plot",
sio.savemat(mat_name, {'pre': heat_pre, 'u_true': u_true}) os.path.splitext(os.path.basename(path))[0] + ".png",
)
mat_name = os.path.join(
"outputs/predict_plot",
os.path.splitext(os.path.basename(path))[0] + ".mat",
)
sio.savemat(mat_name, {"pre": heat_pre, "u_true": u_true})
fig.savefig(save_name, dpi=300) fig.savefig(save_name, dpi=300)
plt.close() plt.close()
mae_test = np.array(mae_test) mae_test = np.array(mae_test)
print(mae_test.mean()) print(mae_test.mean())
np.savetxt('outputs/mae_test.csv', mae_test, fmt='%f', delimiter=',') np.savetxt("outputs/mae_test.csv", mae_test, fmt="%f", delimiter=",")

View File

@ -22,9 +22,7 @@ def main(hparams):
model.testing_step() model.testing_step()
class PointModel: class PointModel:
def __init__(self, hparams): def __init__(self, hparams):
super().__init__() super().__init__()
self.hparams = hparams self.hparams = hparams
@ -49,80 +47,82 @@ class PointModel:
for num, data in enumerate(trange): for num, data in enumerate(trange):
path = self.data.sample_files[num] path = self.data.sample_files[num]
u_obs=data[0] u_obs = data[0]
u=data[1] u = data[1]
F=data[2] F = data[2]
sample = Point(self.hparams.model_name, u_obs, u) sample = Point(self.hparams.model_name, u_obs, u)
u_pred = sample.predict() u_pred = sample.predict()
if(self.hparams.plot): if self.hparams.plot:
self.plot(path, u_pred, u, F) self.plot(path, u_pred, u, F)
all_mae += self.metric.mae(u_pred,u) all_mae += self.metric.mae(u_pred, u)
all_maxae += self.metric.maxae(u_pred,u) all_maxae += self.metric.maxae(u_pred, u)
all_cmae += self.metric.cmae(u_pred,u,F) all_cmae += self.metric.cmae(u_pred, u, F)
all_mcae += self.metric.mcae(u_pred,u,F) all_mcae += self.metric.mcae(u_pred, u, F)
all_bmae += self.metric.bmae(u_pred,u) all_bmae += self.metric.bmae(u_pred, u)
mae_test.append(self.metric.mae(u_pred,u)) mae_test.append(self.metric.mae(u_pred, u))
trange.set_description("Testing") trange.set_description("Testing")
trange.set_postfix(MAE=all_mae/(num+1), MaxAE=all_maxae/(num+1), CMAE=all_cmae/(num+1), MCAE=all_mcae/(num+1), BMAE=all_bmae/(num+1)) trange.set_postfix(
MAE=all_mae / (num + 1),
MaxAE=all_maxae / (num + 1),
CMAE=all_cmae / (num + 1),
MCAE=all_mcae / (num + 1),
BMAE=all_bmae / (num + 1),
)
mae_test = np.array(mae_test) mae_test = np.array(mae_test)
print(mae_test.mean()) print(mae_test.mean())
np.savetxt('outputs/mae_test.csv', mae_test, fmt='%f', delimiter=',') np.savetxt("outputs/mae_test.csv", mae_test, fmt="%f", delimiter=",")
def plot(self, path, heat_pre, u_true, hs_F): def plot(self, path, heat_pre, u_true, hs_F):
fig = plt.figure(figsize=(22.5,5)) fig = plt.figure(figsize=(22.5, 5))
grid_x = np.linspace(0, 0.1, num=200) grid_x = np.linspace(0, 0.1, num=200)
grid_y = np.linspace(0, 0.1, num=200) grid_y = np.linspace(0, 0.1, num=200)
X, Y = np.meshgrid(grid_x, grid_y) X, Y = np.meshgrid(grid_x, grid_y)
plt.subplot(141) plt.subplot(141)
plt.title('Real Time Power') plt.title("Real Time Power")
im = plt.pcolormesh(X,Y,hs_F) im = plt.pcolormesh(X, Y, hs_F)
plt.colorbar(im) plt.colorbar(im)
fig.tight_layout(pad=2.0, w_pad=3.0,h_pad=2.0) fig.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
plt.subplot(142) plt.subplot(142)
plt.title('Real Temperature Field') plt.title("Real Temperature Field")
im = plt.contourf(X,Y,u_true,levels=150,cmap='jet') im = plt.contourf(X, Y, u_true, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(143) plt.subplot(143)
plt.title('Reconstructed Temperature Field') plt.title("Reconstructed Temperature Field")
im = plt.contourf(X, Y, heat_pre,levels=150,cmap='jet') im = plt.contourf(X, Y, heat_pre, levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
plt.subplot(144) plt.subplot(144)
plt.title('Absolute Error') plt.title("Absolute Error")
im = plt.contourf(X, Y, np.abs(heat_pre-u_true),levels=150,cmap='jet') im = plt.contourf(X, Y, np.abs(heat_pre - u_true), levels=150, cmap="jet")
plt.colorbar(im) plt.colorbar(im)
save_name = os.path.join('outputs/predict_plot', os.path.splitext(os.path.basename(path))[0]+'.png') save_name = os.path.join(
mat_name = os.path.join('outputs/predict_plot', os.path.splitext(os.path.basename(path))[0]+'.mat') "outputs/predict_plot", os.path.splitext(os.path.basename(path))[0] + ".png"
sio.savemat(mat_name, {'pre': heat_pre, 'u_true': u_true}) )
mat_name = os.path.join(
"outputs/predict_plot", os.path.splitext(os.path.basename(path))[0] + ".mat"
)
sio.savemat(mat_name, {"pre": heat_pre, "u_true": u_true})
fig.savefig(save_name, dpi=300) fig.savefig(save_name, dpi=300)
plt.close() plt.close()
class Metric: class Metric:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -133,17 +133,17 @@ class Metric:
return np.max(np.max(abs(u_pred - u))) return np.max(np.max(abs(u_pred - u)))
def mcae(self, u_pred, u, F): def mcae(self, u_pred, u, F):
F[np.where(F>TOL)] = 1 F[np.where(F > TOL)] = 1
return np.max(np.max(abs(u_pred - u)*F)) return np.max(np.max(abs(u_pred - u) * F))
def cmae(self, u_pred, u, F): def cmae(self, u_pred, u, F):
F[np.where(F>TOL)] = 1 F[np.where(F > TOL)] = 1
return np.sum(np.sum(abs(u_pred - u)*F)) / np.sum(F) return np.sum(np.sum(abs(u_pred - u) * F)) / np.sum(F)
def bmae(self, u_pred, u): def bmae(self, u_pred, u):
ind = np.zeros_like(u_pred) ind = np.zeros_like(u_pred)
ind[:2,...]=1 ind[:2, ...] = 1
ind[-2:,...]=1 ind[-2:, ...] = 1
ind[...,:2]=1 ind[..., :2] = 1
ind[...,-2:]=1 ind[..., -2:] = 1
return np.sum(np.sum(abs(u_pred - u)*ind)) / np.sum(ind) return np.sum(np.sum(abs(u_pred - u) * ind)) / np.sum(ind)

View File

@ -33,15 +33,16 @@ def main(hparams):
if hparams.gpu == 0: if hparams.gpu == 0:
hparams.gpu = 0 hparams.gpu = 0
else: else:
hparams.gpu = [hparams.gpu-1] hparams.gpu = [hparams.gpu - 1]
trainer = pl.Trainer( trainer = pl.Trainer(
gpus=hparams.gpu, gpus=hparams.gpu,
precision=16 if hparams.use_16bit else 32, precision=16 if hparams.use_16bit else 32,
# limit_test_batches=0.05 # limit_test_batches=0.05
) )
model_path = os.path.join(f'lightning_logs/version_' + model_path = os.path.join(
hparams.test_check_num, 'checkpoints/') f"lightning_logs/version_" + hparams.test_check_num, "checkpoints/"
)
model_path = list(Path(model_path).glob("*.ckpt"))[0] model_path = list(Path(model_path).glob("*.ckpt"))[0]
test_model = model.load_from_checkpoint(checkpoint_path=model_path, hparams=hparams) test_model = model.load_from_checkpoint(checkpoint_path=model_path, hparams=hparams)
@ -52,4 +53,3 @@ def main(hparams):
print() print()
trainer.test(model=test_model) trainer.test(model=test_model)

View File

@ -32,8 +32,8 @@ def main(hparams):
if hparams.gpu == 0: if hparams.gpu == 0:
hparams.gpu = 0 hparams.gpu = 0
else: else:
hparams.gpu = [hparams.gpu-1] hparams.gpu = [hparams.gpu - 1]
#print(hparams.gpus) # print(hparams.gpus)
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=hparams.max_epochs, max_epochs=hparams.max_epochs,
gpus=hparams.gpu, gpus=hparams.gpu,

View File

@ -10,13 +10,14 @@ def weights_init(m):
""" """
class_name = m.__class__.__name__ class_name = m.__class__.__name__
if class_name.find("Conv") != -1: if class_name.find("Conv") != -1:
torch.nn.init.kaiming_normal_(m.weight, torch.nn.init.kaiming_normal_(
mode="fan_out", m.weight, mode="fan_out", nonlinearity="relu"
nonlinearity="relu") # 初始化卷积层权重 ) # 初始化卷积层权重
# torch.nn.init.xavier_normal_(m.weight) # torch.nn.init.xavier_normal_(m.weight)
elif (class_name.find("BatchNorm") != -1 elif (
and class_name.find("WithFixedBatchNorm") == -1 class_name.find("BatchNorm") != -1
): # batch norm层不能用kaiming_normal初始化 and class_name.find("WithFixedBatchNorm") == -1
): # batch norm层不能用kaiming_normal初始化
torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.bias, 0)
# m.weight.data.normal_(1.0, 0.02) # m.weight.data.normal_(1.0, 0.02)
@ -45,9 +46,10 @@ def weights_init_without_kaiming(m):
if class_name.find("Conv") != -1: if class_name.find("Conv") != -1:
torch.nn.init.xavier_normal_(m.weight) torch.nn.init.xavier_normal_(m.weight)
# torch.nn.init.normal_(m.weight) # 初始化卷积层权重 # torch.nn.init.normal_(m.weight) # 初始化卷积层权重
elif (class_name.find("BatchNorm") != -1 elif (
and class_name.find("WithFixedBatchNorm") == -1 class_name.find("BatchNorm") != -1
): # batch norm层不能用kaiming_normal初始化 and class_name.find("WithFixedBatchNorm") == -1
): # batch norm层不能用kaiming_normal初始化
torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.bias, 0) torch.nn.init.constant_(m.bias, 0)
# m.weight.data.normal_(1.0, 0.02) # m.weight.data.normal_(1.0, 0.02)

View File

@ -33,7 +33,6 @@ class ToTensor:
class Resize: class Resize:
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
@ -43,7 +42,7 @@ class Resize:
for _ in range(4 - x_dim): for _ in range(4 - x_dim):
x_tensor = x_tensor.unsqueeze(0) x_tensor = x_tensor.unsqueeze(0)
x_resize = interpolate(x_tensor, size=self.size) x_resize = interpolate(x_tensor, size=self.size)
for _ in range(4-x_dim): for _ in range(4 - x_dim):
x_resize = x_resize.squeeze(0) x_resize = x_resize.squeeze(0)
return x_resize.numpy() return x_resize.numpy()

View File

@ -12,7 +12,9 @@ def get_upsampling_weight(in_channels, out_channels, kernel_size):
center = factor - 0.5 center = factor - 0.5
og = np.ogrid[:kernel_size, :kernel_size] og = np.ogrid[:kernel_size, :kernel_size]
filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) weight = np.zeros(
(in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64
)
weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt
return torch.from_numpy(weight).float() return torch.from_numpy(weight).float()