87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
import torch
|
||
import torchvision
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from torchvision import transforms,datasets
|
||
from torchvision.transforms import ToTensor
|
||
import numpy as np
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torch.optim as optim
|
||
import matplotlib.pyplot as plt
|
||
from PIL import Image
|
||
|
||
|
||
device=torch.device("cpu")
|
||
|
||
cropbox=(0,0,316,316)
|
||
batch_size = 1
|
||
|
||
class CNN(nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.features=nn.Sequential(
|
||
nn.Conv2d(3,6, kernel_size=5),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(kernel_size=2,stride=2),
|
||
nn.Conv2d(6,12, kernel_size=5),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(kernel_size=2,stride=2),
|
||
nn.Conv2d(12,24,kernel_size=5,stride=1),
|
||
nn.ReLU(inplace=True),
|
||
nn.Conv2d(24,24,kernel_size=5,stride=1),
|
||
nn.ReLU(inplace=True),
|
||
nn.Conv2d(24,48,kernel_size=5,stride=1),
|
||
nn.ReLU(inplace=True),
|
||
nn.MaxPool2d(kernel_size=2,stride=2),
|
||
)
|
||
|
||
# 请把之前 Network1 的 self.fc1 复制到这里:
|
||
# self.fc1 = nn.Linear(in_features=XXX, out_features=120)
|
||
|
||
self.classifier=nn.Sequential(
|
||
nn.Linear(32*32*48,2048),
|
||
nn.ReLU(inplace=True),
|
||
#nn.Dropout(0.5),
|
||
nn.Linear(2048,1024),
|
||
nn.ReLU(inplace=True),
|
||
#nn.Dropout(0.5),
|
||
nn.Linear(1024,1),
|
||
)
|
||
|
||
|
||
def forward(self, t):
|
||
t=self.features(t)
|
||
# 特别注意:最后一个卷积层的输出在流入全连接层之前,要对激活值的维度进行变换,变换后的张量
|
||
# 是二维张量,其中 dim0 对应该批次中的不同样本,dim1 对应每个样本流向全连接层的输入激活向量。
|
||
# 请完成如下代码。
|
||
|
||
# dim0:一个批次的样本数
|
||
# dim1:每一个样本的输入激活向量长度
|
||
|
||
# t = t.reshape(XXX,XXX)
|
||
t = t.reshape(batch_size,32*32*48)
|
||
|
||
t = self.classifier(t)
|
||
|
||
return t
|
||
|
||
const_avg=3.6156537532806396
|
||
confidence_margin = 5.0
|
||
|
||
if __name__ == '__main__':
|
||
network=CNN()
|
||
network = torch.load("Dev_netv1.0_CNN.pth") # PATH same as above
|
||
network.eval()
|
||
print("模型已导入完成。")
|
||
path=input("请输入待识别图片地址:")
|
||
img = Image.open(path)
|
||
img_2=img.crop(cropbox)
|
||
img_ten=transforms.ToTensor()(np.array(img_2))
|
||
pred=network(img_ten).item()
|
||
if pred<=const_avg:
|
||
print("该图片正常。预测输出:{:3f}".format(pred))
|
||
elif pred>=confidence_margin:
|
||
print("该图片异常。预测输出:{:3f}".format(pred))
|
||
else:
|
||
output=(pred-const_avg)/(confidence_margin-const_avg)
|
||
print("无法完全确定该图片是否正常。参考值:{:.3f}。参考值靠近0说明正常可能性更大,靠近1说明异常可能性更大。\n预测输出:{:3f}".format(output,pred)) |