classes for models, training
This commit is contained in:
parent
3b78ae9f05
commit
9b59c3b108
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,4 +4,5 @@ new_data
|
||||
model
|
||||
*avi
|
||||
new_data_transformed
|
||||
test.PNG
|
||||
test.PNG
|
||||
__pycache__
|
13
classify.py
Normal file
13
classify.py
Normal file
@ -0,0 +1,13 @@
|
||||
import argparse
|
||||
from models import ClassificationModel
|
||||
import cv2 as cv
|
||||
|
||||
def main(args):
|
||||
cls_model = ClassificationModel()
|
||||
print(cls_model.predict(cv.imread(args.image_path)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--image_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
43
main.py
Normal file
43
main.py
Normal file
@ -0,0 +1,43 @@
|
||||
import argparse
|
||||
import cv2
|
||||
from models import ClassificationModel
|
||||
|
||||
cls_model = ClassificationModel()
|
||||
|
||||
def main(args):
|
||||
cap = cv2.VideoCapture(args.video_path)
|
||||
object_detector = cv2.createBackgroundSubtractorMOG2(history=100, varThreshold=50)
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if(ret):
|
||||
roi = frame[100: 900,330:1900]
|
||||
mask = object_detector.apply(roi)
|
||||
_, mask = cv2.threshold(mask,254,255, cv2.THRESH_BINARY)
|
||||
conturs, _ =cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for cnt in conturs:
|
||||
area = cv2.contourArea(cnt)
|
||||
if area > 200:
|
||||
cv2.drawContours(roi,[cnt],-1,(0,255,0),2)
|
||||
x,y,w,h = cv2.boundingRect(cnt)
|
||||
rectangle = cv2.rectangle(roi,(x,y),(x+w,y+h),(0,255,0),3)
|
||||
image_to_predict = roi[y:y+h+10,x:x+w+10]
|
||||
label = cls_model.predict(image_to_predict)
|
||||
cv2.putText(rectangle, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 1)
|
||||
|
||||
roi = cv2.resize(roi, (960, 540))
|
||||
cv2.imshow("roi", roi)
|
||||
|
||||
key = cv2.waitKey(30)
|
||||
if key == 27:
|
||||
break
|
||||
else:
|
||||
break
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--video_path", type=str, default='./test_videos/rybki2.mp4')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
34
models.py
Normal file
34
models.py
Normal file
@ -0,0 +1,34 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
class ClassificationModel:
|
||||
def __init__(self, model_path: str = './frozen_models/frozen_graph_best_vgg.pb', model_type: str = "VGG16") -> None:
|
||||
print("loading classification model")
|
||||
self.model_path = model_path
|
||||
self.model_func = self.init_frozen_func()
|
||||
self.model_type = model_type
|
||||
self.class_names=sorted(['Fish', "Jellyfish", 'Lionfish', 'Shark', 'Stingray', 'Turtle'])
|
||||
|
||||
|
||||
def wrap_frozen_graph(self, graph_def, inputs, outputs):
|
||||
def _imports_graph_def():
|
||||
tf.compat.v1.import_graph_def(graph_def, name="")
|
||||
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
|
||||
import_graph = wrapped_import.graph
|
||||
return wrapped_import.prune(
|
||||
tf.nest.map_structure(import_graph.as_graph_element, inputs),
|
||||
tf.nest.map_structure(import_graph.as_graph_element, outputs))
|
||||
|
||||
def init_frozen_func(self):
|
||||
with tf.io.gfile.GFile(self.model_path, "rb") as f:
|
||||
graph_def = tf.compat.v1.GraphDef()
|
||||
loaded = graph_def.ParseFromString(f.read())
|
||||
return self.wrap_frozen_graph(graph_def=graph_def,
|
||||
inputs=["x:0"],
|
||||
outputs=["Identity:0"])
|
||||
|
||||
def predict(self, image, shape=(224, 224)):
|
||||
image = cv2.resize(image, shape)
|
||||
pred = self.model_func(x=tf.convert_to_tensor(image[None, :], dtype='float32'))
|
||||
return self.class_names[np.argmax(pred)]
|
@ -60,16 +60,23 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[<tf.Tensor 'x:0' shape=(None, 224, 224, 3) dtype=float32>]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'Jellyfish'"
|
||||
"'Turtle'"
|
||||
]
|
||||
},
|
||||
"execution_count": 28,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -78,7 +85,8 @@
|
||||
"class_names=sorted(['Fish', \"Jellyfish\", 'Lionfish', 'Shark', 'Stingray', 'Turtle'])\n",
|
||||
"a = cv2.imread('test.PNG')\n",
|
||||
"# a.shape\n",
|
||||
"a = cv2.resize(a,(227,227))\n",
|
||||
"a = cv2.resize(a,(224,224))\n",
|
||||
"print(frozen_func.inputs)\n",
|
||||
"pred = frozen_func(x=tf.convert_to_tensor(a[None, :], dtype='float32'))\n",
|
||||
"label = class_names[np.argmax(pred)]\n",
|
||||
"label"
|
||||
@ -86,23 +94,11 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[1;32mIn[51], line 42\u001b[0m\n\u001b[0;32m 39\u001b[0m roi \u001b[39m=\u001b[39m cv2\u001b[39m.\u001b[39mresize(roi, (\u001b[39m960\u001b[39m, \u001b[39m540\u001b[39m)) \n\u001b[0;32m 40\u001b[0m cv2\u001b[39m.\u001b[39mimshow(\u001b[39m\"\u001b[39m\u001b[39mroi\u001b[39m\u001b[39m\"\u001b[39m, roi)\n\u001b[1;32m---> 42\u001b[0m key \u001b[39m=\u001b[39m cv2\u001b[39m.\u001b[39;49mwaitKey(\u001b[39m30\u001b[39;49m)\n\u001b[0;32m 43\u001b[0m \u001b[39mif\u001b[39;00m key \u001b[39m==\u001b[39m \u001b[39m27\u001b[39m:\n\u001b[0;32m 44\u001b[0m \u001b[39mbreak\u001b[39;00m\n",
|
||||
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cap = cv2.VideoCapture(\"rybki4.mp4\")\n",
|
||||
"cap = cv2.VideoCapture(\"rybki.mp4\")\n",
|
||||
"# cap.set(cv2.CAP_PROP_FPS, 60)\n",
|
||||
"\n",
|
||||
"class_names=sorted(['Fish', \"Jellyfish\", 'Lionfish', 'Shark', 'Stingray', 'Turtle'])\n",
|
||||
@ -125,16 +121,16 @@
|
||||
" images = []\n",
|
||||
" for cnt in conturs:\n",
|
||||
" area = cv2.contourArea(cnt)\n",
|
||||
" if area > 300:\n",
|
||||
" #cv2.drawContours(roi,[cnt],-1,(0,255,0),2)\n",
|
||||
" if area > 200:\n",
|
||||
" cv2.drawContours(roi,[cnt],-1,(0,255,0),2)\n",
|
||||
" x,y,w,h = cv2.boundingRect(cnt)\n",
|
||||
" rectangle = cv2.rectangle(roi,(x,y),(x+w,y+h),(0,255,0),3)\n",
|
||||
" # images.append((x,y,rectangle,np.expand_dims(image_to_predict,axis=0)))\n",
|
||||
" # image_to_predict = roi[y:y+h,x:x+w]\n",
|
||||
" # image_to_predict = cv2.resize(image_to_predict,(227,227))\n",
|
||||
" # pred = frozen_func(x=tf.convert_to_tensor(image_to_predict[None, :], dtype='float32'))\n",
|
||||
" # label = class_names[np.argmax(pred)]\n",
|
||||
" # cv2.putText(rectangle, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 1)\n",
|
||||
" image_to_predict = roi[y:y+h+10,x:x+w+10]\n",
|
||||
" image_to_predict = cv2.resize(image_to_predict,(227,227))\n",
|
||||
" images.append((x,y,rectangle,np.expand_dims(image_to_predict,axis=0)))\n",
|
||||
" pred = frozen_func(x=tf.convert_to_tensor(image_to_predict[None, :], dtype='float32'))\n",
|
||||
" label = class_names[np.argmax(pred)]\n",
|
||||
" cv2.putText(rectangle, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 1)\n",
|
||||
" \n",
|
||||
" # if images:\n",
|
||||
" # pred = model.predict(np.vstack([image[3] for image in images]))\n",
|
@ -1025,7 +1025,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 08:41:22) [MSC v.1929 64 bit (AMD64)]"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
55
requirements.txt
Normal file
55
requirements.txt
Normal file
@ -0,0 +1,55 @@
|
||||
astroid==2.4.2
|
||||
asttokens==2.2.1
|
||||
backcall==0.2.0
|
||||
black==22.3.0
|
||||
certifi==2022.12.7
|
||||
charset-normalizer==2.1.1
|
||||
click==8.1.3
|
||||
colorama==0.4.4
|
||||
comm==0.1.2
|
||||
debugpy==1.6.4
|
||||
decorator==5.1.1
|
||||
entrypoints==0.4
|
||||
executing==1.2.0
|
||||
idna==3.4
|
||||
ipykernel==6.19.4
|
||||
ipython==8.7.0
|
||||
isort==5.6.4
|
||||
jedi==0.18.2
|
||||
jupyter-client==7.4.8
|
||||
jupyter-core==5.1.1
|
||||
lazy-object-proxy==1.4.3
|
||||
llvmlite==0.38.0
|
||||
matplotlib-inline==0.1.6
|
||||
mccabe==0.6.1
|
||||
mypy-extensions==0.4.3
|
||||
nest-asyncio==1.5.6
|
||||
numpy==1.24.1
|
||||
packaging==22.0
|
||||
pandas==1.5.2
|
||||
parso==0.8.3
|
||||
pathspec==0.9.0
|
||||
pickleshare==0.7.5
|
||||
Pillow==9.4.0
|
||||
platformdirs==2.5.2
|
||||
prompt-toolkit==3.0.36
|
||||
psutil==5.9.4
|
||||
pure-eval==0.2.2
|
||||
Pygments==2.13.0
|
||||
pylint==2.6.0
|
||||
python-dateutil==2.8.2
|
||||
pytz==2022.7.1
|
||||
pywin32==305
|
||||
pyzmq==24.0.1
|
||||
requests==2.28.2
|
||||
six==1.15.0
|
||||
stack-data==0.6.2
|
||||
toml==0.10.2
|
||||
tomli==2.0.1
|
||||
tornado==6.2
|
||||
tqdm==4.64.1
|
||||
traitlets==5.8.0
|
||||
typing-extensions==4.2.0
|
||||
urllib3==1.26.14
|
||||
wcwidth==0.2.5
|
||||
wrapt==1.12.1
|
450
training/AlexNet.ipynb
Normal file
450
training/AlexNet.ipynb
Normal file
@ -0,0 +1,450 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "RaaVleVhamV5"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython.display import Image, SVG, display\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import sys\n",
|
||||
"import subprocess\n",
|
||||
"import pkg_resources\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import os\n",
|
||||
"from skimage.io import imread\n",
|
||||
"import cv2 as cv\n",
|
||||
"from pathlib import Path\n",
|
||||
"import random\n",
|
||||
"from shutil import copyfile, rmtree\n",
|
||||
"import json\n",
|
||||
"import seaborn as sns\n",
|
||||
"import matplotlib\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from sklearn.preprocessing import LabelEncoder\n",
|
||||
"import tensorflow as tf\n",
|
||||
"from tensorflow import keras\n",
|
||||
"import os\n",
|
||||
"import time"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from training.data_load import load_data\n",
|
||||
"\n",
|
||||
"train_ds, test_ds, validation_ds = load_data((227, 227), './new_data_transformed/train', './new_data_transformed/test')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "se5yACYzZcmm"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tensorboard_callback(model_name):\n",
|
||||
" return keras.callbacks.TensorBoard(os.path.join(f\"./logs/{model_name}\", time.strftime(\"run_%Y_%m_%d-%H_%M_%S\")))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "yvRRfMKbcTwu"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = keras.models.Sequential([\n",
|
||||
" keras.layers.Conv2D(filters=96, kernel_size=(11,11), strides=(4,4), activation='relu', input_shape=(227,227,3)),\n",
|
||||
" keras.layers.BatchNormalization(),\n",
|
||||
" keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),\n",
|
||||
" keras.layers.Conv2D(filters=256, kernel_size=(5,5), strides=(1,1), activation='relu', padding=\"same\"),\n",
|
||||
" keras.layers.BatchNormalization(),\n",
|
||||
" keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),\n",
|
||||
" keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding=\"same\"),\n",
|
||||
" keras.layers.BatchNormalization(),\n",
|
||||
" keras.layers.Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding=\"same\"),\n",
|
||||
" keras.layers.BatchNormalization(),\n",
|
||||
" keras.layers.Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), activation='relu', padding=\"same\"),\n",
|
||||
" keras.layers.BatchNormalization(),\n",
|
||||
" keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),\n",
|
||||
" keras.layers.Flatten(),\n",
|
||||
" keras.layers.Dense(4096, activation='relu'),\n",
|
||||
" keras.layers.Dense(4096, activation='relu'),\n",
|
||||
" keras.layers.Dense(6, activation='softmax')\n",
|
||||
"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "eRI-u6HLcU_H",
|
||||
"outputId": "7ea0a139-c9ab-4092-e082-bea02eca4e93"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model: \"sequential\"\n",
|
||||
"_________________________________________________________________\n",
|
||||
" Layer (type) Output Shape Param # \n",
|
||||
"=================================================================\n",
|
||||
" conv2d (Conv2D) (None, 55, 55, 96) 34944 \n",
|
||||
" \n",
|
||||
" batch_normalization (BatchN (None, 55, 55, 96) 384 \n",
|
||||
" ormalization) \n",
|
||||
" \n",
|
||||
" max_pooling2d (MaxPooling2D (None, 27, 27, 96) 0 \n",
|
||||
" ) \n",
|
||||
" \n",
|
||||
" conv2d_1 (Conv2D) (None, 27, 27, 256) 614656 \n",
|
||||
" \n",
|
||||
" batch_normalization_1 (Batc (None, 27, 27, 256) 1024 \n",
|
||||
" hNormalization) \n",
|
||||
" \n",
|
||||
" max_pooling2d_1 (MaxPooling (None, 13, 13, 256) 0 \n",
|
||||
" 2D) \n",
|
||||
" \n",
|
||||
" conv2d_2 (Conv2D) (None, 13, 13, 384) 885120 \n",
|
||||
" \n",
|
||||
" batch_normalization_2 (Batc (None, 13, 13, 384) 1536 \n",
|
||||
" hNormalization) \n",
|
||||
" \n",
|
||||
" conv2d_3 (Conv2D) (None, 13, 13, 384) 1327488 \n",
|
||||
" \n",
|
||||
" batch_normalization_3 (Batc (None, 13, 13, 384) 1536 \n",
|
||||
" hNormalization) \n",
|
||||
" \n",
|
||||
" conv2d_4 (Conv2D) (None, 13, 13, 256) 884992 \n",
|
||||
" \n",
|
||||
" batch_normalization_4 (Batc (None, 13, 13, 256) 1024 \n",
|
||||
" hNormalization) \n",
|
||||
" \n",
|
||||
" max_pooling2d_2 (MaxPooling (None, 6, 6, 256) 0 \n",
|
||||
" 2D) \n",
|
||||
" \n",
|
||||
" flatten (Flatten) (None, 9216) 0 \n",
|
||||
" \n",
|
||||
" dense (Dense) (None, 4096) 37752832 \n",
|
||||
" \n",
|
||||
" dense_1 (Dense) (None, 4096) 16781312 \n",
|
||||
" \n",
|
||||
" dense_2 (Dense) (None, 6) 24582 \n",
|
||||
" \n",
|
||||
"=================================================================\n",
|
||||
"Total params: 58,311,430\n",
|
||||
"Trainable params: 58,308,678\n",
|
||||
"Non-trainable params: 2,752\n",
|
||||
"_________________________________________________________________\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.optimizers.SGD(learning_rate=.001), metrics=['accuracy'])\n",
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Jxzfxvy3cWBP",
|
||||
"outputId": "fa3b738e-a125-4b19-984a-4b1d83df771f"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 1/50\n",
|
||||
"45/45 [==============================] - 13s 82ms/step - loss: 1.5723 - accuracy: 0.4535 - val_loss: 1.8094 - val_accuracy: 0.1676\n",
|
||||
"Epoch 2/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.8888 - accuracy: 0.6771 - val_loss: 1.8356 - val_accuracy: 0.1733\n",
|
||||
"Epoch 3/50\n",
|
||||
"45/45 [==============================] - 3s 62ms/step - loss: 0.7100 - accuracy: 0.7292 - val_loss: 1.8753 - val_accuracy: 0.1534\n",
|
||||
"Epoch 4/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.5200 - accuracy: 0.8222 - val_loss: 1.9177 - val_accuracy: 0.1477\n",
|
||||
"Epoch 5/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.3887 - accuracy: 0.8868 - val_loss: 2.0252 - val_accuracy: 0.1591\n",
|
||||
"Epoch 6/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.2841 - accuracy: 0.9340 - val_loss: 2.0583 - val_accuracy: 0.2273\n",
|
||||
"Epoch 7/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.2114 - accuracy: 0.9569 - val_loss: 2.1366 - val_accuracy: 0.2216\n",
|
||||
"Epoch 8/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.1738 - accuracy: 0.9764 - val_loss: 2.1499 - val_accuracy: 0.3153\n",
|
||||
"Epoch 9/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.1298 - accuracy: 0.9854 - val_loss: 2.0241 - val_accuracy: 0.3693\n",
|
||||
"Epoch 10/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.1152 - accuracy: 0.9896 - val_loss: 1.5216 - val_accuracy: 0.4631\n",
|
||||
"Epoch 11/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.0893 - accuracy: 0.9958 - val_loss: 1.2411 - val_accuracy: 0.5511\n",
|
||||
"Epoch 12/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0922 - accuracy: 0.9910 - val_loss: 0.9918 - val_accuracy: 0.6477\n",
|
||||
"Epoch 13/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0665 - accuracy: 0.9972 - val_loss: 0.9535 - val_accuracy: 0.6648\n",
|
||||
"Epoch 14/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.0628 - accuracy: 0.9965 - val_loss: 0.7136 - val_accuracy: 0.7415\n",
|
||||
"Epoch 15/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.0624 - accuracy: 0.9972 - val_loss: 0.6692 - val_accuracy: 0.7642\n",
|
||||
"Epoch 16/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.0534 - accuracy: 0.9979 - val_loss: 0.6320 - val_accuracy: 0.7898\n",
|
||||
"Epoch 17/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.0436 - accuracy: 0.9993 - val_loss: 0.6613 - val_accuracy: 0.7841\n",
|
||||
"Epoch 18/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.0401 - accuracy: 0.9972 - val_loss: 0.6103 - val_accuracy: 0.7869\n",
|
||||
"Epoch 19/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0314 - accuracy: 0.9993 - val_loss: 0.6222 - val_accuracy: 0.8040\n",
|
||||
"Epoch 20/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0308 - accuracy: 0.9986 - val_loss: 0.6182 - val_accuracy: 0.7841\n",
|
||||
"Epoch 21/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.0309 - accuracy: 0.9993 - val_loss: 0.6246 - val_accuracy: 0.7926\n",
|
||||
"Epoch 22/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0256 - accuracy: 1.0000 - val_loss: 0.6276 - val_accuracy: 0.7983\n",
|
||||
"Epoch 23/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0259 - accuracy: 1.0000 - val_loss: 0.6332 - val_accuracy: 0.8011\n",
|
||||
"Epoch 24/50\n",
|
||||
"45/45 [==============================] - 3s 62ms/step - loss: 0.0223 - accuracy: 1.0000 - val_loss: 0.6229 - val_accuracy: 0.7955\n",
|
||||
"Epoch 25/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0225 - accuracy: 1.0000 - val_loss: 0.6089 - val_accuracy: 0.8011\n",
|
||||
"Epoch 26/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0217 - accuracy: 1.0000 - val_loss: 0.6463 - val_accuracy: 0.7926\n",
|
||||
"Epoch 27/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.0221 - accuracy: 1.0000 - val_loss: 0.6237 - val_accuracy: 0.8068\n",
|
||||
"Epoch 28/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0196 - accuracy: 1.0000 - val_loss: 0.6484 - val_accuracy: 0.7898\n",
|
||||
"Epoch 29/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0204 - accuracy: 0.9993 - val_loss: 0.6200 - val_accuracy: 0.7926\n",
|
||||
"Epoch 30/50\n",
|
||||
"45/45 [==============================] - 3s 60ms/step - loss: 0.0194 - accuracy: 0.9993 - val_loss: 0.6186 - val_accuracy: 0.8040\n",
|
||||
"Epoch 31/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0171 - accuracy: 1.0000 - val_loss: 0.6418 - val_accuracy: 0.8068\n",
|
||||
"Epoch 32/50\n",
|
||||
"45/45 [==============================] - 3s 62ms/step - loss: 0.0172 - accuracy: 1.0000 - val_loss: 0.6234 - val_accuracy: 0.8011\n",
|
||||
"Epoch 33/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0168 - accuracy: 1.0000 - val_loss: 0.6278 - val_accuracy: 0.8068\n",
|
||||
"Epoch 34/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0153 - accuracy: 1.0000 - val_loss: 0.6527 - val_accuracy: 0.7898\n",
|
||||
"Epoch 35/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0138 - accuracy: 1.0000 - val_loss: 0.6198 - val_accuracy: 0.7926\n",
|
||||
"Epoch 36/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0130 - accuracy: 1.0000 - val_loss: 0.6359 - val_accuracy: 0.7869\n",
|
||||
"Epoch 37/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0130 - accuracy: 1.0000 - val_loss: 0.6123 - val_accuracy: 0.8153\n",
|
||||
"Epoch 38/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0116 - accuracy: 1.0000 - val_loss: 0.6061 - val_accuracy: 0.8040\n",
|
||||
"Epoch 39/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0113 - accuracy: 1.0000 - val_loss: 0.6256 - val_accuracy: 0.8011\n",
|
||||
"Epoch 40/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0142 - accuracy: 0.9993 - val_loss: 0.6386 - val_accuracy: 0.8011\n",
|
||||
"Epoch 41/50\n",
|
||||
"45/45 [==============================] - 3s 62ms/step - loss: 0.0119 - accuracy: 1.0000 - val_loss: 0.6137 - val_accuracy: 0.8040\n",
|
||||
"Epoch 42/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0100 - accuracy: 1.0000 - val_loss: 0.6392 - val_accuracy: 0.8068\n",
|
||||
"Epoch 43/50\n",
|
||||
"45/45 [==============================] - 3s 62ms/step - loss: 0.0098 - accuracy: 1.0000 - val_loss: 0.6461 - val_accuracy: 0.8097\n",
|
||||
"Epoch 44/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0113 - accuracy: 1.0000 - val_loss: 0.6131 - val_accuracy: 0.8125\n",
|
||||
"Epoch 45/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0095 - accuracy: 1.0000 - val_loss: 0.6376 - val_accuracy: 0.8125\n",
|
||||
"Epoch 46/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0117 - accuracy: 0.9986 - val_loss: 0.6414 - val_accuracy: 0.7841\n",
|
||||
"Epoch 47/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0094 - accuracy: 1.0000 - val_loss: 0.6224 - val_accuracy: 0.8068\n",
|
||||
"Epoch 48/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0095 - accuracy: 1.0000 - val_loss: 0.5973 - val_accuracy: 0.8153\n",
|
||||
"Epoch 49/50\n",
|
||||
"45/45 [==============================] - 3s 62ms/step - loss: 0.0088 - accuracy: 1.0000 - val_loss: 0.6366 - val_accuracy: 0.8068\n",
|
||||
"Epoch 50/50\n",
|
||||
"45/45 [==============================] - 3s 61ms/step - loss: 0.0093 - accuracy: 1.0000 - val_loss: 0.6526 - val_accuracy: 0.8040\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<keras.callbacks.History at 0x7efd9022ee80>"
|
||||
]
|
||||
},
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.fit(train_ds,\n",
|
||||
" epochs=50,\n",
|
||||
" validation_data=validation_ds,\n",
|
||||
" validation_freq=1,\n",
|
||||
" callbacks=[tensorboard_callback(\"AlexNet\")])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "hrbrA2t3Zv4d",
|
||||
"outputId": "6874f3db-eacb-4c69-ed72-6616db55303d"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"***** TensorBoard Uploader *****\n",
|
||||
"\n",
|
||||
"This will upload your TensorBoard logs to https://tensorboard.dev/ from\n",
|
||||
"the following directory:\n",
|
||||
"\n",
|
||||
"logs/AlexNet\n",
|
||||
"\n",
|
||||
"This TensorBoard will be visible to everyone. Do not upload sensitive\n",
|
||||
"data.\n",
|
||||
"\n",
|
||||
"Your use of this service is subject to Google's Terms of Service\n",
|
||||
"<https://policies.google.com/terms> and Privacy Policy\n",
|
||||
"<https://policies.google.com/privacy>, and TensorBoard.dev's Terms of Service\n",
|
||||
"<https://tensorboard.dev/policy/terms/>.\n",
|
||||
"\n",
|
||||
"This notice will not be shown again while you are logged into the uploader.\n",
|
||||
"To log out, run `tensorboard dev auth revoke`.\n",
|
||||
"\n",
|
||||
"Continue? (yes/NO) yes\n",
|
||||
"\n",
|
||||
"Please visit this URL to authorize this application: https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=373649185512-8v619h5kft38l4456nm2dj4ubeqsrvh6.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&state=4s45OaCdh7CXhhMYUBHIgm6u6oiMTE&prompt=consent&access_type=offline\n",
|
||||
"Enter the authorization code: 4/1AWtgzh59_rKzQT7gGHgAyOnRtMf7ppSZuYlb-25UkSRY4IdlSjRmYmH_AfE\n",
|
||||
"\n",
|
||||
"Upload started and will continue reading any new data as it's added to the logdir.\n",
|
||||
"\n",
|
||||
"To stop uploading, press Ctrl-C.\n",
|
||||
"\n",
|
||||
"New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/0xVh6RMqTQiv0BCdbdLLBg/\n",
|
||||
"\n",
|
||||
"\u001b[1m[2023-02-01T08:10:10]\u001b[0m Started scanning logdir.\n",
|
||||
"\u001b[1m[2023-02-01T08:10:11]\u001b[0m Total uploaded: 300 scalars, 0 tensors, 1 binary objects (75.6 kB)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Interrupted. View your TensorBoard at https://tensorboard.dev/experiment/0xVh6RMqTQiv0BCdbdLLBg/\n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"/usr/local/bin/tensorboard\", line 8, in <module>\n",
|
||||
" sys.exit(run_main())\n",
|
||||
" File \"/usr/local/lib/python3.8/dist-packages/tensorboard/main.py\", line 46, in run_main\n",
|
||||
" app.run(tensorboard.main, flags_parser=tensorboard.configure)\n",
|
||||
" File \"/usr/local/lib/python3.8/dist-packages/absl/app.py\", line 308, in run\n",
|
||||
" _run_main(main, args)\n",
|
||||
" File \"/usr/local/lib/python3.8/dist-packages/absl/app.py\", line 254, in _run_main\n",
|
||||
" sys.exit(main(argv))\n",
|
||||
" File \"/usr/local/lib/python3.8/dist-packages/tensorboard/program.py\", line 276, in main\n",
|
||||
" return runner(self.flags) or 0\n",
|
||||
" File \"/usr/local/lib/python3.8/dist-packages/tensorboard/uploader/uploader_subcommand.py\", line 692, in run\n",
|
||||
" return _run(flags, self._experiment_url_callback)\n",
|
||||
" File \"/usr/local/lib/python3.8/dist-packages/tensorboard/uploader/uploader_subcommand.py\", line 125, in _run\n",
|
||||
" intent.execute(server_info, channel)\n",
|
||||
" File \"/usr/local/lib/python3.8/dist-packages/tensorboard/uploader/uploader_subcommand.py\", line 508, in execute\n",
|
||||
" sys.stdout.write(end_message + \"\\n\")\n",
|
||||
"KeyboardInterrupt\n",
|
||||
"^C\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!tensorboard dev upload --logdir logs/AlexNet --name AlexNetFish"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "VdNLX57nc_Gs",
|
||||
"outputId": "4353b693-12cd-45ac-f249-8b287281754b"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"9/9 [==============================] - 0s 32ms/step - loss: 0.9122 - accuracy: 0.7188\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[0.9122310280799866, 0.71875]"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.evaluate(test_ds)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 36
|
||||
},
|
||||
"id": "aaPsyfODLBgo",
|
||||
"outputId": "7a1c3dd3-a662-4e6d-9cba-28aef9067d73"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from training.model_freeze import freeze_model\n",
|
||||
"\n",
|
||||
"freeze_model(model, './frozen_models', \"frozen_alex_net\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"machine_shape": "hm",
|
||||
"provenance": []
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "um",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 08:41:22) [MSC v.1929 64 bit (AMD64)]"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "876e189cbbe99a9a838ece62aae1013186c4bb7e0254a10cfa2f9b2381853efb"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
136
training/VGG16.ipynb
Normal file
136
training/VGG16.ipynb
Normal file
@ -0,0 +1,136 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import keras, os, time\n",
|
||||
"from training.data_load import load_data\n",
|
||||
"import keras,os\n",
|
||||
"from keras.models import Sequential\n",
|
||||
"from keras.layers import Dense, Conv2D, MaxPool2D , Flatten\n",
|
||||
"from keras.preprocessing.image import ImageDataGenerator\n",
|
||||
"import numpy as np\n",
|
||||
"from keras.applications import VGG16\n",
|
||||
"from keras.layers import Input, Lambda, Dense, Flatten\n",
|
||||
"from keras.models import Model\n",
|
||||
"from keras.preprocessing import image\n",
|
||||
"from keras.preprocessing.image import ImageDataGenerator\n",
|
||||
"from keras.models import Sequential\n",
|
||||
"import numpy as np\n",
|
||||
"from glob import glob\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import ssl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"train_ds, test_ds, validation_ds = load_data((224, 224), './new_data_transformed/train', './new_data_transformed/test')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def tensorboard_callback(model_name):\n",
|
||||
" return keras.callbacks.TensorBoard(os.path.join(f\"./logs/{model_name}\", time.strftime(\"run_%Y_%m_%d-%H_%M_%S\")))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"ssl._create_default_https_context = ssl._create_unverified_context\n",
|
||||
"\n",
|
||||
"IMAGE_SIZE = [224, 224]\n",
|
||||
"\n",
|
||||
"vgg2 = VGG16(input_shape=tuple(IMAGE_SIZE + [3]), include_top=False, weights='imagenet')\n",
|
||||
"\n",
|
||||
"for layer in vgg2.layers:\n",
|
||||
" layer.trainable = False\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"x = Flatten()(vgg2.output)\n",
|
||||
"prediction = Dense(6, activation='softmax')(x)\n",
|
||||
"\n",
|
||||
"model = Model(inputs=vgg2.input, outputs=prediction)\n",
|
||||
"\n",
|
||||
"model.summary()\n",
|
||||
"model.compile(\n",
|
||||
" loss='sparse_categorical_crossentropy',\n",
|
||||
" optimizer='adam',\n",
|
||||
" metrics=['accuracy']\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"vggr = model.fit(\n",
|
||||
" train_ds,\n",
|
||||
" validation_data=validation_ds,\n",
|
||||
" epochs=15,\n",
|
||||
" steps_per_epoch=len(train_ds),\n",
|
||||
" validation_steps=len(validation_ds),\n",
|
||||
" callbacks=[tensorboard_callback(\"VGG16\")]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!tensorboard dev upload --logdir logs/VGG16 --name VGGFishBest"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.evaluate(test_ds)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from training.model_freeze import freeze_model\n",
|
||||
"\n",
|
||||
"freeze_model(model, './frozen_models', \"frozen_vgg16\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "um",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 08:41:22) [MSC v.1929 64 bit (AMD64)]"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "876e189cbbe99a9a838ece62aae1013186c4bb7e0254a10cfa2f9b2381853efb"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
76
training/Yolo.ipynb
Normal file
76
training/Yolo.ipynb
Normal file
@ -0,0 +1,76 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!git clone https://github.com/ultralytics/ultralytics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -e ./ultralytics"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!unzip 'Aquarium Combined.v2-raw-1024.yolov8.zip' -d /aqua"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!yolo task=detect \\\n",
|
||||
" mode=train \\\n",
|
||||
" model=yolov8s.pt \\\n",
|
||||
" data=./data.yaml \\\n",
|
||||
" epochs=100 \\\n",
|
||||
" imgsz=640"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!yolo task=detect \\\n",
|
||||
" mode=val \\\n",
|
||||
" model=/content/ultralytics/runs/detect/train/weights/best.pt \\\n",
|
||||
" data=./data.yaml"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "um",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 08:41:22) [MSC v.1929 64 bit (AMD64)]"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "876e189cbbe99a9a838ece62aae1013186c4bb7e0254a10cfa2f9b2381853efb"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
0
training/__init__.py
Normal file
0
training/__init__.py
Normal file
129
training/data_load.py
Normal file
129
training/data_load.py
Normal file
@ -0,0 +1,129 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from skimage.io import imread
|
||||
import cv2 as cv
|
||||
from pathlib import Path
|
||||
import json
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import tensorflow as tf
|
||||
import os
|
||||
|
||||
def load_train_data(input_dir, image_size):
|
||||
|
||||
|
||||
image_dir = Path(input_dir)
|
||||
categories_name = []
|
||||
for file in os.listdir(image_dir):
|
||||
d = os.path.join(image_dir, file)
|
||||
if os.path.isdir(d):
|
||||
categories_name.append(file)
|
||||
|
||||
folders = [directory for directory in image_dir.iterdir() if directory.is_dir()]
|
||||
|
||||
train_img = []
|
||||
categories_count=[]
|
||||
labels=[]
|
||||
for i, direc in enumerate(folders):
|
||||
count = 0
|
||||
for obj in direc.iterdir():
|
||||
if os.path.isfile(obj) and os.path.basename(os.path.normpath(obj)) != 'desktop.ini':
|
||||
labels.append(os.path.basename(os.path.normpath(direc)))
|
||||
count += 1
|
||||
img = imread(obj)#zwraca ndarry postaci xSize x ySize x colorDepth
|
||||
img = img[:, :, :3]
|
||||
img = cv.resize(img, image_size, interpolation=cv.INTER_AREA)# zwraca ndarray
|
||||
img = img / 255 #normalizacja
|
||||
train_img.append(img)
|
||||
categories_count.append(count)
|
||||
X={}
|
||||
X["values"] = np.array(train_img)
|
||||
X["categories_name"] = categories_name
|
||||
X["categories_count"] = categories_count
|
||||
X["labels"]=labels
|
||||
return X
|
||||
|
||||
def load_test_data(input_dir, image_size):
|
||||
|
||||
image_path = Path(input_dir)
|
||||
|
||||
labels_path = image_path.parents[0] / 'test_labels.json'
|
||||
|
||||
jsonString = labels_path.read_text()
|
||||
objects = json.loads(jsonString)
|
||||
|
||||
categories_name = []
|
||||
categories_count=[]
|
||||
count = 0
|
||||
c = objects[0]['value']
|
||||
for e in objects:
|
||||
if e['value'] != c:
|
||||
categories_count.append(count)
|
||||
c = e['value']
|
||||
count = 1
|
||||
else:
|
||||
count += 1
|
||||
if not e['value'] in categories_name:
|
||||
categories_name.append(e['value'])
|
||||
|
||||
categories_count.append(count)
|
||||
|
||||
test_img = []
|
||||
|
||||
labels=[]
|
||||
for e in objects:
|
||||
p = image_path / e['filename']
|
||||
img = imread(p)#zwraca ndarry postaci xSize x ySize x colorDepth
|
||||
img = img[:, :, :3]
|
||||
img = cv.resize(img, image_size, interpolation=cv.INTER_AREA)# zwraca ndarray
|
||||
img = img / 255#normalizacja
|
||||
test_img.append(img)
|
||||
labels.append(e['value'])
|
||||
|
||||
X={}
|
||||
X["values"] = np.array(test_img)
|
||||
X["categories_name"] = categories_name
|
||||
X["categories_count"] = categories_count
|
||||
X["labels"]=labels
|
||||
return X
|
||||
|
||||
|
||||
|
||||
def load_data(shape, path_train, path_test):
|
||||
data_train = load_train_data(path_train, shape)
|
||||
values_train = data_train['values']
|
||||
labels_train = data_train['labels']
|
||||
|
||||
data_test = load_test_data(path_test, shape)
|
||||
X_test = data_test['values']
|
||||
y_test = data_test['labels']
|
||||
|
||||
X_train, X_validate, y_train, y_validate = train_test_split(values_train, labels_train, test_size=0.2, random_state=42)
|
||||
|
||||
class_le = LabelEncoder()
|
||||
y_train_enc = class_le.fit_transform(y_train)
|
||||
y_validate_enc = class_le.fit_transform(y_validate)
|
||||
y_test_enc = class_le.fit_transform(y_test)
|
||||
|
||||
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train_enc))
|
||||
validation_ds = tf.data.Dataset.from_tensor_slices((X_validate, y_validate_enc))
|
||||
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test_enc))
|
||||
|
||||
train_ds_size = tf.data.experimental.cardinality(train_ds).numpy()
|
||||
test_ds_size = tf.data.experimental.cardinality(test_ds).numpy()
|
||||
validation_ds_size = tf.data.experimental.cardinality(validation_ds).numpy()
|
||||
print("Training data size:", train_ds_size)
|
||||
print("Test data size:", test_ds_size)
|
||||
print("Validation data size:", validation_ds_size)
|
||||
|
||||
train_ds = (train_ds
|
||||
.shuffle(buffer_size=train_ds_size)
|
||||
.batch(batch_size=32, drop_remainder=True))
|
||||
test_ds = (test_ds
|
||||
.shuffle(buffer_size=train_ds_size)
|
||||
.batch(batch_size=32, drop_remainder=True))
|
||||
validation_ds = (validation_ds
|
||||
.shuffle(buffer_size=train_ds_size)
|
||||
.batch(batch_size=32, drop_remainder=True))
|
||||
|
||||
return train_ds, test_ds, validation_ds
|
14
training/model_freeze.py
Normal file
14
training/model_freeze.py
Normal file
@ -0,0 +1,14 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
||||
|
||||
def freeze_model(model, output_path, name):
|
||||
full_model = tf.function(lambda x: model(x))
|
||||
full_model = full_model.get_concrete_function(
|
||||
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
|
||||
frozen_func = convert_variables_to_constants_v2(full_model)
|
||||
frozen_func.graph.as_graph_def()
|
||||
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
|
||||
logdir=output_path,
|
||||
name=f"{name}.pb",
|
||||
as_text=False)
|
||||
return
|
Loading…
Reference in New Issue
Block a user