weizmann_horse_detection/pre_cnn.py

261 lines
8.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding: utf-8
# In[1]:
import os
import random
import cv2 as cv
import pandas as pd
import numpy as np
from skimage.io import imread
from skimage import segmentation
from skimage.util import img_as_float
from skimage.future import graph
from networkx.linalg import adj_matrix
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from skimage import io
from skimage import measure
from skimage.io import imread
from skimage import segmentation
from skimage.util import img_as_float
from skimage.future import graph
from networkx.linalg import adj_matrix
import matplotlib
from skimage import io
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset, TensorDataset
path = os.getcwd()
# get_ipython().magic('matplotlib inline')
aim = 'rgb'
aimpath = path + '\\' + aim
target = 'figure_ground'
targetpath = path + '\\' + target
# In[2]:
# 任务一处理cnn的数据准备
savedictname1 = path + '\\' + aim + '_data_save_dict.npy'
savedictname2 = path + '\\' + target + '_data_save_dict.npy'
# 读取数据
def save_data_dict(savedictname, aimpath):
if 'figure_ground' in savedictname:
datalist = os.listdir(targetpath)
imglist = []
for i in datalist:
img = cv.imread(aimpath + '\\' + i, cv.IMREAD_GRAYSCALE)
imglist.append(img)
np.save(savedictname, imglist)
else:
datalist = os.listdir(aimpath)
imglist = []
for i in datalist:
img = img_as_float(io.imread(aimpath + '\\' + i))
imglist.append(img)
np.save(savedictname, imglist)
if not os.path.exists(savedictname1):
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
save_data_dict(savedictname1, aimpath)
print('初始化完毕,以后不会运行本次初始化任务')
if not os.path.exists(savedictname2):
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
save_data_dict(savedictname2, targetpath)
print('初始化完毕,以后不会运行本次初始化任务')
# In[3]:
piclist = np.load(savedictname1, allow_pickle=True)
targetlist = np.load(savedictname2, allow_pickle=True)
# 加载图片处理分割图片
from collections import defaultdict
save_img_dict = {}
for indexi in range(len(piclist)):
save_img_dict[int(indexi)] = {}
for indexi in range(len(piclist)):
for labeli in range(1, 101):
save_img_dict[indexi][labeli] = {
0: '',
1: ''
}
slic_target_dict = {
0: '纯背景无马',
1: '含有马部位'
}
savedictname = path + '\\' + 'slic_data_dict.npy'
def save_slic_data_dict(savedictname):
for indexi in range(len(piclist)):
img = piclist[indexi]
img = img_as_float(img)
targetpic = targetlist[indexi]
# SLIC 分割
labels = segmentation.slic(img, compactness=30, n_segments=100, start_label=1)
# labels 转 graph
g = graph.RAG(labels)
# 算邻接矩阵
for labeli in range(1, 100):
pot = (labels == labeli).astype(int)
propsa = measure.regionprops(pot) # pot连通区域
if propsa:
aa, bb, cc, dd = propsa[0].bbox # 外界边界框
img_pot = img[aa:cc, bb:dd, :]
w, h, c = img.shape
tagetpic_resize = cv.resize(targetpic, (h, w), interpolation=cv.INTER_AREA)
target_pot = tagetpic_resize[aa:cc, bb:dd]
if target_pot.size == 0:
print(indexi, labeli, 'size is 0')
targeti = 0
else:
targeti = target_pot.max()
if targeti != 0:
targeti = 1
save_img_dict[indexi][labeli][targeti] = img_pot
np.save(savedictname, save_img_dict)
if not os.path.exists(savedictname):
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
save_slic_data_dict(savedictname)
print('初始化完毕,以后不会运行本次初始化任务')
save_img_dict = np.load(savedictname, allow_pickle=True).item()
# In[4]:
# 任务2保存图片到一个文件夹下
slicaimpath = path + '\\' + 'Slic'
if not os.path.exists(slicaimpath):
os.makedirs(slicaimpath)
for indexi in range(len(piclist)):
slicpicpath = slicaimpath + '\\' + str(indexi)
if not os.path.exists(slicpicpath):
os.makedirs(slicpicpath)
save_shrink_dict = {}
for indexi in range(len(piclist)):
save_shrink_dict[int(indexi)] = {}
for indexi in range(len(piclist)):
for labeli in range(1, 101):
save_shrink_dict[indexi][labeli] = {
0: '',
1: ''
}
shrinkdictname = path + '\\' + 'shrink_32_slic_data_dict.npy'
def save_slic_shrink_dict(savedictname):
for indexi in range(len(piclist)):
for labeli in range(1, 100):
for targeti in range(2):
img_pot = save_img_dict[indexi][labeli][targeti]
if img_pot == '':
img_pot = ''
else:
if img_pot.size == 0:
img_pot = img_pot
else:
img_pot = cv.resize(img_pot, (32, 32), interpolation=cv.INTER_AREA)
save_img_dict[indexi][labeli][targeti] = img_pot
np.save(savedictname, save_img_dict)
if not os.path.exists(shrinkdictname):
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
save_slic_shrink_dict(shrinkdictname)
print('初始化完毕,以后不会运行本次初始化任务')
save_shrink_dict = np.load(shrinkdictname, allow_pickle=True).item()
# 建立for循环语句绘制x的前三列
for labelx in range(328):
pic1slic = []
pic1target = []
for label, labeldict in save_shrink_dict[labelx].items():
for target, img in labeldict.items():
if img != '':
pic1slic.append(img)
pic1target.append(target)
for i in range(len(pic1slic)):
plt.title("{}_{}".format(i, pic1target[i]))
plt.imshow(pic1slic[i])
plt.savefig(
path + '\\Slic\\' + str(labelx) + '\\' + "slic{}_{}.jpg".format(i, pic1target[i])) # 输入地址并利用format函数修改图片名称
plt.clf() # 需要重新更新画布,否则会出现同一张画布上绘制多张图片
# In[ ]:
# 任务3建立数据集
flat_list_name = path + '\\' + 'flat_label_list.npy'
def save_flat_list_name(flat_list_name):
flat_data_list = []
for picnumber, labelnumdict in save_shrink_dict.items():
for labeli, labeldict in labelnumdict.items():
for label, arrayx in labeldict.items():
if arrayx != '':
flat_data_list.append((label, arrayx))
np.save(flat_list_name, flat_data_list)
if not os.path.exists(flat_list_name):
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
save_flat_list_name(flat_list_name)
print('初始化完毕,以后不会运行本次初始化任务')
flat_list_name = path + '\\' + 'flat_label_list.npy'
def save_flat_list_name(flat_list_name):
flat_data_list = []
for picnumber, labelnumdict in save_shrink_dict.items():
for labeli, labeldict in labelnumdict.items():
for label, arrayx in labeldict.items():
if arrayx != '':
flat_data_list.append((label, arrayx))
np.save(flat_list_name, flat_data_list)
if not os.path.exists(flat_list_name):
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
save_flat_list_name(flat_list_name)
print('初始化完毕,以后不会运行本次初始化任务')
flat_data_list = np.load(flat_list_name, allow_pickle=True)
random.shuffle(flat_data_list)
train_split = 0.7
len_all_data = len(flat_data_list)
len_train = int(train_split * len_all_data)
len_test = len_all_data - len_train
train_x = [flat_data_list[i][1] for i in range(len_train)]
train_y = [flat_data_list[i][0] for i in range(len_train)]
test_x = [flat_data_list[i][1] for i in range(len_train, len_all_data)]
test_y = [flat_data_list[i][0] for i in range(len_train, len_all_data)]
print('拥有训练集数据x{}个,训练集y数据{}测试集数据x{}训练集y数据{}'.format(len(train_x), len(train_y), len(test_x), len(test_y)))
train_x = torch.from_numpy(np.array(train_x)).float().reshape([len_train, 3, 32, 32])
train_y = torch.from_numpy(np.array(train_y)).float()
test_x = torch.from_numpy(np.array(test_x)).float().reshape([len_test, 3, 32, 32])
test_y = torch.from_numpy(np.array(test_y))
batch_size = 128
sp_train_set = TensorDataset(train_x, train_y)
sp_test_set = TensorDataset(test_x, test_y)