ium_478839/ml_pytorch_results.ipynb

352 lines
27 KiB
Plaintext
Raw Normal View History

2022-05-06 22:45:36 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "98cddc6a-2ce1-4933-a2b7-96d2c2d197f4",
"metadata": {},
"outputs": [
{
"data": {
"application/javascript": [
"if (window.IPython && IPython.notebook.kernel) IPython.notebook.kernel.execute('jovian.utils.jupyter.get_notebook_name_saved = lambda: \"' + IPython.notebook.notebook_name + '\"')"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import torch\n",
"import jovian\n",
"import torchvision\n",
"import matplotlib\n",
"import torch.nn as nn\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import torch.nn.functional as F\n",
"from torchvision.datasets.utils import download_url\n",
"from torch.utils.data import DataLoader, TensorDataset, random_split\n",
"import random\n",
"import os\n",
"import sys\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.metrics import mean_absolute_error"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7bb63556-d009-4d9f-9de0-033a30ad3fc4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(['matches', 'wins', 'draws', 'loses', 'scored', 'missed', 'pts'],\n",
" ['position'])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#load data\n",
"dataframe = pd.read_csv(\"understat.csv\")\n",
"\n",
"#choose columns\n",
"input_cols=list(dataframe.columns)[4:11]\n",
"output_cols = ['position']\n",
"input_cols, output_cols"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c8151c46-c234-42b7-a786-50c73e3aa2f5",
"metadata": {},
"outputs": [],
"source": [
"def dataframe_to_arrays(dataframe):\n",
" dataframe_loc = dataframe.copy(deep=True)\n",
" inputs_array = dataframe_loc[input_cols].to_numpy()\n",
" targets_array = dataframe_loc[output_cols].to_numpy()\n",
" return inputs_array, targets_array\n",
"\n",
"inputs_array, targets_array = dataframe_to_arrays(dataframe)\n",
"\n",
"inputs = torch.from_numpy(inputs_array).type(torch.float)\n",
"targets = torch.from_numpy(targets_array).type(torch.float)\n",
"\n",
"dataset = TensorDataset(inputs, targets)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8c89947b-c2fe-407d-9588-3f0087df5955",
"metadata": {},
"outputs": [],
"source": [
"train_ds, val_ds = random_split(dataset, [548, 136])\n",
"batch_size=50\n",
"train_loader = DataLoader(train_ds, batch_size, shuffle=True)\n",
"val_loader = DataLoader(val_ds, batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3b1426a0-5b15-46f8-aea9-871462ca9467",
"metadata": {},
"outputs": [],
"source": [
"class Model_xPosition(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.linear = nn.Linear(input_size,output_size) \n",
" \n",
" def forward(self, xb): \n",
" out = self.linear(xb)\n",
" return out\n",
" \n",
" def training_step(self, batch):\n",
" inputs, targets = batch \n",
" # Generate predictions\n",
" out = self(inputs) \n",
" # Calcuate loss\n",
" loss = F.l1_loss(out,targets) \n",
" return loss\n",
" \n",
" def validation_step(self, batch):\n",
" inputs, targets = batch\n",
" out = self(inputs)\n",
" loss = F.l1_loss(out,targets) \n",
" return {'val_loss': loss.detach()}\n",
" \n",
" def validation_epoch_end(self, outputs):\n",
" batch_losses = [x['val_loss'] for x in outputs]\n",
" epoch_loss = torch.stack(batch_losses).mean() \n",
" return {'val_loss': epoch_loss.item()}\n",
" \n",
" def epoch_end(self, epoch, result, num_epochs):\n",
" if (epoch+1) % 100 == 0 or epoch == num_epochs-1:\n",
" print(\"Epoch {} loss: {:.4f}\".format(epoch+1, result['val_loss']))\n",
" \n",
" \n",
"def evaluate(model, val_loader):\n",
" outputs = [model.validation_step(batch) for batch in val_loader]\n",
" return model.validation_epoch_end(outputs)\n",
"\n",
"def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):\n",
" history = []\n",
" optimizer = opt_func(model.parameters(), lr)\n",
" for epoch in range(epochs):\n",
" for batch in train_loader:\n",
" loss = model.training_step(batch)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" result = evaluate(model, val_loader)\n",
" model.epoch_end(epoch, result, epochs)\n",
" history.append(result)\n",
" return history"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f2e22e9a-8724-4084-b706-0be266846c05",
"metadata": {},
"outputs": [],
"source": [
"input_size = len(input_cols)\n",
"output_size = len(output_cols)\n",
"model=Model_xPosition()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "efacafe4-797a-4588-b0d8-2e4d883e639a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 100 loss: 2.2152\n",
"Epoch 200 loss: 1.8737\n",
"Epoch 300 loss: 1.8362\n",
"Epoch 400 loss: 1.7904\n",
"Epoch 500 loss: 1.7507\n",
"Epoch 600 loss: 1.7174\n",
"Epoch 700 loss: 1.6977\n",
"Epoch 800 loss: 1.6847\n",
"Epoch 900 loss: 1.6743\n",
"Epoch 1000 loss: 1.6645\n"
]
}
],
"source": [
"epochs = 1000\n",
"lr = 1e-5\n",
"learning_proccess = fit(epochs, lr, model, train_loader, val_loader)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7007ab5a-dc79-4321-beed-cd54dd197858",
"metadata": {},
"outputs": [],
"source": [
"def predict_single(input, target, model):\n",
" inputs = input.unsqueeze(0)\n",
" predictions = model(inputs)\n",
" prediction = predictions[0].detach()\n",
"\n",
" return \"Target: \"+str(target)+\" Predicted: \"+str(prediction)+\"\\n\""
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c946417a-693a-463a-b123-54348266ff6e",
"metadata": {},
"outputs": [],
"source": [
"def prediction(input, target, model):\n",
" inputs = input.unsqueeze(0)\n",
" predictions = model(inputs)\n",
" predicted = predictions[0].detach()\n",
" return predicted"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "50c62065-5094-4595-995c-6d0b71f1f28a",
"metadata": {},
"outputs": [],
"source": [
"with open(\"result.txt\", \"a+\") as file:\n",
" for i in range(0, len(val_ds), 1):\n",
" input_, target = val_ds[i]\n",
" file.write(str(predict_single(input_, target, model)))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fdd6468c-3d50-4131-b2d2-e3190d8b19e5",
"metadata": {},
"outputs": [],
"source": [
"expected = []\n",
"predicted = []\n",
"for i in range(0, len(val_ds), 1):\n",
" input_, target = val_ds[i]\n",
" expected.append(float(target))\n",
" predicted.append(float(prediction(input_, target, model)))\n",
"\n",
"MSE = mean_squared_error(expected, predicted)\n",
"MAE = mean_absolute_error(expected, predicted)\n",
"\n",
"with open(\"metrics.txt\", \"a+\") as file:\n",
" file.write(\"Mean squared error: MSE = \"+ str(MSE) + \"\\n\")\n",
" file.write(\"Mean absolute error: MAE = \"+ str(MAE)+ \"\\n\")\n",
"\n",
"with open(\"MSE.txt\", \"a+\") as file:\n",
" file.write(str(MSE) + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "4a1e5b16-7d80-47b0-8313-e44667687779",
"metadata": {},
"outputs": [],
"source": [
"with open('MSE.txt') as file:\n",
" y_MSE = [float(line) for line in file if line]\n",
" x_builds = list(range(1, len(y_MSE) + 1))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "1da02242-846a-499c-92eb-4c80fe5f43c4",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEGCAYAAAB/+QKOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAy/UlEQVR4nO3dd3hUZfbA8e9JgUDoVXqTKgQIISCBBJQqCIKyggUVaYKAsO6qu8q6Vlx3QxNEsKCCoCIgSguoEDok1NBbKCImVOkQcn5/zMAvhoEEyGSSyfk8zzyZe9/33jk34py8t5xXVBVjjDEmNR9PB2CMMSZrsgRhjDHGJUsQxhhjXLIEYYwxxiVLEMYYY1zy83QAGalYsWJasWJFT4dhjDHZRmxs7FFVLe6qzasSRMWKFYmJifF0GMYYk22IyP4btdkpJmOMMS5ZgjDGGOOSJQhjjDEuedU1CGOMw+XLlzl06BAXLlzwdCgmiwgICKBs2bL4+/unextLEMZ4oUOHDpE/f34qVqyIiHg6HONhqsqxY8c4dOgQlSpVSvd2dorJGC904cIFihYtasnBACAiFC1a9JZHlJYgjPFSlhxMSrfz78GtCUJEConIdBHZLiLbROTeVO2Pi8gm52uFiNRN0RYvIptFZIOIuPXhhtE/7WLjwZPu/AhjjMl23D2CGAXMV9UaQF1gW6r2fUCEqgYBbwITUrW3UNV6qhrirgBPnrvEV6sP0Hncct6Zu43zl66466OMyVFEhCeffPLaclJSEsWLF6dDhw4ejMpz8uXL5+kQbpnbEoSIFADCgU8AVPWSqp5M2UdVV6jqCefiKqCsu+K5kUJ5cxE1NJxHG5ZnQvRe2o2KZuWeY5kdhjFeJzAwkLi4OM6fPw/AwoULKVOmjIejylhJSUmZ8jlXrly56bIrqkpycvIdfa47RxCVgUTgMxFZLyIfi0jgTfo/C8xLsaxAlIjEikifG20kIn1EJEZEYhITE28r0AIB/rzbpQ5f9W6EAt0nruIfMzfzx4XLt7U/Y4xDu3btmDNnDgBTp06le/fu19rOnj1Lz549adiwIfXr1+f7778HID4+nmbNmhEcHExwcDArVqwAYPHixTRv3pxHHnmEGjVq8Pjjj+NqRszRo0dTq1YtgoKC6NatGwDHjh2jdevW1K9fn759+1KhQgWOHj1KfHw8tWvXvrbtf//7X15//XUAJk6cSMOGDalbty4PP/ww586dA+Dpp59m6NChtGjRgpdeeok9e/bQtm1bGjRoQLNmzdi+fTsA+/bt495776Vhw4a89tprN/wdTZ48mdDQUOrVq0ffvn2vffnny5ePYcOG0ahRI1auXHndcmRkJLVr16Z27dqMHDny2u+uZs2a9O/fn+DgYA4ePHjL/81Scudtrn5AMDBQVVeLyCjgZeC635SItMCRIJqmWB2mqodFpASwUES2q2p06m1VdQLOU1MhISF3NH9qkyrFmD84nMiFO/hk2T5+3pbA251rc3/NkneyW2M86t8/bGHr4T8ydJ+1ShfgXw/ek2a/bt268cYbb9ChQwc2bdpEz549Wbp0KQBvv/029913H59++iknT54kNDSUli1bUqJECRYuXEhAQAC7du2ie/fu12qsrV+/ni1btlC6dGnCwsJYvnw5TZs2/dNnDh8+nH379pE7d25Onjzp+B38+980bdqUYcOGMWfOHCZMSH02+3pdunShd+/eALz66qt88sknDBw4EICdO3eyaNEifH19uf/++xk/fjxVq1Zl9erV9O/fn59//pnBgwfz3HPP0aNHD8aOHevyM7Zt28bXX3/N8uXL8ff3p3///kyZMoUePXpw9uxZateuzRtvvAHwp+XY2Fg+++wzVq9ejarSqFEjIiIiKFy4MDt27OCzzz5j3LhxaR5jWtw5gjgEHFLV1c7l6TgSxp+ISBDwMdBJVa+d21HVw86fCcBMINSNsV6TJ5cv/2xfixn9wyiYx59nP49h0NT1HDtzMTM+3hivEhQURHx8PFOnTuWBBx74U1tUVBTDhw+nXr16NG/enAsXLnDgwAEuX75M7969qVOnDl27dmXr1q3XtgkNDaVs2bL4+PhQr1494uPjXX7m448/zuTJk/Hzc/wNHB0dzRNPPAFA+/btKVy4cJqxx8XF0axZM+rUqcOUKVPYsmXLtbauXbvi6+vLmTNnWLFiBV27dr02Avjtt98AWL58+bURU8prMSn99NNPxMbG0rBhQ+rVq8dPP/3E3r17AfD19eXhhx++1jfl8rJly+jcuTOBgYHky5ePLl26XEu8FSpUoHHjxmkeX3q4bQShqkdE5KCIVFfVHcD9wNaUfUSkPDADeFJVd6ZYHwj4qOpp5/vWwBvuitWVeuUK8cPApoxbvJuxv+xm2e6j/OvBWnSsW9puHzTZSnr+0nenjh078uKLL7J48WKOHfv/63uqynfffUf16tX/1P/111+nZMmSbNy4keTkZAICAq615c6d+9p7X19fl9cA5syZQ3R0NLNnz+bNN9+89sXu6v9bPz+/P52nT/mcwNNPP82sWbOoW7cukyZNYvHixdfaAgMdZ8uTk5MpVKgQGzZscHnsaX1XqCpPPfUU77777nVtAQEB+Pr6ulx2dWotdWwZwd13MQ0EpojIJqAe8I6I9BORfs72YUBRYFyq21lLAstEZCOwBpijqvPdHOt1cvn58ELLavw4sBnliuRl8LQN9Po8ht9Onc/sUIzJtnr27MmwYcOoU6fOn9a3adOGMWPGXPuyW79+PQCnTp2iVKlS+Pj48OWXX6brguxVycnJHDx4kBYtWvCf//yHkydPcubMGcLDw5kyZQoA8+bN48QJx70xJUuWJCEhgWPHjnHx4kV+/PHHa/s6ffo0pUqV4vLly9e2Ta1AgQJUqlSJb7/9FnB8cW/cuBGAsLAwpk2bBnDD7e+//36mT59OQkICAMePH2f//htW374mPDycWbNmce7cOc6ePcvMmTNp1qxZen5Ft8StCUJVN6hqiKoGqepDqnpCVcer6nhney9VLey8lfXa7ayquldV6zpf96jq2+6MMy3V78rPjOea8Gr7mizfc5TWkdF8tfoAycl3dMnDmByhbNmyDB48+Lr1r732GpcvXyYoKIjatWtfu5Dbv39/Pv/8cxo3bszOnTtv6S/iK1eu8MQTT1CnTh3q16/PkCFDKFSoEP/617+Ijo4mODiYqKgoypcvD4C/v/+1C78dOnSgRo0a1/b15ptv0qhRI1q1avWn9alNmTKFTz75hLp163LPPfdcu9g+atQoxo4dS8OGDTl16pTLbWvVqsVbb71F69atCQoKolWrVtdOUd1McHAwTz/9NKGhoTRq1IhevXpRv379dP+e0ktuNlTJbkJCQtTdEwbtP3aWl7/bzMq9x2hcuQjDuwRRsVjGDemMyQjbtm2jZs2ang4jy7o6uVixYsU8HUqmcvXvQkRib/SsmZXauEUVigbyVe9GDO9Shy2//kGbkdFMiN5D0pU7u9/YGGOyGksQt0FE6BZanoVDI2hWtRjvzN3Owx+uYPuRjL2V0BjjHvHx8Tlu9HA7LEHcgbsKBjCxRwhjutfn0InzdBi9jMiFO7mYZOU6jOd50+ljc+du59+DJYg7JCI8WLc0C4dG0CGoFKN/2sWDY5ax/sCJtDc2xk0CAgI4duyYJQkD/P98EClvGU4Pu0idwX7e/jv/nBnHkT8u0DOsEn9tXY28uWxeJpO5bEY5k9qNZpS72UVqSxBucPrCZd6bv53Jqw5QrkgehncJIuxuO99pjMl67C6mTJY/wJ+3HqrDtD6N8RXh8Y9X8/J3mzh13or/GWOyD0sQbtS4clHmvxBO34jKfBNzkFaRS4jacsTTYRljTLpYgnCzAH9fXmlXk1kDwigSmIs+X8by/FfrOGrF/4wxWZwliEwSVLYQs59vyl9bVSNqy++0jFzCzPWH7C4TY0yWZQkiE+Xy82Hg/VWZM6gplYoFMuTrjfSctJbDJ634nzEm67EE4QFVS+Zner8mDOtQi1V7j9Mqcglfrtpvxf+MMVmKJQgP8fURejatRNSQcOqXL8xrs+L
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"<Figure size 432x288 with 0 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.xlabel('Number of builds')\n",
"plt.ylabel('MSE')\n",
"plt.plot(x_builds, y_MSE, label='Mean squared error')\n",
"plt.legend()\n",
"plt.show()\n",
"plt.savefig('RMSplot.png')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8dffe789-1ad5-44f1-8f21-92b9c89ed974",
"metadata": {},
"outputs": [],
"source": [
"!jupyter nbconvert --to script ml_pytorch.ipynb"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}