ADD file via upload
This commit is contained in:
parent
5e96e2046d
commit
4201f20d35
|
@ -0,0 +1,87 @@
|
|||
import torch
|
||||
import torchvision
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms,datasets
|
||||
from torchvision.transforms import ToTensor
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
|
||||
|
||||
device=torch.device("cpu")
|
||||
|
||||
cropbox=(0,0,316,316)
|
||||
batch_size = 1
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.features=nn.Sequential(
|
||||
nn.Conv2d(3,6, kernel_size=5),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2,stride=2),
|
||||
nn.Conv2d(6,12, kernel_size=5),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2,stride=2),
|
||||
nn.Conv2d(12,24,kernel_size=5,stride=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(24,24,kernel_size=5,stride=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(24,48,kernel_size=5,stride=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2,stride=2),
|
||||
)
|
||||
|
||||
# 请把之前 Network1 的 self.fc1 复制到这里:
|
||||
# self.fc1 = nn.Linear(in_features=XXX, out_features=120)
|
||||
|
||||
self.classifier=nn.Sequential(
|
||||
nn.Linear(32*32*48,2048),
|
||||
nn.ReLU(inplace=True),
|
||||
#nn.Dropout(0.5),
|
||||
nn.Linear(2048,1024),
|
||||
nn.ReLU(inplace=True),
|
||||
#nn.Dropout(0.5),
|
||||
nn.Linear(1024,1),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, t):
|
||||
t=self.features(t)
|
||||
# 特别注意:最后一个卷积层的输出在流入全连接层之前,要对激活值的维度进行变换,变换后的张量
|
||||
# 是二维张量,其中 dim0 对应该批次中的不同样本,dim1 对应每个样本流向全连接层的输入激活向量。
|
||||
# 请完成如下代码。
|
||||
|
||||
# dim0:一个批次的样本数
|
||||
# dim1:每一个样本的输入激活向量长度
|
||||
|
||||
# t = t.reshape(XXX,XXX)
|
||||
t = t.reshape(batch_size,32*32*48)
|
||||
|
||||
t = self.classifier(t)
|
||||
|
||||
return t
|
||||
|
||||
const_avg=3.6156537532806396
|
||||
confidence_margin = 5.0
|
||||
|
||||
if __name__ == '__main__':
|
||||
network=CNN()
|
||||
network = torch.load("Dev_netv1.0_CNN.pth") # PATH same as above
|
||||
network.eval()
|
||||
print("模型已导入完成。")
|
||||
path=input("请输入待识别图片地址:")
|
||||
img = Image.open(path)
|
||||
img_2=img.crop(cropbox)
|
||||
img_ten=transforms.ToTensor()(np.array(img_2))
|
||||
pred=network(img_ten).item()
|
||||
if pred<=const_avg:
|
||||
print("该图片正常。预测输出:{:3f}".format(pred))
|
||||
elif pred>=confidence_margin:
|
||||
print("该图片异常。预测输出:{:3f}".format(pred))
|
||||
else:
|
||||
output=(pred-const_avg)/(confidence_margin-const_avg)
|
||||
print("无法完全确定该图片是否正常。参考值:{:.3f}。参考值靠近0说明正常可能性更大,靠近1说明异常可能性更大。\n预测输出:{:3f}".format(output,pred))
|
Loading…
Reference in New Issue