introduction_to_recommender.../P4. Matrix Factorization.ipynb

1404 lines
74 KiB
Plaintext
Raw Normal View History

2021-04-16 22:41:06 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Self made SVD"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import helpers\n",
"import pandas as pd\n",
"import numpy as np\n",
"import scipy.sparse as sparse\n",
"from collections import defaultdict\n",
"from itertools import chain\n",
"import random\n",
"import matplotlib.pyplot as plt\n",
"\n",
"train_read = pd.read_csv(\"./Datasets/ml-100k/train.csv\", sep=\"\\t\", header=None)\n",
"test_read = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n",
"(\n",
" train_ui,\n",
" test_ui,\n",
" user_code_id,\n",
" user_id_code,\n",
" item_code_id,\n",
" item_id_code,\n",
") = helpers.data_to_csr(train_read, test_read)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"class SVD:\n",
" def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n",
" self.train_ui = train_ui\n",
" self.uir = list(\n",
" zip(*[train_ui.nonzero()[0], train_ui.nonzero()[1], train_ui.data])\n",
" )\n",
"\n",
" self.learning_rate = learning_rate\n",
" self.regularization = regularization\n",
" self.iterations = iterations\n",
" self.nb_users, self.nb_items = train_ui.shape\n",
" self.nb_ratings = train_ui.nnz\n",
" self.nb_factors = nb_factors\n",
"\n",
" self.Pu = np.random.normal(\n",
" loc=0, scale=1.0 / self.nb_factors, size=(self.nb_users, self.nb_factors)\n",
" )\n",
" self.Qi = np.random.normal(\n",
" loc=0, scale=1.0 / self.nb_factors, size=(self.nb_items, self.nb_factors)\n",
" )\n",
"\n",
" def train(self, test_ui=None):\n",
" if test_ui != None:\n",
" self.test_uir = list(\n",
" zip(*[test_ui.nonzero()[0], test_ui.nonzero()[1], test_ui.data])\n",
" )\n",
"\n",
" self.learning_process = []\n",
" pbar = tqdm(range(self.iterations))\n",
" for i in pbar:\n",
" pbar.set_description(\n",
" f\"Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...\"\n",
" )\n",
" np.random.shuffle(self.uir)\n",
" self.sgd(self.uir)\n",
" if test_ui == None:\n",
" self.learning_process.append([i + 1, self.RMSE_total(self.uir)])\n",
" else:\n",
" self.learning_process.append(\n",
" [i + 1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)]\n",
" )\n",
"\n",
" def sgd(self, uir):\n",
"\n",
" for u, i, score in uir:\n",
" # Computer prediction and error\n",
" prediction = self.get_rating(u, i)\n",
" e = score - prediction\n",
"\n",
" # Update user and item latent feature matrices\n",
" Pu_update = self.learning_rate * (\n",
" e * self.Qi[i] - self.regularization * self.Pu[u]\n",
" )\n",
" Qi_update = self.learning_rate * (\n",
" e * self.Pu[u] - self.regularization * self.Qi[i]\n",
" )\n",
"\n",
" self.Pu[u] += Pu_update\n",
" self.Qi[i] += Qi_update\n",
"\n",
" def get_rating(self, u, i):\n",
" prediction = self.Pu[u].dot(self.Qi[i].T)\n",
" return prediction\n",
"\n",
" def RMSE_total(self, uir):\n",
" RMSE = 0\n",
" for u, i, score in uir:\n",
" prediction = self.get_rating(u, i)\n",
" RMSE += (score - prediction) ** 2\n",
" return np.sqrt(RMSE / len(uir))\n",
"\n",
" def estimations(self):\n",
" self.estimations = np.dot(self.Pu, self.Qi.T)\n",
"\n",
" def recommend(self, user_code_id, item_code_id, topK=10):\n",
"\n",
" top_k = defaultdict(list)\n",
" for nb_user, user in enumerate(self.estimations):\n",
"\n",
" user_rated = self.train_ui.indices[\n",
" self.train_ui.indptr[nb_user] : self.train_ui.indptr[nb_user + 1]\n",
" ]\n",
" for item, score in enumerate(user):\n",
" if item not in user_rated and not np.isnan(score):\n",
" top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n",
" result = []\n",
" # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n",
" for uid, item_scores in top_k.items():\n",
" item_scores.sort(key=lambda x: x[1], reverse=True)\n",
" result.append([uid] + list(chain(*item_scores[:topK])))\n",
" return result\n",
"\n",
" def estimate(self, user_code_id, item_code_id, test_ui):\n",
" result = []\n",
" for user, item in zip(*test_ui.nonzero()):\n",
" result.append(\n",
" [\n",
" user_code_id[user],\n",
" item_code_id[item],\n",
" self.estimations[user, item]\n",
" if not np.isnan(self.estimations[user, item])\n",
" else 1,\n",
" ]\n",
" )\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 39 RMSE: 0.7477090330529405. Training epoch 40...: 100%|██████████| 40/40 [01:03<00:00, 1.59s/it]\n"
]
}
],
"source": [
"model = SVD(\n",
" train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40\n",
")\n",
"model.train(test_ui)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fdc5a9a2370>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAba0lEQVR4nO3de3RV9Z338fc3IZCEIAkhVYZwsZapjoihxNtIO1jqo1DXSF0uizNPpV2dMlqsONVa6+Pq2Ge005Fp7eCljD7S6jxdxel4rcvq4wW8tWIDBhTQgqMtEWrCLYByEfJ9/tjnkJPDSXJOcpJ9zt6f11q/tffZe59zvmzjZ+/z2zdzd0REpPiVhF2AiIjkhwJdRCQiFOgiIhGhQBcRiQgFuohIRAwJ64tHjx7tEydODOvrRUSK0qpVq7a5e12meaEF+sSJE2lqagrr60VEipKZ/aG7eepyERGJCAW6iEhEKNBFRCIitD50ESk+H330ES0tLezfvz/sUiKvvLyc+vp6ysrKsn6PAl1EstbS0sKIESOYOHEiZhZ2OZHl7mzfvp2WlhaOP/74rN+nLhcRydr+/fupra1VmA8wM6O2tjbnX0IKdBHJicJ8cPRlPRddoLe0wNVXw0cfhV2JiEhhKbpA/93v4N/+Db7//bArEREpLEUX6F/4AvzN38DNN8Pq1WFXIyKDadeuXdx11105v2/27Nns2rUr5/d9+ctf5vjjj6ehoYFTTz2VZ5999si8GTNmMH78eFIfEjRnzhyqqqoA6Ojo4KqrrmLy5MmccsopnHbaabzzzjtAcKX8KaecQkNDAw0NDVx11VU515ZJUZ7lcvvtsHw5zJsHTU0wbFjYFYnIYEgG+te//vUu0w8fPkxpaWm373viiSf6/J2LFi3i4osvZvny5cyfP5+NGzcemVddXc3LL7/M9OnT2bVrF1u3bj0y74EHHmDLli2sXbuWkpISWlpaGD58+JH5y5cvZ/To0X2uK5OiDPRRo+Cee+CCC+B731P3i0gYrr4ampvz+5kNDfDjH3c///rrr+ftt9+moaGBsrIyqqqqGDNmDM3Nzaxfv545c+awefNm9u/fz8KFC5k/fz7Qee+ovXv3MmvWLKZPn85vfvMbxo4dy6OPPkpFRUWvtZ111lm89957XabNnTuXZcuWMX36dB566CEuuugi1q1bB8DWrVsZM2YMJSVBR0h9fX2f1kkuiq7LJenzn4evfAX+5V9g5cqwqxGRwfCDH/yAE044gebmZhYtWsSrr77KLbfcwvr16wFYunQpq1atoqmpicWLF7N9+/ajPmPjxo0sWLCAdevWUV1dzYMPPpjVdz/55JPMmTOny7SZM2fywgsvcPjwYZYtW8YXv/jFI/MuueQSfvWrX9HQ0MA111zDa6+91uW955xzzpEul9tuuy3HNZFZUe6hJ912GzzzTND18tprkMVGVkTypKc96cFy+umnd7nwZvHixTz88MMAbN68mY0bN1JbW9vlPck+cYBp06bx7rvv9vgd3/rWt7juuutobW3llVde6TKvtLSU6dOn88ADD7Bv3z5SbwleX1/PW2+9xXPPPcdzzz3HzJkz+eUvf8nMmTOBgelyKdo9dICRI+Hee+Gtt+DGG8OuRkQGW2qf9IoVK3jmmWf47W9/y5o1a5g6dWrGC3OGpRx0Ky0t5dChQz1+x6JFi9i0aRM333wz8+bNO2r+3Llz+cY3vsEll1yS8btmzZrFokWLuOGGG3jkkUdy+NflrqgDHeDcc+Hyy4O99RdfDLsaERlII0aMYM+ePRnntbe3U1NTQ2VlJW+++eZRe9P9UVJSwsKFC+no6OCpp57qMu/Tn/403/nOd7j00ku7TF+9ejVbtmwBgjNe1q5dy4QJE/JWUyZF3eWStGgRPPVU0Ke+Zg2kbLRFJEJqa2s5++yzmTx5MhUVFRx77LFH5p1//vksWbKEKVOm8MlPfpIzzzwzr99tZtx4443ceuutnHfeeV2mX3vttUct39rayte+9jUOHDgABN1DV1555ZH555xzzpEzc6ZMmcL999/f/xpTz6EcTI2NjZ7PJxY9/zzMmAELFsAdd+TtY0UkxYYNGzjppJPCLiM2Mq1vM1vl7o2Zli/6Lpekv/orWLgQ7rwTUs79FxGJjV4D3czKzexVM1tjZuvM7HsZlplhZu1m1pxo3x2Ycnv2/e/DpElBn7qISLYWLFhw5BTCZPvpT38adlk5y6YP/QDwWXffa2ZlwEtm9mt3Tz/i8KK7X5D/ErNXWRmcwnjjjXDwIAwdGmY1ItHk7pG74+Kdd94ZdglH6Ut3eK976B7Ym3hZlmjhdLxnIXnKaYbrCUSkn8rLy9m+fXufwkayl3zARXl5eU7vy+osFzMrBVYBnwDudPdM12aeZWZrgC3Ate6+LsPnzAfmA4wfPz6nQrOVGuhjxgzIV4jEVn19PS0tLbS1tYVdSuQlH0GXi6wC3d0PAw1mVg08bGaT3f2NlEVWAxMS3TKzgUeASRk+527gbgjOcsmp0iyNGhUMtYcukn9lZWU5PRJNBldOZ7m4+y5gBXB+2vTdyW4Zd38CKDOz/F7TmiV1uYhIXGVzlktdYs8cM6sAPge8mbbMcZY4SmJmpyc+N5RIVaCLSFxl0+UyBrgv0Y9eAvynuz9uZpcDuPsS4GLgCjM7BOwD5npIR00U6CISV70GuruvBaZmmL4kZfwOoCCuz6yshPJyBbqIxE9krhRNVVurQBeR+FGgi4hEhAJdRCQiFOgiIhGhQBcRiYjIBvqOHaDbTYhInEQ20A8fhvb2sCsRERk8kQ10ULeLiMSLAl1EJCIU6CIiEaFAFxGJCAW6iEhERDLQq6vBTIEuIvESyUAvLYWaGgW6iMRLJAMddLWoiMSPAl1EJCIU6CIiEaFAFxGJCAW6iEhERDrQP/gADhwIuxIRkcER6UAH7aWLSHwo0EVEIkKBLiISEQp0EZGI6DXQzazczF41szVmts7MvpdhGTOzxWa2yczWmtmnBqbc7CnQRSRuhmSxzAHgs+6+18zKgJfM7Nfu/krKMrOASYl2BvCTxDA0CnQRiZte99A9sDfxsizR0h+/fCFwf2LZV4BqMxuT31JzU1ERNAW6iMRFVn3oZlZqZs1AK/C0u69MW2QssDnldUtiWvrnzDezJjNramtr62PJ2dPFRSISJ1kFursfdvcGoB443cwmpy1imd6W4XPudvdGd2+sq6vLudhcKdBFJE5yOsvF3XcBK4Dz02a1AONSXtcDW/pTWD4o0EUkTrI5y6XOzKoT4xXA54A30xZ7DLgscbbLmUC7u2/Nd7G5UqCLSJxkc5bLGOA+Mysl2AD8p7s/bmaXA7j7EuAJYDawCfgQ+MoA1ZsTBbqIxEmvge7ua4GpGaYvSRl3YEF+S+u/2lrYsQM6OqAkspdQiYgEIh1ztbVBmLe3h12JiMjAi3ygg7pdRCQeFOgiIhGhQBcRiQgFuohIRCjQRUQiItKBXl0dnK6oQBeROIh0oJeUQE2NAl1E4iHSgQ66WlRE4kOBLiISEQp0EZGIUKCLiESEAl1EJCJiEegffgj794ddiYjIwIpFoIP20kUk+hToIiIRoUAXEYkIBbqISEQo0EVEIkKBLiISEZEP9PJyqKxUoItI9EU+0EEXF4lIPCjQRUQiQoEuIhIRvQa6mY0zs+VmtsHM1pnZwgzLzDCzdjNrTrTvDky5faNAF5E4GJLFMoeAa9x9tZmNAFaZ2dPuvj5tuRfd/YL8l9h/CnQRiYNe99Ddfau7r06M7wE2AGMHurB8qq2FnTuhoyPsSkREBk5OfehmNhGYCqzMMPssM1tjZr82s5O7ef98M2sys6a2trbcq+2j2togzHftGrSvFBEZdFkHuplVAQ8CV7v77rTZq4EJ7n4qcDvwSKbPcPe73b3R3Rvr6ur6WHLudHGRiMRBVoFuZmUEYf5zd38ofb6773b3vYnxJ4AyMxud10r7QYEuInGQzVkuBtwLbHD3H3WzzHGJ5TCz0xOfWzDxqUAXkTjI5iyXs4EvAa+bWXNi2g3AeAB3XwJcDFxhZoeAfcBcd/f8l9s3CnQ
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"df = pd.DataFrame(model.learning_process).iloc[:, :2]\n",
"df.columns = [\"epoch\", \"train_RMSE\"]\n",
"plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fdc6037bb20>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyqElEQVR4nO3de3zPdf/H8cdrY2bO25yHYYg5DENiOR9TShIuinLJFdFVFFJdKVcHV4Wr80FR/XQuOpFQ5CqHyZnZMMzEmDkP296/P96zhuE7tn2+3+9e99ttN/t+Dt/v67NPPffZ+/P+vN9ijEEppZT38nG6AKWUUvlLg14ppbycBr1SSnk5DXqllPJyGvRKKeXlijhdQE6Cg4NNaGio02UopZTHiI6OPmiMKZ/TOrcM+tDQUFavXu10GUop5TFEZNel1l2x6UZEZorIARHZeIn1IiIzRCRORNaLSLNs67qLSEzmuvFXV75SSqlr4Uob/ftA98us7wHUyfwaDrwOICK+wKuZ6xsAA0SkwbUUq5RSKveuGPTGmKVA8mU26Q3MNtbvQFkRqQy0BOKMMTuMMWeAjzO3VUopVYDyoo2+KrAn2+uEzGU5LW91qTcRkeHYvwioXr16HpSllCooZ8+eJSEhgdTUVKdL8Xr+/v6EhIRQtGhRl/fJi6CXHJaZyyzPkTHmLeAtgMjISB2ARykPkpCQQKlSpQgNDUUkp//1VV4wxnDo0CESEhKoWbOmy/vlRT/6BKBattchQOJlliulvExqaipBQUEa8vlMRAgKCsr1X055EfTzgLsye99cDxwxxuwDVgF1RKSmiPgB/TO3VUp5IQ35gnE1P+crNt2IyBygPRAsIgnAk0BRAGPMG8D3QE8gDjgJDM1clyYio4AFgC8w0xizKdcV5kKGySA6MZoWVVvk58copZRHuWLQG2MGXGG9AUZeYt332F8EBeLzzZ9z5+d30qtuL6Z0nELjio0L6qOVUsptedVYN73q9uLZTs/y6+5fiXgjgkFfDmLH4R1Ol6WUymcpKSm89tprud6vZ8+epKSk5Hq/IUOGULNmTSIiImjSpAmLFi3KWte+fXuqV69O9kmdbr31VkqWLAlARkYGo0ePpmHDhjRq1IgWLVqwc+dOwI4K0KhRIyIiIoiIiGD06NG5ri0nbjkEwtUKKBrA+Lbjua/5fbyw/AWmr5jOmn1r2HT/Jm0/VMqLnQv6+++//7zl6enp+Pr6XnK/77+/+gaHqVOn0rdvX5YsWcLw4cOJjY3NWle2bFmWL19O27ZtSUlJYd++fVnrPvnkExITE1m/fj0+Pj4kJCRQokSJrPVLliwhODj4quvKiVcF/Tnlipfj2c7P8kCrB9hzZA8iwsmzJ3npt5d4oOUDlPEv43SJSnmtBx+EtWvz9j0jImDatEuvHz9+PNu3byciIoKiRYtSsmRJKleuzNq1a9m8eTO33nore/bsITU1lTFjxjB8+HDgr3G1jh8/To8ePWjbti3/+9//qFq1KnPnzqV48eJXrK1169bs3bv3vGX9+/fn448/pm3btnz55Zf06dOHTZvsLcp9+/ZRuXJlfHxsg0pISMhV/Uxyw6uabi5UpVQVWoXYZ7QWbl/I40sep9aMWkxdPpVTZ085XJ1SKq8899xz1K5dm7Vr1zJ16lRWrlzJlClT2Lx5MwAzZ84kOjqa1atXM2PGDA4dOnTRe8TGxjJy5Eg2bdpE2bJl+eKLL1z67Pnz53Prrbeet6xTp04sXbqU9PR0Pv74Y+68886sdf369eObb74hIiKChx9+mD/++OO8fTt06JDVdPPyyy/n8ieRM6+8os9J7+t6s2b4Gh5b/BiP/PQI01ZMY9wN4xjdajQ+4tW/75QqUJe78i4oLVu2PO+BohkzZvDVV18BsGfPHmJjYwkKCjpvn3Nt7gDNmzcnPj7+sp8xbtw4HnnkEQ4cOMDvv/9+3jpfX1/atm3LJ598wqlTp8g+7HpISAgxMTEsXryYxYsX06lTJz777DM6deoE5E/TTaFKuKaVm/L9377n57t/pm5QXb7a+lVWyJ9NP+twdUqpvJK9zfvnn3/mp59+4rfffmPdunU0bdo0xweOihUrlvW9r68vaWlpl/2MqVOnEhcXxzPPPMPdd9990fr+/fvzwAMP0K9fvxw/q0ePHkydOpWJEyfy9ddf5+Locq9QBf057ULbseTuJXw38DsAEo8lUu3lajyx5AkOnbz4TzqllHsrVaoUx44dy3HdkSNHKFeuHAEBAWzduvWiq+9r4ePjw5gxY8jIyGDBggXnrYuKimLChAkMGHB+D/U1a9aQmGgHCcjIyGD9+vXUqFEjz2rKsc58fXc3V9LPdnc6nXaaNtXb8PTSp6kxrQZjfxzLvmP7rrC3UspdBAUF0aZNGxo2bMi4cePOW9e9e3fS0tJo3Lgxjz/+ONdff32efraIMGnSJF544YWLlo8dO/aiZpgDBw5w880307BhQxo3bkyRIkUYNWpU1vrsbfR33XVX3tSYva+nu4iMjDROzDC16cAmnv31WeZsnIN/EX/2/HMPgcUDC7wOpTzNli1bqF+/vtNlFBo5/bxFJNoYE5nT9oX6iv5C4RXC+bDPh8SMimFat2lZIf/owkd5d8277D++3+EKlVIq9wpNr5vcCAsMIywwDIBTZ0/x2ebP2JmyE0FoFdKK3vV6c0eDO6gdWNvhSpVS+WnkyJEsX778vGVjxoxh6NChDlV0dTTor6B40eJsH72d9fvXMy9mHnNj5jJh0QRK+pVkVMtRHD51mPX719OmehuK+OiPUylv8uqrrzpdQp7QZHKBiNCkUhOaVGrC4+0eJ+FoAiWK2u5bX2/9mnvm3UM5/3LcXO9mBjUaRMeaHfH1ufRj10opVZA06K9CSOm/Hlm+I/wOyviXYV7MPL7e+jWz180mpHQI60esp1zxcg5WqZRSlgb9NSrpV5I+9fvQp34fUtNSmRczj98Tfs8K+SeXPEmFEhXo37A/QQFBV3g3pZTKe9rrJg/5F/GnX3g/Xur2EmAnQpm/fT6jfhhF5Rcrc/untzMvZp4+hauUKlAa9PnIR3xYMWwFf9z3ByNbjGTZrmX0/rg3/172b8D26Pnz+J8OV6mU57va8egBpk2bxsmTJy+7zblx4hs3bky7du3YtWtX1joRYfDgwVmv09LSKF++PL169QJg//799OrViyZNmtCgQQN69uwJQHx8PMWLF896OCoiIoLZs2df1TFckTHG7b6aN29uvNGZtDNm3tZ5ZufhncYYY+ZtnWf4Fyb81XAz+vvRZu7WuSblVIqzRSp1FTZv3uzo5+/cudOEh4df1b41atQwSUlJLm/zxBNPmGHDhmWtK1GihImIiDAnT540xhjz/fffmyZNmpibbrrJGGPM8OHDzbRp07K2X7du3TXXnNPPG1htLpGp2kZfgIr6FuXmejdnvW5UsRHPd36eRTsX8faat5mxcgY+4kPMqBjCAsNIOpFEqWKl8C/i72DVSuVe+/fbX7SsX3g/7m9xPyfPnqTnRz0vWj8kYghDIoZw8ORB+n7a97x1Pw/5+bKfl308+i5dulChQgU+/fRTTp8+zW233cZTTz3FiRMn6NevHwkJCaSnp/P444+zf/9+EhMT6dChA8HBwSxZsuSKx9a6dWtmzJhx3rIePXrw3Xff0bdvX+bMmcOAAQNYtmwZYMef79q1a9a2jRsX/BSn2nTjoNCyoTzS5hEWDFrA4UcP8/PdPzO5/WRql7MPYk1cNJEKUyvwwPcPEHMwxuFqlXJf2cej79KlC7GxsaxcuZK1a9cSHR3N0qVLmT9/PlWqVGHdunVs3LiR7t27M3r0aKpUqcKSJUtcCnnIefz5cxONpKamsn79elq1apW1buTIkdx777106NCBKVOmZA1oBmT9cjr3de6XQ17TK3o3UaxIMdqFtqNdaLusZX9r/DdS01N5a81bvLLqFbrW7srDrR+ma+2ul3knpZx3uSvwgKIBl10fHBB8xSv4y/nxxx/58ccfadq0KQDHjx8nNjaWqKgoxo4dy6OPPkqvXr2IiorK1ft26NCB/fv3U6FCBZ5
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"df = pd.DataFrame(\n",
" model.learning_process[10:], columns=[\"epoch\", \"train_RMSE\", \"test_RMSE\"]\n",
")\n",
"plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n",
"plt.plot(\"epoch\", \"test_RMSE\", data=df, color=\"green\", linestyle=\"dashed\")\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Saving and evaluating recommendations"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model.estimations()\n",
"\n",
"top_n = pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n",
"\n",
"top_n.to_csv(\n",
" \"Recommendations generated/ml-100k/Self_SVD_reco.csv\", index=False, header=False\n",
")\n",
"\n",
"estimations = pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n",
"estimations.to_csv(\n",
" \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\",\n",
" index=False,\n",
" header=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 11138.92it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.914143</td>\n",
" <td>0.717131</td>\n",
" <td>0.101803</td>\n",
" <td>0.042134</td>\n",
" <td>0.05161</td>\n",
" <td>0.068543</td>\n",
" <td>0.091953</td>\n",
" <td>0.071255</td>\n",
" <td>0.104015</td>\n",
" <td>0.048817</td>\n",
" <td>0.193027</td>\n",
" <td>0.517784</td>\n",
" <td>0.471898</td>\n",
" <td>0.867232</td>\n",
" <td>0.147908</td>\n",
" <td>3.871296</td>\n",
" <td>0.97182</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" RMSE MAE precision recall F_1 F_05 \\\n",
"0 0.914143 0.717131 0.101803 0.042134 0.05161 0.068543 \n",
"\n",
" precision_super recall_super NDCG mAP MRR LAUC \\\n",
"0 0.091953 0.071255 0.104015 0.048817 0.193027 0.517784 \n",
"\n",
" HR Reco in test Test coverage Shannon Gini \n",
"0 0.471898 0.867232 0.147908 3.871296 0.97182 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import evaluation_measures as ev\n",
"\n",
"estimations_df = pd.read_csv(\n",
" \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\", header=None\n",
")\n",
"reco = np.loadtxt(\"Recommendations generated/ml-100k/Self_SVD_reco.csv\", delimiter=\",\")\n",
"\n",
"ev.evaluate(\n",
" test=pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None),\n",
" estimations_df=estimations_df,\n",
" reco=reco,\n",
" super_reactions=[4, 5],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 10694.55it/s]\n",
"943it [00:00, 11600.99it/s]\n",
"943it [00:00, 11461.54it/s]\n",
"943it [00:00, 11660.39it/s]\n",
"943it [00:00, 9872.18it/s]\n",
"943it [00:00, 11443.77it/s]\n",
"943it [00:00, 11990.88it/s]\n",
"943it [00:00, 11615.02it/s]\n",
"943it [00:00, 11874.78it/s]\n",
"943it [00:00, 12387.19it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVD</td>\n",
" <td>0.914143</td>\n",
" <td>0.717131</td>\n",
" <td>0.101803</td>\n",
" <td>0.042134</td>\n",
" <td>0.051610</td>\n",
" <td>0.068543</td>\n",
" <td>0.091953</td>\n",
" <td>0.071255</td>\n",
" <td>0.104015</td>\n",
" <td>0.048817</td>\n",
" <td>0.193027</td>\n",
" <td>0.517784</td>\n",
" <td>0.471898</td>\n",
" <td>0.867232</td>\n",
" <td>0.147908</td>\n",
" <td>3.871296</td>\n",
" <td>0.971820</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Random</td>\n",
" <td>1.521845</td>\n",
" <td>1.225949</td>\n",
" <td>0.047190</td>\n",
" <td>0.020753</td>\n",
" <td>0.024810</td>\n",
" <td>0.032269</td>\n",
" <td>0.029506</td>\n",
" <td>0.023707</td>\n",
" <td>0.050075</td>\n",
" <td>0.018728</td>\n",
" <td>0.121957</td>\n",
" <td>0.506893</td>\n",
" <td>0.329799</td>\n",
" <td>0.986532</td>\n",
" <td>0.184704</td>\n",
" <td>5.099706</td>\n",
" <td>0.907217</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>1.030712</td>\n",
" <td>0.820904</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
"0 Self_SVD 0.914143 0.717131 0.101803 0.042134 0.051610 \n",
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
"0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
"0 0.068543 0.091953 0.071255 0.104015 0.048817 0.193027 \n",
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
"0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR Reco in test Test coverage Shannon Gini \n",
"0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n",
"0 0.517784 0.471898 0.867232 0.147908 3.871296 0.971820 \n",
"0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n",
"0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n",
"0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n",
"0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n",
"0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n",
"0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir_path = \"Recommendations generated/ml-100k/\"\n",
"super_reactions = [4, 5]\n",
"test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embeddings"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>code</th>\n",
" <th>score</th>\n",
" <th>item_id</th>\n",
" <th>id</th>\n",
" <th>title</th>\n",
" <th>genres</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>405</td>\n",
" <td>1.000000</td>\n",
" <td>406</td>\n",
" <td>406</td>\n",
" <td>Thinner (1996)</td>\n",
" <td>Horror, Thriller</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>827</td>\n",
" <td>0.968354</td>\n",
" <td>828</td>\n",
" <td>828</td>\n",
" <td>Alaska (1996)</td>\n",
" <td>Adventure, Children's</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>818</td>\n",
" <td>0.967103</td>\n",
" <td>819</td>\n",
" <td>819</td>\n",
" <td>Eddie (1996)</td>\n",
" <td>Comedy</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>904</td>\n",
" <td>0.963944</td>\n",
" <td>905</td>\n",
" <td>905</td>\n",
" <td>Great Expectations (1998)</td>\n",
" <td>Drama, Romance</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1293</td>\n",
" <td>0.962779</td>\n",
" <td>1294</td>\n",
" <td>1294</td>\n",
" <td>Ayn Rand: A Sense of Life (1997)</td>\n",
" <td>Documentary</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>347</td>\n",
" <td>0.961946</td>\n",
" <td>348</td>\n",
" <td>348</td>\n",
" <td>Desperate Measures (1998)</td>\n",
" <td>Crime, Drama, Thriller</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>807</td>\n",
" <td>0.960952</td>\n",
" <td>808</td>\n",
" <td>808</td>\n",
" <td>Program, The (1993)</td>\n",
" <td>Action, Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>560</td>\n",
" <td>0.960885</td>\n",
" <td>561</td>\n",
" <td>561</td>\n",
" <td>Mary Shelley's Frankenstein (1994)</td>\n",
" <td>Drama, Horror</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>1392</td>\n",
" <td>0.958724</td>\n",
" <td>1393</td>\n",
" <td>1393</td>\n",
" <td>Stag (1997)</td>\n",
" <td>Action, Thriller</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>787</td>\n",
" <td>0.957891</td>\n",
" <td>788</td>\n",
" <td>788</td>\n",
" <td>Relative Fear (1994)</td>\n",
" <td>Horror, Thriller</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" code score item_id id title \\\n",
"0 405 1.000000 406 406 Thinner (1996) \n",
"1 827 0.968354 828 828 Alaska (1996) \n",
"2 818 0.967103 819 819 Eddie (1996) \n",
"3 904 0.963944 905 905 Great Expectations (1998) \n",
"4 1293 0.962779 1294 1294 Ayn Rand: A Sense of Life (1997) \n",
"5 347 0.961946 348 348 Desperate Measures (1998) \n",
"6 807 0.960952 808 808 Program, The (1993) \n",
"7 560 0.960885 561 561 Mary Shelley's Frankenstein (1994) \n",
"8 1392 0.958724 1393 1393 Stag (1997) \n",
"9 787 0.957891 788 788 Relative Fear (1994) \n",
"\n",
" genres \n",
"0 Horror, Thriller \n",
"1 Adventure, Children's \n",
"2 Comedy \n",
"3 Drama, Romance \n",
"4 Documentary \n",
"5 Crime, Drama, Thriller \n",
"6 Action, Drama \n",
"7 Drama, Horror \n",
"8 Action, Thriller \n",
"9 Horror, Thriller "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"item = random.choice(list(set(train_ui.indices)))\n",
"\n",
"embeddings_norm = (\n",
" model.Qi / np.linalg.norm(model.Qi, axis=1)[:, None]\n",
") # we do not mean-center here\n",
"# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n",
"\n",
"similarity_scores = np.dot(embeddings_norm, embeddings_norm[item].T)\n",
"top_similar_items = pd.DataFrame(\n",
" enumerate(similarity_scores), columns=[\"code\", \"score\"]\n",
").sort_values(by=[\"score\"], ascending=[False])[:10]\n",
"\n",
"top_similar_items[\"item_id\"] = top_similar_items[\"code\"].apply(\n",
" lambda x: item_code_id[x]\n",
")\n",
"\n",
"items = pd.read_csv(\"./Datasets/ml-100k/movies.csv\")\n",
"\n",
"result = pd.merge(top_similar_items, items, left_on=\"item_id\", right_on=\"id\")\n",
"\n",
"result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# project task 5: implement SVD on top baseline (as it is in Surprise library)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# making changes to our implementation by considering additional parameters in the gradient descent procedure\n",
"# seems to be the fastest option\n",
"# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n",
"# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ready-made SVD - Surprise implementation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"import helpers\n",
"import surprise as sp\n",
"\n",
"algo = sp.SVD(biased=False) # to use unbiased version\n",
"\n",
"helpers.ready_made(\n",
" algo,\n",
" reco_path=\"Recommendations generated/ml-100k/Ready_SVD_reco.csv\",\n",
" estimations_path=\"Recommendations generated/ml-100k/Ready_SVD_estimations.csv\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD biased - on top baseline"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"algo = sp.SVD() # default is biased=True\n",
"\n",
"helpers.ready_made(\n",
" algo,\n",
" reco_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv\",\n",
" estimations_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 11249.52it/s]\n",
"943it [00:00, 10927.13it/s]\n",
"943it [00:00, 11816.00it/s]\n",
"943it [00:00, 11204.84it/s]\n",
"943it [00:00, 11803.13it/s]\n",
"943it [00:00, 10580.63it/s]\n",
"943it [00:00, 11843.28it/s]\n",
"943it [00:00, 12313.76it/s]\n",
"943it [00:00, 10678.21it/s]\n",
"943it [00:00, 9772.22it/s]\n",
"943it [00:00, 10699.52it/s]\n",
"943it [00:00, 11789.55it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVD</td>\n",
" <td>0.950347</td>\n",
" <td>0.749312</td>\n",
" <td>0.100636</td>\n",
" <td>0.050514</td>\n",
" <td>0.055794</td>\n",
" <td>0.070753</td>\n",
" <td>0.091202</td>\n",
" <td>0.082734</td>\n",
" <td>0.114054</td>\n",
" <td>0.053200</td>\n",
" <td>0.248803</td>\n",
" <td>0.521983</td>\n",
" <td>0.517497</td>\n",
" <td>0.992153</td>\n",
" <td>0.210678</td>\n",
" <td>4.418683</td>\n",
" <td>0.952848</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVD</td>\n",
" <td>0.914143</td>\n",
" <td>0.717131</td>\n",
" <td>0.101803</td>\n",
" <td>0.042134</td>\n",
" <td>0.051610</td>\n",
" <td>0.068543</td>\n",
" <td>0.091953</td>\n",
" <td>0.071255</td>\n",
" <td>0.104015</td>\n",
" <td>0.048817</td>\n",
" <td>0.193027</td>\n",
" <td>0.517784</td>\n",
" <td>0.471898</td>\n",
" <td>0.867232</td>\n",
" <td>0.147908</td>\n",
" <td>3.871296</td>\n",
" <td>0.971820</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVDBiased</td>\n",
" <td>0.939472</td>\n",
" <td>0.739816</td>\n",
" <td>0.085896</td>\n",
" <td>0.036073</td>\n",
" <td>0.043528</td>\n",
" <td>0.057643</td>\n",
" <td>0.077039</td>\n",
" <td>0.057463</td>\n",
" <td>0.097753</td>\n",
" <td>0.045546</td>\n",
" <td>0.219839</td>\n",
" <td>0.514709</td>\n",
" <td>0.431601</td>\n",
" <td>0.997455</td>\n",
" <td>0.168831</td>\n",
" <td>4.217578</td>\n",
" <td>0.962577</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Random</td>\n",
" <td>1.521845</td>\n",
" <td>1.225949</td>\n",
" <td>0.047190</td>\n",
" <td>0.020753</td>\n",
" <td>0.024810</td>\n",
" <td>0.032269</td>\n",
" <td>0.029506</td>\n",
" <td>0.023707</td>\n",
" <td>0.050075</td>\n",
" <td>0.018728</td>\n",
" <td>0.121957</td>\n",
" <td>0.506893</td>\n",
" <td>0.329799</td>\n",
" <td>0.986532</td>\n",
" <td>0.184704</td>\n",
" <td>5.099706</td>\n",
" <td>0.907217</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>1.030712</td>\n",
" <td>0.820904</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
"0 Ready_SVD 0.950347 0.749312 0.100636 0.050514 0.055794 \n",
"0 Self_SVD 0.914143 0.717131 0.101803 0.042134 0.051610 \n",
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
"0 Ready_SVDBiased 0.939472 0.739816 0.085896 0.036073 0.043528 \n",
"0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
"0 0.070753 0.091202 0.082734 0.114054 0.053200 0.248803 \n",
"0 0.068543 0.091953 0.071255 0.104015 0.048817 0.193027 \n",
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
"0 0.057643 0.077039 0.057463 0.097753 0.045546 0.219839 \n",
"0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR Reco in test Test coverage Shannon Gini \n",
"0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n",
"0 0.521983 0.517497 0.992153 0.210678 4.418683 0.952848 \n",
"0 0.517784 0.471898 0.867232 0.147908 3.871296 0.971820 \n",
"0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n",
"0 0.514709 0.431601 0.997455 0.168831 4.217578 0.962577 \n",
"0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n",
"0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n",
"0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n",
"0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n",
"0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 "
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir_path = \"Recommendations generated/ml-100k/\"\n",
"super_reactions = [4, 5]\n",
"test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}