Change dataset
This commit is contained in:
parent
c932eb6bba
commit
4a8da732fc
BIN
body-performance-data.zip
Normal file
BIN
body-performance-data.zip
Normal file
Binary file not shown.
Binary file not shown.
@ -1,969 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "corrected-wholesale",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!kaggle datasets download -d yasserh/breast-cancer-dataset"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "ranging-police",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!unzip -o breast-cancer-dataset.zip"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 109,
|
|
||||||
"id": "ideal-spouse",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import torch\n",
|
|
||||||
"from torch import nn\n",
|
|
||||||
"from torch.autograd import Variable\n",
|
|
||||||
"from sklearn.datasets import load_iris\n",
|
|
||||||
"from sklearn.model_selection import train_test_split\n",
|
|
||||||
"from sklearn.metrics import accuracy_score\n",
|
|
||||||
"from sklearn.preprocessing import LabelEncoder\n",
|
|
||||||
"from tensorflow.keras.utils import to_categorical\n",
|
|
||||||
"import torch.nn.functional as F\n",
|
|
||||||
"import pandas as pd"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 110,
|
|
||||||
"id": "major-compromise",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"class Model(nn.Module):\n",
|
|
||||||
" def __init__(self, input_dim):\n",
|
|
||||||
" super(Model, self).__init__()\n",
|
|
||||||
" self.layer1 = nn.Linear(input_dim,50)\n",
|
|
||||||
" self.layer2 = nn.Linear(50, 20)\n",
|
|
||||||
" self.layer3 = nn.Linear(20, 3)\n",
|
|
||||||
" \n",
|
|
||||||
" def forward(self, x):\n",
|
|
||||||
" x = F.relu(self.layer1(x))\n",
|
|
||||||
" x = F.relu(self.layer2(x))\n",
|
|
||||||
" x = F.softmax(self.layer3(x)) # To check with the loss function\n",
|
|
||||||
" return x"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 111,
|
|
||||||
"id": "czech-regular",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<div>\n",
|
|
||||||
"<style scoped>\n",
|
|
||||||
" .dataframe tbody tr th:only-of-type {\n",
|
|
||||||
" vertical-align: middle;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe tbody tr th {\n",
|
|
||||||
" vertical-align: top;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe thead th {\n",
|
|
||||||
" text-align: right;\n",
|
|
||||||
" }\n",
|
|
||||||
"</style>\n",
|
|
||||||
"<table border=\"1\" class=\"dataframe\">\n",
|
|
||||||
" <thead>\n",
|
|
||||||
" <tr style=\"text-align: right;\">\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th>diagnosis</th>\n",
|
|
||||||
" <th>radius_mean</th>\n",
|
|
||||||
" <th>texture_mean</th>\n",
|
|
||||||
" <th>perimeter_mean</th>\n",
|
|
||||||
" <th>area_mean</th>\n",
|
|
||||||
" <th>smoothness_mean</th>\n",
|
|
||||||
" <th>compactness_mean</th>\n",
|
|
||||||
" <th>concavity_mean</th>\n",
|
|
||||||
" <th>concave points_mean</th>\n",
|
|
||||||
" <th>symmetry_mean</th>\n",
|
|
||||||
" <th>...</th>\n",
|
|
||||||
" <th>radius_worst</th>\n",
|
|
||||||
" <th>texture_worst</th>\n",
|
|
||||||
" <th>perimeter_worst</th>\n",
|
|
||||||
" <th>area_worst</th>\n",
|
|
||||||
" <th>smoothness_worst</th>\n",
|
|
||||||
" <th>compactness_worst</th>\n",
|
|
||||||
" <th>concavity_worst</th>\n",
|
|
||||||
" <th>concave points_worst</th>\n",
|
|
||||||
" <th>symmetry_worst</th>\n",
|
|
||||||
" <th>fractal_dimension_worst</th>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>id</th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </thead>\n",
|
|
||||||
" <tbody>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>842517</th>\n",
|
|
||||||
" <td>M</td>\n",
|
|
||||||
" <td>20.57</td>\n",
|
|
||||||
" <td>17.77</td>\n",
|
|
||||||
" <td>132.90</td>\n",
|
|
||||||
" <td>1326.0</td>\n",
|
|
||||||
" <td>0.08474</td>\n",
|
|
||||||
" <td>0.07864</td>\n",
|
|
||||||
" <td>0.08690</td>\n",
|
|
||||||
" <td>0.07017</td>\n",
|
|
||||||
" <td>0.1812</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>24.99</td>\n",
|
|
||||||
" <td>23.41</td>\n",
|
|
||||||
" <td>158.80</td>\n",
|
|
||||||
" <td>1956.0</td>\n",
|
|
||||||
" <td>0.1238</td>\n",
|
|
||||||
" <td>0.1866</td>\n",
|
|
||||||
" <td>0.2416</td>\n",
|
|
||||||
" <td>0.1860</td>\n",
|
|
||||||
" <td>0.2750</td>\n",
|
|
||||||
" <td>0.08902</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>84300903</th>\n",
|
|
||||||
" <td>M</td>\n",
|
|
||||||
" <td>19.69</td>\n",
|
|
||||||
" <td>21.25</td>\n",
|
|
||||||
" <td>130.00</td>\n",
|
|
||||||
" <td>1203.0</td>\n",
|
|
||||||
" <td>0.10960</td>\n",
|
|
||||||
" <td>0.15990</td>\n",
|
|
||||||
" <td>0.19740</td>\n",
|
|
||||||
" <td>0.12790</td>\n",
|
|
||||||
" <td>0.2069</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>23.57</td>\n",
|
|
||||||
" <td>25.53</td>\n",
|
|
||||||
" <td>152.50</td>\n",
|
|
||||||
" <td>1709.0</td>\n",
|
|
||||||
" <td>0.1444</td>\n",
|
|
||||||
" <td>0.4245</td>\n",
|
|
||||||
" <td>0.4504</td>\n",
|
|
||||||
" <td>0.2430</td>\n",
|
|
||||||
" <td>0.3613</td>\n",
|
|
||||||
" <td>0.08758</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>84348301</th>\n",
|
|
||||||
" <td>M</td>\n",
|
|
||||||
" <td>11.42</td>\n",
|
|
||||||
" <td>20.38</td>\n",
|
|
||||||
" <td>77.58</td>\n",
|
|
||||||
" <td>386.1</td>\n",
|
|
||||||
" <td>0.14250</td>\n",
|
|
||||||
" <td>0.28390</td>\n",
|
|
||||||
" <td>0.24140</td>\n",
|
|
||||||
" <td>0.10520</td>\n",
|
|
||||||
" <td>0.2597</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>14.91</td>\n",
|
|
||||||
" <td>26.50</td>\n",
|
|
||||||
" <td>98.87</td>\n",
|
|
||||||
" <td>567.7</td>\n",
|
|
||||||
" <td>0.2098</td>\n",
|
|
||||||
" <td>0.8663</td>\n",
|
|
||||||
" <td>0.6869</td>\n",
|
|
||||||
" <td>0.2575</td>\n",
|
|
||||||
" <td>0.6638</td>\n",
|
|
||||||
" <td>0.17300</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>84358402</th>\n",
|
|
||||||
" <td>M</td>\n",
|
|
||||||
" <td>20.29</td>\n",
|
|
||||||
" <td>14.34</td>\n",
|
|
||||||
" <td>135.10</td>\n",
|
|
||||||
" <td>1297.0</td>\n",
|
|
||||||
" <td>0.10030</td>\n",
|
|
||||||
" <td>0.13280</td>\n",
|
|
||||||
" <td>0.19800</td>\n",
|
|
||||||
" <td>0.10430</td>\n",
|
|
||||||
" <td>0.1809</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>22.54</td>\n",
|
|
||||||
" <td>16.67</td>\n",
|
|
||||||
" <td>152.20</td>\n",
|
|
||||||
" <td>1575.0</td>\n",
|
|
||||||
" <td>0.1374</td>\n",
|
|
||||||
" <td>0.2050</td>\n",
|
|
||||||
" <td>0.4000</td>\n",
|
|
||||||
" <td>0.1625</td>\n",
|
|
||||||
" <td>0.2364</td>\n",
|
|
||||||
" <td>0.07678</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>843786</th>\n",
|
|
||||||
" <td>M</td>\n",
|
|
||||||
" <td>12.45</td>\n",
|
|
||||||
" <td>15.70</td>\n",
|
|
||||||
" <td>82.57</td>\n",
|
|
||||||
" <td>477.1</td>\n",
|
|
||||||
" <td>0.12780</td>\n",
|
|
||||||
" <td>0.17000</td>\n",
|
|
||||||
" <td>0.15780</td>\n",
|
|
||||||
" <td>0.08089</td>\n",
|
|
||||||
" <td>0.2087</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>15.47</td>\n",
|
|
||||||
" <td>23.75</td>\n",
|
|
||||||
" <td>103.40</td>\n",
|
|
||||||
" <td>741.6</td>\n",
|
|
||||||
" <td>0.1791</td>\n",
|
|
||||||
" <td>0.5249</td>\n",
|
|
||||||
" <td>0.5355</td>\n",
|
|
||||||
" <td>0.1741</td>\n",
|
|
||||||
" <td>0.3985</td>\n",
|
|
||||||
" <td>0.12440</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>844359</th>\n",
|
|
||||||
" <td>M</td>\n",
|
|
||||||
" <td>18.25</td>\n",
|
|
||||||
" <td>19.98</td>\n",
|
|
||||||
" <td>119.60</td>\n",
|
|
||||||
" <td>1040.0</td>\n",
|
|
||||||
" <td>0.09463</td>\n",
|
|
||||||
" <td>0.10900</td>\n",
|
|
||||||
" <td>0.11270</td>\n",
|
|
||||||
" <td>0.07400</td>\n",
|
|
||||||
" <td>0.1794</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>22.88</td>\n",
|
|
||||||
" <td>27.66</td>\n",
|
|
||||||
" <td>153.20</td>\n",
|
|
||||||
" <td>1606.0</td>\n",
|
|
||||||
" <td>0.1442</td>\n",
|
|
||||||
" <td>0.2576</td>\n",
|
|
||||||
" <td>0.3784</td>\n",
|
|
||||||
" <td>0.1932</td>\n",
|
|
||||||
" <td>0.3063</td>\n",
|
|
||||||
" <td>0.08368</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>84458202</th>\n",
|
|
||||||
" <td>M</td>\n",
|
|
||||||
" <td>13.71</td>\n",
|
|
||||||
" <td>20.83</td>\n",
|
|
||||||
" <td>90.20</td>\n",
|
|
||||||
" <td>577.9</td>\n",
|
|
||||||
" <td>0.11890</td>\n",
|
|
||||||
" <td>0.16450</td>\n",
|
|
||||||
" <td>0.09366</td>\n",
|
|
||||||
" <td>0.05985</td>\n",
|
|
||||||
" <td>0.2196</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>17.06</td>\n",
|
|
||||||
" <td>28.14</td>\n",
|
|
||||||
" <td>110.60</td>\n",
|
|
||||||
" <td>897.0</td>\n",
|
|
||||||
" <td>0.1654</td>\n",
|
|
||||||
" <td>0.3682</td>\n",
|
|
||||||
" <td>0.2678</td>\n",
|
|
||||||
" <td>0.1556</td>\n",
|
|
||||||
" <td>0.3196</td>\n",
|
|
||||||
" <td>0.11510</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>844981</th>\n",
|
|
||||||
" <td>M</td>\n",
|
|
||||||
" <td>13.00</td>\n",
|
|
||||||
" <td>21.82</td>\n",
|
|
||||||
" <td>87.50</td>\n",
|
|
||||||
" <td>519.8</td>\n",
|
|
||||||
" <td>0.12730</td>\n",
|
|
||||||
" <td>0.19320</td>\n",
|
|
||||||
" <td>0.18590</td>\n",
|
|
||||||
" <td>0.09353</td>\n",
|
|
||||||
" <td>0.2350</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>15.49</td>\n",
|
|
||||||
" <td>30.73</td>\n",
|
|
||||||
" <td>106.20</td>\n",
|
|
||||||
" <td>739.3</td>\n",
|
|
||||||
" <td>0.1703</td>\n",
|
|
||||||
" <td>0.5401</td>\n",
|
|
||||||
" <td>0.5390</td>\n",
|
|
||||||
" <td>0.2060</td>\n",
|
|
||||||
" <td>0.4378</td>\n",
|
|
||||||
" <td>0.10720</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>84501001</th>\n",
|
|
||||||
" <td>M</td>\n",
|
|
||||||
" <td>12.46</td>\n",
|
|
||||||
" <td>24.04</td>\n",
|
|
||||||
" <td>83.97</td>\n",
|
|
||||||
" <td>475.9</td>\n",
|
|
||||||
" <td>0.11860</td>\n",
|
|
||||||
" <td>0.23960</td>\n",
|
|
||||||
" <td>0.22730</td>\n",
|
|
||||||
" <td>0.08543</td>\n",
|
|
||||||
" <td>0.2030</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>15.09</td>\n",
|
|
||||||
" <td>40.68</td>\n",
|
|
||||||
" <td>97.65</td>\n",
|
|
||||||
" <td>711.4</td>\n",
|
|
||||||
" <td>0.1853</td>\n",
|
|
||||||
" <td>1.0580</td>\n",
|
|
||||||
" <td>1.1050</td>\n",
|
|
||||||
" <td>0.2210</td>\n",
|
|
||||||
" <td>0.4366</td>\n",
|
|
||||||
" <td>0.20750</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </tbody>\n",
|
|
||||||
"</table>\n",
|
|
||||||
"<p>9 rows × 31 columns</p>\n",
|
|
||||||
"</div>"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
" diagnosis radius_mean texture_mean perimeter_mean area_mean \\\n",
|
|
||||||
"id \n",
|
|
||||||
"842517 M 20.57 17.77 132.90 1326.0 \n",
|
|
||||||
"84300903 M 19.69 21.25 130.00 1203.0 \n",
|
|
||||||
"84348301 M 11.42 20.38 77.58 386.1 \n",
|
|
||||||
"84358402 M 20.29 14.34 135.10 1297.0 \n",
|
|
||||||
"843786 M 12.45 15.70 82.57 477.1 \n",
|
|
||||||
"844359 M 18.25 19.98 119.60 1040.0 \n",
|
|
||||||
"84458202 M 13.71 20.83 90.20 577.9 \n",
|
|
||||||
"844981 M 13.00 21.82 87.50 519.8 \n",
|
|
||||||
"84501001 M 12.46 24.04 83.97 475.9 \n",
|
|
||||||
"\n",
|
|
||||||
" smoothness_mean compactness_mean concavity_mean \\\n",
|
|
||||||
"id \n",
|
|
||||||
"842517 0.08474 0.07864 0.08690 \n",
|
|
||||||
"84300903 0.10960 0.15990 0.19740 \n",
|
|
||||||
"84348301 0.14250 0.28390 0.24140 \n",
|
|
||||||
"84358402 0.10030 0.13280 0.19800 \n",
|
|
||||||
"843786 0.12780 0.17000 0.15780 \n",
|
|
||||||
"844359 0.09463 0.10900 0.11270 \n",
|
|
||||||
"84458202 0.11890 0.16450 0.09366 \n",
|
|
||||||
"844981 0.12730 0.19320 0.18590 \n",
|
|
||||||
"84501001 0.11860 0.23960 0.22730 \n",
|
|
||||||
"\n",
|
|
||||||
" concave points_mean symmetry_mean ... \\\n",
|
|
||||||
"id ... \n",
|
|
||||||
"842517 0.07017 0.1812 ... \n",
|
|
||||||
"84300903 0.12790 0.2069 ... \n",
|
|
||||||
"84348301 0.10520 0.2597 ... \n",
|
|
||||||
"84358402 0.10430 0.1809 ... \n",
|
|
||||||
"843786 0.08089 0.2087 ... \n",
|
|
||||||
"844359 0.07400 0.1794 ... \n",
|
|
||||||
"84458202 0.05985 0.2196 ... \n",
|
|
||||||
"844981 0.09353 0.2350 ... \n",
|
|
||||||
"84501001 0.08543 0.2030 ... \n",
|
|
||||||
"\n",
|
|
||||||
" radius_worst texture_worst perimeter_worst area_worst \\\n",
|
|
||||||
"id \n",
|
|
||||||
"842517 24.99 23.41 158.80 1956.0 \n",
|
|
||||||
"84300903 23.57 25.53 152.50 1709.0 \n",
|
|
||||||
"84348301 14.91 26.50 98.87 567.7 \n",
|
|
||||||
"84358402 22.54 16.67 152.20 1575.0 \n",
|
|
||||||
"843786 15.47 23.75 103.40 741.6 \n",
|
|
||||||
"844359 22.88 27.66 153.20 1606.0 \n",
|
|
||||||
"84458202 17.06 28.14 110.60 897.0 \n",
|
|
||||||
"844981 15.49 30.73 106.20 739.3 \n",
|
|
||||||
"84501001 15.09 40.68 97.65 711.4 \n",
|
|
||||||
"\n",
|
|
||||||
" smoothness_worst compactness_worst concavity_worst \\\n",
|
|
||||||
"id \n",
|
|
||||||
"842517 0.1238 0.1866 0.2416 \n",
|
|
||||||
"84300903 0.1444 0.4245 0.4504 \n",
|
|
||||||
"84348301 0.2098 0.8663 0.6869 \n",
|
|
||||||
"84358402 0.1374 0.2050 0.4000 \n",
|
|
||||||
"843786 0.1791 0.5249 0.5355 \n",
|
|
||||||
"844359 0.1442 0.2576 0.3784 \n",
|
|
||||||
"84458202 0.1654 0.3682 0.2678 \n",
|
|
||||||
"844981 0.1703 0.5401 0.5390 \n",
|
|
||||||
"84501001 0.1853 1.0580 1.1050 \n",
|
|
||||||
"\n",
|
|
||||||
" concave points_worst symmetry_worst fractal_dimension_worst \n",
|
|
||||||
"id \n",
|
|
||||||
"842517 0.1860 0.2750 0.08902 \n",
|
|
||||||
"84300903 0.2430 0.3613 0.08758 \n",
|
|
||||||
"84348301 0.2575 0.6638 0.17300 \n",
|
|
||||||
"84358402 0.1625 0.2364 0.07678 \n",
|
|
||||||
"843786 0.1741 0.3985 0.12440 \n",
|
|
||||||
"844359 0.1932 0.3063 0.08368 \n",
|
|
||||||
"84458202 0.1556 0.3196 0.11510 \n",
|
|
||||||
"844981 0.2060 0.4378 0.10720 \n",
|
|
||||||
"84501001 0.2210 0.4366 0.20750 \n",
|
|
||||||
"\n",
|
|
||||||
"[9 rows x 31 columns]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 111,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"data = pd.read_csv('breast-cancer.csv', index_col=0)\n",
|
|
||||||
"data[1:10]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 112,
|
|
||||||
"id": "outdoor-element",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"lb = LabelEncoder()\n",
|
|
||||||
"data['diagnosis'] = lb.fit_transform(data['diagnosis'])\n",
|
|
||||||
"features = data.iloc[:, 1:32].values\n",
|
|
||||||
"labels = np.array(data['diagnosis'])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 113,
|
|
||||||
"id": "buried-community",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"array([[2.057e+01, 1.777e+01, 1.329e+02, 1.326e+03, 8.474e-02, 7.864e-02,\n",
|
|
||||||
" 8.690e-02, 7.017e-02, 1.812e-01, 5.667e-02, 5.435e-01, 7.339e-01,\n",
|
|
||||||
" 3.398e+00, 7.408e+01, 5.225e-03, 1.308e-02, 1.860e-02, 1.340e-02,\n",
|
|
||||||
" 1.389e-02, 3.532e-03, 2.499e+01, 2.341e+01, 1.588e+02, 1.956e+03,\n",
|
|
||||||
" 1.238e-01, 1.866e-01, 2.416e-01, 1.860e-01, 2.750e-01, 8.902e-02],\n",
|
|
||||||
" [1.969e+01, 2.125e+01, 1.300e+02, 1.203e+03, 1.096e-01, 1.599e-01,\n",
|
|
||||||
" 1.974e-01, 1.279e-01, 2.069e-01, 5.999e-02, 7.456e-01, 7.869e-01,\n",
|
|
||||||
" 4.585e+00, 9.403e+01, 6.150e-03, 4.006e-02, 3.832e-02, 2.058e-02,\n",
|
|
||||||
" 2.250e-02, 4.571e-03, 2.357e+01, 2.553e+01, 1.525e+02, 1.709e+03,\n",
|
|
||||||
" 1.444e-01, 4.245e-01, 4.504e-01, 2.430e-01, 3.613e-01, 8.758e-02],\n",
|
|
||||||
" [1.142e+01, 2.038e+01, 7.758e+01, 3.861e+02, 1.425e-01, 2.839e-01,\n",
|
|
||||||
" 2.414e-01, 1.052e-01, 2.597e-01, 9.744e-02, 4.956e-01, 1.156e+00,\n",
|
|
||||||
" 3.445e+00, 2.723e+01, 9.110e-03, 7.458e-02, 5.661e-02, 1.867e-02,\n",
|
|
||||||
" 5.963e-02, 9.208e-03, 1.491e+01, 2.650e+01, 9.887e+01, 5.677e+02,\n",
|
|
||||||
" 2.098e-01, 8.663e-01, 6.869e-01, 2.575e-01, 6.638e-01, 1.730e-01],\n",
|
|
||||||
" [2.029e+01, 1.434e+01, 1.351e+02, 1.297e+03, 1.003e-01, 1.328e-01,\n",
|
|
||||||
" 1.980e-01, 1.043e-01, 1.809e-01, 5.883e-02, 7.572e-01, 7.813e-01,\n",
|
|
||||||
" 5.438e+00, 9.444e+01, 1.149e-02, 2.461e-02, 5.688e-02, 1.885e-02,\n",
|
|
||||||
" 1.756e-02, 5.115e-03, 2.254e+01, 1.667e+01, 1.522e+02, 1.575e+03,\n",
|
|
||||||
" 1.374e-01, 2.050e-01, 4.000e-01, 1.625e-01, 2.364e-01, 7.678e-02],\n",
|
|
||||||
" [1.245e+01, 1.570e+01, 8.257e+01, 4.771e+02, 1.278e-01, 1.700e-01,\n",
|
|
||||||
" 1.578e-01, 8.089e-02, 2.087e-01, 7.613e-02, 3.345e-01, 8.902e-01,\n",
|
|
||||||
" 2.217e+00, 2.719e+01, 7.510e-03, 3.345e-02, 3.672e-02, 1.137e-02,\n",
|
|
||||||
" 2.165e-02, 5.082e-03, 1.547e+01, 2.375e+01, 1.034e+02, 7.416e+02,\n",
|
|
||||||
" 1.791e-01, 5.249e-01, 5.355e-01, 1.741e-01, 3.985e-01, 1.244e-01],\n",
|
|
||||||
" [1.825e+01, 1.998e+01, 1.196e+02, 1.040e+03, 9.463e-02, 1.090e-01,\n",
|
|
||||||
" 1.127e-01, 7.400e-02, 1.794e-01, 5.742e-02, 4.467e-01, 7.732e-01,\n",
|
|
||||||
" 3.180e+00, 5.391e+01, 4.314e-03, 1.382e-02, 2.254e-02, 1.039e-02,\n",
|
|
||||||
" 1.369e-02, 2.179e-03, 2.288e+01, 2.766e+01, 1.532e+02, 1.606e+03,\n",
|
|
||||||
" 1.442e-01, 2.576e-01, 3.784e-01, 1.932e-01, 3.063e-01, 8.368e-02],\n",
|
|
||||||
" [1.371e+01, 2.083e+01, 9.020e+01, 5.779e+02, 1.189e-01, 1.645e-01,\n",
|
|
||||||
" 9.366e-02, 5.985e-02, 2.196e-01, 7.451e-02, 5.835e-01, 1.377e+00,\n",
|
|
||||||
" 3.856e+00, 5.096e+01, 8.805e-03, 3.029e-02, 2.488e-02, 1.448e-02,\n",
|
|
||||||
" 1.486e-02, 5.412e-03, 1.706e+01, 2.814e+01, 1.106e+02, 8.970e+02,\n",
|
|
||||||
" 1.654e-01, 3.682e-01, 2.678e-01, 1.556e-01, 3.196e-01, 1.151e-01],\n",
|
|
||||||
" [1.300e+01, 2.182e+01, 8.750e+01, 5.198e+02, 1.273e-01, 1.932e-01,\n",
|
|
||||||
" 1.859e-01, 9.353e-02, 2.350e-01, 7.389e-02, 3.063e-01, 1.002e+00,\n",
|
|
||||||
" 2.406e+00, 2.432e+01, 5.731e-03, 3.502e-02, 3.553e-02, 1.226e-02,\n",
|
|
||||||
" 2.143e-02, 3.749e-03, 1.549e+01, 3.073e+01, 1.062e+02, 7.393e+02,\n",
|
|
||||||
" 1.703e-01, 5.401e-01, 5.390e-01, 2.060e-01, 4.378e-01, 1.072e-01],\n",
|
|
||||||
" [1.246e+01, 2.404e+01, 8.397e+01, 4.759e+02, 1.186e-01, 2.396e-01,\n",
|
|
||||||
" 2.273e-01, 8.543e-02, 2.030e-01, 8.243e-02, 2.976e-01, 1.599e+00,\n",
|
|
||||||
" 2.039e+00, 2.394e+01, 7.149e-03, 7.217e-02, 7.743e-02, 1.432e-02,\n",
|
|
||||||
" 1.789e-02, 1.008e-02, 1.509e+01, 4.068e+01, 9.765e+01, 7.114e+02,\n",
|
|
||||||
" 1.853e-01, 1.058e+00, 1.105e+00, 2.210e-01, 4.366e-01, 2.075e-01]])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 113,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"features[1:10]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 114,
|
|
||||||
"id": "incredible-quantum",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"array([1, 1, 1, 1, 1, 1, 1, 1, 1])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 114,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"labels[1:10]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 115,
|
|
||||||
"id": "brazilian-butler",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"features_train, features_test, labels_train, labels_test = train_test_split(features, labels, random_state=42, shuffle=True)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 116,
|
|
||||||
"id": "exotic-method",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Training\n",
|
|
||||||
"model = Model(features_train.shape[1])\n",
|
|
||||||
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
|
|
||||||
"loss_fn = nn.CrossEntropyLoss()\n",
|
|
||||||
"epochs = 100\n",
|
|
||||||
"\n",
|
|
||||||
"def print_(loss):\n",
|
|
||||||
" print (\"The loss calculated: \", loss)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 117,
|
|
||||||
"id": "sharp-month",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch # 1\n",
|
|
||||||
"The loss calculated: 0.922476053237915\n",
|
|
||||||
"Epoch # 2\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 3\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 4\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 5\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 6\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 7\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 8\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 9\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 10\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 11\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 12\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 13\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 14\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 15\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 16\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 17\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 18\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 19\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 20\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 21\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 22\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 23\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 24\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 25\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 26\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 27\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 28\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 29\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 30\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 31\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 32\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 33\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 34\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 35\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 36\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 37\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 38\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 39\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 40\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 41\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 42\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 43\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 44\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 45\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 46\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 47\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 48\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 49\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 50\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 51\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 52\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 53\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 54\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 55\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 56\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 57\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 58\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 59\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 60\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 61\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 62\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 63\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 64\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 65\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 66\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 67\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 68\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 69\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 70\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 71\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 72\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 73\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 74\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 75\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 76\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 77\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 78\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 79\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 80\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 81\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 82\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 83\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 84\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 85\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 86\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 87\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 88\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 89\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 90\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 91\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 92\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 93\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 94\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 95\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 96\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 97\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 98\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 99\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n",
|
|
||||||
"Epoch # 100\n",
|
|
||||||
"The loss calculated: 0.9223369359970093\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
|
||||||
" # This is added back by InteractiveShellApp.init_path()\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Not using dataloader\n",
|
|
||||||
"x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long()\n",
|
|
||||||
"for epoch in range(1, epochs+1):\n",
|
|
||||||
" print (\"Epoch #\",epoch)\n",
|
|
||||||
" y_pred = model(x_train)\n",
|
|
||||||
" loss = loss_fn(y_pred, y_train)\n",
|
|
||||||
" print_(loss.item())\n",
|
|
||||||
" \n",
|
|
||||||
" # Zero gradients\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" loss.backward() # Gradients\n",
|
|
||||||
" optimizer.step() # Update"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 118,
|
|
||||||
"id": "mechanical-humidity",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
|
||||||
" # This is added back by InteractiveShellApp.init_path()\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Prediction\n",
|
|
||||||
"x_test = Variable(torch.from_numpy(features_test)).float()\n",
|
|
||||||
"pred = model(x_test)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 119,
|
|
||||||
"id": "based-charleston",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"array([[1., 0., 0.],\n",
|
|
||||||
" [1., 0., 0.],\n",
|
|
||||||
" [1., 0., 0.],\n",
|
|
||||||
" [1., 0., 0.],\n",
|
|
||||||
" [1., 0., 0.],\n",
|
|
||||||
" [1., 0., 0.],\n",
|
|
||||||
" [1., 0., 0.],\n",
|
|
||||||
" [1., 0., 0.],\n",
|
|
||||||
" [1., 0., 0.]], dtype=float32)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 119,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"pred = pred.detach().numpy()\n",
|
|
||||||
"pred[1:10]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 120,
|
|
||||||
"id": "dried-accessory",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"The accuracy is 0.6223776223776224\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"print (\"The accuracy is\", accuracy_score(labels_test, np.argmax(pred, axis=1)))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 121,
|
|
||||||
"id": "effective-characterization",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"0"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 121,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"labels_test[0]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 122,
|
|
||||||
"id": "oriented-determination",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"torch.save(model, \"travel_insurance-pytorch.pkl\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 123,
|
|
||||||
"id": "infectious-wagon",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"saved_model = torch.load(\"travel_insurance-pytorch.pkl\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 124,
|
|
||||||
"id": "built-contributor",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
|
||||||
" # This is added back by InteractiveShellApp.init_path()\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"0"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 124,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"np.argmax(saved_model(x_test[0]).detach().numpy(), axis=0)"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.7.3"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
530
classification_net.ipynb
Normal file
530
classification_net.ipynb
Normal file
@ -0,0 +1,530 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "forty-fault",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!kaggle datasets download -d kukuroo3/body-performance-data"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "pediatric-tuesday",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!unzip -o body-performance-data.zip"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 114,
|
||||||
|
"id": "interstate-presence",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from sklearn.metrics import classification_report\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from torch import nn, optim"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 115,
|
||||||
|
"id": "structural-trigger",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"(13393, 12)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 115,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"df = pd.read_csv('bodyPerformance.csv')\n",
|
||||||
|
"df.shape"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 116,
|
||||||
|
"id": "turkish-category",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>age</th>\n",
|
||||||
|
" <th>gender</th>\n",
|
||||||
|
" <th>height_cm</th>\n",
|
||||||
|
" <th>weight_kg</th>\n",
|
||||||
|
" <th>body fat_%</th>\n",
|
||||||
|
" <th>diastolic</th>\n",
|
||||||
|
" <th>systolic</th>\n",
|
||||||
|
" <th>gripForce</th>\n",
|
||||||
|
" <th>sit and bend forward_cm</th>\n",
|
||||||
|
" <th>sit-ups counts</th>\n",
|
||||||
|
" <th>broad jump_cm</th>\n",
|
||||||
|
" <th>class</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>27.0</td>\n",
|
||||||
|
" <td>M</td>\n",
|
||||||
|
" <td>172.3</td>\n",
|
||||||
|
" <td>75.24</td>\n",
|
||||||
|
" <td>21.3</td>\n",
|
||||||
|
" <td>80.0</td>\n",
|
||||||
|
" <td>130.0</td>\n",
|
||||||
|
" <td>54.9</td>\n",
|
||||||
|
" <td>18.4</td>\n",
|
||||||
|
" <td>60.0</td>\n",
|
||||||
|
" <td>217.0</td>\n",
|
||||||
|
" <td>C</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>1</th>\n",
|
||||||
|
" <td>25.0</td>\n",
|
||||||
|
" <td>M</td>\n",
|
||||||
|
" <td>165.0</td>\n",
|
||||||
|
" <td>55.80</td>\n",
|
||||||
|
" <td>15.7</td>\n",
|
||||||
|
" <td>77.0</td>\n",
|
||||||
|
" <td>126.0</td>\n",
|
||||||
|
" <td>36.4</td>\n",
|
||||||
|
" <td>16.3</td>\n",
|
||||||
|
" <td>53.0</td>\n",
|
||||||
|
" <td>229.0</td>\n",
|
||||||
|
" <td>A</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>31.0</td>\n",
|
||||||
|
" <td>M</td>\n",
|
||||||
|
" <td>179.6</td>\n",
|
||||||
|
" <td>78.00</td>\n",
|
||||||
|
" <td>20.1</td>\n",
|
||||||
|
" <td>92.0</td>\n",
|
||||||
|
" <td>152.0</td>\n",
|
||||||
|
" <td>44.8</td>\n",
|
||||||
|
" <td>12.0</td>\n",
|
||||||
|
" <td>49.0</td>\n",
|
||||||
|
" <td>181.0</td>\n",
|
||||||
|
" <td>C</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>32.0</td>\n",
|
||||||
|
" <td>M</td>\n",
|
||||||
|
" <td>174.5</td>\n",
|
||||||
|
" <td>71.10</td>\n",
|
||||||
|
" <td>18.4</td>\n",
|
||||||
|
" <td>76.0</td>\n",
|
||||||
|
" <td>147.0</td>\n",
|
||||||
|
" <td>41.4</td>\n",
|
||||||
|
" <td>15.2</td>\n",
|
||||||
|
" <td>53.0</td>\n",
|
||||||
|
" <td>219.0</td>\n",
|
||||||
|
" <td>B</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>28.0</td>\n",
|
||||||
|
" <td>M</td>\n",
|
||||||
|
" <td>173.8</td>\n",
|
||||||
|
" <td>67.70</td>\n",
|
||||||
|
" <td>17.1</td>\n",
|
||||||
|
" <td>70.0</td>\n",
|
||||||
|
" <td>127.0</td>\n",
|
||||||
|
" <td>43.5</td>\n",
|
||||||
|
" <td>27.1</td>\n",
|
||||||
|
" <td>45.0</td>\n",
|
||||||
|
" <td>217.0</td>\n",
|
||||||
|
" <td>B</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" age gender height_cm weight_kg body fat_% diastolic systolic \\\n",
|
||||||
|
"0 27.0 M 172.3 75.24 21.3 80.0 130.0 \n",
|
||||||
|
"1 25.0 M 165.0 55.80 15.7 77.0 126.0 \n",
|
||||||
|
"2 31.0 M 179.6 78.00 20.1 92.0 152.0 \n",
|
||||||
|
"3 32.0 M 174.5 71.10 18.4 76.0 147.0 \n",
|
||||||
|
"4 28.0 M 173.8 67.70 17.1 70.0 127.0 \n",
|
||||||
|
"\n",
|
||||||
|
" gripForce sit and bend forward_cm sit-ups counts broad jump_cm class \n",
|
||||||
|
"0 54.9 18.4 60.0 217.0 C \n",
|
||||||
|
"1 36.4 16.3 53.0 229.0 A \n",
|
||||||
|
"2 44.8 12.0 49.0 181.0 C \n",
|
||||||
|
"3 41.4 15.2 53.0 219.0 B \n",
|
||||||
|
"4 43.5 27.1 45.0 217.0 B "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 116,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"df.head()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 117,
|
||||||
|
"id": "received-absence",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"cols = ['gender', 'height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']\n",
|
||||||
|
"df = df[cols]\n",
|
||||||
|
"\n",
|
||||||
|
"# male - 0, female - 1\n",
|
||||||
|
"df['gender'].replace({'M': 0, 'F': 1}, inplace = True)\n",
|
||||||
|
"df = df.dropna(how='any')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 118,
|
||||||
|
"id": "excited-parent",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"0 0.632196\n",
|
||||||
|
"1 0.367804\n",
|
||||||
|
"Name: gender, dtype: float64"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 118,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"df.gender.value_counts() / df.shape[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 119,
|
||||||
|
"id": "extended-cinema",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X = df[['height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']]\n",
|
||||||
|
"y = df[['gender']]\n",
|
||||||
|
"\n",
|
||||||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 120,
|
||||||
|
"id": "animated-farming",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"torch.Size([10714, 5]) torch.Size([10714])\n",
|
||||||
|
"torch.Size([2679, 5]) torch.Size([2679])\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"X_train = torch.from_numpy(np.array(X_train)).float()\n",
|
||||||
|
"y_train = torch.squeeze(torch.from_numpy(y_train.values).float())\n",
|
||||||
|
"\n",
|
||||||
|
"X_test = torch.from_numpy(np.array(X_test)).float()\n",
|
||||||
|
"y_test = torch.squeeze(torch.from_numpy(y_test.values).float())\n",
|
||||||
|
"\n",
|
||||||
|
"print(X_train.shape, y_train.shape)\n",
|
||||||
|
"print(X_test.shape, y_test.shape)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 121,
|
||||||
|
"id": "technical-wallet",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class Net(nn.Module):\n",
|
||||||
|
" def __init__(self, n_features):\n",
|
||||||
|
" super(Net, self).__init__()\n",
|
||||||
|
" self.fc1 = nn.Linear(n_features, 5)\n",
|
||||||
|
" self.fc2 = nn.Linear(5, 3)\n",
|
||||||
|
" self.fc3 = nn.Linear(3, 1)\n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" x = F.relu(self.fc1(x))\n",
|
||||||
|
" x = F.relu(self.fc2(x))\n",
|
||||||
|
" return torch.sigmoid(self.fc3(x))\n",
|
||||||
|
"net = Net(X_train.shape[1])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 122,
|
||||||
|
"id": "requested-plymouth",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"criterion = nn.BCELoss()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 123,
|
||||||
|
"id": "iraqi-english",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"optimizer = optim.Adam(net.parameters(), lr=0.001)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 124,
|
||||||
|
"id": "emerging-helmet",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 125,
|
||||||
|
"id": "differential-aviation",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X_train = X_train.to(device)\n",
|
||||||
|
"y_train = y_train.to(device)\n",
|
||||||
|
"X_test = X_test.to(device)\n",
|
||||||
|
"y_test = y_test.to(device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 126,
|
||||||
|
"id": "ranging-calgary",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"net = net.to(device)\n",
|
||||||
|
"criterion = criterion.to(device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 127,
|
||||||
|
"id": "iraqi-blanket",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def calculate_accuracy(y_true, y_pred):\n",
|
||||||
|
" predicted = y_pred.ge(.5).view(-1)\n",
|
||||||
|
" return (y_true == predicted).sum().float() / len(y_true)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 128,
|
||||||
|
"id": "robust-serbia",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"epoch 0\n",
|
||||||
|
"Train set - loss: 1.005, accuracy: 0.37\n",
|
||||||
|
"Test set - loss: 1.018, accuracy: 0.358\n",
|
||||||
|
"\n",
|
||||||
|
"epoch 100\n",
|
||||||
|
"Train set - loss: 0.677, accuracy: 0.743\n",
|
||||||
|
"Test set - loss: 0.679, accuracy: 0.727\n",
|
||||||
|
"\n",
|
||||||
|
"epoch 200\n",
|
||||||
|
"Train set - loss: 0.636, accuracy: 0.79\n",
|
||||||
|
"Test set - loss: 0.64, accuracy: 0.778\n",
|
||||||
|
"\n",
|
||||||
|
"epoch 300\n",
|
||||||
|
"Train set - loss: 0.568, accuracy: 0.839\n",
|
||||||
|
"Test set - loss: 0.577, accuracy: 0.833\n",
|
||||||
|
"\n",
|
||||||
|
"epoch 400\n",
|
||||||
|
"Train set - loss: 0.504, accuracy: 0.885\n",
|
||||||
|
"Test set - loss: 0.514, accuracy: 0.877\n",
|
||||||
|
"\n",
|
||||||
|
"epoch 500\n",
|
||||||
|
"Train set - loss: 0.441, accuracy: 0.922\n",
|
||||||
|
"Test set - loss: 0.45, accuracy: 0.913\n",
|
||||||
|
"\n",
|
||||||
|
"epoch 600\n",
|
||||||
|
"Train set - loss: 0.388, accuracy: 0.944\n",
|
||||||
|
"Test set - loss: 0.396, accuracy: 0.938\n",
|
||||||
|
"\n",
|
||||||
|
"epoch 700\n",
|
||||||
|
"Train set - loss: 0.353, accuracy: 0.954\n",
|
||||||
|
"Test set - loss: 0.359, accuracy: 0.949\n",
|
||||||
|
"\n",
|
||||||
|
"epoch 800\n",
|
||||||
|
"Train set - loss: 0.327, accuracy: 0.958\n",
|
||||||
|
"Test set - loss: 0.333, accuracy: 0.953\n",
|
||||||
|
"\n",
|
||||||
|
"epoch 900\n",
|
||||||
|
"Train set - loss: 0.306, accuracy: 0.961\n",
|
||||||
|
"Test set - loss: 0.312, accuracy: 0.955\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"def round_tensor(t, decimal_places=3):\n",
|
||||||
|
" return round(t.item(), decimal_places)\n",
|
||||||
|
"for epoch in range(1000):\n",
|
||||||
|
" y_pred = net(X_train)\n",
|
||||||
|
" y_pred = torch.squeeze(y_pred)\n",
|
||||||
|
" train_loss = criterion(y_pred, y_train)\n",
|
||||||
|
" if epoch % 100 == 0:\n",
|
||||||
|
" train_acc = calculate_accuracy(y_train, y_pred)\n",
|
||||||
|
" y_test_pred = net(X_test)\n",
|
||||||
|
" y_test_pred = torch.squeeze(y_test_pred)\n",
|
||||||
|
" test_loss = criterion(y_test_pred, y_test)\n",
|
||||||
|
" test_acc = calculate_accuracy(y_test, y_test_pred)\n",
|
||||||
|
" print(\n",
|
||||||
|
"f'''epoch {epoch}\n",
|
||||||
|
"Train set - loss: {round_tensor(train_loss)}, accuracy: {round_tensor(train_acc)}\n",
|
||||||
|
"Test set - loss: {round_tensor(test_loss)}, accuracy: {round_tensor(test_acc)}\n",
|
||||||
|
"''')\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" train_loss.backward()\n",
|
||||||
|
" optimizer.step()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 129,
|
||||||
|
"id": "optimum-excerpt",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# torch.save(net, 'model.pth')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 130,
|
||||||
|
"id": "dental-seating",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# net = torch.load('model.pth')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 131,
|
||||||
|
"id": "german-satisfaction",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" precision recall f1-score support\n",
|
||||||
|
"\n",
|
||||||
|
" Male 0.97 0.96 0.96 1720\n",
|
||||||
|
" Female 0.93 0.94 0.94 959\n",
|
||||||
|
"\n",
|
||||||
|
" accuracy 0.95 2679\n",
|
||||||
|
" macro avg 0.95 0.95 0.95 2679\n",
|
||||||
|
"weighted avg 0.95 0.95 0.95 2679\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"classes = ['Male', 'Female']\n",
|
||||||
|
"y_pred = net(X_test)\n",
|
||||||
|
"y_pred = y_pred.ge(.5).view(-1).cpu()\n",
|
||||||
|
"y_test = y_test.cpu()\n",
|
||||||
|
"print(classification_report(y_test, y_pred, target_names=classes))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 132,
|
||||||
|
"id": "british-incidence",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"with open('test_out.csv', 'w') as file:\n",
|
||||||
|
" for y in y_pred:\n",
|
||||||
|
" file.write(classes[y.item()])\n",
|
||||||
|
" file.write('\\n')"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user