Classification with Tensorflow

This commit is contained in:
Agata 2022-04-20 11:16:45 +02:00
parent 29e340acf7
commit 8076102c87

970
classification.ipynb Normal file
View File

@ -0,0 +1,970 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "corrected-wholesale",
"metadata": {},
"outputs": [],
"source": [
"!kaggle datasets download -d yasserh/breast-cancer-dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ranging-police",
"metadata": {},
"outputs": [],
"source": [
"!unzip -o breast-cancer-dataset.zip"
]
},
{
"cell_type": "code",
"execution_count": 109,
"id": "ideal-spouse",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from tensorflow.keras.utils import to_categorical\n",
"import torch.nn.functional as F\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 110,
"id": "major-compromise",
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, input_dim):\n",
" super(Model, self).__init__()\n",
" self.layer1 = nn.Linear(input_dim,50)\n",
" self.layer2 = nn.Linear(50, 20)\n",
" self.layer3 = nn.Linear(20, 3)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.layer1(x))\n",
" x = F.relu(self.layer2(x))\n",
" x = F.softmax(self.layer3(x)) # To check with the loss function\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 111,
"id": "czech-regular",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>diagnosis</th>\n",
" <th>radius_mean</th>\n",
" <th>texture_mean</th>\n",
" <th>perimeter_mean</th>\n",
" <th>area_mean</th>\n",
" <th>smoothness_mean</th>\n",
" <th>compactness_mean</th>\n",
" <th>concavity_mean</th>\n",
" <th>concave points_mean</th>\n",
" <th>symmetry_mean</th>\n",
" <th>...</th>\n",
" <th>radius_worst</th>\n",
" <th>texture_worst</th>\n",
" <th>perimeter_worst</th>\n",
" <th>area_worst</th>\n",
" <th>smoothness_worst</th>\n",
" <th>compactness_worst</th>\n",
" <th>concavity_worst</th>\n",
" <th>concave points_worst</th>\n",
" <th>symmetry_worst</th>\n",
" <th>fractal_dimension_worst</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>842517</th>\n",
" <td>M</td>\n",
" <td>20.57</td>\n",
" <td>17.77</td>\n",
" <td>132.90</td>\n",
" <td>1326.0</td>\n",
" <td>0.08474</td>\n",
" <td>0.07864</td>\n",
" <td>0.08690</td>\n",
" <td>0.07017</td>\n",
" <td>0.1812</td>\n",
" <td>...</td>\n",
" <td>24.99</td>\n",
" <td>23.41</td>\n",
" <td>158.80</td>\n",
" <td>1956.0</td>\n",
" <td>0.1238</td>\n",
" <td>0.1866</td>\n",
" <td>0.2416</td>\n",
" <td>0.1860</td>\n",
" <td>0.2750</td>\n",
" <td>0.08902</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84300903</th>\n",
" <td>M</td>\n",
" <td>19.69</td>\n",
" <td>21.25</td>\n",
" <td>130.00</td>\n",
" <td>1203.0</td>\n",
" <td>0.10960</td>\n",
" <td>0.15990</td>\n",
" <td>0.19740</td>\n",
" <td>0.12790</td>\n",
" <td>0.2069</td>\n",
" <td>...</td>\n",
" <td>23.57</td>\n",
" <td>25.53</td>\n",
" <td>152.50</td>\n",
" <td>1709.0</td>\n",
" <td>0.1444</td>\n",
" <td>0.4245</td>\n",
" <td>0.4504</td>\n",
" <td>0.2430</td>\n",
" <td>0.3613</td>\n",
" <td>0.08758</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84348301</th>\n",
" <td>M</td>\n",
" <td>11.42</td>\n",
" <td>20.38</td>\n",
" <td>77.58</td>\n",
" <td>386.1</td>\n",
" <td>0.14250</td>\n",
" <td>0.28390</td>\n",
" <td>0.24140</td>\n",
" <td>0.10520</td>\n",
" <td>0.2597</td>\n",
" <td>...</td>\n",
" <td>14.91</td>\n",
" <td>26.50</td>\n",
" <td>98.87</td>\n",
" <td>567.7</td>\n",
" <td>0.2098</td>\n",
" <td>0.8663</td>\n",
" <td>0.6869</td>\n",
" <td>0.2575</td>\n",
" <td>0.6638</td>\n",
" <td>0.17300</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84358402</th>\n",
" <td>M</td>\n",
" <td>20.29</td>\n",
" <td>14.34</td>\n",
" <td>135.10</td>\n",
" <td>1297.0</td>\n",
" <td>0.10030</td>\n",
" <td>0.13280</td>\n",
" <td>0.19800</td>\n",
" <td>0.10430</td>\n",
" <td>0.1809</td>\n",
" <td>...</td>\n",
" <td>22.54</td>\n",
" <td>16.67</td>\n",
" <td>152.20</td>\n",
" <td>1575.0</td>\n",
" <td>0.1374</td>\n",
" <td>0.2050</td>\n",
" <td>0.4000</td>\n",
" <td>0.1625</td>\n",
" <td>0.2364</td>\n",
" <td>0.07678</td>\n",
" </tr>\n",
" <tr>\n",
" <th>843786</th>\n",
" <td>M</td>\n",
" <td>12.45</td>\n",
" <td>15.70</td>\n",
" <td>82.57</td>\n",
" <td>477.1</td>\n",
" <td>0.12780</td>\n",
" <td>0.17000</td>\n",
" <td>0.15780</td>\n",
" <td>0.08089</td>\n",
" <td>0.2087</td>\n",
" <td>...</td>\n",
" <td>15.47</td>\n",
" <td>23.75</td>\n",
" <td>103.40</td>\n",
" <td>741.6</td>\n",
" <td>0.1791</td>\n",
" <td>0.5249</td>\n",
" <td>0.5355</td>\n",
" <td>0.1741</td>\n",
" <td>0.3985</td>\n",
" <td>0.12440</td>\n",
" </tr>\n",
" <tr>\n",
" <th>844359</th>\n",
" <td>M</td>\n",
" <td>18.25</td>\n",
" <td>19.98</td>\n",
" <td>119.60</td>\n",
" <td>1040.0</td>\n",
" <td>0.09463</td>\n",
" <td>0.10900</td>\n",
" <td>0.11270</td>\n",
" <td>0.07400</td>\n",
" <td>0.1794</td>\n",
" <td>...</td>\n",
" <td>22.88</td>\n",
" <td>27.66</td>\n",
" <td>153.20</td>\n",
" <td>1606.0</td>\n",
" <td>0.1442</td>\n",
" <td>0.2576</td>\n",
" <td>0.3784</td>\n",
" <td>0.1932</td>\n",
" <td>0.3063</td>\n",
" <td>0.08368</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84458202</th>\n",
" <td>M</td>\n",
" <td>13.71</td>\n",
" <td>20.83</td>\n",
" <td>90.20</td>\n",
" <td>577.9</td>\n",
" <td>0.11890</td>\n",
" <td>0.16450</td>\n",
" <td>0.09366</td>\n",
" <td>0.05985</td>\n",
" <td>0.2196</td>\n",
" <td>...</td>\n",
" <td>17.06</td>\n",
" <td>28.14</td>\n",
" <td>110.60</td>\n",
" <td>897.0</td>\n",
" <td>0.1654</td>\n",
" <td>0.3682</td>\n",
" <td>0.2678</td>\n",
" <td>0.1556</td>\n",
" <td>0.3196</td>\n",
" <td>0.11510</td>\n",
" </tr>\n",
" <tr>\n",
" <th>844981</th>\n",
" <td>M</td>\n",
" <td>13.00</td>\n",
" <td>21.82</td>\n",
" <td>87.50</td>\n",
" <td>519.8</td>\n",
" <td>0.12730</td>\n",
" <td>0.19320</td>\n",
" <td>0.18590</td>\n",
" <td>0.09353</td>\n",
" <td>0.2350</td>\n",
" <td>...</td>\n",
" <td>15.49</td>\n",
" <td>30.73</td>\n",
" <td>106.20</td>\n",
" <td>739.3</td>\n",
" <td>0.1703</td>\n",
" <td>0.5401</td>\n",
" <td>0.5390</td>\n",
" <td>0.2060</td>\n",
" <td>0.4378</td>\n",
" <td>0.10720</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84501001</th>\n",
" <td>M</td>\n",
" <td>12.46</td>\n",
" <td>24.04</td>\n",
" <td>83.97</td>\n",
" <td>475.9</td>\n",
" <td>0.11860</td>\n",
" <td>0.23960</td>\n",
" <td>0.22730</td>\n",
" <td>0.08543</td>\n",
" <td>0.2030</td>\n",
" <td>...</td>\n",
" <td>15.09</td>\n",
" <td>40.68</td>\n",
" <td>97.65</td>\n",
" <td>711.4</td>\n",
" <td>0.1853</td>\n",
" <td>1.0580</td>\n",
" <td>1.1050</td>\n",
" <td>0.2210</td>\n",
" <td>0.4366</td>\n",
" <td>0.20750</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>9 rows × 31 columns</p>\n",
"</div>"
],
"text/plain": [
" diagnosis radius_mean texture_mean perimeter_mean area_mean \\\n",
"id \n",
"842517 M 20.57 17.77 132.90 1326.0 \n",
"84300903 M 19.69 21.25 130.00 1203.0 \n",
"84348301 M 11.42 20.38 77.58 386.1 \n",
"84358402 M 20.29 14.34 135.10 1297.0 \n",
"843786 M 12.45 15.70 82.57 477.1 \n",
"844359 M 18.25 19.98 119.60 1040.0 \n",
"84458202 M 13.71 20.83 90.20 577.9 \n",
"844981 M 13.00 21.82 87.50 519.8 \n",
"84501001 M 12.46 24.04 83.97 475.9 \n",
"\n",
" smoothness_mean compactness_mean concavity_mean \\\n",
"id \n",
"842517 0.08474 0.07864 0.08690 \n",
"84300903 0.10960 0.15990 0.19740 \n",
"84348301 0.14250 0.28390 0.24140 \n",
"84358402 0.10030 0.13280 0.19800 \n",
"843786 0.12780 0.17000 0.15780 \n",
"844359 0.09463 0.10900 0.11270 \n",
"84458202 0.11890 0.16450 0.09366 \n",
"844981 0.12730 0.19320 0.18590 \n",
"84501001 0.11860 0.23960 0.22730 \n",
"\n",
" concave points_mean symmetry_mean ... \\\n",
"id ... \n",
"842517 0.07017 0.1812 ... \n",
"84300903 0.12790 0.2069 ... \n",
"84348301 0.10520 0.2597 ... \n",
"84358402 0.10430 0.1809 ... \n",
"843786 0.08089 0.2087 ... \n",
"844359 0.07400 0.1794 ... \n",
"84458202 0.05985 0.2196 ... \n",
"844981 0.09353 0.2350 ... \n",
"84501001 0.08543 0.2030 ... \n",
"\n",
" radius_worst texture_worst perimeter_worst area_worst \\\n",
"id \n",
"842517 24.99 23.41 158.80 1956.0 \n",
"84300903 23.57 25.53 152.50 1709.0 \n",
"84348301 14.91 26.50 98.87 567.7 \n",
"84358402 22.54 16.67 152.20 1575.0 \n",
"843786 15.47 23.75 103.40 741.6 \n",
"844359 22.88 27.66 153.20 1606.0 \n",
"84458202 17.06 28.14 110.60 897.0 \n",
"844981 15.49 30.73 106.20 739.3 \n",
"84501001 15.09 40.68 97.65 711.4 \n",
"\n",
" smoothness_worst compactness_worst concavity_worst \\\n",
"id \n",
"842517 0.1238 0.1866 0.2416 \n",
"84300903 0.1444 0.4245 0.4504 \n",
"84348301 0.2098 0.8663 0.6869 \n",
"84358402 0.1374 0.2050 0.4000 \n",
"843786 0.1791 0.5249 0.5355 \n",
"844359 0.1442 0.2576 0.3784 \n",
"84458202 0.1654 0.3682 0.2678 \n",
"844981 0.1703 0.5401 0.5390 \n",
"84501001 0.1853 1.0580 1.1050 \n",
"\n",
" concave points_worst symmetry_worst fractal_dimension_worst \n",
"id \n",
"842517 0.1860 0.2750 0.08902 \n",
"84300903 0.2430 0.3613 0.08758 \n",
"84348301 0.2575 0.6638 0.17300 \n",
"84358402 0.1625 0.2364 0.07678 \n",
"843786 0.1741 0.3985 0.12440 \n",
"844359 0.1932 0.3063 0.08368 \n",
"84458202 0.1556 0.3196 0.11510 \n",
"844981 0.2060 0.4378 0.10720 \n",
"84501001 0.2210 0.4366 0.20750 \n",
"\n",
"[9 rows x 31 columns]"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_names = ['Malignant', 'Benign']\n",
"data = pd.read_csv('breast-cancer.csv', index_col=0)\n",
"data[1:10]"
]
},
{
"cell_type": "code",
"execution_count": 112,
"id": "outdoor-element",
"metadata": {},
"outputs": [],
"source": [
"lb = LabelEncoder()\n",
"data['diagnosis'] = lb.fit_transform(data['diagnosis'])\n",
"features = data.iloc[:, 1:32].values\n",
"labels = np.array(data['diagnosis'])"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "buried-community",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[2.057e+01, 1.777e+01, 1.329e+02, 1.326e+03, 8.474e-02, 7.864e-02,\n",
" 8.690e-02, 7.017e-02, 1.812e-01, 5.667e-02, 5.435e-01, 7.339e-01,\n",
" 3.398e+00, 7.408e+01, 5.225e-03, 1.308e-02, 1.860e-02, 1.340e-02,\n",
" 1.389e-02, 3.532e-03, 2.499e+01, 2.341e+01, 1.588e+02, 1.956e+03,\n",
" 1.238e-01, 1.866e-01, 2.416e-01, 1.860e-01, 2.750e-01, 8.902e-02],\n",
" [1.969e+01, 2.125e+01, 1.300e+02, 1.203e+03, 1.096e-01, 1.599e-01,\n",
" 1.974e-01, 1.279e-01, 2.069e-01, 5.999e-02, 7.456e-01, 7.869e-01,\n",
" 4.585e+00, 9.403e+01, 6.150e-03, 4.006e-02, 3.832e-02, 2.058e-02,\n",
" 2.250e-02, 4.571e-03, 2.357e+01, 2.553e+01, 1.525e+02, 1.709e+03,\n",
" 1.444e-01, 4.245e-01, 4.504e-01, 2.430e-01, 3.613e-01, 8.758e-02],\n",
" [1.142e+01, 2.038e+01, 7.758e+01, 3.861e+02, 1.425e-01, 2.839e-01,\n",
" 2.414e-01, 1.052e-01, 2.597e-01, 9.744e-02, 4.956e-01, 1.156e+00,\n",
" 3.445e+00, 2.723e+01, 9.110e-03, 7.458e-02, 5.661e-02, 1.867e-02,\n",
" 5.963e-02, 9.208e-03, 1.491e+01, 2.650e+01, 9.887e+01, 5.677e+02,\n",
" 2.098e-01, 8.663e-01, 6.869e-01, 2.575e-01, 6.638e-01, 1.730e-01],\n",
" [2.029e+01, 1.434e+01, 1.351e+02, 1.297e+03, 1.003e-01, 1.328e-01,\n",
" 1.980e-01, 1.043e-01, 1.809e-01, 5.883e-02, 7.572e-01, 7.813e-01,\n",
" 5.438e+00, 9.444e+01, 1.149e-02, 2.461e-02, 5.688e-02, 1.885e-02,\n",
" 1.756e-02, 5.115e-03, 2.254e+01, 1.667e+01, 1.522e+02, 1.575e+03,\n",
" 1.374e-01, 2.050e-01, 4.000e-01, 1.625e-01, 2.364e-01, 7.678e-02],\n",
" [1.245e+01, 1.570e+01, 8.257e+01, 4.771e+02, 1.278e-01, 1.700e-01,\n",
" 1.578e-01, 8.089e-02, 2.087e-01, 7.613e-02, 3.345e-01, 8.902e-01,\n",
" 2.217e+00, 2.719e+01, 7.510e-03, 3.345e-02, 3.672e-02, 1.137e-02,\n",
" 2.165e-02, 5.082e-03, 1.547e+01, 2.375e+01, 1.034e+02, 7.416e+02,\n",
" 1.791e-01, 5.249e-01, 5.355e-01, 1.741e-01, 3.985e-01, 1.244e-01],\n",
" [1.825e+01, 1.998e+01, 1.196e+02, 1.040e+03, 9.463e-02, 1.090e-01,\n",
" 1.127e-01, 7.400e-02, 1.794e-01, 5.742e-02, 4.467e-01, 7.732e-01,\n",
" 3.180e+00, 5.391e+01, 4.314e-03, 1.382e-02, 2.254e-02, 1.039e-02,\n",
" 1.369e-02, 2.179e-03, 2.288e+01, 2.766e+01, 1.532e+02, 1.606e+03,\n",
" 1.442e-01, 2.576e-01, 3.784e-01, 1.932e-01, 3.063e-01, 8.368e-02],\n",
" [1.371e+01, 2.083e+01, 9.020e+01, 5.779e+02, 1.189e-01, 1.645e-01,\n",
" 9.366e-02, 5.985e-02, 2.196e-01, 7.451e-02, 5.835e-01, 1.377e+00,\n",
" 3.856e+00, 5.096e+01, 8.805e-03, 3.029e-02, 2.488e-02, 1.448e-02,\n",
" 1.486e-02, 5.412e-03, 1.706e+01, 2.814e+01, 1.106e+02, 8.970e+02,\n",
" 1.654e-01, 3.682e-01, 2.678e-01, 1.556e-01, 3.196e-01, 1.151e-01],\n",
" [1.300e+01, 2.182e+01, 8.750e+01, 5.198e+02, 1.273e-01, 1.932e-01,\n",
" 1.859e-01, 9.353e-02, 2.350e-01, 7.389e-02, 3.063e-01, 1.002e+00,\n",
" 2.406e+00, 2.432e+01, 5.731e-03, 3.502e-02, 3.553e-02, 1.226e-02,\n",
" 2.143e-02, 3.749e-03, 1.549e+01, 3.073e+01, 1.062e+02, 7.393e+02,\n",
" 1.703e-01, 5.401e-01, 5.390e-01, 2.060e-01, 4.378e-01, 1.072e-01],\n",
" [1.246e+01, 2.404e+01, 8.397e+01, 4.759e+02, 1.186e-01, 2.396e-01,\n",
" 2.273e-01, 8.543e-02, 2.030e-01, 8.243e-02, 2.976e-01, 1.599e+00,\n",
" 2.039e+00, 2.394e+01, 7.149e-03, 7.217e-02, 7.743e-02, 1.432e-02,\n",
" 1.789e-02, 1.008e-02, 1.509e+01, 4.068e+01, 9.765e+01, 7.114e+02,\n",
" 1.853e-01, 1.058e+00, 1.105e+00, 2.210e-01, 4.366e-01, 2.075e-01]])"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"features[1:10]"
]
},
{
"cell_type": "code",
"execution_count": 114,
"id": "incredible-quantum",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 1, 1, 1, 1, 1, 1, 1, 1])"
]
},
"execution_count": 114,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels[1:10]"
]
},
{
"cell_type": "code",
"execution_count": 115,
"id": "brazilian-butler",
"metadata": {},
"outputs": [],
"source": [
"features_train, features_test, labels_train, labels_test = train_test_split(features, labels, random_state=42, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 116,
"id": "exotic-method",
"metadata": {},
"outputs": [],
"source": [
"# Training\n",
"model = Model(features_train.shape[1])\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"epochs = 100\n",
"\n",
"def print_(loss):\n",
" print (\"The loss calculated: \", loss)"
]
},
{
"cell_type": "code",
"execution_count": 117,
"id": "sharp-month",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch # 1\n",
"The loss calculated: 0.922476053237915\n",
"Epoch # 2\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 3\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 4\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 5\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 6\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 7\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 8\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 9\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 10\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 11\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 12\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 13\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 14\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 15\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 16\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 17\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 18\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 19\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 20\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 21\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 22\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 23\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 24\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 25\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 26\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 27\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 28\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 29\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 30\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 31\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 32\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 33\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 34\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 35\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 36\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 37\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 38\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 39\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 40\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 41\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 42\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 43\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 44\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 45\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 46\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 47\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 48\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 49\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 50\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 51\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 52\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 53\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 54\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 55\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 56\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 57\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 58\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 59\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 60\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 61\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 62\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 63\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 64\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 65\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 66\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 67\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 68\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 69\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 70\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 71\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 72\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 73\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 74\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 75\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 76\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 77\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 78\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 79\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 80\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 81\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 82\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 83\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 84\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 85\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 86\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 87\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 88\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 89\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 90\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 91\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 92\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 93\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 94\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 95\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 96\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 97\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 98\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 99\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 100\n",
"The loss calculated: 0.9223369359970093\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" # This is added back by InteractiveShellApp.init_path()\n"
]
}
],
"source": [
"# Not using dataloader\n",
"x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long()\n",
"for epoch in range(1, epochs+1):\n",
" print (\"Epoch #\",epoch)\n",
" y_pred = model(x_train)\n",
" loss = loss_fn(y_pred, y_train)\n",
" print_(loss.item())\n",
" \n",
" # Zero gradients\n",
" optimizer.zero_grad()\n",
" loss.backward() # Gradients\n",
" optimizer.step() # Update"
]
},
{
"cell_type": "code",
"execution_count": 118,
"id": "mechanical-humidity",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" # This is added back by InteractiveShellApp.init_path()\n"
]
}
],
"source": [
"# Prediction\n",
"x_test = Variable(torch.from_numpy(features_test)).float()\n",
"pred = model(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 119,
"id": "based-charleston",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.]], dtype=float32)"
]
},
"execution_count": 119,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred = pred.detach().numpy()\n",
"pred[1:10]"
]
},
{
"cell_type": "code",
"execution_count": 120,
"id": "dried-accessory",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The accuracy is 0.6223776223776224\n"
]
}
],
"source": [
"print (\"The accuracy is\", accuracy_score(labels_test, np.argmax(pred, axis=1)))"
]
},
{
"cell_type": "code",
"execution_count": 121,
"id": "effective-characterization",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 121,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels_test[0]"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "oriented-determination",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model, \"travel_insurance-pytorch.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "infectious-wagon",
"metadata": {},
"outputs": [],
"source": [
"saved_model = torch.load(\"travel_insurance-pytorch.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 124,
"id": "built-contributor",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" # This is added back by InteractiveShellApp.init_path()\n"
]
},
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.argmax(saved_model(x_test[0]).detach().numpy(), axis=0)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}