diff --git a/classification.ipynb b/classification.ipynb
new file mode 100644
index 0000000..406014f
--- /dev/null
+++ b/classification.ipynb
@@ -0,0 +1,970 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "corrected-wholesale",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!kaggle datasets download -d yasserh/breast-cancer-dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ranging-police",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!unzip -o breast-cancer-dataset.zip"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 109,
+ "id": "ideal-spouse",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torch.autograd import Variable\n",
+ "from sklearn.datasets import load_iris\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.metrics import accuracy_score\n",
+ "from sklearn.preprocessing import LabelEncoder\n",
+ "from tensorflow.keras.utils import to_categorical\n",
+ "import torch.nn.functional as F\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 110,
+ "id": "major-compromise",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Model(nn.Module):\n",
+ " def __init__(self, input_dim):\n",
+ " super(Model, self).__init__()\n",
+ " self.layer1 = nn.Linear(input_dim,50)\n",
+ " self.layer2 = nn.Linear(50, 20)\n",
+ " self.layer3 = nn.Linear(20, 3)\n",
+ " \n",
+ " def forward(self, x):\n",
+ " x = F.relu(self.layer1(x))\n",
+ " x = F.relu(self.layer2(x))\n",
+ " x = F.softmax(self.layer3(x)) # To check with the loss function\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 111,
+ "id": "czech-regular",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " diagnosis | \n",
+ " radius_mean | \n",
+ " texture_mean | \n",
+ " perimeter_mean | \n",
+ " area_mean | \n",
+ " smoothness_mean | \n",
+ " compactness_mean | \n",
+ " concavity_mean | \n",
+ " concave points_mean | \n",
+ " symmetry_mean | \n",
+ " ... | \n",
+ " radius_worst | \n",
+ " texture_worst | \n",
+ " perimeter_worst | \n",
+ " area_worst | \n",
+ " smoothness_worst | \n",
+ " compactness_worst | \n",
+ " concavity_worst | \n",
+ " concave points_worst | \n",
+ " symmetry_worst | \n",
+ " fractal_dimension_worst | \n",
+ "
\n",
+ " \n",
+ " id | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 842517 | \n",
+ " M | \n",
+ " 20.57 | \n",
+ " 17.77 | \n",
+ " 132.90 | \n",
+ " 1326.0 | \n",
+ " 0.08474 | \n",
+ " 0.07864 | \n",
+ " 0.08690 | \n",
+ " 0.07017 | \n",
+ " 0.1812 | \n",
+ " ... | \n",
+ " 24.99 | \n",
+ " 23.41 | \n",
+ " 158.80 | \n",
+ " 1956.0 | \n",
+ " 0.1238 | \n",
+ " 0.1866 | \n",
+ " 0.2416 | \n",
+ " 0.1860 | \n",
+ " 0.2750 | \n",
+ " 0.08902 | \n",
+ "
\n",
+ " \n",
+ " 84300903 | \n",
+ " M | \n",
+ " 19.69 | \n",
+ " 21.25 | \n",
+ " 130.00 | \n",
+ " 1203.0 | \n",
+ " 0.10960 | \n",
+ " 0.15990 | \n",
+ " 0.19740 | \n",
+ " 0.12790 | \n",
+ " 0.2069 | \n",
+ " ... | \n",
+ " 23.57 | \n",
+ " 25.53 | \n",
+ " 152.50 | \n",
+ " 1709.0 | \n",
+ " 0.1444 | \n",
+ " 0.4245 | \n",
+ " 0.4504 | \n",
+ " 0.2430 | \n",
+ " 0.3613 | \n",
+ " 0.08758 | \n",
+ "
\n",
+ " \n",
+ " 84348301 | \n",
+ " M | \n",
+ " 11.42 | \n",
+ " 20.38 | \n",
+ " 77.58 | \n",
+ " 386.1 | \n",
+ " 0.14250 | \n",
+ " 0.28390 | \n",
+ " 0.24140 | \n",
+ " 0.10520 | \n",
+ " 0.2597 | \n",
+ " ... | \n",
+ " 14.91 | \n",
+ " 26.50 | \n",
+ " 98.87 | \n",
+ " 567.7 | \n",
+ " 0.2098 | \n",
+ " 0.8663 | \n",
+ " 0.6869 | \n",
+ " 0.2575 | \n",
+ " 0.6638 | \n",
+ " 0.17300 | \n",
+ "
\n",
+ " \n",
+ " 84358402 | \n",
+ " M | \n",
+ " 20.29 | \n",
+ " 14.34 | \n",
+ " 135.10 | \n",
+ " 1297.0 | \n",
+ " 0.10030 | \n",
+ " 0.13280 | \n",
+ " 0.19800 | \n",
+ " 0.10430 | \n",
+ " 0.1809 | \n",
+ " ... | \n",
+ " 22.54 | \n",
+ " 16.67 | \n",
+ " 152.20 | \n",
+ " 1575.0 | \n",
+ " 0.1374 | \n",
+ " 0.2050 | \n",
+ " 0.4000 | \n",
+ " 0.1625 | \n",
+ " 0.2364 | \n",
+ " 0.07678 | \n",
+ "
\n",
+ " \n",
+ " 843786 | \n",
+ " M | \n",
+ " 12.45 | \n",
+ " 15.70 | \n",
+ " 82.57 | \n",
+ " 477.1 | \n",
+ " 0.12780 | \n",
+ " 0.17000 | \n",
+ " 0.15780 | \n",
+ " 0.08089 | \n",
+ " 0.2087 | \n",
+ " ... | \n",
+ " 15.47 | \n",
+ " 23.75 | \n",
+ " 103.40 | \n",
+ " 741.6 | \n",
+ " 0.1791 | \n",
+ " 0.5249 | \n",
+ " 0.5355 | \n",
+ " 0.1741 | \n",
+ " 0.3985 | \n",
+ " 0.12440 | \n",
+ "
\n",
+ " \n",
+ " 844359 | \n",
+ " M | \n",
+ " 18.25 | \n",
+ " 19.98 | \n",
+ " 119.60 | \n",
+ " 1040.0 | \n",
+ " 0.09463 | \n",
+ " 0.10900 | \n",
+ " 0.11270 | \n",
+ " 0.07400 | \n",
+ " 0.1794 | \n",
+ " ... | \n",
+ " 22.88 | \n",
+ " 27.66 | \n",
+ " 153.20 | \n",
+ " 1606.0 | \n",
+ " 0.1442 | \n",
+ " 0.2576 | \n",
+ " 0.3784 | \n",
+ " 0.1932 | \n",
+ " 0.3063 | \n",
+ " 0.08368 | \n",
+ "
\n",
+ " \n",
+ " 84458202 | \n",
+ " M | \n",
+ " 13.71 | \n",
+ " 20.83 | \n",
+ " 90.20 | \n",
+ " 577.9 | \n",
+ " 0.11890 | \n",
+ " 0.16450 | \n",
+ " 0.09366 | \n",
+ " 0.05985 | \n",
+ " 0.2196 | \n",
+ " ... | \n",
+ " 17.06 | \n",
+ " 28.14 | \n",
+ " 110.60 | \n",
+ " 897.0 | \n",
+ " 0.1654 | \n",
+ " 0.3682 | \n",
+ " 0.2678 | \n",
+ " 0.1556 | \n",
+ " 0.3196 | \n",
+ " 0.11510 | \n",
+ "
\n",
+ " \n",
+ " 844981 | \n",
+ " M | \n",
+ " 13.00 | \n",
+ " 21.82 | \n",
+ " 87.50 | \n",
+ " 519.8 | \n",
+ " 0.12730 | \n",
+ " 0.19320 | \n",
+ " 0.18590 | \n",
+ " 0.09353 | \n",
+ " 0.2350 | \n",
+ " ... | \n",
+ " 15.49 | \n",
+ " 30.73 | \n",
+ " 106.20 | \n",
+ " 739.3 | \n",
+ " 0.1703 | \n",
+ " 0.5401 | \n",
+ " 0.5390 | \n",
+ " 0.2060 | \n",
+ " 0.4378 | \n",
+ " 0.10720 | \n",
+ "
\n",
+ " \n",
+ " 84501001 | \n",
+ " M | \n",
+ " 12.46 | \n",
+ " 24.04 | \n",
+ " 83.97 | \n",
+ " 475.9 | \n",
+ " 0.11860 | \n",
+ " 0.23960 | \n",
+ " 0.22730 | \n",
+ " 0.08543 | \n",
+ " 0.2030 | \n",
+ " ... | \n",
+ " 15.09 | \n",
+ " 40.68 | \n",
+ " 97.65 | \n",
+ " 711.4 | \n",
+ " 0.1853 | \n",
+ " 1.0580 | \n",
+ " 1.1050 | \n",
+ " 0.2210 | \n",
+ " 0.4366 | \n",
+ " 0.20750 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
9 rows × 31 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " diagnosis radius_mean texture_mean perimeter_mean area_mean \\\n",
+ "id \n",
+ "842517 M 20.57 17.77 132.90 1326.0 \n",
+ "84300903 M 19.69 21.25 130.00 1203.0 \n",
+ "84348301 M 11.42 20.38 77.58 386.1 \n",
+ "84358402 M 20.29 14.34 135.10 1297.0 \n",
+ "843786 M 12.45 15.70 82.57 477.1 \n",
+ "844359 M 18.25 19.98 119.60 1040.0 \n",
+ "84458202 M 13.71 20.83 90.20 577.9 \n",
+ "844981 M 13.00 21.82 87.50 519.8 \n",
+ "84501001 M 12.46 24.04 83.97 475.9 \n",
+ "\n",
+ " smoothness_mean compactness_mean concavity_mean \\\n",
+ "id \n",
+ "842517 0.08474 0.07864 0.08690 \n",
+ "84300903 0.10960 0.15990 0.19740 \n",
+ "84348301 0.14250 0.28390 0.24140 \n",
+ "84358402 0.10030 0.13280 0.19800 \n",
+ "843786 0.12780 0.17000 0.15780 \n",
+ "844359 0.09463 0.10900 0.11270 \n",
+ "84458202 0.11890 0.16450 0.09366 \n",
+ "844981 0.12730 0.19320 0.18590 \n",
+ "84501001 0.11860 0.23960 0.22730 \n",
+ "\n",
+ " concave points_mean symmetry_mean ... \\\n",
+ "id ... \n",
+ "842517 0.07017 0.1812 ... \n",
+ "84300903 0.12790 0.2069 ... \n",
+ "84348301 0.10520 0.2597 ... \n",
+ "84358402 0.10430 0.1809 ... \n",
+ "843786 0.08089 0.2087 ... \n",
+ "844359 0.07400 0.1794 ... \n",
+ "84458202 0.05985 0.2196 ... \n",
+ "844981 0.09353 0.2350 ... \n",
+ "84501001 0.08543 0.2030 ... \n",
+ "\n",
+ " radius_worst texture_worst perimeter_worst area_worst \\\n",
+ "id \n",
+ "842517 24.99 23.41 158.80 1956.0 \n",
+ "84300903 23.57 25.53 152.50 1709.0 \n",
+ "84348301 14.91 26.50 98.87 567.7 \n",
+ "84358402 22.54 16.67 152.20 1575.0 \n",
+ "843786 15.47 23.75 103.40 741.6 \n",
+ "844359 22.88 27.66 153.20 1606.0 \n",
+ "84458202 17.06 28.14 110.60 897.0 \n",
+ "844981 15.49 30.73 106.20 739.3 \n",
+ "84501001 15.09 40.68 97.65 711.4 \n",
+ "\n",
+ " smoothness_worst compactness_worst concavity_worst \\\n",
+ "id \n",
+ "842517 0.1238 0.1866 0.2416 \n",
+ "84300903 0.1444 0.4245 0.4504 \n",
+ "84348301 0.2098 0.8663 0.6869 \n",
+ "84358402 0.1374 0.2050 0.4000 \n",
+ "843786 0.1791 0.5249 0.5355 \n",
+ "844359 0.1442 0.2576 0.3784 \n",
+ "84458202 0.1654 0.3682 0.2678 \n",
+ "844981 0.1703 0.5401 0.5390 \n",
+ "84501001 0.1853 1.0580 1.1050 \n",
+ "\n",
+ " concave points_worst symmetry_worst fractal_dimension_worst \n",
+ "id \n",
+ "842517 0.1860 0.2750 0.08902 \n",
+ "84300903 0.2430 0.3613 0.08758 \n",
+ "84348301 0.2575 0.6638 0.17300 \n",
+ "84358402 0.1625 0.2364 0.07678 \n",
+ "843786 0.1741 0.3985 0.12440 \n",
+ "844359 0.1932 0.3063 0.08368 \n",
+ "84458202 0.1556 0.3196 0.11510 \n",
+ "844981 0.2060 0.4378 0.10720 \n",
+ "84501001 0.2210 0.4366 0.20750 \n",
+ "\n",
+ "[9 rows x 31 columns]"
+ ]
+ },
+ "execution_count": 111,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "class_names = ['Malignant', 'Benign']\n",
+ "data = pd.read_csv('breast-cancer.csv', index_col=0)\n",
+ "data[1:10]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 112,
+ "id": "outdoor-element",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lb = LabelEncoder()\n",
+ "data['diagnosis'] = lb.fit_transform(data['diagnosis'])\n",
+ "features = data.iloc[:, 1:32].values\n",
+ "labels = np.array(data['diagnosis'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 113,
+ "id": "buried-community",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[2.057e+01, 1.777e+01, 1.329e+02, 1.326e+03, 8.474e-02, 7.864e-02,\n",
+ " 8.690e-02, 7.017e-02, 1.812e-01, 5.667e-02, 5.435e-01, 7.339e-01,\n",
+ " 3.398e+00, 7.408e+01, 5.225e-03, 1.308e-02, 1.860e-02, 1.340e-02,\n",
+ " 1.389e-02, 3.532e-03, 2.499e+01, 2.341e+01, 1.588e+02, 1.956e+03,\n",
+ " 1.238e-01, 1.866e-01, 2.416e-01, 1.860e-01, 2.750e-01, 8.902e-02],\n",
+ " [1.969e+01, 2.125e+01, 1.300e+02, 1.203e+03, 1.096e-01, 1.599e-01,\n",
+ " 1.974e-01, 1.279e-01, 2.069e-01, 5.999e-02, 7.456e-01, 7.869e-01,\n",
+ " 4.585e+00, 9.403e+01, 6.150e-03, 4.006e-02, 3.832e-02, 2.058e-02,\n",
+ " 2.250e-02, 4.571e-03, 2.357e+01, 2.553e+01, 1.525e+02, 1.709e+03,\n",
+ " 1.444e-01, 4.245e-01, 4.504e-01, 2.430e-01, 3.613e-01, 8.758e-02],\n",
+ " [1.142e+01, 2.038e+01, 7.758e+01, 3.861e+02, 1.425e-01, 2.839e-01,\n",
+ " 2.414e-01, 1.052e-01, 2.597e-01, 9.744e-02, 4.956e-01, 1.156e+00,\n",
+ " 3.445e+00, 2.723e+01, 9.110e-03, 7.458e-02, 5.661e-02, 1.867e-02,\n",
+ " 5.963e-02, 9.208e-03, 1.491e+01, 2.650e+01, 9.887e+01, 5.677e+02,\n",
+ " 2.098e-01, 8.663e-01, 6.869e-01, 2.575e-01, 6.638e-01, 1.730e-01],\n",
+ " [2.029e+01, 1.434e+01, 1.351e+02, 1.297e+03, 1.003e-01, 1.328e-01,\n",
+ " 1.980e-01, 1.043e-01, 1.809e-01, 5.883e-02, 7.572e-01, 7.813e-01,\n",
+ " 5.438e+00, 9.444e+01, 1.149e-02, 2.461e-02, 5.688e-02, 1.885e-02,\n",
+ " 1.756e-02, 5.115e-03, 2.254e+01, 1.667e+01, 1.522e+02, 1.575e+03,\n",
+ " 1.374e-01, 2.050e-01, 4.000e-01, 1.625e-01, 2.364e-01, 7.678e-02],\n",
+ " [1.245e+01, 1.570e+01, 8.257e+01, 4.771e+02, 1.278e-01, 1.700e-01,\n",
+ " 1.578e-01, 8.089e-02, 2.087e-01, 7.613e-02, 3.345e-01, 8.902e-01,\n",
+ " 2.217e+00, 2.719e+01, 7.510e-03, 3.345e-02, 3.672e-02, 1.137e-02,\n",
+ " 2.165e-02, 5.082e-03, 1.547e+01, 2.375e+01, 1.034e+02, 7.416e+02,\n",
+ " 1.791e-01, 5.249e-01, 5.355e-01, 1.741e-01, 3.985e-01, 1.244e-01],\n",
+ " [1.825e+01, 1.998e+01, 1.196e+02, 1.040e+03, 9.463e-02, 1.090e-01,\n",
+ " 1.127e-01, 7.400e-02, 1.794e-01, 5.742e-02, 4.467e-01, 7.732e-01,\n",
+ " 3.180e+00, 5.391e+01, 4.314e-03, 1.382e-02, 2.254e-02, 1.039e-02,\n",
+ " 1.369e-02, 2.179e-03, 2.288e+01, 2.766e+01, 1.532e+02, 1.606e+03,\n",
+ " 1.442e-01, 2.576e-01, 3.784e-01, 1.932e-01, 3.063e-01, 8.368e-02],\n",
+ " [1.371e+01, 2.083e+01, 9.020e+01, 5.779e+02, 1.189e-01, 1.645e-01,\n",
+ " 9.366e-02, 5.985e-02, 2.196e-01, 7.451e-02, 5.835e-01, 1.377e+00,\n",
+ " 3.856e+00, 5.096e+01, 8.805e-03, 3.029e-02, 2.488e-02, 1.448e-02,\n",
+ " 1.486e-02, 5.412e-03, 1.706e+01, 2.814e+01, 1.106e+02, 8.970e+02,\n",
+ " 1.654e-01, 3.682e-01, 2.678e-01, 1.556e-01, 3.196e-01, 1.151e-01],\n",
+ " [1.300e+01, 2.182e+01, 8.750e+01, 5.198e+02, 1.273e-01, 1.932e-01,\n",
+ " 1.859e-01, 9.353e-02, 2.350e-01, 7.389e-02, 3.063e-01, 1.002e+00,\n",
+ " 2.406e+00, 2.432e+01, 5.731e-03, 3.502e-02, 3.553e-02, 1.226e-02,\n",
+ " 2.143e-02, 3.749e-03, 1.549e+01, 3.073e+01, 1.062e+02, 7.393e+02,\n",
+ " 1.703e-01, 5.401e-01, 5.390e-01, 2.060e-01, 4.378e-01, 1.072e-01],\n",
+ " [1.246e+01, 2.404e+01, 8.397e+01, 4.759e+02, 1.186e-01, 2.396e-01,\n",
+ " 2.273e-01, 8.543e-02, 2.030e-01, 8.243e-02, 2.976e-01, 1.599e+00,\n",
+ " 2.039e+00, 2.394e+01, 7.149e-03, 7.217e-02, 7.743e-02, 1.432e-02,\n",
+ " 1.789e-02, 1.008e-02, 1.509e+01, 4.068e+01, 9.765e+01, 7.114e+02,\n",
+ " 1.853e-01, 1.058e+00, 1.105e+00, 2.210e-01, 4.366e-01, 2.075e-01]])"
+ ]
+ },
+ "execution_count": 113,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "features[1:10]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 114,
+ "id": "incredible-quantum",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([1, 1, 1, 1, 1, 1, 1, 1, 1])"
+ ]
+ },
+ "execution_count": 114,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "labels[1:10]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 115,
+ "id": "brazilian-butler",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "features_train, features_test, labels_train, labels_test = train_test_split(features, labels, random_state=42, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 116,
+ "id": "exotic-method",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Training\n",
+ "model = Model(features_train.shape[1])\n",
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
+ "loss_fn = nn.CrossEntropyLoss()\n",
+ "epochs = 100\n",
+ "\n",
+ "def print_(loss):\n",
+ " print (\"The loss calculated: \", loss)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 117,
+ "id": "sharp-month",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch # 1\n",
+ "The loss calculated: 0.922476053237915\n",
+ "Epoch # 2\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 3\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 4\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 5\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 6\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 7\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 8\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 9\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 10\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 11\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 12\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 13\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 14\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 15\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 16\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 17\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 18\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 19\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 20\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 21\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 22\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 23\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 24\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 25\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 26\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 27\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 28\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 29\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 30\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 31\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 32\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 33\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 34\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 35\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 36\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 37\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 38\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 39\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 40\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 41\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 42\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 43\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 44\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 45\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 46\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 47\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 48\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 49\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 50\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 51\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 52\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 53\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 54\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 55\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 56\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 57\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 58\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 59\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 60\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 61\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 62\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 63\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 64\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 65\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 66\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 67\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 68\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 69\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 70\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 71\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 72\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 73\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 74\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 75\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 76\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 77\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 78\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 79\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 80\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 81\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 82\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 83\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 84\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 85\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 86\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 87\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 88\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 89\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 90\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 91\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 92\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 93\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 94\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 95\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 96\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 97\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 98\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 99\n",
+ "The loss calculated: 0.9223369359970093\n",
+ "Epoch # 100\n",
+ "The loss calculated: 0.9223369359970093\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
+ " # This is added back by InteractiveShellApp.init_path()\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Not using dataloader\n",
+ "x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long()\n",
+ "for epoch in range(1, epochs+1):\n",
+ " print (\"Epoch #\",epoch)\n",
+ " y_pred = model(x_train)\n",
+ " loss = loss_fn(y_pred, y_train)\n",
+ " print_(loss.item())\n",
+ " \n",
+ " # Zero gradients\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward() # Gradients\n",
+ " optimizer.step() # Update"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 118,
+ "id": "mechanical-humidity",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
+ " # This is added back by InteractiveShellApp.init_path()\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Prediction\n",
+ "x_test = Variable(torch.from_numpy(features_test)).float()\n",
+ "pred = model(x_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 119,
+ "id": "based-charleston",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[1., 0., 0.],\n",
+ " [1., 0., 0.],\n",
+ " [1., 0., 0.],\n",
+ " [1., 0., 0.],\n",
+ " [1., 0., 0.],\n",
+ " [1., 0., 0.],\n",
+ " [1., 0., 0.],\n",
+ " [1., 0., 0.],\n",
+ " [1., 0., 0.]], dtype=float32)"
+ ]
+ },
+ "execution_count": 119,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pred = pred.detach().numpy()\n",
+ "pred[1:10]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 120,
+ "id": "dried-accessory",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The accuracy is 0.6223776223776224\n"
+ ]
+ }
+ ],
+ "source": [
+ "print (\"The accuracy is\", accuracy_score(labels_test, np.argmax(pred, axis=1)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 121,
+ "id": "effective-characterization",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0"
+ ]
+ },
+ "execution_count": 121,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "labels_test[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 122,
+ "id": "oriented-determination",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "torch.save(model, \"travel_insurance-pytorch.pkl\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 123,
+ "id": "infectious-wagon",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "saved_model = torch.load(\"travel_insurance-pytorch.pkl\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 124,
+ "id": "built-contributor",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
+ " # This is added back by InteractiveShellApp.init_path()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "0"
+ ]
+ },
+ "execution_count": 124,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "np.argmax(saved_model(x_test[0]).detach().numpy(), axis=0)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}