CNN1/Devnet for pre.py

87 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms,datasets
from torchvision.transforms import ToTensor
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image
device=torch.device("cpu")
cropbox=(0,0,316,316)
batch_size = 1
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.features=nn.Sequential(
nn.Conv2d(3,6, kernel_size=5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(6,12, kernel_size=5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(12,24,kernel_size=5,stride=1),
nn.ReLU(inplace=True),
nn.Conv2d(24,24,kernel_size=5,stride=1),
nn.ReLU(inplace=True),
nn.Conv2d(24,48,kernel_size=5,stride=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2),
)
# 请把之前 Network1 的 self.fc1 复制到这里:
# self.fc1 = nn.Linear(in_features=XXX, out_features=120)
self.classifier=nn.Sequential(
nn.Linear(32*32*48,2048),
nn.ReLU(inplace=True),
#nn.Dropout(0.5),
nn.Linear(2048,1024),
nn.ReLU(inplace=True),
#nn.Dropout(0.5),
nn.Linear(1024,1),
)
def forward(self, t):
t=self.features(t)
# 特别注意:最后一个卷积层的输出在流入全连接层之前,要对激活值的维度进行变换,变换后的张量
# 是二维张量,其中 dim0 对应该批次中的不同样本dim1 对应每个样本流向全连接层的输入激活向量。
# 请完成如下代码。
# dim0一个批次的样本数
# dim1每一个样本的输入激活向量长度
# t = t.reshape(XXX,XXX)
t = t.reshape(batch_size,32*32*48)
t = self.classifier(t)
return t
const_avg=3.6156537532806396
confidence_margin = 5.0
if __name__ == '__main__':
network=CNN()
network = torch.load("Dev_netv1.0_CNN.pth") # PATH same as above
network.eval()
print("模型已导入完成。")
path=input("请输入待识别图片地址:")
img = Image.open(path)
img_2=img.crop(cropbox)
img_ten=transforms.ToTensor()(np.array(img_2))
pred=network(img_ten).item()
if pred<=const_avg:
print("该图片正常。预测输出:{:3f}".format(pred))
elif pred>=confidence_margin:
print("该图片异常。预测输出:{:3f}".format(pred))
else:
output=(pred-const_avg)/(confidence_margin-const_avg)
print("无法完全确定该图片是否正常。参考值:{:.3f}。参考值靠近0说明正常可能性更大靠近1说明异常可能性更大。\n预测输出:{:3f}".format(output,pred))