{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Self made SVD" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import helpers\n", "import pandas as pd\n", "import numpy as np\n", "import scipy.sparse as sparse\n", "from collections import defaultdict\n", "from itertools import chain\n", "import random\n", "\n", "train_read=pd.read_csv('./Datasets/ml-100k/train.csv', sep='\\t', header=None)\n", "test_read=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "train_ui, test_ui, user_code_id, user_id_code, item_code_id, item_id_code = helpers.data_to_csr(train_read, test_read)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n", "from tqdm import tqdm\n", "\n", "class SVD():\n", " \n", " def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n", " self.train_ui=train_ui\n", " self.uir=list(zip(*[train_ui.nonzero()[0],train_ui.nonzero()[1], train_ui.data]))\n", " \n", " self.learning_rate=learning_rate\n", " self.regularization=regularization\n", " self.iterations=iterations\n", " self.nb_users, self.nb_items=train_ui.shape\n", " self.nb_ratings=train_ui.nnz\n", " self.nb_factors=nb_factors\n", " \n", " self.Pu=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_users, self.nb_factors))\n", " self.Qi=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_items, self.nb_factors))\n", "\n", " def train(self, test_ui=None):\n", " if test_ui!=None:\n", " self.test_uir=list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n", " \n", " self.learning_process=[]\n", " pbar = tqdm(range(self.iterations))\n", " for i in pbar:\n", " pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...')\n", " np.random.shuffle(self.uir)\n", " self.sgd(self.uir)\n", " if test_ui==None:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir)])\n", " else:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n", " \n", " def sgd(self, uir):\n", " \n", " for u, i, score in uir:\n", " # Computer prediction and error\n", " prediction = self.get_rating(u,i)\n", " e = (score - prediction)\n", " \n", " # Update user and item latent feature matrices\n", " Pu_update=self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n", " Qi_update=self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n", " \n", " self.Pu[u] += Pu_update\n", " self.Qi[i] += Qi_update\n", " \n", " def get_rating(self, u, i):\n", " prediction = self.Pu[u].dot(self.Qi[i].T)\n", " return prediction\n", " \n", " def RMSE_total(self, uir):\n", " RMSE=0\n", " for u,i, score in uir:\n", " prediction = self.get_rating(u,i)\n", " RMSE+=(score - prediction)**2\n", " return np.sqrt(RMSE/len(uir))\n", " \n", " def estimations(self):\n", " self.estimations=\\\n", " np.dot(self.Pu,self.Qi.T)\n", "\n", " def recommend(self, user_code_id, item_code_id, topK=10):\n", " \n", " top_k = defaultdict(list)\n", " for nb_user, user in enumerate(self.estimations):\n", " \n", " user_rated=self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user+1]]\n", " for item, score in enumerate(user):\n", " if item not in user_rated and not np.isnan(score):\n", " top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n", " result=[]\n", " # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n", " for uid, item_scores in top_k.items():\n", " item_scores.sort(key=lambda x: x[1], reverse=True)\n", " result.append([uid]+list(chain(*item_scores[:topK])))\n", " return result\n", " \n", " def estimate(self, user_code_id, item_code_id, test_ui):\n", " result=[]\n", " for user, item in zip(*test_ui.nonzero()):\n", " result.append([user_code_id[user], item_code_id[item], \n", " self.estimations[user,item] if not np.isnan(self.estimations[user,item]) else 1])\n", " return result" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 39 RMSE: 0.7480082047970615. Training epoch 40...: 100%|██████████| 40/40 [01:21<00:00, 2.05s/it]\n" ] } ], "source": [ "model=SVD(train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40)\n", "model.train(test_ui)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAHCNJREFUeJzt3X2UFPWd7/H3l3lggAEHZkZlGQE17vEBZAQkuGJWJZsDZqMkl2vw3I16T+6yKj6dGHM1V82a1ZMoJ5qw0bi4akzWG0iCjzmuiQY8GqMYQEAezAVXEwYIDA/Dg8rDDN/7R1UzPU3PdM9Mz1RP1ed1Tp2qrqru/lLAp379qydzd0REJF76RV2AiIgUnsJdRCSGFO4iIjGkcBcRiSGFu4hIDCncRURiSOEuIhJDCncRkRhSuIuIxFBpVF9cU1Pjo0ePjurrRUT6pOXLl+9w99pc60UW7qNHj2bZsmVRfb2ISJ9kZn/KZz11y4iIxJDCXUQkhhTuIiIxFFmfu4j0PYcPH6ahoYEDBw5EXUrsVVRUUFdXR1lZWZfer3AXkbw1NDQwePBgRo8ejZlFXU5suTs7d+6koaGBk08+uUufoW4ZEcnbgQMHqK6uVrD3MDOjurq6W7+QFO4i0ikK9t7R3e3c58J9yxa4+WY4fDjqSkREilefC/elS+EHP4A77oi6EhGR4tXnwv2LX4RrroH774df/zrqakSkNzU1NfHwww93+n2XXHIJTU1NnX7f1Vdfzcknn0x9fT3jxo3jt7/97dFlF154ISNHjsTdj86bMWMGlZWVABw5coQbb7yRMWPGMHbsWM4991w++OADILhCf+zYsdTX11NfX8+NN97Y6dpy6ZNnyzzwAPzud/CVr8CqVTB8eNQViUhvSIX7dddd12Z+S0sLJSUl7b7vxRdf7PJ3zp07l5kzZ7JkyRJmz57Nhg0bji6rqqrijTfeYMqUKTQ1NbF169ajyxYuXMiWLVtYvXo1/fr1o6GhgUGDBh1dvmTJEmpqarpcVy59MtwHDICFC2HiRLjyyqAF36/P/QYR6dtuvhlWrizsZ9bXw/e/3/7y2267jffff5/6+nrKysqorKxk+PDhrFy5knXr1jFjxgw2bdrEgQMHuOmmm5g9ezbQei+r/fv3M336dKZMmcLvf/97RowYwXPPPceAAQNy1nbeeeexefPmNvNmzZrFggULmDJlCk8//TRf+tKXWLt2LQBbt25l+PDh9AvDqa6urotbpWv6bCSeeSbMmwevvBJ00YhI/H33u9/l1FNPZeXKlcydO5e3336be++9l3Xr1gHw+OOPs3z5cpYtW8a8efPYuXPnMZ+xYcMG5syZw9q1a6mqqmLRokV5ffdLL73EjBkz2sybOnUqr732Gi0tLSxYsIAvf/nLR5ddfvnlvPDCC9TX13PLLbfwzjvvtHnvRRdddLRb5sEHH+zspsipT7bcU7761SDc77gD/vZv4bzzoq5IJDk6amH3lkmTJrW5yGfevHk888wzAGzatIkNGzZQXV3d5j2pPnSACRMm8OGHH3b4Hbfeeivf+MY32L59O2+99VabZSUlJUyZMoWFCxfyySefkH4b87q6Ov74xz+yePFiFi9ezNSpU/nFL37B1KlTgZ7vlumzLXcAM/i3f4ORI+GKK2D37qgrEpHelN6H/eqrr/LKK6/w5ptvsmrVKs4555ysFwH179//6HRJSQnNzc0dfsfcuXPZuHEj99xzD1ddddUxy2fNmsUNN9zA5ZdfnvW7pk+fzty5c/nmN7/Js88+25k/Xrf06XAHOO44WLAANm+Gf/xHSDtwLSIxM3jwYPbt25d12Z49exg6dCgDBw7kvffeO6aV3R39+vXjpptu4siRI/w64zS9Cy64gNtvv50rrriizfwVK1awZcsWIDhzZvXq1YwaNapgNeWsude+qQdNmgTf+Q4sWhS05EUknqqrqzn//PMZM2YMt956a5tl06ZNo7m5mbPPPps777yTyZMnF/S7zYw77riD+zMO8pkZX//614/pYtm+fTtf+MIXGDNmDGeffTalpaVcf/31R5en97lfeeWVBa0VwDyipu7EiRO9kE9iOnIEPv95WLIE3n4bzj67YB8tIqH169dzxhlnRF1GYmTb3ma23N0n5npvzpa7mVWY2dtmtsrM1prZ3VnWudrMGs1sZTj8r079CQqgXz948kkYOjQ4PVJEJMny6ZY5CFzs7uOAemCamWX7vbPQ3evD4d8LWmWejj8err02uLDp0KEoKhCRvmjOnDlHu0hSwxNPPBF1Wd2S81RID/pt9ocvy8KhaA9b1obPBN+1C048MdpaROLI3WN3Z8iHHnoo6hKO0d0u87wOqJpZiZmtBLYDL7v70iyr/TczW21mvzSzk7pVVTekTmnNcu2CiHRTRUUFO3fu7HbwSMdSD+uoqKjo8mfkdRGTu7cA9WZWBTxjZmPcfU3aKi8AP3P3g2Z2DfAkcHHm55jZbGA2wMiRI7tcdEcU7iI9p66ujoaGBhobG6MuJfZSj9nrqk5doeruTWb2KjANWJM2Pz1KHwXua+f984H5EJwt09li86FwF+k5ZWVlXX7sm/SufM6WqQ1b7JjZAOCzwHsZ66Tfl/FSYH0hi+wMhbuISH4t9+HAk2ZWQrAz+Lm7/8rMvg0sc/fngRvN7FKgGdgFXN1TBeeicBcRye9smdXAOVnm35U2fTtwe2FL65qBA6GiQuEuIskWi9sPZKquVriLSLLFNtx37Ii6ChGR6MQ23NVyF5EkU7iLiMRQLMO9pkbhLiLJFstwr64O7i2jK6RFJKliG+4tLbBnT9SViIhEI7bhDuqaEZHkUriLiMSQwl1EJIYU7iIiMaRwFxGJoViGe1VV8MBshbuIJFUsw71fPxg6VOEuIskVy3AH3TxMRJIt1uGulruIJJXCXUQkhhTuIiIxpHAXEYmhWIf7xx/DgQNRVyIi0vtiG+41NcFYrXcRSaLYhruuUhWRJFO4i4jEkMJdRCSGcoa7mVWY2dtmtsrM1prZ3VnW6W9mC81so5ktNbPRPVFsZyjcRSTJ8mm5HwQudvdxQD0wzcwmZ6zzVWC3u38KeBC4r7Bldp7CXUSSLGe4e2B/+LIsHDIfPX0Z8GQ4/UtgqplZwarsgv79YdAghbuIJFNefe5mVmJmK4HtwMvuvjRjlRHAJgB3bwb2ANWFLLQrdCGTiCRVXuHu7i3uXg/UAZPMbEzGKtla6Zmte8xstpktM7NljY2Nna+2k3RnSBFJqk6dLePuTcCrwLSMRQ3ASQBmVgocB+zK8v757j7R3SfW1tZ2qeDOUMtdRJIqn7Nlas2sKpweAHwWeC9jteeBq8LpmcBidz+m5d7bFO4iklSleawzHHjSzEoIdgY/d/dfmdm3gWXu/jzwGPBTM9tI0GKf1WMVd4LCXUSSKme4u/tq4Jws8+9Kmz4A/PfCltZ91dWweze0tEBJSdTViIj0ntheoQpBuLtDU1PUlYiI9K7Yhzuoa0ZEkkfhLiISQ7EOd93TXUSSKtbhrpa7iCSVwl1EJIZiHe5DhkBpqcJdRJIn1uFuBsOGKdxFJHliHe6gq1RFJJkSEe66M6SIJE0iwl0tdxFJGoW7iEgMJSbco78BsYhI70lEuB88CB9/HHUlIiK9JxHhDuqaEZFkUbiLiMSQwl1EJIYU7iIiMRT7cNdtf0UkiWIf7sOGBWOFu4gkSezDvawsuDukwl1EkiT24Q66SlVEkkfhLiISQ4kJd90ZUkSSJDHhrpa7iCRJznA3s5PMbImZrTeztWZ2U5Z1LjSzPWa2Mhzu6plyu0bhLiJJU5rHOs3ALe6+wswGA8vN7GV3X5ex3uvu/veFL7H7qqthzx5obg6eqSoiEnc5W+7uvtXdV4TT+4D1wIieLqyQUlep7toVbR0iIr2lU33uZjYaOAdYmmXxeWa2ysz+08zOauf9s81smZkta2xs7HSxXaVbEIhI0uQd7mZWCSwCbnb3vRmLVwCj3H0c8K/As9k+w93nu/tEd59YW1vb1Zo7TeEuIkmTV7ibWRlBsD/l7k9nLnf3ve6+P5x+ESgzs5qCVtoNCncRSZp8zpYx4DFgvbs/0M46J4brYWaTws8tmihVuItI0uRz7sj5wFeAd81sZTjvm8BIAHd/BJgJXGtmzcAnwCz34nlqqcJdRJImZ7i7++8Ay7HOD4EfFqqoQqusDG4gpnAXkaRIxBWqZsF93RXuIpIUiQh30FWqIpIsCncRkRhKVLjrzpAikhSJCne13EUkKRIX7sVzgqaISM9JVLg3N8O+fVFXIiLS8xIV7qCuGRFJBoW7iEgMKdxFRGJI4S4iEkMKdxGRGEpMuA8dGowV7iKSBIkJ99JSqKpSuItIMiQm3EFXqYpIcijcRURiKFHhrnu6i0hSJCrcdWdIEUmKxIW7Wu4ikgSJC/f9++HQoagrERHpWYkLd1DrXUTiT+EuIhJDCncRkRhSuIuIxFDOcDezk8xsiZmtN7O1ZnZTlnXMzOaZ2UYzW21m43um3O5RuItIUpTmsU4zcIu7rzCzwcByM3vZ3delrTMdOC0cPg38KBwXFYW7iCRFzpa7u2919xXh9D5gPTAiY7XLgJ944C2gysyGF7zabho4ECoqFO4iEn+d6nM3s9HAOcDSjEUjgE1prxs4dgdQFHQhk4gkQd7hbmaVwCLgZnffm7k4y1s8y2fMNrNlZrassbGxc5UWiMJdRJIgr3A3szKCYH/K3Z/OskoDcFLa6zpgS+ZK7j7f3Se6+8Ta2tqu1NtttbWwbVskXy0i0mvyOVvGgMeA9e7+QDurPQ9cGZ41MxnY4+5bC1hnwZx5JqxZAy0tUVciItJz8mm5nw98BbjYzFaGwyVmdo2ZXROu8yLwX8BG4FHgup4pt/vGj4ePPoING6KuRESk5+Q8FdLdf0f2PvX0dRyYU6iietKECcF4+XI4/fRoaxER6SmJukIV4IwzgtMhV6yIuhIRkZ6TuHAvLYVx44KWu4hIXCUu3CHod3/nHThyJOpKRER6RiLDfcIE2LsX3n8/6kpERHpGIsN9fHhbM/W7i0hcJTLczzoLysvV7y4i8ZXIcC8vh7Fj1XIXkfhKZLhD0O++YgX4MXfAERHp+xIb7uPHw+7d8OGHUVciIlJ4iQ339CtVRUTiJrHhPmZMcEGT+t1FJI4SG+4VFUHAq+UuInGU2HCHoN9dB1VFJI4SHe4TJsCOHbBpU+51RUT6kkSHu65UFZG4SnS4jxsHJSXqdxeR+El0uA8YENzfXS13EYmbRIc7BP3uy5froKqIxIvCfQJs2wZbi/Jx3iIiXZP4cE8dVFW/u4jESeLDvb4ezNTvLiLxkvhwHzQITj9dLXcRiZfEhzu03v5XRCQuFO4E/e6bNwcHVkVE4iBnuJvZ42a23czWtLP8QjPbY2Yrw+GuwpfZs1K3/1XrXUTiIp+W+4+BaTnWed3d68Ph290vq3fV1wdj9buLSFzkDHd3fw3Y1Qu1RGbIEPjrv1bLXUTio1B97ueZ2Soz+08zO6tAn9mrxo9Xy11E4qMQ4b4CGOXu44B/BZ5tb0Uzm21my8xsWWNjYwG+unAmTIA//zm4BbCISF/X7XB3973uvj+cfhEoM7Oadtad7+4T3X1ibW1td7+6oHT7XxGJk26Hu5mdaGYWTk8KP3Nndz+3t+k2BCISJ6W5VjCznwEXAjVm1gB8CygDcPdHgJnAtWbWDHwCzHLve/dYrKqCU05Ry11E4iFnuLv7FTmW/xD4YcEqitCECbBsWdRViIh0n65QTTN+PHzwAezeHXUlIiLdo3BPoytVRSQuFO5pdFBVROJC4Z6muhrGjoUnnoDm5qirERHpOoV7hn/5F3jvPXjssagrERHpOoV7hksvhSlT4Fvfgv37o65GRKRrFO4ZzGDu3ODe7t/7XtTViIh0jcI9i8mTYebMIOT/8peoqxER6TyFezu+8x04eBDuvjvqSkREOk/h3o5PfQquvRYefTQ4wCoi0pco3Dtw550wcCDcdlvUlYiIdI7CvQO1tUGwP/ccvP561NWIiORP4Z7DzTfDX/0V3Hor9L17XYpIUinccxg4MLiwaelSWLQo6mpERPKjcM/DVVfBWWfB7bfDoUNRVyMikpvCPQ8lJXD//bBxI8yfH3U1IiK5KdzzNH06XHxxcN77nj1RVyMi0jGFe57Mgtb7zp3wuc/Bli1RVyQi0j6FeydMmABPPw1r18K558If/hB1RSIi2SncO2nGDHjzTSgvhwsugKeeiroiEZFjKdy7YOzYoNU+eTL8wz8EFzq1tERdlYhIK4V7F9XUwG9+A//0T3DffXDZZbB3b9RViYgEFO7dUF4OjzwCDz8ML70UtOQ3boy6KhERhXtBXHstvPxy8ICPc8+Fu+7S2TQiEi2Fe4FcdFHQDz9lCtxzD4waBVdcERx81T1pRKS35Qx3M3vczLab2Zp2lpuZzTOzjWa22szGF77MvuGUU+CFF2DDBrjhBnjxRfibv4FJk+CnPw0e/iEi0hvyabn/GJjWwfLpwGnhMBv4UffL6ttOPRUeeAA2b4aHHgoetH3llTByZHBmzeLF8MknUVcpInGWM9zd/TVgVwerXAb8xANvAVVmNrxQBfZllZVw3XWwbl1wZs2kScFzWadOhaoq+Mxngv55hb2IFFppAT5jBLAp7XVDOG9r5opmNpugdc/IkSML8NV9gxn83d8Fw9698MYb8OqrwXDvvcEthcvL4dOfDs64OeOM1uG446KuXkT6okKEu2WZl/UQorvPB+YDTJw4MZGHGYcMCW5CNn168Do97JcsgR/8oO1thYcPbxv2p5wCdXXBUFUV7DhERDIVItwbgJPSXtcBOhEwT5lh39wMH3wA69e3HX7yE9i3r+17BwxoDfoRI4LxCSfA8ce3HWpqoLQQf9Mi0mcU4r/888D1ZrYA+DSwx92P6ZKR/JSWwmmnBcOll7bOdw8O0P75z9DQEEw3NLROv/56cG794cPHfqYZDBsWPBO2qiro6kmNM6ePOy7Y4QwZ0jp93HFQVtZ720BEui9nuJvZz4ALgRozawC+BZQBuPsjwIvAJcBG4GPgf/ZUsUlm1tpKb8+RI9DUBNu3Zx8aG4Plu3YFvw727AmGAwdyf39FRRD0gwcHQ2XlsdOVlcEwaFD744EDg/GgQdC/v7qVRHpKznB39ytyLHdgTsEqki7r1y9ooQ8bBqefnv/7Dh5sDfq9e9sf790bdA3t3x+MGxuDncS+fa3zjxzpXL3pYZ++I8ic7sxQXq6dhoh6YoX+/Vv757vDPdhRfPRREPT797dOf/RR2+Hjj4+dl/6+v/yl7fs//jj/OkpLs+8g0n9dpE+n/wrJNlRWBo9aFOlLFO5SMGZB901FBVRXF/azW1qCgE+Ff2pI/yWRuYPIHLZsaV0/NeR7a4hBg1q7pVLHJFLTqZ1A+nRH83RwW3qD/plJn1BS0hqOhXLkSHDxWKpbqb0h1R2Vmk69bmxsuzzbwexsBgw4dgeQfvA6fbq9eUOGaCchHdM/D0msfv1a+/tPPLH7n3fwYNsdQbadQ3s7j02b2h7jaG7O/X2pXxPpZzplnvmUOaQv1w4i3vRXK1Ig/fsHp5vW1nbvc9yDXxTpB7lToZ/tYHdqaGqCP/2p9XU+t7SorAwCv6oKhg5tO842L31cWakD18VM4S5SZMyCs4gGDgyuUO6qQ4fa7gCamtruDFLzUsPu3cEviNWrW9ftSElJEPLV1cEZWpnjzOnU68GDtVPoDQp3kZgqLw+uTq6p6dr7W1qCnUMq+Hfvbp1OjXftCoadO4OL6VavDl7v39/+55aUHBv47U1XVwf1V1cHOzvtFPKncBeRrFIt86FD4eSTO/feQ4dagz99B5Dt9ebN8O67weuOdgr9+7cGfSr0cw0DB3ZvG/RlCncRKbjy8uAgdWcPVB86FPwi2Lmz7bBjR9vpHTuCXwk7dgQ7iPZOaR04sPU4SLYhtRNI7SyGDg0OtMeBwl1EikZ5eXDzuxNOyP89LS3BDiEV+jt2BKeppsapYds2WLMmmG7vlhupq7xraoLwP/74jsfV1cV7gZvCXUT6tJKSzh1bcA8udkv/FZBtaGwMHrSzfXv7vw7MgoDPDP7jj2/dSaUPlZWF/bN3ROEuIoli1nrbiVGj8ntPc3OwM2hsbHsjvtSQev3uu8F4VzvPrhs0KAj5OXPga18r3J8pG4W7iEgOpaWd6y46fDgI/G3bsg+FuGguF4W7iEiBlZUFD9AZMSK6GmJyXFhERNIp3EVEYkjhLiISQwp3EZEYUriLiMSQwl1EJIYU7iIiMaRwFxGJIfN8nxBc6C82awT+1MEqNcCOXiqns1Rb16i2rlFtXRPX2ka5e87nfUUW7rmY2TJ3nxh1Hdmotq5RbV2j2rom6bWpW0ZEJIYU7iIiMVTM4T4/6gI6oNq6RrV1jWrrmkTXVrR97iIi0nXF3HIXEZEuKrpwN7NpZvZHM9toZrdFXU86M/vQzN41s5VmtiziWh43s+1mtiZt3jAze9nMNoTjoUVU2z+b2eZw2600s0siqu0kM1tiZuvNbK2Z3RTOj3zbdVBb5NvOzCrM7G0zWxXWdnc4/2QzWxput4VmVl5Etf3YzD5I2271vV1bWo0lZvaOmf0qfN3z283di2YASoD3gVOAcmAVcGbUdaXV9yFQE3UdYS2fAcYDa9Lm3Q/cFk7fBtxXRLX9M/D1Ithuw4Hx4fRg4P8BZxbDtuugtsi3HWBAZThdBiwFJgM/B2aF8x8Bri2i2n4MzIz631xY19eA/wv8Knzd49ut2Fruk4CN7v5f7n4IWABcFnFNRcndXwMyn9R4GfBkOP0kMKNXiwq1U1tRcPet7r4inN4HrAdGUATbroPaIueB/eHLsnBw4GLgl+H8qLZbe7UVBTOrAz4P/Hv42uiF7VZs4T4C2JT2uoEi+ccdcuA3ZrbczGZHXUwWJ7j7VgiCAjg+4noyXW9mq8Num0i6jNKZ2WjgHIKWXlFtu4zaoAi2Xdi1sBLYDrxM8Cu7yd2bw1Ui+/+aWZu7p7bbveF2e9DM+kdRG/B94BvAkfB1Nb2w3Yot3C3LvKLZAwPnu/t4YDowx8w+E3VBfciPgFOBemAr8L0oizGzSmARcLO7742ylkxZaiuKbefuLe5eD9QR/Mo+I9tqvVtV+KUZtZnZGOB24HTgXGAY8L97uy4z+3tgu7svT5+dZdWCb7diC/cG4KS013XAlohqOYa7bwnH24FnCP6BF5NtZjYcIBxvj7ieo9x9W/gf8AjwKBFuOzMrIwjPp9z96XB2UWy7bLUV07YL62kCXiXo164ys9JwUeT/X9NqmxZ2c7m7HwSeIJrtdj5wqZl9SNDNfDFBS77Ht1uxhfsfgNPCI8nlwCzg+YhrAsDMBpnZ4NQ08DlgTcfv6nXPA1eF01cBz0VYSxup4Ax9kYi2Xdjf+Riw3t0fSFsU+bZrr7Zi2HZmVmtmVeH0AOCzBMcElgAzw9Wi2m7ZansvbWdtBH3avb7d3P12d69z99EEebbY3f8HvbHdoj6KnOWo8iUEZwm8D/yfqOtJq+sUgrN3VgFro64N+BnBT/TDBL94vkrQl/dbYEM4HlZEtf0UeBdYTRCkwyOqbQrBT+DVwMpwuKQYtl0HtUW+7YCzgXfCGtYAd4XzTwHeBjYCvwD6F1Fti8Pttgb4D8IzaqIagAtpPVumx7ebrlAVEYmhYuuWERGRAlC4i4jEkMJdRCSGFO4iIjGkcBcRiSGFu4hIDCncRURiSOEuIhJD/x8UYvW3JHYhtQAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "df=pd.DataFrame(model.learning_process).iloc[:,:2]\n", "df.columns=['epoch', 'train_RMSE']\n", "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAIABJREFUeJzt3XuczXX+wPHXe8YwLl1cpsIktiSEoSGKsC4hSSXRijZlt1y6EKl0UW02FdnVxZassknpuokUVhfSyCUSM+6DEOlHovD5/fH+DscYM2du53su7+fjcR4z5/v9njPvb0fv7/d8Lu+POOcwxhgTG+L8DsAYY0zoWNI3xpgYYknfGGNiiCV9Y4yJIZb0jTEmhljSN8aYGGJJ3xhjYoglfWOMiSGW9I0xJoaU8DuA7CpVquSqV6/udxjGGBNRFi9e/KNzLimv48Iu6VevXp20tDS/wzDGmIgiIhuDOS7P5h0RmSgiO0RkxUn2i4iME5EMEVkuIo0C9vURkXTv0Sf48I0xxhSHYNr0JwEdctnfEajpPfoBzwOISAXgIeBioAnwkIiUL0ywxhhjCifPpO+cmw/szuWQq4DJTi0ETheRysDlwGzn3G7n3E/AbHK/eBhjjClmRdGmXxXYHPA809t2su3GmCjz+++/k5mZyYEDB/wOJeolJiaSnJxMQkJCgV5fFElfctjmctl+4huI9EObhqhWrVoRhGSMCaXMzExOOeUUqlevjkhO/+ubouCcY9euXWRmZlKjRo0CvUdRjNPPBM4OeJ4MbM1l+wmccxOcc6nOudSkpDxHHBljwsyBAweoWLGiJfxiJiJUrFixUN+oiiLpvw/09kbxNAV+ds5tA2YB7UWkvNeB297bZoyJQpbwQ6Ow/53zbN4RkdeBVkAlEclER+QkADjnXgBmAJ2ADGA/8Gdv324ReRT42nurkc653DqEi8BeIB1olNeBxhgTk/JM+s65nnnsd0D/k+ybCEwsWGgFMQEYgg4Sug9oEbo/bYwxESDKau/cAvwNWAxchib9j3yNyBhT/Pbs2cNzzz2X79d16tSJPXv25Pt1N910EzVq1CAlJYUGDRrw6aefHt3XqlUrqlWrht4Pq65du1KuXDkAjhw5wqBBg7jwwgupV68ejRs3Zv369YBWJKhXrx4pKSmkpKQwaNCgfMeWl7Arw1A4pwHDgTuAl4EngafQ+WNw8kFFxphIlpX0b7/99uO2Hz58mPj4+JO+bsaMGQX+m6NHj6Zbt27MnTuXfv36kZ6efnTf6aefzhdffEHz5s3Zs2cP27ZtO7rvjTfeYOvWrSxfvpy4uDgyMzMpW7bs0f1z586lUqVKBY4rL1GW9LOUAQYCfwF2etu2oF0PdwM34HVLGGOK2J13wtKlRfueKSkwduzJ9997772sXbuWlJQUEhISKFeuHJUrV2bp0qV89913dO3alc2bN3PgwAHuuOMO+vXrBxyr9bVv3z46duxI8+bN+fLLL6latSrvvfcepUuXzjO2Zs2asWXLluO29ejRg6lTp9K8eXPefvttrrnmGlauXAnAtm3bqFy5MnFx2tCSnJxcwP8qBRNlzTvZleTYfLDt6F3+TcB5wHjgV3/CMsYUqVGjRnHuueeydOlSRo8ezaJFi3j88cf57rvvAJg4cSKLFy8mLS2NcePGsWvXrhPeIz09nf79+7Ny5UpOP/10pk+fHtTfnjlzJl27dj1uW5s2bZg/fz6HDx9m6tSpXH/99Uf3de/enQ8++ICUlBQGDx7MkiVLjntt69atjzbvjBkzJr//KfIUpXf6OWkELEEHGz0ODPB+pgNlc3mdMSY/crsjD5UmTZocN3lp3LhxvPPOOwBs3ryZ9PR0KlaseNxrstroAS666CI2bNiQ69+45557GDp0KDt27GDhwoXH7YuPj6d58+a88cYb/PrrrwSWi09OTmb16tXMmTOHOXPm0KZNG958803atGkDFH/zTpTf6WcnwBXAF8D/gMEcS/gTgR98issYU5QC28jnzZvHJ598woIFC1i2bBkNGzbMcXJTqVKljv4eHx/PoUOHcv0bo0ePJiMjg8cee4w+fU4sItyjRw8GDhxI9+7dc/xbHTt2ZPTo0dx33328++67+Tm9QomxpJ9F0NE9g73nm4BbgerA7cB6f8IyxhTIKaecwt69e3Pc9/PPP1O+fHnKlCnD999/f8JdeWHExcVxxx13cOTIEWbNOn7uaYsWLRg+fDg9ex4/6v2bb75h61YtTnDkyBGWL1/OOeecU2Qx5SWGmndyUw34Hh3t8xI63r8H8HesRpwx4a9ixYpceumlXHjhhZQuXZozzzzz6L4OHTrwwgsvUL9+fWrVqkXTpk2L9G+LCA888ABPPvkkl19++XHbhwwZcsLxO3bs4NZbb+XgwYOANkUNGDDg6P7WrVsfHXFUv359Jk+eXLTxBo4lDQepqanO35WztgBjgCnASqACsBZN/ok+xmVM+Fq1ahW1a9f2O4yYkdN/bxFZ7JxLzeu1Mdq8k5uq6Nj+jWjCB7geOAPoBbwHWPlYY0xksqR/UiUDfn8MuA6d3dsVSEKbgowx0ax///5Hh09mPV555RW/wyoUa9MPSgfv8QIwF3iLY1WjfwDuArqhM3/L+BGgMaYYjB8/3u8Qipwl/XxJQCtEtw/Y9j3wCTAVTfhXoBeAK4G8Z/MZY0woWfNOobUCtqGJvzc6/v96IGvG3wa05LMxxvjPkn6RKAG0AZ5HFwf7Gl0oDHTmbxLaFzAF+NmPAI0xBrCkXwzigcBRU/ehhd/S0NE/Z6BF37L8FrrQjDExz5J+sbsEeBad9fsFut5MTW/ffrQcdCPgNmAS8B1wJORRGhPJClpPH2Ds2LHs378/12Oy6tzXr1+fli1bsnHjxqP7RIQbb7zx6PNDhw6RlJRE586dAdi+fTudO3emQYMG1KlTh06dOgGwYcMGSpcufdzIoKKeiJUTS/ohE4deAJ5BEzzAQbQUREXgP+hKk3WBcd7+3cBM4PeQRmpMpCnupA9aCG358uW0atWKxx577Oj2smXLsmLFCn79Vav2zp49m6pVj83kf/DBB2nXrh3Lli3ju+++Y9SoUUf3ZVUGzXr07t27QOeQHzZ6x1fl0TkAoHf3a4Cv0IsDaKfwNehFoRtaGuIy7Fptwl+rHLZ1R2tb7UfXtsjuJu/xI/rvPdC8XP9aYD39du3accYZZzBt2jQOHjzI1VdfzSOPPMIvv/xC9+7dyczM5PDhw4wYMYLt27ezdetWWrduTaVKlZg7d26eZ9asWTPGjRt33LaOHTvy4Ycf0q1bN15//XV69uzJZ599Bmj9/Pbtj434q1+/fp5/ozhZ9ggbccAFQB+ONf90QGcAtwdeBVqj8wO2+xGgMWErsJ5+u3btSE9PZ9GiRSxdupTFixczf/58Zs6cSZUqVVi2bBkrVqygQ4cODBo0iCpVqjB37tygEj7kXD8/a9GUAwcOsHz5ci6++OKj+/r370/fvn1p3bo1jz/++NFia8DRC1XWI+tCUZzsTj+slQa6eI9fgA+Az9DOYIAH0IqhPdBmIWPCxbxc9pXJY3+lPPbn7uOPP+bjjz+mYcOGAOzbt4/09HRatGjBkCFDGDZsGJ07d6ZFixb5et/WrVuzfft2zjjjjOOad0Dv3jds2MDrr79+tM0+y+WXX866deuYOXMmH330EQ0bNmTFihXAseadULI7/YhRFk3u4zm2zu/36ELwFwL1gRFAaP8BGRNunHMMHz78aDt5RkYGffv25fzzz2fx4sXUq1eP4cOHM3LkyHy979y5c9m4cSN169blwQcfPGF/ly5dGDJkyAmllAEqVKjADTfcwKuvvkrjxo2ZP39+gc+vsCzpR7S30Kqg49BRQH8DXvP2/Y52Dv/oT2jGhFBgPf3LL7+ciRMnsm/fPgC2bNnCjh072Lp1K2XKlKFXr14MGTKEb7755oTX5qV06dKMHTuWyZMns3v37uP23XzzzTz44IPUq1fvuO1z5sw52lG8d+9e1q5dS7Vq1Qp1voVhzTsR7yx0EfiB6GifrHH/C4A/od8KmqIdZ52AFOxab6JNYD39jh07csMNN9CsWTMAypUrx2uvvUZGRgb33HMPcXFxJCQk8PzzzwPQr18/OnbsSOXKlYNq169cuTI9e/Zk/PjxjBgx4uj25ORk7rjjjhOOX7x4MQMGDKBEiRIcOXKEW265hcaNG7Nhw4ajbfpZbr75ZgYNGlTY/xy5snr6UesIsBj4EF0X+Gtv+//QEUBbAMexmcPGFJzV0w8tq6dvchAHNAYeBhah1UD/DTTz9o9FRwKdD/wVmAbsCHmUxpjQsuadmHEmWhAuS1+gCloq+nXgRXQ+wA70grEcvSiUD22Yxvjo4osvPrqMYZZXX331hHb6SGZJP2Zd4D3uAg4B3wCZHPvydx2Qjn5byCon3RQtL23MiZxziEjeB4axr776yu8Q8lTYJnlr3jHotb8JOvsXtK3/JeBBb98TaD/AbQH7M7yfxkBiYiK7du0qdEIyuXPOsWvXLhITC75et93pmxwI0MJ7PAzsQZuBqnj7VwO1geroN4B2aGlpawqKVcnJyWRmZrJz506/Q4l6iYmJJCcXfABGUElfRDqgpSLjgZecc6Oy7T8HmIgWjt8N9HLOZXr7DgPfeoducs51KXC0xienA1cHPE9C1w74GF0xbAJ6ofgYaIteJOKBU0IbpvFNQkICNWrU8DsME4Q8m3dEJB6dBtoRqAP0FJE62Q57CpjsnKsPjETbA7L86pxL8R6W8KNCRXTEz9voCmGfo98IGnn7X0Tv+i9BS0XMAX4NeZTGmBMF06bfBMhwzq1zzv2G3tpdle2YOsCn3u9zc9hvolYJ4FK0/b+Ct60dMAxt8x+FNv2cxbGJY5vQstLGmFALJulXBTYHPM/0tgVaBlzr/X41cIqIVPSeJ4pImogsFJGu5EBE+nnHpFmbYDRoBDyOzgreDfwXLSFd0tvfE20yag08hH4TyLueuTGm8IJJ+jmNwcreRT8EaCkiS4CW6HTPQ96+at4ssRuAsSJy7glv5twE51yqcy41KSkp+OgDHD4MN90Ec+YU6OWm2JwKXIGWichyHzoS6P/Qi0EbtJhcls+AfaEK0JiYEkxHbiY6SydLMrr691HOua144/1EpBxwrXPu54B9OOfWicg8oCGwttCRZ7NhA3zyCfz739CqFTz6KDRvXtR/xRSNK7wH6ELxX6BVREELxF2GdgTXB+qhVUSvQFsRjTGFEcyd/tdATRGpISIl0Vuy9wMPEJFKIpL1XsPRkTyISHkRKZV1DNr4+11RBR/o3HMhIwOefRZWrYIWLaB9e4iAuRYx7jS0EFxL7/kpwCy0TyAJ+AQYCnzp7V8FNEc7ksejtYR2hTBeYyJbnknfOXcIGID+n7gKmOacWykiI0UkazROK2C1iKxB5/s/7m2vDaSJyDK0g3eUc65Ykj5AYiIMGgTr1sFTT8GSJdC0KXTuDF4VVRP2SqFj/x9H/8ltQZN6d2//L+g/2zfQf5at0EU3PvL2/+wdY4zJSVRX2dy3D/75T3jySfjpJ+jaFR55BHxeotIUCQdsQ6eALAX6ocNE/44uJnMpOoqoHdqxHO9PmMaEiFXZBMqVg3vv1fb+Rx7RTt4GDaB7d/iu2L5vmNAQdIbw5WhTUNZs4LZoPaE9wP3oiOOz0UVlQEcTHQlppMaEk6hO+llOPRUefFCT/wMPwEcfwYUXavJftszv6EzRugi921+CLiA/BRjEsUJxndHhoi287a8AK0MfpjE+iermnZP58UcYMwb+8Q/Yuxe6dNGLQePGxfpnTVh4DViIXhSWoe3/XYF3vP2DgXPQFcbOR7uoIrtypIkNwTbvxGTSz/LTT5r4x47V3y+/HEaMgEsvDcmfN747glYL/R2oi04Qq8Hxi8mURfsIhqGziCcC53qPc7CahSZcWJt+EMqX12afjRth1Cgd4dO8ObRure3/YXY9NEUuDr2br+s9L4OuMLYFXWLyWXSxmaz964Hb0X6E84DS3s8p3v4jWLlpE+5i+k4/u/37YcIEHe2zbRtccok2+3ToABG+NoQpEkfQeYlrsz3uAi5G5xTchM4wbuv9rJLTGxlT5Kx5pxAOHIBXXtG7/02boG5dGDAAbrwRypbN+/UmVi0CnkFrD/7obasDzERHEDmsf8AUF2veKYTERLjtNkhPh0mToFQpfV61Ktx9N6wt8iISJjo0QYvQbkc7ikejzUdZd/t3oR3D9dH5A73QYaVZVqCdyz8Ah0MTsok5lvRzUbIk9OkDaWnwxRfQsaN2/NasqbN8Z82CIzbk25wgDh39MwQdFZQ1MawuWnW8BrAXrTn0TsDrBnuvqwyUQ4efDg7YvxfrMzCFZc07+bR1K7z4oj62b4fzz9emnz59dD6AMQW3BFiH3umvQ2cbJ3Gso7gO2mxUL+DRDCtEZ8Da9IvdwYPw5pt6579okc7+7d0bbr0VUlL8js5Ep/HoheFbtCloP9px/Ar6DaAp2pRUI+DRgOOL5JpoZUk/hBYt0uT/5pt6MWjUCPr2hRtugNNP9zs6E52OoN8GjqD9Br8A16PDStdzbHnKEegKpruBK4EL0DqIF3iPGlhdouhgSd8Hu3fDlCnw8sta3iExEa69Fm65BVq2tGGfJlQcOsFsPXAG8AdgI9AH+B7taM7yAvAXb//LHLsYJKPVS63bL1JY0veRczrR6+WX4T//gZ9/1nr/N9+sbf9Vsy82aUxI7QZWo5XSW6Kzi2eg3wQCRyaUQIefXoauZzAZ7WSu4v2sjHZOJ4YqcJMLS/phYv9+ePtteOkl+N//IC5ORwH16QNXXqnfBowJDwfRshSr0Ulo29BlLZOB/wB3cGz+QZZV6DeDV9BvDVXRi0LWz+7ozOVDWMmK4mVJPwxlZMDEiTr2f9s2OO006NZNJ321aKEXBGPC229o89A29MLQAb3Tnwa85G3bgpa2Bh1mWg4dvvoq2p+Q1adQG525bP/wi4Il/TB2+LDW9nntNZg+HX75BapVgz/9CXr1gjo2As9EvP3oBeA87/kHwLvoN4NV6EWhPLoqmgAPoX0QtdE+iLPQbxjnhjTqSGZJP0L88gu8955eAD7+WC8IjRpp8u/ZE846y+8IjSlqWR3NW9BVzUCXvnwPyAw4riGQtc5pV/TbxVkBj/rAtd7+H9GLSOyORLKkH4G2b4epU/UCkJamzT3t2+vonyuv1BnCxkS3vcAmdIJaHNDa2z4YnZvwg/fYiS6I8763v6q3/SyO9Sd0QpfRBPgYvSAkeo/S6MS3JPQidABdnzlym5os6Ue477/X5P/vf0NmJpxxhnb+9u0LtWr5HZ0xfjuEzkU4xXs+AdiMfnvYgjYtdUJXUfsdTejZc93dwNPAvoD3KQWcCpwG3An0B372fj8t26M5UMt7/QLvtSUDHsnoKm2H0E7yMhRnwT1L+lHi8GFt9vnXv+CDD+DQIe30vfVWnQNQpozfERoT7g6hM5kPeI9fvZ/noc1LB9C1Ew6gfRH/hyb6a4Bu6MXkEm/b3oD3HQcMRL+B1Mvh704E/oyu1NYMXbLzdLQZ6nT0gtQKWIOOfiqPLuFZsCF9lvSj0A8/wOTJOvwzPV1H//Tqpc0/VvrBmFA4jCb+n9FvBOXR2dBL0JFNgY+L0BnPm9Ehr3uAnwJ+PoxeDD5E+yyyvhEUrB3Xkn4Ucw7mz9e7/7fe0tIPqanQr592/pYr53eExpj8cei3jIIv2GH19KOYiJZ1eO01He//j39o4u/XD6pUgf79Yflyv6M0xgRPKEzCzw9L+hGufHkt7bxsGXz5JVx9tU4Aa9BAl3ucPBl+/TXv9zHGxAZL+lFCBJo109E+W7bAmDFaAC6r1s9dd+mIIGNMbLOkH4UqVIA774RVq2DuXB3rP3481K4NrVvrXICDB/2O0hjjB0v6UUwEWrXSJL95MzzxBGzYoJ29yclwzz2wZo3fURpjQsmSfow480y4915d1H3WLO0IHjtWJ3q1agWvvw4HDvgdpTGmuAWV9EWkg4isFpEMEbk3h/3niMinIrJcROaJSHLAvj4iku49+hRl8Cb/sko7vPXWsbv/zZt1la/kZBg82Nr+jYlmeSZ9EYlHF+fsiK7A3FNEsteBfAqY7Jyrj67N9oT32gpo+byLgSbAQyJSvujCN4Vx1ll695+eDrNnwx//COPGadv/ZZfpkFC7+zcmugRzp98EyHDOrXPO/QZMBa7KdkwddIkdgLkB+y8HZjvndjvnfgJmowW4TRiJi4O2bWHaNK3z8/e/w9atWuc/ORmGDYP16/2O0hhTFIJJ+lXRecRZMr1tgZZxrMbp1cApIlIxyNeaMHLmmTB0qHbwzp6tbf9PP63LPV5xBcyYofWAjDGRKZikn1NZuOy1G4YALUVkCbro5ha0kEQwr0VE+olImoik7dy5M4iQTHHLuvufPl1H/IwYoev+XnEF1KwJTz4JP2ZfOc8YE/aCSfqZwNkBz5PRuqVHOee2Oueucc41BO73tv0czGu9Yyc451Kdc6lJSUn5PAVT3JKT4ZFHYONGeOMNXeVr2DDd3rs3LFyo9YCMMeEvmKT/NVBTRGqISEmgB8dWLgBARCqJSNZ7DUdrigLMAtqLSHmvA7e9t81EoJIloXt3mDcPVqzQ6p7vvqszgVNT4dVX4bff/I7SGJObPJO+c+4QupbZLHRxy2nOuZUiMlJEuniHtQJWi8ga4Ezgce+1u4FH0QvH18BIb5uJcHXrwj//qSUfnntO6/v07g3nnAOPPgo7dvgdoTEmJ1Za2RQJ57Tjd+xY+OgjKFVKF3q/806ol9P6EsaYImWllU1IieikrxkztObPzTfrLN/69bVD+L//hSNH/I7SGGNJ3xS5Cy7QJp+sMf+rV+vC7rVqaZPQvn1+R2hM7LKkb4pNhQo65n/dOh31k5QEAwfq6J/779flH40xoWVJ3xS7hAQd9fPll7BggZZ7eOIJ7fS95RZtDjLGhIYlfRNSTZtqsbc1a6BvX5gyBerU0eaf+fNtvL8xxc2SvvHFeedpu/+mTfDwwzrBq2VLuPhiePNNK/VgTHGxpG98lZQEDz2ks32fe06XeOzeXUs9WKevMUXPkr4JC2XKwG236Uif6dO18NvAgXD22VryYfPmvN/DGJM3S/omrMTHwzXXaIfvl1/qGP+nnoIaNXShl6+/9jtCYyKbJX0Ttpo10/b9tWth0CCd4NWkCbRoAW+/be3+xhSEJX0T9qpXh2ee0cleY8boz2uv1Xb/Z5+FvXv9jtCYyGFJ30SMU0/VWj4ZGTrss0oVfZ61tu/GjX5HaEz4s6RvIk58vN7pf/45fPUVdOqkd/x/+IOO/FmwwO8IjQlflvRNRGvSRAu7rV8PQ4Zopc9LLtFJYFOnwu+/+x2hMeHFkr6JCmefrcXdNm/W8f27dkHPnrq275NPwk8/+R2hMeHBkr6JKuXKQf/+Ot7//fd15u+wYXpRGDBA+wOMiWWW9E1UiovTej5z5sCSJdCtG0yYoOWdb7wRvv/e7wiN8YclfRP1UlJg0iQd3XP33TrGv04dney1cqXf0RkTWpb0TcyoXBlGj9ZO36FDtfmnXj0d8bN8ud/RGRMalvRNzDnjDBg1CjZsgPvug5kzoUEDLf+wZInf0RlTvCzpm5hVqRI89pg2+zz0kLb/N2oEXbpYjR8TvSzpm5hXvrzW9N+4ER59VCd9NWkCHTvaRC8TfSzpG+M57TR44AFt9vnb3yAtTSd6tW2rq3oZEw0s6RuTzamnwvDhmvyfegpWrNBVvVq2hE8/tSUdTWSzpG/MSZQtq4Xc1q/X2j4ZGXrXf+ml2vlryd9EIkv6xuShdGmt5792rS7pmJmp7f1NmuiwT0v+JpJY0jcmSImJuqRjRgb8619a3+eqq3TEjyV/Eyks6RuTTyVLwi23aH2fV17RRVyuugoaN4YPP7Tkb8KbJX1jCighAW66CVatgokT9c6/c2ct62xt/iZcWdI3ppASEuDPf4Y1a7TZ54cftM3/0ku1vr8lfxNOgkr6ItJBRFaLSIaI3JvD/moiMldElojIchHp5G2vLiK/ishS7/FCUZ+AMeEiIUGbfdLT4YUXtLZ/+/Zw2WUwd67f0Rmj8kz6IhIPjAc6AnWAniJSJ9thDwDTnHMNgR7AcwH71jrnUrzHX4sobmPCVsmS8Je/aIfv+PGwbh388Y/QqpVN8jL+C+ZOvwmQ4Zxb55z7DZgKXJXtGAec6v1+GrC16EI0JjKVKgW3365DPceN047fli11rP+XX/odnYlVwST9qsDmgOeZ3rZADwO9RCQTmAEMDNhXw2v2+Z+ItMjpD4hIPxFJE5G0nTt3Bh+9MREgMREGDtQ7/meegW+/1fb+Dh10YXdjQimYpC85bMveNdUTmOScSwY6Aa+KSBywDajmNfvcDfxHRE7N9lqccxOcc6nOudSkpKT8nYExEaJ0abjrLk3+o0fD4sU60ueKK7TOjzGhEEzSzwTODniezInNN32BaQDOuQVAIlDJOXfQObfL274YWAucX9igjYlkZcvCkCFa3uGJJ2DhQh3jf9VVsHSp39GZaBdM0v8aqCkiNUSkJNpR+362YzYBbQBEpDaa9HeKSJLXEYyI/AGoCawrquCNiWTlysG992ryf/RR7eRt2BCuvVabgIwpDnkmfefcIWAAMAtYhY7SWSkiI0Wki3fYYOBWEVkGvA7c5JxzwGXAcm/7W8BfnXO7i+NEjIlUp56qJZ3Xr9fFXD75RFfyuvlmrfNjTFESF2YzR1JTU12aNXCaGLZ7t9bz/8c/ID5e+wGGDtV6/8acjIgsds6l5nWczcg1JsxUqKB1/Fevhquv1gvAeefpReC33/yOzkQ6S/rGhKnq1WHKFB3ZU6+elneuWxfeestKO5iCs6RvTJi76CJdsevDD3XC13XX6TKOn3/ud2QmElnSNyYCiECnTrBsGbz8MmzaBC1aaPPP6tV+R2ciiSV9YyJIfLyO6lmzBh57TEf6XHihNv38+KPf0ZlIYEnfmAhUtizcf78WdbvlFi3sdt55OtP3wAG/ozPhzJK+MRHszDPh+eeP1fMZOhRq14Y33rDOXpMzS/pXVBehAAAOD0lEQVTGRIE6dbSjd/ZsnezVowc0a2bVPM2JLOkbE0XatoVvvtHlGzdt0rv/667T8s7GgCV9Y6JOfLwu35ieDg8/DDNmaJPPkCGwb5/f0Rm/WdI3JkqVLau1fNLToVcvePppTf7vvGPt/bHMkr4xUa5KFW3u+eILLfFwzTXQpQts2OB3ZMYPlvSNiRGXXKILtzz9tC7UXqcOjBpl9XxijSV9Y2JIiRJw992wahV07AjDh2sNf1uwPXZY0jcmBp19NkyfDv/9L+zfrwu2//nPYEtURz9L+sbEsCuugJUr9Y7/tdfgggvgpZfgyBG/IzPFxZK+MTGuTBmt2b90qdbxufVWLea2fLnfkZniYEnfGANorf5582DSJC3o1qgRDB4Me/f6HZkpSpb0jTFHiUCfPlqu+ZZbYMwYHdtvC7dED0v6xpgTVKgAL7ygtXuSkrSUQ6dOVs4hGljSN8acVNOm8PXXMHasTu6qWxdGjoSDB/2OzBSUJX1jTK5KlIA77oDvv4euXbW0Q716uoCLiTyW9I0xQalSBaZOhVmztH2/XTst4bxtm9+RmfywpG+MyZf27XXRlocfhnff1bH9//wnHD7sd2QmGJb0jTH5lpiozTzffqvt/gMHwsUXQ1qa35GZvFjSN8YUWM2aMHOmNvts3QpNmsCAAbBnj9+RmZOxpG+MKRQRuP56LeI2YICu2XvBBfCf/9jY/nBkSd8YUyROOw3GjYNFi6BaNfjTn7Szd80avyMzgSzpG2OK1EUXwYIFMH68tvHXqwcPPggHDvgdmYEgk76IdBCR1SKSISL35rC/mojMFZElIrJcRDoF7BvuvW61iFxelMEbY8JTfDzcfruO7e/WDR59FFJSdIKX8VeeSV9E4oHxQEegDtBTROpkO+wBYJpzriHQA3jOe20d73ldoAPwnPd+xpgYcNZZMGWKju0/cECrdw4caEXc/BTMnX4TIMM5t8459xswFbgq2zEOONX7/TRgq/f7VcBU59xB59x6IMN7P2NMDGnfHlas0IQ/fryWcJ41y++oYlMwSb8qsDngeaa3LdDDQC8RyQRmAAPz8VpjTAwoVw6efRY+/1xr+HfooBU9d+/2O7LYEkzSlxy2ZR+I1ROY5JxLBjoBr4pIXJCvRUT6iUiaiKTttPXajIlql1wCS5bAAw/osE4r3RxawST9TODsgOfJHGu+ydIXmAbgnFsAJAKVgnwtzrkJzrlU51xqUlJS8NEbYyJSYqJ27qal6Xq9110H115rdXxCIZik/zVQU0RqiEhJtGP2/WzHbALaAIhIbTTp7/SO6yEipUSkBlATWFRUwRtjIluDBrBwITz5JHz0EdSpA6+8Ynf9xSnPpO+cOwQMAGYBq9BROitFZKSIdPEOGwzcKiLLgNeBm5xaiX4D+A6YCfR3zllZJmPMUSVKwD336Jq89evDzTfrgu1btvgdWXQSF2aX1NTUVJdmVZuMiUlHjsBzz8GwYZCQoB2/vXtrqQeTOxFZ7JxLzes4m5FrjAkbcXFavyfrrv+mm+DKK7WYmykalvSNMWHn3HNh3jxdpnHOHF2m8dVXra2/KFjSN8aEpbg4XaZx2TJN+r1763KNNsKncCzpG2PCWs2a8L//wdNPw8cf6wVgyhS76y8oS/rGmLAXHw933w1Ll0KtWtCrF1xzDezY4XdkkceSvjEmYtSqpWUcRo/Wcf316sGMGX5HFVks6RtjIkp8PAwZorN5zzxTx/QPGAC//up3ZJHBkr4xJiJdeKGu0nXXXVq586KLtPnH5M6SvjEmYiUmwjPPaAfvnj26MPvo0TrJy+TMkr4xJuK1awfffgudO8PQodC2LWzenPfrYpElfWNMVKhYEaZPh5de0maf+vVh2jS/owo/lvSNMVFDBPr21Xr9558P11+vpRz+7//8jix8WNI3xkSdmjV1aOeIEVq+oUEDLedgLOkbY6JUQgKMHAnz5+vvbdrAbbfZouyW9I0xUe3SS3Uo5913w4sv6oSuTz/1Oyr/WNI3xkS9MmW0ds/nn0OpUjq6569/jc22fkv6xpiYcckletc/ZAhMmKB3/bNn+x1VaFnSN8bElNKldQLXF1/o7+3bQ79+sXPXb0nfGBOTmjXToZ333AMvv6xlHWbN8juq4mdJ3xgTs0qXhief1Lv+smWhQ4foL95mSd8YE/OaNtW7/qzibRdfDN9953dUxcOSvjHGcKx424wZ8MMPWrXzxRejb4UuS/rGGBOgY0dYvhxatNBhnd26we7dfkdVdCzpG2NMNmedBTNn6iifDz7QMg7z5/sdVdGwpG+MMTmIi9Px/F9+qU0/rVvDQw/BoUN+R1Y4lvSNMSYXqanwzTdw441ay6dVK9i40e+oCs6SvjHG5OGUU2DSJJgyRdv7GzSAN9/0O6qCsaRvjDFBuuEGLeNwwQXQvbvW7t+3z++o8seSvjHG5MMf/gCffQb33w+vvAKNGkFamt9RBc+SvjHG5FNCAjz2GMydq7N3mzXTmb2RsCB7UElfRDqIyGoRyRCRe3PYP0ZElnqPNSKyJ2Df4YB97xdl8MYY46eWLWHZMujaFYYN0wXat2zxO6rc5Zn0RSQeGA90BOoAPUWkTuAxzrm7nHMpzrkU4B/A2wG7f83a55zrUoSxG2OM7ypU0AXYX3oJFi7UBdnffdfvqE4umDv9JkCGc26dc+43YCpwVS7H9wReL4rgjDEmEmQtyP7NN1C9Olx9tS7NuH+/35GdKJikXxXYHPA809t2AhE5B6gBBC5BnCgiaSKyUES6nuR1/bxj0nbu3Blk6MYYE15q1YIFC2DoUHjhBR3jv3Sp31EdL5ikLzlsO1kJoh7AW865wwHbqjnnUoEbgLEicu4Jb+bcBOdcqnMuNSkpKYiQjDEmPJUsCX//u67ItWePVux8/vnwKdwWTNLPBM4OeJ4MbD3JsT3I1rTjnNvq/VwHzAMa5jtKY4yJMG3b6kSutm3h9tuhd2/45Re/owou6X8N1BSRGiJSEk3sJ4zCEZFaQHlgQcC28iJSyvu9EnApEKVVqo0x5niVKmnBtpEjdTZv06awZo2/MeWZ9J1zh4ABwCxgFTDNObdSREaKSOBonJ7AVOeO+xJTG0gTkWXAXGCUc86SvjEmZsTFwYgRWrVz2zZo3Bjeece/eMSFS0OTJzU11aVF0vQ2Y4wJ0qZNcN11sGiRrs37t79BiRJF894istjrP82Vzcg1xpgQqVZN6/LfdpvW6m/TRlfpCiVL+sYYE0KlSsFzz8HkyfD119CwodbyCRVL+sYY44Mbb4SvvtKyza1bw5gxoRnWaUnfGGN8Uq+e3u136QJ33w3XX1/8RduKqAvBGGNMQZx2GkyfDk8/DT//rKN9ipMlfWOM8ZmIrscbCta8Y4wxMcSSvjHGxBBL+sYYE0Ms6RtjTAyxpG+MMTHEkr4xxsQQS/rGGBNDLOkbY0wMCbvSyiKyE9iYbXMl4EcfwilO0XZO0XY+EH3nFG3nA9F3ToU5n3Occ3muNxt2ST8nIpIWTJ3oSBJt5xRt5wPRd07Rdj4QfecUivOx5h1jjIkhlvSNMSaGRErSn+B3AMUg2s4p2s4Hou+cou18IPrOqdjPJyLa9I0xxhSNSLnTN8YYUwTCLumLyEQR2SEiKwK2VRCR2SKS7v0s72eM+XGS83lYRLaIyFLv0cnPGPNLRM4WkbkiskpEVorIHd72iPyccjmfiP2cRCRRRBaJyDLvnB7xttcQka+8z+gNESnpd6zByOV8JonI+oDPKMXvWPNDROJFZImI/Nd7XuyfT9glfWAS0CHbtnuBT51zNYFPveeRYhInng/AGOdciveYEeKYCusQMNg5VxtoCvQXkTpE7ud0svOByP2cDgJ/dM41AFKADiLSFPg7ek41gZ+Avj7GmB8nOx+AewI+o6X+hVggdwCrAp4X++cTdknfOTcf2J1t81XAv73f/w10DWlQhXCS84lozrltzrlvvN/3ov9oqxKhn1Mu5xOxnNrnPU3wHg74I/CWtz2SPqOTnU/EEpFk4ArgJe+5EILPJ+yS/kmc6ZzbBvo/KHCGz/EUhQEistxr/omIZpCciEh1oCHwFVHwOWU7H4jgz8lrOlgK7ABmA2uBPc65Q94hmUTQxS37+Tjnsj6jx73PaIyIlPIxxPwaCwwFspZCr0gIPp9ISfrR5nngXPRr6jbgaX/DKRgRKQdMB+50zv2f3/EUVg7nE9Gfk3PusHMuBUgGmgC1czostFEVXPbzEZELgeHABUBjoAIwzMcQgyYinYEdzrnFgZtzOLTIP59ISfrbRaQygPdzh8/xFIpzbrv3D/gI8C/0f8iIIiIJaIKc4px729scsZ9TTucTDZ8TgHNuDzAP7a84XURKeLuSga1+xVVQAefTwWuac865g8ArRM5ndCnQRUQ2AFPRZp2xhODziZSk/z7Qx/u9D/Cej7EUWlZi9FwNrDjZseHIa3t8GVjlnHsmYFdEfk4nO59I/pxEJElETvd+Lw20Rfsq5gLdvMMi6TPK6Xy+D7jJELT9OyI+I+fccOdcsnOuOtADmOOc+xMh+HzCbnKWiLwOtEKrzW0HHgLeBaYB1YBNwHXOuYjoHD3J+bRCmwwcsAH4S1ZbeCQQkebAZ8C3HGuPvA9tB4+4zymX8+lJhH5OIlIf7QiMR2/upjnnRorIH9A7ywrAEqCXd5cc1nI5nzlAEto0shT4a0CHb0QQkVbAEOdc51B8PmGX9I0xxhSfSGneMcYYUwQs6RtjTAyxpG+MMTHEkr4xxsQQS/rGGBNDLOkbY0wMsaRvjDExxJK+McbEkP8H+io0ZWuq5r0AAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "df=pd.DataFrame(model.learning_process[10:], columns=['epoch', 'train_RMSE', 'test_RMSE'])\n", "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n", "plt.plot('epoch', 'test_RMSE', data=df, color='yellow', linestyle='dashed')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Saving and evaluating recommendations" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "model.estimations()\n", "\n", "top_n=pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n", "\n", "top_n.to_csv('Recommendations generated/ml-100k/Self_SVD_reco.csv', index=False, header=False)\n", "\n", "estimations=pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n", "estimations.to_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', index=False, header=False)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 8982.19it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
00.9148560.7183840.1004240.0408590.0505230.0674310.0906650.0683680.1013280.0479170.1837920.5171410.4591730.8605510.1464653.8532360.971798
\n", "
" ], "text/plain": [ " RMSE MAE precision recall F_1 F_05 \\\n", "0 0.914856 0.718384 0.100424 0.040859 0.050523 0.067431 \n", "\n", " precision_super recall_super NDCG mAP MRR LAUC \\\n", "0 0.090665 0.068368 0.101328 0.047917 0.183792 0.517141 \n", "\n", " HR Reco in test Test coverage Shannon Gini \n", "0 0.459173 0.860551 0.146465 3.853236 0.971798 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import evaluation_measures as ev\n", "\n", "estimations_df=pd.read_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', header=None)\n", "reco=np.loadtxt('Recommendations generated/ml-100k/Self_SVD_reco.csv', delimiter=',')\n", "\n", "ev.evaluate(test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None),\n", " estimations_df=estimations_df, \n", " reco=reco,\n", " super_reactions=[4,5])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 9603.22it/s]\n", "943it [00:00, 8786.72it/s]\n", "943it [00:00, 8141.95it/s]\n", "943it [00:00, 8884.14it/s]\n", "943it [00:00, 10117.77it/s]\n", "943it [00:00, 8687.46it/s]\n", "943it [00:00, 10361.84it/s]\n", "943it [00:00, 10162.64it/s]\n", "943it [00:00, 8493.19it/s]\n", "943it [00:00, 9153.50it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Self_SVD0.9148560.7183840.1004240.0408590.0505230.0674310.0906650.0683680.1013280.0479170.1837920.5171410.4591730.8605510.1464653.8532360.971798
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Self_GlobalAvg1.1257600.9435340.0611880.0259680.0313830.0413430.0405580.0321070.0676950.0274700.1711870.5095460.3849421.0000000.0259742.7117720.992003
0Ready_Random1.5189641.2221590.0465540.0206030.0236790.0312160.0289700.0211790.0504890.0191850.1238560.5068120.3223750.9878050.1847045.1031720.906873
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Self_SVD 0.914856 0.718384 0.100424 0.040859 0.050523 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n", "0 Ready_Random 1.518964 1.222159 0.046554 0.020603 0.023679 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.067431 0.090665 0.068368 0.101328 0.047917 0.183792 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n", "0 0.031216 0.028970 0.021179 0.050489 0.019185 0.123856 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.517141 0.459173 0.860551 0.146465 3.853236 0.971798 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.509546 0.384942 1.000000 0.025974 2.711772 0.992003 \n", "0 0.506812 0.322375 0.987805 0.184704 5.103172 0.906873 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import imp\n", "imp.reload(ev)\n", "\n", "import evaluation_measures as ev\n", "dir_path=\"Recommendations generated/ml-100k/\"\n", "super_reactions=[4,5]\n", "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embeddings" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1, 2],\n", " [3, 4]])" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "array([[0.4472136 , 0.89442719],\n", " [0.6 , 0.8 ]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x=np.array([[1,2],[3,4]])\n", "display(x)\n", "x/np.linalg.norm(x, axis=1)[:,None]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
codescoreitem_ididtitlegenres
09161.000000917917Mercury Rising (1998)Action, Drama, Thriller
19140.991506915915Primary Colors (1998)Drama
29080.990078909909Dangerous Beauty (1998)Drama
36900.989487691691Dark City (1998)Film-Noir, Sci-Fi, Thriller
43590.988384360360Wonderland (1997)Documentary
58100.987781811811Thirty-Two Short Films About Glenn Gould (1993)Documentary
69170.986770918918City of Angels (1998)Romance
78690.986746870870Touch (1997)Romance
87560.986005757757Across the Sea of Time (1995)Documentary
97320.985919733733Go Fish (1994)Drama, Romance
\n", "
" ], "text/plain": [ " code score item_id id \\\n", "0 916 1.000000 917 917 \n", "1 914 0.991506 915 915 \n", "2 908 0.990078 909 909 \n", "3 690 0.989487 691 691 \n", "4 359 0.988384 360 360 \n", "5 810 0.987781 811 811 \n", "6 917 0.986770 918 918 \n", "7 869 0.986746 870 870 \n", "8 756 0.986005 757 757 \n", "9 732 0.985919 733 733 \n", "\n", " title \\\n", "0 Mercury Rising (1998) \n", "1 Primary Colors (1998) \n", "2 Dangerous Beauty (1998) \n", "3 Dark City (1998) \n", "4 Wonderland (1997) \n", "5 Thirty-Two Short Films About Glenn Gould (1993) \n", "6 City of Angels (1998) \n", "7 Touch (1997) \n", "8 Across the Sea of Time (1995) \n", "9 Go Fish (1994) \n", "\n", " genres \n", "0 Action, Drama, Thriller \n", "1 Drama \n", "2 Drama \n", "3 Film-Noir, Sci-Fi, Thriller \n", "4 Documentary \n", "5 Documentary \n", "6 Romance \n", "7 Romance \n", "8 Documentary \n", "9 Drama, Romance " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item=random.choice(list(set(train_ui.indices)))\n", "\n", "embeddings_norm=model.Qi/np.linalg.norm(model.Qi, axis=1)[:,None] # we do not mean-center here\n", "# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n", "\n", "similarity_scores=np.dot(embeddings_norm,embeddings_norm[item].T)\n", "top_similar_items=pd.DataFrame(enumerate(similarity_scores), columns=['code', 'score'])\\\n", ".sort_values(by=['score'], ascending=[False])[:10]\n", "\n", "top_similar_items['item_id']=top_similar_items['code'].apply(lambda x: item_code_id[x])\n", "\n", "items=pd.read_csv('./Datasets/ml-100k/movies.csv')\n", "\n", "result=pd.merge(top_similar_items, items, left_on='item_id', right_on='id')\n", "\n", "result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# project task 5: implement SVD on top baseline (as it is in Surprise library)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# making changes to our implementation by considering additional parameters in the gradient descent procedure \n", "# seems to be the fastest option\n", "# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n", "# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Ready-made SVD - Surprise implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "import imp\n", "imp.reload(helpers)\n", "\n", "algo = sp.SVD(biased=False) # to use unbiased version\n", "\n", "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVD_reco.csv',\n", " estimations_path='Recommendations generated/ml-100k/Ready_SVD_estimations.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD biased - on top baseline" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "import imp\n", "imp.reload(helpers)\n", "\n", "algo = sp.SVD() # default is biased=True\n", "\n", "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv',\n", " estimations_path='Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 8010.33it/s]\n", "943it [00:00, 7939.12it/s]\n", "943it [00:00, 8331.15it/s]\n", "943it [00:00, 8696.10it/s]\n", "943it [00:00, 8172.62it/s]\n", "943it [00:00, 8807.34it/s]\n", "943it [00:00, 8646.67it/s]\n", "943it [00:00, 7192.36it/s]\n", "943it [00:00, 8888.67it/s]\n", "943it [00:00, 8736.94it/s]\n", "943it [00:00, 8047.44it/s]\n", "943it [00:00, 8326.85it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Ready_SVD0.9528890.7506740.0988340.0478990.0536630.0685810.0878760.0768310.1134460.0541270.2429180.5206770.4888650.9980910.2049064.4403360.952374
0Self_SVD0.9148560.7183840.1004240.0408590.0505230.0674310.0906650.0683680.1013280.0479170.1837920.5171410.4591730.8605510.1464653.8532360.971798
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Ready_SVDBiased0.9398070.7416100.0820780.0326910.0406110.0545030.0733910.0514000.0885310.0397390.1881870.5129980.4231180.9958640.1724394.1766120.963967
0Self_GlobalAvg1.1257600.9435340.0611880.0259680.0313830.0413430.0405580.0321070.0676950.0274700.1711870.5095460.3849421.0000000.0259742.7117720.992003
0Ready_Random1.5189641.2221590.0465540.0206030.0236790.0312160.0289700.0211790.0504890.0191850.1238560.5068120.3223750.9878050.1847045.1031720.906873
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Ready_SVD 0.952889 0.750674 0.098834 0.047899 0.053663 \n", "0 Self_SVD 0.914856 0.718384 0.100424 0.040859 0.050523 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Ready_SVDBiased 0.939807 0.741610 0.082078 0.032691 0.040611 \n", "0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n", "0 Ready_Random 1.518964 1.222159 0.046554 0.020603 0.023679 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.068581 0.087876 0.076831 0.113446 0.054127 0.242918 \n", "0 0.067431 0.090665 0.068368 0.101328 0.047917 0.183792 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.054503 0.073391 0.051400 0.088531 0.039739 0.188187 \n", "0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n", "0 0.031216 0.028970 0.021179 0.050489 0.019185 0.123856 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.520677 0.488865 0.998091 0.204906 4.440336 0.952374 \n", "0 0.517141 0.459173 0.860551 0.146465 3.853236 0.971798 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.512998 0.423118 0.995864 0.172439 4.176612 0.963967 \n", "0 0.509546 0.384942 1.000000 0.025974 2.711772 0.992003 \n", "0 0.506812 0.322375 0.987805 0.184704 5.103172 0.906873 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import imp\n", "imp.reload(ev)\n", "\n", "import evaluation_measures as ev\n", "dir_path=\"Recommendations generated/ml-100k/\"\n", "super_reactions=[4,5]\n", "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }