introduction_to_recommender.../P4. Matrix Factorization.ipynb

1456 lines
76 KiB
Plaintext
Raw Permalink Normal View History

2021-04-16 22:41:06 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Self made SVD"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import helpers\n",
"import pandas as pd\n",
"import numpy as np\n",
"import scipy.sparse as sparse\n",
"from collections import defaultdict\n",
"from itertools import chain\n",
"import random\n",
"import matplotlib.pyplot as plt\n",
"\n",
"train_read = pd.read_csv(\"./Datasets/ml-100k/train.csv\", sep=\"\\t\", header=None)\n",
"test_read = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n",
"(\n",
" train_ui,\n",
" test_ui,\n",
" user_code_id,\n",
" user_id_code,\n",
" item_code_id,\n",
" item_id_code,\n",
") = helpers.data_to_csr(train_read, test_read)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"class SVD:\n",
" def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n",
" self.train_ui = train_ui\n",
" self.uir = list(\n",
" zip(*[train_ui.nonzero()[0], train_ui.nonzero()[1], train_ui.data])\n",
" )\n",
"\n",
" self.learning_rate = learning_rate\n",
" self.regularization = regularization\n",
" self.iterations = iterations\n",
" self.nb_users, self.nb_items = train_ui.shape\n",
" self.nb_ratings = train_ui.nnz\n",
" self.nb_factors = nb_factors\n",
"\n",
" self.Pu = np.random.normal(\n",
" loc=0, scale=1.0 / self.nb_factors, size=(self.nb_users, self.nb_factors)\n",
" )\n",
" self.Qi = np.random.normal(\n",
" loc=0, scale=1.0 / self.nb_factors, size=(self.nb_items, self.nb_factors)\n",
" )\n",
"\n",
" def train(self, test_ui=None):\n",
" if test_ui != None:\n",
" self.test_uir = list(\n",
" zip(*[test_ui.nonzero()[0], test_ui.nonzero()[1], test_ui.data])\n",
" )\n",
"\n",
" self.learning_process = []\n",
" pbar = tqdm(range(self.iterations))\n",
" for i in pbar:\n",
" pbar.set_description(\n",
" f\"Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...\"\n",
" )\n",
" np.random.shuffle(self.uir)\n",
" self.sgd(self.uir)\n",
" if test_ui == None:\n",
" self.learning_process.append([i + 1, self.RMSE_total(self.uir)])\n",
" else:\n",
" self.learning_process.append(\n",
" [i + 1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)]\n",
" )\n",
"\n",
" def sgd(self, uir):\n",
"\n",
" for u, i, score in uir:\n",
" # Computer prediction and error\n",
" prediction = self.get_rating(u, i)\n",
" e = score - prediction\n",
"\n",
" # Update user and item latent feature matrices\n",
" Pu_update = self.learning_rate * (\n",
" e * self.Qi[i] - self.regularization * self.Pu[u]\n",
" )\n",
" Qi_update = self.learning_rate * (\n",
" e * self.Pu[u] - self.regularization * self.Qi[i]\n",
" )\n",
"\n",
" self.Pu[u] += Pu_update\n",
" self.Qi[i] += Qi_update\n",
"\n",
" def get_rating(self, u, i):\n",
" prediction = self.Pu[u].dot(self.Qi[i].T)\n",
" return prediction\n",
"\n",
" def RMSE_total(self, uir):\n",
" RMSE = 0\n",
" for u, i, score in uir:\n",
" prediction = self.get_rating(u, i)\n",
" RMSE += (score - prediction) ** 2\n",
" return np.sqrt(RMSE / len(uir))\n",
"\n",
" def estimations(self):\n",
" self.estimations = np.dot(self.Pu, self.Qi.T)\n",
"\n",
" def recommend(self, user_code_id, item_code_id, topK=10):\n",
"\n",
" top_k = defaultdict(list)\n",
2021-05-07 22:16:28 +02:00
" for nb_user, user_scores in enumerate(self.estimations):\n",
2021-04-16 22:41:06 +02:00
"\n",
" user_rated = self.train_ui.indices[\n",
" self.train_ui.indptr[nb_user] : self.train_ui.indptr[nb_user + 1]\n",
" ]\n",
2021-05-07 22:16:28 +02:00
" for item, score in enumerate(user_scores):\n",
2021-04-16 22:41:06 +02:00
" if item not in user_rated and not np.isnan(score):\n",
" top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n",
" result = []\n",
" # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n",
" for uid, item_scores in top_k.items():\n",
" item_scores.sort(key=lambda x: x[1], reverse=True)\n",
" result.append([uid] + list(chain(*item_scores[:topK])))\n",
" return result\n",
"\n",
" def estimate(self, user_code_id, item_code_id, test_ui):\n",
" result = []\n",
" for user, item in zip(*test_ui.nonzero()):\n",
" result.append(\n",
" [\n",
" user_code_id[user],\n",
" item_code_id[item],\n",
" self.estimations[user, item]\n",
" if not np.isnan(self.estimations[user, item])\n",
" else 1,\n",
" ]\n",
" )\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-05-07 22:16:28 +02:00
"Epoch 39 RMSE: 0.7489999966900885. Training epoch 40...: 100%|██████████| 40/40 [01:02<00:00, 1.57s/it]\n"
2021-04-16 22:41:06 +02:00
]
}
],
"source": [
"model = SVD(\n",
" train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40\n",
")\n",
"model.train(test_ui)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-05-07 22:16:28 +02:00
"<matplotlib.legend.Legend at 0x7ff52c7e5100>"
2021-04-16 22:41:06 +02:00
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-05-07 22:16:28 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbmklEQVR4nO3de3RV9Z338fc3FwmQQCBEpIRbfewaO6BBIuqIXVQ7LkGtjGMtrjXeVldZtVhx6rXWtqPLdlwy1Y7VljoVpz61RTteqo629YJaa1EDBhTBB2ytRLJMDCQQlEvg+/yxzyEn4SQ5JznJPtn781rrt/Y5e+9z9jdb/Ox9fvtm7o6IiAx9BWEXICIiuaFAFxGJCAW6iEhEKNBFRCJCgS4iEhFFYS143LhxPnXq1LAWLyIyJK1evfojd69MNy20QJ86dSq1tbVhLV5EZEgys791N01dLiIiEaFAFxGJCAW6iEhEhNaHLiJDz759+6ivr2f37t1hlxJ5JSUlVFVVUVxcnPFnFOgikrH6+nrKysqYOnUqZhZ2OZHl7jQ3N1NfX8+0adMy/py6XEQkY7t376aiokJhPsDMjIqKiqx/CSnQRSQrCvPB0Zf1POQCfcsWuPJK2Lcv7EpERPLLkAv0NWvgP/8Tli4NuxIRkfwy5AL9nHPgS1+Cm2+GjRvDrkZEBlNLSws/+clPsv7c/PnzaWlpyfpzl1xyCdOmTaO6uppjjz2W55577uC0uXPnMnnyZFIfErRgwQJKS0sBOHDgAFdccQXTp09nxowZHH/88fz1r38FgivlZ8yYQXV1NdXV1VxxxRVZ15bOkDzL5cc/hmefha9+FV58EQqG3GZJRPoiGehf//rXO43fv38/hYWF3X7uqaee6vMyly5dynnnncfKlStZtGgRmzZtOjitvLycP/3pT8yZM4eWlhYaGhoOTnvwwQfZunUr69ato6CggPr6ekaOHHlw+sqVKxk3blyf60pnSAb6+PFwxx1wySWwbBl0+W8rIoPgyiuhri6331ldDT/6UffTr7/+et59912qq6spLi6mtLSUCRMmUFdXx9tvv82CBQvYsmULu3fvZsmSJSxatAjouHdUW1sb8+bNY86cObzyyitMnDiR3/72twwfPrzX2k466SQ++OCDTuMWLlzIihUrmDNnDo888gjnnnsu69evB6ChoYEJEyZQkNjjrKqq6tM6ycaQ3be96CI4/XS47jp4//2wqxGRwXDrrbdy5JFHUldXx9KlS3nttdf4/ve/z9tvvw3A8uXLWb16NbW1tdx55500Nzcf8h2bNm1i8eLFrF+/nvLych5++OGMlv273/2OBQsWdBp32mmn8dJLL7F//35WrFjBl7/85YPTzj//fJ544gmqq6u56qqreOONNzp99vOf//zBLpc77rgjyzWR3pDcQwcwg5/9DKZPh699Df73f4NxIjI4etqTHiyzZ8/udOHNnXfeyaOPPgrAli1b2LRpExUVFZ0+k+wTB5g1axbvvfdej8u45ppruPbaa2lsbGTVqlWdphUWFjJnzhwefPBBPvnkE1JvCV5VVcU777zD888/z/PPP89pp53Gb37zG0477TRgYLpchuweOsDUqfCDH8DTT8OvfhV2NSIy2FL7pF944QWeffZZ/vznP7N27VpmzpyZ9sKcYcOGHXxdWFhIe3t7j8tYunQpmzdv5pZbbuHiiy8+ZPrChQv5xje+wfnnn592WfPmzWPp0qXccMMNPPbYY1n8ddkb0oEOsHgxnHgiLFkCTU1hVyMiA6msrIydO3emndba2sqYMWMYMWIEGzduPGRvuj8KCgpYsmQJBw4c4Pe//32naaeccgrf+ta3uOCCCzqNX7NmDVu3bgWCM17WrVvHlClTclZT2joH9NsHQWEh3Hsv7NwZhLqIRFdFRQUnn3wy06dP55prruk07YwzzqC9vZ1jjjmG73znO5x44ok5XbaZceONN3LbbbcdMv7qq68+pPuksbGRs88+m+nTp3PMMcdQVFTE5ZdffnB6ah/6RRddlJsaU8+hHEw1NTWeyycW3XwzfO978MQTcNZZOftaEUmxYcMGjj766LDLiI1069vMVrt7Tbr5h/weetL11wcHSC+7DHbsCLsaEZHB12ugm1mJmb1mZmvNbL2Z3ZRmnrlm1mpmdYn23YEpt3uHHRZ0vWzdGoS7iEimFi9efLD7I9nuu+++sMvKWianLe4BTnX3NjMrBl42s6fdvesRhz+6e6idHbNnw6WXws9/DnffrdMYRQaCu0fujot333132CUcoi/d4b3uoXugLfG2ONHC6XjPwGc/G9yJUd0uIrlXUlJCc3Nzn8JGMpd8wEVJSUlWn8vowiIzKwRWA/8HuNvdX00z20lmthbYClzt7uvTfM8iYBHA5MmTsyo0U8lrCJqbYfToAVmESGxVVVVRX19Pk84RHnDJR9BlI6NAd/f9QLWZlQOPmtl0d38rZZY1wJREt8x84DHgqDTfcw9wDwRnuWRVaYZSA/3Tnx6IJYjEV3FxcVaPRJPBldVZLu7eArwAnNFl/I5kt4y7PwUUm1lur2nNUGqgi4jESSZnuVQm9swxs+HAF4CNXeY5whJHScxsduJ7Q4lUBbqIxFUmXS4TgF8k+tELgIfc/Ukz+xqAuy8DzgMuM7N24BNgoYd01GTs2GCoQBeRuOk10N19HTAzzfhlKa/vAu7KbWl9M2ZMcLqiAl1E4iYyV4omFRZCebkCXUTiJ3KBDkE/ugJdROJGgS4iEhEKdBGRiFCgi4hEhAJdRCQiIhvobW2wd2/YlYiIDJ7IBjrAtm3h1iEiMpgiHejqdhGROFGgi4hEhAJdRCQiIhnoukGXiMRRJANde+giEkeRDPQRI2DYMAW6iMRLJAPdTBcXiUj8RDLQQYEuIvGjQBcRiQgFuohIRCjQRUQiItKBvm0bhPOoahGRwRfpQG9vh507w65ERGRwRDrQQd0uIhIfCnQRkYhQoIuIRESvgW5mJWb2mpmtNbP1ZnZTmnnMzO40s81mts7MjhuYcjOnG3SJSNwUZTDPHuBUd28zs2LgZTN72t1XpcwzDzgq0U4AfpoYhkZ76CISN73uoXugLfG2ONG6ngx4DnB/Yt5VQLmZTchtqdkZMyYYKtBFJC4y6kM3s0IzqwMagWfc/dUus0wEtqS8r0+M6/o9i8ys1sxqm5qa+lhyZoqKoLxcgS4i8ZFRoLv7fnevBqqA2WY2vcsslu5jab7nHnevcfeaysrKrIvNlq4WFZE4yeosF3dvAV4AzugyqR6YlPK+Ctjan8JyQYEuInGSyVkulWZWnng9HPgCsLHLbI8DFyXOdjkRaHX3hlwXmy0FuojESSZnuUwAfmFmhQQbgIfc/Ukz+xqAuy8DngLmA5uBj4FLB6jerFRUwIYNYVchIjI4eg10d18HzEwzflnKawcW57a0/tMeuojESWSvFIUg0HfuhL17w65ERGTgRT7QAbZvD7cOEZHBEItAV7eLiMSBAl1EJCIiHei6QZeIxEmkA1176CISJwp0EZGIiHSgjxwJhx2mQBeReIh0oJvp4iIRiY9IBzoo0EUkPhToIiIRoUAXEYkIBbqISETEJtD9kOcniYhESywCvb0d2tp6n1dEZCiLRaCDul1EJPoU6CIiERH5QNcNukQkLiIf6NpDF5G4UKCLiERE5ANdXS4iEheRD/SiIhg9WoEuItEX+UAHXS0qIvGgQBcRiYheA93MJpnZSjPbYGbrzWxJmnnmmlmrmdUl2ncHpty+UaCLSBwUZTBPO3CVu68xszJgtZk94+5vd5nvj+5+Vu5L7L+KCnjnnbCrEBEZWL3uobt7g7uvSbzeCWwAJg50YbmkPXQRiYOs+tDNbCowE3g1zeSTzGytmT1tZn/fzecXmVmtmdU2NTVlX20fVVTAjh2wb9+gLVJEZNBlHOhmVgo8DFzp7ju6TF4DTHH3Y4EfA4+l+w53v8fda9y9prKyso8lZy95cdH27YO2SBGRQZdRoJtZMUGYP+Duj3Sd7u473L0t8fopoNjMxuW00n7Q1aIiEgeZnOViwL3ABne/vZt5jkjMh5n
2021-04-16 22:41:06 +02:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"df = pd.DataFrame(model.learning_process).iloc[:, :2]\n",
"df.columns = [\"epoch\", \"train_RMSE\"]\n",
"plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-05-07 22:16:28 +02:00
"<matplotlib.legend.Legend at 0x7ff52f336dc0>"
2021-04-16 22:41:06 +02:00
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2021-05-07 22:16:28 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyp0lEQVR4nO3de3zO5f/A8dd7B8wmNoYxzJk5DXMqhaQIoa9EB5VKwpcOVPrp+O3gmw4o36JzVOigFKGYFEUbm7M2TGbOQ5jTtuv3xzVrGO6x+dz3vffz8djD7s/hvt/XPvXetetzfd6XGGNQSinlvXycDkAppVTh0kSvlFJeThO9Ukp5OU30Sinl5TTRK6WUl/NzOoC8lCtXzkRERDgdhlJKeYy4uLi9xpjQvPa5ZaKPiIggNjbW6TCUUspjiMjWc+3ToRullPJyF0z0IvKBiOwWkTXn2C8iMkFEkkRklYg0y7Wvs4hszN73REEGrpRSyjWu9Og/AjqfZ38XoHb210DgbQAR8QUmZu+PBPqJSOSlBKuUUir/LjhGb4xZLCIR5zmkB/CJsbUUfheRMiISBkQAScaYzQAiMi372HWXHLVSyq2cPHmSlJQUjh075nQoXq9EiRKEh4fj7+/v8jkFcTO2MrAt1+uU7G15bW91rjcRkYHYvwioWrVqAYSllLpcUlJSKFWqFBEREYiI0+F4LWMM+/btIyUlherVq7t8XkHcjM3rqprzbM+TMWayMSbaGBMdGprnDCGllJs6duwYZcuW1SRfyESEsmXL5vsvp4Lo0acAVXK9DgdSgWLn2K6U8kKa5C+Pi/k5F0SPfhbQP3v2TWvgoDFmB/AHUFtEqotIMaBv9rGFJstksWLHisL8CKWU8jiuTK/8HPgNqCsiKSJyr4gMEpFB2YfMATYDScC7wGAAY0wGMBSYB6wHZhhj1hZCG3J8sfYLmk9uTp8v+rB5/+bC/CillPIYF0z0xph+xpgwY4y/MSbcGPO+MeYdY8w72fuNMWaIMaamMaaRMSY217lzjDF1sve9WJgNAehapyvPtHuG2YmzqT+xPiPmj2D/0f2F/bFKKYcdOHCA//3vf/k+78Ybb+TAgQP5Pu/uu++mevXqREVF0aRJExYsWJCzr3379lStWpXcizr17NmToKAgALKyshg2bBgNGzakUaNGtGjRgi1btgC2KkCjRo2IiooiKiqKYcOG5Tu2vLhlCYSLFVQsiGfbP8v9ze5ndMxoXv/tdRZvXcyy+5bp+KFSXuxUoh88ePBp2zMzM/H19T3neXPmzLnozxw7diy9e/cmJiaGgQMHkpiYmLOvTJkyLFmyhLZt23LgwAF27NiRs2/69OmkpqayatUqfHx8SElJITAwMGd/TEwM5cqVu+i48uJVif6UyldU5sMeHzK81XAOHDuAiJB+Mp2FWxbStXZXTfpKFaKHHoL4+IJ9z6goGDfu3PufeOIJNm3aRFRUFP7+/gQFBREWFkZ8fDzr1q2jZ8+ebNu2jWPHjjF8+HAGDhwI/FNX6/Dhw3Tp0oW2bduydOlSKleuzLfffktAQMAFY2vTpg3bt28/bVvfvn2ZNm0abdu25euvv+bmm29m7Vo7cr1jxw7CwsLw8bEDKuHh4Rf1M8kPr651E1UxivYR7QH4YOUHdP+8Ox0+7qA3bJXyMmPGjKFmzZrEx8czduxYli9fzosvvsi6dfb5zA8++IC4uDhiY2OZMGEC+/btO+s9EhMTGTJkCGvXrqVMmTJ89dVXLn323Llz6dmz52nbOnbsyOLFi8nMzGTatGnceuutOfv69OnDd999R1RUFI8++igrV6487dwOHTrkDN288cYb+fxJ5M0re/R5eaD5A/iID88seoboydH0a9SPQc0HcXW1q50OTSmvcr6e9+XSsmXL0x4omjBhAjNnzgRg27ZtJCYmUrZs2dPOOTXmDtC8eXOSk5PP+xkjR47kscceY/fu3fz++++n7fP19aVt27ZMnz6do0ePkrvsenh4OBs3bmThwoUsXLiQjh078sUXX9CxY0egcIZuvLpHn5u/rz+DWwwm6d9JjLxyJN9u+JYXf/nn/nDK3ykORqeUKki5x7wXLVrETz/9xG+//UZCQgJNmzbN84Gj4sWL53zv6+tLRkbGeT9j7NixJCUl8cILL3DXXXedtb9v3778+9//pk+fPnl+VpcuXRg7dixPPvkk33zzTT5al39FJtGfUrpEaf7b6b/sGrGLd7q9A9gkX21cNVq914qJyyeyN32vw1EqpfKjVKlSHDp0KM99Bw8eJDg4mJIlS7Jhw4azet+XwsfHh+HDh5OVlcW8efNO23f11VczatQo+vXrd9r2FStWkJpqnx3Nyspi1apVVKtWrcBiyjPOQn13NxZYLJCIMhH2e/9AXrnuFY5lHGPoD0MJey2MHtN6sHHvRmeDVEq5pGzZslx11VU0bNiQkSNHnravc+fOZGRk0LhxY5566ilat25doJ8tIowePZpXXnnlrO0jRow4axhm9+7ddO/enYYNG9K4cWP8/PwYOnRozv7cY/T9+/cvmBhzz/V0F9HR0capFaYSdiYwZdUUpq+dzvL7lhNWKoyEnQkEFQuiZkhNR2JSyt2tX7+e+vXrOx1GkZHXz1tE4owx0XkdX2R79OfSpGITXr3+VbY+tJWwUmEAPDr/UWq/WZue03oSsyUGd/zlqJRS56KJ/hx85J8fzZReUxh9zWiWbFvCtZ9cS9SkKGaun+lgdEqpy2HIkCE5wyinvj788EOnw8q3IjO98lKElQrj+Q7PM6rtKD5f8znjfh/HXwf/AuBYxjH2H92f0/tXSnmPiRMnOh1CgdAefT4E+AcwoOkAEgYlMKTlEAA+W/0Z1cZVo//M/izfvlyHdZRSbkcT/UUQEfx87B9D7SPa82D0g8zcMJNW77Wi6riqDJ49mIys88/BVUqpy0UT/SWqEVyD8V3Gk/JwCh/c9AEtKrVgze41Ob8Ixvw6hg9WfsCuw7scjlQpVVTpGH0BKV2iNPc0vYd7mt6TM3yTZbL4KP4jNu7biCC0rNyS7nW60zuyN3XL1XU4YqVUUaE9+kJwqjqmj/iwfsh6Vj6wkufaP0eWyWJ0zGg+W/0ZACcyT5CUluRkqEp5hYutRw8wbtw40tPTz3vMqTrxjRs3pl27dmzdujVnn4hw55135rzOyMggNDSUbt26AbBr1y66detGkyZNiIyM5MYbbwQgOTmZgICA02b0fPLJJxfVhgsyxrjdV/PmzY23Sv071ew8tNMYY8ysDbMMz2JavdvKTPh9gtl1eJfD0Sl1cdatW+fo52/ZssU0aNDgos6tVq2a2bNnj8vHPP300+a+++7L2RcYGGiioqJMenq6McaYOXPmmCZNmpiuXbsaY4wZOHCgGTduXM7xCQkJlxxzXj9vINacI6fq0M1llnsaZovKLXjlulf4dPWnDJs7jIfnPUynmp2Y2msqZUuWPc+7KOXe2n/U/qxtfRr0YXCLwaSfTOfGT288a//dUXdzd9Td7E3fS+8ZvU/bt+juRef9vNz16Dt16kT58uWZMWMGx48fp1evXjz33HMcOXKEPn36kJKSQmZmJk899RS7du0iNTWVDh06UK5cOWJiYi7YtjZt2jBhwoTTtnXp0oXZs2fTu3dvPv/8c/r168cvv/wC2Prz119/fc6xjRs3vuBnFDQdunFQxaCKjLxqJPGD4lnz4Boeu+ox0k+mExwQDMAnCZ8wc/1Mdh/Z7XCkSrm33PXoO3XqRGJiIsuXLyc+Pp64uDgWL17M3LlzqVSpEgkJCaxZs4bOnTszbNgwKlWqRExMjEtJHvKuP39qoZFjx46xatUqWrVqlbNvyJAh3HvvvXTo0IEXX3wxp6AZkPPL6dTXqV8OBU179G6iQfkGvNTxpZzXxhjG/DqG9XvXA1C9THVahbeiR90e9G3Y16kwlXLJ+XrgJf1Lnnd/uZLlLtiDP5/58+czf/58mjZtCsDhw4dJTEzk6quvZsSIETz++ON069aNq6/O31oUHTp0YNeuXZQvX54XXnjhtH2NGzcmOTmZzz//PGcM/pQbbriBzZs3M3f
2021-04-16 22:41:06 +02:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"df = pd.DataFrame(\n",
" model.learning_process[10:], columns=[\"epoch\", \"train_RMSE\", \"test_RMSE\"]\n",
")\n",
"plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n",
"plt.plot(\"epoch\", \"test_RMSE\", data=df, color=\"green\", linestyle=\"dashed\")\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Saving and evaluating recommendations"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model.estimations()\n",
"\n",
"top_n = pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n",
"\n",
"top_n.to_csv(\n",
" \"Recommendations generated/ml-100k/Self_SVD_reco.csv\", index=False, header=False\n",
")\n",
"\n",
"estimations = pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n",
"estimations.to_csv(\n",
" \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\",\n",
" index=False,\n",
" header=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-05-07 22:16:28 +02:00
"943it [00:00, 8683.10it/s]\n"
2021-04-16 22:41:06 +02:00
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
2021-05-07 22:16:28 +02:00
" <td>0.914393</td>\n",
" <td>0.717199</td>\n",
" <td>0.101697</td>\n",
" <td>0.042334</td>\n",
" <td>0.051787</td>\n",
" <td>0.068811</td>\n",
" <td>0.092489</td>\n",
" <td>0.07236</td>\n",
" <td>0.104839</td>\n",
" <td>0.04897</td>\n",
" <td>0.196117</td>\n",
" <td>0.517889</td>\n",
" <td>0.480382</td>\n",
" <td>0.867338</td>\n",
" <td>0.147186</td>\n",
" <td>3.852545</td>\n",
" <td>0.972694</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2021-05-07 22:16:28 +02:00
" RMSE MAE precision recall F_1 F_05 \\\n",
"0 0.914393 0.717199 0.101697 0.042334 0.051787 0.068811 \n",
2021-04-16 22:41:06 +02:00
"\n",
2021-05-07 22:16:28 +02:00
" precision_super recall_super NDCG mAP MRR LAUC \\\n",
"0 0.092489 0.07236 0.104839 0.04897 0.196117 0.517889 \n",
2021-04-16 22:41:06 +02:00
"\n",
2021-05-07 22:16:28 +02:00
" HR Reco in test Test coverage Shannon Gini \n",
"0 0.480382 0.867338 0.147186 3.852545 0.972694 "
2021-04-16 22:41:06 +02:00
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import evaluation_measures as ev\n",
"\n",
"estimations_df = pd.read_csv(\n",
" \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\", header=None\n",
")\n",
"reco = np.loadtxt(\"Recommendations generated/ml-100k/Self_SVD_reco.csv\", delimiter=\",\")\n",
"\n",
"ev.evaluate(\n",
" test=pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None),\n",
" estimations_df=estimations_df,\n",
" reco=reco,\n",
" super_reactions=[4, 5],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-05-07 22:16:28 +02:00
"943it [00:00, 8505.85it/s]\n",
"943it [00:00, 9544.72it/s]\n",
"943it [00:00, 9154.80it/s]\n",
"943it [00:00, 8282.66it/s]\n",
"943it [00:00, 8432.23it/s]\n",
"943it [00:00, 9601.30it/s]\n",
"943it [00:00, 9158.89it/s]\n",
"943it [00:00, 12283.59it/s]\n",
"943it [00:00, 9500.43it/s]\n",
"943it [00:00, 10085.91it/s]\n",
"943it [00:00, 10260.90it/s]\n",
"943it [00:00, 9691.20it/s]\n"
2021-04-16 22:41:06 +02:00
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
2021-05-07 22:16:28 +02:00
" <td>Ready_SVD</td>\n",
" <td>0.950347</td>\n",
" <td>0.749312</td>\n",
" <td>0.100636</td>\n",
" <td>0.050514</td>\n",
" <td>0.055794</td>\n",
" <td>0.070753</td>\n",
" <td>0.091202</td>\n",
" <td>0.082734</td>\n",
" <td>0.114054</td>\n",
" <td>0.053200</td>\n",
" <td>0.248803</td>\n",
" <td>0.521983</td>\n",
" <td>0.517497</td>\n",
" <td>0.992153</td>\n",
" <td>0.210678</td>\n",
" <td>4.418683</td>\n",
" <td>0.952848</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
2021-04-16 22:41:06 +02:00
" <td>Self_SVD</td>\n",
2021-05-07 22:16:28 +02:00
" <td>0.914393</td>\n",
" <td>0.717199</td>\n",
" <td>0.101697</td>\n",
" <td>0.042334</td>\n",
" <td>0.051787</td>\n",
" <td>0.068811</td>\n",
" <td>0.092489</td>\n",
" <td>0.072360</td>\n",
" <td>0.104839</td>\n",
" <td>0.048970</td>\n",
" <td>0.196117</td>\n",
" <td>0.517889</td>\n",
" <td>0.480382</td>\n",
" <td>0.867338</td>\n",
" <td>0.147186</td>\n",
" <td>3.852545</td>\n",
" <td>0.972694</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
2021-05-07 22:16:28 +02:00
" <td>Ready_SVDBiased</td>\n",
" <td>0.939472</td>\n",
" <td>0.739816</td>\n",
" <td>0.085896</td>\n",
" <td>0.036073</td>\n",
" <td>0.043528</td>\n",
" <td>0.057643</td>\n",
" <td>0.077039</td>\n",
" <td>0.057463</td>\n",
" <td>0.097753</td>\n",
" <td>0.045546</td>\n",
" <td>0.219839</td>\n",
" <td>0.514709</td>\n",
" <td>0.431601</td>\n",
" <td>0.997455</td>\n",
" <td>0.168831</td>\n",
" <td>4.217578</td>\n",
" <td>0.962577</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
2021-04-16 22:41:06 +02:00
" <td>Ready_Random</td>\n",
" <td>1.521845</td>\n",
" <td>1.225949</td>\n",
" <td>0.047190</td>\n",
" <td>0.020753</td>\n",
" <td>0.024810</td>\n",
" <td>0.032269</td>\n",
" <td>0.029506</td>\n",
" <td>0.023707</td>\n",
" <td>0.050075</td>\n",
" <td>0.018728</td>\n",
" <td>0.121957</td>\n",
" <td>0.506893</td>\n",
" <td>0.329799</td>\n",
" <td>0.986532</td>\n",
" <td>0.184704</td>\n",
" <td>5.099706</td>\n",
" <td>0.907217</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>1.030712</td>\n",
" <td>0.820904</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
2021-05-07 22:16:28 +02:00
"0 Ready_SVD 0.950347 0.749312 0.100636 0.050514 0.055794 \n",
"0 Self_SVD 0.914393 0.717199 0.101697 0.042334 0.051787 \n",
2021-04-16 22:41:06 +02:00
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
2021-05-07 22:16:28 +02:00
"0 Ready_SVDBiased 0.939472 0.739816 0.085896 0.036073 0.043528 \n",
2021-04-16 22:41:06 +02:00
"0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
2021-05-07 22:16:28 +02:00
"0 0.070753 0.091202 0.082734 0.114054 0.053200 0.248803 \n",
"0 0.068811 0.092489 0.072360 0.104839 0.048970 0.196117 \n",
2021-04-16 22:41:06 +02:00
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
2021-05-07 22:16:28 +02:00
"0 0.057643 0.077039 0.057463 0.097753 0.045546 0.219839 \n",
2021-04-16 22:41:06 +02:00
"0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR Reco in test Test coverage Shannon Gini \n",
"0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n",
2021-05-07 22:16:28 +02:00
"0 0.521983 0.517497 0.992153 0.210678 4.418683 0.952848 \n",
"0 0.517889 0.480382 0.867338 0.147186 3.852545 0.972694 \n",
2021-04-16 22:41:06 +02:00
"0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n",
2021-05-07 22:16:28 +02:00
"0 0.514709 0.431601 0.997455 0.168831 4.217578 0.962577 \n",
2021-04-16 22:41:06 +02:00
"0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n",
"0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n",
"0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n",
"0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n",
"0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir_path = \"Recommendations generated/ml-100k/\"\n",
"super_reactions = [4, 5]\n",
"test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embeddings"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>code</th>\n",
" <th>score</th>\n",
" <th>item_id</th>\n",
" <th>id</th>\n",
" <th>title</th>\n",
" <th>genres</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
2021-05-07 22:16:28 +02:00
" <td>321</td>\n",
2021-04-16 22:41:06 +02:00
" <td>1.000000</td>\n",
2021-05-07 22:16:28 +02:00
" <td>322</td>\n",
" <td>322</td>\n",
" <td>Murder at 1600 (1997)</td>\n",
" <td>Mystery, Thriller</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
2021-05-07 22:16:28 +02:00
" <td>983</td>\n",
" <td>0.902748</td>\n",
" <td>984</td>\n",
" <td>984</td>\n",
" <td>Shadow Conspiracy (1997)</td>\n",
" <td>Thriller</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
2021-05-07 22:16:28 +02:00
" <td>985</td>\n",
" <td>0.894696</td>\n",
" <td>986</td>\n",
" <td>986</td>\n",
" <td>Turbulence (1997)</td>\n",
" <td>Thriller</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
2021-05-07 22:16:28 +02:00
" <td>778</td>\n",
" <td>0.890524</td>\n",
" <td>779</td>\n",
" <td>779</td>\n",
" <td>Drop Zone (1994)</td>\n",
" <td>Action</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
2021-05-07 22:16:28 +02:00
" <td>686</td>\n",
" <td>0.889220</td>\n",
" <td>687</td>\n",
" <td>687</td>\n",
" <td>McHale's Navy (1997)</td>\n",
" <td>Comedy, War</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
2021-05-07 22:16:28 +02:00
" <td>331</td>\n",
" <td>0.887596</td>\n",
" <td>332</td>\n",
" <td>332</td>\n",
" <td>Kiss the Girls (1997)</td>\n",
2021-04-16 22:41:06 +02:00
" <td>Crime, Drama, Thriller</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
2021-05-07 22:16:28 +02:00
" <td>987</td>\n",
" <td>0.886547</td>\n",
" <td>988</td>\n",
" <td>988</td>\n",
" <td>Beautician and the Beast, The (1997)</td>\n",
" <td>Comedy, Romance</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
2021-05-07 22:16:28 +02:00
" <td>1039</td>\n",
" <td>0.882845</td>\n",
" <td>1040</td>\n",
" <td>1040</td>\n",
" <td>Two if by Sea (1996)</td>\n",
" <td>Comedy, Romance</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
2021-05-07 22:16:28 +02:00
" <td>1022</td>\n",
" <td>0.882782</td>\n",
" <td>1023</td>\n",
" <td>1023</td>\n",
" <td>Fathers' Day (1997)</td>\n",
" <td>Comedy</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
2021-05-07 22:16:28 +02:00
" <td>929</td>\n",
" <td>0.877662</td>\n",
" <td>930</td>\n",
" <td>930</td>\n",
" <td>Chain Reaction (1996)</td>\n",
" <td>Action, Adventure, Thriller</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
2021-05-07 22:16:28 +02:00
" code score item_id id title \\\n",
"0 321 1.000000 322 322 Murder at 1600 (1997) \n",
"1 983 0.902748 984 984 Shadow Conspiracy (1997) \n",
"2 985 0.894696 986 986 Turbulence (1997) \n",
"3 778 0.890524 779 779 Drop Zone (1994) \n",
"4 686 0.889220 687 687 McHale's Navy (1997) \n",
"5 331 0.887596 332 332 Kiss the Girls (1997) \n",
"6 987 0.886547 988 988 Beautician and the Beast, The (1997) \n",
"7 1039 0.882845 1040 1040 Two if by Sea (1996) \n",
"8 1022 0.882782 1023 1023 Fathers' Day (1997) \n",
"9 929 0.877662 930 930 Chain Reaction (1996) \n",
2021-04-16 22:41:06 +02:00
"\n",
2021-05-07 22:16:28 +02:00
" genres \n",
"0 Mystery, Thriller \n",
"1 Thriller \n",
"2 Thriller \n",
"3 Action \n",
"4 Comedy, War \n",
"5 Crime, Drama, Thriller \n",
"6 Comedy, Romance \n",
"7 Comedy, Romance \n",
"8 Comedy \n",
"9 Action, Adventure, Thriller "
2021-04-16 22:41:06 +02:00
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"item = random.choice(list(set(train_ui.indices)))\n",
"\n",
"embeddings_norm = (\n",
" model.Qi / np.linalg.norm(model.Qi, axis=1)[:, None]\n",
") # we do not mean-center here\n",
"# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n",
"\n",
"similarity_scores = np.dot(embeddings_norm, embeddings_norm[item].T)\n",
"top_similar_items = pd.DataFrame(\n",
" enumerate(similarity_scores), columns=[\"code\", \"score\"]\n",
").sort_values(by=[\"score\"], ascending=[False])[:10]\n",
"\n",
"top_similar_items[\"item_id\"] = top_similar_items[\"code\"].apply(\n",
" lambda x: item_code_id[x]\n",
")\n",
"\n",
"items = pd.read_csv(\"./Datasets/ml-100k/movies.csv\")\n",
"\n",
"result = pd.merge(top_similar_items, items, left_on=\"item_id\", right_on=\"id\")\n",
"\n",
"result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# project task 5: implement SVD on top baseline (as it is in Surprise library)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# making changes to our implementation by considering additional parameters in the gradient descent procedure\n",
"# seems to be the fastest option\n",
"# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n",
2021-05-29 13:08:49 +02:00
"# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'\n",
"\n",
"# link to the relevant Surprise documentation https://surprise.readthedocs.io/en/stable/matrix_factorization.html#matrix-factorization-based-algorithms"
2021-04-16 22:41:06 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ready-made SVD - Surprise implementation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"import helpers\n",
"import surprise as sp\n",
"\n",
"algo = sp.SVD(biased=False) # to use unbiased version\n",
"\n",
"helpers.ready_made(\n",
" algo,\n",
" reco_path=\"Recommendations generated/ml-100k/Ready_SVD_reco.csv\",\n",
" estimations_path=\"Recommendations generated/ml-100k/Ready_SVD_estimations.csv\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD biased - on top baseline"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"algo = sp.SVD() # default is biased=True\n",
"\n",
"helpers.ready_made(\n",
" algo,\n",
" reco_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv\",\n",
" estimations_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-05-07 22:16:28 +02:00
"943it [00:00, 11456.53it/s]\n",
"943it [00:00, 11932.50it/s]\n",
"943it [00:00, 10853.07it/s]\n",
"943it [00:00, 9426.44it/s]\n",
"943it [00:00, 8757.09it/s]\n",
"943it [00:00, 9999.67it/s]\n",
"943it [00:00, 11323.49it/s]\n",
"943it [00:00, 9764.72it/s]\n",
"943it [00:00, 9692.41it/s]\n",
"943it [00:00, 9052.77it/s]\n",
"943it [00:00, 8645.18it/s]\n",
"943it [00:00, 10594.54it/s]\n"
2021-04-16 22:41:06 +02:00
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVD</td>\n",
2021-05-07 22:16:28 +02:00
" <td>0.951652</td>\n",
" <td>0.750975</td>\n",
" <td>0.096394</td>\n",
" <td>0.047252</td>\n",
" <td>0.052870</td>\n",
" <td>0.067257</td>\n",
" <td>0.085515</td>\n",
" <td>0.074754</td>\n",
" <td>0.109578</td>\n",
" <td>0.051562</td>\n",
" <td>0.235567</td>\n",
" <td>0.520341</td>\n",
" <td>0.496288</td>\n",
" <td>0.995546</td>\n",
" <td>0.208514</td>\n",
" <td>4.455755</td>\n",
" <td>0.951624</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVD</td>\n",
2021-05-07 22:16:28 +02:00
" <td>0.914393</td>\n",
" <td>0.717199</td>\n",
" <td>0.101697</td>\n",
" <td>0.042334</td>\n",
" <td>0.051787</td>\n",
" <td>0.068811</td>\n",
" <td>0.092489</td>\n",
" <td>0.072360</td>\n",
" <td>0.104839</td>\n",
" <td>0.048970</td>\n",
" <td>0.196117</td>\n",
" <td>0.517889</td>\n",
" <td>0.480382</td>\n",
" <td>0.867338</td>\n",
" <td>0.147186</td>\n",
" <td>3.852545</td>\n",
" <td>0.972694</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVDBiased</td>\n",
2021-05-07 22:16:28 +02:00
" <td>0.940413</td>\n",
" <td>0.739571</td>\n",
" <td>0.086002</td>\n",
" <td>0.035478</td>\n",
" <td>0.043196</td>\n",
" <td>0.057507</td>\n",
" <td>0.075751</td>\n",
" <td>0.053460</td>\n",
" <td>0.094897</td>\n",
" <td>0.043361</td>\n",
" <td>0.209124</td>\n",
" <td>0.514405</td>\n",
" <td>0.428420</td>\n",
" <td>0.997349</td>\n",
" <td>0.177489</td>\n",
" <td>4.212509</td>\n",
" <td>0.962656</td>\n",
2021-04-16 22:41:06 +02:00
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Random</td>\n",
" <td>1.521845</td>\n",
" <td>1.225949</td>\n",
" <td>0.047190</td>\n",
" <td>0.020753</td>\n",
" <td>0.024810</td>\n",
" <td>0.032269</td>\n",
" <td>0.029506</td>\n",
" <td>0.023707</td>\n",
" <td>0.050075</td>\n",
" <td>0.018728</td>\n",
" <td>0.121957</td>\n",
" <td>0.506893</td>\n",
" <td>0.329799</td>\n",
" <td>0.986532</td>\n",
" <td>0.184704</td>\n",
" <td>5.099706</td>\n",
" <td>0.907217</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>1.030712</td>\n",
" <td>0.820904</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
2021-05-07 22:16:28 +02:00
"0 Ready_SVD 0.951652 0.750975 0.096394 0.047252 0.052870 \n",
"0 Self_SVD 0.914393 0.717199 0.101697 0.042334 0.051787 \n",
2021-04-16 22:41:06 +02:00
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
2021-05-07 22:16:28 +02:00
"0 Ready_SVDBiased 0.940413 0.739571 0.086002 0.035478 0.043196 \n",
2021-04-16 22:41:06 +02:00
"0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
2021-05-07 22:16:28 +02:00
"0 0.067257 0.085515 0.074754 0.109578 0.051562 0.235567 \n",
"0 0.068811 0.092489 0.072360 0.104839 0.048970 0.196117 \n",
2021-04-16 22:41:06 +02:00
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
2021-05-07 22:16:28 +02:00
"0 0.057507 0.075751 0.053460 0.094897 0.043361 0.209124 \n",
2021-04-16 22:41:06 +02:00
"0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR Reco in test Test coverage Shannon Gini \n",
"0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n",
2021-05-07 22:16:28 +02:00
"0 0.520341 0.496288 0.995546 0.208514 4.455755 0.951624 \n",
"0 0.517889 0.480382 0.867338 0.147186 3.852545 0.972694 \n",
2021-04-16 22:41:06 +02:00
"0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n",
2021-05-07 22:16:28 +02:00
"0 0.514405 0.428420 0.997349 0.177489 4.212509 0.962656 \n",
2021-04-16 22:41:06 +02:00
"0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n",
"0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n",
"0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n",
"0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n",
"0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 "
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir_path = \"Recommendations generated/ml-100k/\"\n",
"super_reactions = [4, 5]\n",
"test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}