201 lines
8.5 KiB
Plaintext
201 lines
8.5 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "9128debb",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Found 300 files belonging to 15 classes.\n",
|
||
"Using 240 files for training.\n",
|
||
"Found 300 files belonging to 15 classes.\n",
|
||
"Using 60 files for validation.\n",
|
||
"Epoch 1/20\n",
|
||
"4/4 [==============================] - 1s 99ms/step - loss: 2.6807 - accuracy: 0.0875 - val_loss: 2.6252 - val_accuracy: 0.1167\n",
|
||
"Epoch 2/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 2.5773 - accuracy: 0.1625 - val_loss: 2.5274 - val_accuracy: 0.3167\n",
|
||
"Epoch 3/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 2.4399 - accuracy: 0.2917 - val_loss: 2.3353 - val_accuracy: 0.4667\n",
|
||
"Epoch 4/20\n",
|
||
"4/4 [==============================] - 0s 14ms/step - loss: 2.2246 - accuracy: 0.5333 - val_loss: 2.0439 - val_accuracy: 0.6833\n",
|
||
"Epoch 5/20\n",
|
||
"4/4 [==============================] - 0s 16ms/step - loss: 1.9292 - accuracy: 0.6167 - val_loss: 1.6177 - val_accuracy: 0.9167\n",
|
||
"Epoch 6/20\n",
|
||
"4/4 [==============================] - 0s 38ms/step - loss: 1.5042 - accuracy: 0.8333 - val_loss: 1.1052 - val_accuracy: 1.0000\n",
|
||
"Epoch 7/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 1.0313 - accuracy: 0.8792 - val_loss: 0.6252 - val_accuracy: 0.9833\n",
|
||
"Epoch 8/20\n",
|
||
"4/4 [==============================] - 0s 16ms/step - loss: 0.6847 - accuracy: 0.8958 - val_loss: 0.3330 - val_accuracy: 0.9833\n",
|
||
"Epoch 9/20\n",
|
||
"4/4 [==============================] - 0s 16ms/step - loss: 0.3673 - accuracy: 0.9250 - val_loss: 0.1612 - val_accuracy: 1.0000\n",
|
||
"Epoch 10/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 0.2573 - accuracy: 0.9500 - val_loss: 0.1193 - val_accuracy: 1.0000\n",
|
||
"Epoch 11/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 0.1788 - accuracy: 0.9667 - val_loss: 0.0505 - val_accuracy: 1.0000\n",
|
||
"Epoch 12/20\n",
|
||
"4/4 [==============================] - 0s 16ms/step - loss: 0.1045 - accuracy: 0.9875 - val_loss: 0.0339 - val_accuracy: 1.0000\n",
|
||
"Epoch 13/20\n",
|
||
"4/4 [==============================] - 0s 18ms/step - loss: 0.0553 - accuracy: 1.0000 - val_loss: 0.0132 - val_accuracy: 1.0000\n",
|
||
"Epoch 14/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 0.0351 - accuracy: 1.0000 - val_loss: 0.0121 - val_accuracy: 1.0000\n",
|
||
"Epoch 15/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 0.0325 - accuracy: 0.9958 - val_loss: 0.0060 - val_accuracy: 1.0000\n",
|
||
"Epoch 16/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 0.0273 - accuracy: 0.9958 - val_loss: 0.0044 - val_accuracy: 1.0000\n",
|
||
"Epoch 17/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 0.0144 - accuracy: 1.0000 - val_loss: 0.0044 - val_accuracy: 1.0000\n",
|
||
"Epoch 18/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 0.0098 - accuracy: 1.0000 - val_loss: 0.0034 - val_accuracy: 1.0000\n",
|
||
"Epoch 19/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 0.0110 - accuracy: 1.0000 - val_loss: 0.0024 - val_accuracy: 1.0000\n",
|
||
"Epoch 20/20\n",
|
||
"4/4 [==============================] - 0s 15ms/step - loss: 0.0088 - accuracy: 0.9958 - val_loss: 0.0019 - val_accuracy: 1.0000\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import tensorflow as tf\n",
|
||
"import numpy as np\n",
|
||
"from tensorflow.keras import layers\n",
|
||
"from tensorflow.keras.models import Sequential\n",
|
||
"import pathlib\n",
|
||
"import cv2\n",
|
||
"\n",
|
||
"# %% 构建模型\n",
|
||
"def create_model():\n",
|
||
" model = Sequential([\n",
|
||
" layers.experimental.preprocessing.Rescaling(1./255, input_shape=(24, 24, 1)),\n",
|
||
" layers.Conv2D(24, 3, padding='same', activation='relu'),\n",
|
||
" layers.MaxPooling2D(),\n",
|
||
" layers.Conv2D(32, 3, padding='same', activation='relu'),\n",
|
||
" layers.MaxPooling2D(),\n",
|
||
" layers.Conv2D(64, 3, padding='same', activation='relu'),\n",
|
||
" layers.MaxPooling2D(),\n",
|
||
" layers.Dropout(0.2),\n",
|
||
" layers.Flatten(),\n",
|
||
" layers.Dense(96, activation='relu'),\n",
|
||
" layers.Dense(15)]\n",
|
||
" )\n",
|
||
" \n",
|
||
" model.compile(optimizer='adam',\n",
|
||
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
|
||
" metrics=['accuracy'])\n",
|
||
"\n",
|
||
" return model\n",
|
||
"\n",
|
||
"# %% 训练数据\n",
|
||
"def train():\n",
|
||
" # 统计文件夹下的所有图片数量\n",
|
||
" data_dir = pathlib.Path('dataset')\n",
|
||
" batch_size = 64\n",
|
||
" img_width = 24\n",
|
||
" img_height = 24\n",
|
||
"\n",
|
||
" # 从文件夹下读取图片,生成数据集\n",
|
||
" train_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
|
||
" data_dir,\n",
|
||
" validation_split=0.2,\n",
|
||
" subset='training',\n",
|
||
" seed=123,\n",
|
||
" color_mode=\"grayscale\",\n",
|
||
" image_size=(img_height, img_width),\n",
|
||
" batch_size=batch_size\n",
|
||
" )\n",
|
||
"\n",
|
||
" val_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
|
||
" data_dir,\n",
|
||
" validation_split=0.2,\n",
|
||
" subset=\"validation\",\n",
|
||
" seed=123,\n",
|
||
" color_mode=\"grayscale\",\n",
|
||
" image_size=(img_height, img_width),\n",
|
||
" batch_size=batch_size\n",
|
||
" )\n",
|
||
"\n",
|
||
" # 数据集的分类,对应dataset文件夹下有多少图片分类\n",
|
||
" class_names = train_ds.class_names\n",
|
||
" # 保存数据集分类\n",
|
||
" np.save(\"checkpoint/class_name.npy\", class_names)\n",
|
||
"\n",
|
||
" # 数据集缓存处理\n",
|
||
" AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
|
||
" train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)\n",
|
||
" val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
|
||
" # 创建模型\n",
|
||
" model = create_model()\n",
|
||
" # 训练模型,epochs=10,所有数据集训练10遍\n",
|
||
" model.fit(train_ds,validation_data=val_ds,epochs=20)\n",
|
||
" # 保存训练后的权重\n",
|
||
" model.save_weights('checkpoint/char_checkpoint')\n",
|
||
"\n",
|
||
"# %% 预测\n",
|
||
"def predict(model, imgs, class_name):\n",
|
||
" label_dict = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '=', 11: '+', 12: '-', 13: '×', 14: '÷'}\n",
|
||
" # 预测图片,获取预测值\n",
|
||
" predicts = model.predict(imgs) \n",
|
||
" results = [] # 保存结果的数组\n",
|
||
" for predict in predicts: #遍历每一个预测结果\n",
|
||
" index = np.argmax(predict) # 寻找最大值\n",
|
||
" result = class_name[index] # 取出字符\n",
|
||
" results.append(label_dict[int(result)])\n",
|
||
" return results\n",
|
||
"\n",
|
||
"\n",
|
||
"# %% \n",
|
||
"if __name__ == '__main__':\n",
|
||
"\n",
|
||
" train()\n",
|
||
" \n",
|
||
" # model = create_model()\n",
|
||
" # # 加载前期训练好的权重\n",
|
||
" # model.load_weights('checkpoint/char_checkpoint')\n",
|
||
" # # 读出图片分类\n",
|
||
" # class_name = np.load('checkpoint/class_name.npy')\n",
|
||
" # print(class_name)\n",
|
||
" # img1=cv2.imread('img1.png',0) \n",
|
||
" # img2=cv2.imread('img2.png',0) \n",
|
||
" # img3=cv2.imread('img3.png',0)\n",
|
||
" # img4=cv2.imread('img4.png',0)\n",
|
||
" # img5=cv2.imread('img5.png',0)\n",
|
||
" # img6=cv2.imread('img6.png',0)\n",
|
||
" # imgs = np.array([img1,img2,img3,img4,img5,img6])\n",
|
||
" # results = predict(model, imgs, class_name)\n",
|
||
" # print(results)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "bf017f99",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.13"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|