532 lines
14 KiB
Plaintext
532 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "forty-fault",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!kaggle datasets download -d kukuroo3/body-performance-data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "pediatric-tuesday",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!unzip -o body-performance-data.zip"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 114,
|
|
"id": "interstate-presence",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"from sklearn.metrics import classification_report\n",
|
|
"import torch\n",
|
|
"from torch import nn, optim\n",
|
|
"import torch.nn.functional as F"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 115,
|
|
"id": "structural-trigger",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(13393, 12)"
|
|
]
|
|
},
|
|
"execution_count": 115,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"df = pd.read_csv('bodyPerformance.csv')\n",
|
|
"df.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 116,
|
|
"id": "turkish-category",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>age</th>\n",
|
|
" <th>gender</th>\n",
|
|
" <th>height_cm</th>\n",
|
|
" <th>weight_kg</th>\n",
|
|
" <th>body fat_%</th>\n",
|
|
" <th>diastolic</th>\n",
|
|
" <th>systolic</th>\n",
|
|
" <th>gripForce</th>\n",
|
|
" <th>sit and bend forward_cm</th>\n",
|
|
" <th>sit-ups counts</th>\n",
|
|
" <th>broad jump_cm</th>\n",
|
|
" <th>class</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>27.0</td>\n",
|
|
" <td>M</td>\n",
|
|
" <td>172.3</td>\n",
|
|
" <td>75.24</td>\n",
|
|
" <td>21.3</td>\n",
|
|
" <td>80.0</td>\n",
|
|
" <td>130.0</td>\n",
|
|
" <td>54.9</td>\n",
|
|
" <td>18.4</td>\n",
|
|
" <td>60.0</td>\n",
|
|
" <td>217.0</td>\n",
|
|
" <td>C</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>25.0</td>\n",
|
|
" <td>M</td>\n",
|
|
" <td>165.0</td>\n",
|
|
" <td>55.80</td>\n",
|
|
" <td>15.7</td>\n",
|
|
" <td>77.0</td>\n",
|
|
" <td>126.0</td>\n",
|
|
" <td>36.4</td>\n",
|
|
" <td>16.3</td>\n",
|
|
" <td>53.0</td>\n",
|
|
" <td>229.0</td>\n",
|
|
" <td>A</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>31.0</td>\n",
|
|
" <td>M</td>\n",
|
|
" <td>179.6</td>\n",
|
|
" <td>78.00</td>\n",
|
|
" <td>20.1</td>\n",
|
|
" <td>92.0</td>\n",
|
|
" <td>152.0</td>\n",
|
|
" <td>44.8</td>\n",
|
|
" <td>12.0</td>\n",
|
|
" <td>49.0</td>\n",
|
|
" <td>181.0</td>\n",
|
|
" <td>C</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>32.0</td>\n",
|
|
" <td>M</td>\n",
|
|
" <td>174.5</td>\n",
|
|
" <td>71.10</td>\n",
|
|
" <td>18.4</td>\n",
|
|
" <td>76.0</td>\n",
|
|
" <td>147.0</td>\n",
|
|
" <td>41.4</td>\n",
|
|
" <td>15.2</td>\n",
|
|
" <td>53.0</td>\n",
|
|
" <td>219.0</td>\n",
|
|
" <td>B</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>28.0</td>\n",
|
|
" <td>M</td>\n",
|
|
" <td>173.8</td>\n",
|
|
" <td>67.70</td>\n",
|
|
" <td>17.1</td>\n",
|
|
" <td>70.0</td>\n",
|
|
" <td>127.0</td>\n",
|
|
" <td>43.5</td>\n",
|
|
" <td>27.1</td>\n",
|
|
" <td>45.0</td>\n",
|
|
" <td>217.0</td>\n",
|
|
" <td>B</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" age gender height_cm weight_kg body fat_% diastolic systolic \\\n",
|
|
"0 27.0 M 172.3 75.24 21.3 80.0 130.0 \n",
|
|
"1 25.0 M 165.0 55.80 15.7 77.0 126.0 \n",
|
|
"2 31.0 M 179.6 78.00 20.1 92.0 152.0 \n",
|
|
"3 32.0 M 174.5 71.10 18.4 76.0 147.0 \n",
|
|
"4 28.0 M 173.8 67.70 17.1 70.0 127.0 \n",
|
|
"\n",
|
|
" gripForce sit and bend forward_cm sit-ups counts broad jump_cm class \n",
|
|
"0 54.9 18.4 60.0 217.0 C \n",
|
|
"1 36.4 16.3 53.0 229.0 A \n",
|
|
"2 44.8 12.0 49.0 181.0 C \n",
|
|
"3 41.4 15.2 53.0 219.0 B \n",
|
|
"4 43.5 27.1 45.0 217.0 B "
|
|
]
|
|
},
|
|
"execution_count": 116,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"df.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 117,
|
|
"id": "received-absence",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cols = ['gender', 'height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']\n",
|
|
"df = df[cols]\n",
|
|
"\n",
|
|
"# male - 0, female - 1\n",
|
|
"df['gender'].replace({'M': 0, 'F': 1}, inplace = True)\n",
|
|
"df = df.dropna(how='any')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 118,
|
|
"id": "excited-parent",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0 0.632196\n",
|
|
"1 0.367804\n",
|
|
"Name: gender, dtype: float64"
|
|
]
|
|
},
|
|
"execution_count": 118,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"df.gender.value_counts() / df.shape[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 119,
|
|
"id": "extended-cinema",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X = df[['height_cm', 'weight_kg', 'body fat_%', 'sit-ups counts', 'broad jump_cm']]\n",
|
|
"y = df[['gender']]\n",
|
|
"\n",
|
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 120,
|
|
"id": "animated-farming",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"torch.Size([10714, 5]) torch.Size([10714])\n",
|
|
"torch.Size([2679, 5]) torch.Size([2679])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"X_train = torch.from_numpy(np.array(X_train)).float()\n",
|
|
"y_train = torch.squeeze(torch.from_numpy(y_train.values).float())\n",
|
|
"\n",
|
|
"X_test = torch.from_numpy(np.array(X_test)).float()\n",
|
|
"y_test = torch.squeeze(torch.from_numpy(y_test.values).float())\n",
|
|
"\n",
|
|
"print(X_train.shape, y_train.shape)\n",
|
|
"print(X_test.shape, y_test.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 121,
|
|
"id": "technical-wallet",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Net(nn.Module):\n",
|
|
" def __init__(self, n_features):\n",
|
|
" super(Net, self).__init__()\n",
|
|
" self.fc1 = nn.Linear(n_features, 5)\n",
|
|
" self.fc2 = nn.Linear(5, 3)\n",
|
|
" self.fc3 = nn.Linear(3, 1)\n",
|
|
" def forward(self, x):\n",
|
|
" x = F.relu(self.fc1(x))\n",
|
|
" x = F.relu(self.fc2(x))\n",
|
|
" return torch.sigmoid(self.fc3(x))\n",
|
|
"net = Net(X_train.shape[1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 122,
|
|
"id": "requested-plymouth",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"criterion = nn.BCELoss()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 123,
|
|
"id": "iraqi-english",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"optimizer = optim.Adam(net.parameters(), lr=0.001)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 124,
|
|
"id": "emerging-helmet",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 125,
|
|
"id": "differential-aviation",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"X_train = X_train.to(device)\n",
|
|
"y_train = y_train.to(device)\n",
|
|
"X_test = X_test.to(device)\n",
|
|
"y_test = y_test.to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 126,
|
|
"id": "ranging-calgary",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"net = net.to(device)\n",
|
|
"criterion = criterion.to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 127,
|
|
"id": "iraqi-blanket",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def calculate_accuracy(y_true, y_pred):\n",
|
|
" predicted = y_pred.ge(.5).view(-1)\n",
|
|
" return (y_true == predicted).sum().float() / len(y_true)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 128,
|
|
"id": "robust-serbia",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 0\n",
|
|
"Train set - loss: 1.005, accuracy: 0.37\n",
|
|
"Test set - loss: 1.018, accuracy: 0.358\n",
|
|
"\n",
|
|
"epoch 100\n",
|
|
"Train set - loss: 0.677, accuracy: 0.743\n",
|
|
"Test set - loss: 0.679, accuracy: 0.727\n",
|
|
"\n",
|
|
"epoch 200\n",
|
|
"Train set - loss: 0.636, accuracy: 0.79\n",
|
|
"Test set - loss: 0.64, accuracy: 0.778\n",
|
|
"\n",
|
|
"epoch 300\n",
|
|
"Train set - loss: 0.568, accuracy: 0.839\n",
|
|
"Test set - loss: 0.577, accuracy: 0.833\n",
|
|
"\n",
|
|
"epoch 400\n",
|
|
"Train set - loss: 0.504, accuracy: 0.885\n",
|
|
"Test set - loss: 0.514, accuracy: 0.877\n",
|
|
"\n",
|
|
"epoch 500\n",
|
|
"Train set - loss: 0.441, accuracy: 0.922\n",
|
|
"Test set - loss: 0.45, accuracy: 0.913\n",
|
|
"\n",
|
|
"epoch 600\n",
|
|
"Train set - loss: 0.388, accuracy: 0.944\n",
|
|
"Test set - loss: 0.396, accuracy: 0.938\n",
|
|
"\n",
|
|
"epoch 700\n",
|
|
"Train set - loss: 0.353, accuracy: 0.954\n",
|
|
"Test set - loss: 0.359, accuracy: 0.949\n",
|
|
"\n",
|
|
"epoch 800\n",
|
|
"Train set - loss: 0.327, accuracy: 0.958\n",
|
|
"Test set - loss: 0.333, accuracy: 0.953\n",
|
|
"\n",
|
|
"epoch 900\n",
|
|
"Train set - loss: 0.306, accuracy: 0.961\n",
|
|
"Test set - loss: 0.312, accuracy: 0.955\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def round_tensor(t, decimal_places=3):\n",
|
|
" return round(t.item(), decimal_places)\n",
|
|
"for epoch in range(1000):\n",
|
|
" y_pred = net(X_train)\n",
|
|
" y_pred = torch.squeeze(y_pred)\n",
|
|
" train_loss = criterion(y_pred, y_train)\n",
|
|
" if epoch % 100 == 0:\n",
|
|
" train_acc = calculate_accuracy(y_train, y_pred)\n",
|
|
" y_test_pred = net(X_test)\n",
|
|
" y_test_pred = torch.squeeze(y_test_pred)\n",
|
|
" test_loss = criterion(y_test_pred, y_test)\n",
|
|
" test_acc = calculate_accuracy(y_test, y_test_pred)\n",
|
|
" print(\n",
|
|
"f'''epoch {epoch}\n",
|
|
"Train set - loss: {round_tensor(train_loss)}, accuracy: {round_tensor(train_acc)}\n",
|
|
"Test set - loss: {round_tensor(test_loss)}, accuracy: {round_tensor(test_acc)}\n",
|
|
"''')\n",
|
|
" optimizer.zero_grad()\n",
|
|
" train_loss.backward()\n",
|
|
" optimizer.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 129,
|
|
"id": "optimum-excerpt",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# torch.save(net, 'model.pth')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 130,
|
|
"id": "dental-seating",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# net = torch.load('model.pth')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 131,
|
|
"id": "german-satisfaction",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" precision recall f1-score support\n",
|
|
"\n",
|
|
" Male 0.97 0.96 0.96 1720\n",
|
|
" Female 0.93 0.94 0.94 959\n",
|
|
"\n",
|
|
" accuracy 0.95 2679\n",
|
|
" macro avg 0.95 0.95 0.95 2679\n",
|
|
"weighted avg 0.95 0.95 0.95 2679\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"classes = ['Male', 'Female']\n",
|
|
"y_pred = net(X_test)\n",
|
|
"y_pred = y_pred.ge(.5).view(-1).cpu()\n",
|
|
"y_test = y_test.cpu()\n",
|
|
"print(classification_report(y_test, y_pred, target_names=classes))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 132,
|
|
"id": "british-incidence",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with open('test_out.csv', 'w') as file:\n",
|
|
" for y in y_pred:\n",
|
|
" file.write(classes[y.item()])\n",
|
|
" file.write('\\n')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|