CNN1/Devnet for pre.py

87 lines
2.8 KiB
Python
Raw Normal View History

2023-11-25 16:16:45 +08:00
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,datasets
from torchvision.transforms import ToTensor
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
device=torch.device("cpu")
cropbox=(0,0,316,316)
batch_size = 1
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.features=nn.Sequential(
nn.Conv2d(3,6, kernel_size=5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,12, kernel_size=5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(12,24,kernel_size=5,stride=1),
nn.ReLU(inplace=True),
nn.Conv2d(24,24,kernel_size=5,stride=1),
nn.ReLU(inplace=True),
nn.Conv2d(24,48,kernel_size=5,stride=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2),
)
# 请把之前 Network1 的 self.fc1 复制到这里:
# self.fc1 = nn.Linear(in_features=XXX, out_features=120)
self.classifier=nn.Sequential(
nn.Linear(32*32*48,2048),
nn.ReLU(inplace=True),
#nn.Dropout(0.5),
nn.Linear(2048,1024),
nn.ReLU(inplace=True),
#nn.Dropout(0.5),
nn.Linear(1024,1),
)
def forward(self, t):
t=self.features(t)
# 特别注意:最后一个卷积层的输出在流入全连接层之前,要对激活值的维度进行变换,变换后的张量
# 是二维张量,其中 dim0 对应该批次中的不同样本dim1 对应每个样本流向全连接层的输入激活向量。
# 请完成如下代码。
# dim0一个批次的样本数
# dim1每一个样本的输入激活向量长度
# t = t.reshape(XXX,XXX)
t = t.reshape(batch_size,32*32*48)
t = self.classifier(t)
return t
const_avg=3.6156537532806396
confidence_margin = 5.0
if __name__ == '__main__':
network=CNN()
network = torch.load("Dev_netv1.0_CNN.pth") # PATH same as above
network.eval()
print("模型已导入完成。")
path=input("请输入待识别图片地址:")
img = Image.open(path)
img_2=img.crop(cropbox)
img_ten=transforms.ToTensor()(np.array(img_2))
pred=network(img_ten).item()
if pred<=const_avg:
print("该图片正常。预测输出:{:3f}".format(pred))
elif pred>=confidence_margin:
print("该图片异常。预测输出:{:3f}".format(pred))
else:
output=(pred-const_avg)/(confidence_margin-const_avg)
print("无法完全确定该图片是否正常。参考值:{:.3f}。参考值靠近0说明正常可能性更大靠近1说明异常可能性更大。\n预测输出:{:3f}".format(output,pred))