WSS-project/P4. Matrix Factorization.ipynb
2021-05-29 13:08:49 +02:00

1456 lines
76 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Self made SVD"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import helpers\n",
"import pandas as pd\n",
"import numpy as np\n",
"import scipy.sparse as sparse\n",
"from collections import defaultdict\n",
"from itertools import chain\n",
"import random\n",
"import matplotlib.pyplot as plt\n",
"\n",
"train_read = pd.read_csv(\"./Datasets/ml-100k/train.csv\", sep=\"\\t\", header=None)\n",
"test_read = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n",
"(\n",
" train_ui,\n",
" test_ui,\n",
" user_code_id,\n",
" user_id_code,\n",
" item_code_id,\n",
" item_id_code,\n",
") = helpers.data_to_csr(train_read, test_read)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"class SVD:\n",
" def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n",
" self.train_ui = train_ui\n",
" self.uir = list(\n",
" zip(*[train_ui.nonzero()[0], train_ui.nonzero()[1], train_ui.data])\n",
" )\n",
"\n",
" self.learning_rate = learning_rate\n",
" self.regularization = regularization\n",
" self.iterations = iterations\n",
" self.nb_users, self.nb_items = train_ui.shape\n",
" self.nb_ratings = train_ui.nnz\n",
" self.nb_factors = nb_factors\n",
"\n",
" self.Pu = np.random.normal(\n",
" loc=0, scale=1.0 / self.nb_factors, size=(self.nb_users, self.nb_factors)\n",
" )\n",
" self.Qi = np.random.normal(\n",
" loc=0, scale=1.0 / self.nb_factors, size=(self.nb_items, self.nb_factors)\n",
" )\n",
"\n",
" def train(self, test_ui=None):\n",
" if test_ui != None:\n",
" self.test_uir = list(\n",
" zip(*[test_ui.nonzero()[0], test_ui.nonzero()[1], test_ui.data])\n",
" )\n",
"\n",
" self.learning_process = []\n",
" pbar = tqdm(range(self.iterations))\n",
" for i in pbar:\n",
" pbar.set_description(\n",
" f\"Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...\"\n",
" )\n",
" np.random.shuffle(self.uir)\n",
" self.sgd(self.uir)\n",
" if test_ui == None:\n",
" self.learning_process.append([i + 1, self.RMSE_total(self.uir)])\n",
" else:\n",
" self.learning_process.append(\n",
" [i + 1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)]\n",
" )\n",
"\n",
" def sgd(self, uir):\n",
"\n",
" for u, i, score in uir:\n",
" # Computer prediction and error\n",
" prediction = self.get_rating(u, i)\n",
" e = score - prediction\n",
"\n",
" # Update user and item latent feature matrices\n",
" Pu_update = self.learning_rate * (\n",
" e * self.Qi[i] - self.regularization * self.Pu[u]\n",
" )\n",
" Qi_update = self.learning_rate * (\n",
" e * self.Pu[u] - self.regularization * self.Qi[i]\n",
" )\n",
"\n",
" self.Pu[u] += Pu_update\n",
" self.Qi[i] += Qi_update\n",
"\n",
" def get_rating(self, u, i):\n",
" prediction = self.Pu[u].dot(self.Qi[i].T)\n",
" return prediction\n",
"\n",
" def RMSE_total(self, uir):\n",
" RMSE = 0\n",
" for u, i, score in uir:\n",
" prediction = self.get_rating(u, i)\n",
" RMSE += (score - prediction) ** 2\n",
" return np.sqrt(RMSE / len(uir))\n",
"\n",
" def estimations(self):\n",
" self.estimations = np.dot(self.Pu, self.Qi.T)\n",
"\n",
" def recommend(self, user_code_id, item_code_id, topK=10):\n",
"\n",
" top_k = defaultdict(list)\n",
" for nb_user, user_scores in enumerate(self.estimations):\n",
"\n",
" user_rated = self.train_ui.indices[\n",
" self.train_ui.indptr[nb_user] : self.train_ui.indptr[nb_user + 1]\n",
" ]\n",
" for item, score in enumerate(user_scores):\n",
" if item not in user_rated and not np.isnan(score):\n",
" top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n",
" result = []\n",
" # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n",
" for uid, item_scores in top_k.items():\n",
" item_scores.sort(key=lambda x: x[1], reverse=True)\n",
" result.append([uid] + list(chain(*item_scores[:topK])))\n",
" return result\n",
"\n",
" def estimate(self, user_code_id, item_code_id, test_ui):\n",
" result = []\n",
" for user, item in zip(*test_ui.nonzero()):\n",
" result.append(\n",
" [\n",
" user_code_id[user],\n",
" item_code_id[item],\n",
" self.estimations[user, item]\n",
" if not np.isnan(self.estimations[user, item])\n",
" else 1,\n",
" ]\n",
" )\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 39 RMSE: 0.7489999966900885. Training epoch 40...: 100%|██████████| 40/40 [01:02<00:00, 1.57s/it]\n"
]
}
],
"source": [
"model = SVD(\n",
" train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40\n",
")\n",
"model.train(test_ui)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7ff52c7e5100>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbmklEQVR4nO3de3RV9Z338fc3FwmQQCBEpIRbfewaO6BBIuqIXVQ7LkGtjGMtrjXeVldZtVhx6rXWtqPLdlwy1Y7VljoVpz61RTteqo629YJaa1EDBhTBB2ytRLJMDCQQlEvg+/yxzyEn4SQ5JznJPtn781rrt/Y5e+9z9jdb/Ox9fvtm7o6IiAx9BWEXICIiuaFAFxGJCAW6iEhEKNBFRCJCgS4iEhFFYS143LhxPnXq1LAWLyIyJK1evfojd69MNy20QJ86dSq1tbVhLV5EZEgys791N01dLiIiEaFAFxGJCAW6iEhEhNaHLiJDz759+6ivr2f37t1hlxJ5JSUlVFVVUVxcnPFnFOgikrH6+nrKysqYOnUqZhZ2OZHl7jQ3N1NfX8+0adMy/py6XEQkY7t376aiokJhPsDMjIqKiqx/CSnQRSQrCvPB0Zf1POQCfcsWuPJK2Lcv7EpERPLLkAv0NWvgP/8Tli4NuxIRkfwy5AL9nHPgS1+Cm2+GjRvDrkZEBlNLSws/+clPsv7c/PnzaWlpyfpzl1xyCdOmTaO6uppjjz2W55577uC0uXPnMnnyZFIfErRgwQJKS0sBOHDgAFdccQXTp09nxowZHH/88fz1r38FgivlZ8yYQXV1NdXV1VxxxRVZ15bOkDzL5cc/hmefha9+FV58EQqG3GZJRPoiGehf//rXO43fv38/hYWF3X7uqaee6vMyly5dynnnncfKlStZtGgRmzZtOjitvLycP/3pT8yZM4eWlhYaGhoOTnvwwQfZunUr69ato6CggPr6ekaOHHlw+sqVKxk3blyf60pnSAb6+PFwxx1wySWwbBl0+W8rIoPgyiuhri6331ldDT/6UffTr7/+et59912qq6spLi6mtLSUCRMmUFdXx9tvv82CBQvYsmULu3fvZsmSJSxatAjouHdUW1sb8+bNY86cObzyyitMnDiR3/72twwfPrzX2k466SQ++OCDTuMWLlzIihUrmDNnDo888gjnnnsu69evB6ChoYEJEyZQkNjjrKqq6tM6ycaQ3be96CI4/XS47jp4//2wqxGRwXDrrbdy5JFHUldXx9KlS3nttdf4/ve/z9tvvw3A8uXLWb16NbW1tdx55500Nzcf8h2bNm1i8eLFrF+/nvLych5++OGMlv273/2OBQsWdBp32mmn8dJLL7F//35WrFjBl7/85YPTzj//fJ544gmqq6u56qqreOONNzp99vOf//zBLpc77rgjyzWR3pDcQwcwg5/9DKZPh699Df73f4NxIjI4etqTHiyzZ8/udOHNnXfeyaOPPgrAli1b2LRpExUVFZ0+k+wTB5g1axbvvfdej8u45ppruPbaa2lsbGTVqlWdphUWFjJnzhwefPBBPvnkE1JvCV5VVcU777zD888/z/PPP89pp53Gb37zG0477TRgYLpchuweOsDUqfCDH8DTT8OvfhV2NSIy2FL7pF944QWeffZZ/vznP7N27VpmzpyZ9sKcYcOGHXxdWFhIe3t7j8tYunQpmzdv5pZbbuHiiy8+ZPrChQv5xje+wfnnn592WfPmzWPp0qXccMMNPPbYY1n8ddkb0oEOsHgxnHgiLFkCTU1hVyMiA6msrIydO3emndba2sqYMWMYMWIEGzduPGRvuj8KCgpYsmQJBw4c4Pe//32naaeccgrf+ta3uOCCCzqNX7NmDVu3bgWCM17WrVvHlClTclZT2joH9NsHQWEh3Hsv7NwZhLqIRFdFRQUnn3wy06dP55prruk07YwzzqC9vZ1jjjmG73znO5x44ok5XbaZceONN3LbbbcdMv7qq68+pPuksbGRs88+m+nTp3PMMcdQVFTE5ZdffnB6ah/6RRddlJsaU8+hHEw1NTWeyycW3XwzfO978MQTcNZZOftaEUmxYcMGjj766LDLiI1069vMVrt7Tbr5h/weetL11wcHSC+7DHbsCLsaEZHB12ugm1mJmb1mZmvNbL2Z3ZRmnrlm1mpmdYn23YEpt3uHHRZ0vWzdGoS7iEimFi9efLD7I9nuu+++sMvKWianLe4BTnX3NjMrBl42s6fdvesRhz+6e6idHbNnw6WXws9/DnffrdMYRQaCu0fujot333132CUcoi/d4b3uoXugLfG2ONHC6XjPwGc/G9yJUd0uIrlXUlJCc3Nzn8JGMpd8wEVJSUlWn8vowiIzKwRWA/8HuNvdX00z20lmthbYClzt7uvTfM8iYBHA5MmTsyo0U8lrCJqbYfToAVmESGxVVVVRX19Pk84RHnDJR9BlI6NAd/f9QLWZlQOPmtl0d38rZZY1wJREt8x84DHgqDTfcw9wDwRnuWRVaYZSA/3Tnx6IJYjEV3FxcVaPRJPBldVZLu7eArwAnNFl/I5kt4y7PwUUm1lur2nNUGqgi4jESSZnuVQm9swxs+HAF4CNXeY5whJHScxsduJ7Q4lUBbqIxFUmXS4TgF8k+tELgIfc/Ukz+xqAuy8DzgMuM7N24BNgoYd01GTs2GCoQBeRuOk10N19HTAzzfhlKa/vAu7KbWl9M2ZMcLqiAl1E4iYyV4omFRZCebkCXUTiJ3KBDkE/ugJdROJGgS4iEhEKdBGRiFCgi4hEhAJdRCQiIhvobW2wd2/YlYiIDJ7IBjrAtm3h1iEiMpgiHejqdhGROFGgi4hEhAJdRCQiIhnoukGXiMRRJANde+giEkeRDPQRI2DYMAW6iMRLJAPdTBcXiUj8RDLQQYEuIvGjQBcRiQgFuohIRCjQRUQiItKBvm0bhPOoahGRwRfpQG9vh507w65ERGRwRDrQQd0uIhIfCnQRkYhQoIuIRESvgW5mJWb2mpmtNbP1ZnZTmnnMzO40s81mts7MjhuYcjOnG3SJSNwUZTDPHuBUd28zs2LgZTN72t1XpcwzDzgq0U4AfpoYhkZ76CISN73uoXugLfG2ONG6ngx4DnB/Yt5VQLmZTchtqdkZMyYYKtBFJC4y6kM3s0IzqwMagWfc/dUus0wEtqS8r0+M6/o9i8ys1sxqm5qa+lhyZoqKoLxcgS4i8ZFRoLv7fnevBqqA2WY2vcsslu5jab7nHnevcfeaysrKrIvNlq4WFZE4yeosF3dvAV4AzugyqR6YlPK+Ctjan8JyQYEuInGSyVkulWZWnng9HPgCsLHLbI8DFyXOdjkRaHX3hlwXmy0FuojESSZnuUwAfmFmhQQbgIfc/Ukz+xqAuy8DngLmA5uBj4FLB6jerFRUwIYNYVchIjI4eg10d18HzEwzflnKawcW57a0/tMeuojESWSvFIUg0HfuhL17w65ERGTgRT7QAbZvD7cOEZHBEItAV7eLiMSBAl1EJCIiHei6QZeIxEmkA1176CISJwp0EZGIiHSgjxwJhx2mQBeReIh0oJvp4iIRiY9IBzoo0EUkPhToIiIRoUAXEYkIBbqISETEJtD9kOcniYhESywCvb0d2tp6n1dEZCiLRaCDul1EJPoU6CIiERH5QNcNukQkLiIf6NpDF5G4UKCLiERE5ANdXS4iEheRD/SiIhg9WoEuItEX+UAHXS0qIvGgQBcRiYheA93MJpnZSjPbYGbrzWxJmnnmmlmrmdUl2ncHpty+UaCLSBwUZTBPO3CVu68xszJgtZk94+5vd5nvj+5+Vu5L7L+KCnjnnbCrEBEZWL3uobt7g7uvSbzeCWwAJg50YbmkPXQRiYOs+tDNbCowE3g1zeSTzGytmT1tZn/fzecXmVmtmdU2NTVlX20fVVTAjh2wb9+gLVJEZNBlHOhmVgo8DFzp7ju6TF4DTHH3Y4EfA4+l+w53v8fda9y9prKyso8lZy95cdH27YO2SBGRQZdRoJtZMUGYP+Duj3Sd7u473L0t8fopoNjMxuW00n7Q1aIiEgeZnOViwL3ABne/vZt5jkjMh5nNTnxv3sSnrhYVkTjI5CyXk4ELgTfNrC4x7gZgMoC7LwPOAy4zs3bgE2Che/48I0h76CISB70Guru/DFgv89wF3JWronJNgS4icRCbK0VBgS4i0RaLQC8theJiBbqIRFssAt1MFxeJSPTFItBBgS4i0adAFxGJCAW6iEhEKNBFRCIidoGeP5c7iYjkVqwCfd8+2LUr7EpERAZGrAId1O0iItEVm0DXDbpEJOpiE+jaQxeRqFOgi4hEhAJdRCQiYhPo6kMXkaiLTaAXF8OoUQp0EYmu2AQ66GpREYk2BbqISEQo0EVEIkKBLiISEQp0EZGIiF2gt7ZCe3vYlYiI5F7sAh1g+/Zw6xARGQixCnRdXCQiURarQNfl/yISZb0GuplNMrOVZrbBzNab2ZI085iZ3Wlmm81snZkdNzDl9o8CXUSirCiDedqBq9x9jZmVAavN7Bl3fztlnnnAUYl2AvDTxDCvKNBFJMp63UN39wZ3X5N4vRPYAEzsMts5wP0eWAWUm9mEnFfbTwp0EYmyrPrQzWwqMBN4tcukicCWlPf1HBr6mNkiM6s1s9qmpqYsS+2/sjIoKlKgi0g0ZRzoZlYKPAxc6e47uk5O8xE/ZIT7Pe5e4+41lZWV2VWaA2a6uEhEoiujQDezYoIwf8DdH0kzSz0wKeV9FbC1/+XlngJdRKIqk7NcDLgX2ODut3cz2+PARYmzXU4EWt29IYd15kxFBXz0UdhViIjkXiZnuZwMXAi8aWZ1iXE3AJMB3H0Z8BQwH9gMfAxcmvNKc+Soo+Dxx8E96IIREYmKXgPd3V8mfR956jwOLM5VUQOppgaWL4f334cpU8KuRkQkd2J1pSjA8ccHw9dfD7cOEZFci12gz5gRPF+0tjbsSkREcit2gT5sGBx7rAJdRKIndoEOQT96bS0cOBB2JSIiuRPbQG9thXffDbsSEZHciWWgJw+MqttFRKIkloH+2c9CSYnOdBGRaIlloBcVwcyZ2kMXkWiJZaBD0O2yZg3s3x92JSIiuRHbQK+pgV27YOPGsCsREcmN2Aa6DoyKSNTENtA/8xkoLdWBURGJjtgGekEBzJqlPXQRiY7YBjoE3S51dbB3b9iViIj0X6wDvaYG9uyB9evDrkREpP9iHeg6MCoiURLrQJ82DcaM0YFREYmGWAe6WcedF0VEhrpYBzoE3S5vvgm7d4ddiYhI/8Q+0GtqoL0d1q4NuxIRkf5RoNcEQ3W7iMhQF/tAr6qC8eN1YFREhr7YB7oOjIpIVMQ+0CE4MLphA7S1hV2JiEjf9RroZrbczBrN7K1ups81s1Yzq0u07+a+zIFVUxM8MPqNN8KuRESk7zLZQ/9v4Ixe5vmju1cn2s39L2tw6cCoiERBr4Hu7i8B2wahltCMHw+TJunAqIgMbbnqQz/JzNaa2dNm9vfdzWRmi8ys1sxqm5qacrTo3NCBUREZ6nIR6GuAKe5+LPBj4LHuZnT3e9y9xt1rKisrc7Do3KmpgU2boKUl7EpERPqm34Hu7jvcvS3x+img2MzG9buyQZa88+Lq1eHWISLSV/0OdDM7wsws8Xp24jub+/u9g23WrGCobhcRGaqKepvBzH4NzAXGmVk98D2gGMDdlwHnAZeZWTvwCbDQ3X3AKh4gY8fCkUfqwKiIDF29Brq7X9DL9LuAu3JWUYhqamDVqrCrEBHpG10pmqKmBv72N8izE3BERDKiQE+hR9KJyFCmQE8xc2Zwsy51u4jIUKRATzFqFMydC8uWwY4dYVcjIpIdBXoXt90GjY3w7/8ediUiItlRoHdRUwMXXgh33AHvvRd2NSIimVOgp/GDH0BBAVx/fdiViIhkToGeRlUVXHstPPggvPJK2NWIiGRGgd6Na66BT30K/vVfg4dfiIjkOwV6N0aODLpeXnsNVqwIuxoRkd4p0Htw4YVw3HFBX/rHH4ddjYhIzxToPSgoCM522bIFbr897GpERHqmQO/F5z4H554Lt94KDQ1hVyMi0j0FegZuuw327oUbbwy7EhGR7inQM3DkkXDFFXDfffDGG2FXIyKSngI9QzfeCBUVcNVVMPQe3yEicaBAz1B5Odx0E6xcCUuXKtRFJP8o0LOwaBH88z/DddcFpzTqVEYRyScK9CwUFcFDD8Ett8CvfgVz5gRPOBIRyQcK9CwVFMC3vw1PPAF/+QvMmgXPPx92VSIiCvQ+O/NMeP11GD8eTj89uABJ/eoiEiYFej8cdVTwuLovfhG++U31q4tIuBTo/VRWBv/zPx396v/wD/Dkk7B/f9iViUjcKNBzILVfvakJzj4bPvMZ+OEPYdu2sKsTkbjoNdDNbLmZNZrZW91MNzO708w2m9k6Mzsu92UODWeeGTy27qGHgodkXH11MPzqV6GuLuzqRCTqMtlD/2/gjB6mzwOOSrRFwE/7X9bQVVwMX/oSvPgirF0b9Ks/8ADMnAmnnAL33w9bt4ZdpYhEUa+B7u4vAT11HJwD3O+BVUC5mU3IVYFD2THHwM9+Bh98EHS/bN0KF18MEycGB1S/8pUg4PUwahHJhaIcfMdEYEvK+/rEuENuNmtmiwj24pk8eXIOFj00jBkTnAVz5ZXBzb1efBFeegkefRSWLw/mmTw5uFXvCSfA0UcHbcIEMAu1dBEZQnIR6OkiJ+0Z2e5+D3APQE1NTezO2i4oCC5EmjUrCPgDB2D9+o6A/8Mf4Je/7Jh/1Cj4u78L2tFHB8Np04J++bFjFfYi0lkuAr0emJTyvgpQL3EGCgpgxoygXX55cGFSQwNs3AgbNgTDjRvhueeCrplUJSVB101VVefh4YdDZWXQxo0L2mGHhfP3icjgykWgPw5cbmYrgBOAVnfXs336wAw+9amgnXpq52k7dwbh/v77UF8f9MvX1wdt1apguHdv+u8dPToI9srKoPunvLz74ahRHW30aBg+XL8ERIaKXgPdzH4NzAXGmVk98D2gGMDdlwFPAfOBzcDHwKUDVWyclZXB8ccHLR13aG6GxsbgXPiPPgqGqa+Tw02bYPt2aGnp/QKowsKOgC8rg9LSoI0c2fE6ddyIEcEwXUudt7g456tIJPZ6DXR3v6CX6Q4szllF0idmHV0smXKHtrYg2JMBv2NHz23XruAzjY3BMNmyveXBsGHpNwjJYepGIN3GI10bOTLoxhKJq1x0ucgQZRbsdZeVwaRJvc/fkwMHglDftevQ9vHHQegnNwZtbUEXUuoGYefOYHpzc8d8yc8fOJB5HSNGpN9IlJZ2/K3J1nVc8n3qsKREXU4ydCjQJScKCjqCM5fcYc+ezuGfriU3CKkbjdRpDQ3BMNn27cts+YWF6TcGyS6o1Nepxx9Su6mSwxEjtHGQgaVAl7xmFuwll5Rk153Umz17Ogd8T78c0rXGxqALaufOYNje3vsykxu9dBuH0aM7hqmv040rLVXXkqSnQJdYGjYsaLnYSCR/RfR07KG7DUPqxqG1NRj2dl99s85nInUN/K7vy8s7WvK9NgrRpEAX6afUXxGHH96/70oeqE4GfDLkU4ddx7e2wocfBmcvJd/v2dPzcgoKDg37ri31lNaxY4Nhsg0b1r+/UwaGAl0kj6QeqJ44se/fk/zF0NISBHxLS0dLfb99e8f7TZs6xre19fz9I0Z0Dvp0w66vx44NfjXol8HAUaCLRNCwYR1XDPfFvn0dQb9tWxD827enf71tG7z7bsf7nk5hLSjo2MsfOxYqKoKW+rprGzdOB5QzpUAXkUMUF2d/XUPSnj1BsDc3dw7/1Jac3tgY3OZi27bgF0V3ksc7kgGfybCsLH4bAQW6iOTUsGFwxBFBy8a+fUGwNzd3bh99dOhw3bqO6d0dRE7dKCVvfZH6Pl0bPrz/f3+YFOgikheKi2H8+KBl6sCBoFsoXeh3bWvXBsNt27rfCIwY0XkDkK6l3gBv1Kj8+hWgQBeRIaugoOOAa6b27+/4JZAu+FPvg/TOO8Fw167035X8FZAu9A8/vGMDlXw9cmRu/u7uKNBFJFYKC7M/YPzxxx0hn9pSw7+pCWprg2Fra/rvGTkyCPfLLw+eiZBrCnQRkV6MGAFTpgQtE3v3Bgd8P/ywo6W+z/b4QqYU6CIiOXbYYcFDZ6qqBne5OsVfRCQiFOgiIhGhQBcRiQgFuohIRCjQRUQiQoEuIhIRCnQRkYhQoIuIRIR5b8+7GqgFmzUBf+thlnHAR4NUTrZUW9+otr5RbX0T1dqmuHvaGxeEFui9MbNad68Ju450VFvfqLa+UW19E8fa1OUiIhIRCnQRkYjI50C/J+wCeqDa+ka19Y1q65vY1Za3fegiIpKdfN5DFxGRLCjQRUQiIu8C3czOMLN3zGyzmV0fdj2pzOw9M3vTzOrMrDbkWpabWaOZvZUybqyZPWNmmxLDMXlU27+Z2QeJdVdnZvNDqm2Sma00sw1mtt7MliTGh77ueqgt9HVnZiVm9pqZrU3UdlNifD6st+5qC329pdRYaGZvmNmTifcDst7yqg/dzAqB/wf8I1APvA5c4O5vh1pYgpm9B9S4e+gXK5jZ54A24H53n54Ydxuwzd1vTWwMx7j7dXlS278Bbe7+H4NdT5faJgAT3H2NmZUBq4EFwCWEvO56qO18Ql53ZmbASHdvM7Ni4GVgCXAu4a+37mo7gzz4NwdgZt8EaoBR7n7WQP2/mm976LOBze7+F3ffC6wAzgm5przk7i8B27qMPgf4ReL1LwjCYNB1U1tecPcGd1+TeL0T2ABMJA/WXQ+1hc4DbYm3xYnm5Md66662vGBmVcCZwM9TRg/Iesu3QJ8IbEl5X0+e/INOcOAPZrbazBaFXUwa4929AYJwAA4PuZ6uLjezdYkumVC6g1KZ2VRgJvAqebbuutQGebDuEt0GdUAj8Iy7581666Y2yIP1BvwIuBY4kDJuQNZbvgW6pRmXN1ta4GR3Pw6YByxOdC1IZn4KHAlUAw3AD8MsxsxKgYeBK919R5i1dJWmtrxYd+6+392rgSpgtplND6OOdLqpLfT1ZmZnAY3uvnowlpdvgV4PTEp5XwVsDamWQ7j71sSwEXiUoIson3yY6IdN9sc2hlzPQe7+YeJ/ugPAfxHiukv0sz4MPODujyRG58W6S1dbPq27RD0twAsEfdR5sd6SUmvLk/V2MvDFxPG3FcCpZvZLBmi95Vugvw4cZWbTzOwwYCHweMg1AWBmIxMHqjCzkcDpwFs9f2rQPQ5cnHh9MfDbEGvpJPmPN+GfCGndJQ6g3QtscPfbUyaFvu66qy0f1p2ZVZpZeeL1cOALwEbyY72lrS0f1pu7f8vdq9x9KkGePe/u/8JArTd3z6sGzCc40+Vd4Nth15NS16eBtYm2PuzagF8T/IzcR/DL5itABfAcsCkxHJtHtf1f4E1gXeIf84SQaptD0I23DqhLtPn5sO56qC30dQccA7yRqOEt4LuJ8fmw3rqrLfT11qXOucCTA7ne8uq0RRER6bt863IREZE+UqCLiESEAl1EJCIU6CIiEaFAFxGJCAW6iEhEKNBFRCLi/wO1vKB6yNy13AAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"df = pd.DataFrame(model.learning_process).iloc[:, :2]\n",
"df.columns = [\"epoch\", \"train_RMSE\"]\n",
"plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7ff52f336dc0>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyp0lEQVR4nO3de3zO5f/A8dd7B8wmNoYxzJk5DXMqhaQIoa9EB5VKwpcOVPrp+O3gmw4o36JzVOigFKGYFEUbm7M2TGbOQ5jTtuv3xzVrGO6x+dz3vffz8djD7s/hvt/XPvXetetzfd6XGGNQSinlvXycDkAppVTh0kSvlFJeThO9Ukp5OU30Sinl5TTRK6WUl/NzOoC8lCtXzkRERDgdhlJKeYy4uLi9xpjQvPa5ZaKPiIggNjbW6TCUUspjiMjWc+3ToRullPJyF0z0IvKBiOwWkTXn2C8iMkFEkkRklYg0y7Wvs4hszN73REEGrpRSyjWu9Og/AjqfZ38XoHb210DgbQAR8QUmZu+PBPqJSOSlBKuUUir/LjhGb4xZLCIR5zmkB/CJsbUUfheRMiISBkQAScaYzQAiMi372HWXHLVSyq2cPHmSlJQUjh075nQoXq9EiRKEh4fj7+/v8jkFcTO2MrAt1+uU7G15bW91rjcRkYHYvwioWrVqAYSllLpcUlJSKFWqFBEREYiI0+F4LWMM+/btIyUlherVq7t8XkHcjM3rqprzbM+TMWayMSbaGBMdGprnDCGllJs6duwYZcuW1SRfyESEsmXL5vsvp4Lo0acAVXK9DgdSgWLn2K6U8kKa5C+Pi/k5F0SPfhbQP3v2TWvgoDFmB/AHUFtEqotIMaBv9rGFJstksWLHisL8CKWU8jiuTK/8HPgNqCsiKSJyr4gMEpFB2YfMATYDScC7wGAAY0wGMBSYB6wHZhhj1hZCG3J8sfYLmk9uTp8v+rB5/+bC/CillPIYF0z0xph+xpgwY4y/MSbcGPO+MeYdY8w72fuNMWaIMaamMaaRMSY217lzjDF1sve9WJgNAehapyvPtHuG2YmzqT+xPiPmj2D/0f2F/bFKKYcdOHCA//3vf/k+78Ybb+TAgQP5Pu/uu++mevXqREVF0aRJExYsWJCzr3379lStWpXcizr17NmToKAgALKyshg2bBgNGzakUaNGtGjRgi1btgC2KkCjRo2IiooiKiqKYcOG5Tu2vLhlCYSLFVQsiGfbP8v9ze5ndMxoXv/tdRZvXcyy+5bp+KFSXuxUoh88ePBp2zMzM/H19T3neXPmzLnozxw7diy9e/cmJiaGgQMHkpiYmLOvTJkyLFmyhLZt23LgwAF27NiRs2/69OmkpqayatUqfHx8SElJITAwMGd/TEwM5cqVu+i48uJVif6UyldU5sMeHzK81XAOHDuAiJB+Mp2FWxbStXZXTfpKFaKHHoL4+IJ9z6goGDfu3PufeOIJNm3aRFRUFP7+/gQFBREWFkZ8fDzr1q2jZ8+ebNu2jWPHjjF8+HAGDhwI/FNX6/Dhw3Tp0oW2bduydOlSKleuzLfffktAQMAFY2vTpg3bt28/bVvfvn2ZNm0abdu25euvv+bmm29m7Vo7cr1jxw7CwsLw8bEDKuHh4Rf1M8kPr651E1UxivYR7QH4YOUHdP+8Ox0+7qA3bJXyMmPGjKFmzZrEx8czduxYli9fzosvvsi6dfb5zA8++IC4uDhiY2OZMGEC+/btO+s9EhMTGTJkCGvXrqVMmTJ89dVXLn323Llz6dmz52nbOnbsyOLFi8nMzGTatGnceuutOfv69OnDd999R1RUFI8++igrV6487dwOHTrkDN288cYb+fxJ5M0re/R5eaD5A/iID88seoboydH0a9SPQc0HcXW1q50OTSmvcr6e9+XSsmXL0x4omjBhAjNnzgRg27ZtJCYmUrZs2dPOOTXmDtC8eXOSk5PP+xkjR47kscceY/fu3fz++++n7fP19aVt27ZMnz6do0ePkrvsenh4OBs3bmThwoUsXLiQjh078sUXX9CxY0egcIZuvLpHn5u/rz+DWwwm6d9JjLxyJN9u+JYXf/nn/nDK3ykORqeUKki5x7wXLVrETz/9xG+//UZCQgJNmzbN84Gj4sWL53zv6+tLRkbGeT9j7NixJCUl8cILL3DXXXedtb9v3778+9//pk+fPnl+VpcuXRg7dixPPvkk33zzTT5al39FJtGfUrpEaf7b6b/sGrGLd7q9A9gkX21cNVq914qJyyeyN32vw1EqpfKjVKlSHDp0KM99Bw8eJDg4mJIlS7Jhw4azet+XwsfHh+HDh5OVlcW8efNO23f11VczatQo+vXrd9r2FStWkJpqnx3Nyspi1apVVKtWrcBiyjPOQn13NxZYLJCIMhH2e/9AXrnuFY5lHGPoD0MJey2MHtN6sHHvRmeDVEq5pGzZslx11VU0bNiQkSNHnravc+fOZGRk0LhxY5566ilat25doJ8tIowePZpXXnnlrO0jRow4axhm9+7ddO/enYYNG9K4cWP8/PwYOnRozv7cY/T9+/cvmBhzz/V0F9HR0capFaYSdiYwZdUUpq+dzvL7lhNWKoyEnQkEFQuiZkhNR2JSyt2tX7+e+vXrOx1GkZHXz1tE4owx0XkdX2R79OfSpGITXr3+VbY+tJWwUmEAPDr/UWq/WZue03oSsyUGd/zlqJRS56KJ/hx85J8fzZReUxh9zWiWbFvCtZ9cS9SkKGaun+lgdEqpy2HIkCE5wyinvj788EOnw8q3IjO98lKElQrj+Q7PM6rtKD5f8znjfh/HXwf/AuBYxjH2H92f0/tXSnmPiRMnOh1CgdAefT4E+AcwoOkAEgYlMKTlEAA+W/0Z1cZVo//M/izfvlyHdZRSbkcT/UUQEfx87B9D7SPa82D0g8zcMJNW77Wi6riqDJ49mIys88/BVUqpy0UT/SWqEVyD8V3Gk/JwCh/c9AEtKrVgze41Ob8Ixvw6hg9WfsCuw7scjlQpVVTpGH0BKV2iNPc0vYd7mt6TM3yTZbL4KP4jNu7biCC0rNyS7nW60zuyN3XL1XU4YqVUUaE9+kJwqjqmj/iwfsh6Vj6wkufaP0eWyWJ0zGg+W/0ZACcyT5CUluRkqEp5hYutRw8wbtw40tPTz3vMqTrxjRs3pl27dmzdujVnn4hw55135rzOyMggNDSUbt26AbBr1y66detGkyZNiIyM5MYbbwQgOTmZgICA02b0fPLJJxfVhgsyxrjdV/PmzY23Sv071ew8tNMYY8ysDbMMz2JavdvKTPh9gtl1eJfD0Sl1cdatW+fo52/ZssU0aNDgos6tVq2a2bNnj8vHPP300+a+++7L2RcYGGiioqJMenq6McaYOXPmmCZNmpiuXbsaY4wZOHCgGTduXM7xCQkJlxxzXj9vINacI6fq0M1llnsaZovKLXjlulf4dPWnDJs7jIfnPUynmp2Y2msqZUuWPc+7KOXe2n/U/qxtfRr0YXCLwaSfTOfGT288a//dUXdzd9Td7E3fS+8ZvU/bt+juRef9vNz16Dt16kT58uWZMWMGx48fp1evXjz33HMcOXKEPn36kJKSQmZmJk899RS7du0iNTWVDh06UK5cOWJiYi7YtjZt2jBhwoTTtnXp0oXZs2fTu3dvPv/8c/r168cvv/wC2Prz119/fc6xjRs3vuBnFDQdunFQxaCKjLxqJPGD4lnz4Boeu+ox0k+mExwQDMAnCZ8wc/1Mdh/Z7XCkSrm33PXoO3XqRGJiIsuXLyc+Pp64uDgWL17M3LlzqVSpEgkJCaxZs4bOnTszbNgwKlWqRExMjEtJHvKuP39qoZFjx46xatUqWrVqlbNvyJAh3HvvvXTo0IEXX3wxp6AZkPPL6dTXqV8OBU179G6iQfkGvNTxpZzXxhjG/DqG9XvXA1C9THVahbeiR90e9G3Y16kwlXLJ+XrgJf1Lnnd/uZLlLtiDP5/58+czf/58mjZtCsDhw4dJTEzk6quvZsSIETz++ON069aNq6/O31oUHTp0YNeuXZQvX54XXnjhtH2NGzcmOTmZzz//PGcM/pQbbriBzZs3M3fuXH744QeaNm3KmjVrAHJ+ORU27dG7KREhflA8i+9ezNhOY2kW1oxf//qVn5N/BiAzK5N2H7Vj2A/D+HTVp2xK26QPaymF7SSNGjWK+Ph44uPjSUpK4t5776VOnTrExcXRqFEjRo0axfPPP5+v942JiWHr1q00aNCAp59++qz9N910EyNGjDirLDFASEgIt912G1OmTKFFixYsXrz4ott3MbRH78aK+Rbj6mpXn7YK1onMEwDsO7oPQXh/5fu8ufxNwPb6X7/hdXrW6+lEuEo5Jnc9+htuuIGnnnqK22+/naCgILZv346/vz8ZGRmEhIRwxx13EBQUxEcffXTaua6s6hQQEMC4ceNo1KgRo0ePJiQkJGffgAEDKF26NI0aNWLRokU52xcuXEjr1q0pWbIkhw4dYtOmTVStWrVA238hmug9TDHfYgCUDyzPorsXkZGVwdrda1m6bSnfJ35P2QB7E3f59uW8t+I9etXrxbXVr6W4X/Hzva1SHi13PfouXbpw22230aZNGwCCgoKYOnUqSUlJjBw5Eh8fH/z9/Xn77bcBGDhwIF26dCEsLMylcfqwsDD69evHxIkTeeqpp3K2h4eHM3z48LOOj4uLY+jQofj5+ZGVlcV9991HixYtSE5OzhmjP2XAgAEMGzbsEn8aZ9N69F5qSsIUBs8ZzOETh7mi+BV0rd2VXvV60aNej5xfFkoVFK1Hf3lpPXoFwJ1N7mTPyD183+97bom8hR83/8j9392fsz9mSwwJOxPIMlkORqmUuhx06MaLlfArQdc6XelapyuTsiaRmJaY05sf+sNQ1u1ZR7mS5egQ0YFrq1/LdTWuo1ZILYejVso5rVq14vjx46dtmzJlCo0aNXIoooKhib6I8PXxpV65ejmv594+l5jkGBZsWcCCzQv4Yt0X3NboNj69+VOMMUxfO52rqlxFldJVHIxaeRJjTE75D0+1bNkyp0O4oIsZbtdEX0RVKV2F/k36079Jf4wxJKUlkWkyAdi8fzP9vrJTxCqXqkzzSs1pVrEZtzS4hcjQSCfDVm6qRIkS7Nu3j7Jly3p8sndnxhj27dtHiRIl8nWeJnqFiFC7bO2c19WDq5MwKIGYLTH8kfoHK3as4LuN3xEZGklkaCRxqXGMjhlN87DmNAtrRrOwZlQrXU3/By/CwsPDSUlJYc+ePU6H4vVKlChBeHh4vs5xKdGLSGdgPOALvGeMGXPG/mDgA6AmcAwYYIxZk70vGTgEZAIZ57orrNyHj/jQuEJjGlf4pybH4ROH8RVfAA4cO0DqoVR+3PRjzl8B5QPL8+OdP552jio6/P39qV69utNhqHO4YKIXEV9gItAJSAH+EJFZxph1uQ57Eog3xvQSkXrZx3fMtb+DMWZvAcatLrOgYkE533es0ZGEQQkcPXmU1btXs2LHCpZsW0LtEPtXwYuLX2Rh8kK61e5G1zpdqVO2jlNhK6VwbXplSyDJGLPZGHMCmAb0OOOYSGABgDFmAxAhIhUKNFLldgL8A2hZuSWDogcxpdcUAvwDAAgOCGbX4V08Mv8R6r5Vlzpv1mH0wtEOR6tU0eVKoq8MbMv1OiV7W24JwM0AItISqAacGkQywHwRiRORgZcWrvIEg1sMZs3gNWwZvoW3urxFrZBaJKYl5uwf9sMw3lz2Juv2rNP6PEpdBq6M0ed1h+3M/zvHAONFJB5YDawETq2OfZUxJlVEygM/isgGY8xZFX2yfwkMBC66DsSrr0LHjpBdtE45LKJMBENaDmFIyyE5CT39ZDpzEuewaf8mAMKCwuhYoyP3Nb2PdhHtnAxXKa/lSo8+Bcg9mTocSM19gDHmb2PMPcaYKKA/EApsyd6Xmv3vbmAmdijoLMaYycaYaGNMdGhoaH7bwYED8Mor0Lw53HUXbNt2wVPUZXRqRk5J/5IkDUti87DNvNv9XdpFtGNe0jz+3PcnAH8d/IvBswfz9fqvSTua5mTISnmNC9a6ERE/4E/szdXtwB/AbcaYtbmOKQOkG2NOiMj9wNXGmP4iEgj4GGMOZX//I/C8MWbu+T7zYmvdHDgAL78M48eDCDzyCDz+OFxxRb7fSl1GWSaLjKwMivkWY27SXHrP6M2Rk0cQhKZhTelYvSMPtX6ISqUqOR2qUm7rkmrdGGMygKHAPGA9MMMYs1ZEBonIoOzD6gNrRWQD0AU4VcKtAvCriCQAy4HZF0ryl6JMGfjvf2HDBrj5ZnjpJahVC95+G06eLKxPVZfKR3xySjN0rtWZtMfT+OWeX3im3TMEFQti/LLx+Ij9T/WrdV/xTMwzLN66mOMZx8/3tkqpbF5dvTI2Fh59FBYvhnr17C+B7t1tb195jqMnj+bM6Hl03qOMWzaOLJNFgF8Abau25aa6NzG05VDAzvEvVawUvj6+Toas1GV3vh69Vyd6AGPgu+/gscdg40Zo187etI3Wx7Y81v6j+1m8dTELtixg4ZaFFPcrTtzAOABavtuSFTtWUDGoIpWvqEylUpW4MvxKRl41EoD4nfHUK1ePEn75e4RcKXdXpBP9KSdPwrvvwjPPwN690K8fjB4NkVq6xeMdPnE454GuKQlT2LhvI9sPbSf1UCqph1JpWrEpn/T6BIBKr1Ui7WgarcJb0b5ae9pFtKNNeJucvxiU8lSa6HM5eNAO4YwfD+np0KsXjBoFLVoUyscpNzP7z9nEJMfw89afWbFjBVkmi+GthjOu8zhOZp5k8dbFtKnShpL+JZ0OVal80USfh717YcIEePNNO1vnuuvgySehfXsdwy8qDh47yK9//UrV0lVpVKERy1KW0fr91vj7+NOycktah7emacWmXF/zekID8z/lV6nLSRP9efz9N0yaBK+9Brt2QevWtoffrRv46PpbRcqRE0dYvHUxi5IX8fPWn4nfGc/xzOPE3BVD+4j2/PrXr0xbMy2nYmdkaKQuy6jchiZ6Fxw7Bh9+aB+6Sk6Ghg1twu/TB/y0mHORdDLzJBv2bqBmSE1K+pfk/RXv89C8hzh84jBgF2pvWL4hP9z+A+UDy7Nm9xrSjqYRUSaCyqUq68wfdVlpos+HjAyYNg3GjIG1a6FGDRg2DO68E0JCHAlJuZEsk0VSWhIrd6xkxY4VrNmzhll9Z+Hr48uD3z/IO3HvAODn40eVK6pQK6QW8+6Yh4iwfPty/Hz8aBDagOJ+xR1uifI2mugvQlaWnZY5Zgz8/juUKAG33AIPPABXXqnj+Ops2w5uY8PeDSQfSCb5QDJbDmzhaMZRZt46E4Cun3VlTuIc/H38aVC+AU0rNqVt1bYMaDrA4ciVN9BEf4ni42HyZJg6FQ4dggYNYOBA28sPDnY6OuUpNu/fTGxqLCt2rGDlTvsXQcPyDYm5KwaAf834F8V9i9MsrBmNKzSmdkhtqpSugp+Pjh2qC9NEX0COHLHDOpMmwR9/2F5+nz62l9+mjfbyVf4YYzh84jClipfCGEOfL/uwLGUZ2/7+pyLfXU3u4qOeH2GM4ZF5j1CtTDVqBtekZkhNagTX0Ae/VA5N9IVg5Urby//0U9vLb9gQ7r/fPoh1EcU3lcqxN30va3avYVPaJqoHV+fa6tey/+h+IsZH8Pfxv0879uWOL/NE2yc4dPwQE5ZNoGZIzZxfBMElgnUd3yJEE30hOnzY9vInT7a9fD8/6NLFDut07257/UoVBGMM+47uY1PaJpLSkti0fxPtI9pzTbVrSNiZQNSkqNOOL128NO92f5dbGtxC6qFU5iTOoU7ZOtQpW4cKgRX0l4CX0UR/maxeDVOm2F5+aiqULm2Hdvr3h6uu0qEdVbiOnDjC5v2b2bR/k/03bRP3NruXZmHN+GbDN/Sa3ivn2KBiQdQOqc17N71Hs7BmbP97O9v+3kadsnUICdDpZZ5IE/1llpkJCxfapP/113Zsv3p1uOMO29OvXdvpCFVRk5mVyV8H/yIxLZE/9/1J4r5E/kz7k4k3TqRGcA3eXPYmw+YOA6BsQFnqlqtL/XL1ebnjy4QGhnLkxBFK+JXQZwPcmCZ6Bx0+DDNn2qT/00+2mmbr1nDvvXY8PzDQ6QiVgh2HdhCbGpvzi2DD3g1s2LuBTcM2EVgskMd/fJwJyydQp2wd6perb79C69M7sjc+4oMxRoeCHKaJ3k1s3w6ffQYff2wfxipd2g7rPPgg1K/vdHRKndv8TfOZv2k+6/euZ/2e9SQfSKZMiTLse2wfIkLfL/uycMtCKgRVoGJQRSoEVqBO2To83e5pABJ2JgBQPrA85UqWw9/X38nmeCVN9G7GGFiyxK589eWXcOKErZM/eDD07AnFtHyKcnNHTx5l+6Ht1AqpBcBH8R+xLGUZO4/sZNfhXew8vJMKQRX47d7fALjy/Sv5LeW3nPODSwRzbfVr+bLPlwCM+XUMxzKOUT6wPJVKVaJWSC1qBtfU8tH5oIneje3ebWvsTJoEW7ZAhQpw3332gayqVZ2OTqmCEZsay18H/2L3kd3sObKH3Ud2U6lUJUZdPQqAppOakrAzAcM/+ahr7a58f9v3AIz6aRShgaHUDqlN7bK1qRFcQwvKnUETvQfIyoJ58+B//4PZs+0Mna5d7bDO9deDr94DU14uIyuDfen7SPk7hcS0RIJLBHNDrRs4mXmSyq9XZk/6npxjfcSHUW1H8cK1L3Do+CGun3o9fj5++Pv44+/rj5+PH/0b9+fWhreyL30fj/34GCEBITlfZUuWpXlYc6oHVycjK4NjGccI9A/06PsM50v0+my1m/DxsfPvu3SBrVvtvPz33rP1dipXtmP5d98Ndeo4HalShcPPx48KQRWoEFSB5pWa52z39/Vn98jd7EvfR2JaIon7EklMS6RV5VYAGAylipUiIyuDk1knST+ZTkZWRk6V0aMZR5m7aS77j+7naMbRnPedeONEBrcYzLo962jyThOK+RYjJCCEikEVqVSqEo9d+RjtItqxN30vS7ctpVKpSlQqVYnygeU9riyF9ujd2IkTNtF/+CH88IPt9V91Fdxzj52fX6qU0xEq5VmOnjxK2tE00o6mUSGoAuUDy7Pj0A6mrppK2tE09qbvZeeRnaQeSmVMxzF0qtmJHxJ/4MbPbsx5Dx/xoUJgBb6+9Wtah7dmweYFvPjLixTzLUZxv+L2X9/iPN/heWoE12BZyjLmbZpHuZLlKFeyHKElQylXshx1y9Ut0OEn7dF7qGLF4F//sl87dtgpmh9+aMfwhw2D3r1t0r/mGl0kRSlXBPgHUNm/MpWvqJyzLaxUWM7i8Xm5utrV/HH/HzlrEJ/6Kl28NACZJpOMrAyOnDzCicwTHM84zonMExw9af96WLVrFc8seuas9036dxI1Q2oycflExi0bR7mS5Zjaayo1Q2oWcKu1R+9xjIFly2zCnzbNrpBVvbod1rnnHqhSxekIlVJnOpl5Mucvhj3pe9ibvpdudbpRwq8E3274lulrp7MnfQ8f9/yYSqUqXdRn6M1YL5Webh/G+vBD+yTuqRu4gwbBDTfoDVylipLzJXr9g9+DlSwJt99un7jdvNkufbh8uU32NWvCSy/Bzp1OR6mUcpomei8REQEvvADbtsEXX0CtWvB//2eHcm65BRYssDdzlVJFjyZ6L+Pvb2/S/vQTbNwIw4dDTAxcdx3UqwevvQZ79zodpVLqctJE78Xq1IFXX4WUFLsMYoUKMGKEnZd/++2weLG9uauU8m6a6IuAEiVsYv/lF1sz/4EH7NO37dpBZCS88Qbs2+d0lEqpwqKJvohp2BAmTLALo3z4oV3c/JFHbC//jjvsLwPt5SvlXVxK9CLSWUQ2ikiSiDyRx/5gEZkpIqtEZLmINHT1XOWMkiXt3PulS2HVKvsQ1nff2YevGjSAceMgLc3pKJVSBeGCiV5EfIGJQBcgEugnIpFnHPYkEG+MaQz0B8bn41zlsEaN4K23bC//gw9snfyHH4ZKlexDWPHxTkeolLoUrvToWwJJxpjNxpgTwDSgxxnHRAILAIwxG4AIEang4rnKTQQG2sT+2282uQ8YADNmQNOmdjz/66/tMolKKc/iSqKvDGzL9Tole1tuCcDNACLSEqgGhLt4LtnnDRSRWBGJ3bNnT16HqMuoSRNbMnn7djtzZ+tWW3OnVi07RfPAAacjVEq5ypVEn1eB5jNv140BgkUkHvg3sBLIcPFcu9GYycaYaGNMdGhoqAthqcuhTBl49FFISoKvvrKLoYwYAeHhMGSInauvlHJvriT6FCB3qaxwIDX3AcaYv40x9xhjorBj9KHAFlfOVZ7Bzw9uvhl+/hlWrrRP2773nn0Iq0sXmDtXn7xVyl25kuj/AGqLSHURKQb0BWblPkBEymTvA7gPWGyM+duVc5XniYqyUzO3bYPnn7fj+V26QN268PrrOltHKXdzwURvjMkAhgLzgPXADGPMWhEZJCKDsg+rD6wVkQ3YGTbDz3duwTdDOaF8eXjqKTt+/9lnULGiHeapXBnuvRfi4pyOUCkFWqZYFbCEBHsTd+pUW0a5VSsYPNiuiFWihNPRKeW9tEyxumyaNIFJk+yc/PHj7eycu+6yN28ffxy2bHE6QqWKHk30qlCULm2XO1y/3lbSbNfOTsusWRN69LA3dd3wj0mlvJImelWoRKBjRzs1MzkZnnwSliyB9u0hOho+/dQugq6UKjya6NVlEx7+z+IokybZMfw77rBr3o4Zo7N1lCosmujVZRcQAAMHwtq1MGeOLZU8apRdDWvIEEhMdDpCpbyLJnrlGB8fO//+xx/tbJ1bb7UPYdWtq+P4ShUkTfTKLTRubCtnbt1q5+YvXWrH8du0gW++0adulboUmuiVW6lYEZ57Dv76C95+G/bsgV69bI38jz7SG7dKXQxN9MotBQTAoEG2aNrnn0Px4raEcq1adn7+kSNOR6iU59BEr9yanx/07WsLqf3wA9SoAQ89ZKtoPvecrnWrlCs00SuPIAKdO8OiRXb8vm1bePZZm/AfftiO7Sul8qaJXnmcNm3g229hzRro3dsug1izpq2ns3SpztRR6kya6JXHatAAPv4YNm+2VTN//BGuugpat7bj+idPOh2hUu5BE73yeFWqwH//CykpMHGiLaR22236xK1Sp2iiV14jMNCWRF6/HmbP/ueJ2/BwePBBu12pokgTvfI6Pj5w440wfz6sXg23325XxIqMtNuXLnU6QqUuL030yqs1bAjvvmsLqf3nPxAba8fxO3a0JRaUKgo00asiITQURo+2C5+8/jqsW2dLLFxzja2XrzN1lDfTRK+KlMBAO+9+82Z48037b6dOcOWVtpKmJnzljTTRqyIpIACGDoVNm+Cdd2DHDujaFVq0sHP0NeErb6KJXhVpxYvDAw/YGvjvv2+nZvbsCVFR8OWXWjVTeQdN9EoB/v4wYABs2ACffALHj8Mtt9jFzmfM0ISvPJsmeqVy8fODO++0q199+ilkZNgFURo1gmnTIDPT6QiVyj9N9ErlwdfXPl27Zo1N8AD9+tnpmp99pglfeRZN9Eqdh6+v7dGvXm2HcPz87ANYDRrA1Km2x6+Uu9NEr5QLfHzsmH1Cgr1JW7y4HeKJjLSF1TThK3emiV6pfPDxgX/9yy6E8vXXdl7+3XdDvXq2zIJWzFTuSBO9UhfBx8euZbtihZ13f8UVdtZOvXp2mqYmfOVONNErdQlE4KabIC4OZs2C4GC47z6oU8fW2NHFzJU70ESvVAEQge7d4Y8/4PvvbW2dgQOhdm2YNEkTvnKWS4leRDqLyEYRSRKRJ/LYX1pEvhORBBFZKyL35NqXLCKrRSReRGILMnil3I2ILaWwbJmtnRMWBoMGQa1a8Pbb9kEspS63CyZ6EfEFJgJdgEign4hEnnHYEGCdMaYJ0B54TUSK5drfwRgTZYyJLpiwlXJvItClC/z2G8ydaxc/GTzYjuFPmaLz8NXl5UqPviWQZIzZbIw5AUwDepxxjAFKiYgAQUAaoBPOVJEnAjfcAEuW2IQfHAz9+0PTpnaIR4unqcvBlURfGdiW63VK9rbc3gLqA6nAamC4MeZUdRADzBeROBEZeK4PEZGBIhIrIrF79uxxuQFKeYJTCT821j5pe/SoHdO/5hr7S0CpwuRKopc8tp3ZD7kBiAcqAVHAWyJyRfa+q4wxzbBDP0NE5Jq8PsQYM9kYE22MiQ4NDXUldqU8jo+PfdJ23To7Zp+UBG3b2pk7q1c7HZ3yVq4k+hSgSq7X4diee273AF8bKwnYAtQDMMakZv+7G5iJHQpSqkjz97c3aZOS4KWXYPFiWynzrrsgOdnp6JS3cSXR/wHUFpHq2TdY+wKzzjjmL6AjgIhUAOoCm0UkUERKZW8PBK4H1hRU8Ep5usBAGDXKrnQ1YoStp1O3Ljz0EOzd63R0yltcMNEbYzKAocA8YD0wwxizVkQGicig7MP+A1wpIquBBcDjxpi9QAXgVxFJAJYDs40xcwujIUp5spAQeOUVuwBK//52mcOaNW1vPz3d6eiUpxPjhrf9o6OjTWysTrlXRde6dbanP2sWVKoEzz1na+r4+TkdmXJXIhJ3rins+mSsUm4oMtLW0PnlF6haFe6/347hz5qlUzJV/mmiV8qNtW0LS5fCV1/ZUsg9ekC7dvbJW6VcpYleKTcnAjffbFe7evtt+PNPaN0aeve23yt1IZrolfIQuadkPvcczJtnV7oaPhz273c6OuXONNEr5WGCguDpp23Cv+8+eOstWyXznXe0ho7KmyZ6pTxUhQp2KGfFCrto+YMPQrNm8PPPTkem3I0meqU8XJMmEBMDX3wBBw9C+/bQpw9s3ep0ZMpdaKJXyguI2Juz69fD88/bypj16sEzz+gDV0oTvVJeJSAAnnoKNm6Enj1t0q9XD6ZP1/n3RZkmeqW8UJUq8PnntlhauXLQt6+df5+Q4HRkygma6JXyYldfbdexnTzZDus0awZDhkBamtORqctJE71SXs7X15ZQ+PNPm+TfeQfq1LHJX6djFg2a6JUqIoKDYcIEWLnSPmj1wAPQqpVd11Z5N030ShUxjRvDokXw2WewYwdceaWtjLlzp9ORqcKiiV6pIkgE+vWzs3OeeMIm/Tp14PXX4eRJp6NTBU0TvVJFWFAQvPwyrF1rK2U++qh9AGvBAqcjUwVJE71Sitq1YfZsW+/++HG47jo7JXP7dqcjUwVBE71SCrDDOd272979c8/ZhU/q1oWxY+HECaejU5dCE71S6jQlStjqmOvWQceO8NhjEBUFCxc6HZm6WJrolVJ5ql7d9uq/+w6OHbNJv18/Hc7xRJrolVLn1a2bHc559lmYOdPWznn1VZ2d40k00SulLiggwFbCXLfOlkEeOdIO5yxa5HBgyiWa6JVSLqtRww7lzJoFR49Chw4wYADs2+d0ZOp8NNErpfLt1OycUaNgyhSoX99Wy9RSyO5JE71S6qIEBMBLL0FcHEREwG23wY03QnKy05GpM2miV0pdksaNbWG08ePhl19swbTXX4eMDKcjU6dooldKXTJfXxg2zN6svfZaW0qhdWtbKVM5TxO9UqrAVK1qb9TOmAEpKdCihX3gStetdZYmeqVUgRKBW26xK1oNGGBLKDRsCD/95HRkRZdLiV5EOovIRhFJEpEn8thfWkS+E5EEEVkrIve4eq5SyjsFB9tVrH7+Gfz9oVMnu9LVwYNOR1b0XDDRi4gvMBHoAkQC/UQk8ozDhgDrjDFNgPbAayJSzMVzlVJe7JprID4eHn8cPvgAIiPh+++djqpocaVH3xJIMsZsNsacAKYBPc44xgClRESAICANyHDxXKWUlwsIgDFjYNkyKFvWzsO/4w590OpycSXRVwa25Xqdkr0tt7eA+kAqsBoYbozJcvFcAERkoIjEikjsnj17XAxfKeVJoqMhNtbWzZk+3fbuv/zS6ai8nyuJXvLYdubzbzcA8UAlIAp4S0SucPFcu9GYycaYaGNMdGhoqAthKaU8UbFitm5OXBxUqWJv3P7rX7pmbWFyJdGnAFVyvQ7H9txzuwf42lhJwBagnovnKqWKoMaN4fff7ZDO7Nm2d//JJ1pGoTC4kuj/AGqLSHURKQb0BWadccxfQEcAEakA1AU2u3iuUqqI8vOzN2kTEmy9nLvugq5dYdu2C5+rXHfBRG+MyQCGAvOA9cAMY8xaERkkIoOyD/sPcKWIrAYWAI8bY/ae69zCaIhSynPVrQuLF8O4cXY6ZoMGMGkSZGU5HZl3EOOGfydFR0eb2NhYp8NQSjlg82Y7337hQlv7/r33oGZNp6NyfyISZ4yJzmufPhmrlHIrNWrYp2gnT4YVK6BRI3jjDcjMdDoyz6WJXinldkRsr37tWlsk7ZFHoG1bW1ZB5Z8meqWU2woPtytaTZ0Kf/5ply986SVdrza/NNErpdyaCNx+uy2B3KMH/N//QcuWtqyCco0meqWUR6hQwZY//uor2LHDPmX79NNw4oTTkbk/TfRKKY9y8822d3/bbfCf/0CrVrB6tdNRuTdN9EopjxMSYp+inTkTUlOheXN4+WVdvvBcNNErpTxWz56wZo0du3/ySTszZ+NGp6NyP5rolVIeLTTUjt1//jkkJtqZOePG6VO1uWmiV0p5PBHo29f27q+7Dh5+GDp0sE/ZKk30SikvEhZmFyf/4AM7/bJxY1szxw0rvVxWmuiVUl5FBO65x87EadMGBg2Czp2LdkVMTfRKKa9UtSrMnw//+x8sWQING8KHHxbN3r0meqWU1xKBBx+EVavsTdoBA+x6talFbPkjTfRKKa9XowbExMD48bb8cYMGtn5OUenda6JXShUJPj4wbJi9SRsZCXfeaZ+y3bXL6cgKnyZ6pVSRUqeOXc1q7Fj44Qfbu58xw+moCpcmeqVUkePrCyNGwMqVdljn1luhTx/Yu9fpyAqHJnqlVJFVvz4sXWpr3H/zje3dz5rldFQFTxO9UqpI8/ODUaMgLg4qVbJ1cx58ENLTnY6s4GiiV0op7Nq0v/9uh3TeecfWu/eWxU000SulVLbixe1N2h9/hIMH7UpWr73m+QXSNNErpdQZrrvOPmTVtavt4d9wg2c/ZKWJXiml8lC2LHz9NUyebG/YNmpkb9h6Ik30Sil1DiJw//2wYgVERECvXvDAA3DkiNOR5Y8meqWUuoC6deG33+Dxx+Hdd+3ShStWOB2V6zTRK6WUC4oVgzFjYMECOHzYLko+ZgxkZjod2YVpoldKqXzo0MHeqO3Vy86/79ABkpOdjur8NNErpVQ+hYTA9OnwySf/rGQ1ZYr7VsPURK+UUhdBxFbAXLUKmjSB/v1tzZy0NKcjO5tLiV5EOovIRhFJEpEn8tg/UkTis7/WiEimiIRk70sWkdXZ+2ILugFKKeWkiAhYtAhefhlmzrTTMH/6yemoTnfBRC8ivsBEoAsQCfQTkcjcxxhjxhpjoowxUcAo4GdjTO7fax2y90cXXOhKKeUefH3hiSdg2TK44gro1AkeeQSOHXM6MsuVHn1LIMkYs9kYcwKYBvQ4z/H9gM8LIjillPIkzZrZ4mhDhsAbb0CLFpCQ4HRUriX6ykDu9dNTsredRURKAp2Br3JtNsB8EYkTkYHn+hARGSgisSISu2fPHhfCUkop91OyJLz1FsyZY+vbt2wJb77p7I1aVxK95LHtXCF3B5acMWxzlTGmGXboZ4iIXJPXicaYycaYaGNMdGhoqAthKaWU++rSBVavhuuvt0sY9u0Lhw45E4sriT4FqJLrdThwrvI+fTlj2MYYk5r9725gJnYoSCmlvF65cvDtt/bBqi+/tEM5a9Zc/jhcSfR/ALVFpLqIFMMm87PWYBGR0kA74Ntc2wJFpNSp74HrAQeaqZRSzvDxsaUTFiyAAwfsUM6UKZc5hgsdYIzJAIYC84D1wAxjzFoRGSQig3Id2guYb4zJXe6nAvCriCQAy4HZxpi5BRe+Ukp5hvbt7Rq1LVvaOfeDBl2+WTli3PBRrujoaBMbq1PulVLeJyMDnnrKDuc0a2aHdKpXv/T3FZG4c01h1ydjlVLqMvLzsw9XzZoFmzfbZF/YC5JroldKKQd0727n3NeoYRckf+IJ29svDJrolVLKITVqwJIldjGT//4XOna0JZALml/Bv6VSSilXlSgB77wDbdvamjmBgQX/GZrolVLKDdxxh/0qDDp0o5RSXk4TvVJKeTlN9Eop5eU00SullJfTRK+UUl5OE71SSnk5TfRKKeXlNNErpZSXc8vqlSKyB9iaa1M5YK9D4RQWb2uTt7UHvK9N3tYe8L42XUp7qhlj8lyezy0T/ZlEJPZc5Tc9lbe1ydvaA97XJm9rD3hfmwqrPTp0o5RSXk4TvVJKeTlPSfSTnQ6gEHhbm7ytPeB9bfK29oD3talQ2uMRY/RKKaUunqf06JVSSl0kTfRKKeXl3C7Ri8gHIrJbRNbk2hYiIj+KSGL2v8FOxphf52jTsyKyXUTis79udDLG/BCRKiISIyLrRWStiAzP3u6R1+k87fHka1RCRJaLSEJ2m57L3u6p1+hc7fHYawQgIr4islJEvs9+XSjXx+3G6EXkGuAw8IkxpmH2tleANGPMGBF5Agg2xjzuZJz5cY42PQscNsa86mRsF0NEwoAwY8wKESkFxAE9gbvxwOt0nvb0wXOvkQCBxpjDIuIP/AoMB27GM6/RudrTGQ+9RgAi8ggQDVxhjOlWWLnO7Xr0xpjFQNoZm3sAH2d//zH2f0KPcY42eSxjzA5jzIrs7w8B64HKeOh1Ok97PJaxTi0z7Z/9ZfDca3Su9ngsEQkHugLv5dpcKNfH7RL9OVQwxuwA+z8lUN7heArKUBFZlT204xF/Qp9JRCKApsAyvOA6ndEe8OBrlD0sEA/sBn40xnj0NTpHe8Bzr9E44DEgK9e2Qrk+npLovdHbQE0gCtgBvOZoNBdBRIKAr4CHjDF/Ox3PpcqjPR59jYwxmcaYKCAcaCkiDR0O6ZKcoz0eeY1EpBuw2xgTdzk+z1MS/a7scdRT46m7HY7nkhljdmX/h5sFvAu0dDqm/MgeJ/0K+NQY83X2Zo+9Tnm1x9Ov0SnGmAPAIux4tsdeo1Nyt8eDr9FVwE0ikgxMA64VkakU0vXxlEQ/C7gr+/u7gG8djKVAnLqY2XoBa851rLvJvjH2PrDeGPN6rl0eeZ3O1R4Pv0ahIlIm+/sA4DpgA557jfJsj6deI2PMKGNMuDEmAugLLDTG3EEhXR93nHXzOdAeW65zF/AM8A0wA6gK/AXcYozxmJub52hTe+yfmwZIBh44NTbn7kSkLfALsJp/xhefxI5re9x1Ok97+uG516gx9maeL7ZDN8MY87yIlMUzr9G52jMFD71Gp4hIe2BE9qybQrk+bpfolVJKFSxPGbpRSil1kTTRK6WUl9NEr5RSXk4TvVJKeTlN9Eop5eU00SullJfTRK+UUl7u/wEb2jfukM+2VwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"df = pd.DataFrame(\n",
" model.learning_process[10:], columns=[\"epoch\", \"train_RMSE\", \"test_RMSE\"]\n",
")\n",
"plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n",
"plt.plot(\"epoch\", \"test_RMSE\", data=df, color=\"green\", linestyle=\"dashed\")\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Saving and evaluating recommendations"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model.estimations()\n",
"\n",
"top_n = pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n",
"\n",
"top_n.to_csv(\n",
" \"Recommendations generated/ml-100k/Self_SVD_reco.csv\", index=False, header=False\n",
")\n",
"\n",
"estimations = pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n",
"estimations.to_csv(\n",
" \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\",\n",
" index=False,\n",
" header=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 8683.10it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.914393</td>\n",
" <td>0.717199</td>\n",
" <td>0.101697</td>\n",
" <td>0.042334</td>\n",
" <td>0.051787</td>\n",
" <td>0.068811</td>\n",
" <td>0.092489</td>\n",
" <td>0.07236</td>\n",
" <td>0.104839</td>\n",
" <td>0.04897</td>\n",
" <td>0.196117</td>\n",
" <td>0.517889</td>\n",
" <td>0.480382</td>\n",
" <td>0.867338</td>\n",
" <td>0.147186</td>\n",
" <td>3.852545</td>\n",
" <td>0.972694</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" RMSE MAE precision recall F_1 F_05 \\\n",
"0 0.914393 0.717199 0.101697 0.042334 0.051787 0.068811 \n",
"\n",
" precision_super recall_super NDCG mAP MRR LAUC \\\n",
"0 0.092489 0.07236 0.104839 0.04897 0.196117 0.517889 \n",
"\n",
" HR Reco in test Test coverage Shannon Gini \n",
"0 0.480382 0.867338 0.147186 3.852545 0.972694 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import evaluation_measures as ev\n",
"\n",
"estimations_df = pd.read_csv(\n",
" \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\", header=None\n",
")\n",
"reco = np.loadtxt(\"Recommendations generated/ml-100k/Self_SVD_reco.csv\", delimiter=\",\")\n",
"\n",
"ev.evaluate(\n",
" test=pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None),\n",
" estimations_df=estimations_df,\n",
" reco=reco,\n",
" super_reactions=[4, 5],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 8505.85it/s]\n",
"943it [00:00, 9544.72it/s]\n",
"943it [00:00, 9154.80it/s]\n",
"943it [00:00, 8282.66it/s]\n",
"943it [00:00, 8432.23it/s]\n",
"943it [00:00, 9601.30it/s]\n",
"943it [00:00, 9158.89it/s]\n",
"943it [00:00, 12283.59it/s]\n",
"943it [00:00, 9500.43it/s]\n",
"943it [00:00, 10085.91it/s]\n",
"943it [00:00, 10260.90it/s]\n",
"943it [00:00, 9691.20it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVD</td>\n",
" <td>0.950347</td>\n",
" <td>0.749312</td>\n",
" <td>0.100636</td>\n",
" <td>0.050514</td>\n",
" <td>0.055794</td>\n",
" <td>0.070753</td>\n",
" <td>0.091202</td>\n",
" <td>0.082734</td>\n",
" <td>0.114054</td>\n",
" <td>0.053200</td>\n",
" <td>0.248803</td>\n",
" <td>0.521983</td>\n",
" <td>0.517497</td>\n",
" <td>0.992153</td>\n",
" <td>0.210678</td>\n",
" <td>4.418683</td>\n",
" <td>0.952848</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVD</td>\n",
" <td>0.914393</td>\n",
" <td>0.717199</td>\n",
" <td>0.101697</td>\n",
" <td>0.042334</td>\n",
" <td>0.051787</td>\n",
" <td>0.068811</td>\n",
" <td>0.092489</td>\n",
" <td>0.072360</td>\n",
" <td>0.104839</td>\n",
" <td>0.048970</td>\n",
" <td>0.196117</td>\n",
" <td>0.517889</td>\n",
" <td>0.480382</td>\n",
" <td>0.867338</td>\n",
" <td>0.147186</td>\n",
" <td>3.852545</td>\n",
" <td>0.972694</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVDBiased</td>\n",
" <td>0.939472</td>\n",
" <td>0.739816</td>\n",
" <td>0.085896</td>\n",
" <td>0.036073</td>\n",
" <td>0.043528</td>\n",
" <td>0.057643</td>\n",
" <td>0.077039</td>\n",
" <td>0.057463</td>\n",
" <td>0.097753</td>\n",
" <td>0.045546</td>\n",
" <td>0.219839</td>\n",
" <td>0.514709</td>\n",
" <td>0.431601</td>\n",
" <td>0.997455</td>\n",
" <td>0.168831</td>\n",
" <td>4.217578</td>\n",
" <td>0.962577</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Random</td>\n",
" <td>1.521845</td>\n",
" <td>1.225949</td>\n",
" <td>0.047190</td>\n",
" <td>0.020753</td>\n",
" <td>0.024810</td>\n",
" <td>0.032269</td>\n",
" <td>0.029506</td>\n",
" <td>0.023707</td>\n",
" <td>0.050075</td>\n",
" <td>0.018728</td>\n",
" <td>0.121957</td>\n",
" <td>0.506893</td>\n",
" <td>0.329799</td>\n",
" <td>0.986532</td>\n",
" <td>0.184704</td>\n",
" <td>5.099706</td>\n",
" <td>0.907217</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>1.030712</td>\n",
" <td>0.820904</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
"0 Ready_SVD 0.950347 0.749312 0.100636 0.050514 0.055794 \n",
"0 Self_SVD 0.914393 0.717199 0.101697 0.042334 0.051787 \n",
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
"0 Ready_SVDBiased 0.939472 0.739816 0.085896 0.036073 0.043528 \n",
"0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
"0 0.070753 0.091202 0.082734 0.114054 0.053200 0.248803 \n",
"0 0.068811 0.092489 0.072360 0.104839 0.048970 0.196117 \n",
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
"0 0.057643 0.077039 0.057463 0.097753 0.045546 0.219839 \n",
"0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR Reco in test Test coverage Shannon Gini \n",
"0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n",
"0 0.521983 0.517497 0.992153 0.210678 4.418683 0.952848 \n",
"0 0.517889 0.480382 0.867338 0.147186 3.852545 0.972694 \n",
"0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n",
"0 0.514709 0.431601 0.997455 0.168831 4.217578 0.962577 \n",
"0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n",
"0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n",
"0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n",
"0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n",
"0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir_path = \"Recommendations generated/ml-100k/\"\n",
"super_reactions = [4, 5]\n",
"test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embeddings"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>code</th>\n",
" <th>score</th>\n",
" <th>item_id</th>\n",
" <th>id</th>\n",
" <th>title</th>\n",
" <th>genres</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>321</td>\n",
" <td>1.000000</td>\n",
" <td>322</td>\n",
" <td>322</td>\n",
" <td>Murder at 1600 (1997)</td>\n",
" <td>Mystery, Thriller</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>983</td>\n",
" <td>0.902748</td>\n",
" <td>984</td>\n",
" <td>984</td>\n",
" <td>Shadow Conspiracy (1997)</td>\n",
" <td>Thriller</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>985</td>\n",
" <td>0.894696</td>\n",
" <td>986</td>\n",
" <td>986</td>\n",
" <td>Turbulence (1997)</td>\n",
" <td>Thriller</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>778</td>\n",
" <td>0.890524</td>\n",
" <td>779</td>\n",
" <td>779</td>\n",
" <td>Drop Zone (1994)</td>\n",
" <td>Action</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>686</td>\n",
" <td>0.889220</td>\n",
" <td>687</td>\n",
" <td>687</td>\n",
" <td>McHale's Navy (1997)</td>\n",
" <td>Comedy, War</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>331</td>\n",
" <td>0.887596</td>\n",
" <td>332</td>\n",
" <td>332</td>\n",
" <td>Kiss the Girls (1997)</td>\n",
" <td>Crime, Drama, Thriller</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>987</td>\n",
" <td>0.886547</td>\n",
" <td>988</td>\n",
" <td>988</td>\n",
" <td>Beautician and the Beast, The (1997)</td>\n",
" <td>Comedy, Romance</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>1039</td>\n",
" <td>0.882845</td>\n",
" <td>1040</td>\n",
" <td>1040</td>\n",
" <td>Two if by Sea (1996)</td>\n",
" <td>Comedy, Romance</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>1022</td>\n",
" <td>0.882782</td>\n",
" <td>1023</td>\n",
" <td>1023</td>\n",
" <td>Fathers' Day (1997)</td>\n",
" <td>Comedy</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>929</td>\n",
" <td>0.877662</td>\n",
" <td>930</td>\n",
" <td>930</td>\n",
" <td>Chain Reaction (1996)</td>\n",
" <td>Action, Adventure, Thriller</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" code score item_id id title \\\n",
"0 321 1.000000 322 322 Murder at 1600 (1997) \n",
"1 983 0.902748 984 984 Shadow Conspiracy (1997) \n",
"2 985 0.894696 986 986 Turbulence (1997) \n",
"3 778 0.890524 779 779 Drop Zone (1994) \n",
"4 686 0.889220 687 687 McHale's Navy (1997) \n",
"5 331 0.887596 332 332 Kiss the Girls (1997) \n",
"6 987 0.886547 988 988 Beautician and the Beast, The (1997) \n",
"7 1039 0.882845 1040 1040 Two if by Sea (1996) \n",
"8 1022 0.882782 1023 1023 Fathers' Day (1997) \n",
"9 929 0.877662 930 930 Chain Reaction (1996) \n",
"\n",
" genres \n",
"0 Mystery, Thriller \n",
"1 Thriller \n",
"2 Thriller \n",
"3 Action \n",
"4 Comedy, War \n",
"5 Crime, Drama, Thriller \n",
"6 Comedy, Romance \n",
"7 Comedy, Romance \n",
"8 Comedy \n",
"9 Action, Adventure, Thriller "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"item = random.choice(list(set(train_ui.indices)))\n",
"\n",
"embeddings_norm = (\n",
" model.Qi / np.linalg.norm(model.Qi, axis=1)[:, None]\n",
") # we do not mean-center here\n",
"# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n",
"\n",
"similarity_scores = np.dot(embeddings_norm, embeddings_norm[item].T)\n",
"top_similar_items = pd.DataFrame(\n",
" enumerate(similarity_scores), columns=[\"code\", \"score\"]\n",
").sort_values(by=[\"score\"], ascending=[False])[:10]\n",
"\n",
"top_similar_items[\"item_id\"] = top_similar_items[\"code\"].apply(\n",
" lambda x: item_code_id[x]\n",
")\n",
"\n",
"items = pd.read_csv(\"./Datasets/ml-100k/movies.csv\")\n",
"\n",
"result = pd.merge(top_similar_items, items, left_on=\"item_id\", right_on=\"id\")\n",
"\n",
"result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# project task 5: implement SVD on top baseline (as it is in Surprise library)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# making changes to our implementation by considering additional parameters in the gradient descent procedure\n",
"# seems to be the fastest option\n",
"# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n",
"# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'\n",
"\n",
"# link to the relevant Surprise documentation https://surprise.readthedocs.io/en/stable/matrix_factorization.html#matrix-factorization-based-algorithms"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ready-made SVD - Surprise implementation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"import helpers\n",
"import surprise as sp\n",
"\n",
"algo = sp.SVD(biased=False) # to use unbiased version\n",
"\n",
"helpers.ready_made(\n",
" algo,\n",
" reco_path=\"Recommendations generated/ml-100k/Ready_SVD_reco.csv\",\n",
" estimations_path=\"Recommendations generated/ml-100k/Ready_SVD_estimations.csv\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD biased - on top baseline"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"algo = sp.SVD() # default is biased=True\n",
"\n",
"helpers.ready_made(\n",
" algo,\n",
" reco_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv\",\n",
" estimations_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 11456.53it/s]\n",
"943it [00:00, 11932.50it/s]\n",
"943it [00:00, 10853.07it/s]\n",
"943it [00:00, 9426.44it/s]\n",
"943it [00:00, 8757.09it/s]\n",
"943it [00:00, 9999.67it/s]\n",
"943it [00:00, 11323.49it/s]\n",
"943it [00:00, 9764.72it/s]\n",
"943it [00:00, 9692.41it/s]\n",
"943it [00:00, 9052.77it/s]\n",
"943it [00:00, 8645.18it/s]\n",
"943it [00:00, 10594.54it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVD</td>\n",
" <td>0.951652</td>\n",
" <td>0.750975</td>\n",
" <td>0.096394</td>\n",
" <td>0.047252</td>\n",
" <td>0.052870</td>\n",
" <td>0.067257</td>\n",
" <td>0.085515</td>\n",
" <td>0.074754</td>\n",
" <td>0.109578</td>\n",
" <td>0.051562</td>\n",
" <td>0.235567</td>\n",
" <td>0.520341</td>\n",
" <td>0.496288</td>\n",
" <td>0.995546</td>\n",
" <td>0.208514</td>\n",
" <td>4.455755</td>\n",
" <td>0.951624</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVD</td>\n",
" <td>0.914393</td>\n",
" <td>0.717199</td>\n",
" <td>0.101697</td>\n",
" <td>0.042334</td>\n",
" <td>0.051787</td>\n",
" <td>0.068811</td>\n",
" <td>0.092489</td>\n",
" <td>0.072360</td>\n",
" <td>0.104839</td>\n",
" <td>0.048970</td>\n",
" <td>0.196117</td>\n",
" <td>0.517889</td>\n",
" <td>0.480382</td>\n",
" <td>0.867338</td>\n",
" <td>0.147186</td>\n",
" <td>3.852545</td>\n",
" <td>0.972694</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVDBiased</td>\n",
" <td>0.940413</td>\n",
" <td>0.739571</td>\n",
" <td>0.086002</td>\n",
" <td>0.035478</td>\n",
" <td>0.043196</td>\n",
" <td>0.057507</td>\n",
" <td>0.075751</td>\n",
" <td>0.053460</td>\n",
" <td>0.094897</td>\n",
" <td>0.043361</td>\n",
" <td>0.209124</td>\n",
" <td>0.514405</td>\n",
" <td>0.428420</td>\n",
" <td>0.997349</td>\n",
" <td>0.177489</td>\n",
" <td>4.212509</td>\n",
" <td>0.962656</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Random</td>\n",
" <td>1.521845</td>\n",
" <td>1.225949</td>\n",
" <td>0.047190</td>\n",
" <td>0.020753</td>\n",
" <td>0.024810</td>\n",
" <td>0.032269</td>\n",
" <td>0.029506</td>\n",
" <td>0.023707</td>\n",
" <td>0.050075</td>\n",
" <td>0.018728</td>\n",
" <td>0.121957</td>\n",
" <td>0.506893</td>\n",
" <td>0.329799</td>\n",
" <td>0.986532</td>\n",
" <td>0.184704</td>\n",
" <td>5.099706</td>\n",
" <td>0.907217</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>1.030712</td>\n",
" <td>0.820904</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
"0 Ready_SVD 0.951652 0.750975 0.096394 0.047252 0.052870 \n",
"0 Self_SVD 0.914393 0.717199 0.101697 0.042334 0.051787 \n",
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
"0 Ready_SVDBiased 0.940413 0.739571 0.086002 0.035478 0.043196 \n",
"0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
"0 0.067257 0.085515 0.074754 0.109578 0.051562 0.235567 \n",
"0 0.068811 0.092489 0.072360 0.104839 0.048970 0.196117 \n",
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
"0 0.057507 0.075751 0.053460 0.094897 0.043361 0.209124 \n",
"0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR Reco in test Test coverage Shannon Gini \n",
"0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n",
"0 0.520341 0.496288 0.995546 0.208514 4.455755 0.951624 \n",
"0 0.517889 0.480382 0.867338 0.147186 3.852545 0.972694 \n",
"0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n",
"0 0.514405 0.428420 0.997349 0.177489 4.212509 0.962656 \n",
"0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n",
"0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n",
"0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n",
"0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n",
"0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 "
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir_path = \"Recommendations generated/ml-100k/\"\n",
"super_reactions = [4, 5]\n",
"test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}