url/cnn.py

110 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import pathlib
import cv2
# %% 构建模型
def create_model():
model = Sequential([
layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(24, 24, 1)),
layers.Conv2D(24, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(96, activation='relu'),
layers.Dense(15)]
)
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
return model
# %% 训练数据
def train():
# 统计文件夹下的所有图片数量
data_dir = pathlib.Path('dataset')
batch_size = 64
img_width = 24
img_height = 24
# 从文件夹下读取图片,生成数据集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset='training',
seed=123,
color_mode="grayscale",
image_size=(img_height, img_width),
batch_size=batch_size
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
color_mode="grayscale",
image_size=(img_height, img_width),
batch_size=batch_size
)
# 数据集的分类对应dataset文件夹下有多少图片分类
class_names = train_ds.class_names
# 保存数据集分类
np.save("checkpoint/class_name.npy", class_names)
# 数据集缓存处理
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
# 创建模型
model = create_model()
# 训练模型epochs=10所有数据集训练10遍
model.fit(train_ds, validation_data=val_ds, epochs=20)
# 保存训练后的权重
model.save_weights('checkpoint/char_checkpoint')
# %% 预测
def predict(model, imgs, class_name):
label_dict = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '=', 11: '+',
12: '-', 13: '×', 14: '÷'}
# 预测图片,获取预测值
predicts = model.predict(imgs)
results = [] # 保存结果的数组
for predict in predicts: # 遍历每一个预测结果
index = np.argmax(predict) # 寻找最大值
result = class_name[index] # 取出字符
results.append(label_dict[int(result)])
return results
# %%
if __name__ == '__main__':
train()
# model = create_model()
# # 加载前期训练好的权重
# model.load_weights('checkpoint/char_checkpoint')
# # 读出图片分类
# class_name = np.load('checkpoint/class_name.npy')
# print(class_name)
# img1=cv2.imread('img1.png',0)
# img2=cv2.imread('img2.png',0)
# img3=cv2.imread('img3.png',0)
# img4=cv2.imread('img4.png',0)
# img5=cv2.imread('img5.png',0)
# img6=cv2.imread('img6.png',0)
# imgs = np.array([img1,img2,img3,img4,img5,img6])
# results = predict(model, imgs, class_name)
# print(results)