{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "forty-fault",
   "metadata": {},
   "outputs": [],
   "source": [
    "!kaggle datasets download -d kukuroo3/body-performance-data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "pediatric-tuesday",
   "metadata": {},
   "outputs": [],
   "source": [
    "!unzip -o body-performance-data.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "interstate-presence",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "id": "structural-trigger",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(13393, 12)"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('bodyPerformance.csv')\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "turkish-category",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>height_cm</th>\n",
       "      <th>weight_kg</th>\n",
       "      <th>body fat_%</th>\n",
       "      <th>diastolic</th>\n",
       "      <th>systolic</th>\n",
       "      <th>gripForce</th>\n",
       "      <th>sit and bend forward_cm</th>\n",
       "      <th>sit-ups counts</th>\n",
       "      <th>broad jump_cm</th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>27.0</td>\n",
       "      <td>M</td>\n",
       "      <td>172.3</td>\n",
       "      <td>75.24</td>\n",
       "      <td>21.3</td>\n",
       "      <td>80.0</td>\n",
       "      <td>130.0</td>\n",
       "      <td>54.9</td>\n",
       "      <td>18.4</td>\n",
       "      <td>60.0</td>\n",
       "      <td>217.0</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>25.0</td>\n",
       "      <td>M</td>\n",
       "      <td>165.0</td>\n",
       "      <td>55.80</td>\n",
       "      <td>15.7</td>\n",
       "      <td>77.0</td>\n",
       "      <td>126.0</td>\n",
       "      <td>36.4</td>\n",
       "      <td>16.3</td>\n",
       "      <td>53.0</td>\n",
       "      <td>229.0</td>\n",
       "      <td>A</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>31.0</td>\n",
       "      <td>M</td>\n",
       "      <td>179.6</td>\n",
       "      <td>78.00</td>\n",
       "      <td>20.1</td>\n",
       "      <td>92.0</td>\n",
       "      <td>152.0</td>\n",
       "      <td>44.8</td>\n",
       "      <td>12.0</td>\n",
       "      <td>49.0</td>\n",
       "      <td>181.0</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>32.0</td>\n",
       "      <td>M</td>\n",
       "      <td>174.5</td>\n",
       "      <td>71.10</td>\n",
       "      <td>18.4</td>\n",
       "      <td>76.0</td>\n",
       "      <td>147.0</td>\n",
       "      <td>41.4</td>\n",
       "      <td>15.2</td>\n",
       "      <td>53.0</td>\n",
       "      <td>219.0</td>\n",
       "      <td>B</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>28.0</td>\n",
       "      <td>M</td>\n",
       "      <td>173.8</td>\n",
       "      <td>67.70</td>\n",
       "      <td>17.1</td>\n",
       "      <td>70.0</td>\n",
       "      <td>127.0</td>\n",
       "      <td>43.5</td>\n",
       "      <td>27.1</td>\n",
       "      <td>45.0</td>\n",
       "      <td>217.0</td>\n",
       "      <td>B</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    age gender  height_cm  weight_kg  body fat_%  diastolic  systolic  \\\n",
       "0  27.0      M      172.3      75.24        21.3       80.0     130.0   \n",
       "1  25.0      M      165.0      55.80        15.7       77.0     126.0   \n",
       "2  31.0      M      179.6      78.00        20.1       92.0     152.0   \n",
       "3  32.0      M      174.5      71.10        18.4       76.0     147.0   \n",
       "4  28.0      M      173.8      67.70        17.1       70.0     127.0   \n",
       "\n",
       "   gripForce  sit and bend forward_cm  sit-ups counts  broad jump_cm class  \n",
       "0       54.9                     18.4            60.0          217.0     C  \n",
       "1       36.4                     16.3            53.0          229.0     A  \n",
       "2       44.8                     12.0            49.0          181.0     C  \n",
       "3       41.4                     15.2            53.0          219.0     B  \n",
       "4       43.5                     27.1            45.0          217.0     B  "
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "received-absence",
   "metadata": {},
   "outputs": [],
   "source": [
    "cols = ['gender', 'height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']\n",
    "df = df[cols]\n",
    "\n",
    "# male - 0, female - 1\n",
    "df['gender'].replace({'M': 0, 'F': 1}, inplace = True)\n",
    "df = df.dropna(how='any')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "excited-parent",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    0.632196\n",
       "1    0.367804\n",
       "Name: gender, dtype: float64"
      ]
     },
     "execution_count": 118,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.gender.value_counts() / df.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "extended-cinema",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = df[['height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']]\n",
    "y = df[['gender']]\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "animated-farming",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10714, 5]) torch.Size([10714])\n",
      "torch.Size([2679, 5]) torch.Size([2679])\n"
     ]
    }
   ],
   "source": [
    "X_train = torch.from_numpy(np.array(X_train)).float()\n",
    "y_train = torch.squeeze(torch.from_numpy(y_train.values).float())\n",
    "\n",
    "X_test = torch.from_numpy(np.array(X_test)).float()\n",
    "y_test = torch.squeeze(torch.from_numpy(y_test.values).float())\n",
    "\n",
    "print(X_train.shape, y_train.shape)\n",
    "print(X_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "id": "technical-wallet",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "  def __init__(self, n_features):\n",
    "    super(Net, self).__init__()\n",
    "    self.fc1 = nn.Linear(n_features, 5)\n",
    "    self.fc2 = nn.Linear(5, 3)\n",
    "    self.fc3 = nn.Linear(3, 1)\n",
    "  def forward(self, x):\n",
    "    x = F.relu(self.fc1(x))\n",
    "    x = F.relu(self.fc2(x))\n",
    "    return torch.sigmoid(self.fc3(x))\n",
    "net = Net(X_train.shape[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "id": "requested-plymouth",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.BCELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "iraqi-english",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(net.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "emerging-helmet",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "differential-aviation",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = X_train.to(device)\n",
    "y_train = y_train.to(device)\n",
    "X_test = X_test.to(device)\n",
    "y_test = y_test.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "ranging-calgary",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = net.to(device)\n",
    "criterion = criterion.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "id": "iraqi-blanket",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_accuracy(y_true, y_pred):\n",
    "  predicted = y_pred.ge(.5).view(-1)\n",
    "  return (y_true == predicted).sum().float() / len(y_true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "robust-serbia",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0\n",
      "Train set - loss: 1.005, accuracy: 0.37\n",
      "Test  set - loss: 1.018, accuracy: 0.358\n",
      "\n",
      "epoch 100\n",
      "Train set - loss: 0.677, accuracy: 0.743\n",
      "Test  set - loss: 0.679, accuracy: 0.727\n",
      "\n",
      "epoch 200\n",
      "Train set - loss: 0.636, accuracy: 0.79\n",
      "Test  set - loss: 0.64, accuracy: 0.778\n",
      "\n",
      "epoch 300\n",
      "Train set - loss: 0.568, accuracy: 0.839\n",
      "Test  set - loss: 0.577, accuracy: 0.833\n",
      "\n",
      "epoch 400\n",
      "Train set - loss: 0.504, accuracy: 0.885\n",
      "Test  set - loss: 0.514, accuracy: 0.877\n",
      "\n",
      "epoch 500\n",
      "Train set - loss: 0.441, accuracy: 0.922\n",
      "Test  set - loss: 0.45, accuracy: 0.913\n",
      "\n",
      "epoch 600\n",
      "Train set - loss: 0.388, accuracy: 0.944\n",
      "Test  set - loss: 0.396, accuracy: 0.938\n",
      "\n",
      "epoch 700\n",
      "Train set - loss: 0.353, accuracy: 0.954\n",
      "Test  set - loss: 0.359, accuracy: 0.949\n",
      "\n",
      "epoch 800\n",
      "Train set - loss: 0.327, accuracy: 0.958\n",
      "Test  set - loss: 0.333, accuracy: 0.953\n",
      "\n",
      "epoch 900\n",
      "Train set - loss: 0.306, accuracy: 0.961\n",
      "Test  set - loss: 0.312, accuracy: 0.955\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def round_tensor(t, decimal_places=3):\n",
    "  return round(t.item(), decimal_places)\n",
    "for epoch in range(1000):\n",
    "    y_pred = net(X_train)\n",
    "    y_pred = torch.squeeze(y_pred)\n",
    "    train_loss = criterion(y_pred, y_train)\n",
    "    if epoch % 100 == 0:\n",
    "      train_acc = calculate_accuracy(y_train, y_pred)\n",
    "      y_test_pred = net(X_test)\n",
    "      y_test_pred = torch.squeeze(y_test_pred)\n",
    "      test_loss = criterion(y_test_pred, y_test)\n",
    "      test_acc = calculate_accuracy(y_test, y_test_pred)\n",
    "      print(\n",
    "f'''epoch {epoch}\n",
    "Train set - loss: {round_tensor(train_loss)}, accuracy: {round_tensor(train_acc)}\n",
    "Test  set - loss: {round_tensor(test_loss)}, accuracy: {round_tensor(test_acc)}\n",
    "''')\n",
    "    optimizer.zero_grad()\n",
    "    train_loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "optimum-excerpt",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(net, 'model.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "dental-seating",
   "metadata": {},
   "outputs": [],
   "source": [
    "# net = torch.load('model.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "german-satisfaction",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "        Male       0.97      0.96      0.96      1720\n",
      "      Female       0.93      0.94      0.94       959\n",
      "\n",
      "    accuracy                           0.95      2679\n",
      "   macro avg       0.95      0.95      0.95      2679\n",
      "weighted avg       0.95      0.95      0.95      2679\n",
      "\n"
     ]
    }
   ],
   "source": [
    "classes = ['Male', 'Female']\n",
    "y_pred = net(X_test)\n",
    "y_pred = y_pred.ge(.5).view(-1).cpu()\n",
    "y_test = y_test.cpu()\n",
    "print(classification_report(y_test, y_pred, target_names=classes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "id": "british-incidence",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('test_out.csv', 'w') as file:\n",
    "    for y in y_pred:\n",
    "        file.write(classes[y.item()])\n",
    "        file.write('\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}