import torch import torchvision from torch.utils.data import Dataset, DataLoader from torchvision import transforms,datasets from torchvision.transforms import ToTensor import numpy as np import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import matplotlib.pyplot as plt from PIL import Image device=torch.device("cpu") cropbox=(0,0,316,316) batch_size = 1 class CNN(nn.Module): def __init__(self): super().__init__() self.features=nn.Sequential( nn.Conv2d(3,6, kernel_size=5), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2,stride=2), nn.Conv2d(6,12, kernel_size=5), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2,stride=2), nn.Conv2d(12,24,kernel_size=5,stride=1), nn.ReLU(inplace=True), nn.Conv2d(24,24,kernel_size=5,stride=1), nn.ReLU(inplace=True), nn.Conv2d(24,48,kernel_size=5,stride=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2,stride=2), ) # 请把之前 Network1 的 self.fc1 复制到这里: # self.fc1 = nn.Linear(in_features=XXX, out_features=120) self.classifier=nn.Sequential( nn.Linear(32*32*48,2048), nn.ReLU(inplace=True), #nn.Dropout(0.5), nn.Linear(2048,1024), nn.ReLU(inplace=True), #nn.Dropout(0.5), nn.Linear(1024,1), ) def forward(self, t): t=self.features(t) # 特别注意:最后一个卷积层的输出在流入全连接层之前,要对激活值的维度进行变换,变换后的张量 # 是二维张量,其中 dim0 对应该批次中的不同样本,dim1 对应每个样本流向全连接层的输入激活向量。 # 请完成如下代码。 # dim0:一个批次的样本数 # dim1:每一个样本的输入激活向量长度 # t = t.reshape(XXX,XXX) t = t.reshape(batch_size,32*32*48) t = self.classifier(t) return t const_avg=3.6156537532806396 confidence_margin = 5.0 if __name__ == '__main__': network=CNN() network = torch.load("Dev_netv1.0_CNN.pth") # PATH same as above network.eval() print("模型已导入完成。") path=input("请输入待识别图片地址:") img = Image.open(path) img_2=img.crop(cropbox) img_ten=transforms.ToTensor()(np.array(img_2)) pred=network(img_ten).item() if pred<=const_avg: print("该图片正常。预测输出:{:3f}".format(pred)) elif pred>=confidence_margin: print("该图片异常。预测输出:{:3f}".format(pred)) else: output=(pred-const_avg)/(confidence_margin-const_avg) print("无法完全确定该图片是否正常。参考值:{:.3f}。参考值靠近0说明正常可能性更大,靠近1说明异常可能性更大。\n预测输出:{:3f}".format(output,pred))