url/cnn.ipynb

201 lines
8.5 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "9128debb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 300 files belonging to 15 classes.\n",
"Using 240 files for training.\n",
"Found 300 files belonging to 15 classes.\n",
"Using 60 files for validation.\n",
"Epoch 1/20\n",
"4/4 [==============================] - 1s 99ms/step - loss: 2.6807 - accuracy: 0.0875 - val_loss: 2.6252 - val_accuracy: 0.1167\n",
"Epoch 2/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 2.5773 - accuracy: 0.1625 - val_loss: 2.5274 - val_accuracy: 0.3167\n",
"Epoch 3/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 2.4399 - accuracy: 0.2917 - val_loss: 2.3353 - val_accuracy: 0.4667\n",
"Epoch 4/20\n",
"4/4 [==============================] - 0s 14ms/step - loss: 2.2246 - accuracy: 0.5333 - val_loss: 2.0439 - val_accuracy: 0.6833\n",
"Epoch 5/20\n",
"4/4 [==============================] - 0s 16ms/step - loss: 1.9292 - accuracy: 0.6167 - val_loss: 1.6177 - val_accuracy: 0.9167\n",
"Epoch 6/20\n",
"4/4 [==============================] - 0s 38ms/step - loss: 1.5042 - accuracy: 0.8333 - val_loss: 1.1052 - val_accuracy: 1.0000\n",
"Epoch 7/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 1.0313 - accuracy: 0.8792 - val_loss: 0.6252 - val_accuracy: 0.9833\n",
"Epoch 8/20\n",
"4/4 [==============================] - 0s 16ms/step - loss: 0.6847 - accuracy: 0.8958 - val_loss: 0.3330 - val_accuracy: 0.9833\n",
"Epoch 9/20\n",
"4/4 [==============================] - 0s 16ms/step - loss: 0.3673 - accuracy: 0.9250 - val_loss: 0.1612 - val_accuracy: 1.0000\n",
"Epoch 10/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 0.2573 - accuracy: 0.9500 - val_loss: 0.1193 - val_accuracy: 1.0000\n",
"Epoch 11/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 0.1788 - accuracy: 0.9667 - val_loss: 0.0505 - val_accuracy: 1.0000\n",
"Epoch 12/20\n",
"4/4 [==============================] - 0s 16ms/step - loss: 0.1045 - accuracy: 0.9875 - val_loss: 0.0339 - val_accuracy: 1.0000\n",
"Epoch 13/20\n",
"4/4 [==============================] - 0s 18ms/step - loss: 0.0553 - accuracy: 1.0000 - val_loss: 0.0132 - val_accuracy: 1.0000\n",
"Epoch 14/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 0.0351 - accuracy: 1.0000 - val_loss: 0.0121 - val_accuracy: 1.0000\n",
"Epoch 15/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 0.0325 - accuracy: 0.9958 - val_loss: 0.0060 - val_accuracy: 1.0000\n",
"Epoch 16/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 0.0273 - accuracy: 0.9958 - val_loss: 0.0044 - val_accuracy: 1.0000\n",
"Epoch 17/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 0.0144 - accuracy: 1.0000 - val_loss: 0.0044 - val_accuracy: 1.0000\n",
"Epoch 18/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 0.0098 - accuracy: 1.0000 - val_loss: 0.0034 - val_accuracy: 1.0000\n",
"Epoch 19/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 0.0110 - accuracy: 1.0000 - val_loss: 0.0024 - val_accuracy: 1.0000\n",
"Epoch 20/20\n",
"4/4 [==============================] - 0s 15ms/step - loss: 0.0088 - accuracy: 0.9958 - val_loss: 0.0019 - val_accuracy: 1.0000\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"from tensorflow.keras import layers\n",
"from tensorflow.keras.models import Sequential\n",
"import pathlib\n",
"import cv2\n",
"\n",
"# %% 构建模型\n",
"def create_model():\n",
" model = Sequential([\n",
" layers.experimental.preprocessing.Rescaling(1./255, input_shape=(24, 24, 1)),\n",
" layers.Conv2D(24, 3, padding='same', activation='relu'),\n",
" layers.MaxPooling2D(),\n",
" layers.Conv2D(32, 3, padding='same', activation='relu'),\n",
" layers.MaxPooling2D(),\n",
" layers.Conv2D(64, 3, padding='same', activation='relu'),\n",
" layers.MaxPooling2D(),\n",
" layers.Dropout(0.2),\n",
" layers.Flatten(),\n",
" layers.Dense(96, activation='relu'),\n",
" layers.Dense(15)]\n",
" )\n",
" \n",
" model.compile(optimizer='adam',\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n",
"\n",
" return model\n",
"\n",
"# %% 训练数据\n",
"def train():\n",
" # 统计文件夹下的所有图片数量\n",
" data_dir = pathlib.Path('dataset')\n",
" batch_size = 64\n",
" img_width = 24\n",
" img_height = 24\n",
"\n",
" # 从文件夹下读取图片,生成数据集\n",
" train_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
" data_dir,\n",
" validation_split=0.2,\n",
" subset='training',\n",
" seed=123,\n",
" color_mode=\"grayscale\",\n",
" image_size=(img_height, img_width),\n",
" batch_size=batch_size\n",
" )\n",
"\n",
" val_ds = tf.keras.preprocessing.image_dataset_from_directory(\n",
" data_dir,\n",
" validation_split=0.2,\n",
" subset=\"validation\",\n",
" seed=123,\n",
" color_mode=\"grayscale\",\n",
" image_size=(img_height, img_width),\n",
" batch_size=batch_size\n",
" )\n",
"\n",
" # 数据集的分类对应dataset文件夹下有多少图片分类\n",
" class_names = train_ds.class_names\n",
" # 保存数据集分类\n",
" np.save(\"checkpoint/class_name.npy\", class_names)\n",
"\n",
" # 数据集缓存处理\n",
" AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
" train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)\n",
" val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
" # 创建模型\n",
" model = create_model()\n",
" # 训练模型epochs=10所有数据集训练10遍\n",
" model.fit(train_ds,validation_data=val_ds,epochs=20)\n",
" # 保存训练后的权重\n",
" model.save_weights('checkpoint/char_checkpoint')\n",
"\n",
"# %% 预测\n",
"def predict(model, imgs, class_name):\n",
" label_dict = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: '=', 11: '+', 12: '-', 13: '×', 14: '÷'}\n",
" # 预测图片,获取预测值\n",
" predicts = model.predict(imgs) \n",
" results = [] # 保存结果的数组\n",
" for predict in predicts: #遍历每一个预测结果\n",
" index = np.argmax(predict) # 寻找最大值\n",
" result = class_name[index] # 取出字符\n",
" results.append(label_dict[int(result)])\n",
" return results\n",
"\n",
"\n",
"# %% \n",
"if __name__ == '__main__':\n",
"\n",
" train()\n",
" \n",
" # model = create_model()\n",
" # # 加载前期训练好的权重\n",
" # model.load_weights('checkpoint/char_checkpoint')\n",
" # # 读出图片分类\n",
" # class_name = np.load('checkpoint/class_name.npy')\n",
" # print(class_name)\n",
" # img1=cv2.imread('img1.png',0) \n",
" # img2=cv2.imread('img2.png',0) \n",
" # img3=cv2.imread('img3.png',0)\n",
" # img4=cv2.imread('img4.png',0)\n",
" # img5=cv2.imread('img5.png',0)\n",
" # img6=cv2.imread('img6.png',0)\n",
" # imgs = np.array([img1,img2,img3,img4,img5,img6])\n",
" # results = predict(model, imgs, class_name)\n",
" # print(results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf017f99",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}