widzenie-komputerowe-projekt/training/VGG16.ipynb

137 lines
3.5 KiB
Plaintext
Raw Normal View History

2023-02-01 18:42:47 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import keras, os, time\n",
"from training.data_load import load_data\n",
"import keras,os\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Conv2D, MaxPool2D , Flatten\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"import numpy as np\n",
"from keras.applications import VGG16\n",
"from keras.layers import Input, Lambda, Dense, Flatten\n",
"from keras.models import Model\n",
"from keras.preprocessing import image\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras.models import Sequential\n",
"import numpy as np\n",
"from glob import glob\n",
"import matplotlib.pyplot as plt\n",
"import ssl"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"train_ds, test_ds, validation_ds = load_data((224, 224), './new_data_transformed/train', './new_data_transformed/test')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tensorboard_callback(model_name):\n",
" return keras.callbacks.TensorBoard(os.path.join(f\"./logs/{model_name}\", time.strftime(\"run_%Y_%m_%d-%H_%M_%S\")))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ssl._create_default_https_context = ssl._create_unverified_context\n",
"\n",
"IMAGE_SIZE = [224, 224]\n",
"\n",
"vgg2 = VGG16(input_shape=tuple(IMAGE_SIZE + [3]), include_top=False, weights='imagenet')\n",
"\n",
"for layer in vgg2.layers:\n",
" layer.trainable = False\n",
" \n",
"\n",
"x = Flatten()(vgg2.output)\n",
"prediction = Dense(6, activation='softmax')(x)\n",
"\n",
"model = Model(inputs=vgg2.input, outputs=prediction)\n",
"\n",
"model.summary()\n",
"model.compile(\n",
" loss='sparse_categorical_crossentropy',\n",
" optimizer='adam',\n",
" metrics=['accuracy']\n",
")\n",
"\n",
"vggr = model.fit(\n",
" train_ds,\n",
" validation_data=validation_ds,\n",
" epochs=15,\n",
" steps_per_epoch=len(train_ds),\n",
" validation_steps=len(validation_ds),\n",
" callbacks=[tensorboard_callback(\"VGG16\")]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!tensorboard dev upload --logdir logs/VGG16 --name VGGFishBest"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.evaluate(test_ds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from training.model_freeze import freeze_model\n",
"\n",
"freeze_model(model, './frozen_models', \"frozen_vgg16\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "um",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 08:41:22) [MSC v.1929 64 bit (AMD64)]"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "876e189cbbe99a9a838ece62aae1013186c4bb7e0254a10cfa2f9b2381853efb"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}