{ "cells": [ { "cell_type": "markdown", "id": "dd9a88f0", "metadata": {}, "source": [ "#### Aleksandra Jonas, Aleksandra Gronowska, Iwona Christop" ] }, { "cell_type": "markdown", "id": "acda0087", "metadata": {}, "source": [ "### Generowanie dodatkowych zdjęć w oparciu o filtry krawędziowe" ] }, { "cell_type": "code", "execution_count": 1, "id": "f790226b", "metadata": {}, "outputs": [], "source": [ "import os\n", "import cv2 as cv\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import json\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "id": "4e3ebfd0", "metadata": {}, "outputs": [], "source": [ "def fix_float_img(img):\n", " img_normed = 255 * (img - img.min()) / (img.max() - img.min())\n", " img_normed = np.array(img_normed, np.int)\n", " return img_normed" ] }, { "cell_type": "code", "execution_count": 3, "id": "ffeda62d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\jonas\\AppData\\Local\\Temp\\ipykernel_7316\\1949762618.py:3: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " img_normed = np.array(img_normed, np.int)\n" ] } ], "source": [ "directory = r\"C:\\Users\\jonas\\OneDrive\\Pulpit\\train_test_sw\\train_sw\"\n", "subdirs = [r\"\\Tomato\", r\"\\Lemon\", r\"\\Beech\", r\"\\Mean\", r\"\\Gardenia\"]\n", "\n", "json_entries = []\n", "\n", "for sub in subdirs:\n", " path = directory + sub\n", " \n", " for filename in os.listdir(path):\n", " f = os.path.join(path, filename)\n", " \n", " if os.path.isfile(f):\n", " img = cv.imread(f)\n", "\n", " # edge detecting using canny\n", " img_gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)\n", " img_blurred = cv.GaussianBlur(img_gray, (3, 3), 0, 0)\n", " img_laplacian = cv.Laplacian(img_blurred, cv.CV_32F, ksize=3)\n", "\n", " cv.normalize(img_laplacian, img_laplacian, 0, 1, norm_type=cv.NORM_MINMAX, dtype=cv.CV_32F)\n", "\n", " filename_edge = f[:-4] + 'K.png'\n", " final_edge = fix_float_img(img_laplacian)\n", " cv.imwrite(filename_edge, final_edge)\n", "\n", "# # rotating images\n", "# img_rotated = cv.rotate(img, cv.ROTATE_90_CLOCKWISE)\n", "# img_rot_4 = cv.cvtColor(img_rotated, cv.COLOR_RGB2RGBA)\n", "# img_rot_4[:, :, 3] = np.zeros((256,1))\n", "# filename_rotated = f[:-4] + 'R.png'\n", "# cv.imwrite(filename_rotated, img_rotated)" ] }, { "cell_type": "markdown", "id": "52952131", "metadata": {}, "source": [ "### MLP" ] }, { "cell_type": "code", "execution_count": 4, "id": "f7868480", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import subprocess\n", "import pkg_resources\n", "import numpy as np\n", "\n", "required = { 'scikit-image'}\n", "installed = {pkg.key for pkg in pkg_resources.working_set}\n", "missing = required - installed\n", "\n", "if missing: \n", " python = sys.executable\n", " subprocess.check_call([python, '-m', 'pip', 'install', *missing], stdout=subprocess.DEVNULL)\n", "\n", "def load_train_data(input_dir, newSize=(64,64)):\n", " import numpy as np\n", " import pandas as pd\n", " import os\n", " from skimage.io import imread\n", " import cv2 as cv\n", " from pathlib import Path\n", " import random\n", " from shutil import copyfile, rmtree\n", " import json\n", "\n", " import seaborn as sns\n", " import matplotlib.pyplot as plt\n", "\n", " import matplotlib\n", " \n", " image_dir = Path(input_dir)\n", " categories_name = []\n", " for file in os.listdir(image_dir):\n", " d = os.path.join(image_dir, file)\n", " if os.path.isdir(d):\n", " categories_name.append(file)\n", "\n", " folders = [directory for directory in image_dir.iterdir() if directory.is_dir()]\n", "\n", " train_img = []\n", " categories_count=[]\n", " labels=[]\n", " for i, direc in enumerate(folders):\n", " count = 0\n", " for obj in direc.iterdir():\n", " if os.path.isfile(obj) and os.path.basename(os.path.normpath(obj)) != 'desktop.ini':\n", " labels.append(os.path.basename(os.path.normpath(direc)))\n", " count += 1\n", " img = imread(obj)#zwraca ndarry postaci xSize x ySize x colorDepth\n", " if img.shape[-1] == 256:\n", " img = np.repeat(img[..., np.newaxis], 4, axis=2)\n", " elif img.shape[-1] == 3:\n", " img[:, :, 3] = img[1]\n", " img = cv.resize(img, newSize, interpolation=cv.INTER_AREA)# zwraca ndarray\n", " img = img / 255#normalizacja\n", " train_img.append(img)\n", " categories_count.append(count)\n", " X={}\n", " X[\"values\"] = np.array(train_img)\n", " X[\"categories_name\"] = categories_name\n", " X[\"categories_count\"] = categories_count\n", " X[\"labels\"]=labels\n", " return X\n", "\n", "def load_test_data(input_dir, newSize=(256,256)):\n", " import numpy as np\n", " import pandas as pd\n", " import os\n", " from skimage.io import imread\n", " import cv2 as cv\n", " from pathlib import Path\n", " import random\n", " from shutil import copyfile, rmtree\n", " import json\n", "\n", " import seaborn as sns\n", " import matplotlib.pyplot as plt\n", "\n", " import matplotlib\n", "\n", " image_path = Path(input_dir)\n", "\n", " labels_path = image_path.parents[0] / 'test_labels.json'\n", "\n", " jsonString = labels_path.read_text()\n", " objects = json.loads(jsonString)\n", "\n", " categories_name = []\n", " categories_count=[]\n", " count = 0\n", " c = objects[0]['value']\n", " for e in objects:\n", " if e['value'] != c:\n", " categories_count.append(count)\n", " c = e['value']\n", " count = 1\n", " else:\n", " count += 1\n", " if not e['value'] in categories_name:\n", " categories_name.append(e['value'])\n", "\n", " categories_count.append(count)\n", " \n", " test_img = []\n", "\n", " labels=[]\n", " for e in objects:\n", " p = image_path / e['filename']\n", " img = imread(p)#zwraca ndarry postaci xSize x ySize x colorDepth\n", " img = cv.resize(img, newSize, interpolation=cv.INTER_AREA)# zwraca ndarray\n", " img = img / 255#normalizacja\n", " test_img.append(img)\n", " labels.append(e['value'])\n", "\n", " X={}\n", " X[\"values\"] = np.array(test_img)\n", " X[\"categories_name\"] = categories_name\n", " X[\"categories_count\"] = categories_count\n", " X[\"labels\"]=labels\n", " return X\n", "\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", "# Data load\n", "data_train = load_train_data(\"train_test_sw/train_sw\", newSize=(16,16))\n", "X_train = data_train['values']\n", "y_train = data_train['labels']\n", "\n", "data_test = load_test_data(\"train_test_sw/test_sw\", newSize=(16,16))\n", "X_test = data_test['values']\n", "y_test = data_test['labels']\n", "\n", "class_le = LabelEncoder()\n", "y_train_enc = class_le.fit_transform(y_train)\n", "y_test_enc = class_le.fit_transform(y_test)\n", "\n", "X_train = X_train.flatten().reshape(X_train.shape[0], int(np.prod(X_train.shape) / X_train.shape[0]))\n", "X_test = X_test.flatten().reshape(X_test.shape[0], int(np.prod(X_test.shape) / X_test.shape[0]))" ] }, { "cell_type": "code", "execution_count": 5, "id": "e90f9516", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "(4708, 1024)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.shape" ] }, { "cell_type": "code", "execution_count": 6, "id": "542f53d9", "metadata": {}, "outputs": [], "source": [ "X = X_train.T" ] }, { "cell_type": "code", "execution_count": 7, "id": "d75f9dd7", "metadata": {}, "outputs": [], "source": [ "m_train, _ = X.shape\n", "m, n = X.shape" ] }, { "cell_type": "code", "execution_count": 8, "id": "27e98944", "metadata": {}, "outputs": [], "source": [ "def init_params():\n", " W1 = np.random.rand(5, 1024) - 0.5\n", " b1 = np.random.rand(5, 1) - 0.5\n", " W2 = np.random.rand(5, 5) - 0.5\n", " b2 = np.random.rand(5, 1) - 0.5\n", " return W1, b1, W2, b2\n", "\n", "def ReLU(Z):\n", " return np.maximum(Z, 0)\n", "\n", "def softmax(Z):\n", " A = np.exp(Z) / sum(np.exp(Z))\n", " return A\n", " \n", "def forward_prop(W1, b1, W2, b2, X):\n", " Z1 = W1.dot(X) + b1\n", " A1 = ReLU(Z1)\n", " Z2 = W2.dot(A1) + b2\n", " A2 = softmax(Z2)\n", " return Z1, A1, Z2, A2\n", "\n", "def ReLU_deriv(Z):\n", " return Z > 0\n", "\n", "def one_hot(Y):\n", " one_hot_Y = np.zeros((Y.size, Y.max() + 1))\n", " one_hot_Y[np.arange(Y.size), Y] = 1\n", " one_hot_Y = one_hot_Y.T\n", " return one_hot_Y\n", "\n", "def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):\n", " one_hot_Y = one_hot(Y)\n", " dZ2 = A2 - one_hot_Y\n", " dW2 = 1 / m * dZ2.dot(A1.T)\n", " db2 = 1 / m * np.sum(dZ2)\n", " dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1)\n", " dW1 = 1 / m * dZ1.dot(X.T)\n", " db1 = 1 / m * np.sum(dZ1)\n", " return dW1, db1, dW2, db2\n", "\n", "def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n", " W1 = W1 - alpha * dW1\n", " b1 = b1 - alpha * db1 \n", " W2 = W2 - alpha * dW2 \n", " b2 = b2 - alpha * db2 \n", " return W1, b1, W2, b2" ] }, { "cell_type": "code", "execution_count": 9, "id": "fc219ddc", "metadata": {}, "outputs": [], "source": [ "def get_predictions(A2):\n", " return np.argmax(A2, 0)\n", "\n", "def get_accuracy(predictions, Y):\n", " print(predictions, Y)\n", " return np.sum(predictions == Y) / Y.size\n", "\n", "def gradient_descent(X, Y, alpha, iterations):\n", " W1, b1, W2, b2 = init_params()\n", " for i in range(iterations):\n", " Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)\n", " dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)\n", " W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n", " if i % 10 == 0:\n", " print(\"Iteration: \", i)\n", " predictions = get_predictions(A2)\n", " print(get_accuracy(predictions, Y))\n", " return W1, b1, W2, b2" ] }, { "cell_type": "code", "execution_count": 10, "id": "a81b8a73", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iteration: 0\n", "[0 0 0 ... 0 0 0] [0 0 0 ... 4 4 4]\n", "0.20348343245539507\n", "Iteration: 10\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 20\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 30\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 40\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 50\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 60\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 70\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 80\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 90\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 100\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 110\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 120\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 130\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 140\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 150\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 160\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 170\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 180\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 190\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 200\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 210\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 220\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 230\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 240\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 250\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 260\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 270\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 280\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 290\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 300\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 310\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 320\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 330\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 340\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 350\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 360\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 370\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 380\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 390\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 400\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 410\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 420\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 430\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 440\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 450\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 460\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 470\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 480\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 490\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 500\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 510\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 520\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 530\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 540\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 550\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 560\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 570\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 580\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 590\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 600\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 610\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 620\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 630\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 640\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 650\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 660\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 670\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 680\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 690\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 700\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 710\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 720\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 730\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 740\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 750\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 760\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 770\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 780\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 790\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 800\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 810\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 820\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 830\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 840\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 850\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 860\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 870\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 880\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 890\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 900\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 910\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 920\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 930\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 940\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 950\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 960\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 970\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 980\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n", "Iteration: 990\n", "[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n", "0.2296091758708581\n" ] } ], "source": [ "W1, b1, W2, b2 = gradient_descent(X, y_train_enc, 0.10, 1000)" ] }, { "cell_type": "code", "execution_count": 11, "id": "51ff4997", "metadata": {}, "outputs": [], "source": [ "def make_predictions(X, W1, b1, W2, b2):\n", " _, _, _, A2 = forward_prop(W1, b1, W2, b2, X)\n", " predictions = get_predictions(A2)\n", " return predictions" ] }, { "cell_type": "code", "execution_count": 12, "id": "a0861788", "metadata": {}, "outputs": [], "source": [ "dev_predictions = make_predictions(X_test.T, W1, b1, W2, b2)" ] }, { "cell_type": "code", "execution_count": 13, "id": "971555e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2\n", " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", " 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3\n", " 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n", " 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4]\n" ] }, { "data": { "text/plain": [ "0.20077220077220076" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_accuracy(dev_predictions, y_test_enc)" ] }, { "cell_type": "code", "execution_count": null, "id": "6642aa68", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.4" } }, "nbformat": 4, "nbformat_minor": 5 }