{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Self made SVD" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import helpers\n", "import pandas as pd\n", "import numpy as np\n", "import scipy.sparse as sparse\n", "from collections import defaultdict\n", "from itertools import chain\n", "import random\n", "import matplotlib.pyplot as plt\n", "\n", "train_read = pd.read_csv(\"./Datasets/ml-100k/train.csv\", sep=\"\\t\", header=None)\n", "test_read = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n", "(\n", " train_ui,\n", " test_ui,\n", " user_code_id,\n", " user_id_code,\n", " item_code_id,\n", " item_id_code,\n", ") = helpers.data_to_csr(train_read, test_read)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n", "from tqdm import tqdm\n", "\n", "\n", "class SVD:\n", " def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n", " self.train_ui = train_ui\n", " self.uir = list(\n", " zip(*[train_ui.nonzero()[0], train_ui.nonzero()[1], train_ui.data])\n", " )\n", "\n", " self.learning_rate = learning_rate\n", " self.regularization = regularization\n", " self.iterations = iterations\n", " self.nb_users, self.nb_items = train_ui.shape\n", " self.nb_ratings = train_ui.nnz\n", " self.nb_factors = nb_factors\n", "\n", " self.Pu = np.random.normal(\n", " loc=0, scale=1.0 / self.nb_factors, size=(self.nb_users, self.nb_factors)\n", " )\n", " self.Qi = np.random.normal(\n", " loc=0, scale=1.0 / self.nb_factors, size=(self.nb_items, self.nb_factors)\n", " )\n", "\n", " def train(self, test_ui=None):\n", " if test_ui != None:\n", " self.test_uir = list(\n", " zip(*[test_ui.nonzero()[0], test_ui.nonzero()[1], test_ui.data])\n", " )\n", "\n", " self.learning_process = []\n", " pbar = tqdm(range(self.iterations))\n", " for i in pbar:\n", " pbar.set_description(\n", " f\"Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...\"\n", " )\n", " np.random.shuffle(self.uir)\n", " self.sgd(self.uir)\n", " if test_ui == None:\n", " self.learning_process.append([i + 1, self.RMSE_total(self.uir)])\n", " else:\n", " self.learning_process.append(\n", " [i + 1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)]\n", " )\n", "\n", " def sgd(self, uir):\n", "\n", " for u, i, score in uir:\n", " # Computer prediction and error\n", " prediction = self.get_rating(u, i)\n", " e = score - prediction\n", "\n", " # Update user and item latent feature matrices\n", " Pu_update = self.learning_rate * (\n", " e * self.Qi[i] - self.regularization * self.Pu[u]\n", " )\n", " Qi_update = self.learning_rate * (\n", " e * self.Pu[u] - self.regularization * self.Qi[i]\n", " )\n", "\n", " self.Pu[u] += Pu_update\n", " self.Qi[i] += Qi_update\n", "\n", " def get_rating(self, u, i):\n", " prediction = self.Pu[u].dot(self.Qi[i].T)\n", " return prediction\n", "\n", " def RMSE_total(self, uir):\n", " RMSE = 0\n", " for u, i, score in uir:\n", " prediction = self.get_rating(u, i)\n", " RMSE += (score - prediction) ** 2\n", " return np.sqrt(RMSE / len(uir))\n", "\n", " def estimations(self):\n", " self.estimations = np.dot(self.Pu, self.Qi.T)\n", "\n", " def recommend(self, user_code_id, item_code_id, topK=10):\n", "\n", " top_k = defaultdict(list)\n", " for nb_user, user_scores in enumerate(self.estimations):\n", "\n", " user_rated = self.train_ui.indices[\n", " self.train_ui.indptr[nb_user] : self.train_ui.indptr[nb_user + 1]\n", " ]\n", " for item, score in enumerate(user_scores):\n", " if item not in user_rated and not np.isnan(score):\n", " top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n", " result = []\n", " # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n", " for uid, item_scores in top_k.items():\n", " item_scores.sort(key=lambda x: x[1], reverse=True)\n", " result.append([uid] + list(chain(*item_scores[:topK])))\n", " return result\n", "\n", " def estimate(self, user_code_id, item_code_id, test_ui):\n", " result = []\n", " for user, item in zip(*test_ui.nonzero()):\n", " result.append(\n", " [\n", " user_code_id[user],\n", " item_code_id[item],\n", " self.estimations[user, item]\n", " if not np.isnan(self.estimations[user, item])\n", " else 1,\n", " ]\n", " )\n", " return result" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 39 RMSE: 0.7489999966900885. Training epoch 40...: 100%|██████████| 40/40 [01:02<00:00, 1.57s/it]\n" ] } ], "source": [ "model = SVD(\n", " train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40\n", ")\n", "model.train(test_ui)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbmklEQVR4nO3de3RV9Z338fc3FwmQQCBEpIRbfewaO6BBIuqIXVQ7LkGtjGMtrjXeVldZtVhx6rXWtqPLdlwy1Y7VljoVpz61RTteqo629YJaa1EDBhTBB2ytRLJMDCQQlEvg+/yxzyEn4SQ5JznJPtn781rrt/Y5e+9z9jdb/Ox9fvtm7o6IiAx9BWEXICIiuaFAFxGJCAW6iEhEKNBFRCJCgS4iEhFFYS143LhxPnXq1LAWLyIyJK1evfojd69MNy20QJ86dSq1tbVhLV5EZEgys791N01dLiIiEaFAFxGJCAW6iEhEhNaHLiJDz759+6ivr2f37t1hlxJ5JSUlVFVVUVxcnPFnFOgikrH6+nrKysqYOnUqZhZ2OZHl7jQ3N1NfX8+0adMy/py6XEQkY7t376aiokJhPsDMjIqKiqx/CSnQRSQrCvPB0Zf1POQCfcsWuPJK2Lcv7EpERPLLkAv0NWvgP/8Tli4NuxIRkfwy5AL9nHPgS1+Cm2+GjRvDrkZEBlNLSws/+clPsv7c/PnzaWlpyfpzl1xyCdOmTaO6uppjjz2W55577uC0uXPnMnnyZFIfErRgwQJKS0sBOHDgAFdccQXTp09nxowZHH/88fz1r38FgivlZ8yYQXV1NdXV1VxxxRVZ15bOkDzL5cc/hmefha9+FV58EQqG3GZJRPoiGehf//rXO43fv38/hYWF3X7uqaee6vMyly5dynnnncfKlStZtGgRmzZtOjitvLycP/3pT8yZM4eWlhYaGhoOTnvwwQfZunUr69ato6CggPr6ekaOHHlw+sqVKxk3blyf60pnSAb6+PFwxx1wySWwbBl0+W8rIoPgyiuhri6331ldDT/6UffTr7/+et59912qq6spLi6mtLSUCRMmUFdXx9tvv82CBQvYsmULu3fvZsmSJSxatAjouHdUW1sb8+bNY86cObzyyitMnDiR3/72twwfPrzX2k466SQ++OCDTuMWLlzIihUrmDNnDo888gjnnnsu69evB6ChoYEJEyZQkNjjrKqq6tM6ycaQ3be96CI4/XS47jp4//2wqxGRwXDrrbdy5JFHUldXx9KlS3nttdf4/ve/z9tvvw3A8uXLWb16NbW1tdx55500Nzcf8h2bNm1i8eLFrF+/nvLych5++OGMlv273/2OBQsWdBp32mmn8dJLL7F//35WrFjBl7/85YPTzj//fJ544gmqq6u56qqreOONNzp99vOf//zBLpc77rgjyzWR3pDcQwcwg5/9DKZPh699Df73f4NxIjI4etqTHiyzZ8/udOHNnXfeyaOPPgrAli1b2LRpExUVFZ0+k+wTB5g1axbvvfdej8u45ppruPbaa2lsbGTVqlWdphUWFjJnzhwefPBBPvnkE1JvCV5VVcU777zD888/z/PPP89pp53Gb37zG0477TRgYLpchuweOsDUqfCDH8DTT8OvfhV2NSIy2FL7pF944QWeffZZ/vznP7N27VpmzpyZ9sKcYcOGHXxdWFhIe3t7j8tYunQpmzdv5pZbbuHiiy8+ZPrChQv5xje+wfnnn592WfPmzWPp0qXccMMNPPbYY1n8ddkb0oEOsHgxnHgiLFkCTU1hVyMiA6msrIydO3emndba2sqYMWMYMWIEGzduPGRvuj8KCgpYsmQJBw4c4Pe//32naaeccgrf+ta3uOCCCzqNX7NmDVu3bgWCM17WrVvHlClTclZT2joH9NsHQWEh3Hsv7NwZhLqIRFdFRQUnn3wy06dP55prruk07YwzzqC9vZ1jjjmG73znO5x44ok5XbaZceONN3LbbbcdMv7qq68+pPuksbGRs88+m+nTp3PMMcdQVFTE5ZdffnB6ah/6RRddlJsaU8+hHEw1NTWeyycW3XwzfO978MQTcNZZOftaEUmxYcMGjj766LDLiI1069vMVrt7Tbr5h/weetL11wcHSC+7DHbsCLsaEZHB12ugm1mJmb1mZmvNbL2Z3ZRmnrlm1mpmdYn23YEpt3uHHRZ0vWzdGoS7iEimFi9efLD7I9nuu+++sMvKWianLe4BTnX3NjMrBl42s6fdvesRhz+6e6idHbNnw6WXws9/DnffrdMYRQaCu0fujot333132CUcoi/d4b3uoXugLfG2ONHC6XjPwGc/G9yJUd0uIrlXUlJCc3Nzn8JGMpd8wEVJSUlWn8vowiIzKwRWA/8HuNvdX00z20lmthbYClzt7uvTfM8iYBHA5MmTsyo0U8lrCJqbYfToAVmESGxVVVVRX19Pk84RHnDJR9BlI6NAd/f9QLWZlQOPmtl0d38rZZY1wJREt8x84DHgqDTfcw9wDwRnuWRVaYZSA/3Tnx6IJYjEV3FxcVaPRJPBldVZLu7eArwAnNFl/I5kt4y7PwUUm1lur2nNUGqgi4jESSZnuVQm9swxs+HAF4CNXeY5whJHScxsduJ7Q4lUBbqIxFUmXS4TgF8k+tELgIfc/Ukz+xqAuy8DzgMuM7N24BNgoYd01GTs2GCoQBeRuOk10N19HTAzzfhlKa/vAu7KbWl9M2ZMcLqiAl1E4iYyV4omFRZCebkCXUTiJ3KBDkE/ugJdROJGgS4iEhEKdBGRiFCgi4hEhAJdRCQiIhvobW2wd2/YlYiIDJ7IBjrAtm3h1iEiMpgiHejqdhGROFGgi4hEhAJdRCQiIhnoukGXiMRRJANde+giEkeRDPQRI2DYMAW6iMRLJAPdTBcXiUj8RDLQQYEuIvGjQBcRiQgFuohIRCjQRUQiItKBvm0bhPOoahGRwRfpQG9vh507w65ERGRwRDrQQd0uIhIfCnQRkYhQoIuIRESvgW5mJWb2mpmtNbP1ZnZTmnnMzO40s81mts7MjhuYcjOnG3SJSNwUZTDPHuBUd28zs2LgZTN72t1XpcwzDzgq0U4AfpoYhkZ76CISN73uoXugLfG2ONG6ngx4DnB/Yt5VQLmZTchtqdkZMyYYKtBFJC4y6kM3s0IzqwMagWfc/dUus0wEtqS8r0+M6/o9i8ys1sxqm5qa+lhyZoqKoLxcgS4i8ZFRoLv7fnevBqqA2WY2vcsslu5jab7nHnevcfeaysrKrIvNlq4WFZE4yeosF3dvAV4AzugyqR6YlPK+Ctjan8JyQYEuInGSyVkulWZWnng9HPgCsLHLbI8DFyXOdjkRaHX3hlwXmy0FuojESSZnuUwAfmFmhQQbgIfc/Ukz+xqAuy8DngLmA5uBj4FLB6jerFRUwIYNYVchIjI4eg10d18HzEwzflnKawcW57a0/tMeuojESWSvFIUg0HfuhL17w65ERGTgRT7QAbZvD7cOEZHBEItAV7eLiMSBAl1EJCIiHei6QZeIxEmkA1176CISJwp0EZGIiHSgjxwJhx2mQBeReIh0oJvp4iIRiY9IBzoo0EUkPhToIiIRoUAXEYkIBbqISETEJtD9kOcniYhESywCvb0d2tp6n1dEZCiLRaCDul1EJPoU6CIiERH5QNcNukQkLiIf6NpDF5G4UKCLiERE5ANdXS4iEheRD/SiIhg9WoEuItEX+UAHXS0qIvGgQBcRiYheA93MJpnZSjPbYGbrzWxJmnnmmlmrmdUl2ncHpty+UaCLSBwUZTBPO3CVu68xszJgtZk94+5vd5nvj+5+Vu5L7L+KCnjnnbCrEBEZWL3uobt7g7uvSbzeCWwAJg50YbmkPXQRiYOs+tDNbCowE3g1zeSTzGytmT1tZn/fzecXmVmtmdU2NTVlX20fVVTAjh2wb9+gLVJEZNBlHOhmVgo8DFzp7ju6TF4DTHH3Y4EfA4+l+w53v8fda9y9prKyso8lZy95cdH27YO2SBGRQZdRoJtZMUGYP+Duj3Sd7u473L0t8fopoNjMxuW00n7Q1aIiEgeZnOViwL3ABne/vZt5jkjMh5nNTnxv3sSnrhYVkTjI5CyXk4ELgTfNrC4x7gZgMoC7LwPOAy4zs3bgE2Che/48I0h76CISB70Guru/DFgv89wF3JWronJNgS4icRCbK0VBgS4i0RaLQC8theJiBbqIRFssAt1MFxeJSPTFItBBgS4i0adAFxGJCAW6iEhEKNBFRCIidoGeP5c7iYjkVqwCfd8+2LUr7EpERAZGrAId1O0iItEVm0DXDbpEJOpiE+jaQxeRqFOgi4hEhAJdRCQiYhPo6kMXkaiLTaAXF8OoUQp0EYmu2AQ66GpREYk2BbqISEQo0EVEIkKBLiISEQp0EZGIiF2gt7ZCe3vYlYiI5F7sAh1g+/Zw6xARGQixCnRdXCQiURarQNfl/yISZb0GuplNMrOVZrbBzNab2ZI085iZ3Wlmm81snZkdNzDl9o8CXUSirCiDedqBq9x9jZmVAavN7Bl3fztlnnnAUYl2AvDTxDCvKNBFJMp63UN39wZ3X5N4vRPYAEzsMts5wP0eWAWUm9mEnFfbTwp0EYmyrPrQzWwqMBN4tcukicCWlPf1HBr6mNkiM6s1s9qmpqYsS+2/sjIoKlKgi0g0ZRzoZlYKPAxc6e47uk5O8xE/ZIT7Pe5e4+41lZWV2VWaA2a6uEhEoiujQDezYoIwf8DdH0kzSz0wKeV9FbC1/+XlngJdRKIqk7NcDLgX2ODut3cz2+PARYmzXU4EWt29IYd15kxFBXz0UdhViIjkXiZnuZwMXAi8aWZ1iXE3AJMB3H0Z8BQwH9gMfAxcmvNKc+Soo+Dxx8E96IIREYmKXgPd3V8mfR956jwOLM5VUQOppgaWL4f334cpU8KuRkQkd2J1pSjA8ccHw9dfD7cOEZFci12gz5gRPF+0tjbsSkREcit2gT5sGBx7rAJdRKIndoEOQT96bS0cOBB2JSIiuRPbQG9thXffDbsSEZHciWWgJw+MqttFRKIkloH+2c9CSYnOdBGRaIlloBcVwcyZ2kMXkWiJZaBD0O2yZg3s3x92JSIiuRHbQK+pgV27YOPGsCsREcmN2Aa6DoyKSNTENtA/8xkoLdWBURGJjtgGekEBzJqlPXQRiY7YBjoE3S51dbB3b9iViIj0X6wDvaYG9uyB9evDrkREpP9iHeg6MCoiURLrQJ82DcaM0YFREYmGWAe6WcedF0VEhrpYBzoE3S5vvgm7d4ddiYhI/8Q+0GtqoL0d1q4NuxIRkf5RoNcEQ3W7iMhQF/tAr6qC8eN1YFREhr7YB7oOjIpIVMQ+0CE4MLphA7S1hV2JiEjf9RroZrbczBrN7K1ups81s1Yzq0u07+a+zIFVUxM8MPqNN8KuRESk7zLZQ/9v4Ixe5vmju1cn2s39L2tw6cCoiERBr4Hu7i8B2wahltCMHw+TJunAqIgMbbnqQz/JzNaa2dNm9vfdzWRmi8ys1sxqm5qacrTo3NCBUREZ6nIR6GuAKe5+LPBj4LHuZnT3e9y9xt1rKisrc7Do3KmpgU2boKUl7EpERPqm34Hu7jvcvS3x+img2MzG9buyQZa88+Lq1eHWISLSV/0OdDM7wsws8Xp24jub+/u9g23WrGCobhcRGaqKepvBzH4NzAXGmVk98D2gGMDdlwHnAZeZWTvwCbDQ3X3AKh4gY8fCkUfqwKiIDF29Brq7X9DL9LuAu3JWUYhqamDVqrCrEBHpG10pmqKmBv72N8izE3BERDKiQE+hR9KJyFCmQE8xc2Zwsy51u4jIUKRATzFqFMydC8uWwY4dYVcjIpIdBXoXt90GjY3w7/8ediUiItlRoHdRUwMXXgh33AHvvRd2NSIimVOgp/GDH0BBAVx/fdiViIhkToGeRlUVXHstPPggvPJK2NWIiGRGgd6Na66BT30K/vVfg4dfiIjkOwV6N0aODLpeXnsNVqwIuxoRkd4p0Htw4YVw3HFBX/rHH4ddjYhIzxToPSgoCM522bIFbr897GpERHqmQO/F5z4H554Lt94KDQ1hVyMi0j0FegZuuw327oUbbwy7EhGR7inQM3DkkXDFFXDfffDGG2FXIyKSngI9QzfeCBUVcNVVMPQe3yEicaBAz1B5Odx0E6xcCUuXKtRFJP8o0LOwaBH88z/DddcFpzTqVEYRyScK9CwUFcFDD8Ett8CvfgVz5gRPOBIRyQcK9CwVFMC3vw1PPAF/+QvMmgXPPx92VSIiCvQ+O/NMeP11GD8eTj89uABJ/eoiEiYFej8cdVTwuLovfhG++U31q4tIuBTo/VRWBv/zPx396v/wD/Dkk7B/f9iViUjcKNBzILVfvakJzj4bPvMZ+OEPYdu2sKsTkbjoNdDNbLmZNZrZW91MNzO708w2m9k6Mzsu92UODWeeGTy27qGHgodkXH11MPzqV6GuLuzqRCTqMtlD/2/gjB6mzwOOSrRFwE/7X9bQVVwMX/oSvPgirF0b9Ks/8ADMnAmnnAL33w9bt4ZdpYhEUa+B7u4vAT11HJwD3O+BVUC5mU3IVYFD2THHwM9+Bh98EHS/bN0KF18MEycGB1S/8pUg4PUwahHJhaIcfMdEYEvK+/rEuENuNmtmiwj24pk8eXIOFj00jBkTnAVz5ZXBzb1efBFeegkefRSWLw/mmTw5uFXvCSfA0UcHbcIEMAu1dBEZQnIR6OkiJ+0Z2e5+D3APQE1NTezO2i4oCC5EmjUrCPgDB2D9+o6A/8Mf4Je/7Jh/1Cj4u78L2tFHB8Np04J++bFjFfYi0lkuAr0emJTyvgpQL3EGCgpgxoygXX55cGFSQwNs3AgbNgTDjRvhueeCrplUJSVB101VVefh4YdDZWXQxo0L2mGHhfP3icjgykWgPw5cbmYrgBOAVnfXs336wAw+9amgnXpq52k7dwbh/v77UF8f9MvX1wdt1apguHdv+u8dPToI9srKoPunvLz74ahRHW30aBg+XL8ERIaKXgPdzH4NzAXGmVk98D2gGMDdlwFPAfOBzcDHwKUDVWyclZXB8ccHLR13aG6GxsbgXPiPPgqGqa+Tw02bYPt2aGnp/QKowsKOgC8rg9LSoI0c2fE6ddyIEcEwXUudt7g456tIJPZ6DXR3v6CX6Q4szllF0idmHV0smXKHtrYg2JMBv2NHz23XruAzjY3BMNmyveXBsGHpNwjJYepGIN3GI10bOTLoxhKJq1x0ucgQZRbsdZeVwaRJvc/fkwMHglDftevQ9vHHQegnNwZtbUEXUuoGYefOYHpzc8d8yc8fOJB5HSNGpN9IlJZ2/K3J1nVc8n3qsKREXU4ydCjQJScKCjqCM5fcYc+ezuGfriU3CKkbjdRpDQ3BMNn27cts+YWF6TcGyS6o1Nepxx9Su6mSwxEjtHGQgaVAl7xmFuwll5Rk153Umz17Ogd8T78c0rXGxqALaufOYNje3vsykxu9dBuH0aM7hqmv040rLVXXkqSnQJdYGjYsaLnYSCR/RfR07KG7DUPqxqG1NRj2dl99s85nInUN/K7vy8s7WvK9NgrRpEAX6afUXxGHH96/70oeqE4GfDLkU4ddx7e2wocfBmcvJd/v2dPzcgoKDg37ri31lNaxY4Nhsg0b1r+/UwaGAl0kj6QeqJ44se/fk/zF0NISBHxLS0dLfb99e8f7TZs6xre19fz9I0Z0Dvp0w66vx44NfjXol8HAUaCLRNCwYR1XDPfFvn0dQb9tWxD827enf71tG7z7bsf7nk5hLSjo2MsfOxYqKoKW+rprGzdOB5QzpUAXkUMUF2d/XUPSnj1BsDc3dw7/1Jac3tgY3OZi27bgF0V3ksc7kgGfybCsLH4bAQW6iOTUsGFwxBFBy8a+fUGwNzd3bh99dOhw3bqO6d0dRE7dKCVvfZH6Pl0bPrz/f3+YFOgikheKi2H8+KBl6sCBoFsoXeh3bWvXBsNt27rfCIwY0XkDkK6l3gBv1Kj8+hWgQBeRIaugoOOAa6b27+/4JZAu+FPvg/TOO8Fw167035X8FZAu9A8/vGMDlXw9cmRu/u7uKNBFJFYKC7M/YPzxxx0hn9pSw7+pCWprg2Fra/rvGTkyCPfLLw+eiZBrCnQRkV6MGAFTpgQtE3v3Bgd8P/ywo6W+z/b4QqYU6CIiOXbYYcFDZ6qqBne5OsVfRCQiFOgiIhGhQBcRiQgFuohIRCjQRUQiQoEuIhIRCnQRkYhQoIuIRIR5b8+7GqgFmzUBf+thlnHAR4NUTrZUW9+otr5RbX0T1dqmuHvaGxeEFui9MbNad68Ju450VFvfqLa+UW19E8fa1OUiIhIRCnQRkYjI50C/J+wCeqDa+ka19Y1q65vY1Za3fegiIpKdfN5DFxGRLCjQRUQiIu8C3czOMLN3zGyzmV0fdj2pzOw9M3vTzOrMrDbkWpabWaOZvZUybqyZPWNmmxLDMXlU27+Z2QeJdVdnZvNDqm2Sma00sw1mtt7MliTGh77ueqgt9HVnZiVm9pqZrU3UdlNifD6st+5qC329pdRYaGZvmNmTifcDst7yqg/dzAqB/wf8I1APvA5c4O5vh1pYgpm9B9S4e+gXK5jZ54A24H53n54Ydxuwzd1vTWwMx7j7dXlS278Bbe7+H4NdT5faJgAT3H2NmZUBq4EFwCWEvO56qO18Ql53ZmbASHdvM7Ni4GVgCXAu4a+37mo7gzz4NwdgZt8EaoBR7n7WQP2/mm976LOBze7+F3ffC6wAzgm5przk7i8B27qMPgf4ReL1LwjCYNB1U1tecPcGd1+TeL0T2ABMJA/WXQ+1hc4DbYm3xYnm5Md66662vGBmVcCZwM9TRg/Iesu3QJ8IbEl5X0+e/INOcOAPZrbazBaFXUwa4929AYJwAA4PuZ6uLjezdYkumVC6g1KZ2VRgJvAqebbuutQGebDuEt0GdUAj8Iy7581666Y2yIP1BvwIuBY4kDJuQNZbvgW6pRmXN1ta4GR3Pw6YByxOdC1IZn4KHAlUAw3AD8MsxsxKgYeBK919R5i1dJWmtrxYd+6+392rgSpgtplND6OOdLqpLfT1ZmZnAY3uvnowlpdvgV4PTEp5XwVsDamWQ7j71sSwEXiUoIson3yY6IdN9sc2hlzPQe7+YeJ/ugPAfxHiukv0sz4MPODujyRG58W6S1dbPq27RD0twAsEfdR5sd6SUmvLk/V2MvDFxPG3FcCpZvZLBmi95Vugvw4cZWbTzOwwYCHweMg1AWBmIxMHqjCzkcDpwFs9f2rQPQ5cnHh9MfDbEGvpJPmPN+GfCGndJQ6g3QtscPfbUyaFvu66qy0f1p2ZVZpZeeL1cOALwEbyY72lrS0f1pu7f8vdq9x9KkGePe/u/8JArTd3z6sGzCc40+Vd4Nth15NS16eBtYm2PuzagF8T/IzcR/DL5itABfAcsCkxHJtHtf1f4E1gXeIf84SQaptD0I23DqhLtPn5sO56qC30dQccA7yRqOEt4LuJ8fmw3rqrLfT11qXOucCTA7ne8uq0RRER6bt863IREZE+UqCLiESEAl1EJCIU6CIiEaFAFxGJCAW6iEhEKNBFRCLi/wO1vKB6yNy13AAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df = pd.DataFrame(model.learning_process).iloc[:, :2]\n", "df.columns = [\"epoch\", \"train_RMSE\"]\n", "plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyp0lEQVR4nO3de3zO5f/A8dd7B8wmNoYxzJk5DXMqhaQIoa9EB5VKwpcOVPrp+O3gmw4o36JzVOigFKGYFEUbm7M2TGbOQ5jTtuv3xzVrGO6x+dz3vffz8djD7s/hvt/XPvXetetzfd6XGGNQSinlvXycDkAppVTh0kSvlFJeThO9Ukp5OU30Sinl5TTRK6WUl/NzOoC8lCtXzkRERDgdhlJKeYy4uLi9xpjQvPa5ZaKPiIggNjbW6TCUUspjiMjWc+3ToRullPJyF0z0IvKBiOwWkTXn2C8iMkFEkkRklYg0y7Wvs4hszN73REEGrpRSyjWu9Og/AjqfZ38XoHb210DgbQAR8QUmZu+PBPqJSOSlBKuUUir/LjhGb4xZLCIR5zmkB/CJsbUUfheRMiISBkQAScaYzQAiMi372HWXHLVSyq2cPHmSlJQUjh075nQoXq9EiRKEh4fj7+/v8jkFcTO2MrAt1+uU7G15bW91rjcRkYHYvwioWrVqAYSllLpcUlJSKFWqFBEREYiI0+F4LWMM+/btIyUlherVq7t8XkHcjM3rqprzbM+TMWayMSbaGBMdGprnDCGllJs6duwYZcuW1SRfyESEsmXL5vsvp4Lo0acAVXK9DgdSgWLn2K6U8kKa5C+Pi/k5F0SPfhbQP3v2TWvgoDFmB/AHUFtEqotIMaBv9rGFJstksWLHisL8CKWU8jiuTK/8HPgNqCsiKSJyr4gMEpFB2YfMATYDScC7wGAAY0wGMBSYB6wHZhhj1hZCG3J8sfYLmk9uTp8v+rB5/+bC/CillPIYF0z0xph+xpgwY4y/MSbcGPO+MeYdY8w72fuNMWaIMaamMaaRMSY217lzjDF1sve9WJgNAehapyvPtHuG2YmzqT+xPiPmj2D/0f2F/bFKKYcdOHCA//3vf/k+78Ybb+TAgQP5Pu/uu++mevXqREVF0aRJExYsWJCzr3379lStWpXcizr17NmToKAgALKyshg2bBgNGzakUaNGtGjRgi1btgC2KkCjRo2IiooiKiqKYcOG5Tu2vLhlCYSLFVQsiGfbP8v9ze5ndMxoXv/tdRZvXcyy+5bp+KFSXuxUoh88ePBp2zMzM/H19T3neXPmzLnozxw7diy9e/cmJiaGgQMHkpiYmLOvTJkyLFmyhLZt23LgwAF27NiRs2/69OmkpqayatUqfHx8SElJITAwMGd/TEwM5cqVu+i48uJVif6UyldU5sMeHzK81XAOHDuAiJB+Mp2FWxbStXZXTfpKFaKHHoL4+IJ9z6goGDfu3PufeOIJNm3aRFRUFP7+/gQFBREWFkZ8fDzr1q2jZ8+ebNu2jWPHjjF8+HAGDhwI/FNX6/Dhw3Tp0oW2bduydOlSKleuzLfffktAQMAFY2vTpg3bt28/bVvfvn2ZNm0abdu25euvv+bmm29m7Vo7cr1jxw7CwsLw8bEDKuHh4Rf1M8kPr651E1UxivYR7QH4YOUHdP+8Ox0+7qA3bJXyMmPGjKFmzZrEx8czduxYli9fzosvvsi6dfb5zA8++IC4uDhiY2OZMGEC+/btO+s9EhMTGTJkCGvXrqVMmTJ89dVXLn323Llz6dmz52nbOnbsyOLFi8nMzGTatGnceuutOfv69OnDd999R1RUFI8++igrV6487dwOHTrkDN288cYb+fxJ5M0re/R5eaD5A/iID88seoboydH0a9SPQc0HcXW1q50OTSmvcr6e9+XSsmXL0x4omjBhAjNnzgRg27ZtJCYmUrZs2dPOOTXmDtC8eXOSk5PP+xkjR47kscceY/fu3fz++++n7fP19aVt27ZMnz6do0ePkrvsenh4OBs3bmThwoUsXLiQjh078sUXX9CxY0egcIZuvLpHn5u/rz+DWwwm6d9JjLxyJN9u+JYXf/nn/nDK3ykORqeUKki5x7wXLVrETz/9xG+//UZCQgJNmzbN84Gj4sWL53zv6+tLRkbGeT9j7NixJCUl8cILL3DXXXedtb9v3778+9//pk+fPnl+VpcuXRg7dixPPvkk33zzTT5al39FJtGfUrpEaf7b6b/sGrGLd7q9A9gkX21cNVq914qJyyeyN32vw1EqpfKjVKlSHDp0KM99Bw8eJDg4mJIlS7Jhw4azet+XwsfHh+HDh5OVlcW8efNO23f11VczatQo+vXrd9r2FStWkJpqnx3Nyspi1apVVKtWrcBiyjPOQn13NxZYLJCIMhH2e/9AXrnuFY5lHGPoD0MJey2MHtN6sHHvRmeDVEq5pGzZslx11VU0bNiQkSNHnravc+fOZGRk0LhxY5566ilat25doJ8tIowePZpXXnnlrO0jRow4axhm9+7ddO/enYYNG9K4cWP8/PwYOnRozv7cY/T9+/cvmBhzz/V0F9HR0capFaYSdiYwZdUUpq+dzvL7lhNWKoyEnQkEFQuiZkhNR2JSyt2tX7+e+vXrOx1GkZHXz1tE4owx0XkdX2R79OfSpGITXr3+VbY+tJWwUmEAPDr/UWq/WZue03oSsyUGd/zlqJRS56KJ/hx85J8fzZReUxh9zWiWbFvCtZ9cS9SkKGaun+lgdEqpy2HIkCE5wyinvj788EOnw8q3IjO98lKElQrj+Q7PM6rtKD5f8znjfh/HXwf/AuBYxjH2H92f0/tXSnmPiRMnOh1CgdAefT4E+AcwoOkAEgYlMKTlEAA+W/0Z1cZVo//M/izfvlyHdZRSbkcT/UUQEfx87B9D7SPa82D0g8zcMJNW77Wi6riqDJ49mIys88/BVUqpy0UT/SWqEVyD8V3Gk/JwCh/c9AEtKrVgze41Ob8Ixvw6hg9WfsCuw7scjlQpVVTpGH0BKV2iNPc0vYd7mt6TM3yTZbL4KP4jNu7biCC0rNyS7nW60zuyN3XL1XU4YqVUUaE9+kJwqjqmj/iwfsh6Vj6wkufaP0eWyWJ0zGg+W/0ZACcyT5CUluRkqEp5hYutRw8wbtw40tPTz3vMqTrxjRs3pl27dmzdujVnn4hw55135rzOyMggNDSUbt26AbBr1y66detGkyZNiIyM5MYbbwQgOTmZgICA02b0fPLJJxfVhgsyxrjdV/PmzY23Sv071ew8tNMYY8ysDbMMz2JavdvKTPh9gtl1eJfD0Sl1cdatW+fo52/ZssU0aNDgos6tVq2a2bNnj8vHPP300+a+++7L2RcYGGiioqJMenq6McaYOXPmmCZNmpiuXbsaY4wZOHCgGTduXM7xCQkJlxxzXj9vINacI6fq0M1llnsaZovKLXjlulf4dPWnDJs7jIfnPUynmp2Y2msqZUuWPc+7KOXe2n/U/qxtfRr0YXCLwaSfTOfGT288a//dUXdzd9Td7E3fS+8ZvU/bt+juRef9vNz16Dt16kT58uWZMWMGx48fp1evXjz33HMcOXKEPn36kJKSQmZmJk899RS7du0iNTWVDh06UK5cOWJiYi7YtjZt2jBhwoTTtnXp0oXZs2fTu3dvPv/8c/r168cvv/wC2Prz119/fc6xjRs3vuBnFDQdunFQxaCKjLxqJPGD4lnz4Boeu+ox0k+mExwQDMAnCZ8wc/1Mdh/Z7XCkSrm33PXoO3XqRGJiIsuXLyc+Pp64uDgWL17M3LlzqVSpEgkJCaxZs4bOnTszbNgwKlWqRExMjEtJHvKuP39qoZFjx46xatUqWrVqlbNvyJAh3HvvvXTo0IEXX3wxp6AZkPPL6dTXqV8OBU179G6iQfkGvNTxpZzXxhjG/DqG9XvXA1C9THVahbeiR90e9G3Y16kwlXLJ+XrgJf1Lnnd/uZLlLtiDP5/58+czf/58mjZtCsDhw4dJTEzk6quvZsSIETz++ON069aNq6/O31oUHTp0YNeuXZQvX54XXnjhtH2NGzcmOTmZzz//PGcM/pQbbriBzZs3M3fuXH744QeaNm3KmjVrAHJ+ORU27dG7KREhflA8i+9ezNhOY2kW1oxf//qVn5N/BiAzK5N2H7Vj2A/D+HTVp2xK26QPaymF7SSNGjWK+Ph44uPjSUpK4t5776VOnTrExcXRqFEjRo0axfPPP5+v942JiWHr1q00aNCAp59++qz9N910EyNGjDirLDFASEgIt912G1OmTKFFixYsXrz4ott3MbRH78aK+Rbj6mpXn7YK1onMEwDsO7oPQXh/5fu8ufxNwPb6X7/hdXrW6+lEuEo5Jnc9+htuuIGnnnqK22+/naCgILZv346/vz8ZGRmEhIRwxx13EBQUxEcffXTaua6s6hQQEMC4ceNo1KgRo0ePJiQkJGffgAEDKF26NI0aNWLRokU52xcuXEjr1q0pWbIkhw4dYtOmTVStWrVA238hmug9TDHfYgCUDyzPorsXkZGVwdrda1m6bSnfJ35P2QB7E3f59uW8t+I9etXrxbXVr6W4X/Hzva1SHi13PfouXbpw22230aZNGwCCgoKYOnUqSUlJjBw5Eh8fH/z9/Xn77bcBGDhwIF26dCEsLMylcfqwsDD69evHxIkTeeqpp3K2h4eHM3z48LOOj4uLY+jQofj5+ZGVlcV9991HixYtSE5OzhmjP2XAgAEMGzbsEn8aZ9N69F5qSsIUBs8ZzOETh7mi+BV0rd2VXvV60aNej5xfFkoVFK1Hf3lpPXoFwJ1N7mTPyD183+97bom8hR83/8j9392fsz9mSwwJOxPIMlkORqmUuhx06MaLlfArQdc6XelapyuTsiaRmJaY05sf+sNQ1u1ZR7mS5egQ0YFrq1/LdTWuo1ZILYejVso5rVq14vjx46dtmzJlCo0aNXIoooKhib6I8PXxpV65ejmv594+l5jkGBZsWcCCzQv4Yt0X3NboNj69+VOMMUxfO52rqlxFldJVHIxaeRJjTE75D0+1bNkyp0O4oIsZbtdEX0RVKV2F/k36079Jf4wxJKUlkWkyAdi8fzP9vrJTxCqXqkzzSs1pVrEZtzS4hcjQSCfDVm6qRIkS7Nu3j7Jly3p8sndnxhj27dtHiRIl8nWeJnqFiFC7bO2c19WDq5MwKIGYLTH8kfoHK3as4LuN3xEZGklkaCRxqXGMjhlN87DmNAtrRrOwZlQrXU3/By/CwsPDSUlJYc+ePU6H4vVKlChBeHh4vs5xKdGLSGdgPOALvGeMGXPG/mDgA6AmcAwYYIxZk70vGTgEZAIZ57orrNyHj/jQuEJjGlf4pybH4ROH8RVfAA4cO0DqoVR+3PRjzl8B5QPL8+OdP552jio6/P39qV69utNhqHO4YKIXEV9gItAJSAH+EJFZxph1uQ57Eog3xvQSkXrZx3fMtb+DMWZvAcatLrOgYkE533es0ZGEQQkcPXmU1btXs2LHCpZsW0LtEPtXwYuLX2Rh8kK61e5G1zpdqVO2jlNhK6VwbXplSyDJGLPZGHMCmAb0OOOYSGABgDFmAxAhIhUKNFLldgL8A2hZuSWDogcxpdcUAvwDAAgOCGbX4V08Mv8R6r5Vlzpv1mH0wtEOR6tU0eVKoq8MbMv1OiV7W24JwM0AItISqAacGkQywHwRiRORgZcWrvIEg1sMZs3gNWwZvoW3urxFrZBaJKYl5uwf9sMw3lz2Juv2rNP6PEpdBq6M0ed1h+3M/zvHAONFJB5YDawETq2OfZUxJlVEygM/isgGY8xZFX2yfwkMBC66DsSrr0LHjpBdtE45LKJMBENaDmFIyyE5CT39ZDpzEuewaf8mAMKCwuhYoyP3Nb2PdhHtnAxXKa/lSo8+Bcg9mTocSM19gDHmb2PMPcaYKKA/EApsyd6Xmv3vbmAmdijoLMaYycaYaGNMdGhoaH7bwYED8Mor0Lw53HUXbNt2wVPUZXRqRk5J/5IkDUti87DNvNv9XdpFtGNe0jz+3PcnAH8d/IvBswfz9fqvSTua5mTISnmNC9a6ERE/4E/szdXtwB/AbcaYtbmOKQOkG2NOiMj9wNXGmP4iEgj4GGMOZX//I/C8MWbu+T7zYmvdHDgAL78M48eDCDzyCDz+OFxxRb7fSl1GWSaLjKwMivkWY27SXHrP6M2Rk0cQhKZhTelYvSMPtX6ISqUqOR2qUm7rkmrdGGMygKHAPGA9MMMYs1ZEBonIoOzD6gNrRWQD0AU4VcKtAvCriCQAy4HZF0ryl6JMGfjvf2HDBrj5ZnjpJahVC95+G06eLKxPVZfKR3xySjN0rtWZtMfT+OWeX3im3TMEFQti/LLx+Ij9T/WrdV/xTMwzLN66mOMZx8/3tkqpbF5dvTI2Fh59FBYvhnr17C+B7t1tb195jqMnj+bM6Hl03qOMWzaOLJNFgF8Abau25aa6NzG05VDAzvEvVawUvj6+Toas1GV3vh69Vyd6AGPgu+/gscdg40Zo187etI3Wx7Y81v6j+1m8dTELtixg4ZaFFPcrTtzAOABavtuSFTtWUDGoIpWvqEylUpW4MvxKRl41EoD4nfHUK1ePEn75e4RcKXdXpBP9KSdPwrvvwjPPwN690K8fjB4NkVq6xeMdPnE454GuKQlT2LhvI9sPbSf1UCqph1JpWrEpn/T6BIBKr1Ui7WgarcJb0b5ae9pFtKNNeJucvxiU8lSa6HM5eNAO4YwfD+np0KsXjBoFLVoUyscpNzP7z9nEJMfw89afWbFjBVkmi+GthjOu8zhOZp5k8dbFtKnShpL+JZ0OVal80USfh717YcIEePNNO1vnuuvgySehfXsdwy8qDh47yK9//UrV0lVpVKERy1KW0fr91vj7+NOycktah7emacWmXF/zekID8z/lV6nLSRP9efz9N0yaBK+9Brt2QevWtoffrRv46PpbRcqRE0dYvHUxi5IX8fPWn4nfGc/xzOPE3BVD+4j2/PrXr0xbMy2nYmdkaKQuy6jchiZ6Fxw7Bh9+aB+6Sk6Ghg1twu/TB/y0mHORdDLzJBv2bqBmSE1K+pfk/RXv89C8hzh84jBgF2pvWL4hP9z+A+UDy7Nm9xrSjqYRUSaCyqUq68wfdVlpos+HjAyYNg3GjIG1a6FGDRg2DO68E0JCHAlJuZEsk0VSWhIrd6xkxY4VrNmzhll9Z+Hr48uD3z/IO3HvAODn40eVK6pQK6QW8+6Yh4iwfPty/Hz8aBDagOJ+xR1uifI2mugvQlaWnZY5Zgz8/juUKAG33AIPPABXXqnj+Ops2w5uY8PeDSQfSCb5QDJbDmzhaMZRZt46E4Cun3VlTuIc/H38aVC+AU0rNqVt1bYMaDrA4ciVN9BEf4ni42HyZJg6FQ4dggYNYOBA28sPDnY6OuUpNu/fTGxqLCt2rGDlTvsXQcPyDYm5KwaAf834F8V9i9MsrBmNKzSmdkhtqpSugp+Pjh2qC9NEX0COHLHDOpMmwR9/2F5+nz62l9+mjfbyVf4YYzh84jClipfCGEOfL/uwLGUZ2/7+pyLfXU3u4qOeH2GM4ZF5j1CtTDVqBtekZkhNagTX0Ae/VA5N9IVg5Urby//0U9vLb9gQ7r/fPoh1EcU3lcqxN30va3avYVPaJqoHV+fa6tey/+h+IsZH8Pfxv0879uWOL/NE2yc4dPwQE5ZNoGZIzZxfBMElgnUd3yJEE30hOnzY9vInT7a9fD8/6NLFDut07257/UoVBGMM+47uY1PaJpLSkti0fxPtI9pzTbVrSNiZQNSkqNOOL128NO92f5dbGtxC6qFU5iTOoU7ZOtQpW4cKgRX0l4CX0UR/maxeDVOm2F5+aiqULm2Hdvr3h6uu0qEdVbiOnDjC5v2b2bR/k/03bRP3NruXZmHN+GbDN/Sa3ivn2KBiQdQOqc17N71Hs7BmbP97O9v+3kadsnUICdDpZZ5IE/1llpkJCxfapP/113Zsv3p1uOMO29OvXdvpCFVRk5mVyV8H/yIxLZE/9/1J4r5E/kz7k4k3TqRGcA3eXPYmw+YOA6BsQFnqlqtL/XL1ebnjy4QGhnLkxBFK+JXQZwPcmCZ6Bx0+DDNn2qT/00+2mmbr1nDvvXY8PzDQ6QiVgh2HdhCbGpvzi2DD3g1s2LuBTcM2EVgskMd/fJwJyydQp2wd6perb79C69M7sjc+4oMxRoeCHKaJ3k1s3w6ffQYff2wfxipd2g7rPPgg1K/vdHRKndv8TfOZv2k+6/euZ/2e9SQfSKZMiTLse2wfIkLfL/uycMtCKgRVoGJQRSoEVqBO2To83e5pABJ2JgBQPrA85UqWw9/X38nmeCVN9G7GGFiyxK589eWXcOKErZM/eDD07AnFtHyKcnNHTx5l+6Ht1AqpBcBH8R+xLGUZO4/sZNfhXew8vJMKQRX47d7fALjy/Sv5LeW3nPODSwRzbfVr+bLPlwCM+XUMxzKOUT6wPJVKVaJWSC1qBtfU8tH5oIneje3ebWvsTJoEW7ZAhQpw3332gayqVZ2OTqmCEZsay18H/2L3kd3sObKH3Ud2U6lUJUZdPQqAppOakrAzAcM/+ahr7a58f9v3AIz6aRShgaHUDqlN7bK1qRFcQwvKnUETvQfIyoJ58+B//4PZs+0Mna5d7bDO9deDr94DU14uIyuDfen7SPk7hcS0RIJLBHNDrRs4mXmSyq9XZk/6npxjfcSHUW1H8cK1L3Do+CGun3o9fj5++Pv44+/rj5+PH/0b9+fWhreyL30fj/34GCEBITlfZUuWpXlYc6oHVycjK4NjGccI9A/06PsM50v0+my1m/DxsfPvu3SBrVvtvPz33rP1dipXtmP5d98Ndeo4HalShcPPx48KQRWoEFSB5pWa52z39/Vn98jd7EvfR2JaIon7EklMS6RV5VYAGAylipUiIyuDk1knST+ZTkZWRk6V0aMZR5m7aS77j+7naMbRnPedeONEBrcYzLo962jyThOK+RYjJCCEikEVqVSqEo9d+RjtItqxN30vS7ctpVKpSlQqVYnygeU9riyF9ujd2IkTNtF/+CH88IPt9V91Fdxzj52fX6qU0xEq5VmOnjxK2tE00o6mUSGoAuUDy7Pj0A6mrppK2tE09qbvZeeRnaQeSmVMxzF0qtmJHxJ/4MbPbsx5Dx/xoUJgBb6+9Wtah7dmweYFvPjLixTzLUZxv+L2X9/iPN/heWoE12BZyjLmbZpHuZLlKFeyHKElQylXshx1y9Ut0OEn7dF7qGLF4F//sl87dtgpmh9+aMfwhw2D3r1t0r/mGl0kRSlXBPgHUNm/MpWvqJyzLaxUWM7i8Xm5utrV/HH/HzlrEJ/6Kl28NACZJpOMrAyOnDzCicwTHM84zonMExw9af96WLVrFc8seuas9036dxI1Q2oycflExi0bR7mS5Zjaayo1Q2oWcKu1R+9xjIFly2zCnzbNrpBVvbod1rnnHqhSxekIlVJnOpl5Mucvhj3pe9ibvpdudbpRwq8E3274lulrp7MnfQ8f9/yYSqUqXdRn6M1YL5Webh/G+vBD+yTuqRu4gwbBDTfoDVylipLzJXr9g9+DlSwJt99un7jdvNkufbh8uU32NWvCSy/Bzp1OR6mUcpomei8REQEvvADbtsEXX0CtWvB//2eHcm65BRYssDdzlVJFjyZ6L+Pvb2/S/vQTbNwIw4dDTAxcdx3UqwevvQZ79zodpVLqctJE78Xq1IFXX4WUFLsMYoUKMGKEnZd/++2weLG9uauU8m6a6IuAEiVsYv/lF1sz/4EH7NO37dpBZCS88Qbs2+d0lEqpwqKJvohp2BAmTLALo3z4oV3c/JFHbC//jjvsLwPt5SvlXVxK9CLSWUQ2ikiSiDyRx/5gEZkpIqtEZLmINHT1XOWMkiXt3PulS2HVKvsQ1nff2YevGjSAceMgLc3pKJVSBeGCiV5EfIGJQBcgEugnIpFnHPYkEG+MaQz0B8bn41zlsEaN4K23bC//gw9snfyHH4ZKlexDWPHxTkeolLoUrvToWwJJxpjNxpgTwDSgxxnHRAILAIwxG4AIEang4rnKTQQG2sT+2282uQ8YADNmQNOmdjz/66/tMolKKc/iSqKvDGzL9Tole1tuCcDNACLSEqgGhLt4LtnnDRSRWBGJ3bNnT16HqMuoSRNbMnn7djtzZ+tWW3OnVi07RfPAAacjVEq5ypVEn1eB5jNv140BgkUkHvg3sBLIcPFcu9GYycaYaGNMdGhoqAthqcuhTBl49FFISoKvvrKLoYwYAeHhMGSInauvlHJvriT6FCB3qaxwIDX3AcaYv40x9xhjorBj9KHAFlfOVZ7Bzw9uvhl+/hlWrrRP2773nn0Iq0sXmDtXn7xVyl25kuj/AGqLSHURKQb0BWblPkBEymTvA7gPWGyM+duVc5XniYqyUzO3bYPnn7fj+V26QN268PrrOltHKXdzwURvjMkAhgLzgPXADGPMWhEZJCKDsg+rD6wVkQ3YGTbDz3duwTdDOaF8eXjqKTt+/9lnULGiHeapXBnuvRfi4pyOUCkFWqZYFbCEBHsTd+pUW0a5VSsYPNiuiFWihNPRKeW9tEyxumyaNIFJk+yc/PHj7eycu+6yN28ffxy2bHE6QqWKHk30qlCULm2XO1y/3lbSbNfOTsusWRN69LA3dd3wj0mlvJImelWoRKBjRzs1MzkZnnwSliyB9u0hOho+/dQugq6UKjya6NVlEx7+z+IokybZMfw77rBr3o4Zo7N1lCosmujVZRcQAAMHwtq1MGeOLZU8apRdDWvIEEhMdDpCpbyLJnrlGB8fO//+xx/tbJ1bb7UPYdWtq+P4ShUkTfTKLTRubCtnbt1q5+YvXWrH8du0gW++0adulboUmuiVW6lYEZ57Dv76C95+G/bsgV69bI38jz7SG7dKXQxN9MotBQTAoEG2aNrnn0Px4raEcq1adn7+kSNOR6iU59BEr9yanx/07WsLqf3wA9SoAQ89ZKtoPvecrnWrlCs00SuPIAKdO8OiRXb8vm1bePZZm/AfftiO7Sul8qaJXnmcNm3g229hzRro3dsug1izpq2ns3SpztRR6kya6JXHatAAPv4YNm+2VTN//BGuugpat7bj+idPOh2hUu5BE73yeFWqwH//CykpMHGiLaR22236xK1Sp2iiV14jMNCWRF6/HmbP/ueJ2/BwePBBu12pokgTvfI6Pj5w440wfz6sXg23325XxIqMtNuXLnU6QqUuL030yqs1bAjvvmsLqf3nPxAba8fxO3a0JRaUKgo00asiITQURo+2C5+8/jqsW2dLLFxzja2XrzN1lDfTRK+KlMBAO+9+82Z48037b6dOcOWVtpKmJnzljTTRqyIpIACGDoVNm+Cdd2DHDujaFVq0sHP0NeErb6KJXhVpxYvDAw/YGvjvv2+nZvbsCVFR8OWXWjVTeQdN9EoB/v4wYABs2ACffALHj8Mtt9jFzmfM0ISvPJsmeqVy8fODO++0q199+ilkZNgFURo1gmnTIDPT6QiVyj9N9ErlwdfXPl27Zo1N8AD9+tnpmp99pglfeRZN9Eqdh6+v7dGvXm2HcPz87ANYDRrA1Km2x6+Uu9NEr5QLfHzsmH1Cgr1JW7y4HeKJjLSF1TThK3emiV6pfPDxgX/9yy6E8vXXdl7+3XdDvXq2zIJWzFTuSBO9UhfBx8euZbtihZ13f8UVdtZOvXp2mqYmfOVONNErdQlE4KabIC4OZs2C4GC47z6oU8fW2NHFzJU70ESvVAEQge7d4Y8/4PvvbW2dgQOhdm2YNEkTvnKWS4leRDqLyEYRSRKRJ/LYX1pEvhORBBFZKyL35NqXLCKrRSReRGILMnil3I2ILaWwbJmtnRMWBoMGQa1a8Pbb9kEspS63CyZ6EfEFJgJdgEign4hEnnHYEGCdMaYJ0B54TUSK5drfwRgTZYyJLpiwlXJvItClC/z2G8ydaxc/GTzYjuFPmaLz8NXl5UqPviWQZIzZbIw5AUwDepxxjAFKiYgAQUAaoBPOVJEnAjfcAEuW2IQfHAz9+0PTpnaIR4unqcvBlURfGdiW63VK9rbc3gLqA6nAamC4MeZUdRADzBeROBEZeK4PEZGBIhIrIrF79uxxuQFKeYJTCT821j5pe/SoHdO/5hr7S0CpwuRKopc8tp3ZD7kBiAcqAVHAWyJyRfa+q4wxzbBDP0NE5Jq8PsQYM9kYE22MiQ4NDXUldqU8jo+PfdJ23To7Zp+UBG3b2pk7q1c7HZ3yVq4k+hSgSq7X4diee273AF8bKwnYAtQDMMakZv+7G5iJHQpSqkjz97c3aZOS4KWXYPFiWynzrrsgOdnp6JS3cSXR/wHUFpHq2TdY+wKzzjjmL6AjgIhUAOoCm0UkUERKZW8PBK4H1hRU8Ep5usBAGDXKrnQ1YoStp1O3Ljz0EOzd63R0yltcMNEbYzKAocA8YD0wwxizVkQGicig7MP+A1wpIquBBcDjxpi9QAXgVxFJAJYDs40xcwujIUp5spAQeOUVuwBK//52mcOaNW1vPz3d6eiUpxPjhrf9o6OjTWysTrlXRde6dbanP2sWVKoEzz1na+r4+TkdmXJXIhJ3rins+mSsUm4oMtLW0PnlF6haFe6/347hz5qlUzJV/mmiV8qNtW0LS5fCV1/ZUsg9ekC7dvbJW6VcpYleKTcnAjffbFe7evtt+PNPaN0aeve23yt1IZrolfIQuadkPvcczJtnV7oaPhz273c6OuXONNEr5WGCguDpp23Cv+8+eOstWyXznXe0ho7KmyZ6pTxUhQp2KGfFCrto+YMPQrNm8PPPTkem3I0meqU8XJMmEBMDX3wBBw9C+/bQpw9s3ep0ZMpdaKJXyguI2Juz69fD88/bypj16sEzz+gDV0oTvVJeJSAAnnoKNm6Enj1t0q9XD6ZP1/n3RZkmeqW8UJUq8PnntlhauXLQt6+df5+Q4HRkygma6JXyYldfbdexnTzZDus0awZDhkBamtORqctJE71SXs7X15ZQ+PNPm+TfeQfq1LHJX6djFg2a6JUqIoKDYcIEWLnSPmj1wAPQqpVd11Z5N030ShUxjRvDokXw2WewYwdceaWtjLlzp9ORqcKiiV6pIkgE+vWzs3OeeMIm/Tp14PXX4eRJp6NTBU0TvVJFWFAQvPwyrF1rK2U++qh9AGvBAqcjUwVJE71Sitq1YfZsW+/++HG47jo7JXP7dqcjUwVBE71SCrDDOd272979c8/ZhU/q1oWxY+HECaejU5dCE71S6jQlStjqmOvWQceO8NhjEBUFCxc6HZm6WJrolVJ5ql7d9uq/+w6OHbNJv18/Hc7xRJrolVLn1a2bHc559lmYOdPWznn1VZ2d40k00SulLiggwFbCXLfOlkEeOdIO5yxa5HBgyiWa6JVSLqtRww7lzJoFR49Chw4wYADs2+d0ZOp8NNErpfLt1OycUaNgyhSoX99Wy9RSyO5JE71S6qIEBMBLL0FcHEREwG23wY03QnKy05GpM2miV0pdksaNbWG08ePhl19swbTXX4eMDKcjU6dooldKXTJfXxg2zN6svfZaW0qhdWtbKVM5TxO9UqrAVK1qb9TOmAEpKdCihX3gStetdZYmeqVUgRKBW26xK1oNGGBLKDRsCD/95HRkRZdLiV5EOovIRhFJEpEn8thfWkS+E5EEEVkrIve4eq5SyjsFB9tVrH7+Gfz9oVMnu9LVwYNOR1b0XDDRi4gvMBHoAkQC/UQk8ozDhgDrjDFNgPbAayJSzMVzlVJe7JprID4eHn8cPvgAIiPh+++djqpocaVH3xJIMsZsNsacAKYBPc44xgClRESAICANyHDxXKWUlwsIgDFjYNkyKFvWzsO/4w590OpycSXRVwa25Xqdkr0tt7eA+kAqsBoYbozJcvFcAERkoIjEikjsnj17XAxfKeVJoqMhNtbWzZk+3fbuv/zS6ai8nyuJXvLYdubzbzcA8UAlIAp4S0SucPFcu9GYycaYaGNMdGhoqAthKaU8UbFitm5OXBxUqWJv3P7rX7pmbWFyJdGnAFVyvQ7H9txzuwf42lhJwBagnovnKqWKoMaN4fff7ZDO7Nm2d//JJ1pGoTC4kuj/AGqLSHURKQb0BWadccxfQEcAEakA1AU2u3iuUqqI8vOzN2kTEmy9nLvugq5dYdu2C5+rXHfBRG+MyQCGAvOA9cAMY8xaERkkIoOyD/sPcKWIrAYWAI8bY/ae69zCaIhSynPVrQuLF8O4cXY6ZoMGMGkSZGU5HZl3EOOGfydFR0eb2NhYp8NQSjlg82Y7337hQlv7/r33oGZNp6NyfyISZ4yJzmufPhmrlHIrNWrYp2gnT4YVK6BRI3jjDcjMdDoyz6WJXinldkRsr37tWlsk7ZFHoG1bW1ZB5Z8meqWU2woPtytaTZ0Kf/5ply986SVdrza/NNErpdyaCNx+uy2B3KMH/N//QcuWtqyCco0meqWUR6hQwZY//uor2LHDPmX79NNw4oTTkbk/TfRKKY9y8822d3/bbfCf/0CrVrB6tdNRuTdN9EopjxMSYp+inTkTUlOheXN4+WVdvvBcNNErpTxWz56wZo0du3/ySTszZ+NGp6NyP5rolVIeLTTUjt1//jkkJtqZOePG6VO1uWmiV0p5PBHo29f27q+7Dh5+GDp0sE/ZKk30SikvEhZmFyf/4AM7/bJxY1szxw0rvVxWmuiVUl5FBO65x87EadMGBg2Czp2LdkVMTfRKKa9UtSrMnw//+x8sWQING8KHHxbN3r0meqWU1xKBBx+EVavsTdoBA+x6talFbPkjTfRKKa9XowbExMD48bb8cYMGtn5OUenda6JXShUJPj4wbJi9SRsZCXfeaZ+y3bXL6cgKnyZ6pVSRUqeOXc1q7Fj44Qfbu58xw+moCpcmeqVUkePrCyNGwMqVdljn1luhTx/Yu9fpyAqHJnqlVJFVvz4sXWpr3H/zje3dz5rldFQFTxO9UqpI8/ODUaMgLg4qVbJ1cx58ENLTnY6s4GiiV0op7Nq0v/9uh3TeecfWu/eWxU000SulVLbixe1N2h9/hIMH7UpWr73m+QXSNNErpdQZrrvOPmTVtavt4d9wg2c/ZKWJXiml8lC2LHz9NUyebG/YNmpkb9h6Ik30Sil1DiJw//2wYgVERECvXvDAA3DkiNOR5Y8meqWUuoC6deG33+Dxx+Hdd+3ShStWOB2V6zTRK6WUC4oVgzFjYMECOHzYLko+ZgxkZjod2YVpoldKqXzo0MHeqO3Vy86/79ABkpOdjur8NNErpVQ+hYTA9OnwySf/rGQ1ZYr7VsPURK+UUhdBxFbAXLUKmjSB/v1tzZy0NKcjO5tLiV5EOovIRhFJEpEn8tg/UkTis7/WiEimiIRk70sWkdXZ+2ILugFKKeWkiAhYtAhefhlmzrTTMH/6yemoTnfBRC8ivsBEoAsQCfQTkcjcxxhjxhpjoowxUcAo4GdjTO7fax2y90cXXOhKKeUefH3hiSdg2TK44gro1AkeeQSOHXM6MsuVHn1LIMkYs9kYcwKYBvQ4z/H9gM8LIjillPIkzZrZ4mhDhsAbb0CLFpCQ4HRUriX6ykDu9dNTsredRURKAp2Br3JtNsB8EYkTkYHn+hARGSgisSISu2fPHhfCUkop91OyJLz1FsyZY+vbt2wJb77p7I1aVxK95LHtXCF3B5acMWxzlTGmGXboZ4iIXJPXicaYycaYaGNMdGhoqAthKaWU++rSBVavhuuvt0sY9u0Lhw45E4sriT4FqJLrdThwrvI+fTlj2MYYk5r9725gJnYoSCmlvF65cvDtt/bBqi+/tEM5a9Zc/jhcSfR/ALVFpLqIFMMm87PWYBGR0kA74Ntc2wJFpNSp74HrAQeaqZRSzvDxsaUTFiyAAwfsUM6UKZc5hgsdYIzJAIYC84D1wAxjzFoRGSQig3Id2guYb4zJXe6nAvCriCQAy4HZxpi5BRe+Ukp5hvbt7Rq1LVvaOfeDBl2+WTli3PBRrujoaBMbq1PulVLeJyMDnnrKDuc0a2aHdKpXv/T3FZG4c01h1ydjlVLqMvLzsw9XzZoFmzfbZF/YC5JroldKKQd0727n3NeoYRckf+IJ29svDJrolVLKITVqwJIldjGT//4XOna0JZALml/Bv6VSSilXlSgB77wDbdvamjmBgQX/GZrolVLKDdxxh/0qDDp0o5RSXk4TvVJKeTlN9Eop5eU00SullJfTRK+UUl5OE71SSnk5TfRKKeXlNNErpZSXc8vqlSKyB9iaa1M5YK9D4RQWb2uTt7UHvK9N3tYe8L42XUp7qhlj8lyezy0T/ZlEJPZc5Tc9lbe1ydvaA97XJm9rD3hfmwqrPTp0o5RSXk4TvVJKeTlPSfSTnQ6gEHhbm7ytPeB9bfK29oD3talQ2uMRY/RKKaUunqf06JVSSl0kTfRKKeXl3C7Ri8gHIrJbRNbk2hYiIj+KSGL2v8FOxphf52jTsyKyXUTis79udDLG/BCRKiISIyLrRWStiAzP3u6R1+k87fHka1RCRJaLSEJ2m57L3u6p1+hc7fHYawQgIr4islJEvs9+XSjXx+3G6EXkGuAw8IkxpmH2tleANGPMGBF5Agg2xjzuZJz5cY42PQscNsa86mRsF0NEwoAwY8wKESkFxAE9gbvxwOt0nvb0wXOvkQCBxpjDIuIP/AoMB27GM6/RudrTGQ+9RgAi8ggQDVxhjOlWWLnO7Xr0xpjFQNoZm3sAH2d//zH2f0KPcY42eSxjzA5jzIrs7w8B64HKeOh1Ok97PJaxTi0z7Z/9ZfDca3Su9ngsEQkHugLv5dpcKNfH7RL9OVQwxuwA+z8lUN7heArKUBFZlT204xF/Qp9JRCKApsAyvOA6ndEe8OBrlD0sEA/sBn40xnj0NTpHe8Bzr9E44DEgK9e2Qrk+npLovdHbQE0gCtgBvOZoNBdBRIKAr4CHjDF/Ox3PpcqjPR59jYwxmcaYKCAcaCkiDR0O6ZKcoz0eeY1EpBuw2xgTdzk+z1MS/a7scdRT46m7HY7nkhljdmX/h5sFvAu0dDqm/MgeJ/0K+NQY83X2Zo+9Tnm1x9Ov0SnGmAPAIux4tsdeo1Nyt8eDr9FVwE0ikgxMA64VkakU0vXxlEQ/C7gr+/u7gG8djKVAnLqY2XoBa851rLvJvjH2PrDeGPN6rl0eeZ3O1R4Pv0ahIlIm+/sA4DpgA557jfJsj6deI2PMKGNMuDEmAugLLDTG3EEhXR93nHXzOdAeW65zF/AM8A0wA6gK/AXcYozxmJub52hTe+yfmwZIBh44NTbn7kSkLfALsJp/xhefxI5re9x1Ok97+uG516gx9maeL7ZDN8MY87yIlMUzr9G52jMFD71Gp4hIe2BE9qybQrk+bpfolVJKFSxPGbpRSil1kTTRK6WUl9NEr5RSXk4TvVJKeTlN9Eop5eU00SullJfTRK+UUl7u/wEb2jfukM+2VwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df = pd.DataFrame(\n", " model.learning_process[10:], columns=[\"epoch\", \"train_RMSE\", \"test_RMSE\"]\n", ")\n", "plt.plot(\"epoch\", \"train_RMSE\", data=df, color=\"blue\")\n", "plt.plot(\"epoch\", \"test_RMSE\", data=df, color=\"green\", linestyle=\"dashed\")\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Saving and evaluating recommendations" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "model.estimations()\n", "\n", "top_n = pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n", "\n", "top_n.to_csv(\n", " \"Recommendations generated/ml-100k/Self_SVD_reco.csv\", index=False, header=False\n", ")\n", "\n", "estimations = pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n", "estimations.to_csv(\n", " \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\",\n", " index=False,\n", " header=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 8683.10it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
00.9143930.7171990.1016970.0423340.0517870.0688110.0924890.072360.1048390.048970.1961170.5178890.4803820.8673380.1471863.8525450.972694
\n", "
" ], "text/plain": [ " RMSE MAE precision recall F_1 F_05 \\\n", "0 0.914393 0.717199 0.101697 0.042334 0.051787 0.068811 \n", "\n", " precision_super recall_super NDCG mAP MRR LAUC \\\n", "0 0.092489 0.07236 0.104839 0.04897 0.196117 0.517889 \n", "\n", " HR Reco in test Test coverage Shannon Gini \n", "0 0.480382 0.867338 0.147186 3.852545 0.972694 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import evaluation_measures as ev\n", "\n", "estimations_df = pd.read_csv(\n", " \"Recommendations generated/ml-100k/Self_SVD_estimations.csv\", header=None\n", ")\n", "reco = np.loadtxt(\"Recommendations generated/ml-100k/Self_SVD_reco.csv\", delimiter=\",\")\n", "\n", "ev.evaluate(\n", " test=pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None),\n", " estimations_df=estimations_df,\n", " reco=reco,\n", " super_reactions=[4, 5],\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 8505.85it/s]\n", "943it [00:00, 9544.72it/s]\n", "943it [00:00, 9154.80it/s]\n", "943it [00:00, 8282.66it/s]\n", "943it [00:00, 8432.23it/s]\n", "943it [00:00, 9601.30it/s]\n", "943it [00:00, 9158.89it/s]\n", "943it [00:00, 12283.59it/s]\n", "943it [00:00, 9500.43it/s]\n", "943it [00:00, 10085.91it/s]\n", "943it [00:00, 10260.90it/s]\n", "943it [00:00, 9691.20it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Ready_SVD0.9503470.7493120.1006360.0505140.0557940.0707530.0912020.0827340.1140540.0532000.2488030.5219830.5174970.9921530.2106784.4186830.952848
0Self_SVD0.9143930.7171990.1016970.0423340.0517870.0688110.0924890.0723600.1048390.0489700.1961170.5178890.4803820.8673380.1471863.8525450.972694
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Ready_SVDBiased0.9394720.7398160.0858960.0360730.0435280.0576430.0770390.0574630.0977530.0455460.2198390.5147090.4316010.9974550.1688314.2175780.962577
0Ready_Random1.5218451.2259490.0471900.0207530.0248100.0322690.0295060.0237070.0500750.0187280.1219570.5068930.3297990.9865320.1847045.0997060.907217
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_TopRated1.0307120.8209040.0009540.0001880.0002980.0004810.0006440.0002230.0010430.0003350.0033480.4964330.0095440.6990460.0050511.9459100.995669
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Ready_SVD 0.950347 0.749312 0.100636 0.050514 0.055794 \n", "0 Self_SVD 0.914393 0.717199 0.101697 0.042334 0.051787 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Ready_SVDBiased 0.939472 0.739816 0.085896 0.036073 0.043528 \n", "0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.070753 0.091202 0.082734 0.114054 0.053200 0.248803 \n", "0 0.068811 0.092489 0.072360 0.104839 0.048970 0.196117 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.057643 0.077039 0.057463 0.097753 0.045546 0.219839 \n", "0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.521983 0.517497 0.992153 0.210678 4.418683 0.952848 \n", "0 0.517889 0.480382 0.867338 0.147186 3.852545 0.972694 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.514709 0.431601 0.997455 0.168831 4.217578 0.962577 \n", "0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dir_path = \"Recommendations generated/ml-100k/\"\n", "super_reactions = [4, 5]\n", "test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embeddings" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
codescoreitem_ididtitlegenres
03211.000000322322Murder at 1600 (1997)Mystery, Thriller
19830.902748984984Shadow Conspiracy (1997)Thriller
29850.894696986986Turbulence (1997)Thriller
37780.890524779779Drop Zone (1994)Action
46860.889220687687McHale's Navy (1997)Comedy, War
53310.887596332332Kiss the Girls (1997)Crime, Drama, Thriller
69870.886547988988Beautician and the Beast, The (1997)Comedy, Romance
710390.88284510401040Two if by Sea (1996)Comedy, Romance
810220.88278210231023Fathers' Day (1997)Comedy
99290.877662930930Chain Reaction (1996)Action, Adventure, Thriller
\n", "
" ], "text/plain": [ " code score item_id id title \\\n", "0 321 1.000000 322 322 Murder at 1600 (1997) \n", "1 983 0.902748 984 984 Shadow Conspiracy (1997) \n", "2 985 0.894696 986 986 Turbulence (1997) \n", "3 778 0.890524 779 779 Drop Zone (1994) \n", "4 686 0.889220 687 687 McHale's Navy (1997) \n", "5 331 0.887596 332 332 Kiss the Girls (1997) \n", "6 987 0.886547 988 988 Beautician and the Beast, The (1997) \n", "7 1039 0.882845 1040 1040 Two if by Sea (1996) \n", "8 1022 0.882782 1023 1023 Fathers' Day (1997) \n", "9 929 0.877662 930 930 Chain Reaction (1996) \n", "\n", " genres \n", "0 Mystery, Thriller \n", "1 Thriller \n", "2 Thriller \n", "3 Action \n", "4 Comedy, War \n", "5 Crime, Drama, Thriller \n", "6 Comedy, Romance \n", "7 Comedy, Romance \n", "8 Comedy \n", "9 Action, Adventure, Thriller " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item = random.choice(list(set(train_ui.indices)))\n", "\n", "embeddings_norm = (\n", " model.Qi / np.linalg.norm(model.Qi, axis=1)[:, None]\n", ") # we do not mean-center here\n", "# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n", "\n", "similarity_scores = np.dot(embeddings_norm, embeddings_norm[item].T)\n", "top_similar_items = pd.DataFrame(\n", " enumerate(similarity_scores), columns=[\"code\", \"score\"]\n", ").sort_values(by=[\"score\"], ascending=[False])[:10]\n", "\n", "top_similar_items[\"item_id\"] = top_similar_items[\"code\"].apply(\n", " lambda x: item_code_id[x]\n", ")\n", "\n", "items = pd.read_csv(\"./Datasets/ml-100k/movies.csv\")\n", "\n", "result = pd.merge(top_similar_items, items, left_on=\"item_id\", right_on=\"id\")\n", "\n", "result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# project task 5: implement SVD on top baseline (as it is in Surprise library)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# making changes to our implementation by considering additional parameters in the gradient descent procedure\n", "# seems to be the fastest option\n", "# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n", "# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Ready-made SVD - Surprise implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "\n", "algo = sp.SVD(biased=False) # to use unbiased version\n", "\n", "helpers.ready_made(\n", " algo,\n", " reco_path=\"Recommendations generated/ml-100k/Ready_SVD_reco.csv\",\n", " estimations_path=\"Recommendations generated/ml-100k/Ready_SVD_estimations.csv\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD biased - on top baseline" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "algo = sp.SVD() # default is biased=True\n", "\n", "helpers.ready_made(\n", " algo,\n", " reco_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv\",\n", " estimations_path=\"Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv\",\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 11456.53it/s]\n", "943it [00:00, 11932.50it/s]\n", "943it [00:00, 10853.07it/s]\n", "943it [00:00, 9426.44it/s]\n", "943it [00:00, 8757.09it/s]\n", "943it [00:00, 9999.67it/s]\n", "943it [00:00, 11323.49it/s]\n", "943it [00:00, 9764.72it/s]\n", "943it [00:00, 9692.41it/s]\n", "943it [00:00, 9052.77it/s]\n", "943it [00:00, 8645.18it/s]\n", "943it [00:00, 10594.54it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Ready_SVD0.9516520.7509750.0963940.0472520.0528700.0672570.0855150.0747540.1095780.0515620.2355670.5203410.4962880.9955460.2085144.4557550.951624
0Self_SVD0.9143930.7171990.1016970.0423340.0517870.0688110.0924890.0723600.1048390.0489700.1961170.5178890.4803820.8673380.1471863.8525450.972694
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Ready_SVDBiased0.9404130.7395710.0860020.0354780.0431960.0575070.0757510.0534600.0948970.0433610.2091240.5144050.4284200.9973490.1774894.2125090.962656
0Ready_Random1.5218451.2259490.0471900.0207530.0248100.0322690.0295060.0237070.0500750.0187280.1219570.5068930.3297990.9865320.1847045.0997060.907217
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_TopRated1.0307120.8209040.0009540.0001880.0002980.0004810.0006440.0002230.0010430.0003350.0033480.4964330.0095440.6990460.0050511.9459100.995669
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Ready_SVD 0.951652 0.750975 0.096394 0.047252 0.052870 \n", "0 Self_SVD 0.914393 0.717199 0.101697 0.042334 0.051787 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Ready_SVDBiased 0.940413 0.739571 0.086002 0.035478 0.043196 \n", "0 Ready_Random 1.521845 1.225949 0.047190 0.020753 0.024810 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_TopRated 1.030712 0.820904 0.000954 0.000188 0.000298 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.067257 0.085515 0.074754 0.109578 0.051562 0.235567 \n", "0 0.068811 0.092489 0.072360 0.104839 0.048970 0.196117 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.057507 0.075751 0.053460 0.094897 0.043361 0.209124 \n", "0 0.032269 0.029506 0.023707 0.050075 0.018728 0.121957 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.520341 0.496288 0.995546 0.208514 4.455755 0.951624 \n", "0 0.517889 0.480382 0.867338 0.147186 3.852545 0.972694 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.514405 0.428420 0.997349 0.177489 4.212509 0.962656 \n", "0 0.506893 0.329799 0.986532 0.184704 5.099706 0.907217 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dir_path = \"Recommendations generated/ml-100k/\"\n", "super_reactions = [4, 5]\n", "test = pd.read_csv(\"./Datasets/ml-100k/test.csv\", sep=\"\\t\", header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }