Change dataset

This commit is contained in:
Agata 2022-04-23 23:27:19 +02:00
parent c932eb6bba
commit 4a8da732fc
4 changed files with 530 additions and 969 deletions

BIN
body-performance-data.zip Normal file

Binary file not shown.

Binary file not shown.

View File

@ -1,969 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "corrected-wholesale",
"metadata": {},
"outputs": [],
"source": [
"!kaggle datasets download -d yasserh/breast-cancer-dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ranging-police",
"metadata": {},
"outputs": [],
"source": [
"!unzip -o breast-cancer-dataset.zip"
]
},
{
"cell_type": "code",
"execution_count": 109,
"id": "ideal-spouse",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from torch import nn\n",
"from torch.autograd import Variable\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from tensorflow.keras.utils import to_categorical\n",
"import torch.nn.functional as F\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 110,
"id": "major-compromise",
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, input_dim):\n",
" super(Model, self).__init__()\n",
" self.layer1 = nn.Linear(input_dim,50)\n",
" self.layer2 = nn.Linear(50, 20)\n",
" self.layer3 = nn.Linear(20, 3)\n",
" \n",
" def forward(self, x):\n",
" x = F.relu(self.layer1(x))\n",
" x = F.relu(self.layer2(x))\n",
" x = F.softmax(self.layer3(x)) # To check with the loss function\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 111,
"id": "czech-regular",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>diagnosis</th>\n",
" <th>radius_mean</th>\n",
" <th>texture_mean</th>\n",
" <th>perimeter_mean</th>\n",
" <th>area_mean</th>\n",
" <th>smoothness_mean</th>\n",
" <th>compactness_mean</th>\n",
" <th>concavity_mean</th>\n",
" <th>concave points_mean</th>\n",
" <th>symmetry_mean</th>\n",
" <th>...</th>\n",
" <th>radius_worst</th>\n",
" <th>texture_worst</th>\n",
" <th>perimeter_worst</th>\n",
" <th>area_worst</th>\n",
" <th>smoothness_worst</th>\n",
" <th>compactness_worst</th>\n",
" <th>concavity_worst</th>\n",
" <th>concave points_worst</th>\n",
" <th>symmetry_worst</th>\n",
" <th>fractal_dimension_worst</th>\n",
" </tr>\n",
" <tr>\n",
" <th>id</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>842517</th>\n",
" <td>M</td>\n",
" <td>20.57</td>\n",
" <td>17.77</td>\n",
" <td>132.90</td>\n",
" <td>1326.0</td>\n",
" <td>0.08474</td>\n",
" <td>0.07864</td>\n",
" <td>0.08690</td>\n",
" <td>0.07017</td>\n",
" <td>0.1812</td>\n",
" <td>...</td>\n",
" <td>24.99</td>\n",
" <td>23.41</td>\n",
" <td>158.80</td>\n",
" <td>1956.0</td>\n",
" <td>0.1238</td>\n",
" <td>0.1866</td>\n",
" <td>0.2416</td>\n",
" <td>0.1860</td>\n",
" <td>0.2750</td>\n",
" <td>0.08902</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84300903</th>\n",
" <td>M</td>\n",
" <td>19.69</td>\n",
" <td>21.25</td>\n",
" <td>130.00</td>\n",
" <td>1203.0</td>\n",
" <td>0.10960</td>\n",
" <td>0.15990</td>\n",
" <td>0.19740</td>\n",
" <td>0.12790</td>\n",
" <td>0.2069</td>\n",
" <td>...</td>\n",
" <td>23.57</td>\n",
" <td>25.53</td>\n",
" <td>152.50</td>\n",
" <td>1709.0</td>\n",
" <td>0.1444</td>\n",
" <td>0.4245</td>\n",
" <td>0.4504</td>\n",
" <td>0.2430</td>\n",
" <td>0.3613</td>\n",
" <td>0.08758</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84348301</th>\n",
" <td>M</td>\n",
" <td>11.42</td>\n",
" <td>20.38</td>\n",
" <td>77.58</td>\n",
" <td>386.1</td>\n",
" <td>0.14250</td>\n",
" <td>0.28390</td>\n",
" <td>0.24140</td>\n",
" <td>0.10520</td>\n",
" <td>0.2597</td>\n",
" <td>...</td>\n",
" <td>14.91</td>\n",
" <td>26.50</td>\n",
" <td>98.87</td>\n",
" <td>567.7</td>\n",
" <td>0.2098</td>\n",
" <td>0.8663</td>\n",
" <td>0.6869</td>\n",
" <td>0.2575</td>\n",
" <td>0.6638</td>\n",
" <td>0.17300</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84358402</th>\n",
" <td>M</td>\n",
" <td>20.29</td>\n",
" <td>14.34</td>\n",
" <td>135.10</td>\n",
" <td>1297.0</td>\n",
" <td>0.10030</td>\n",
" <td>0.13280</td>\n",
" <td>0.19800</td>\n",
" <td>0.10430</td>\n",
" <td>0.1809</td>\n",
" <td>...</td>\n",
" <td>22.54</td>\n",
" <td>16.67</td>\n",
" <td>152.20</td>\n",
" <td>1575.0</td>\n",
" <td>0.1374</td>\n",
" <td>0.2050</td>\n",
" <td>0.4000</td>\n",
" <td>0.1625</td>\n",
" <td>0.2364</td>\n",
" <td>0.07678</td>\n",
" </tr>\n",
" <tr>\n",
" <th>843786</th>\n",
" <td>M</td>\n",
" <td>12.45</td>\n",
" <td>15.70</td>\n",
" <td>82.57</td>\n",
" <td>477.1</td>\n",
" <td>0.12780</td>\n",
" <td>0.17000</td>\n",
" <td>0.15780</td>\n",
" <td>0.08089</td>\n",
" <td>0.2087</td>\n",
" <td>...</td>\n",
" <td>15.47</td>\n",
" <td>23.75</td>\n",
" <td>103.40</td>\n",
" <td>741.6</td>\n",
" <td>0.1791</td>\n",
" <td>0.5249</td>\n",
" <td>0.5355</td>\n",
" <td>0.1741</td>\n",
" <td>0.3985</td>\n",
" <td>0.12440</td>\n",
" </tr>\n",
" <tr>\n",
" <th>844359</th>\n",
" <td>M</td>\n",
" <td>18.25</td>\n",
" <td>19.98</td>\n",
" <td>119.60</td>\n",
" <td>1040.0</td>\n",
" <td>0.09463</td>\n",
" <td>0.10900</td>\n",
" <td>0.11270</td>\n",
" <td>0.07400</td>\n",
" <td>0.1794</td>\n",
" <td>...</td>\n",
" <td>22.88</td>\n",
" <td>27.66</td>\n",
" <td>153.20</td>\n",
" <td>1606.0</td>\n",
" <td>0.1442</td>\n",
" <td>0.2576</td>\n",
" <td>0.3784</td>\n",
" <td>0.1932</td>\n",
" <td>0.3063</td>\n",
" <td>0.08368</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84458202</th>\n",
" <td>M</td>\n",
" <td>13.71</td>\n",
" <td>20.83</td>\n",
" <td>90.20</td>\n",
" <td>577.9</td>\n",
" <td>0.11890</td>\n",
" <td>0.16450</td>\n",
" <td>0.09366</td>\n",
" <td>0.05985</td>\n",
" <td>0.2196</td>\n",
" <td>...</td>\n",
" <td>17.06</td>\n",
" <td>28.14</td>\n",
" <td>110.60</td>\n",
" <td>897.0</td>\n",
" <td>0.1654</td>\n",
" <td>0.3682</td>\n",
" <td>0.2678</td>\n",
" <td>0.1556</td>\n",
" <td>0.3196</td>\n",
" <td>0.11510</td>\n",
" </tr>\n",
" <tr>\n",
" <th>844981</th>\n",
" <td>M</td>\n",
" <td>13.00</td>\n",
" <td>21.82</td>\n",
" <td>87.50</td>\n",
" <td>519.8</td>\n",
" <td>0.12730</td>\n",
" <td>0.19320</td>\n",
" <td>0.18590</td>\n",
" <td>0.09353</td>\n",
" <td>0.2350</td>\n",
" <td>...</td>\n",
" <td>15.49</td>\n",
" <td>30.73</td>\n",
" <td>106.20</td>\n",
" <td>739.3</td>\n",
" <td>0.1703</td>\n",
" <td>0.5401</td>\n",
" <td>0.5390</td>\n",
" <td>0.2060</td>\n",
" <td>0.4378</td>\n",
" <td>0.10720</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84501001</th>\n",
" <td>M</td>\n",
" <td>12.46</td>\n",
" <td>24.04</td>\n",
" <td>83.97</td>\n",
" <td>475.9</td>\n",
" <td>0.11860</td>\n",
" <td>0.23960</td>\n",
" <td>0.22730</td>\n",
" <td>0.08543</td>\n",
" <td>0.2030</td>\n",
" <td>...</td>\n",
" <td>15.09</td>\n",
" <td>40.68</td>\n",
" <td>97.65</td>\n",
" <td>711.4</td>\n",
" <td>0.1853</td>\n",
" <td>1.0580</td>\n",
" <td>1.1050</td>\n",
" <td>0.2210</td>\n",
" <td>0.4366</td>\n",
" <td>0.20750</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>9 rows × 31 columns</p>\n",
"</div>"
],
"text/plain": [
" diagnosis radius_mean texture_mean perimeter_mean area_mean \\\n",
"id \n",
"842517 M 20.57 17.77 132.90 1326.0 \n",
"84300903 M 19.69 21.25 130.00 1203.0 \n",
"84348301 M 11.42 20.38 77.58 386.1 \n",
"84358402 M 20.29 14.34 135.10 1297.0 \n",
"843786 M 12.45 15.70 82.57 477.1 \n",
"844359 M 18.25 19.98 119.60 1040.0 \n",
"84458202 M 13.71 20.83 90.20 577.9 \n",
"844981 M 13.00 21.82 87.50 519.8 \n",
"84501001 M 12.46 24.04 83.97 475.9 \n",
"\n",
" smoothness_mean compactness_mean concavity_mean \\\n",
"id \n",
"842517 0.08474 0.07864 0.08690 \n",
"84300903 0.10960 0.15990 0.19740 \n",
"84348301 0.14250 0.28390 0.24140 \n",
"84358402 0.10030 0.13280 0.19800 \n",
"843786 0.12780 0.17000 0.15780 \n",
"844359 0.09463 0.10900 0.11270 \n",
"84458202 0.11890 0.16450 0.09366 \n",
"844981 0.12730 0.19320 0.18590 \n",
"84501001 0.11860 0.23960 0.22730 \n",
"\n",
" concave points_mean symmetry_mean ... \\\n",
"id ... \n",
"842517 0.07017 0.1812 ... \n",
"84300903 0.12790 0.2069 ... \n",
"84348301 0.10520 0.2597 ... \n",
"84358402 0.10430 0.1809 ... \n",
"843786 0.08089 0.2087 ... \n",
"844359 0.07400 0.1794 ... \n",
"84458202 0.05985 0.2196 ... \n",
"844981 0.09353 0.2350 ... \n",
"84501001 0.08543 0.2030 ... \n",
"\n",
" radius_worst texture_worst perimeter_worst area_worst \\\n",
"id \n",
"842517 24.99 23.41 158.80 1956.0 \n",
"84300903 23.57 25.53 152.50 1709.0 \n",
"84348301 14.91 26.50 98.87 567.7 \n",
"84358402 22.54 16.67 152.20 1575.0 \n",
"843786 15.47 23.75 103.40 741.6 \n",
"844359 22.88 27.66 153.20 1606.0 \n",
"84458202 17.06 28.14 110.60 897.0 \n",
"844981 15.49 30.73 106.20 739.3 \n",
"84501001 15.09 40.68 97.65 711.4 \n",
"\n",
" smoothness_worst compactness_worst concavity_worst \\\n",
"id \n",
"842517 0.1238 0.1866 0.2416 \n",
"84300903 0.1444 0.4245 0.4504 \n",
"84348301 0.2098 0.8663 0.6869 \n",
"84358402 0.1374 0.2050 0.4000 \n",
"843786 0.1791 0.5249 0.5355 \n",
"844359 0.1442 0.2576 0.3784 \n",
"84458202 0.1654 0.3682 0.2678 \n",
"844981 0.1703 0.5401 0.5390 \n",
"84501001 0.1853 1.0580 1.1050 \n",
"\n",
" concave points_worst symmetry_worst fractal_dimension_worst \n",
"id \n",
"842517 0.1860 0.2750 0.08902 \n",
"84300903 0.2430 0.3613 0.08758 \n",
"84348301 0.2575 0.6638 0.17300 \n",
"84358402 0.1625 0.2364 0.07678 \n",
"843786 0.1741 0.3985 0.12440 \n",
"844359 0.1932 0.3063 0.08368 \n",
"84458202 0.1556 0.3196 0.11510 \n",
"844981 0.2060 0.4378 0.10720 \n",
"84501001 0.2210 0.4366 0.20750 \n",
"\n",
"[9 rows x 31 columns]"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = pd.read_csv('breast-cancer.csv', index_col=0)\n",
"data[1:10]"
]
},
{
"cell_type": "code",
"execution_count": 112,
"id": "outdoor-element",
"metadata": {},
"outputs": [],
"source": [
"lb = LabelEncoder()\n",
"data['diagnosis'] = lb.fit_transform(data['diagnosis'])\n",
"features = data.iloc[:, 1:32].values\n",
"labels = np.array(data['diagnosis'])"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "buried-community",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[2.057e+01, 1.777e+01, 1.329e+02, 1.326e+03, 8.474e-02, 7.864e-02,\n",
" 8.690e-02, 7.017e-02, 1.812e-01, 5.667e-02, 5.435e-01, 7.339e-01,\n",
" 3.398e+00, 7.408e+01, 5.225e-03, 1.308e-02, 1.860e-02, 1.340e-02,\n",
" 1.389e-02, 3.532e-03, 2.499e+01, 2.341e+01, 1.588e+02, 1.956e+03,\n",
" 1.238e-01, 1.866e-01, 2.416e-01, 1.860e-01, 2.750e-01, 8.902e-02],\n",
" [1.969e+01, 2.125e+01, 1.300e+02, 1.203e+03, 1.096e-01, 1.599e-01,\n",
" 1.974e-01, 1.279e-01, 2.069e-01, 5.999e-02, 7.456e-01, 7.869e-01,\n",
" 4.585e+00, 9.403e+01, 6.150e-03, 4.006e-02, 3.832e-02, 2.058e-02,\n",
" 2.250e-02, 4.571e-03, 2.357e+01, 2.553e+01, 1.525e+02, 1.709e+03,\n",
" 1.444e-01, 4.245e-01, 4.504e-01, 2.430e-01, 3.613e-01, 8.758e-02],\n",
" [1.142e+01, 2.038e+01, 7.758e+01, 3.861e+02, 1.425e-01, 2.839e-01,\n",
" 2.414e-01, 1.052e-01, 2.597e-01, 9.744e-02, 4.956e-01, 1.156e+00,\n",
" 3.445e+00, 2.723e+01, 9.110e-03, 7.458e-02, 5.661e-02, 1.867e-02,\n",
" 5.963e-02, 9.208e-03, 1.491e+01, 2.650e+01, 9.887e+01, 5.677e+02,\n",
" 2.098e-01, 8.663e-01, 6.869e-01, 2.575e-01, 6.638e-01, 1.730e-01],\n",
" [2.029e+01, 1.434e+01, 1.351e+02, 1.297e+03, 1.003e-01, 1.328e-01,\n",
" 1.980e-01, 1.043e-01, 1.809e-01, 5.883e-02, 7.572e-01, 7.813e-01,\n",
" 5.438e+00, 9.444e+01, 1.149e-02, 2.461e-02, 5.688e-02, 1.885e-02,\n",
" 1.756e-02, 5.115e-03, 2.254e+01, 1.667e+01, 1.522e+02, 1.575e+03,\n",
" 1.374e-01, 2.050e-01, 4.000e-01, 1.625e-01, 2.364e-01, 7.678e-02],\n",
" [1.245e+01, 1.570e+01, 8.257e+01, 4.771e+02, 1.278e-01, 1.700e-01,\n",
" 1.578e-01, 8.089e-02, 2.087e-01, 7.613e-02, 3.345e-01, 8.902e-01,\n",
" 2.217e+00, 2.719e+01, 7.510e-03, 3.345e-02, 3.672e-02, 1.137e-02,\n",
" 2.165e-02, 5.082e-03, 1.547e+01, 2.375e+01, 1.034e+02, 7.416e+02,\n",
" 1.791e-01, 5.249e-01, 5.355e-01, 1.741e-01, 3.985e-01, 1.244e-01],\n",
" [1.825e+01, 1.998e+01, 1.196e+02, 1.040e+03, 9.463e-02, 1.090e-01,\n",
" 1.127e-01, 7.400e-02, 1.794e-01, 5.742e-02, 4.467e-01, 7.732e-01,\n",
" 3.180e+00, 5.391e+01, 4.314e-03, 1.382e-02, 2.254e-02, 1.039e-02,\n",
" 1.369e-02, 2.179e-03, 2.288e+01, 2.766e+01, 1.532e+02, 1.606e+03,\n",
" 1.442e-01, 2.576e-01, 3.784e-01, 1.932e-01, 3.063e-01, 8.368e-02],\n",
" [1.371e+01, 2.083e+01, 9.020e+01, 5.779e+02, 1.189e-01, 1.645e-01,\n",
" 9.366e-02, 5.985e-02, 2.196e-01, 7.451e-02, 5.835e-01, 1.377e+00,\n",
" 3.856e+00, 5.096e+01, 8.805e-03, 3.029e-02, 2.488e-02, 1.448e-02,\n",
" 1.486e-02, 5.412e-03, 1.706e+01, 2.814e+01, 1.106e+02, 8.970e+02,\n",
" 1.654e-01, 3.682e-01, 2.678e-01, 1.556e-01, 3.196e-01, 1.151e-01],\n",
" [1.300e+01, 2.182e+01, 8.750e+01, 5.198e+02, 1.273e-01, 1.932e-01,\n",
" 1.859e-01, 9.353e-02, 2.350e-01, 7.389e-02, 3.063e-01, 1.002e+00,\n",
" 2.406e+00, 2.432e+01, 5.731e-03, 3.502e-02, 3.553e-02, 1.226e-02,\n",
" 2.143e-02, 3.749e-03, 1.549e+01, 3.073e+01, 1.062e+02, 7.393e+02,\n",
" 1.703e-01, 5.401e-01, 5.390e-01, 2.060e-01, 4.378e-01, 1.072e-01],\n",
" [1.246e+01, 2.404e+01, 8.397e+01, 4.759e+02, 1.186e-01, 2.396e-01,\n",
" 2.273e-01, 8.543e-02, 2.030e-01, 8.243e-02, 2.976e-01, 1.599e+00,\n",
" 2.039e+00, 2.394e+01, 7.149e-03, 7.217e-02, 7.743e-02, 1.432e-02,\n",
" 1.789e-02, 1.008e-02, 1.509e+01, 4.068e+01, 9.765e+01, 7.114e+02,\n",
" 1.853e-01, 1.058e+00, 1.105e+00, 2.210e-01, 4.366e-01, 2.075e-01]])"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"features[1:10]"
]
},
{
"cell_type": "code",
"execution_count": 114,
"id": "incredible-quantum",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 1, 1, 1, 1, 1, 1, 1, 1])"
]
},
"execution_count": 114,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels[1:10]"
]
},
{
"cell_type": "code",
"execution_count": 115,
"id": "brazilian-butler",
"metadata": {},
"outputs": [],
"source": [
"features_train, features_test, labels_train, labels_test = train_test_split(features, labels, random_state=42, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 116,
"id": "exotic-method",
"metadata": {},
"outputs": [],
"source": [
"# Training\n",
"model = Model(features_train.shape[1])\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"epochs = 100\n",
"\n",
"def print_(loss):\n",
" print (\"The loss calculated: \", loss)"
]
},
{
"cell_type": "code",
"execution_count": 117,
"id": "sharp-month",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch # 1\n",
"The loss calculated: 0.922476053237915\n",
"Epoch # 2\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 3\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 4\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 5\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 6\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 7\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 8\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 9\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 10\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 11\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 12\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 13\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 14\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 15\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 16\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 17\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 18\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 19\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 20\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 21\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 22\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 23\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 24\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 25\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 26\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 27\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 28\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 29\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 30\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 31\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 32\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 33\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 34\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 35\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 36\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 37\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 38\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 39\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 40\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 41\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 42\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 43\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 44\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 45\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 46\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 47\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 48\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 49\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 50\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 51\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 52\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 53\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 54\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 55\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 56\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 57\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 58\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 59\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 60\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 61\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 62\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 63\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 64\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 65\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 66\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 67\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 68\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 69\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 70\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 71\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 72\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 73\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 74\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 75\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 76\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 77\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 78\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 79\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 80\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 81\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 82\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 83\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 84\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 85\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 86\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 87\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 88\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 89\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 90\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 91\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 92\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 93\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 94\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 95\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 96\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 97\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 98\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 99\n",
"The loss calculated: 0.9223369359970093\n",
"Epoch # 100\n",
"The loss calculated: 0.9223369359970093\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" # This is added back by InteractiveShellApp.init_path()\n"
]
}
],
"source": [
"# Not using dataloader\n",
"x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long()\n",
"for epoch in range(1, epochs+1):\n",
" print (\"Epoch #\",epoch)\n",
" y_pred = model(x_train)\n",
" loss = loss_fn(y_pred, y_train)\n",
" print_(loss.item())\n",
" \n",
" # Zero gradients\n",
" optimizer.zero_grad()\n",
" loss.backward() # Gradients\n",
" optimizer.step() # Update"
]
},
{
"cell_type": "code",
"execution_count": 118,
"id": "mechanical-humidity",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" # This is added back by InteractiveShellApp.init_path()\n"
]
}
],
"source": [
"# Prediction\n",
"x_test = Variable(torch.from_numpy(features_test)).float()\n",
"pred = model(x_test)"
]
},
{
"cell_type": "code",
"execution_count": 119,
"id": "based-charleston",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.],\n",
" [1., 0., 0.]], dtype=float32)"
]
},
"execution_count": 119,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred = pred.detach().numpy()\n",
"pred[1:10]"
]
},
{
"cell_type": "code",
"execution_count": 120,
"id": "dried-accessory",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The accuracy is 0.6223776223776224\n"
]
}
],
"source": [
"print (\"The accuracy is\", accuracy_score(labels_test, np.argmax(pred, axis=1)))"
]
},
{
"cell_type": "code",
"execution_count": 121,
"id": "effective-characterization",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 121,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels_test[0]"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "oriented-determination",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model, \"travel_insurance-pytorch.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "infectious-wagon",
"metadata": {},
"outputs": [],
"source": [
"saved_model = torch.load(\"travel_insurance-pytorch.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 124,
"id": "built-contributor",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" # This is added back by InteractiveShellApp.init_path()\n"
]
},
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.argmax(saved_model(x_test[0]).detach().numpy(), axis=0)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

530
classification_net.ipynb Normal file
View File

@ -0,0 +1,530 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "forty-fault",
"metadata": {},
"outputs": [],
"source": [
"!kaggle datasets download -d kukuroo3/body-performance-data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "pediatric-tuesday",
"metadata": {},
"outputs": [],
"source": [
"!unzip -o body-performance-data.zip"
]
},
{
"cell_type": "code",
"execution_count": 114,
"id": "interstate-presence",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import classification_report\n",
"import torch\n",
"from torch import nn, optim"
]
},
{
"cell_type": "code",
"execution_count": 115,
"id": "structural-trigger",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(13393, 12)"
]
},
"execution_count": 115,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv('bodyPerformance.csv')\n",
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": 116,
"id": "turkish-category",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>gender</th>\n",
" <th>height_cm</th>\n",
" <th>weight_kg</th>\n",
" <th>body fat_%</th>\n",
" <th>diastolic</th>\n",
" <th>systolic</th>\n",
" <th>gripForce</th>\n",
" <th>sit and bend forward_cm</th>\n",
" <th>sit-ups counts</th>\n",
" <th>broad jump_cm</th>\n",
" <th>class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>27.0</td>\n",
" <td>M</td>\n",
" <td>172.3</td>\n",
" <td>75.24</td>\n",
" <td>21.3</td>\n",
" <td>80.0</td>\n",
" <td>130.0</td>\n",
" <td>54.9</td>\n",
" <td>18.4</td>\n",
" <td>60.0</td>\n",
" <td>217.0</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>25.0</td>\n",
" <td>M</td>\n",
" <td>165.0</td>\n",
" <td>55.80</td>\n",
" <td>15.7</td>\n",
" <td>77.0</td>\n",
" <td>126.0</td>\n",
" <td>36.4</td>\n",
" <td>16.3</td>\n",
" <td>53.0</td>\n",
" <td>229.0</td>\n",
" <td>A</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>31.0</td>\n",
" <td>M</td>\n",
" <td>179.6</td>\n",
" <td>78.00</td>\n",
" <td>20.1</td>\n",
" <td>92.0</td>\n",
" <td>152.0</td>\n",
" <td>44.8</td>\n",
" <td>12.0</td>\n",
" <td>49.0</td>\n",
" <td>181.0</td>\n",
" <td>C</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>32.0</td>\n",
" <td>M</td>\n",
" <td>174.5</td>\n",
" <td>71.10</td>\n",
" <td>18.4</td>\n",
" <td>76.0</td>\n",
" <td>147.0</td>\n",
" <td>41.4</td>\n",
" <td>15.2</td>\n",
" <td>53.0</td>\n",
" <td>219.0</td>\n",
" <td>B</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>28.0</td>\n",
" <td>M</td>\n",
" <td>173.8</td>\n",
" <td>67.70</td>\n",
" <td>17.1</td>\n",
" <td>70.0</td>\n",
" <td>127.0</td>\n",
" <td>43.5</td>\n",
" <td>27.1</td>\n",
" <td>45.0</td>\n",
" <td>217.0</td>\n",
" <td>B</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age gender height_cm weight_kg body fat_% diastolic systolic \\\n",
"0 27.0 M 172.3 75.24 21.3 80.0 130.0 \n",
"1 25.0 M 165.0 55.80 15.7 77.0 126.0 \n",
"2 31.0 M 179.6 78.00 20.1 92.0 152.0 \n",
"3 32.0 M 174.5 71.10 18.4 76.0 147.0 \n",
"4 28.0 M 173.8 67.70 17.1 70.0 127.0 \n",
"\n",
" gripForce sit and bend forward_cm sit-ups counts broad jump_cm class \n",
"0 54.9 18.4 60.0 217.0 C \n",
"1 36.4 16.3 53.0 229.0 A \n",
"2 44.8 12.0 49.0 181.0 C \n",
"3 41.4 15.2 53.0 219.0 B \n",
"4 43.5 27.1 45.0 217.0 B "
]
},
"execution_count": 116,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 117,
"id": "received-absence",
"metadata": {},
"outputs": [],
"source": [
"cols = ['gender', 'height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']\n",
"df = df[cols]\n",
"\n",
"# male - 0, female - 1\n",
"df['gender'].replace({'M': 0, 'F': 1}, inplace = True)\n",
"df = df.dropna(how='any')"
]
},
{
"cell_type": "code",
"execution_count": 118,
"id": "excited-parent",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 0.632196\n",
"1 0.367804\n",
"Name: gender, dtype: float64"
]
},
"execution_count": 118,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.gender.value_counts() / df.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": 119,
"id": "extended-cinema",
"metadata": {},
"outputs": [],
"source": [
"X = df[['height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']]\n",
"y = df[['gender']]\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 120,
"id": "animated-farming",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([10714, 5]) torch.Size([10714])\n",
"torch.Size([2679, 5]) torch.Size([2679])\n"
]
}
],
"source": [
"X_train = torch.from_numpy(np.array(X_train)).float()\n",
"y_train = torch.squeeze(torch.from_numpy(y_train.values).float())\n",
"\n",
"X_test = torch.from_numpy(np.array(X_test)).float()\n",
"y_test = torch.squeeze(torch.from_numpy(y_test.values).float())\n",
"\n",
"print(X_train.shape, y_train.shape)\n",
"print(X_test.shape, y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 121,
"id": "technical-wallet",
"metadata": {},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" def __init__(self, n_features):\n",
" super(Net, self).__init__()\n",
" self.fc1 = nn.Linear(n_features, 5)\n",
" self.fc2 = nn.Linear(5, 3)\n",
" self.fc3 = nn.Linear(3, 1)\n",
" def forward(self, x):\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" return torch.sigmoid(self.fc3(x))\n",
"net = Net(X_train.shape[1])"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "requested-plymouth",
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.BCELoss()"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "iraqi-english",
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.Adam(net.parameters(), lr=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 124,
"id": "emerging-helmet",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 125,
"id": "differential-aviation",
"metadata": {},
"outputs": [],
"source": [
"X_train = X_train.to(device)\n",
"y_train = y_train.to(device)\n",
"X_test = X_test.to(device)\n",
"y_test = y_test.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "ranging-calgary",
"metadata": {},
"outputs": [],
"source": [
"net = net.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 127,
"id": "iraqi-blanket",
"metadata": {},
"outputs": [],
"source": [
"def calculate_accuracy(y_true, y_pred):\n",
" predicted = y_pred.ge(.5).view(-1)\n",
" return (y_true == predicted).sum().float() / len(y_true)"
]
},
{
"cell_type": "code",
"execution_count": 128,
"id": "robust-serbia",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0\n",
"Train set - loss: 1.005, accuracy: 0.37\n",
"Test set - loss: 1.018, accuracy: 0.358\n",
"\n",
"epoch 100\n",
"Train set - loss: 0.677, accuracy: 0.743\n",
"Test set - loss: 0.679, accuracy: 0.727\n",
"\n",
"epoch 200\n",
"Train set - loss: 0.636, accuracy: 0.79\n",
"Test set - loss: 0.64, accuracy: 0.778\n",
"\n",
"epoch 300\n",
"Train set - loss: 0.568, accuracy: 0.839\n",
"Test set - loss: 0.577, accuracy: 0.833\n",
"\n",
"epoch 400\n",
"Train set - loss: 0.504, accuracy: 0.885\n",
"Test set - loss: 0.514, accuracy: 0.877\n",
"\n",
"epoch 500\n",
"Train set - loss: 0.441, accuracy: 0.922\n",
"Test set - loss: 0.45, accuracy: 0.913\n",
"\n",
"epoch 600\n",
"Train set - loss: 0.388, accuracy: 0.944\n",
"Test set - loss: 0.396, accuracy: 0.938\n",
"\n",
"epoch 700\n",
"Train set - loss: 0.353, accuracy: 0.954\n",
"Test set - loss: 0.359, accuracy: 0.949\n",
"\n",
"epoch 800\n",
"Train set - loss: 0.327, accuracy: 0.958\n",
"Test set - loss: 0.333, accuracy: 0.953\n",
"\n",
"epoch 900\n",
"Train set - loss: 0.306, accuracy: 0.961\n",
"Test set - loss: 0.312, accuracy: 0.955\n",
"\n"
]
}
],
"source": [
"def round_tensor(t, decimal_places=3):\n",
" return round(t.item(), decimal_places)\n",
"for epoch in range(1000):\n",
" y_pred = net(X_train)\n",
" y_pred = torch.squeeze(y_pred)\n",
" train_loss = criterion(y_pred, y_train)\n",
" if epoch % 100 == 0:\n",
" train_acc = calculate_accuracy(y_train, y_pred)\n",
" y_test_pred = net(X_test)\n",
" y_test_pred = torch.squeeze(y_test_pred)\n",
" test_loss = criterion(y_test_pred, y_test)\n",
" test_acc = calculate_accuracy(y_test, y_test_pred)\n",
" print(\n",
"f'''epoch {epoch}\n",
"Train set - loss: {round_tensor(train_loss)}, accuracy: {round_tensor(train_acc)}\n",
"Test set - loss: {round_tensor(test_loss)}, accuracy: {round_tensor(test_acc)}\n",
"''')\n",
" optimizer.zero_grad()\n",
" train_loss.backward()\n",
" optimizer.step()"
]
},
{
"cell_type": "code",
"execution_count": 129,
"id": "optimum-excerpt",
"metadata": {},
"outputs": [],
"source": [
"# torch.save(net, 'model.pth')"
]
},
{
"cell_type": "code",
"execution_count": 130,
"id": "dental-seating",
"metadata": {},
"outputs": [],
"source": [
"# net = torch.load('model.pth')"
]
},
{
"cell_type": "code",
"execution_count": 131,
"id": "german-satisfaction",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" Male 0.97 0.96 0.96 1720\n",
" Female 0.93 0.94 0.94 959\n",
"\n",
" accuracy 0.95 2679\n",
" macro avg 0.95 0.95 0.95 2679\n",
"weighted avg 0.95 0.95 0.95 2679\n",
"\n"
]
}
],
"source": [
"classes = ['Male', 'Female']\n",
"y_pred = net(X_test)\n",
"y_pred = y_pred.ge(.5).view(-1).cpu()\n",
"y_test = y_test.cpu()\n",
"print(classification_report(y_test, y_pred, target_names=classes))"
]
},
{
"cell_type": "code",
"execution_count": 132,
"id": "british-incidence",
"metadata": {},
"outputs": [],
"source": [
"with open('test_out.csv', 'w') as file:\n",
" for y in y_pred:\n",
" file.write(classes[y.item()])\n",
" file.write('\\n')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}