Symulowanie-wizualne/sw-lab5.ipynb

788 lines
25 KiB
Plaintext
Raw Normal View History

2022-11-29 18:58:49 +01:00
{
"cells": [
{
"cell_type": "markdown",
"id": "dd9a88f0",
"metadata": {},
"source": [
"#### Aleksandra Jonas, Aleksandra Gronowska, Iwona Christop"
]
},
{
"cell_type": "markdown",
"id": "acda0087",
"metadata": {},
"source": [
"### Generowanie dodatkowych zdjęć w oparciu o filtry krawędziowe"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f790226b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2 as cv\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import json\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4e3ebfd0",
"metadata": {},
"outputs": [],
"source": [
"def fix_float_img(img):\n",
" img_normed = 255 * (img - img.min()) / (img.max() - img.min())\n",
" img_normed = np.array(img_normed, np.int)\n",
" return img_normed"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ffeda62d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\jonas\\AppData\\Local\\Temp\\ipykernel_7316\\1949762618.py:3: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" img_normed = np.array(img_normed, np.int)\n"
]
}
],
"source": [
"directory = r\"C:\\Users\\jonas\\OneDrive\\Pulpit\\train_test_sw\\train_sw\"\n",
"subdirs = [r\"\\Tomato\", r\"\\Lemon\", r\"\\Beech\", r\"\\Mean\", r\"\\Gardenia\"]\n",
"\n",
"json_entries = []\n",
"\n",
"for sub in subdirs:\n",
" path = directory + sub\n",
" \n",
" for filename in os.listdir(path):\n",
" f = os.path.join(path, filename)\n",
" \n",
" if os.path.isfile(f):\n",
" img = cv.imread(f)\n",
"\n",
" # edge detecting using canny\n",
" img_gray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)\n",
" img_blurred = cv.GaussianBlur(img_gray, (3, 3), 0, 0)\n",
" img_laplacian = cv.Laplacian(img_blurred, cv.CV_32F, ksize=3)\n",
"\n",
" cv.normalize(img_laplacian, img_laplacian, 0, 1, norm_type=cv.NORM_MINMAX, dtype=cv.CV_32F)\n",
"\n",
" filename_edge = f[:-4] + 'K.png'\n",
" final_edge = fix_float_img(img_laplacian)\n",
" cv.imwrite(filename_edge, final_edge)\n",
"\n",
"# # rotating images\n",
"# img_rotated = cv.rotate(img, cv.ROTATE_90_CLOCKWISE)\n",
"# img_rot_4 = cv.cvtColor(img_rotated, cv.COLOR_RGB2RGBA)\n",
"# img_rot_4[:, :, 3] = np.zeros((256,1))\n",
"# filename_rotated = f[:-4] + 'R.png'\n",
"# cv.imwrite(filename_rotated, img_rotated)"
]
},
{
"cell_type": "markdown",
"id": "52952131",
"metadata": {},
"source": [
"### MLP"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f7868480",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import subprocess\n",
"import pkg_resources\n",
"import numpy as np\n",
"\n",
"required = { 'scikit-image'}\n",
"installed = {pkg.key for pkg in pkg_resources.working_set}\n",
"missing = required - installed\n",
"\n",
"if missing: \n",
" python = sys.executable\n",
" subprocess.check_call([python, '-m', 'pip', 'install', *missing], stdout=subprocess.DEVNULL)\n",
"\n",
"def load_train_data(input_dir, newSize=(64,64)):\n",
" import numpy as np\n",
" import pandas as pd\n",
" import os\n",
" from skimage.io import imread\n",
" import cv2 as cv\n",
" from pathlib import Path\n",
" import random\n",
" from shutil import copyfile, rmtree\n",
" import json\n",
"\n",
" import seaborn as sns\n",
" import matplotlib.pyplot as plt\n",
"\n",
" import matplotlib\n",
" \n",
" image_dir = Path(input_dir)\n",
" categories_name = []\n",
" for file in os.listdir(image_dir):\n",
" d = os.path.join(image_dir, file)\n",
" if os.path.isdir(d):\n",
" categories_name.append(file)\n",
"\n",
" folders = [directory for directory in image_dir.iterdir() if directory.is_dir()]\n",
"\n",
" train_img = []\n",
" categories_count=[]\n",
" labels=[]\n",
" for i, direc in enumerate(folders):\n",
" count = 0\n",
" for obj in direc.iterdir():\n",
" if os.path.isfile(obj) and os.path.basename(os.path.normpath(obj)) != 'desktop.ini':\n",
" labels.append(os.path.basename(os.path.normpath(direc)))\n",
" count += 1\n",
" img = imread(obj)#zwraca ndarry postaci xSize x ySize x colorDepth\n",
" if img.shape[-1] == 256:\n",
" img = np.repeat(img[..., np.newaxis], 4, axis=2)\n",
" elif img.shape[-1] == 3:\n",
" img[:, :, 3] = img[1]\n",
" img = cv.resize(img, newSize, interpolation=cv.INTER_AREA)# zwraca ndarray\n",
" img = img / 255#normalizacja\n",
" train_img.append(img)\n",
" categories_count.append(count)\n",
" X={}\n",
" X[\"values\"] = np.array(train_img)\n",
" X[\"categories_name\"] = categories_name\n",
" X[\"categories_count\"] = categories_count\n",
" X[\"labels\"]=labels\n",
" return X\n",
"\n",
"def load_test_data(input_dir, newSize=(256,256)):\n",
" import numpy as np\n",
" import pandas as pd\n",
" import os\n",
" from skimage.io import imread\n",
" import cv2 as cv\n",
" from pathlib import Path\n",
" import random\n",
" from shutil import copyfile, rmtree\n",
" import json\n",
"\n",
" import seaborn as sns\n",
" import matplotlib.pyplot as plt\n",
"\n",
" import matplotlib\n",
"\n",
" image_path = Path(input_dir)\n",
"\n",
" labels_path = image_path.parents[0] / 'test_labels.json'\n",
"\n",
" jsonString = labels_path.read_text()\n",
" objects = json.loads(jsonString)\n",
"\n",
" categories_name = []\n",
" categories_count=[]\n",
" count = 0\n",
" c = objects[0]['value']\n",
" for e in objects:\n",
" if e['value'] != c:\n",
" categories_count.append(count)\n",
" c = e['value']\n",
" count = 1\n",
" else:\n",
" count += 1\n",
" if not e['value'] in categories_name:\n",
" categories_name.append(e['value'])\n",
"\n",
" categories_count.append(count)\n",
" \n",
" test_img = []\n",
"\n",
" labels=[]\n",
" for e in objects:\n",
" p = image_path / e['filename']\n",
" img = imread(p)#zwraca ndarry postaci xSize x ySize x colorDepth\n",
" img = cv.resize(img, newSize, interpolation=cv.INTER_AREA)# zwraca ndarray\n",
" img = img / 255#normalizacja\n",
" test_img.append(img)\n",
" labels.append(e['value'])\n",
"\n",
" X={}\n",
" X[\"values\"] = np.array(test_img)\n",
" X[\"categories_name\"] = categories_name\n",
" X[\"categories_count\"] = categories_count\n",
" X[\"labels\"]=labels\n",
" return X\n",
"\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"# Data load\n",
"data_train = load_train_data(\"train_test_sw/train_sw\", newSize=(16,16))\n",
"X_train = data_train['values']\n",
"y_train = data_train['labels']\n",
"\n",
"data_test = load_test_data(\"train_test_sw/test_sw\", newSize=(16,16))\n",
"X_test = data_test['values']\n",
"y_test = data_test['labels']\n",
"\n",
"class_le = LabelEncoder()\n",
"y_train_enc = class_le.fit_transform(y_train)\n",
"y_test_enc = class_le.fit_transform(y_test)\n",
"\n",
"X_train = X_train.flatten().reshape(X_train.shape[0], int(np.prod(X_train.shape) / X_train.shape[0]))\n",
"X_test = X_test.flatten().reshape(X_test.shape[0], int(np.prod(X_test.shape) / X_test.shape[0]))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e90f9516",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"(4708, 1024)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "542f53d9",
"metadata": {},
"outputs": [],
"source": [
"X = X_train.T"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d75f9dd7",
"metadata": {},
"outputs": [],
"source": [
"m_train, _ = X.shape\n",
"m, n = X.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "27e98944",
"metadata": {},
"outputs": [],
"source": [
"def init_params():\n",
" W1 = np.random.rand(5, 1024) - 0.5\n",
" b1 = np.random.rand(5, 1) - 0.5\n",
" W2 = np.random.rand(5, 5) - 0.5\n",
" b2 = np.random.rand(5, 1) - 0.5\n",
" return W1, b1, W2, b2\n",
"\n",
"def ReLU(Z):\n",
" return np.maximum(Z, 0)\n",
"\n",
"def softmax(Z):\n",
" A = np.exp(Z) / sum(np.exp(Z))\n",
" return A\n",
" \n",
"def forward_prop(W1, b1, W2, b2, X):\n",
" Z1 = W1.dot(X) + b1\n",
" A1 = ReLU(Z1)\n",
" Z2 = W2.dot(A1) + b2\n",
" A2 = softmax(Z2)\n",
" return Z1, A1, Z2, A2\n",
"\n",
"def ReLU_deriv(Z):\n",
" return Z > 0\n",
"\n",
"def one_hot(Y):\n",
" one_hot_Y = np.zeros((Y.size, Y.max() + 1))\n",
" one_hot_Y[np.arange(Y.size), Y] = 1\n",
" one_hot_Y = one_hot_Y.T\n",
" return one_hot_Y\n",
"\n",
"def backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y):\n",
" one_hot_Y = one_hot(Y)\n",
" dZ2 = A2 - one_hot_Y\n",
" dW2 = 1 / m * dZ2.dot(A1.T)\n",
" db2 = 1 / m * np.sum(dZ2)\n",
" dZ1 = W2.T.dot(dZ2) * ReLU_deriv(Z1)\n",
" dW1 = 1 / m * dZ1.dot(X.T)\n",
" db1 = 1 / m * np.sum(dZ1)\n",
" return dW1, db1, dW2, db2\n",
"\n",
"def update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha):\n",
" W1 = W1 - alpha * dW1\n",
" b1 = b1 - alpha * db1 \n",
" W2 = W2 - alpha * dW2 \n",
" b2 = b2 - alpha * db2 \n",
" return W1, b1, W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fc219ddc",
"metadata": {},
"outputs": [],
"source": [
"def get_predictions(A2):\n",
" return np.argmax(A2, 0)\n",
"\n",
"def get_accuracy(predictions, Y):\n",
" print(predictions, Y)\n",
" return np.sum(predictions == Y) / Y.size\n",
"\n",
"def gradient_descent(X, Y, alpha, iterations):\n",
" W1, b1, W2, b2 = init_params()\n",
" for i in range(iterations):\n",
" Z1, A1, Z2, A2 = forward_prop(W1, b1, W2, b2, X)\n",
" dW1, db1, dW2, db2 = backward_prop(Z1, A1, Z2, A2, W1, W2, X, Y)\n",
" W1, b1, W2, b2 = update_params(W1, b1, W2, b2, dW1, db1, dW2, db2, alpha)\n",
" if i % 10 == 0:\n",
" print(\"Iteration: \", i)\n",
" predictions = get_predictions(A2)\n",
" print(get_accuracy(predictions, Y))\n",
" return W1, b1, W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a81b8a73",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 0\n",
"[0 0 0 ... 0 0 0] [0 0 0 ... 4 4 4]\n",
"0.20348343245539507\n",
"Iteration: 10\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 20\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 30\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 40\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 50\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 60\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 70\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 80\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 90\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 100\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 110\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 120\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 130\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 140\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 150\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 160\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 170\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 180\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 190\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 200\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 210\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 220\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 230\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 240\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 250\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 260\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 270\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 280\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 290\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 300\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 310\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 320\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 330\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 340\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 350\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 360\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 370\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 380\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 390\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 400\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 410\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 420\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 430\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 440\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 450\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 460\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 470\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 480\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 490\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 500\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 510\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 520\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 530\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 540\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 550\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 560\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 570\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 580\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 590\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 600\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 610\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 620\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 630\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 640\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 650\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 660\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 670\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 680\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 690\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 700\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 710\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 720\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 730\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 740\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 750\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 760\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 770\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 780\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 790\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 800\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 810\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 820\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 830\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 840\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 850\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 860\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 870\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 880\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 890\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 900\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 910\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 920\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 930\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 940\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 950\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 960\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 970\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 980\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n",
"Iteration: 990\n",
"[2 2 2 ... 2 2 2] [0 0 0 ... 4 4 4]\n",
"0.2296091758708581\n"
]
}
],
"source": [
"W1, b1, W2, b2 = gradient_descent(X, y_train_enc, 0.10, 1000)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "51ff4997",
"metadata": {},
"outputs": [],
"source": [
"def make_predictions(X, W1, b1, W2, b2):\n",
" _, _, _, A2 = forward_prop(W1, b1, W2, b2, X)\n",
" predictions = get_predictions(A2)\n",
" return predictions"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "a0861788",
"metadata": {},
"outputs": [],
"source": [
"dev_predictions = make_predictions(X_test.T, W1, b1, W2, b2)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "971555e3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
" 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
" 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
" 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
" 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
" 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
" 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2\n",
" 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
" 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3\n",
" 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4\n",
" 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4]\n"
]
},
{
"data": {
"text/plain": [
"0.20077220077220076"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_accuracy(dev_predictions, y_test_enc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6642aa68",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}