87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
|
import torch
|
|||
|
import torchvision
|
|||
|
from torch.utils.data import Dataset, DataLoader
|
|||
|
from torchvision import transforms,datasets
|
|||
|
from torchvision.transforms import ToTensor
|
|||
|
import numpy as np
|
|||
|
import torch.nn as nn
|
|||
|
import torch.nn.functional as F
|
|||
|
import torch.optim as optim
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
from PIL import Image
|
|||
|
|
|||
|
|
|||
|
device=torch.device("cpu")
|
|||
|
|
|||
|
cropbox=(0,0,316,316)
|
|||
|
batch_size = 1
|
|||
|
|
|||
|
class CNN(nn.Module):
|
|||
|
def __init__(self):
|
|||
|
super().__init__()
|
|||
|
self.features=nn.Sequential(
|
|||
|
nn.Conv2d(3,6, kernel_size=5),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
nn.MaxPool2d(kernel_size=2,stride=2),
|
|||
|
nn.Conv2d(6,12, kernel_size=5),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
nn.MaxPool2d(kernel_size=2,stride=2),
|
|||
|
nn.Conv2d(12,24,kernel_size=5,stride=1),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
nn.Conv2d(24,24,kernel_size=5,stride=1),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
nn.Conv2d(24,48,kernel_size=5,stride=1),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
nn.MaxPool2d(kernel_size=2,stride=2),
|
|||
|
)
|
|||
|
|
|||
|
# 请把之前 Network1 的 self.fc1 复制到这里:
|
|||
|
# self.fc1 = nn.Linear(in_features=XXX, out_features=120)
|
|||
|
|
|||
|
self.classifier=nn.Sequential(
|
|||
|
nn.Linear(32*32*48,2048),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
#nn.Dropout(0.5),
|
|||
|
nn.Linear(2048,1024),
|
|||
|
nn.ReLU(inplace=True),
|
|||
|
#nn.Dropout(0.5),
|
|||
|
nn.Linear(1024,1),
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
def forward(self, t):
|
|||
|
t=self.features(t)
|
|||
|
# 特别注意:最后一个卷积层的输出在流入全连接层之前,要对激活值的维度进行变换,变换后的张量
|
|||
|
# 是二维张量,其中 dim0 对应该批次中的不同样本,dim1 对应每个样本流向全连接层的输入激活向量。
|
|||
|
# 请完成如下代码。
|
|||
|
|
|||
|
# dim0:一个批次的样本数
|
|||
|
# dim1:每一个样本的输入激活向量长度
|
|||
|
|
|||
|
# t = t.reshape(XXX,XXX)
|
|||
|
t = t.reshape(batch_size,32*32*48)
|
|||
|
|
|||
|
t = self.classifier(t)
|
|||
|
|
|||
|
return t
|
|||
|
|
|||
|
const_avg=3.6156537532806396
|
|||
|
confidence_margin = 5.0
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
network=CNN()
|
|||
|
network = torch.load("Dev_netv1.0_CNN.pth") # PATH same as above
|
|||
|
network.eval()
|
|||
|
print("模型已导入完成。")
|
|||
|
path=input("请输入待识别图片地址:")
|
|||
|
img = Image.open(path)
|
|||
|
img_2=img.crop(cropbox)
|
|||
|
img_ten=transforms.ToTensor()(np.array(img_2))
|
|||
|
pred=network(img_ten).item()
|
|||
|
if pred<=const_avg:
|
|||
|
print("该图片正常。预测输出:{:3f}".format(pred))
|
|||
|
elif pred>=confidence_margin:
|
|||
|
print("该图片异常。预测输出:{:3f}".format(pred))
|
|||
|
else:
|
|||
|
output=(pred-const_avg)/(confidence_margin-const_avg)
|
|||
|
print("无法完全确定该图片是否正常。参考值:{:.3f}。参考值靠近0说明正常可能性更大,靠近1说明异常可能性更大。\n预测输出:{:3f}".format(output,pred))
|