110 lines
3.5 KiB
Python
110 lines
3.5 KiB
Python
|
import tensorflow as tf
|
|||
|
import numpy as np
|
|||
|
from tensorflow.keras import layers
|
|||
|
from tensorflow.keras.models import Sequential
|
|||
|
import pathlib
|
|||
|
import cv2
|
|||
|
|
|||
|
|
|||
|
# %% 构建模型
|
|||
|
def create_model():
|
|||
|
model = Sequential([
|
|||
|
layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(24, 24, 1)),
|
|||
|
layers.Conv2D(24, 3, padding='same', activation='relu'),
|
|||
|
layers.MaxPooling2D(),
|
|||
|
layers.Conv2D(32, 3, padding='same', activation='relu'),
|
|||
|
layers.MaxPooling2D(),
|
|||
|
layers.Conv2D(64, 3, padding='same', activation='relu'),
|
|||
|
layers.MaxPooling2D(),
|
|||
|
layers.Dropout(0.2),
|
|||
|
layers.Flatten(),
|
|||
|
layers.Dense(96, activation='relu'),
|
|||
|
layers.Dense(15)]
|
|||
|
)
|
|||
|
|
|||
|
model.compile(optimizer='adam',
|
|||
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
|||
|
metrics=['accuracy'])
|
|||
|
|
|||
|
return model
|
|||
|
|
|||
|
|
|||
|
# %% 训练数据
|
|||
|
def train():
|
|||
|
# 统计文件夹下的所有图片数量
|
|||
|
data_dir = pathlib.Path('dataset')
|
|||
|
batch_size = 64
|
|||
|
img_width = 24
|
|||
|
img_height = 24
|
|||
|
|
|||
|
# 从文件夹下读取图片,生成数据集
|
|||
|
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
|
|||
|
data_dir,
|
|||
|
validation_split=0.2,
|
|||
|
subset='training',
|
|||
|
seed=123,
|
|||
|
color_mode="grayscale",
|
|||
|
image_size=(img_height, img_width),
|
|||
|
batch_size=batch_size
|
|||
|
)
|
|||
|
|
|||
|
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
|
|||
|
data_dir,
|
|||
|
validation_split=0.2,
|
|||
|
subset="validation",
|
|||
|
seed=123,
|
|||
|
color_mode="grayscale",
|
|||
|
image_size=(img_height, img_width),
|
|||
|
batch_size=batch_size
|
|||
|
)
|
|||
|
|
|||
|
# 数据集的分类,对应dataset文件夹下有多少图片分类
|
|||
|
class_names = train_ds.class_names
|
|||
|
# 保存数据集分类
|
|||
|
np.save("checkpoint/class_name.npy", class_names)
|
|||
|
|
|||
|
# 数据集缓存处理
|
|||
|
AUTOTUNE = tf.data.experimental.AUTOTUNE
|
|||
|
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
|
|||
|
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
|
|||
|
# 创建模型
|
|||
|
model = create_model()
|
|||
|
# 训练模型,epochs=10,所有数据集训练10遍
|
|||
|
model.fit(train_ds, validation_data=val_ds, epochs=20)
|
|||
|
# 保存训练后的权重
|
|||
|
model.save_weights('checkpoint/char_checkpoint')
|
|||
|
|
|||
|
|
|||
|
# %% 预测
|
|||
|
def predict(model, imgs, class_name):
|
|||
|
label_dict = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '=', 11: '+',
|
|||
|
12: '-', 13: '×', 14: '÷'}
|
|||
|
# 预测图片,获取预测值
|
|||
|
predicts = model.predict(imgs)
|
|||
|
results = [] # 保存结果的数组
|
|||
|
for predict in predicts: # 遍历每一个预测结果
|
|||
|
index = np.argmax(predict) # 寻找最大值
|
|||
|
result = class_name[index] # 取出字符
|
|||
|
results.append(label_dict[int(result)])
|
|||
|
return results
|
|||
|
|
|||
|
|
|||
|
# %%
|
|||
|
if __name__ == '__main__':
|
|||
|
train()
|
|||
|
|
|||
|
# model = create_model()
|
|||
|
# # 加载前期训练好的权重
|
|||
|
# model.load_weights('checkpoint/char_checkpoint')
|
|||
|
# # 读出图片分类
|
|||
|
# class_name = np.load('checkpoint/class_name.npy')
|
|||
|
# print(class_name)
|
|||
|
# img1=cv2.imread('img1.png',0)
|
|||
|
# img2=cv2.imread('img2.png',0)
|
|||
|
# img3=cv2.imread('img3.png',0)
|
|||
|
# img4=cv2.imread('img4.png',0)
|
|||
|
# img5=cv2.imread('img5.png',0)
|
|||
|
# img6=cv2.imread('img6.png',0)
|
|||
|
# imgs = np.array([img1,img2,img3,img4,img5,img6])
|
|||
|
# results = predict(model, imgs, class_name)
|
|||
|
# print(results)
|