261 lines
8.8 KiB
Python
261 lines
8.8 KiB
Python
# coding: utf-8
|
||
|
||
# In[1]:
|
||
|
||
|
||
import os
|
||
import random
|
||
import cv2 as cv
|
||
import pandas as pd
|
||
import numpy as np
|
||
from skimage.io import imread
|
||
from skimage import segmentation
|
||
from skimage.util import img_as_float
|
||
from skimage.future import graph
|
||
from networkx.linalg import adj_matrix
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.cm as cm
|
||
from skimage import io
|
||
from skimage import measure
|
||
from skimage.io import imread
|
||
from skimage import segmentation
|
||
from skimage.util import img_as_float
|
||
from skimage.future import graph
|
||
from networkx.linalg import adj_matrix
|
||
import matplotlib
|
||
from skimage import io
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.autograd import Variable
|
||
from torch.utils.data import DataLoader, Dataset, TensorDataset
|
||
|
||
path = os.getcwd()
|
||
|
||
# get_ipython().magic('matplotlib inline')
|
||
aim = 'rgb'
|
||
aimpath = path + '\\' + aim
|
||
target = 'figure_ground'
|
||
targetpath = path + '\\' + target
|
||
|
||
# In[2]:
|
||
|
||
|
||
# 任务一:处理cnn的数据准备
|
||
savedictname1 = path + '\\' + aim + '_data_save_dict.npy'
|
||
savedictname2 = path + '\\' + target + '_data_save_dict.npy'
|
||
|
||
|
||
# 读取数据
|
||
def save_data_dict(savedictname, aimpath):
|
||
if 'figure_ground' in savedictname:
|
||
datalist = os.listdir(targetpath)
|
||
imglist = []
|
||
for i in datalist:
|
||
img = cv.imread(aimpath + '\\' + i, cv.IMREAD_GRAYSCALE)
|
||
imglist.append(img)
|
||
np.save(savedictname, imglist)
|
||
else:
|
||
datalist = os.listdir(aimpath)
|
||
imglist = []
|
||
for i in datalist:
|
||
img = img_as_float(io.imread(aimpath + '\\' + i))
|
||
imglist.append(img)
|
||
np.save(savedictname, imglist)
|
||
|
||
|
||
if not os.path.exists(savedictname1):
|
||
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
|
||
save_data_dict(savedictname1, aimpath)
|
||
print('初始化完毕,以后不会运行本次初始化任务')
|
||
|
||
if not os.path.exists(savedictname2):
|
||
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
|
||
save_data_dict(savedictname2, targetpath)
|
||
print('初始化完毕,以后不会运行本次初始化任务')
|
||
|
||
# In[3]:
|
||
|
||
|
||
piclist = np.load(savedictname1, allow_pickle=True)
|
||
targetlist = np.load(savedictname2, allow_pickle=True)
|
||
|
||
# 加载图片处理分割图片
|
||
from collections import defaultdict
|
||
|
||
save_img_dict = {}
|
||
for indexi in range(len(piclist)):
|
||
save_img_dict[int(indexi)] = {}
|
||
for indexi in range(len(piclist)):
|
||
for labeli in range(1, 101):
|
||
save_img_dict[indexi][labeli] = {
|
||
0: '',
|
||
1: ''
|
||
}
|
||
|
||
slic_target_dict = {
|
||
0: '纯背景无马',
|
||
1: '含有马部位'
|
||
}
|
||
|
||
savedictname = path + '\\' + 'slic_data_dict.npy'
|
||
|
||
|
||
def save_slic_data_dict(savedictname):
|
||
for indexi in range(len(piclist)):
|
||
img = piclist[indexi]
|
||
img = img_as_float(img)
|
||
targetpic = targetlist[indexi]
|
||
# SLIC 分割
|
||
labels = segmentation.slic(img, compactness=30, n_segments=100, start_label=1)
|
||
# labels 转 graph
|
||
g = graph.RAG(labels)
|
||
# 算邻接矩阵
|
||
|
||
for labeli in range(1, 100):
|
||
pot = (labels == labeli).astype(int)
|
||
propsa = measure.regionprops(pot) # pot连通区域
|
||
if propsa:
|
||
aa, bb, cc, dd = propsa[0].bbox # 外界边界框
|
||
img_pot = img[aa:cc, bb:dd, :]
|
||
w, h, c = img.shape
|
||
tagetpic_resize = cv.resize(targetpic, (h, w), interpolation=cv.INTER_AREA)
|
||
target_pot = tagetpic_resize[aa:cc, bb:dd]
|
||
if target_pot.size == 0:
|
||
print(indexi, labeli, 'size is 0')
|
||
targeti = 0
|
||
else:
|
||
targeti = target_pot.max()
|
||
if targeti != 0:
|
||
targeti = 1
|
||
save_img_dict[indexi][labeli][targeti] = img_pot
|
||
np.save(savedictname, save_img_dict)
|
||
|
||
|
||
if not os.path.exists(savedictname):
|
||
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
|
||
save_slic_data_dict(savedictname)
|
||
print('初始化完毕,以后不会运行本次初始化任务')
|
||
save_img_dict = np.load(savedictname, allow_pickle=True).item()
|
||
|
||
# In[4]:
|
||
|
||
|
||
# 任务2:保存图片到一个文件夹下
|
||
slicaimpath = path + '\\' + 'Slic'
|
||
if not os.path.exists(slicaimpath):
|
||
os.makedirs(slicaimpath)
|
||
|
||
for indexi in range(len(piclist)):
|
||
slicpicpath = slicaimpath + '\\' + str(indexi)
|
||
if not os.path.exists(slicpicpath):
|
||
os.makedirs(slicpicpath)
|
||
save_shrink_dict = {}
|
||
for indexi in range(len(piclist)):
|
||
save_shrink_dict[int(indexi)] = {}
|
||
for indexi in range(len(piclist)):
|
||
for labeli in range(1, 101):
|
||
save_shrink_dict[indexi][labeli] = {
|
||
0: '',
|
||
1: ''
|
||
}
|
||
|
||
shrinkdictname = path + '\\' + 'shrink_32_slic_data_dict.npy'
|
||
|
||
|
||
def save_slic_shrink_dict(savedictname):
|
||
for indexi in range(len(piclist)):
|
||
for labeli in range(1, 100):
|
||
for targeti in range(2):
|
||
img_pot = save_img_dict[indexi][labeli][targeti]
|
||
if img_pot == '':
|
||
img_pot = ''
|
||
else:
|
||
if img_pot.size == 0:
|
||
img_pot = img_pot
|
||
else:
|
||
img_pot = cv.resize(img_pot, (32, 32), interpolation=cv.INTER_AREA)
|
||
save_img_dict[indexi][labeli][targeti] = img_pot
|
||
np.save(savedictname, save_img_dict)
|
||
|
||
|
||
if not os.path.exists(shrinkdictname):
|
||
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
|
||
save_slic_shrink_dict(shrinkdictname)
|
||
print('初始化完毕,以后不会运行本次初始化任务')
|
||
save_shrink_dict = np.load(shrinkdictname, allow_pickle=True).item()
|
||
# 建立for循环语句,绘制x的前三列
|
||
for labelx in range(328):
|
||
pic1slic = []
|
||
pic1target = []
|
||
for label, labeldict in save_shrink_dict[labelx].items():
|
||
for target, img in labeldict.items():
|
||
if img != '':
|
||
pic1slic.append(img)
|
||
pic1target.append(target)
|
||
|
||
for i in range(len(pic1slic)):
|
||
plt.title("{}_{}".format(i, pic1target[i]))
|
||
plt.imshow(pic1slic[i])
|
||
plt.savefig(
|
||
path + '\\Slic\\' + str(labelx) + '\\' + "slic{}_{}.jpg".format(i, pic1target[i])) # 输入地址,并利用format函数修改图片名称
|
||
plt.clf() # 需要重新更新画布,否则会出现同一张画布上绘制多张图片
|
||
|
||
# In[ ]:
|
||
|
||
|
||
# 任务3:建立数据集
|
||
flat_list_name = path + '\\' + 'flat_label_list.npy'
|
||
|
||
|
||
def save_flat_list_name(flat_list_name):
|
||
flat_data_list = []
|
||
for picnumber, labelnumdict in save_shrink_dict.items():
|
||
for labeli, labeldict in labelnumdict.items():
|
||
for label, arrayx in labeldict.items():
|
||
if arrayx != '':
|
||
flat_data_list.append((label, arrayx))
|
||
np.save(flat_list_name, flat_data_list)
|
||
|
||
|
||
if not os.path.exists(flat_list_name):
|
||
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
|
||
save_flat_list_name(flat_list_name)
|
||
print('初始化完毕,以后不会运行本次初始化任务')
|
||
|
||
flat_list_name = path + '\\' + 'flat_label_list.npy'
|
||
|
||
|
||
def save_flat_list_name(flat_list_name):
|
||
flat_data_list = []
|
||
for picnumber, labelnumdict in save_shrink_dict.items():
|
||
for labeli, labeldict in labelnumdict.items():
|
||
for label, arrayx in labeldict.items():
|
||
if arrayx != '':
|
||
flat_data_list.append((label, arrayx))
|
||
np.save(flat_list_name, flat_data_list)
|
||
|
||
|
||
if not os.path.exists(flat_list_name):
|
||
print('不存在数字字典,此次开始初始化运行,花费时间较长,请耐心等待')
|
||
save_flat_list_name(flat_list_name)
|
||
print('初始化完毕,以后不会运行本次初始化任务')
|
||
flat_data_list = np.load(flat_list_name, allow_pickle=True)
|
||
random.shuffle(flat_data_list)
|
||
|
||
train_split = 0.7
|
||
len_all_data = len(flat_data_list)
|
||
len_train = int(train_split * len_all_data)
|
||
len_test = len_all_data - len_train
|
||
train_x = [flat_data_list[i][1] for i in range(len_train)]
|
||
train_y = [flat_data_list[i][0] for i in range(len_train)]
|
||
test_x = [flat_data_list[i][1] for i in range(len_train, len_all_data)]
|
||
test_y = [flat_data_list[i][0] for i in range(len_train, len_all_data)]
|
||
print('拥有训练集数据x{}个,训练集y数据{}个;测试集数据x{}个,训练集y数据{}个'.format(len(train_x), len(train_y), len(test_x), len(test_y)))
|
||
train_x = torch.from_numpy(np.array(train_x)).float().reshape([len_train, 3, 32, 32])
|
||
train_y = torch.from_numpy(np.array(train_y)).float()
|
||
test_x = torch.from_numpy(np.array(test_x)).float().reshape([len_test, 3, 32, 32])
|
||
test_y = torch.from_numpy(np.array(test_y))
|
||
batch_size = 128
|
||
sp_train_set = TensorDataset(train_x, train_y)
|
||
sp_test_set = TensorDataset(test_x, test_y)
|