352 lines
27 KiB
Plaintext
352 lines
27 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "98cddc6a-2ce1-4933-a2b7-96d2c2d197f4",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"application/javascript": [
|
||
|
"if (window.IPython && IPython.notebook.kernel) IPython.notebook.kernel.execute('jovian.utils.jupyter.get_notebook_name_saved = lambda: \"' + IPython.notebook.notebook_name + '\"')"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.Javascript object>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import jovian\n",
|
||
|
"import torchvision\n",
|
||
|
"import matplotlib\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import pandas as pd\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import seaborn as sns\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"from torchvision.datasets.utils import download_url\n",
|
||
|
"from torch.utils.data import DataLoader, TensorDataset, random_split\n",
|
||
|
"import random\n",
|
||
|
"import os\n",
|
||
|
"import sys\n",
|
||
|
"from sklearn.metrics import mean_squared_error\n",
|
||
|
"from sklearn.metrics import mean_absolute_error"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "7bb63556-d009-4d9f-9de0-033a30ad3fc4",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(['matches', 'wins', 'draws', 'loses', 'scored', 'missed', 'pts'],\n",
|
||
|
" ['position'])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"#load data\n",
|
||
|
"dataframe = pd.read_csv(\"understat.csv\")\n",
|
||
|
"\n",
|
||
|
"#choose columns\n",
|
||
|
"input_cols=list(dataframe.columns)[4:11]\n",
|
||
|
"output_cols = ['position']\n",
|
||
|
"input_cols, output_cols"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "c8151c46-c234-42b7-a786-50c73e3aa2f5",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def dataframe_to_arrays(dataframe):\n",
|
||
|
" dataframe_loc = dataframe.copy(deep=True)\n",
|
||
|
" inputs_array = dataframe_loc[input_cols].to_numpy()\n",
|
||
|
" targets_array = dataframe_loc[output_cols].to_numpy()\n",
|
||
|
" return inputs_array, targets_array\n",
|
||
|
"\n",
|
||
|
"inputs_array, targets_array = dataframe_to_arrays(dataframe)\n",
|
||
|
"\n",
|
||
|
"inputs = torch.from_numpy(inputs_array).type(torch.float)\n",
|
||
|
"targets = torch.from_numpy(targets_array).type(torch.float)\n",
|
||
|
"\n",
|
||
|
"dataset = TensorDataset(inputs, targets)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "8c89947b-c2fe-407d-9588-3f0087df5955",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_ds, val_ds = random_split(dataset, [548, 136])\n",
|
||
|
"batch_size=50\n",
|
||
|
"train_loader = DataLoader(train_ds, batch_size, shuffle=True)\n",
|
||
|
"val_loader = DataLoader(val_ds, batch_size)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "3b1426a0-5b15-46f8-aea9-871462ca9467",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Model_xPosition(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.linear = nn.Linear(input_size,output_size) \n",
|
||
|
" \n",
|
||
|
" def forward(self, xb): \n",
|
||
|
" out = self.linear(xb)\n",
|
||
|
" return out\n",
|
||
|
" \n",
|
||
|
" def training_step(self, batch):\n",
|
||
|
" inputs, targets = batch \n",
|
||
|
" # Generate predictions\n",
|
||
|
" out = self(inputs) \n",
|
||
|
" # Calcuate loss\n",
|
||
|
" loss = F.l1_loss(out,targets) \n",
|
||
|
" return loss\n",
|
||
|
" \n",
|
||
|
" def validation_step(self, batch):\n",
|
||
|
" inputs, targets = batch\n",
|
||
|
" out = self(inputs)\n",
|
||
|
" loss = F.l1_loss(out,targets) \n",
|
||
|
" return {'val_loss': loss.detach()}\n",
|
||
|
" \n",
|
||
|
" def validation_epoch_end(self, outputs):\n",
|
||
|
" batch_losses = [x['val_loss'] for x in outputs]\n",
|
||
|
" epoch_loss = torch.stack(batch_losses).mean() \n",
|
||
|
" return {'val_loss': epoch_loss.item()}\n",
|
||
|
" \n",
|
||
|
" def epoch_end(self, epoch, result, num_epochs):\n",
|
||
|
" if (epoch+1) % 100 == 0 or epoch == num_epochs-1:\n",
|
||
|
" print(\"Epoch {} loss: {:.4f}\".format(epoch+1, result['val_loss']))\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
"def evaluate(model, val_loader):\n",
|
||
|
" outputs = [model.validation_step(batch) for batch in val_loader]\n",
|
||
|
" return model.validation_epoch_end(outputs)\n",
|
||
|
"\n",
|
||
|
"def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):\n",
|
||
|
" history = []\n",
|
||
|
" optimizer = opt_func(model.parameters(), lr)\n",
|
||
|
" for epoch in range(epochs):\n",
|
||
|
" for batch in train_loader:\n",
|
||
|
" loss = model.training_step(batch)\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" result = evaluate(model, val_loader)\n",
|
||
|
" model.epoch_end(epoch, result, epochs)\n",
|
||
|
" history.append(result)\n",
|
||
|
" return history"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "f2e22e9a-8724-4084-b706-0be266846c05",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"input_size = len(input_cols)\n",
|
||
|
"output_size = len(output_cols)\n",
|
||
|
"model=Model_xPosition()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "efacafe4-797a-4588-b0d8-2e4d883e639a",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 100 loss: 2.2152\n",
|
||
|
"Epoch 200 loss: 1.8737\n",
|
||
|
"Epoch 300 loss: 1.8362\n",
|
||
|
"Epoch 400 loss: 1.7904\n",
|
||
|
"Epoch 500 loss: 1.7507\n",
|
||
|
"Epoch 600 loss: 1.7174\n",
|
||
|
"Epoch 700 loss: 1.6977\n",
|
||
|
"Epoch 800 loss: 1.6847\n",
|
||
|
"Epoch 900 loss: 1.6743\n",
|
||
|
"Epoch 1000 loss: 1.6645\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"epochs = 1000\n",
|
||
|
"lr = 1e-5\n",
|
||
|
"learning_proccess = fit(epochs, lr, model, train_loader, val_loader)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "7007ab5a-dc79-4321-beed-cd54dd197858",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def predict_single(input, target, model):\n",
|
||
|
" inputs = input.unsqueeze(0)\n",
|
||
|
" predictions = model(inputs)\n",
|
||
|
" prediction = predictions[0].detach()\n",
|
||
|
"\n",
|
||
|
" return \"Target: \"+str(target)+\" Predicted: \"+str(prediction)+\"\\n\""
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "c946417a-693a-463a-b123-54348266ff6e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def prediction(input, target, model):\n",
|
||
|
" inputs = input.unsqueeze(0)\n",
|
||
|
" predictions = model(inputs)\n",
|
||
|
" predicted = predictions[0].detach()\n",
|
||
|
" return predicted"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "50c62065-5094-4595-995c-6d0b71f1f28a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"with open(\"result.txt\", \"a+\") as file:\n",
|
||
|
" for i in range(0, len(val_ds), 1):\n",
|
||
|
" input_, target = val_ds[i]\n",
|
||
|
" file.write(str(predict_single(input_, target, model)))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "fdd6468c-3d50-4131-b2d2-e3190d8b19e5",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"expected = []\n",
|
||
|
"predicted = []\n",
|
||
|
"for i in range(0, len(val_ds), 1):\n",
|
||
|
" input_, target = val_ds[i]\n",
|
||
|
" expected.append(float(target))\n",
|
||
|
" predicted.append(float(prediction(input_, target, model)))\n",
|
||
|
"\n",
|
||
|
"MSE = mean_squared_error(expected, predicted)\n",
|
||
|
"MAE = mean_absolute_error(expected, predicted)\n",
|
||
|
"\n",
|
||
|
"with open(\"metrics.txt\", \"a+\") as file:\n",
|
||
|
" file.write(\"Mean squared error: MSE = \"+ str(MSE) + \"\\n\")\n",
|
||
|
" file.write(\"Mean absolute error: MAE = \"+ str(MAE)+ \"\\n\")\n",
|
||
|
"\n",
|
||
|
"with open(\"MSE.txt\", \"a+\") as file:\n",
|
||
|
" file.write(str(MSE) + \"\\n\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "4a1e5b16-7d80-47b0-8313-e44667687779",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"with open('MSE.txt') as file:\n",
|
||
|
" y_MSE = [float(line) for line in file if line]\n",
|
||
|
" x_builds = list(range(1, len(y_MSE) + 1))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "1da02242-846a-499c-92eb-4c80fe5f43c4",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAy/UlEQVR4nO3dd3hUZfbA8e9JgUDoVXqTKgQIISCBBJQqCIKyggUVaYKAsO6qu8q6Vlx3QxNEsKCCoCIgSguoEDok1NBbKCImVOkQcn5/zMAvhoEEyGSSyfk8zzyZe9/33jk34py8t5xXVBVjjDEmNR9PB2CMMSZrsgRhjDHGJUsQxhhjXLIEYYwxxiVLEMYYY1zy83QAGalYsWJasWJFT4dhjDHZRmxs7FFVLe6qzasSRMWKFYmJifF0GMYYk22IyP4btdkpJmOMMS5ZgjDGGOOSJQhjjDEuedU1CGOMw+XLlzl06BAXLlzwdCgmiwgICKBs2bL4+/unextLEMZ4oUOHDpE/f34qVqyIiHg6HONhqsqxY8c4dOgQlSpVSvd2dorJGC904cIFihYtasnBACAiFC1a9JZHlJYgjPFSlhxMSrfz78GtCUJEConIdBHZLiLbROTeVO2Pi8gm52uFiNRN0RYvIptFZIOIuPXhhtE/7WLjwZPu/AhjjMl23D2CGAXMV9UaQF1gW6r2fUCEqgYBbwITUrW3UNV6qhrirgBPnrvEV6sP0Hncct6Zu43zl66466OMyVFEhCeffPLaclJSEsWLF6dDhw4ejMpz8uXL5+kQbpnbEoSIFADCgU8AVPWSqp5M2UdVV6jqCefiKqCsu+K5kUJ5cxE1NJxHG5ZnQvRe2o2KZuWeY5kdhjFeJzAwkLi4OM6fPw/AwoULKVOmjIejylhJSUmZ8jlXrly56bIrqkpycvIdfa47RxCVgUTgMxFZLyIfi0jgTfo/C8xLsaxAlIjEikifG20kIn1EJEZEYhITE28r0AIB/rzbpQ5f9W6EAt0nruIfMzfzx4XLt7U/Y4xDu3btmDNnDgBTp06le/fu19rOnj1Lz549adiwIfXr1+f7778HID4+nmbNmhEcHExwcDArVqwAYPHixTRv3pxHHnmEGjVq8Pjjj+NqRszRo0dTq1YtgoKC6NatGwDHjh2jdevW1K9fn759+1KhQgWOHj1KfHw8tWvXvrbtf//7X15//XUAJk6cSMOGDalbty4PP/ww586dA+Dpp59m6NChtGjRgpdeeok9e/bQtm1bGjRoQLNmzdi+fTsA+/bt495776Vhw4a89tprN/wdTZ48mdDQUOrVq0ffvn2vffnny5ePYcOG0ahRI1auXHndcmRkJLVr16Z27dqMHDny2u+uZs2a9O/fn+DgYA4ePHjL/81Scudtrn5AMDBQVVeLyCjgZeC635SItMCRIJqmWB2mqodFpASwUES2q2p06m1VdQLOU1MhISF3NH9qkyrFmD84nMiFO/hk2T5+3pbA251rc3/NkneyW2M86t8/bGHr4T8ydJ+1ShfgXw/ek2a/bt268cYbb9ChQwc2bdpEz549Wbp0KQBvv/029913H59++iknT54kNDSUli1bUqJECRYuXEhAQAC7du2ie/fu12qsrV+/ni1btlC6dGnCwsJYvnw5TZs2/dNnDh8+nH379pE7d25Onjzp+B38+980bdqUYcOGMWfOHCZMSH02+3pdunShd+/eALz66qt88sknDBw4EICdO3eyaNEifH19uf/++xk/fjxVq1Zl9erV9O/fn59//pnBgwfz3HPP0aNHD8aOHevyM7Zt28bXX3/N8uXL8ff3p3///kyZMoUePXpw9uxZateuzRtvvAHwp+XY2Fg+++wzVq9ejarSqFEjIiIiKFy4MDt27OCzzz5j3LhxaR5jWtw5gjgEHFLV1c7l6TgSxp+ISBDwMdBJVa+d21HVw86fCcBMINSNsV6TJ5cv/2xfixn9wyiYx59nP49h0NT1HDtzMTM+3hivEhQURHx8PFOnTuWBBx74U1tUVBTDhw+nXr16NG/enAsXLnDgwAEuX75M7969qVOnDl27dmXr1q3XtgkNDaVs2bL4+PhQr1494uPjXX7m448/zuTJk/Hzc/wNHB0dzRNPPAFA+/btKVy4cJqxx8XF0axZM+rUqcOUKVPYsmXLtbauXbvi6+vLmTNnWLFiBV27dr02Avjtt98AWL58+bURU8prMSn99NNPxMbG0rBhQ+rVq8dPP/3E3r17AfD19eXhhx++1jfl8rJly+jcuTOBgYHky5ePLl26XEu8FSpUoHHjxmkeX3q4bQShqkdE5KCIVFfVHcD9wNaUfUSkPDADeFJVd6ZYHwj4qOpp5/vWwBvuitWVeuUK8cPApoxbvJuxv+xm2e6j/OvBWnSsW9puHzTZSnr+0nenjh078uKLL7J48WKOHfv/63uqynfffUf16tX/1P/111+nZMmSbNy4keTkZAICAq615c6d+9p7X19fl9cA5syZQ3R0NLNnz+bNN9+89sXu6v9bPz+/P52nT/mcwNNPP82sWbOoW7cukyZNYvHixdfaAgMdZ8uTk5MpVKgQGzZscHnsaX1XqCpPPfUU77777nVtAQEB+Pr6ulx2dWotdWwZwd13MQ0EpojIJqAe8I6I9BORfs72YUBRYFyq21lLAstEZCOwBpijqvPdHOt1cvn58ELLavw4sBnliuRl8LQN9Po8ht9Onc/sUIzJtnr27MmwYcOoU6fOn9a3adOGMWPGXPuyW79+PQCnTp2iVKlS+Pj48OWXX6brguxVycnJHDx4kBYtWvCf//yHkydPcubMGcLDw5kyZQoA8+bN48QJx70xJUuWJCEhgWPHjnHx4kV+/PHHa/s6ffo0pUqV4vLly9e2Ta1AgQJUqlSJb7/9FnB8cW/cuBGAsLAwpk2bBnDD7e+//36mT59OQkICAMePH2f//htW374mPDycWbNmce7cOc6ePcvMmTNp1qxZen5Ft8StCUJVN6hqiKoGqepDqnpCVcer6nhney9VLey8lfXa7ayquldV6zpf96jq2+6MMy3V78rPjOea8Gr7mizfc5TWkdF8tfoAycl3dMnDmByhbNmyDB48+Lr1r732GpcvXyYoKIjatWtfu5Dbv39/Pv/8cxo3bszOnTtv6S/iK1eu8MQTT1CnTh3q16/PkCFDKFSoEP/617+Ijo4mODiYqKgoypcvD4C/v/+1C78dOnSgRo0a1/b15ptv0qhRI1q1avWn9alNmTKFTz75hLp163LPPfdcu9g+atQoxo4dS8OGDTl16pTLbWvVqsVbb71F69atCQoKolWrVtdOUd1McHAwTz/9NKGhoTRq1IhevXpRv379dP+e0ktuNlTJbkJCQtTdEwbtP3aWl7/bzMq9x2hcuQjDuwRRsVjGDemMyQjbtm2jZs2ang4jy7o6uVixYsU8HUqmcvXvQkRib/SsmZXauEUVigbyVe9GDO9Shy2//kGbkdFMiN5D0pU7u9/YGGOyGksQt0FE6BZanoVDI2hWtRjvzN3Owx+uYPuRjL2V0BjjHvHx8Tlu9HA7LEHcgbsKBjCxRwhjutfn0InzdBi9jMiFO7mYZOU6jOd50+ljc+du59+DJYg7JCI8WLc0C4dG0CGoFKN/2sWDY5ax/sCJtDc2xk0CAgI4duyYJQkD/P98EClvGU4Pu0idwX7e/jv/nBnHkT8u0DOsEn9tXY28uWxeJpO5bEY5k9qNZpS72UVqSxBucPrCZd6bv53Jqw5QrkgehncJIuxuO99pjMl67C6mTJY/wJ+3HqrDtD6N8RXh8Y9X8/J3mzh13or/GWOyD0sQbtS4clHmvxBO34jKfBNzkFaRS4jacsTTYRljTLpYgnCzAH9fXmlXk1kDwigSmIs+X8by/FfrOGrF/4wxWZwliEwSVLYQs59vyl9bVSNqy++0jFzCzPWH7C4TY0yWZQkiE+Xy82Hg/VWZM6gplYoFMuTrjfSctJbDJ634nzEm67EE4QFVS+Zner8mDOtQi1V7j9Mqcglfrtpvxf+MMVmKJQgP8fURejatRNSQcOqXL8xrs+L
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 0 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.xlabel('Number of builds')\n",
|
||
|
"plt.ylabel('MSE')\n",
|
||
|
"plt.plot(x_builds, y_MSE, label='Mean squared error')\n",
|
||
|
"plt.legend()\n",
|
||
|
"plt.show()\n",
|
||
|
"plt.savefig('RMSplot.png')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "8dffe789-1ad5-44f1-8f21-92b9c89ed974",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"!jupyter nbconvert --to script ml_pytorch.ipynb"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.7"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|