{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Self made SVD" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import helpers\n", "import pandas as pd\n", "import numpy as np\n", "import scipy.sparse as sparse\n", "from collections import defaultdict\n", "from itertools import chain\n", "import random\n", "import matplotlib.pyplot as plt\n", "\n", "train_read = pd.read_csv(\"./Datasets/ml-100k/train.csv\", sep=\"\\t\", header=None)\n", "test_read = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n", "(\n", " train_ui,\n", " test_ui,\n", " user_code_id,\n", " user_id_code,\n", " item_code_id,\n", " item_id_code,\n", ") = helpers.data_to_csr(train_read, test_read)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n", "from tqdm import tqdm\n", "\n", "\n", "class SVD:\n", " def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n", " self.train_ui = train_ui\n", " self.uir = list(\n", " zip(*[train_ui.nonzero()[0], train_ui.nonzero()[1], train_ui.data])\n", " )\n", "\n", " self.learning_rate = learning_rate\n", " self.regularization = regularization\n", " self.iterations = iterations\n", " self.nb_users, self.nb_items = train_ui.shape\n", " self.nb_ratings = train_ui.nnz\n", " self.nb_factors = nb_factors\n", "\n", " self.Pu = np.random.normal(\n", " loc=0, scale=1.0 / self.nb_factors, size=(self.nb_users, self.nb_factors)\n", " )\n", " self.Qi = np.random.normal(\n", " loc=0, scale=1.0 / self.nb_factors, size=(self.nb_items, self.nb_factors)\n", " )\n", "\n", " def train(self, test_ui=None):\n", " if test_ui != None:\n", " self.test_uir = list(\n", " zip(*[test_ui.nonzero()[0], test_ui.nonzero()[1], test_ui.data])\n", " )\n", "\n", " self.learning_process = []\n", " pbar = tqdm(range(self.iterations))\n", " for i in pbar:\n", " pbar.set_description(\n", " f\"Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...\"\n", " )\n", " np.random.shuffle(self.uir)\n", " self.sgd(self.uir)\n", " if test_ui == None:\n", " self.learning_process.append([i + 1, self.RMSE_total(self.uir)])\n", " else:\n", " self.learning_process.append(\n", " [i + 1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)]\n", " )\n", "\n", " def sgd(self, uir):\n", "\n", " for u, i, score in uir:\n", " # Computer prediction and error\n", " prediction = self.get_rating(u, i)\n", " e = score - prediction\n", "\n", " # Update user and item latent feature matrices\n", " Pu_update = self.learning_rate * (\n", " e * self.Qi[i] - self.regularization * self.Pu[u]\n", " )\n", " Qi_update = self.learning_rate * (\n", " e * self.Pu[u] - self.regularization * self.Qi[i]\n", " )\n", "\n", " self.Pu[u] += Pu_update\n", " self.Qi[i] += Qi_update\n", "\n", " def get_rating(self, u, i):\n", " prediction = self.Pu[u].dot(self.Qi[i].T)\n", " return prediction\n", "\n", " def RMSE_total(self, uir):\n", " RMSE = 0\n", " for u, i, score in uir:\n", " prediction = self.get_rating(u, i)\n", " RMSE += (score - prediction) ** 2\n", " return np.sqrt(RMSE / len(uir))\n", "\n", " def estimations(self):\n", " self.estimations = np.dot(self.Pu, self.Qi.T)\n", "\n", " def recommend(self, user_code_id, item_code_id, topK=10):\n", "\n", " top_k = defaultdict(list)\n", " for nb_user, user in enumerate(self.estimations):\n", "\n", " user_rated = self.train_ui.indices[\n", " self.train_ui.indptr[nb_user] : self.train_ui.indptr[nb_user + 1]\n", " ]\n", " for item, score in enumerate(user):\n", " if item not in user_rated and not np.isnan(score):\n", " top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n", " result = []\n", " # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n", " for uid, item_scores in top_k.items():\n", " item_scores.sort(key=lambda x: x[1], reverse=True)\n", " result.append([uid] + list(chain(*item_scores[:topK])))\n", " return result\n", "\n", " def estimate(self, user_code_id, item_code_id, test_ui):\n", " result = []\n", " for user, item in zip(*test_ui.nonzero()):\n", " result.append(\n", " [\n", " user_code_id[user],\n", " item_code_id[item],\n", " self.estimations[user, item]\n", " if not np.isnan(self.estimations[user, item])\n", " else 1,\n", " ]\n", " )\n", " return result" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 39 RMSE: 0.7477090330529405. Training epoch 40...: 100%|██████████| 40/40 [01:03<00:00, 1.59s/it]\n" ] } ], "source": [ "model = SVD(\n", " train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40\n", ")\n", "model.train(test_ui)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAba0lEQVR4nO3de3RV9Z338fc3IZCEIAkhVYZwsZapjoihxNtIO1jqo1DXSF0uizNPpV2dMlqsONVa6+Pq2Ge005Fp7eCljD7S6jxdxel4rcvq4wW8tWIDBhTQgqMtEWrCLYByEfJ9/tjnkJPDSXJOcpJ9zt6f11q/tffZe59zvmzjZ+/z2zdzd0REpPiVhF2AiIjkhwJdRCQiFOgiIhGhQBcRiQgFuohIRAwJ64tHjx7tEydODOvrRUSK0qpVq7a5e12meaEF+sSJE2lqagrr60VEipKZ/aG7eepyERGJCAW6iEhEKNBFRCIitD50ESk+H330ES0tLezfvz/sUiKvvLyc+vp6ysrKsn6PAl1EstbS0sKIESOYOHEiZhZ2OZHl7mzfvp2WlhaOP/74rN+nLhcRydr+/fupra1VmA8wM6O2tjbnX0IKdBHJicJ8cPRlPRddoLe0wNVXw0cfhV2JiEhhKbpA/93v4N/+Db7//bArEREpLEUX6F/4AvzN38DNN8Pq1WFXIyKDadeuXdx11105v2/27Nns2rUr5/d9+ctf5vjjj6ehoYFTTz2VZ5999si8GTNmMH78eFIfEjRnzhyqqqoA6Ojo4KqrrmLy5MmccsopnHbaabzzzjtAcKX8KaecQkNDAw0NDVx11VU515ZJUZ7lcvvtsHw5zJsHTU0wbFjYFYnIYEgG+te//vUu0w8fPkxpaWm373viiSf6/J2LFi3i4osvZvny5cyfP5+NGzcemVddXc3LL7/M9OnT2bVrF1u3bj0y74EHHmDLli2sXbuWkpISWlpaGD58+JH5y5cvZ/To0X2uK5OiDPRRo+Cee+CCC+B731P3i0gYrr4ampvz+5kNDfDjH3c///rrr+ftt9+moaGBsrIyqqqqGDNmDM3Nzaxfv545c+awefNm9u/fz8KFC5k/fz7Qee+ovXv3MmvWLKZPn85vfvMbxo4dy6OPPkpFRUWvtZ111lm89957XabNnTuXZcuWMX36dB566CEuuugi1q1bB8DWrVsZM2YMJSVBR0h9fX2f1kkuiq7LJenzn4evfAX+5V9g5cqwqxGRwfCDH/yAE044gebmZhYtWsSrr77KLbfcwvr16wFYunQpq1atoqmpicWLF7N9+/ajPmPjxo0sWLCAdevWUV1dzYMPPpjVdz/55JPMmTOny7SZM2fywgsvcPjwYZYtW8YXv/jFI/MuueQSfvWrX9HQ0MA111zDa6+91uW955xzzpEul9tuuy3HNZFZUe6hJ912GzzzTND18tprkMVGVkTypKc96cFy+umnd7nwZvHixTz88MMAbN68mY0bN1JbW9vlPck+cYBp06bx7rvv9vgd3/rWt7juuutobW3llVde6TKvtLSU6dOn88ADD7Bv3z5SbwleX1/PW2+9xXPPPcdzzz3HzJkz+eUvf8nMmTOBgelyKdo9dICRI+Hee+Gtt+DGG8OuRkQGW2qf9IoVK3jmmWf47W9/y5o1a5g6dWrGC3OGpRx0Ky0t5dChQz1+x6JFi9i0aRM333wz8+bNO2r+3Llz+cY3vsEll1yS8btmzZrFokWLuOGGG3jkkUdy+NflrqgDHeDcc+Hyy4O99RdfDLsaERlII0aMYM+ePRnntbe3U1NTQ2VlJW+++eZRe9P9UVJSwsKFC+no6OCpp57qMu/Tn/403/nOd7j00ku7TF+9ejVbtmwBgjNe1q5dy4QJE/JWUyZF3eWStGgRPPVU0Ke+Zg2kbLRFJEJqa2s5++yzmTx5MhUVFRx77LFH5p1//vksWbKEKVOm8MlPfpIzzzwzr99tZtx4443ceuutnHfeeV2mX3vttUct39rayte+9jUOHDgABN1DV1555ZH555xzzpEzc6ZMmcL999/f/xpTz6EcTI2NjZ7PJxY9/zzMmAELFsAdd+TtY0UkxYYNGzjppJPCLiM2Mq1vM1vl7o2Zli/6Lpekv/orWLgQ7rwTUs79FxGJjV4D3czKzexVM1tjZuvM7HsZlplhZu1m1pxo3x2Ycnv2/e/DpElBn7qISLYWLFhw5BTCZPvpT38adlk5y6YP/QDwWXffa2ZlwEtm9mt3Tz/i8KK7X5D/ErNXWRmcwnjjjXDwIAwdGmY1ItHk7pG74+Kdd94ZdglH6Ut3eK976B7Ym3hZlmjhdLxnIXnKaYbrCUSkn8rLy9m+fXufwkayl3zARXl5eU7vy+osFzMrBVYBnwDudPdM12aeZWZrgC3Ate6+LsPnzAfmA4wfPz6nQrOVGuhjxgzIV4jEVn19PS0tLbS1tYVdSuQlH0GXi6wC3d0PAw1mVg08bGaT3f2NlEVWAxMS3TKzgUeASRk+527gbgjOcsmp0iyNGhUMtYcukn9lZWU5PRJNBldOZ7m4+y5gBXB+2vTdyW4Zd38CKDOz/F7TmiV1uYhIXGVzlktdYs8cM6sAPge8mbbMcZY4SmJmpyc+N5RIVaCLSFxl0+UyBrgv0Y9eAvynuz9uZpcDuPsS4GLgCjM7BOwD5npIR00U6CISV70GuruvBaZmmL4kZfwOoCCuz6yshPJyBbqIxE9krhRNVVurQBeR+FGgi4hEhAJdRCQiFOgiIhGhQBcRiYjIBvqOHaDbTYhInEQ20A8fhvb2sCsRERk8kQ10ULeLiMSLAl1EJCIU6CIiEaFAFxGJCAW6iEhERDLQq6vBTIEuIvESyUAvLYWaGgW6iMRLJAMddLWoiMSPAl1EJCIU6CIiEaFAFxGJCAW6iEhERDrQP/gADhwIuxIRkcER6UAH7aWLSHwo0EVEIkKBLiISEQp0EZGI6DXQzazczF41szVmts7MvpdhGTOzxWa2yczWmtmnBqbc7CnQRSRuhmSxzAHgs+6+18zKgJfM7Nfu/krKMrOASYl2BvCTxDA0CnQRiZte99A9sDfxsizR0h+/fCFwf2LZV4BqMxuT31JzU1ERNAW6iMRFVn3oZlZqZs1AK/C0u69MW2QssDnldUtiWvrnzDezJjNramtr62PJ2dPFRSISJ1kFursfdvcGoB443cwmpy1imd6W4XPudvdGd2+sq6vLudhcKdBFJE5yOsvF3XcBK4Dz02a1AONSXtcDW/pTWD4o0EUkTrI5y6XOzKoT4xXA54A30xZ7DLgscbbLmUC7u2/Nd7G5UqCLSJxkc5bLGOA+Mysl2AD8p7s/bmaXA7j7EuAJYDawCfgQ+MoA1ZsTBbqIxEmvge7ua4GpGaYvSRl3YEF+S+u/2lrYsQM6OqAkspdQiYgEIh1ztbVBmLe3h12JiMjAi3ygg7pdRCQeFOgiIhGhQBcRiQgFuohIRCjQRUQiItKBXl0dnK6oQBeROIh0oJeUQE2NAl1E4iHSgQ66WlRE4kOBLiISEQp0EZGIUKCLiESEAl1EJCJiEegffgj794ddiYjIwIpFoIP20kUk+hToIiIRoUAXEYkIBbqISEQo0EVEIkKBLiISEZEP9PJyqKxUoItI9EU+0EEXF4lIPCjQRUQiQoEuIhIRvQa6mY0zs+VmtsHM1pnZwgzLzDCzdjNrTrTvDky5faNAF5E4GJLFMoeAa9x9tZmNAFaZ2dPuvj5tuRfd/YL8l9h/CnQRiYNe99Ddfau7r06M7wE2AGMHurB8qq2FnTuhoyPsSkREBk5OfehmNhGYCqzMMPssM1tjZr82s5O7ef98M2sys6a2trbcq+2j2togzHftGrSvFBEZdFkHuplVAQ8CV7v77rTZq4EJ7n4qcDvwSKbPcPe73b3R3Rvr6ur6WHLudHGRiMRBVoFuZmUEYf5zd38ofb6773b3vYnxJ4AyMxud10r7QYEuInGQzVkuBtwLbHD3H3WzzHGJ5TCz0xOfWzDxqUAXkTjI5iyXs4EvAa+bWXNi2g3AeAB3XwJcDFxhZoeAfcBcd/f8l9s3CnQRiYNeA93dXwKsl2XuAO7IV1H5pkAXkTiIxZWiI0dCSYkCXUSiLRaBXlICo0Yp0EUk2mIR6KCrRUUk+hToIiIRoUAXEYkIBbqISETEKtB37Ai7ChGRgROrQP/wQ9i/P+xKREQGRqwCHdTtIiLRpUAXEYmI2AT6qFHBUIEuIlEVm0DXHrqIRJ0CXUQkIhToIiIREZtALy+HykoFuohEV2wCHXS1qIhEmwJdRCQiFOgiIhGhQBcRiQgFuohIRMQu0HfuhI6OsCsREcm/2AV6Rwfs2hV2JSIi+Re7QAd1u4hINCnQRUQiQoEuIhIRvQa6mY0zs+VmtsHM1pnZwgzLmJktNrNNZrbWzD41MOX2jwJdRKJsSBbLHAKucffVZjYCWGVmT7v7+pRlZgGTEu0M4CeJYUFRoItIlPW6h+7uW919dWJ8D7ABGJu22IXA/R54Bag2szF5r7afRo6EkhIFuohEU0596GY2EZgKrEybNRbYnPK6haNDHzObb2ZNZtbU1taWY6n9V1ISPLlIgS4iUZR1oJtZFfAgcLW7706fneEtftQE97vdvdHdG+vq6nKrNE90taiIRFVWgW5mZQRh/nN3fyjDIi3AuJTX9cCW/peXfwp0EYmqbM5yMeBeYIO7/6ibxR4DLkuc7XIm0O7uW/NYZ94cfzz8/vdhVyEikn/Z7KGfDXwJ+KyZNSfabDO73MwuTyzzBPDfwCbgHuDrA1Nu/02bBi0t8P77YVciIpJfvZ626O4vkbmPPHUZBxbkq6iB1NgYDFetgtmzw61FRCSfYnWlKMDUqWAGTU1hVyIikl+xC/SqKjjxxGAPXUQkSmIX6BB0u2gPXUSiJpaBPm0abNkCWwvyPBwRkb6JZaCnHhgVEYmKWAZ6Q0NwGwB1u4hIlMQy0IcPh5NO0h66iERLLAMdgn70pibwo+44IyJSnGIb6I2N8Kc/BQdHRUSiILaBPm1aMFS3i4hERWwDXQdGRSRqYhvolZVw8snaQxeR6IhtoIMOjIpItMQ60BsbobUV3nsv7EpERPov1oGePDCqfnQRiYJYB/qpp0JpqQJdRKIh1oFeUaEDoyISHbEOdOi8la4OjIpIsYt9oE+bBtu2webNYVciItI/sQ/05K101Y8uIsUu9oE+ZQoMGaJAF5HiF/tALy+HyZN1YFREil/sAx10YFREokGBTnBgdMcO+MMfwq5ERKTveg10M1tqZq1m9kY382eYWbuZNSfad/Nf5sDSgVERiYJs9tB/BpzfyzIvuntDov3v/pc1uE45BcrK1I8uIsWt10B39xeAHYNQS2iGDQtCXXvoIlLM8tWHfpaZrTGzX5vZyd0tZGbzzazJzJra2try9NX5MW1asIeuA6MiUqzyEeirgQnufipwO/BIdwu6+93u3ujujXV1dXn46vxpbISdO+Gdd8KuRESkb/od6O6+2933JsafAMrMbHS/KxtkupWuiBS7fge6mR1nZpYYPz3xmdv7+7mDbfJkGDpUB0ZFpHgN6W0BM/sFMAMYbWYtwD8CZQDuvgS4GLjCzA4B+4C57sXXEz1sWHAbAO2hi0ix6jXQ3f3SXubfAdyRt4pCNG0aLFsWHBgNfnOIiBQPXSmaorER2tvh7bfDrkREJHcK9BQ6MCoixUyBnuLkk4O+9BUrwq5ERCR3CvQUQ4fC3/4tLF2qbhcRKT4K9DT/9E/BfV2uvz7sSkREcqNAT/NnfwbXXQf/9V/w8sthVyMikj0FegbXXhsE+zXX6N4uIlI8FOgZDB8ON98MK1fCAw+EXY2ISHYU6N247DI49dSgL33//rCrERHpnQK9G6Wl8MMfBo+lW7w47GpERHqnQO/BzJlwwQVwyy1QYLdvFxE5igK9F4sWwQcfwE03hV2JiEjPFOi9OPFE+Pu/h3//d9iwIexqRES6p0DPwk03BWe+XHdd2JWIiHRPgZ6Fujq44QZ4/HF47rmwqxERyUyBnqWFC2HChOBio8OHw65GRORoCvQslZfDP/8zNDfDP/wDHDoUdkUiIl0p0HMwdy5cfTXcfjucdx5s2xZ2RSIinRToOTCD226Dn/0suHFXY2Owxy4iUggU6H0wbx68+GLQ7fKXfxk8h1REJGwK9D467TRYtSp4bN2ll8K3v62DpSISLgV6Pxx7LDz7LFxxBdx6K8yeDTt2hF2ViMSVAr2fhg6Fu+6Ce+6B5cuDfvWlS2HfvrArE5G4UaDnyd/9HTz/fHBF6Ve/CvX1QTfMu++GXZmIxEWvgW5mS82s1cze6Ga+mdliM9tkZmvN7FP5L7M4nHUWrF0b7Kmfc05w+90TToA5c+CZZ/T0IxEZWNnsof8MOL+H+bOASYk2H/hJ/8sqXmYwY0bwTNJ33gkekPHyy3DuufAXfxGc9vj669DREXalIhI1vQa6u78A9HSo70Lgfg+8AlSb2Zh8FVjMxo0L7qW+eTPcdx9UVcE3vwlTpsDHPgYXXRQ8PGPNGgW8iPTfkDx8xlhgc8rrlsS0rekLmtl8gr14xo8fn4evLg7l5cEj7S67LOhTf/55WLEiGD78cLBMTQ185jNw5plw0klB+/jHYUg+/guJSCzkIy4sw7SMvcXufjdwN0BjY2Mse5QnTgzavHnB6z/+sTPgV6yARx/tXLasDCZNCsL9xBM7Q76+HsaMUdiLSFf5iIQWYFzK63pgSx4+NxbGj4cvfSloAO3t8OabwcM0ksPXX4dHHul64VJJCRx3XNCtU1/f2Y49NujOqavrHA4bFso/TUQGWT4C/THgSjNbBpwBtLv7Ud0tkp2RI+GMM4KW6sAB2LQpeGh1S0tn27wZ1q2DJ58MHpXX3WfW1QWtpiZo1dVBSx2vroZjjgmWP+aYoFVUBAd6RaTw9RroZvYLYAYw2sxagH8EygDcfQnwBDAb2AR8CHxloIqNs2HD4OSTg5aJO+zeDe+/HzzQurW1c5gcb2uDP/0p2PPftStovR2MHTKkM9xHjDi6pU+vquq5lZdrAyEyUHoNdHe/tJf5DizIW0XSJ2bBnvXIkfDnf57dezo6YO9e2LmzM+B37+657dkTLP/HPwbjyZbtWTolJT0H/vDhmce7ayNGBMuVlvZ1zYlEhw6rxVhJSefe94QJff8cd/jww2DjkN727OkcfvBB5mX27g1+OSTnf/BBsHwuNzurqOgM+PTAzzTe0zIjRkBlZbB+RIqJAl36zSzYSx4+PDgomw/ucPBgzxuB9A1HckOQnNbeDu+913XewYPZ15Ae+Jm6nFK7mjJtMFLn6awkGWj6E5OCZBYcNxg2DEaNyt/nZtpIpP6KSLbuXm/ZcvSvjmyVl3d/HCLZXZY6nvo6daizlqQ7CnSJlaFDg1ZTk5/PO3y4+w1E6kYhfQORbG1t8Pbbwa+J9nbYvz+7f0PqmUipG4DeWnL54cN1cDqKFOgi/VBa2hms+XDwYHDwORnw7e1dD0qnvk6d/847wUHt5OvebgSXrDv99NVMw/TxmppgoyKFR4EuUkCGDoXRo4PWV8mzl1I3CukbiOR48uymnTuD01mTZzz1dj//ysog5EeN6tpqa7u+rqnpOhwxQr8MBpICXSRiUs9eGjeu9+UzOXCgM9x37uxs6a937gye0rVpUzDcvj14b3dKSzv38lM3ALW1nS31dXK8qkobgmwo0EXkKMOGBbeWOO643N+7b18Q7jt2dAZ+6njqtPffh/Xrgw3Bnj3df2ZZ2dHBX1sbXP2c/EWTPh7HjYACXUTyqqICxo4NWi4OHgzCfvv2zpbc608f37QJVq6Ebdvgo48yf15q91V64KdPSw6L/diAAl1ECsLQocF1DLlcy5C85cW2bUFra+s6TG2vvRYMe3qQ+zHHdN73aPTozhvcddcqK/v/784nBbqIFK3UW16ccEJ27zl0KAj1TBuA1GFLS7ARaGvr/oK04cOPvrtp6kYgfVp5ef7+7Zko0EUkVoYMCQL2Yx/Lbvnkr4DkDe5Sb36XOp7cALS2dt8NVFUVfO+CBcHTy/JNgS4i0oPUXwGf+ETvyyc3AKl3PE0P/74cbM6GAl1EJI9SNwCTJg3ud+t+ciIiEaFAFxGJCAW6iEhEKNBFRCJCgS4iEhEKdBGRiFCgi4hEhAJdRCQizHt7tMlAfbFZG/CHHhYZDWwbpHJypdr6RrX1jWrrm6jWNsHd6zLNCC3Qe2NmTe7eGHYdmai2vlFtfaPa+iaOtanLRUQkIhToIiIRUciBfnfYBfRAtfWNausb1dY3sautYPvQRUQkN4W8hy4iIjlQoIuIRETBBbqZnW9mb5nZJjO7Pux6UpnZu2b2upk1m1lTyLUsNbNWM3sjZdooM3vazDYmhjUFVNtNZvZeYt01m9nskGobZ2bLzWyDma0zs4WJ6aGvux5qC33dmVm5mb1qZmsStX0vMb0Q1lt3tYW+3lJqLDWz18zs8cTrAVlvBdWHbmalwO+Bc4EW4HfApe6+PtTCEszsXaDR3UO/WMHMPgPsBe5398mJabcCO9z9B4mNYY27f7tAarsJ2Ovu/zrY9aTVNgYY4+6rzWwEsAqYA3yZkNddD7VdQsjrzswMGO7ue82sDHgJWAhcRPjrrbvazqcA/uYAzOybQCNwjLtfMFD/rxbaHvrpwCZ3/293PwgsAy4MuaaC5O4vADvSJl8I3JcYv48gDAZdN7UVBHff6u6rE+N7gA3AWApg3fVQW+g8sDfxsizRnMJYb93VVhDMrB74PPB/UiYPyHortEAfC2xOed1CgfxBJzjw/8xslZnND7uYDI51960QhAOQ5XPNB82VZrY20SUTSndQKjObCEwFVlJg6y6tNiiAdZfoNmgGWoGn3b1g1ls3tUEBrDfgx8B1QEfKtAFZb4UW6JZhWsFsaYGz3f1TwCxgQaJrQbLzE+AEoAHYCvwwzGLMrAp4ELja3XeHWUu6DLUVxLpz98Pu3gDUA6eb2eQw6sikm9pCX29mdgHQ6u6rBuP7Ci3QW4BxKa/rgS0h1XIUd9+SGLYCDxN0ERWS9xP9sMn+2NaQ6znC3d9P/E/XAdxDiOsu0c/6IPBzd38oMbkg1l2m2gpp3SXq2QWsIOijLoj1lpRaW4Gst7OBv04cf1sGfNbM/i8DtN4KLdB/B0wys+PNbCgwF3gs5JoAMLPhiQNVmNlw4H8Ab/T8rkH3GDAvMT4PeDTEWrpI/vEmfIGQ1l3iANq9wAZ3/1HKrNDXXXe1FcK6M7M6M6tOjFcAnwPepDDWW8baCmG9uft33L3e3ScS5Nlz7v4/Gaj15u4F1YDZBGe6vA38r7DrSanr48CaRFsXdm3ALwh+Rn5E8Mvmq0At8CywMTEcVUC1/QfwOrA28cc8JqTaphN0460FmhNtdiGsux5qC33dAVOA1xI1vAF8NzG9ENZbd7WFvt7S6pwBPD6Q662gTlsUEZG+K7QuFxER6SMFuohIRCjQRUQiQoEuIhIRCnQRkYhQoIuIRIQCXUQkIv4/uXWfnCx11KgAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df = pd.DataFrame(model.learning_process).iloc[:, :2]\n", "df.columns = [\"epoch\", \"train_RMSE\"]\n", "plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyqElEQVR4nO3de3zPdf/H8cdrY2bO25yHYYg5DENiOR9TShIuinLJFdFVFFJdKVcHV4Wr80FR/XQuOpFQ5CqHyZnZMMzEmDkP296/P96zhuE7tn2+3+9e99ttN/t+Dt/v67NPPffZ+/P+vN9ijEEppZT38nG6AKWUUvlLg14ppbycBr1SSnk5DXqllPJyGvRKKeXlijhdQE6Cg4NNaGio02UopZTHiI6OPmiMKZ/TOrcM+tDQUFavXu10GUop5TFEZNel1l2x6UZEZorIARHZeIn1IiIzRCRORNaLSLNs67qLSEzmuvFXV75SSqlr4Uob/ftA98us7wHUyfwaDrwOICK+wKuZ6xsAA0SkwbUUq5RSKveuGPTGmKVA8mU26Q3MNtbvQFkRqQy0BOKMMTuMMWeAjzO3VUopVYDyoo2+KrAn2+uEzGU5LW91qTcRkeHYvwioXr16HpSllCooZ8+eJSEhgdTUVKdL8Xr+/v6EhIRQtGhRl/fJi6CXHJaZyyzPkTHmLeAtgMjISB2ARykPkpCQQKlSpQgNDUUkp//1VV4wxnDo0CESEhKoWbOmy/vlRT/6BKBattchQOJlliulvExqaipBQUEa8vlMRAgKCsr1X055EfTzgLsye99cDxwxxuwDVgF1RKSmiPgB/TO3VUp5IQ35gnE1P+crNt2IyBygPRAsIgnAk0BRAGPMG8D3QE8gDjgJDM1clyYio4AFgC8w0xizKdcV5kKGySA6MZoWVVvk58copZRHuWLQG2MGXGG9AUZeYt332F8EBeLzzZ9z5+d30qtuL6Z0nELjio0L6qOVUsptedVYN73q9uLZTs/y6+5fiXgjgkFfDmLH4R1Ol6WUymcpKSm89tprud6vZ8+epKSk5Hq/IUOGULNmTSIiImjSpAmLFi3KWte+fXuqV69O9kmdbr31VkqWLAlARkYGo0ePpmHDhjRq1IgWLVqwc+dOwI4K0KhRIyIiIoiIiGD06NG5ri0nbjkEwtUKKBrA+Lbjua/5fbyw/AWmr5jOmn1r2HT/Jm0/VMqLnQv6+++//7zl6enp+Pr6XnK/77+/+gaHqVOn0rdvX5YsWcLw4cOJjY3NWle2bFmWL19O27ZtSUlJYd++fVnrPvnkExITE1m/fj0+Pj4kJCRQokSJrPVLliwhODj4quvKiVcF/Tnlipfj2c7P8kCrB9hzZA8iwsmzJ3npt5d4oOUDlPEv43SJSnmtBx+EtWvz9j0jImDatEuvHz9+PNu3byciIoKiRYtSsmRJKleuzNq1a9m8eTO33nore/bsITU1lTFjxjB8+HDgr3G1jh8/To8ePWjbti3/+9//qFq1KnPnzqV48eJXrK1169bs3bv3vGX9+/fn448/pm3btnz55Zf06dOHTZvsLcp9+/ZRuXJlfHxsg0pISMhV/Uxyw6uabi5UpVQVWoXYZ7QWbl/I40sep9aMWkxdPpVTZ085XJ1SKq8899xz1K5dm7Vr1zJ16lRWrlzJlClT2Lx5MwAzZ84kOjqa1atXM2PGDA4dOnTRe8TGxjJy5Eg2bdpE2bJl+eKLL1z67Pnz53Prrbeet6xTp04sXbqU9PR0Pv74Y+68886sdf369eObb74hIiKChx9+mD/++OO8fTt06JDVdPPyyy/n8ieRM6+8os9J7+t6s2b4Gh5b/BiP/PQI01ZMY9wN4xjdajQ+4tW/75QqUJe78i4oLVu2PO+BohkzZvDVV18BsGfPHmJjYwkKCjpvn3Nt7gDNmzcnPj7+sp8xbtw4HnnkEQ4cOMDvv/9+3jpfX1/atm3LJ598wqlTp8g+7HpISAgxMTEsXryYxYsX06lTJz777DM6deoE5E/TTaFKuKaVm/L9377n57t/pm5QXb7a+lVWyJ9NP+twdUqpvJK9zfvnn3/mp59+4rfffmPdunU0bdo0xweOihUrlvW9r68vaWlpl/2MqVOnEhcXxzPPPMPdd9990fr+/fvzwAMP0K9fvxw/q0ePHkydOpWJEyfy9ddf5+Locq9QBf057ULbseTuJXw38DsAEo8lUu3lajyx5AkOnbz4TzqllHsrVaoUx44dy3HdkSNHKFeuHAEBAWzduvWiq+9r4ePjw5gxY8jIyGDBggXnrYuKimLChAkMGHB+D/U1a9aQmGgHCcjIyGD9+vXUqFEjz2rKsc58fXc3V9LPdnc6nXaaNtXb8PTSp6kxrQZjfxzLvmP7rrC3UspdBAUF0aZNGxo2bMi4cePOW9e9e3fS0tJo3Lgxjz/+ONdff32efraIMGnSJF544YWLlo8dO/aiZpgDBw5w880307BhQxo3bkyRIkUYNWpU1vrsbfR33XVX3tSYva+nu4iMjDROzDC16cAmnv31WeZsnIN/EX/2/HMPgcUDC7wOpTzNli1bqF+/vtNlFBo5/bxFJNoYE5nT9oX6iv5C4RXC+bDPh8SMimFat2lZIf/owkd5d8277D++3+EKlVIq9wpNr5vcCAsMIywwDIBTZ0/x2ebP2JmyE0FoFdKK3vV6c0eDO6gdWNvhSpVS+WnkyJEsX778vGVjxoxh6NChDlV0dTTor6B40eJsH72d9fvXMy9mHnNj5jJh0QRK+pVkVMtRHD51mPX719OmehuK+OiPUylv8uqrrzpdQp7QZHKBiNCkUhOaVGrC4+0eJ+FoAiWK2u5bX2/9mnvm3UM5/3LcXO9mBjUaRMeaHfH1ufRj10opVZA06K9CSOm/Hlm+I/wOyviXYV7MPL7e+jWz180mpHQI60esp1zxcg5WqZRSlgb9NSrpV5I+9fvQp34fUtNSmRczj98Tfs8K+SeXPEmFEhXo37A/QQFBV3g3pZTKe9rrJg/5F/GnX3g/Xur2EmAnQpm/fT6jfhhF5Rcrc/untzMvZp4+hauUKlAa9PnIR3xYMWwFf9z3ByNbjGTZrmX0/rg3/172b8D26Pnz+J8OV6mU57va8egBpk2bxsmTJy+7zblx4hs3bky7du3YtWtX1joRYfDgwVmv09LSKF++PL169QJg//799OrViyZNmtCgQQN69uwJQHx8PMWLF896OCoiIoLZs2df1TFckTHG7b6aN29uvNGZtDNm3tZ5ZufhncYYY+ZtnWf4Fyb81XAz+vvRZu7WuSblVIqzRSp1FTZv3uzo5+/cudOEh4df1b41atQwSUlJLm/zxBNPmGHDhmWtK1GihImIiDAnT540xhjz/fffmyZNmpibbrrJGGPM8OHDzbRp07K2X7du3TXXnNPPG1htLpGp2kZfgIr6FuXmejdnvW5UsRHPd36eRTsX8faat5mxcgY+4kPMqBjCAsNIOpFEqWKl8C/i72DVSuVe+/fbX7SsX3g/7m9xPyfPnqTnRz0vWj8kYghDIoZw8ORB+n7a97x1Pw/5+bKfl308+i5dulChQgU+/fRTTp8+zW233cZTTz3FiRMn6NevHwkJCaSnp/P444+zf/9+EhMT6dChA8HBwSxZsuSKx9a6dWtmzJhx3rIePXrw3Xff0bdvX+bMmcOAAQNYtmwZYMef79q1a9a2jRsX/BSn2nTjoNCyoTzS5hEWDFrA4UcP8/PdPzO5/WRql7MPYk1cNJEKUyvwwPcPEHMwxuFqlXJf2cej79KlC7GxsaxcuZK1a9cSHR3N0qVLmT9/PlWqVGHdunVs3LiR7t27M3r0aKpUqcKSJUtcCnnIefz5cxONpKamsn79elq1apW1buTIkdx777106NCBKVOmZA1oBmT9cjr3de6XQ17TK3o3UaxIMdqFtqNdaLusZX9r/DdS01N5a81bvLLqFbrW7srDrR+ma+2ul3knpZx3uSvwgKIBl10fHBB8xSv4y/nxxx/58ccfadq0KQDHjx8nNjaWqKgoxo4dy6OPPkqvXr2IiorK1ft26NCB/fv3U6FCBZ555pnz1jVu3Jj4+HjmzJmT1QZ/Trdu3dixYwfz58/nhx9+oGnTpmzcuBEg65dTftMrejfWPrQ9H9z2Absf3M3THZ5m44GNfLfNDq1sjCElNcXZApVyQ8YYJkyYwNq1a1m7di1xcXHce++91K1bl+joaBo1asSECROYPHlyrt53yZIl7Nq1i/DwcJ544omL1t9yyy2MHTv2omGJAQIDAxk4cCAffPABLVq0YOnSpVd9fFdDg94DVCxZkUk3TiJ+TDyTO9j/OH+O/5mqL1Xlvm/uY8P+DQ5XqJSzso9H361bN2bOnMnx48cB2Lt3LwcOHCAxMZGAgAAGDRrE2LFjWbNmzUX7Xknx4sWZNm0as2fPJjk5+bx199xzD0888QSNGjU6b/nixYuzevUcO3aM7du3U7169Ws63tzSoPcgRX2LZk1sXrV0VQY0HMDs9bNp/EZj2r/fnpd+e0n76KtCKft49AsXLmTgwIG0bt2aRo0a0bdvX44dO8aGDRto2bIlERERTJkyhUmTJgEwfPhwevToQYcOHVz6rMqVKzNgwICLxsEJCQlhzJgxF20fHR1NZGQkjRs3pnXr1gwbNowWLVoAF7fRX3iTN6/oePQeLvlUMu+ueZc3o98kJTWFpHFJiAiPLnyUnSk7aVa5Gc0qN6NppaaUL1He6XKVl9Lx6AtWbsej15uxHi6weCDj2oxjXJtxHEk9gogAkG7Sid4XzWebP8vatnOtziwcvBCADfs3UDuwNgFFAxypWylVcDTovci5Zh2A/3T9D//p+h8OnzrM2j/XsmbfGkr42RE3M0wG7d5vx/Ezx4msEsmNNW7kxho30qZam/PeQ6nCplWrVpw+ffq8ZR988MFF7e6eRptuCqH0jHTmx81n2e5lLN21lNWJqzmbcZZxN4zjhS4vcDrtNPNi5hFVI4pKJSs5Xa7yAFu2bOG6667L+otS5R9jDFu3btWmG3V5vj6+3FT3Jm6qexMAJ8+eZEXCCqqUqgLAqsRV9Pu8HwANyjegc83OdKndhQ6hHbL+KlAqO39/fw4dOkRQUJCGfT4yxnDo0CH8/XP3tLxe0auLnE0/y5p9a/hl1y8s2rmIpbuWkpqWytIhS4mqEUVcchwHTx4kskqkzqqlADh79iwJCQmkpqY6XYrX8/f3JyQkhKJFi563/HJX9C4FvYh0B6YDvsA7xpjnLlhfDpgJ1AZSgXuMMRsz18UDx4B0IO1ShWSnQe9eUtNSWb57OVE1ovDz9WPsj2N58bcXKVOsDB1qdqBLrS50qdWFsMAwvZpTyiHXFPQi4gtsA7oACcAqYIAxZnO2baYCx40xT4nIdcCrxphOmevigUhjzEFXC9agd28HTx5k8c7FLNy+kIU7FrLryC4CiwdyYOwBfH18WZ24muplqlOhRAWnS1Wq0LjWNvqWQJwxZkfmm30M9AY2Z9umAfAsgDFmq4iEikhFY8z+aytduaPggGD6hfejX3g/jDFsP7yd7cnbs+bJHfjFQGKTY2lSsQldanWhc63ORNWI0q6cSjnElSdjqwJ7sr1OyFyW3TqgD4CItARqAOcmVjXAjyISLSLDL/UhIjJcRFaLyOqkpCRX61cOExHCAsPoFtYNsDeLPurzEVM6TiGweCAzVs6g+0fdGfn9yKz1KxJW6BO8ShUgV67oc2p0vbC95zlguoisBTYAfwBpmevaGGMSRaQCsFBEthpjLhrRxxjzFvAW2KYbF+s/z+uvQ4cOcN11V7O3ygsiQouqLWhRtQUToyZy4swJft39K8EBwQBsO7SN69+9npJ+JWlbvS3ta7SnfWh7mldprjd2lconrlzRJwDVsr0OARKzb2CMOWqMGWqMiQDuAsoDOzPXJWb+ewD4CtsUlOdSUmDiRGjYEP7xD9ivjUZuoYRfCbqFdaN5leYAVClVhc/u+Iy7Gt/F7iO7Gb9oPNe/ez3fxHwDwN6je1m1dxVpGWmXe1ulVC64cgm1CqgjIjWBvUB/YGD2DUSkLHDSGHMGGAYsNcYcFZESgI8x5ljm912B3I0N6qKyZSEmBiZPhjffhA8/hEcegYceghLa9dttlCpWir4N+tK3gZ1BaP/x/SzdtZT2oe0B+GjDRzz606OU8itFq5BW1AuqR1hgGMObDyegaADGGO3Zo1Quudq9sicwDdu9cqYxZoqIjAAwxrwhIq2B2dgulJuBe40xh0WkFvYqHuwvlf8zxky50udda6+bbdtgwgT48kuoXNmG/9Ch4Ot71W+pCkjSiSSWxC9hyc4lrEpcRWxyLMfPHOfUY6fw8/XjoQUPMTdmLmGBYdQJrEOdwDqEBYbRs05P/QWgCrVr7kdf0PKqe+Xy5TBuHPz2G4SHwwsvQI8eoHngOYwxHE49TGDxQAA+XP8h38V+R+yhWGKTYzl6+iiVS1Ym8WHbmjhl6RSOnD6SNWJnnaA6+IiOxq28X6ENegBj4IsvYPx42L7d3qydOhWaN8+Tt1cOMsZw8ORBkk4m0aB8AwBu//R2vt32LWfSzwBQ0q8k/cP78/YtbwOw4/AOQsuGavgrr1Oog/6cM2ds2/1TT8GhQzBgAEyaBA0a5OnHKDdwJv0Mm5M288e+P1izbw3VylTjkTaPkGEyKPNcGYr6FKVt9bZZo3Y2rdSUor5Fr/zGSrkxDfpsjhyB556DGTPg1Cm47TbbW0ev8L3f2fSzzNk4h2W7lrF091K2HdoGwKSoSTzd8WlOnT3FqsRVtKzaEv8iuRs0SimnadDnICnJhv1//2vDv2tXeOwxiIrSNvzCYt+xfSzbvYzw8uGEVwhnyc4ldJzdET9fP1pVbUWzys1oWKEhN9e9mYolKzpdrlKXpUF/GUeP2getXnoJDhyANm3sFb7etC18jp4+yi/xv/DLrl9YtnsZG/Zv4FTaKVYMW0HLqi2ZFzOPN1a/kfWLoWGFhtQPrq9DNyu3oEHvglOnYOZM2zNn925o0sQG/u23a7fMwirDZBCfEk/VUlUpVqQYn2z8hOeWP8eWpC2cTrezEAlC4sOJVCpZiRUJK0g+lUzTyk11whZV4DToc+HsWfjoI9uOHxMDdevCww/D4MFQvLgjJSk3k5aRxo7DO9h4YCMxB2MY33Y8IsKgLwfx0YaPAKhUshJNKzWlRZUWPNXhKYcrVoWBBv1VSE+Hr76CZ5+FNWsgKAhGjID774cqVRwtTbmpI6lHWLd/HX/s+4M//rRfRXyKED08GoA+n/Rh77G9WQ961QmqQ8MKDWlcsbHDlStvoEF/DYyBZcvg5Zdh7lzbjHPnnfDPf2pPHXVlGSYjq8/+xEUTWbl3JXHJcew+shuDoVvtbswfNB+AOz67g4CiAVm/COoF16NOYB29B6BcokGfR3bssL103n0Xjh2Dtm1t4Pfure34KndS01LZcXgH6RnpNKrYCGMMXT/sypakLew9tjdru2FNh/H2LW+TYTIY/cNowgLDqBdUj7pBdQktG5o1B4BSGvR57OhRe+N2xgzYuRNCQ+GBB+Dee6FMGaerU57u5NmTxCXHEXMwhmplqnF9yPX8efxP6r9an5TUlKzt/Hz9eLHri4xqOYp9x/YxafEkyvqXpVzxcvZf/3K0qd6G0LKhnE47zam0U5T1L+vYcan8pUGfT9LTYd48mDYNli61o2QOGmTb8iMinK5OeZtzQz7EHIoh5mAM2w5to1fdXkTViGLTgU10+7AbKakpnDh7ImufD277gEGNB7Fs1zJufP9G6gfX54ZqN9CmWhvaVG9DncA6Ohicl9CgLwBr1sArr8CcOZCaCtdfb8fF79cP/PUhS1WAzqSfISU1hZTUFCqUqEBZ/7LsPbqXWetmsXzPcv63539ZfxksHbKUqBpRxCXHsf/4fiKrRFKsSDFnD0BdFQ36AnT4MMyaBW+8YbtnBgbaIZLvuw/q1HG6OqXsDeKtB7eyfPdyBjUeRPGixRn/03ieX/48fr5+RFaJpE5gHaqUqsLTHZ7G18eXPUf2YDBUKlkJP18/pw9B5UCD3gHGwJIl9qnbr7+GtDTo3Nle5d9yCxTRWfOUG0k6kcTyPctZvns5v+/9nV0puzhx9gSHHjkEcN4zAuUDylO1dFXqBdXj474fA7Bm3xqK+BShdrna2kvIIRr0Dtu3z/bUeest2LPH9sMfNsx+Vat25f2VckL22bx+T/idjQc2kngskcRjiew9tpdivsX4vN/nALR7vx1Ld9mpoCuXrExYYBg31riRZzo+A9i5gov5FqOMfxlK+ZXS3kL5QIPeTaSlwfff26v8BQvsWDq9etlmnW7dtIum8lwb9m9gy8EtxCXHsT15O3GH4wgtG8qsW2cBUHN6TeJT4rO2L+lXkr4N+vJe7/cA6PdZPwyG0n6lCQ4Ipla5WrSo2oJmlZs5cTge6XJBrw0IBahIEdtsc8sttlvm22/bK/1586BGDfj7320XzUo6TIryMI0qNqJRxUaXXD+9+3QOnjzIkdQjHD19lCOnjxBePjxr/YETBzhw4gBHTx8l6WQSZ9LP8I/If/DaTa+RlpFGw9caUr1MdWqVq0WtcrWoXa42zas0J7RsaAEcnefTK3qHnTljn7h94w1YvNj+Mrj1VttFs0MH8NGJkFQhk2Ey2Ht0LyJCSOkQjqQeYcR3I9hxeAfbk7dz6JS9b/Dvjv9mQtQE4lPiafhaQ4ICgggOCCaoeBBBAUH8vdnf6VizI8mnkpkfN5/A4oGU8y9HYPFAAosHUta/rFc1IekVvRvz84M77rBf27bZdvz33oPPP4ewMBg+HO6+GypUcLpSpQqGj/hQrcxfN6/K+Jdhzu1zsl4fST3CjsM7CA4IBsC/iD/Dmw/n0KlDHDx5kEMnD7EzZSe3XXcbAFsPbuVvX/7tos/5pO8n9Avvx297fuOBHx6wvwiKlyO4eDBBAUEMiRhCrXK1OHjyILtSdhEcEExwQDABRQMuevYgLSONY6ePcezMMY6ePsqx08doXqU5fr5+rNm3hj+P/0nNsjWpUbYGAUUD8uPHdll6Re+GUlPtPLdvvAG//mqv8nv3tk07nTtrW75SuZGalkp8SjyHTx0m+VQyh1PtvzfVuYnagbVZuXclT/3yFMmnkkk+lcyhk4dIPpXML0N+IapGFB+u/5DBXw3Oer9ivsUIDgjm5yE/ExYYxvTfp/Pgggcv+txdD+6iepnqTFk6hUlLJmUtr1CiAjXL1uTHwT9SulhpohOjOXjyIKFlQ6lVrtZVT2upN2M92ObNth1/1iw712316nDPPfZLe+wolT/SM9IB8PXxZe/RvaxOXJ31F8O5r8eiHqN2YG1WJKxgwfYFlPIrRelipSlVrBSl/ErRLrQdAUUDSElNYXPSZuJT4rO+Eo4m8N3A7xARhs4dyvtr3wdg3Yh1Vz2aqQa9Fzh92rblv/MOLFxoe+x0726v8nv1gqI6t7VSHmn/8f3EJscSnxJPn/p9rrppR4Pey+zcadvxZ86EvXtt+/2QITb0w8Kcrk4p5YTLBb326fBANWvC5MkQHw/ffgs33AAvvmiHWOja1U6YkpbmdJVKKXehQe/BihSBm26ywb5nDzz9NGzdCn362KGTn3rKXvErpQo3DXovUbkyTJpkJ0eZOxcaNbJBX6OGDf6FCyEjw+kqlVJO0KD3Mueevv3hB4iLsxObL1tmm3Tq1bNNPIcOOV2lUqogadB7sVq14PnnISEBPvrIDq0wdixUrWpv3q5caUfZVEp5N5eCXkS6i0iMiMSJyPgc1pcTka9EZL2IrBSRhq7uq/JfsWIwcKC9sl+/3o6n88UX0KoVtGhh++mfPOl0lUqp/HLFoBcRX+BVoAfQABggIg0u2GwisNYY0xi4C5iei31VAWrUCF59FRIT4bXX7FO4w4bZq/x//tMOw6CU8i6uXNG3BOKMMTuMMWeAj4HeF2zTAFgEYIzZCoSKSEUX91UOKFXKToKyYYOd77Z7d/sLoF496NJFu2gq5U1cCfqqwJ5srxMyl2W3DugDICItgRpAiIv7KgeJQFSUnet292545hk7BeK5LpqTJ9urf6WU53Il6HOaIv7CW3jPAeVEZC3wAPAHkObivvZDRIaLyGoRWZ2UlORCWSqvVaoEjz1mu2h+/TWEh8OTT9oumn37wqJFevNWKU/kStAnANmHzwoBzrvGM8YcNcYMNcZEYNvoywM7Xdk323u8ZYyJNMZEli9f3vUjUHnu3GiZCxZAbCw8+KCd/7ZzZ7juOnj5ZTsJulLKM7gS9KuAOiJSU0T8gP7AvOwbiEjZzHUAw4Clxpijruyr3FtYGEydap+wnT0bgoLgoYfsvLdDh2oXTaU8wRWD3hiTBowCFgBbgE+NMZtEZISIjMjcrD6wSUS2YnvYjLncvnl/GCq/+fvD4MHwv//B2rV2MpTPPrNdNCMjtYumUu5MR69UV+3oUfjwQzvZ+caNUK6c7ap5//32Rq5SquDo6JUqX5QubUN9/Xr45Rfo1Aleesk+kdu7N/z0kzbrKOUONOjVNROBG2+0TTk7d8KECbaJp0sX23Pntdfg2DGnq1Sq8NKgV3mqWjWYMsUOmzxrFpQoASNHQkgIjBlje/EopQqWBr3KF/7+cNddtlfOb7/Z6Q5ffx3q1oWePW3XTW3WUapgaNCrfCUC119vR8/cvRv+9S9Ys8YOudCggQ3/EyecrlIp76ZBrwpMpUr2Sdtdu2yf/BIl7M3ckBAYN84uV0rlPQ16VeCKFbN98letgl9/tTdtX37Z9ta5/XY7yJo26yiVdzTolWNEoE0b+PRTO77OuHF2qIV27aB5c3j/fTh1yukqlfJ8GvTKLVSvDs89Z2fDevNNOH3aDrEQEmJnxYqLc7pCpTyXBr1yKwEBMHy4fdJ28WLo2BGmT4c6dewN3HnzID3d6SqV8iwa9MotiUCHDvYhrF274Kmn7CQpvXtDzZq2r/7+/U5XqZRn0KBXbq9KFXjiCRv4X35pZ8GaNMk+nDVggN68VepKNOiVxyhSBG67DRYutLNgjRoF8+fbm7cREfDee3YOXKXU+TTolUeqW9cOoLZ3L7zzjr2iv+cee1P3iSdg3z6nK1TKfWjQK48WEAD33gvr1tmpDlu3tvPe1qhh++rraNdKadArLyFie+jMnQvbttknbr/+Glq0gLZt4fPPIS3N6SqVcoYGvfI6YWEwbZrtk//yy7YZ5447oHZt+M9/4PhxpytUqmBp0CuvVaaMndh82zZ7dV+rln36tmZNeP55DXxVeGjQK6/n62v73y9ZYodMjoyE8ePtdIfPPaeToijvp0GvCpXrr4cffoDff4eWLe1sWKGh8OyzGvjKe2nQq0KpVSv4/ntYscKG/8SJNvCnTLGTnivlTTToVaHWsiV8952dCat1a/vEbWio7aJ55IjT1SmVNzTolcJ2w/z2WztGfps28Pjjti/+E0/AoUNOV6fUtdGgVyqbyEj45huIjoZOneDpp23gP/II/Pmn09UpdXU06JXKQbNm8MUXdrjk3r3hxRdtt8wHHrBz3yrlSTTolbqM8HA7sXlMDPztb/DGG/aBrL//HbZvd7o6pVyjQa+UC8LC7OBp27fDfffBBx/YgdUGDYLNm52uTqnL06BXKheqV4f//hd27oSHHrJP3DZsaMfF37rV6eqUypkGvVJXoXJlmDoV4uPtU7bffGObeQYPhthYp6tT6nwa9Epdg+Bg+Pe/7RX+2LF2BqzrroMhQ7QNX7kPl4JeRLqLSIyIxInI+BzWlxGRb0RknYhsEpGh2dbFi8gGEVkrIjo6uPJK5cvbgdJ27LADqX3yiZ3y8N577S8BpZx0xaAXEV/gVaAH0AAYICINLthsJLDZGNMEaA+8KCJ+2dZ3MMZEGGMi86ZspdxTxYq2K+aOHTBypO2xU7euvYGr3TKVU1y5om8JxBljdhhjzgAfA70v2MYApUREgJJAMqDTPKhCq3JlmD79r146779ve+784x+wZ4/T1anCxpWgrwpk/08zIXNZdq8A9YFEYAMwxhiTkbnOAD+KSLSIDL/Uh4jIcBFZLSKrk5KSXD4ApdxZ1arwyisQF2fntH33XRv4I0dq4KuC40rQSw7LzAWvuwFrgSpABPCKiJTOXNfGGNMM2/QzUkRuzOlDjDFvGWMijTGR5cuXd6V2pTxGtWr2YavYWBg6FN5++6/AT0hwujrl7VwJ+gSgWrbXIdgr9+yGAl8aKw7YCVwHYIxJzPz3APAVtilIqUKpRo2LA792bRg1SgNf5R9Xgn4VUEdEambeYO0PzLtgm91AJwARqQjUA3aISAkRKZW5vATQFdiYV8Ur5amyB/6QIfDmmzbwH3gA9u51ujrlba4Y9MaYNGAUsADYAnxqjNkkIiNEZETmZk8DN4jIBmAR8Kgx5iBQEfhVRNYBK4HvjDHz8+NAlPJENWrYkI+NhbvvtuFfuzaMGQN6q0rlFTHmwuZ250VGRprVq7XLvSp84uPtLFfvvQclSsBjj8Ho0eDv73Rlyt2JSPSlurDrk7FKuZHQUNtuv2EDtGsHjz5qn7SdMwcyMq64u1I50qBXyg3Vrw/z5sGiRRAYCAMH2rltly1zujLliTTolXJjHTvC6tUwaxYkJsKNN0KfPjpwmsodDXql3JyPD9x1F2zbZictX7gQGjSwN2x1PlvlCg16pTxEQIC9ORsba5+yfeUV20PnxRfh9Gmnq1PuTINeKQ9TqZLtkrl+PbRubYdHDg+3QyS7YSc65QY06JXyUOHh8MMPMH8+FCsGt98OHTrAmjVOV6bcjQa9Uh6uWzdYtw5efx02bYLISPu0beKFA5WoQkuDXikvUKQIjBhhR8kcO9b2u69TByZPhpMnna5OOU2DXikvUqYMvPACbNkCPXvCk0/aiU8++EAfuCrMNOiV8kK1asFnn8HSpXYSlLvusg9c6cgihZMGvVJeLCoKVqyA2bPtRCctW9pZrpKTna5MFSQNeqW8nI8PDB4MW7faAdLeestOXP7ee9qcU1ho0CtVSJQpA9OmQXS0vVF7zz12SIX1652uTOU3DXqlCpmICPj1Vzt/7dat0KwZ/POfcPSo05Wp/KJBr1Qh5ONjr+hjYmDYMJg+/a/hkPXpWu+jQa9UIRYUZGe1+v132ztn4EDo3Nl2z1TeQ4NeKUXLlrByJbz6qh1CoXFj++CVNud4Bw16pRQAvr5w//22Oefuu+Gll2zvnNmztXeOp9OgV0qdp0IFeOcd2/++enUb+lFROliaJ9OgV0rlqEUL+O03mDnTjoEfGWnH09HJTjyPBr1S6pJ8fGDoUDu71ejR9kq/Th07UmZ6utPVKVdp0CulrqhsWfuw1dq10KSJbcuPjITlyx0uTLlEg14p5bKGDWHxYvjkEzh4ENq2heHD4fBhpytTl6NBr5TKFRHo188+VTt2rG3Dr1/fhr8+bOWeNOiVUlelRAmYOhVWrYKQEOjfH266CeLjna5MXUiDXil1TZo2tV0xp02z49+Hh8OLL0JamtOVqXM06JVS18zXF8aMgc2boWNH26TTsqVOdOIuNOiVUnmmenWYN8/ObvXnn9CqlR0Z8/hxpysr3DTolVJ5SgT69rUDo913n23SadAAvv3W6coKL5eCXkS6i0iMiMSJyPgc1pcRkW9EZJ2IbBKRoa7uq5TyTmXKwGuv2b72pUvDzTfDHXdAYqLTlRU+Vwx6EfEFXgV6AA2AASLS4ILNRgKbjTFNgPbAiyLi5+K+SikvdsMNdpycKVPgm29sV8zXX9eB0gqSK1f0LYE4Y8wOY8wZ4GOg9wXbGKCUiAhQEkgG0lzcVynl5fz8YOJE2LjRjqFz//32YauNG52urHBwJeirAnuyvU7IXJbdK0B9IBHYAIwxxmS4uC8AIjJcRFaLyOqkpCQXy1dKeZKwMFi40A59HBtru2ZOnAinTjldmXdzJeglh2UXPv/WDVgLVAEigFdEpLSL+9qFxrxljIk0xkSWL1/ehbKUUp5IBAYPtjdrBw2CZ5+FRo3gp5+crsx7uRL0CUC1bK9DsFfu2Q0FvjRWHLATuM7FfZVShVBwMLz3nh07x8cHunSxvwD0D/q850rQrwLqiEhNEfED+gPzLthmN9AJQEQqAvWAHS7uq5QqxDp0gPXr4fHH7Xg59evDF184XZV3uWLQG2PSgFHAAmAL8KkxZpOIjBCREZmbPQ3cICIbgEXAo8aYg5faNz8ORCnlufz9YfJkOwxyaKjthz94MKSkOFyYlxDjhsPNRUZGmtX67LRShdLZs7Yr5jPPQOXKtnmnc2enq3J/IhJtjInMaZ0+GauUcitFi8K//mWnMSxRwrbdjx4NJ086XZnn0qBXSrmlFi3gjz9syP/3v9CsmR0SWeWeBr1Sym0VLw7Tp9uulydOQOvW9mr/7FmnK/MsGvRKKbfXqRNs2AADB8JTT9nA37LF6ao8hwa9UsojlC1rn6j9/HM7i1WzZvDyyzpmjis06JVSHuX22+0YOV26wEMP2X74O3Y4XZV706BXSnmcSpVg7lx4/33b975xY3jjDZ2c/FI06JVSHkkE7r7bXt3fcAP84x/QvTvs2XPlfQsbDXqllEerVg0WLLBj3C9fDg0b2it9vbr/iwa9UsrjicCIEXbMnIgIGDoUeve289YqDXqllBepVQuWLLG9cRYuhPBwO1BaYadBr5TyKj4+8OCD9qnasDDo3x/uvBMOHXK6Mudo0CulvNJ119k2+3//G776yk5u8sMPTlflDA16pZTXKlIEJkyAlSshMBB69rS9c06ccLqygqVBr5TyehERsHo1jB0Lb75pX//2m9NVFRwNeqVUoeDvD1On2pu1Z89C27YwaRKcOeN0ZflPg14pVai0a2e7Yd59t53g5PrrYZOXz3unQa+UKnRKl4aZM+HrryEhAZo3h5de8t4B0jTolVKFVu/edgiFbt3g4YftcMi7dztdVd7ToFdKFWoVKtgr+3fftTdsGzeGjz92uqq8pUGvlCr0ROCee2DdOmjQAAYMgLvugqNHna4sb2jQK6VUplq1YOlSO13hRx9Bkyb2oStPp0GvlFLZFCkCTz4Jy5bZK/0bb7Sv09KcruzqadArpVQObrjBTmoyaBBMnmz73W/f7nRVV0eDXimlLqF0aZg1y96cjYmxT9R64lj3GvRKKXUFd95pb9Q2b27Hur/zTkhOdroq12nQK6WUC6pXh0WL4Nln7WiYTZrAL784XZVrNOiVUspFvr4wfrwdEM3fHzp29IwbtRr0SimVS5GRsGYNDB5sb9S2bw+7djld1aW5FPQi0l1EYkQkTkTG57B+nIiszfzaKCLpIhKYuS5eRDZkrlud1weglFJOKFXK3pj96CM7SFqTJvDZZ05XlbMrBr2I+AKvAj2ABsAAEWmQfRtjzFRjTIQxJgKYAPxijMl+q6JD5vrIvCtdKaWcN3Cg7YZZrx706wfDh7vfxCauXNG3BOKMMTuMMWeAj4Hel9l+ADAnL4pTSilPUKsW/Pqrbb9/5x3btLNundNV/cWVoK8K7Mn2OiFz2UVEJADoDnyRbbEBfhSRaBEZfrWFKqWUOyta1PbIWbgQjhyBVq3gv/91jz73rgS95LDsUqXfDCy/oNmmjTGmGbbpZ6SI3Jjjh4gMF5HVIrI6KSnJhbKUUsr9dOpkr+Y7d4bRo+1QyAcPOluTK0GfAFTL9joESLzEtv25oNnGGJOY+e8B4CtsU9BFjDFvGWMijTGR5cuXd6EspZRyT+XLwzffwPTpsGCBHfp40SLn6nEl6FcBdUSkpoj4YcN83oUbiUgZoB0wN9uyEiJS6tz3QFdgY14UrpRS7kzEXtGvXAllykCXLrYN34k5aq8Y9MaYNGAUsADYAnxqjNkkIiNEZES2TW8DfjTGZL/fXBH4VUTWASuB74wx8/OufKWUcm9NmkB0NPz97/D889CmDcTFFWwNYtzhTsEFIiMjzerV2uVeKeVdvvwShg2Ds2fhlVfs5CaS013QqyAi0Zfqwq5PxiqlVAHp0+evwdGGDIG//c320MlvGvRKKVWAqlWzN2affho+/dQOffzbb/n7mRr0SilVwHx9YdIkO4sVQFQUPPMMpKfnz+dp0CullENat7bDJ/TrB48/bkfDPH487z9Hg14ppRxUpowdGG3WLAgLgxIl8v4ziuT9WyqllMoNEdsD56678uf99YpeKaW8nAa9Ukp5OQ16pZTychr0Sinl5TTolVLKy2nQK6WUl9OgV0opL6dBr5RSXs4thykWkSRgV7ZFwYDDk3HlOW87Jm87HvC+Y/K24wHvO6ZrOZ4axpgcp+dzy6C/kIisvtQ4y57K247J244HvO+YvO14wPuOKb+OR5tulFLKy2nQK6WUl/OUoH/L6QLygbcdk7cdD3jfMXnb8YD3HVO+HI9HtNErpZS6ep5yRa+UUuoqadArpZSXc7ugF5GZInJARDZmWxYoIgtFJDbz33JO1phblzimf4nIXhFZm/nV08kac0NEqonIEhHZIiKbRGRM5nKPPE+XOR5PPkf+IrJSRNZlHtNTmcs99Rxd6ng89hwBiIiviPwhIt9mvs6X8+N2bfQiciNwHJhtjGmYuewFINkY85yIjAfKGWMedbLO3LjEMf0LOG6M+Y+TtV0NEakMVDbGrBGRUkA0cCswBA88T5c5nn547jkSoIQx5riIFAV+BcYAffDMc3Sp4+mOh54jABF5CIgEShtjeuVX1rndFb0xZimQfMHi3sCszO9nYf8n9BiXOCaPZYzZZ4xZk/n9MWALUBUPPU+XOR6PZaxz00wXzfwyeO45utTxeCwRCQFuAt7Jtjhfzo/bBf0lVDTG7AP7PyVQweF68sooEVmf2bTjEX9CX0hEQoGmwAq84DxdcDzgwecos1lgLXAAWGiM8ehzdInjAc89R9OAR4CMbMvy5fx4StB7o9eB2kAEsA940dFqroKIlAS+AB40xhx1up5rlcPxePQ5MsakG2MigBCgpYg0dLika3KJ4/HIcyQivYADxpjogvg8Twn6/ZntqOfaUw84XM81M8bsz/wPNwN4G2jpdE25kdlO+gXwkTHmy8zFHnuecjoeTz9H5xhjUoCfse3ZHnuOzsl+PB58jtoAt4hIPPAx0FFEPiSfzo+nBP084O7M7+8G5jpYS544dzIz3QZsvNS27ibzxti7wBZjzEvZVnnkebrU8Xj4OSovImUzvy8OdAa24rnnKMfj8dRzZIyZYIwJMcaEAv2BxcaYQeTT+XHHXjdzgPbY4Tr3A08CXwOfAtWB3cAdxhiPubl5iWNqj/1z0wDxwH3n2ubcnYi0BZYBG/irfXEitl3b487TZY5nAJ57jhpjb+b5Yi/oPjXGTBaRIDzzHF3qeD7AQ8/ROSLSHhib2esmX86P2wW9UkqpvOUpTTdKKaWukga9Ukp5OQ16pZTychr0Sinl5TTolVLKy2nQK6WUl9OgV0opL/f/92tLuFEDhXQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df = pd.DataFrame(\n", " model.learning_process[10:], columns=[\"epoch\", \"train_RMSE\", \"test_RMSE\"]\n", ")\n", "plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n", "plt.plot(\"epoch\", \"test_RMSE\", data=df, color=\"green\", linestyle=\"dashed\")\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Saving and evaluating recommendations" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "model.estimations()\n", "\n", "top_n = pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n", "\n", "top_n.to_csv(\n", " \"Recommendations generated/ml-100k/Self_SVD_reco.csv\", index=False, header=False\n", ")\n", "\n", "estimations = pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n", "estimations.to_csv(\n", " \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\",\n", " index=False,\n", " header=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 11138.92it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
00.9141430.7171310.1018030.0421340.051610.0685430.0919530.0712550.1040150.0488170.1930270.5177840.4718980.8672320.1479083.8712960.97182
\n", "
" ], "text/plain": [ " RMSE MAE precision recall F_1 F_05 \\\n", "0 0.914143 0.717131 0.101803 0.042134 0.05161 0.068543 \n", "\n", " precision_super recall_super NDCG mAP MRR LAUC \\\n", "0 0.091953 0.071255 0.104015 0.048817 0.193027 0.517784 \n", "\n", " HR Reco in test Test coverage Shannon Gini \n", "0 0.471898 0.867232 0.147908 3.871296 0.97182 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import evaluation_measures as ev\n", "\n", "estimations_df = pd.read_csv(\n", " \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\", header=None\n", ")\n", "reco = np.loadtxt(\"Recommendations generated/ml-100k/Self_SVD_reco.csv\", delimiter=\",\")\n", "\n", "ev.evaluate(\n", " test=pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None),\n", " estimations_df=estimations_df,\n", " reco=reco,\n", " super_reactions=[4, 5],\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 10694.55it/s]\n", "943it [00:00, 11600.99it/s]\n", "943it [00:00, 11461.54it/s]\n", "943it [00:00, 11660.39it/s]\n", "943it [00:00, 9872.18it/s]\n", "943it [00:00, 11443.77it/s]\n", "943it [00:00, 11990.88it/s]\n", "943it [00:00, 11615.02it/s]\n", "943it [00:00, 11874.78it/s]\n", "943it [00:00, 12387.19it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Self_SVD0.9141430.7171310.1018030.0421340.0516100.0685430.0919530.0712550.1040150.0488170.1930270.5177840.4718980.8672320.1479083.8712960.971820
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Ready_Random1.5218451.2259490.0471900.0207530.0248100.0322690.0295060.0237070.0500750.0187280.1219570.5068930.3297990.9865320.1847045.0997060.907217
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_TopRated1.0307120.8209040.0009540.0001880.0002980.0004810.0006440.0002230.0010430.0003350.0033480.4964330.0095440.6990460.0050511.9459100.995669
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Self_SVD 0.914143 0.717131 0.101803 0.042134 0.051610 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.068543 0.091953 0.071255 0.104015 0.048817 0.193027 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.517784 0.471898 0.867232 0.147908 3.871296 0.971820 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dir_path = \"Recommendations generated/ml-100k/\"\n", "super_reactions = [4, 5]\n", "test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embeddings" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
codescoreitem_ididtitlegenres
04051.000000406406Thinner (1996)Horror, Thriller
18270.968354828828Alaska (1996)Adventure, Children's
28180.967103819819Eddie (1996)Comedy
39040.963944905905Great Expectations (1998)Drama, Romance
412930.96277912941294Ayn Rand: A Sense of Life (1997)Documentary
53470.961946348348Desperate Measures (1998)Crime, Drama, Thriller
68070.960952808808Program, The (1993)Action, Drama
75600.960885561561Mary Shelley's Frankenstein (1994)Drama, Horror
813920.95872413931393Stag (1997)Action, Thriller
97870.957891788788Relative Fear (1994)Horror, Thriller
\n", "
" ], "text/plain": [ " code score item_id id title \\\n", "0 405 1.000000 406 406 Thinner (1996) \n", "1 827 0.968354 828 828 Alaska (1996) \n", "2 818 0.967103 819 819 Eddie (1996) \n", "3 904 0.963944 905 905 Great Expectations (1998) \n", "4 1293 0.962779 1294 1294 Ayn Rand: A Sense of Life (1997) \n", "5 347 0.961946 348 348 Desperate Measures (1998) \n", "6 807 0.960952 808 808 Program, The (1993) \n", "7 560 0.960885 561 561 Mary Shelley's Frankenstein (1994) \n", "8 1392 0.958724 1393 1393 Stag (1997) \n", "9 787 0.957891 788 788 Relative Fear (1994) \n", "\n", " genres \n", "0 Horror, Thriller \n", "1 Adventure, Children's \n", "2 Comedy \n", "3 Drama, Romance \n", "4 Documentary \n", "5 Crime, Drama, Thriller \n", "6 Action, Drama \n", "7 Drama, Horror \n", "8 Action, Thriller \n", "9 Horror, Thriller " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item = random.choice(list(set(train_ui.indices)))\n", "\n", "embeddings_norm = (\n", " model.Qi / np.linalg.norm(model.Qi, axis=1)[:, None]\n", ") # we do not mean-center here\n", "# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n", "\n", "similarity_scores = np.dot(embeddings_norm, embeddings_norm[item].T)\n", "top_similar_items = pd.DataFrame(\n", " enumerate(similarity_scores), columns=[\"code\", \"score\"]\n", ").sort_values(by=[\"score\"], ascending=[False])[:10]\n", "\n", "top_similar_items[\"item_id\"] = top_similar_items[\"code\"].apply(\n", " lambda x: item_code_id[x]\n", ")\n", "\n", "items = pd.read_csv(\"./Datasets/ml-100k/movies.csv\")\n", "\n", "result = pd.merge(top_similar_items, items, left_on=\"item_id\", right_on=\"id\")\n", "\n", "result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# project task 5: implement SVD on top baseline (as it is in Surprise library)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# making changes to our implementation by considering additional parameters in the gradient descent procedure\n", "# seems to be the fastest option\n", "# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n", "# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Ready-made SVD - Surprise implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "\n", "algo = sp.SVD(biased=False) # to use unbiased version\n", "\n", "helpers.ready_made(\n", " algo,\n", " reco_path=\"Recommendations generated/ml-100k/Ready_SVD_reco.csv\",\n", " estimations_path=\"Recommendations generated/ml-100k/Ready_SVD_estimations.csv\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD biased - on top baseline" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "algo = sp.SVD() # default is biased=True\n", "\n", "helpers.ready_made(\n", " algo,\n", " reco_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv\",\n", " estimations_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv\",\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 11249.52it/s]\n", "943it [00:00, 10927.13it/s]\n", "943it [00:00, 11816.00it/s]\n", "943it [00:00, 11204.84it/s]\n", "943it [00:00, 11803.13it/s]\n", "943it [00:00, 10580.63it/s]\n", "943it [00:00, 11843.28it/s]\n", "943it [00:00, 12313.76it/s]\n", "943it [00:00, 10678.21it/s]\n", "943it [00:00, 9772.22it/s]\n", "943it [00:00, 10699.52it/s]\n", "943it [00:00, 11789.55it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Ready_SVD0.9503470.7493120.1006360.0505140.0557940.0707530.0912020.0827340.1140540.0532000.2488030.5219830.5174970.9921530.2106784.4186830.952848
0Self_SVD0.9141430.7171310.1018030.0421340.0516100.0685430.0919530.0712550.1040150.0488170.1930270.5177840.4718980.8672320.1479083.8712960.971820
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Ready_SVDBiased0.9394720.7398160.0858960.0360730.0435280.0576430.0770390.0574630.0977530.0455460.2198390.5147090.4316010.9974550.1688314.2175780.962577
0Ready_Random1.5218451.2259490.0471900.0207530.0248100.0322690.0295060.0237070.0500750.0187280.1219570.5068930.3297990.9865320.1847045.0997060.907217
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_TopRated1.0307120.8209040.0009540.0001880.0002980.0004810.0006440.0002230.0010430.0003350.0033480.4964330.0095440.6990460.0050511.9459100.995669
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Ready_SVD 0.950347 0.749312 0.100636 0.050514 0.055794 \n", "0 Self_SVD 0.914143 0.717131 0.101803 0.042134 0.051610 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Ready_SVDBiased 0.939472 0.739816 0.085896 0.036073 0.043528 \n", "0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.070753 0.091202 0.082734 0.114054 0.053200 0.248803 \n", "0 0.068543 0.091953 0.071255 0.104015 0.048817 0.193027 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.057643 0.077039 0.057463 0.097753 0.045546 0.219839 \n", "0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.521983 0.517497 0.992153 0.210678 4.418683 0.952848 \n", "0 0.517784 0.471898 0.867232 0.147908 3.871296 0.971820 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.514709 0.431601 0.997455 0.168831 4.217578 0.962577 \n", "0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dir_path = \"Recommendations generated/ml-100k/\"\n", "super_reactions = [4, 5]\n", "test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }