diff --git a/body-performance-data.zip b/body-performance-data.zip new file mode 100644 index 0000000..d3f3f80 Binary files /dev/null and b/body-performance-data.zip differ diff --git a/breast-cancer-dataset.zip b/breast-cancer-dataset.zip deleted file mode 100644 index ac2a6f8..0000000 Binary files a/breast-cancer-dataset.zip and /dev/null differ diff --git a/classification.ipynb b/classification.ipynb deleted file mode 100644 index 7da3764..0000000 --- a/classification.ipynb +++ /dev/null @@ -1,969 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "corrected-wholesale", - "metadata": {}, - "outputs": [], - "source": [ - "!kaggle datasets download -d yasserh/breast-cancer-dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ranging-police", - "metadata": {}, - "outputs": [], - "source": [ - "!unzip -o breast-cancer-dataset.zip" - ] - }, - { - "cell_type": "code", - "execution_count": 109, - "id": "ideal-spouse", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n", - "from torch import nn\n", - "from torch.autograd import Variable\n", - "from sklearn.datasets import load_iris\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.metrics import accuracy_score\n", - "from sklearn.preprocessing import LabelEncoder\n", - "from tensorflow.keras.utils import to_categorical\n", - "import torch.nn.functional as F\n", - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": 110, - "id": "major-compromise", - "metadata": {}, - "outputs": [], - "source": [ - "class Model(nn.Module):\n", - " def __init__(self, input_dim):\n", - " super(Model, self).__init__()\n", - " self.layer1 = nn.Linear(input_dim,50)\n", - " self.layer2 = nn.Linear(50, 20)\n", - " self.layer3 = nn.Linear(20, 3)\n", - " \n", - " def forward(self, x):\n", - " x = F.relu(self.layer1(x))\n", - " x = F.relu(self.layer2(x))\n", - " x = F.softmax(self.layer3(x)) # To check with the loss function\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "id": "czech-regular", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
diagnosisradius_meantexture_meanperimeter_meanarea_meansmoothness_meancompactness_meanconcavity_meanconcave points_meansymmetry_mean...radius_worsttexture_worstperimeter_worstarea_worstsmoothness_worstcompactness_worstconcavity_worstconcave points_worstsymmetry_worstfractal_dimension_worst
id
842517M20.5717.77132.901326.00.084740.078640.086900.070170.1812...24.9923.41158.801956.00.12380.18660.24160.18600.27500.08902
84300903M19.6921.25130.001203.00.109600.159900.197400.127900.2069...23.5725.53152.501709.00.14440.42450.45040.24300.36130.08758
84348301M11.4220.3877.58386.10.142500.283900.241400.105200.2597...14.9126.5098.87567.70.20980.86630.68690.25750.66380.17300
84358402M20.2914.34135.101297.00.100300.132800.198000.104300.1809...22.5416.67152.201575.00.13740.20500.40000.16250.23640.07678
843786M12.4515.7082.57477.10.127800.170000.157800.080890.2087...15.4723.75103.40741.60.17910.52490.53550.17410.39850.12440
844359M18.2519.98119.601040.00.094630.109000.112700.074000.1794...22.8827.66153.201606.00.14420.25760.37840.19320.30630.08368
84458202M13.7120.8390.20577.90.118900.164500.093660.059850.2196...17.0628.14110.60897.00.16540.36820.26780.15560.31960.11510
844981M13.0021.8287.50519.80.127300.193200.185900.093530.2350...15.4930.73106.20739.30.17030.54010.53900.20600.43780.10720
84501001M12.4624.0483.97475.90.118600.239600.227300.085430.2030...15.0940.6897.65711.40.18531.05801.10500.22100.43660.20750
\n", - "

9 rows × 31 columns

\n", - "
" - ], - "text/plain": [ - " diagnosis radius_mean texture_mean perimeter_mean area_mean \\\n", - "id \n", - "842517 M 20.57 17.77 132.90 1326.0 \n", - "84300903 M 19.69 21.25 130.00 1203.0 \n", - "84348301 M 11.42 20.38 77.58 386.1 \n", - "84358402 M 20.29 14.34 135.10 1297.0 \n", - "843786 M 12.45 15.70 82.57 477.1 \n", - "844359 M 18.25 19.98 119.60 1040.0 \n", - "84458202 M 13.71 20.83 90.20 577.9 \n", - "844981 M 13.00 21.82 87.50 519.8 \n", - "84501001 M 12.46 24.04 83.97 475.9 \n", - "\n", - " smoothness_mean compactness_mean concavity_mean \\\n", - "id \n", - "842517 0.08474 0.07864 0.08690 \n", - "84300903 0.10960 0.15990 0.19740 \n", - "84348301 0.14250 0.28390 0.24140 \n", - "84358402 0.10030 0.13280 0.19800 \n", - "843786 0.12780 0.17000 0.15780 \n", - "844359 0.09463 0.10900 0.11270 \n", - "84458202 0.11890 0.16450 0.09366 \n", - "844981 0.12730 0.19320 0.18590 \n", - "84501001 0.11860 0.23960 0.22730 \n", - "\n", - " concave points_mean symmetry_mean ... \\\n", - "id ... \n", - "842517 0.07017 0.1812 ... \n", - "84300903 0.12790 0.2069 ... \n", - "84348301 0.10520 0.2597 ... \n", - "84358402 0.10430 0.1809 ... \n", - "843786 0.08089 0.2087 ... \n", - "844359 0.07400 0.1794 ... \n", - "84458202 0.05985 0.2196 ... \n", - "844981 0.09353 0.2350 ... \n", - "84501001 0.08543 0.2030 ... \n", - "\n", - " radius_worst texture_worst perimeter_worst area_worst \\\n", - "id \n", - "842517 24.99 23.41 158.80 1956.0 \n", - "84300903 23.57 25.53 152.50 1709.0 \n", - "84348301 14.91 26.50 98.87 567.7 \n", - "84358402 22.54 16.67 152.20 1575.0 \n", - "843786 15.47 23.75 103.40 741.6 \n", - "844359 22.88 27.66 153.20 1606.0 \n", - "84458202 17.06 28.14 110.60 897.0 \n", - "844981 15.49 30.73 106.20 739.3 \n", - "84501001 15.09 40.68 97.65 711.4 \n", - "\n", - " smoothness_worst compactness_worst concavity_worst \\\n", - "id \n", - "842517 0.1238 0.1866 0.2416 \n", - "84300903 0.1444 0.4245 0.4504 \n", - "84348301 0.2098 0.8663 0.6869 \n", - "84358402 0.1374 0.2050 0.4000 \n", - "843786 0.1791 0.5249 0.5355 \n", - "844359 0.1442 0.2576 0.3784 \n", - "84458202 0.1654 0.3682 0.2678 \n", - "844981 0.1703 0.5401 0.5390 \n", - "84501001 0.1853 1.0580 1.1050 \n", - "\n", - " concave points_worst symmetry_worst fractal_dimension_worst \n", - "id \n", - "842517 0.1860 0.2750 0.08902 \n", - "84300903 0.2430 0.3613 0.08758 \n", - "84348301 0.2575 0.6638 0.17300 \n", - "84358402 0.1625 0.2364 0.07678 \n", - "843786 0.1741 0.3985 0.12440 \n", - "844359 0.1932 0.3063 0.08368 \n", - "84458202 0.1556 0.3196 0.11510 \n", - "844981 0.2060 0.4378 0.10720 \n", - "84501001 0.2210 0.4366 0.20750 \n", - "\n", - "[9 rows x 31 columns]" - ] - }, - "execution_count": 111, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data = pd.read_csv('breast-cancer.csv', index_col=0)\n", - "data[1:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 112, - "id": "outdoor-element", - "metadata": {}, - "outputs": [], - "source": [ - "lb = LabelEncoder()\n", - "data['diagnosis'] = lb.fit_transform(data['diagnosis'])\n", - "features = data.iloc[:, 1:32].values\n", - "labels = np.array(data['diagnosis'])" - ] - }, - { - "cell_type": "code", - "execution_count": 113, - "id": "buried-community", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[2.057e+01, 1.777e+01, 1.329e+02, 1.326e+03, 8.474e-02, 7.864e-02,\n", - " 8.690e-02, 7.017e-02, 1.812e-01, 5.667e-02, 5.435e-01, 7.339e-01,\n", - " 3.398e+00, 7.408e+01, 5.225e-03, 1.308e-02, 1.860e-02, 1.340e-02,\n", - " 1.389e-02, 3.532e-03, 2.499e+01, 2.341e+01, 1.588e+02, 1.956e+03,\n", - " 1.238e-01, 1.866e-01, 2.416e-01, 1.860e-01, 2.750e-01, 8.902e-02],\n", - " [1.969e+01, 2.125e+01, 1.300e+02, 1.203e+03, 1.096e-01, 1.599e-01,\n", - " 1.974e-01, 1.279e-01, 2.069e-01, 5.999e-02, 7.456e-01, 7.869e-01,\n", - " 4.585e+00, 9.403e+01, 6.150e-03, 4.006e-02, 3.832e-02, 2.058e-02,\n", - " 2.250e-02, 4.571e-03, 2.357e+01, 2.553e+01, 1.525e+02, 1.709e+03,\n", - " 1.444e-01, 4.245e-01, 4.504e-01, 2.430e-01, 3.613e-01, 8.758e-02],\n", - " [1.142e+01, 2.038e+01, 7.758e+01, 3.861e+02, 1.425e-01, 2.839e-01,\n", - " 2.414e-01, 1.052e-01, 2.597e-01, 9.744e-02, 4.956e-01, 1.156e+00,\n", - " 3.445e+00, 2.723e+01, 9.110e-03, 7.458e-02, 5.661e-02, 1.867e-02,\n", - " 5.963e-02, 9.208e-03, 1.491e+01, 2.650e+01, 9.887e+01, 5.677e+02,\n", - " 2.098e-01, 8.663e-01, 6.869e-01, 2.575e-01, 6.638e-01, 1.730e-01],\n", - " [2.029e+01, 1.434e+01, 1.351e+02, 1.297e+03, 1.003e-01, 1.328e-01,\n", - " 1.980e-01, 1.043e-01, 1.809e-01, 5.883e-02, 7.572e-01, 7.813e-01,\n", - " 5.438e+00, 9.444e+01, 1.149e-02, 2.461e-02, 5.688e-02, 1.885e-02,\n", - " 1.756e-02, 5.115e-03, 2.254e+01, 1.667e+01, 1.522e+02, 1.575e+03,\n", - " 1.374e-01, 2.050e-01, 4.000e-01, 1.625e-01, 2.364e-01, 7.678e-02],\n", - " [1.245e+01, 1.570e+01, 8.257e+01, 4.771e+02, 1.278e-01, 1.700e-01,\n", - " 1.578e-01, 8.089e-02, 2.087e-01, 7.613e-02, 3.345e-01, 8.902e-01,\n", - " 2.217e+00, 2.719e+01, 7.510e-03, 3.345e-02, 3.672e-02, 1.137e-02,\n", - " 2.165e-02, 5.082e-03, 1.547e+01, 2.375e+01, 1.034e+02, 7.416e+02,\n", - " 1.791e-01, 5.249e-01, 5.355e-01, 1.741e-01, 3.985e-01, 1.244e-01],\n", - " [1.825e+01, 1.998e+01, 1.196e+02, 1.040e+03, 9.463e-02, 1.090e-01,\n", - " 1.127e-01, 7.400e-02, 1.794e-01, 5.742e-02, 4.467e-01, 7.732e-01,\n", - " 3.180e+00, 5.391e+01, 4.314e-03, 1.382e-02, 2.254e-02, 1.039e-02,\n", - " 1.369e-02, 2.179e-03, 2.288e+01, 2.766e+01, 1.532e+02, 1.606e+03,\n", - " 1.442e-01, 2.576e-01, 3.784e-01, 1.932e-01, 3.063e-01, 8.368e-02],\n", - " [1.371e+01, 2.083e+01, 9.020e+01, 5.779e+02, 1.189e-01, 1.645e-01,\n", - " 9.366e-02, 5.985e-02, 2.196e-01, 7.451e-02, 5.835e-01, 1.377e+00,\n", - " 3.856e+00, 5.096e+01, 8.805e-03, 3.029e-02, 2.488e-02, 1.448e-02,\n", - " 1.486e-02, 5.412e-03, 1.706e+01, 2.814e+01, 1.106e+02, 8.970e+02,\n", - " 1.654e-01, 3.682e-01, 2.678e-01, 1.556e-01, 3.196e-01, 1.151e-01],\n", - " [1.300e+01, 2.182e+01, 8.750e+01, 5.198e+02, 1.273e-01, 1.932e-01,\n", - " 1.859e-01, 9.353e-02, 2.350e-01, 7.389e-02, 3.063e-01, 1.002e+00,\n", - " 2.406e+00, 2.432e+01, 5.731e-03, 3.502e-02, 3.553e-02, 1.226e-02,\n", - " 2.143e-02, 3.749e-03, 1.549e+01, 3.073e+01, 1.062e+02, 7.393e+02,\n", - " 1.703e-01, 5.401e-01, 5.390e-01, 2.060e-01, 4.378e-01, 1.072e-01],\n", - " [1.246e+01, 2.404e+01, 8.397e+01, 4.759e+02, 1.186e-01, 2.396e-01,\n", - " 2.273e-01, 8.543e-02, 2.030e-01, 8.243e-02, 2.976e-01, 1.599e+00,\n", - " 2.039e+00, 2.394e+01, 7.149e-03, 7.217e-02, 7.743e-02, 1.432e-02,\n", - " 1.789e-02, 1.008e-02, 1.509e+01, 4.068e+01, 9.765e+01, 7.114e+02,\n", - " 1.853e-01, 1.058e+00, 1.105e+00, 2.210e-01, 4.366e-01, 2.075e-01]])" - ] - }, - "execution_count": 113, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "features[1:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 114, - "id": "incredible-quantum", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([1, 1, 1, 1, 1, 1, 1, 1, 1])" - ] - }, - "execution_count": 114, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "labels[1:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 115, - "id": "brazilian-butler", - "metadata": {}, - "outputs": [], - "source": [ - "features_train, features_test, labels_train, labels_test = train_test_split(features, labels, random_state=42, shuffle=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 116, - "id": "exotic-method", - "metadata": {}, - "outputs": [], - "source": [ - "# Training\n", - "model = Model(features_train.shape[1])\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", - "loss_fn = nn.CrossEntropyLoss()\n", - "epochs = 100\n", - "\n", - "def print_(loss):\n", - " print (\"The loss calculated: \", loss)" - ] - }, - { - "cell_type": "code", - "execution_count": 117, - "id": "sharp-month", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch # 1\n", - "The loss calculated: 0.922476053237915\n", - "Epoch # 2\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 3\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 4\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 5\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 6\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 7\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 8\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 9\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 10\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 11\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 12\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 13\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 14\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 15\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 16\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 17\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 18\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 19\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 20\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 21\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 22\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 23\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 24\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 25\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 26\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 27\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 28\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 29\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 30\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 31\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 32\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 33\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 34\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 35\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 36\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 37\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 38\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 39\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 40\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 41\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 42\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 43\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 44\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 45\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 46\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 47\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 48\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 49\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 50\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 51\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 52\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 53\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 54\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 55\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 56\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 57\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 58\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 59\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 60\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 61\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 62\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 63\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 64\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 65\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 66\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 67\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 68\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 69\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 70\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 71\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 72\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 73\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 74\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 75\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 76\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 77\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 78\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 79\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 80\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 81\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 82\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 83\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 84\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 85\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 86\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 87\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 88\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 89\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 90\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 91\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 92\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 93\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 94\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 95\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 96\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 97\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 98\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 99\n", - "The loss calculated: 0.9223369359970093\n", - "Epoch # 100\n", - "The loss calculated: 0.9223369359970093\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", - " # This is added back by InteractiveShellApp.init_path()\n" - ] - } - ], - "source": [ - "# Not using dataloader\n", - "x_train, y_train = Variable(torch.from_numpy(features_train)).float(), Variable(torch.from_numpy(labels_train)).long()\n", - "for epoch in range(1, epochs+1):\n", - " print (\"Epoch #\",epoch)\n", - " y_pred = model(x_train)\n", - " loss = loss_fn(y_pred, y_train)\n", - " print_(loss.item())\n", - " \n", - " # Zero gradients\n", - " optimizer.zero_grad()\n", - " loss.backward() # Gradients\n", - " optimizer.step() # Update" - ] - }, - { - "cell_type": "code", - "execution_count": 118, - "id": "mechanical-humidity", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", - " # This is added back by InteractiveShellApp.init_path()\n" - ] - } - ], - "source": [ - "# Prediction\n", - "x_test = Variable(torch.from_numpy(features_test)).float()\n", - "pred = model(x_test)" - ] - }, - { - "cell_type": "code", - "execution_count": 119, - "id": "based-charleston", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1., 0., 0.],\n", - " [1., 0., 0.],\n", - " [1., 0., 0.],\n", - " [1., 0., 0.],\n", - " [1., 0., 0.],\n", - " [1., 0., 0.],\n", - " [1., 0., 0.],\n", - " [1., 0., 0.],\n", - " [1., 0., 0.]], dtype=float32)" - ] - }, - "execution_count": 119, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pred = pred.detach().numpy()\n", - "pred[1:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 120, - "id": "dried-accessory", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The accuracy is 0.6223776223776224\n" - ] - } - ], - "source": [ - "print (\"The accuracy is\", accuracy_score(labels_test, np.argmax(pred, axis=1)))" - ] - }, - { - "cell_type": "code", - "execution_count": 121, - "id": "effective-characterization", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0" - ] - }, - "execution_count": 121, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "labels_test[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 122, - "id": "oriented-determination", - "metadata": {}, - "outputs": [], - "source": [ - "torch.save(model, \"travel_insurance-pytorch.pkl\")" - ] - }, - { - "cell_type": "code", - "execution_count": 123, - "id": "infectious-wagon", - "metadata": {}, - "outputs": [], - "source": [ - "saved_model = torch.load(\"travel_insurance-pytorch.pkl\")" - ] - }, - { - "cell_type": "code", - "execution_count": 124, - "id": "built-contributor", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", - " # This is added back by InteractiveShellApp.init_path()\n" - ] - }, - { - "data": { - "text/plain": [ - "0" - ] - }, - "execution_count": 124, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.argmax(saved_model(x_test[0]).detach().numpy(), axis=0)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/classification_net.ipynb b/classification_net.ipynb new file mode 100644 index 0000000..67b646d --- /dev/null +++ b/classification_net.ipynb @@ -0,0 +1,530 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "forty-fault", + "metadata": {}, + "outputs": [], + "source": [ + "!kaggle datasets download -d kukuroo3/body-performance-data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "pediatric-tuesday", + "metadata": {}, + "outputs": [], + "source": [ + "!unzip -o body-performance-data.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "id": "interstate-presence", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import classification_report\n", + "import torch\n", + "from torch import nn, optim" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "id": "structural-trigger", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(13393, 12)" + ] + }, + "execution_count": 115, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('bodyPerformance.csv')\n", + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "turkish-category", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
agegenderheight_cmweight_kgbody fat_%diastolicsystolicgripForcesit and bend forward_cmsit-ups countsbroad jump_cmclass
027.0M172.375.2421.380.0130.054.918.460.0217.0C
125.0M165.055.8015.777.0126.036.416.353.0229.0A
231.0M179.678.0020.192.0152.044.812.049.0181.0C
332.0M174.571.1018.476.0147.041.415.253.0219.0B
428.0M173.867.7017.170.0127.043.527.145.0217.0B
\n", + "
" + ], + "text/plain": [ + " age gender height_cm weight_kg body fat_% diastolic systolic \\\n", + "0 27.0 M 172.3 75.24 21.3 80.0 130.0 \n", + "1 25.0 M 165.0 55.80 15.7 77.0 126.0 \n", + "2 31.0 M 179.6 78.00 20.1 92.0 152.0 \n", + "3 32.0 M 174.5 71.10 18.4 76.0 147.0 \n", + "4 28.0 M 173.8 67.70 17.1 70.0 127.0 \n", + "\n", + " gripForce sit and bend forward_cm sit-ups counts broad jump_cm class \n", + "0 54.9 18.4 60.0 217.0 C \n", + "1 36.4 16.3 53.0 229.0 A \n", + "2 44.8 12.0 49.0 181.0 C \n", + "3 41.4 15.2 53.0 219.0 B \n", + "4 43.5 27.1 45.0 217.0 B " + ] + }, + "execution_count": 116, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "received-absence", + "metadata": {}, + "outputs": [], + "source": [ + "cols = ['gender', 'height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']\n", + "df = df[cols]\n", + "\n", + "# male - 0, female - 1\n", + "df['gender'].replace({'M': 0, 'F': 1}, inplace = True)\n", + "df = df.dropna(how='any')" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "id": "excited-parent", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 0.632196\n", + "1 0.367804\n", + "Name: gender, dtype: float64" + ] + }, + "execution_count": 118, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.gender.value_counts() / df.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "id": "extended-cinema", + "metadata": {}, + "outputs": [], + "source": [ + "X = df[['height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']]\n", + "y = df[['gender']]\n", + "\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "id": "animated-farming", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([10714, 5]) torch.Size([10714])\n", + "torch.Size([2679, 5]) torch.Size([2679])\n" + ] + } + ], + "source": [ + "X_train = torch.from_numpy(np.array(X_train)).float()\n", + "y_train = torch.squeeze(torch.from_numpy(y_train.values).float())\n", + "\n", + "X_test = torch.from_numpy(np.array(X_test)).float()\n", + "y_test = torch.squeeze(torch.from_numpy(y_test.values).float())\n", + "\n", + "print(X_train.shape, y_train.shape)\n", + "print(X_test.shape, y_test.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "id": "technical-wallet", + "metadata": {}, + "outputs": [], + "source": [ + "class Net(nn.Module):\n", + " def __init__(self, n_features):\n", + " super(Net, self).__init__()\n", + " self.fc1 = nn.Linear(n_features, 5)\n", + " self.fc2 = nn.Linear(5, 3)\n", + " self.fc3 = nn.Linear(3, 1)\n", + " def forward(self, x):\n", + " x = F.relu(self.fc1(x))\n", + " x = F.relu(self.fc2(x))\n", + " return torch.sigmoid(self.fc3(x))\n", + "net = Net(X_train.shape[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "id": "requested-plymouth", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.BCELoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "id": "iraqi-english", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = optim.Adam(net.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "id": "emerging-helmet", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "id": "differential-aviation", + "metadata": {}, + "outputs": [], + "source": [ + "X_train = X_train.to(device)\n", + "y_train = y_train.to(device)\n", + "X_test = X_test.to(device)\n", + "y_test = y_test.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "id": "ranging-calgary", + "metadata": {}, + "outputs": [], + "source": [ + "net = net.to(device)\n", + "criterion = criterion.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "id": "iraqi-blanket", + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_accuracy(y_true, y_pred):\n", + " predicted = y_pred.ge(.5).view(-1)\n", + " return (y_true == predicted).sum().float() / len(y_true)" + ] + }, + { + "cell_type": "code", + "execution_count": 128, + "id": "robust-serbia", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0\n", + "Train set - loss: 1.005, accuracy: 0.37\n", + "Test set - loss: 1.018, accuracy: 0.358\n", + "\n", + "epoch 100\n", + "Train set - loss: 0.677, accuracy: 0.743\n", + "Test set - loss: 0.679, accuracy: 0.727\n", + "\n", + "epoch 200\n", + "Train set - loss: 0.636, accuracy: 0.79\n", + "Test set - loss: 0.64, accuracy: 0.778\n", + "\n", + "epoch 300\n", + "Train set - loss: 0.568, accuracy: 0.839\n", + "Test set - loss: 0.577, accuracy: 0.833\n", + "\n", + "epoch 400\n", + "Train set - loss: 0.504, accuracy: 0.885\n", + "Test set - loss: 0.514, accuracy: 0.877\n", + "\n", + "epoch 500\n", + "Train set - loss: 0.441, accuracy: 0.922\n", + "Test set - loss: 0.45, accuracy: 0.913\n", + "\n", + "epoch 600\n", + "Train set - loss: 0.388, accuracy: 0.944\n", + "Test set - loss: 0.396, accuracy: 0.938\n", + "\n", + "epoch 700\n", + "Train set - loss: 0.353, accuracy: 0.954\n", + "Test set - loss: 0.359, accuracy: 0.949\n", + "\n", + "epoch 800\n", + "Train set - loss: 0.327, accuracy: 0.958\n", + "Test set - loss: 0.333, accuracy: 0.953\n", + "\n", + "epoch 900\n", + "Train set - loss: 0.306, accuracy: 0.961\n", + "Test set - loss: 0.312, accuracy: 0.955\n", + "\n" + ] + } + ], + "source": [ + "def round_tensor(t, decimal_places=3):\n", + " return round(t.item(), decimal_places)\n", + "for epoch in range(1000):\n", + " y_pred = net(X_train)\n", + " y_pred = torch.squeeze(y_pred)\n", + " train_loss = criterion(y_pred, y_train)\n", + " if epoch % 100 == 0:\n", + " train_acc = calculate_accuracy(y_train, y_pred)\n", + " y_test_pred = net(X_test)\n", + " y_test_pred = torch.squeeze(y_test_pred)\n", + " test_loss = criterion(y_test_pred, y_test)\n", + " test_acc = calculate_accuracy(y_test, y_test_pred)\n", + " print(\n", + "f'''epoch {epoch}\n", + "Train set - loss: {round_tensor(train_loss)}, accuracy: {round_tensor(train_acc)}\n", + "Test set - loss: {round_tensor(test_loss)}, accuracy: {round_tensor(test_acc)}\n", + "''')\n", + " optimizer.zero_grad()\n", + " train_loss.backward()\n", + " optimizer.step()" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "id": "optimum-excerpt", + "metadata": {}, + "outputs": [], + "source": [ + "# torch.save(net, 'model.pth')" + ] + }, + { + "cell_type": "code", + "execution_count": 130, + "id": "dental-seating", + "metadata": {}, + "outputs": [], + "source": [ + "# net = torch.load('model.pth')" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "id": "german-satisfaction", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " Male 0.97 0.96 0.96 1720\n", + " Female 0.93 0.94 0.94 959\n", + "\n", + " accuracy 0.95 2679\n", + " macro avg 0.95 0.95 0.95 2679\n", + "weighted avg 0.95 0.95 0.95 2679\n", + "\n" + ] + } + ], + "source": [ + "classes = ['Male', 'Female']\n", + "y_pred = net(X_test)\n", + "y_pred = y_pred.ge(.5).view(-1).cpu()\n", + "y_test = y_test.cpu()\n", + "print(classification_report(y_test, y_pred, target_names=classes))" + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "id": "british-incidence", + "metadata": {}, + "outputs": [], + "source": [ + "with open('test_out.csv', 'w') as file:\n", + " for y in y_pred:\n", + " file.write(classes[y.item()])\n", + " file.write('\\n')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}