SW-Wiktor-Bombola/SW-Unity/Plants Neural Network.ipynb

298 lines
29 KiB
Plaintext
Raw Normal View History

2021-12-19 23:09:15 +01:00
{
"cells": [
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 28,
2021-12-19 23:09:15 +01:00
"id": "comprehensive-talent",
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import os\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation, Conv2D, MaxPooling2D\n",
2021-12-20 04:30:49 +01:00
"from keras.preprocessing.image import ImageDataGenerator\n",
2021-12-19 23:09:15 +01:00
"from sklearn.neural_network import MLPClassifier\n",
"from sklearn.model_selection import train_test_split\n",
2021-12-20 04:30:49 +01:00
"from tensorflow.keras.optimizers import RMSprop\n",
2021-12-19 23:09:15 +01:00
"from sklearn.metrics import classification_report\n",
2021-12-20 04:30:49 +01:00
"import re\n",
"import matplotlib.pyplot as plt"
2021-12-19 23:09:15 +01:00
]
},
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 15,
"id": "macro-michigan",
2021-12-19 23:09:15 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-12-20 04:30:49 +01:00
"train_data_dir=\"../Trees\"\n",
"validation_data_dir=\"../Trees\"\n",
"batch_size=14\n",
"img_height, img_width = 60,80"
2021-12-19 23:09:15 +01:00
]
},
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 16,
"id": "defined-briefing",
2021-12-19 23:09:15 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-12-20 04:30:49 +01:00
"Found 428 images belonging to 3 classes.\n",
"Found 105 images belonging to 3 classes.\n"
2021-12-19 23:09:15 +01:00
]
}
],
"source": [
2021-12-20 04:30:49 +01:00
"train_datagen = ImageDataGenerator(rescale=1./255,\n",
" shear_range=0.2,\n",
" zoom_range=0.2,\n",
" horizontal_flip=True,\n",
" validation_split=0.2) # set validation split\n",
"\n",
"train_generator = train_datagen.flow_from_directory(\n",
" train_data_dir,\n",
" target_size=(img_height, img_width),\n",
" batch_size=batch_size,\n",
" class_mode='categorical',\n",
" subset='training') # set as training data\n",
"\n",
"validation_generator = train_datagen.flow_from_directory(\n",
" validation_data_dir, # same directory as training data\n",
" target_size=(img_height, img_width),\n",
" batch_size=batch_size,\n",
" class_mode='categorical',\n",
" subset='validation')"
2021-12-19 23:09:15 +01:00
]
},
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 17,
"id": "natural-cutting",
2021-12-19 23:09:15 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-12-20 04:30:49 +01:00
"model = Sequential()"
2021-12-19 23:09:15 +01:00
]
},
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 18,
"id": "conservative-hypothetical",
2021-12-19 23:09:15 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-12-20 04:30:49 +01:00
"model.add(Conv2D(32, (3,3), activation='relu', input_shape=(60, 80, 3)))\n",
"model.add(MaxPooling2D((2,2)))\n",
"model.add(Conv2D(64, (3,3), activation='relu'))\n",
"model.add(MaxPooling2D((2,2)))\n",
"model.add(Conv2D(64, (3,3), activation='relu'))\n",
"model.add(MaxPooling2D((2,2)))\n",
"model.add(Flatten())\n",
"model.add(Dense(512, activation='relu'))\n",
"model.add(Dense(512, activation='relu'))\n",
2021-12-19 23:09:15 +01:00
"\n",
2021-12-20 04:30:49 +01:00
"model.add(Dense(3, activation='softmax'))"
2021-12-19 23:09:15 +01:00
]
},
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 23,
"id": "irish-monitoring",
2021-12-19 23:09:15 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-12-20 04:30:49 +01:00
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: pydot in c:\\program files\\python39\\lib\\site-packages (1.4.2)\n",
"Requirement already satisfied: pyparsing>=2.1.4 in c:\\users\\wbloc\\appdata\\roaming\\python\\python39\\site-packages (from pydot) (2.4.7)\n",
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: graphviz in c:\\program files\\python39\\lib\\site-packages (0.19.1)\n",
"('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')\n"
2021-12-19 23:09:15 +01:00
]
}
],
"source": [
2021-12-20 04:30:49 +01:00
"from tensorflow.keras.utils import plot_model \n",
"!pip install pydot\n",
"!pip install graphviz\n",
"plot_model(model, to_file='model1_plot.png', show_shapes=True,show_dtype=True, show_layer_names=True, expand_nested=True,)"
2021-12-19 23:09:15 +01:00
]
},
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 24,
2021-12-19 23:09:15 +01:00
"id": "illegal-zoning",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-12-20 04:30:49 +01:00
"Model: \"sequential_2\"\n",
2021-12-19 23:09:15 +01:00
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
2021-12-20 04:30:49 +01:00
"conv2d_6 (Conv2D) (None, 58, 78, 32) 896 \n",
2021-12-19 23:09:15 +01:00
"_________________________________________________________________\n",
2021-12-20 04:30:49 +01:00
"max_pooling2d_6 (MaxPooling2 (None, 29, 39, 32) 0 \n",
2021-12-19 23:09:15 +01:00
"_________________________________________________________________\n",
2021-12-20 04:30:49 +01:00
"conv2d_7 (Conv2D) (None, 27, 37, 64) 18496 \n",
2021-12-19 23:09:15 +01:00
"_________________________________________________________________\n",
2021-12-20 04:30:49 +01:00
"max_pooling2d_7 (MaxPooling2 (None, 13, 18, 64) 0 \n",
2021-12-19 23:09:15 +01:00
"_________________________________________________________________\n",
2021-12-20 04:30:49 +01:00
"conv2d_8 (Conv2D) (None, 11, 16, 64) 36928 \n",
2021-12-19 23:09:15 +01:00
"_________________________________________________________________\n",
2021-12-20 04:30:49 +01:00
"max_pooling2d_8 (MaxPooling2 (None, 5, 8, 64) 0 \n",
2021-12-19 23:09:15 +01:00
"_________________________________________________________________\n",
2021-12-20 04:30:49 +01:00
"flatten_2 (Flatten) (None, 2560) 0 \n",
2021-12-19 23:09:15 +01:00
"_________________________________________________________________\n",
2021-12-20 04:30:49 +01:00
"dense_6 (Dense) (None, 512) 1311232 \n",
2021-12-19 23:09:15 +01:00
"_________________________________________________________________\n",
2021-12-20 04:30:49 +01:00
"dense_7 (Dense) (None, 512) 262656 \n",
"_________________________________________________________________\n",
"dense_8 (Dense) (None, 3) 1539 \n",
2021-12-19 23:09:15 +01:00
"=================================================================\n",
2021-12-20 04:30:49 +01:00
"Total params: 1,631,747\n",
"Trainable params: 1,631,747\n",
2021-12-19 23:09:15 +01:00
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"None\n"
]
}
],
"source": [
"print(model.summary())"
]
},
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 25,
2021-12-19 23:09:15 +01:00
"id": "cardiac-highland",
"metadata": {},
"outputs": [],
"source": [
2021-12-20 04:30:49 +01:00
"model.compile(optimizer=RMSprop(learning_rate=0.0001),\n",
" loss=tf.keras.losses.CategoricalCrossentropy(),\n",
2021-12-19 23:09:15 +01:00
" metrics=['accuracy'])"
]
},
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 26,
2021-12-19 23:09:15 +01:00
"id": "informed-baker",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 3s 393ms/step - loss: 0.9368 - accuracy: 0.5897 - val_loss: 0.7715 - val_accuracy: 0.6786\n",
2021-12-19 23:09:15 +01:00
"Epoch 2/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 2s 352ms/step - loss: 0.8682 - accuracy: 0.6310 - val_loss: 0.7431 - val_accuracy: 0.6786\n",
2021-12-19 23:09:15 +01:00
"Epoch 3/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 2s 352ms/step - loss: 0.7690 - accuracy: 0.6667 - val_loss: 0.6637 - val_accuracy: 0.7143\n",
2021-12-19 23:09:15 +01:00
"Epoch 4/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 2s 351ms/step - loss: 0.7816 - accuracy: 0.6310 - val_loss: 0.6656 - val_accuracy: 0.6607\n",
2021-12-19 23:09:15 +01:00
"Epoch 5/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 2s 346ms/step - loss: 0.7070 - accuracy: 0.6667 - val_loss: 0.7131 - val_accuracy: 0.5893\n",
2021-12-19 23:09:15 +01:00
"Epoch 6/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 2s 351ms/step - loss: 0.6677 - accuracy: 0.6905 - val_loss: 0.5662 - val_accuracy: 0.7143\n",
2021-12-19 23:09:15 +01:00
"Epoch 7/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 2s 350ms/step - loss: 0.6171 - accuracy: 0.7262 - val_loss: 0.7703 - val_accuracy: 0.5893\n",
2021-12-19 23:09:15 +01:00
"Epoch 8/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 2s 351ms/step - loss: 0.6001 - accuracy: 0.7738 - val_loss: 0.4881 - val_accuracy: 0.8036\n",
2021-12-19 23:09:15 +01:00
"Epoch 9/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 2s 362ms/step - loss: 0.4903 - accuracy: 0.7821 - val_loss: 0.4565 - val_accuracy: 0.9107\n",
2021-12-19 23:09:15 +01:00
"Epoch 10/10\n",
2021-12-20 04:30:49 +01:00
"6/6 [==============================] - 2s 351ms/step - loss: 0.5107 - accuracy: 0.7976 - val_loss: 0.5301 - val_accuracy: 0.8393\n"
2021-12-19 23:09:15 +01:00
]
}
],
"source": [
2021-12-20 04:30:49 +01:00
"history = model.fit(train_generator, steps_per_epoch=6, epochs=10, verbose=1,\n",
" validation_data = validation_generator, validation_steps = 4)"
2021-12-19 23:09:15 +01:00
]
},
{
"cell_type": "code",
2021-12-20 04:30:49 +01:00
"execution_count": 29,
2021-12-19 23:09:15 +01:00
"id": "inclusive-chess",
"metadata": {},
"outputs": [
{
2021-12-20 04:30:49 +01:00
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x1c0ecf5e190>"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAA5cElEQVR4nO3deVhV5fbA8e9iUAQVJ8QB5xlFU8mxnC1N00xNbbYcmq1udZvrVvd3mwfLSiwrmzQ1zaw0zQFLyzQVFYcUJ3BCFBCV+f39sdFQQQ94Nhs46/M8PJ5hn70XR91r73dYrxhjUEop5bm8nA5AKaWUszQRKKWUh9NEoJRSHk4TgVJKeThNBEop5eE0ESillIezLRGIyFQROSwim/J5X0RkoojsEJEoEWlnVyxKKaXyZ+cdwadAvwu83x9okvMzDvjAxliUUkrlw7ZEYIyJBI5eYJPBwDRj+R2oJCI17YpHKaVU3nwcPHZtYF+u57E5rx04d0MRGYd110BAQED75s2bF0mASilVWqxdu/aIMSYor/ecTAQuM8ZEABEA4eHhZs2aNQ5HpJRSJYuI7MnvPSdHDcUBdXI9D8l5TSmlVBFyMhHMA27NGT3UCUgyxpzXLKSUUspetjUNicjXQA+gmojEAs8BvgDGmA+BH4FrgB3ASWC0XbEopZTKn22JwBgz6iLvG+Beu46vlFLKNTqzWCmlPJwmAqWU8nCaCJRSysNpIlBKKQ+niUAppTycJgKllPJwmgiUUsrDaSJQSikPp4lAKaU8nCYCpZTycJoIlFLKw2kiUEopD6eJQCmlPJwmAqWU8nCaCJRSysNpIlBKKQ+niUAppTycJgKllPJwmgiUUsrDaSJQSikPp4lAKaU8nCYCpZTycJoIlFLKw2kiUEopD6eJQCmlPJwmAqWU8nCaCJRSysNpIlBKKQ+niUAppTycJgKllPJwmgiUUsrDaSJQSikPp4lAKaUuZMN02Pen01HYShOBUkrlJ3oezBkPn15jPS6lNBEopVReEvfBvPug5mXWz8zbYO1nTkdlC1sTgYj0E5FtIrJDRB7P4/16IvKLiESJyDIRCbEzHqWUcklWJnw7DrKzYNhUuHUuNOoN3z8AK94EY5yO0K1sSwQi4g1MAvoDocAoEQk9Z7PXgWnGmNbAC8D/7IpHKaVctuJ12LsSBrwJVRtBmQAY9TWEDYdf/gM/Pw3Z2U5H6TY+Nu67A7DDGBMDICLTgcFAdK5tQoGHcx4vBebaGI9SSl3cnpWw/BVoPQLajPjndW9fGBIB5arAqvfg5FEYNNF6vYSzs2moNrAv1/PYnNdy2wBcn/N4CFBBRKqeuyMRGScia0RkTXx8vC3BKqUUp47B7LFQqR5c8/r573t5Qf9XoOdTsOErmHEzZJwq+jjdzOnO4keA7iKyDugOxAFZ525kjIkwxoQbY8KDgoKKOkallCcwBuY9ACkHYdjH4Fcx7+1EoPtjMOAN2L4QPr8eTiUWaajuZmciiAPq5HoekvPaGcaY/caY640xbYGncl5LtDEmpZTK29pPYcs86P0s1G5/8e0vH2N1JMf+CZ8OhOOHbA/RLnYmgj+BJiLSQETKACOBswbiikg1ETkdwxPAVBvjUUqpvB3eAgseh4Y9ofP9rn+u1fVw0zdwNAamXmX9WQLZlgiMMZnAfcBCYAvwjTFms4i8ICKDcjbrAWwTke1AMPBfu+JRSqk8ZaTCrDuhTHkYMtnqByiIRr3gtu8hNRk+vhoObrQnThuJKWHjYcPDw82aNWucDkMpVVr8+CisjoCbZkGTvoXfT/w2+HwIpKXAjdOhXhf3xegGIrLWGBOe13tOdxYrpZRztv5oJYFO915aEgAIagZ3LITy1a2EsO0n98RYBDQRKKU8U/J++O4eqNEa+jznnn1WqmMlg+qhMP0mWP+Ve/ZrM00ESinPk51llZDITIdhn4BPWfftO6Aq3DYPGlwJc++Gle+5b9820USglPI8v74Ju1fANa9Btcbu33/ZCnDjNxA6GH5+ChY/X6zrE2kiUEp5ln2rYen/oNUwuOxG+47jU9a62wi/A359yypYl5Vp3/EugZ21hpRSqng5lWgNFQ2sDQPftGYJ28nL2ypc518NIl+16hMN/Rh8/ew9bgHpHYFSyjMYA/MfguQ4GDoV/AKL5rgi0Osp6PcKbJ0PXw6z5hwUI5oIlFKeYd0XsPlb66Rc5/KiP36nu+D6KbB3FXw2EFKKTwFNTQRKqdIvfjv89Bg06AZdH3QujtY3wMivrXimXg2Je52LJRdNBEqp0i0zDWbfAT5+1noCXt7OxtP0KmvFs5NH4OOrrDpHDtNEoJQq3RY/b9X/ue59qFjT6WgsdTvB6J+sfoup/ayRTA7SRKCUKr22L4Tf34cO46FZf6ejOVtwS7hzIfhXgWmD4e/FjoWiiUApVTodP2jN7A1uBX1fcDqavFWub5WkqNoYvh4BG2c5EoYmAqVU6ZOdDXPGQ/pJa/GYYjZu/yzlq8Pt86FOJ5g9Bv6IKPIQNBEopUqflRMhZpm1vnBQM6ejuTi/QLh5NjQfAD89Ckv/r0hLUmgiUEqVLrFrYcmLEHodtLvV6Whc5+sHwz+DtjfD8lfgx0es4nhFQEtMKKVKj9Rka6hohZpw7Tv2l5BwN28fGPQe+FeF396xSlIMmQw+ZWw9rCYCpVTp8cO/rElao3+CcpWcjqZwRKzObf9qsOgZSE2EGz6HsuVtO6Q2DSmlSocN02HjN9DjCWucfknX9QEYPAlillvDS08ete1QmgiUUiVfwk7rbqBeV7jyX05H4z5tb4YRX1gT4qb2g6RYWw6jiUApVbJlpsOsO8Db1yrq5nQJCXdrfg3c8i0cP2BNkLOB9hEopUq2JS/AgfUw4ktrnYHSqP4VcO9q20pk6B2BUqrk2vELrHwXwu+EFgOdjsZeNtZJ0kSglCqZUg7DnLsgqAVc/V+noynRtGlIKVXyZGdbdYTSkuHW78C3nNMRlWiaCJRSJc/v78OOxdZ6wMGhTkdT4mnTkFKqZNm/zlpjoPlACL/D6WhKBU0ESqmSIy0FZt1pVewc9G7JKyFRTGnTkFKq5PjpMTi2C2773lrQRbmF3hEopUqGqJmw/kvo9qg1rl65jSYCpVTxd3QXzH/IWryl22NOR1PqaCJQShVvWRkw+07w8oKhU6xSzcqt9BtVShVvS/8P4tZai7ZUqut0NKWS3hEopYqvmGXw61vQ7jZoeZ3T0ZRatiYCEeknIttEZIeIPJ7H+3VFZKmIrBORKBG5xs54lFIlyIkj8O14qNYU+r3sdDSlmm2JQES8gUlAfyAUGCUi504BfBr4xhjTFhgJvG9XPEqpEmbBE3DqGAz7GMr4Ox1NqWbnHUEHYIcxJsYYkw5MBwafs40BKuY8DgT22xiPUqqkOHUMoufC5XdCjTCnoyn17EwEtYF9uZ7H5ryW2/PAzSISC/wI3J/XjkRknIisEZE18fHxdsSqlCpONs+BrHRoM9LpSByXlW34+9BxZq+NJSY+xZZjOD1qaBTwqTHmDRHpDHwuIq2MMdm5NzLGRAARAOHh4caBOJVSRWnDDKu8dI3WTkdSpLKyDTHxKUTFJrExLolNcUlEH0jmZHoWAE8PaEHDIPcvYn/RRCAi1wI/nHtydkEcUCfX85Cc13K7E+gHYIxZJSJ+QDXgcAGPpZQqLY7ugn2/Q+/nSnUtodMn/Y1xSUTFnn/SL+frTWititwQXoew2oGEhQTSyIYkAK7dEYwA3haR2cBUY8xWF/f9J9BERBpgJYCRwI3nbLMX6A18KiItAD9A236U8mRR3wACrW9wOhK3yX3S3xiXxMbY/E/6rWoH0jrnpO/tVTSJ8KKJwBhzs4hUJKcZR0QM8AnwtTHm+AU+lyki9wELAW+sJLJZRF4A1hhj5gH/AqaIyENYHce3G2O06UcpT2UMRE23agkFhjgdTaGce9LfFJfE5v3/nPT9fL1oWSvwzEk/rHYgjYIC8PF2blqXS30ExphkEZkFlAMeBIYAj4rIRGPMuxf43I9YncC5X3s21+NooGsh4lZKlUaxa+BoDFz5L6cjcUl
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
2021-12-19 23:09:15 +01:00
}
],
"source": [
2021-12-20 04:30:49 +01:00
"plt.plot(history.history['accuracy'], label='accuracy')\n",
"plt.plot(history.history['val_accuracy'], label = 'val_accuracy')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Accuracy')\n",
"plt.ylim([0.5, 1])\n",
"plt.legend(loc='lower right')"
2021-12-19 23:09:15 +01:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "marine-satellite",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}