From 4201f20d356940eb861f503632566151f688e36b Mon Sep 17 00:00:00 2001 From: p74035216 Date: Sat, 25 Nov 2023 16:16:45 +0800 Subject: [PATCH] ADD file via upload --- Devnet for pre.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 Devnet for pre.py diff --git a/Devnet for pre.py b/Devnet for pre.py new file mode 100644 index 0000000..020425d --- /dev/null +++ b/Devnet for pre.py @@ -0,0 +1,87 @@ +import torch +import torchvision +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms,datasets +from torchvision.transforms import ToTensor +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import matplotlib.pyplot as plt +from PIL import Image + + +device=torch.device("cpu") + +cropbox=(0,0,316,316) +batch_size = 1 + +class CNN(nn.Module): + def __init__(self): + super().__init__() + self.features=nn.Sequential( + nn.Conv2d(3,6, kernel_size=5), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2,stride=2), + nn.Conv2d(6,12, kernel_size=5), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2,stride=2), + nn.Conv2d(12,24,kernel_size=5,stride=1), + nn.ReLU(inplace=True), + nn.Conv2d(24,24,kernel_size=5,stride=1), + nn.ReLU(inplace=True), + nn.Conv2d(24,48,kernel_size=5,stride=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2,stride=2), + ) + + # 请把之前 Network1 的 self.fc1 复制到这里: + # self.fc1 = nn.Linear(in_features=XXX, out_features=120) + + self.classifier=nn.Sequential( + nn.Linear(32*32*48,2048), + nn.ReLU(inplace=True), + #nn.Dropout(0.5), + nn.Linear(2048,1024), + nn.ReLU(inplace=True), + #nn.Dropout(0.5), + nn.Linear(1024,1), + ) + + + def forward(self, t): + t=self.features(t) + # 特别注意:最后一个卷积层的输出在流入全连接层之前,要对激活值的维度进行变换,变换后的张量 + # 是二维张量,其中 dim0 对应该批次中的不同样本,dim1 对应每个样本流向全连接层的输入激活向量。 + # 请完成如下代码。 + + # dim0:一个批次的样本数 + # dim1:每一个样本的输入激活向量长度 + + # t = t.reshape(XXX,XXX) + t = t.reshape(batch_size,32*32*48) + + t = self.classifier(t) + + return t + +const_avg=3.6156537532806396 +confidence_margin = 5.0 + +if __name__ == '__main__': + network=CNN() + network = torch.load("Dev_netv1.0_CNN.pth") # PATH same as above + network.eval() + print("模型已导入完成。") + path=input("请输入待识别图片地址:") + img = Image.open(path) + img_2=img.crop(cropbox) + img_ten=transforms.ToTensor()(np.array(img_2)) + pred=network(img_ten).item() + if pred<=const_avg: + print("该图片正常。预测输出:{:3f}".format(pred)) + elif pred>=confidence_margin: + print("该图片异常。预测输出:{:3f}".format(pred)) + else: + output=(pred-const_avg)/(confidence_margin-const_avg) + print("无法完全确定该图片是否正常。参考值:{:.3f}。参考值靠近0说明正常可能性更大,靠近1说明异常可能性更大。\n预测输出:{:3f}".format(output,pred)) \ No newline at end of file