{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Self made SVD" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import helpers\n", "import pandas as pd\n", "import numpy as np\n", "import scipy.sparse as sparse\n", "from collections import defaultdict\n", "from itertools import chain\n", "import random\n", "\n", "train_read=pd.read_csv('./Datasets/ml-100k/train.csv', sep='\\t', header=None)\n", "test_read=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "train_ui, test_ui, user_code_id, user_id_code, item_code_id, item_id_code = helpers.data_to_csr(train_read, test_read)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n", "from tqdm import tqdm\n", "\n", "class SVD():\n", " \n", " def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n", " self.train_ui=train_ui\n", " self.uir=list(zip(*[train_ui.nonzero()[0],train_ui.nonzero()[1], train_ui.data]))\n", " \n", " self.learning_rate=learning_rate\n", " self.regularization=regularization\n", " self.iterations=iterations\n", " self.nb_users, self.nb_items=train_ui.shape\n", " self.nb_ratings=train_ui.nnz\n", " self.nb_factors=nb_factors\n", " \n", " self.Pu=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_users, self.nb_factors))\n", " self.Qi=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_items, self.nb_factors))\n", "\n", " def train(self, test_ui=None):\n", " if test_ui!=None:\n", " self.test_uir=list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n", " \n", " self.learning_process=[]\n", " pbar = tqdm(range(self.iterations))\n", " for i in pbar:\n", " pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...')\n", " np.random.shuffle(self.uir)\n", " self.sgd(self.uir)\n", " if test_ui==None:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir)])\n", " else:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n", " \n", " def sgd(self, uir):\n", " \n", " for u, i, score in uir:\n", " # Computer prediction and error\n", " prediction = self.get_rating(u,i)\n", " e = (score - prediction)\n", " \n", " # Update user and item latent feature matrices\n", " Pu_update=self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n", " Qi_update=self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n", " \n", " self.Pu[u] += Pu_update\n", " self.Qi[i] += Qi_update\n", " \n", " def get_rating(self, u, i):\n", " prediction = self.Pu[u].dot(self.Qi[i].T)\n", " return prediction\n", " \n", " def RMSE_total(self, uir):\n", " RMSE=0\n", " for u,i, score in uir:\n", " prediction = self.get_rating(u,i)\n", " RMSE+=(score - prediction)**2\n", " return np.sqrt(RMSE/len(uir))\n", " \n", " def estimations(self):\n", " self.estimations=\\\n", " np.dot(self.Pu,self.Qi.T)\n", "\n", " def recommend(self, user_code_id, item_code_id, topK=10):\n", " \n", " top_k = defaultdict(list)\n", " for nb_user, user in enumerate(self.estimations):\n", " \n", " user_rated=self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user+1]]\n", " for item, score in enumerate(user):\n", " if item not in user_rated and not np.isnan(score):\n", " top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n", " result=[]\n", " # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n", " for uid, item_scores in top_k.items():\n", " item_scores.sort(key=lambda x: x[1], reverse=True)\n", " result.append([uid]+list(chain(*item_scores[:topK])))\n", " return result\n", " \n", " def estimate(self, user_code_id, item_code_id, test_ui):\n", " result=[]\n", " for user, item in zip(*test_ui.nonzero()):\n", " result.append([user_code_id[user], item_code_id[item], \n", " self.estimations[user,item] if not np.isnan(self.estimations[user,item]) else 1])\n", " return result" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 39 RMSE: 0.7476169311403564. Training epoch 40...: 100%|██████████| 40/40 [01:47<00:00, 2.68s/it]\n" ] } ], "source": [ "model=SVD(train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40)\n", "model.train(test_ui)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAcHElEQVR4nO3de3RV9Z338feXJBCQlJDLKEPAYHWWVtCoMeiIDsqqg4wXpg+PpX1mWsfOylOVQVvLU2Qcrc64RktbK73o0NbLdPqIdrxbq9MKjpcqNGBAKBRQtARpCWBCqIAEvvPHPoechJPkJDnJPtn781prr7Nv55wve5HP/p3fvpm7IyIig9+QsAsQEZHsUKCLiESEAl1EJCIU6CIiEaFAFxGJiPywvrisrMwrKyvD+noRkUFp5cqVO929PN2y0AK9srKSurq6sL5eRGRQMrP3OlumLhcRkYhQoIuIRIQCXUQkIkLrQxeRwefgwYM0NDSwf//+sEuJvMLCQioqKigoKMj4PQp0EclYQ0MDRUVFVFZWYmZhlxNZ7s6uXbtoaGhgwoQJGb9PXS4ikrH9+/dTWlqqMO9nZkZpaWmPfwkp0EWkRxTmA6M323nQBfqWLXDDDXDwYNiViIjklkEX6GvXwj33wOLFYVciIpJbBl2gX3opTJ0KX/saNDeHXY2IDKSmpia+//3v9/h9M2bMoKmpqcfvu+qqq5gwYQJVVVWcfvrpvPjii0eWTZ06lfHjx5P6kKCZM2cycuRIAA4fPszcuXOZOHEikyZN4uyzz2bLli1AcKX8pEmTqKqqoqqqirlz5/a4tnQG3VkuZvCNb0B1Ndx5J/zrv4ZdkYgMlGSgX3vtte3mt7a2kp/feZw999xzvf7OhQsXMmvWLJYtW0ZtbS2bNm06sqy4uJjXXnuNKVOm0NTUxPbt248se+SRR3j//fdZs2YNQ4YMoaGhgWOOOebI8mXLllFWVtbrutIZdIEOcNZZ8Dd/A3ffDddcA+PHh12RSPzccAPU12f3M6uq4Nvf7nz5/Pnzefvtt6mqqqKgoIDCwkJGjx7Nhg0b2LhxIzNnzmTr1q3s37+f66+/ntraWqDt3lF79+7lkksuYcqUKfzqV79i7NixPPXUUwwfPrzb2s4991y2bdvWbt7s2bNZsmQJU6ZM4fHHH+dTn/oU69atA2D79u2MGTOGIUOCjpCKiopebpXMDboul6Q77gheb7453DpEZODceeedfPzjH6e+vp6FCxeyatUq7rnnHjZu3AjA/fffz8qVK6mrq2PRokXs2rXrqM/YtGkT1113HevWraO4uJjHHnsso+9+/vnnmTlzZrt506ZN4+WXX+bQoUMsWbKET3/600eWXXnllTzzzDNUVVVx44038uabb7Z774UXXniky+Xuu+/u6aZIa1C20CFolX/pS0G3yw03wJlnhl2RSLx01ZIeKDU1Ne0uvFm0aBFPPPEEAFu3bmXTpk2Ulpa2e0+yTxzgrLPO4t133+3yO+bNm8eCBQtoaGjg9ddfb7csLy+PKVOmsGTJEvbt20fqLcErKir47W9/y9KlS1m6dCnTpk3jpz/9KdOmTQP6p8tl0LbQAebPh7IyuPFGSDkuISIxkdon/dJLL/HLX/6S119/ndWrV3PGGWekvTBn2LBhR8bz8vJobW3t8jsWLlzIxo0bueuuu7j66quPWj579mzmzp3LlVdemfa7LrnkEhYuXMiCBQt48skne/LP67FBHeijRgVnu7z0EvzsZ2FXIyL9raioiJaWlrTLmpubGT16NCNGjGDDhg288cYbWf3uOXPmcPjwYV544YV2888//3xuuukmPvOZz7Sbv2rVKt5//30gOONlzZo1HH/88VmtqaNBHegAtbXwZ38G8+ZBNztaERnkSktLOe+885g4cSLz5s1rt2z69Om0trZyyimnMH/+fM4555ysfreZcfPNN/P1r3/9qPlf+cpXjuo+2bFjB5dddhkTJ07ktNNOIz8/nzlz5hxZntqH/rnPfS47NXpIfRXV1dWerScWPfUUzJwJ994LX/xiVj5SRNJYv349p5xySthlxEa67W1mK929Ot36g76FDnD55XDBBXDrrbBnT9jViIiEo9tAN7NCM1thZqvNbJ2Z3ZZmnavMrNHM6hPD3/dPuZ3VGFxstGMHdPg1JCLSreuuu+5I90dyeOCBB8Iuq8cyOW3xAHCRu+81swLgVTP7ubt3POLwiLvPSfP+AXH22fDZz8I3vxl0uwzAOfwiseTukbvj4ve+972wSzhKb7rDu22he2BvYrIgMeTkSYJ33BGcvqiLjUT6R2FhIbt27epV2Ejmkg+4KCws7NH7MrqwyMzygJXAicD33H15mtX+l5ldAGwEvuTuW9N8Ti1QCzC+H67Xr6yEq6+GH/4QHngg6IoRkeypqKigoaGBxsbGsEuJvOQj6Hoio0B390NAlZkVA0+Y2UR3X5uyyjPAw+5+wMz+L/AQcFGaz1kMLIbgLJceVZqhE08M7pW+Z09wnrqIZE9BQUGPHokmA6tHZ7m4exOwDJjeYf4udz+QmPwhcFZ2yuu55FW+aW7hICISaZmc5VKeaJljZsOBTwIbOqwzJmXycmB9NovsCQW6iMRVJl0uY4CHEv3oQ4BH3f1ZM7sdqHP3p4G5ZnY50ArsBq7qr4K7o0AXkbjqNtDdfQ1wRpr5t6SM3wTclN3SekeBLiJxFYkrRVMp0EUkriIX6MXFwemKCnQRiZvIBXpeXhDqu3eHXYmIyMCKXKBD0O2iFrqIxI0CXUQkIhToIiIRoUAXEYkIBbqISERENtBbWuCjj8KuRERk4EQy0EtKgleduigicRLJQNfVoiISR5EOdLXQRSROIh3oaqGLSJwo0EVEIkKBLiISEZEM9BEjYNgwBbqIxEskA91MFxeJSPxEMtAhOBddgS4icRLZQFcLXUTiJtKBrvPQRSROIh3oaqGLSJxEPtDdw65ERGRgRDrQDx6EvXvDrkREZGBEOtBB3S4iEh/dBrqZFZrZCjNbbWbrzOy2NOsMM7NHzGyzmS03s8r+KLYnFOgiEjeZtNAPABe5++lAFTDdzM7psM4XgA/c/UTgbuCu7JbZcwp0EYmbbgPdA8me6ILE0PFQ4xXAQ4nx/wSmmZllrcpeSD7kQoEuInGRUR+6meWZWT2wA/iFuy/vsMpYYCuAu7cCzUBpms+pNbM6M6trbGzsW+XdUAtdROImo0B390PuXgVUADVmNrE3X+bui9292t2ry8vLe/MRGdNj6EQkbnp0lou7NwHLgOkdFm0DxgGYWT4wCgi1bZyfD6NGqYUuIvGRyVku5WZWnBgfDnwS2NBhtaeBzyfGZwFL3cO/pEdXi4pInORnsM4Y4CEzyyPYATzq7s+a2e1Anbs/DfwI+LGZbQZ2A7P7reIeUKCLSJx0G+juvgY4I838W1LG9wP/O7ul9Z0CXUTiJLJXioICXUTiJdKBrodciEicRDrQS0uhuRlaW8OuRESk/0U+0EHnootIPCjQRUQiIhaBrn50EYkDBbqISEQo0EVEIkKBLiISEZEO9JEjg5t0KdBFJA4iHehmulpUROIj0oEOCnQRiY9YBLrOQxeROIhFoKuFLiJxoEAXEYmI2AR6+M9PEhHpX7EI9AMH4MMPw65ERKR/RT7QS0qCV3W7iEjURT7QdbWoiMSFAl1EJCIU6CIiERGbQNfFRSISdZEPdB0UFZG4iHygDx0KRUUKdBGJvm4D3czGmdkyM/uNma0zs+vTrDPVzJrNrD4x3NI/5faOrhYVkTjIz2CdVuBGd19lZkXASjP7hbv/psN6r7j7pdkvse9KShToIhJ93bbQ3X27u69KjLcA64Gx/V1YNqmFLiJx0KM+dDOrBM4AlqdZfK6ZrTazn5vZqZ28v9bM6sysrrGxscfF9pYCXUTiIONAN7ORwGPADe6+p8PiVcDx7n468B3gyXSf4e6L3b3a3avLy8t7W3OPKdBFJA4yCnQzKyAI85+4++Mdl7v7Hnffmxh/Digws7KsVtoHpaXQ1ASHDoVdiYhI/8nkLBcDfgSsd/dvdbLOcYn1MLOaxOfmTJu4tDS4fW5TU9iViIj0n0zOcjkP+FvgLTOrT8xbAIwHcPf7gFnANWbWCuwDZrvnzh3IUy//T46LiERNt4Hu7q8C1s063wW+m62isk33cxGROIj8laKgy/9FJB5iEehqoYtIHCjQRUQiIhaBPmoU5OUp0EUk2mIR6Ga6n4uIRF8sAh10taiIRF+sAl1PLRKRKItVoKuFLiJRFptAVx+6iERdbAJdLXQRibpYBfq+fcEgIhJFsQp0UCtdRKJLgS4iEhEKdBGRiIhdoOtcdBGJqtgFulroIhJVsQl03RNdRKIuNoFeWAgjRijQRSS6YhPooIuLRCTaFOgiIhGhQBcRiQgFuohIRCjQRUQiInaB/sEHcPhw2JWIiGRfrAK9pCQI8+bmsCsREcm+bgPdzMaZ2TIz+42ZrTOz69OsY2a2yMw2m9kaMzuzf8rtG10tKiJRlkkLvRW40d0/AZwDXGdmn+iwziXASYmhFrg3q1VmiQJdRKKs20B39+3uviox3gKsB8Z2WO0K4N898AZQbGZjsl5tHynQRSTKetSHbmaVwBnA8g6LxgJbU6YbODr0MbNaM6szs7rGxsaeVZoFCnQRibKMA93MRgKPATe4+57efJm7L3b3anevLi8v781H9IkCXUSiLKNAN7MCgjD/ibs/nmaVbcC4lOmKxLycUlwMQ4Yo0EUkmjI5y8WAHwHr3f1bnaz2NPC5xNku5wDN7r49i3VmxZAhMHq0Al1Eoik/g3XOA/4WeMvM6hPzFgDjAdz9PuA5YAawGfgQ+Lvsl5odpaV6apGIRFO3ge7urwLWzToOXJetovpTSYla6CISTbG6UhRg/HjYtCnsKkREsi92gV5TA++9Bzt2hF2JiEh2xS7QJ08OXpd3PJNeRGSQi12gn3km5OXBihVhVyIikl2xC/QRI2DSJLXQRSR6YhfoEHS7rFih+6KLSLTEMtBraoJ7outsFxGJklgGug6MikgUxTLQTz4ZRo7UgVERiZZYBnpeHpx9tlroIhItsQx0CLpdVq+G/fvDrkREJDtiG+g1NXDwINTXd7+uiMhgENtA14FREYma2Ab6n/4pVFTowKiIREdsAx2Cbhe10EUkKmId6JMnw9tvw86dYVciItJ3sQ70mprg9de/DrcOEZFsiHWgV1cHzxlVt4uIREGsA33kSDj1VAW6iERDrAMdgm6XFSvAPexKRET6JvaBPnky7N4dHBwVERnMYh/oyQOjOh9dRAa72Af6qacGTzFSP7qIDHaxD/T8/OBsFwW6iAx23Qa6md1vZjvMbG0ny6eaWbOZ1SeGW7JfZv+qqYE334SPPgq7EhGR3sukhf4gML2bdV5x96rEcHvfyxpYkycHYb56ddiViIj0XreB7u4vA7sHoJbQ6MCoiERBtvrQzzWz1Wb2czM7tbOVzKzWzOrMrK6xsTFLX91348bBccepH11EBrdsBPoq4Hh3Px34DvBkZyu6+2J3r3b36vLy8ix8dXaYBd0uCnQRGcz6HOjuvsfd9ybGnwMKzKysz5UNsJoa2LgRPvgg7EpERHqnz4FuZseZmSXGaxKfuauvnzvQkk8w0p0XRWSwyu9uBTN7GJgKlJlZA3ArUADg7vcBs4BrzKwV2AfMdh98d0aprg66XpYvh4svDrsaEZGe6zbQ3f0z3Sz/LvDdrFUUklGj4OSTdaaLiAxesb9SNFXywOjg+30hIqJAb6emBhob4b33wq5ERKTnFOgpkgdGdfqiiAxGCvQUkybB8OHwwgthVyIi0nMK9BQFBfCFL8BDD8GaNWFXIyLSMwr0Dm67DYqL4frrdXBURAYXBXoHJSXwL/8CL70Ejz0WdjUiIplToKdRWwunnQY33ggffhh2NSIimVGgp5GXB4sWwe9+BwsXhl2NiEhmFOid+Iu/gCuvhLvuCoJdRCTXKdC7kGydz5sXbh0iIplQoHdh/Hj46lfh0Ufhv/877GpERLqmQO/GvHlBsM+dC62tYVcjItI5BXo3RoyAb3wjuNDoBz8IuxoRkc4p0DMwaxZMnQo33wy7I/24bBEZzBToGTCDe+6Bpia49dawqxERSU+BnqHTToMvfhHuvRdefTXsakREjqZA74Hbbw8OkE6bBv/2b7rXi4jkFgV6D5SWBg+RvvDCoLV+9dWwb1/YVYmIBBToPVRaCj/7GfzTP8GDD8J558GWLWFXJSKiQO+VvLyg++WZZ+Cdd+Css+D558OuSkTiToHeB5deCnV1MG4czJgB//zPcPhw2FWJSFwp0PvoxBPh9dfhs5+FW26Byy6Dt94KuyoRiSMFehaMGAE//jF85zuwdGlwiuP558PDD8OBA2FXJyJx0W2gm9n9ZrbDzNZ2stzMbJGZbTazNWZ2ZvbLzH1mMGcONDQEtwrYvj1otY8bBwsWwHvvhV2hiERdJi30B4HpXSy/BDgpMdQC9/a9rMGrtDR40tHGjcGB0j//8+Ce6hMmBN0xTz0VXHEqIpJt3Qa6u78MdHUHkyuAf/fAG0CxmY3JVoGD1ZAh8Jd/CU8+GZzWuGBBcA77zJnBc0urquAf/iG4Ne/27WFXKyJRYJ7B5Y5mVgk86+4T0yx7FrjT3V9NTL8IfNXd69KsW0vQimf8+PFnvRezfoiPPoLXXoNXXoGXXw4OpiafWXriiUG/e00NnHIKnHwy/MmfBF05IiJJZrbS3avTLcsfyELcfTGwGKC6ujp2F84PHRpcZXrhhcH0wYPw5ptBwL/yCjz9NDzwQNv6xcVBsKcOEybA2LFBK19hLyKpshHo24BxKdMViXnSjYKCoEVeUxP0ux8+HBxU3bCh/fDCC8FVqakKC4NgTw4VFcHrsccGLfvy8uC1tBTyB3S3LSJhycaf+tPAHDNbAkwGmt1dvcK9MGRIcPOv8ePh4ovbL2tuDsL9d7+DbduC4N+2LRiWL4fHH09/iqRZ0JovLw+G0aODobj46NfiYvjYx9oPQ4cOzL9dRPqu20A3s4eBqUCZmTUAtwIFAO5+H/AcMAPYDHwI/F1/FRtno0bB5MnBkI578PCNHTvahsbG9uONjfDuu1BfH5xps2dP999bWNgW7kVFwTByZNvQcXrkSDjmmPSvyXH9YhDpHxkdFO0P1dXVXld31HFTGUCtrUGof/BBEPDJkE8dmpvbT//xj9DSAnv3tg0tLT275UFh4dHhn8nOILkDSTcUFPTfdhLJJTlzUFRyS35+0B1TUtK3z3GH/fuDsN+7t+21q/GO0y0tsHPn0Z+RaXtj2LCjfwl0HE/3ayJ1fuprUZG6m2TwUaBLn5nB8OHBUFaWvc91D+43nwz3lpa2Yc+e9tOpO4HUHUVDQ/udSE92EgUF7buZOvt1kMkwcqR+RUj/U6BLzjIL7pMzYkRwQDcbkjuJ1G6j5HjHHUS66T17ggPRqctbWzP77qFD0x+DKCoKjpEkj1V0Np56LEPHISQd/beQWEndSRx7bN8/zz04uyg1/DsOXR17aGkJrhROPU6RyS+IESPSh39yOvmaPHspOSTPaBo1Krivv0SLAl2kD8yCg7yFhdn5FXH4cHD1cPJgdHNz204h3YHq1PHf/75tuqWl+x1DUVH7U1bTncaaOi91fMQIXdiWixToIjlkyJC27pixY3v/OYcPB78CmpqCkE89k6mz8S1b2ua1tHT9+QUF7YO+41Bc3HbAvbS0/asONvcfBbpIBA0Z0tYV0xutrcGOIDX0P/jg6PHksGsXbN7cNt3VaaxFRW0Bnxr2nU2XlQVdRPpF0D0FuogcJT+/LVB7yj1o4e/e3Tbs2tU2pE7v3h1c7LZrV7Aj6KybKC8vCPqysvZB39UQx52AAl1Essqs7ddBZWXm7zt0KPhVkBr+yWHnzvbjb78NK1YE4wcPpv+8vLz2Lf3UHUHytby8/etg3wko0EUkJyRb4SUlcNJJmb0n+Wtg58620N+5M7jNRcedwTvvBM8k2LkzuJV1Ovn5bQGfvMFd6s3uOo5/7GO5tQNQoIvIoJX6a+CEEzJ7j3twGmky+Dt7bWyEurrgXkid3fdo6NC2cE8dUkM/dd7w4dn7t6ejQBeRWDFrO5Mo0y6hAwfabnbX2Ah/+EP76eRN8NavD1737Uv/OUVFQbhfey18+ctZ+ycdoUAXEenGsGHBMwcqKjJb/49/bH/n0453QT3uuP6pU4EuIpJlxxwTPF1swoSB/d5uHxItIiKDgwJdRCQiFOgiIhGhQBcRiQgFuohIRCjQRUQiQoEuIhIRCnQRkYgwz/SJudn+YrNG4L0uVikDdg5QOT2l2npHtfWOauudqNZ2vLunfT5WaIHeHTOrc/fqsOtIR7X1jmrrHdXWO3GsTV0uIiIRoUAXEYmIXA70xWEX0AXV1juqrXdUW+/Errac7UMXEZGeyeUWuoiI9IACXUQkInIu0M1supn91sw2m9n8sOtJZWbvmtlbZlZvZnUh13K/me0ws7Up80rM7BdmtinxOjqHavuamW1LbLt6M5sRUm3jzGyZmf3GzNaZ2fWJ+aFvuy5qC33bmVmhma0ws9WJ2m5LzJ9gZssTf6+PmNnQHKrtQTPbkrLdqga6tpQa88zsTTN7NjHdP9vN3XNmAPKAt4ETgKHAauATYdeVUt+7QFnYdSRquQA4E1ibMu/rwPzE+Hzgrhyq7WvAV3Jgu40BzkyMFwEbgU/kwrbrorbQtx1gwMjEeAGwHDgHeBSYnZh/H3BNDtX2IDAr7P9zibq+DPx/4NnEdL9st1xrodcAm939HXf/CFgCXBFyTTnJ3V8GdneYfQXwUGL8IWDmgBaV0EltOcHdt7v7qsR4C7AeGEsObLsuagudB/YmJgsSgwMXAf+ZmB/WduustpxgZhXAXwE/TEwb/bTdci3QxwJbU6YbyJH/0AkO/JeZrTSz2rCLSeNYd9+eGP89cGyYxaQxx8zWJLpkQukOSmVmlcAZBC26nNp2HWqDHNh2iW6DemAH8AuCX9NN7t6aWCW0v9eOtbl7crvdkdhud5vZsDBqA74N/D/gcGK6lH7abrkW6LluirufCVwCXGdmF4RdUGc8+C2XM60U4F7g40AVsB34ZpjFmNlI4DHgBnffk7os7G2Xprac2Hbufsjdq4AKgl/TJ4dRRzodazOzicBNBDWeDZQAXx3ouszsUmCHu68ciO/LtUDfBoxLma5IzMsJ7r4t8boDeILgP3Uu+YOZjQFIvO4IuZ4j3P0PiT+6w8APCHHbmVkBQWD+xN0fT8zOiW2XrrZc2naJepqAZcC5QLGZ5ScWhf73mlLb9EQXlrv7AeABwtlu5wGXm9m7BF3IFwH30E/bLdcC/dfASYkjwEOB2cDTIdcEgJkdY2ZFyXHgYmBt1+8acE8Dn0+Mfx54KsRa2kmGZcJfE9K2S/Rf/ghY7+7fSlkU+rbrrLZc2HZmVm5mxYnx4cAnCfr4lwGzEquFtd3S1bYhZQdtBH3UA77d3P0md69w90qCPFvq7v+H/tpuYR/9TXM0eAbB0f23gX8Mu56Uuk4gOOtmNbAu7NqAhwl+fh8k6IP7AkHf3IvAJuCXQEkO1fZj4C1gDUF4jgmptikE3SlrgPrEMCMXtl0XtYW+7YDTgDcTNawFbknMPwFYAWwGfgoMy6Halia221rgP0icCRPWAEyl7SyXftluuvRfRCQicq3LRUREekmBLiISEQp0EZGIUKCLiESEAl1EJCIU6CIiEaFAFxGJiP8BeZT+egds+vkAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "df=pd.DataFrame(model.learning_process).iloc[:,:2]\n", "df.columns=['epoch', 'train_RMSE']\n", "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deXhU5fXA8e+BAGFxYYkVCEtkadlDDSAFFKTIKlLcQFERLFqhYBUqIG4oFaUK0iJuP6CgJeCKW0GEWKwbhFX2sEoABdkUBRF4f3+cOzCEQCZhkjszOZ/nmSeZe2cm5zJ65s57z3tecc5hjDEmdhXxOwBjjDH5yxK9McbEOEv0xhgT4yzRG2NMjLNEb4wxMS7O7wCyqlChgqtevbrfYRhjTFRZvHjxd865hOz2RVyir169Ounp6X6HYYwxUUVEtp5pX45DNyIySUR2icjKM+wXERkvIhtEZIWI/DZo320ikuHdbstb+MYYY85FKGP0U4AOZ9nfEajl3foBEwFEpBzwMNAMaAo8LCJlzyVYY4wxuZdjonfOLQD2nuUh1wBTnfoCuFBEKgLtgbnOub3OuX3AXM7+gWGMMSYfhGOMvjKwLeh+prftTNtPIyL90G8DVK1aNQwhGWMK0i+//EJmZiaHDx/2O5SYFx8fT2JiIsWKFQv5ORFxMdY59yLwIkBKSoo13zEmymRmZnLeeedRvXp1RMTvcGKWc449e/aQmZlJUlJSyM8LRx39dqBK0P1Eb9uZthtjYszhw4cpX768Jfl8JiKUL18+19+cwpHo3wFu9apvLgMOOOd2AnOAq0SkrHcR9ipvmzEmBlmSLxh5+XfOcehGRKYDrYEKIpKJVtIUA3DOPQ98AHQCNgA/Abd7+/aKyGPAIu+lRjrnznZRNwwOAhuBRvn7Z4wxJorkmOidcz1z2O+A/mfYNwmYlLfQ8mIi8FfgOvTzqH7B/WljjIlQMdbr5g7gQXSEqCHQA1jja0TGmPy3f/9+nnvuuVw/r1OnTuzfvz/Xz+vduzdJSUkkJyfTqFEj5s2bd2Jf69atqVq1KsGLOnXr1o0yZcoAcPz4cQYOHEj9+vVp0KABTZo0YfPmzYB2BmjQoAHJyckkJyczcODAXMeWnYiougmfssBIYBDwDPAssBuYd7YnGWOiXCDR33333adsP3r0KHFxZ05zH3zwQZ7/5pgxY7juuutIS0ujX79+ZGRknNh34YUX8umnn9KyZUv279/Pzp07T+ybMWMGO3bsYMWKFRQpUoTMzExKly59Yn9aWhoVKlTIc1zZibFEH1AeGAXcAwQ+rbcDDwHDgRo+xWVM7LvnHli2LLyvmZwM48adef/QoUPZuHEjycnJFCtWjPj4eMqWLcvatWtZv3493bp1Y9u2bRw+fJhBgwbRr18/4GRvrYMHD9KxY0datmzJZ599RuXKlZk1axYlS5bMMbbmzZuzffupBYU9evQgNTWVli1b8uabb9K9e3dWrVoFwM6dO6lYsSJFiuiASmJiYh7/VUIXY0M3WSWgnRkAFgL/Bn6NDvFs9isoY0yYjR49mho1arBs2TLGjBnDkiVLePbZZ1m/fj0AkyZNYvHixaSnpzN+/Hj27Nlz2mtkZGTQv39/Vq1axYUXXsgbb7wR0t+ePXs23bp1O2Vb27ZtWbBgAceOHSM1NZUbb7zxxL4bbriBd999l+TkZO677z6WLl16ynPbtGlzYuhm7Nixuf2nyFaMntFn5w/AJuBJ4Hn0GnEXYBZgZWHGhMvZzrwLStOmTU+ZUDR+/HjeeustALZt20ZGRgbly5c/5TmBMXeASy+9lC1btpz1bwwZMoThw4eTmZnJ559/fsq+okWL0rJlS1JTUzl06BDBrdcTExNZt24d8+fPZ/78+bRt25bXXnuNtm3bAvkzdBPjZ/RZVQTGoSWYD6FDOIEkPw5Y5VNcxphwCh7z/vjjj/noo4/4/PPPWb58OY0bN852wlGJEiVO/F60aFGOHj161r8xZswY1q9fz5NPPkmfPn1O29+jRw8GDhzIDTfckO3f6tixI2PGjGH48OG8/fbbuTm8XCtkiT6gMvAIEPhatAMty6wPXAa8BPzgS2TGmNw777zz+OGH7P+fPXDgAGXLlqVUqVKsXbuWL774Iqx/e8CAARw/fpw5c06dD9qqVSuGDRtGz56nVqgvWbKEHTt2AFqBs2LFCqpVqxbWmLIqpIk+q0roxdqn0QTfDz37n+9nUMaYEJUvX54WLVpQv359hgwZcsq+Dh06cPToUerUqcPQoUO57LLLwvq3RYQRI0bw1FNPnbZ98ODBpw3D7Nq1i6uvvpr69evTsGFD4uLiGDBgwIn9wWP0t956a3hiDK71jAQpKSnO3xWmHPAlMBkdz78QeBNtxHkLUM6/0IyJUGvWrKFOnTp+h1FoZPfvLSKLnXMp2T3ezuhPI+jwzQtokgd4Hy3VrIQm+0/QDwRjjIl8luhD8n/AMrQs8x3gcuBmXyMyxuS//v37nxhGCdwmT57sd1i5VojKK89VI+Cf6HDOa8BF3vY96Nn+HegHgJVqGhMrJkyY4HcIYWGJPtdKA72D7q8A3gVeAaqjqyW29245z6ozxpj8ZkM356wNWp45BS3PnIZOzgrMvFsCpAPH/QjOGGMs0YdHKeA29Mx+L/A5uqAWwGNAE3Sopyf6gbCj4EM0xhRalujDrjhatRPwAjqs0wlIQ9dl6Ry0/2usgscYk58s0ee7i9AKnanomfxSYLy37xBQB7gEvaD7MXD2adfGmNPltR89wLhx4/jpp5/O+phAn/iGDRtyxRVXsHXr1hP7RIRevXqduH/06FESEhLo0qULAN9++y1dunShUaNG1K1bl06dOgGwZcsWSpYseUpFz9SpU/N0DDmxRF+gigDJQCvvvkN75tdHG621AX4FhNY1zxij8jvRgzYbW7FiBa1bt+bxxx8/sb106dKsXLmSQ4cOATB37lwqV658Yv9DDz1Eu3btWL58OatXr2b06NEn9gU6bgZu4ZoJm5Ulel+VQssy3wW+QxN8Z072y5+LDvk8i66UZUM8Jlq0zuYWSMQ/nWH/FG//d9nsO7vgfvRDhgxhzJgxNGnShIYNG/Lwww8D8OOPP9K5c2caNWpE/fr1mTFjBuPHj2fHjh20adOGNm3ahHRk2fWf79SpE++//z4A06dPP6W/zc6dO0/pOd+wYcOQ/k44WaKPGGWA7ugQT7K3bT+65vo9QF2gGvrBYA3XjAkW3I++Xbt2ZGRksHDhQpYtW8bixYtZsGABs2fPplKlSixfvpyVK1fSoUMHBg4cSKVKlUhLSyMtLS2kv5Vd//nAQiOHDx9mxYoVNGvW7MS+/v3707dvX9q0acOoUaNONDQDTnw4BW6ffPJJeP5BsrA6+oh2vXfbjJ7dzwH+h9byAzyFJv2r0AvAxXyI0ZjsfHyWfaVy2F8hh/1n9+GHH/Lhhx/SuHFjAA4ePEhGRgatWrXivvvu4/7776dLly60atUqh1c6VZs2bdi7dy9lypThscceO2Vfw4YN2bJlC9OnTz8xBh/Qvn17Nm3axOzZs/nPf/5D48aNWblyJXBy6Ca/2Rl9VEhCO2q+gQ7hBN62xcDf0Bm55dH6/Rl+BGhMxHDOMWzYsBPj3hs2bKBv377Url2bJUuW0KBBA0aMGMHIkSNz9bppaWls3bqV5OTkE8NBwbp27crgwYNPa0sMUK5cOW666SamTZtGkyZNWLBgQZ6PLy8s0Ued4BYLM9CJWW+gNfpLgI+8fQ540Lv/c0EGaEyBC+5H3759eyZNmsTBgwcB2L59O7t27WLHjh2UKlWKXr16MWTIEJYsWXLac3MSFxfHuHHjmDp1Knv37j1lX58+fXj44Ydp0KDBKdvnz59/4mLvDz/8wMaNG6lateo5HW9u2dBN1LsQHdvvjib3Q972zejQzuPoV+Ur0fYM3dFe+8bEjuB+9B07duSmm26iefPmAJQpU4ZXXnmFDRs2MGTIEIoUKUKxYsWYOHEiAP369aNDhw4nxupzUrFiRXr27MmECRN48MEHT2xPTExk4MCBpz1+8eLFDBgwgLi4OI4fP84dd9xBkyZN2LJly4kx+oA+ffpk+xrnyvrRx7Qf0Ulas4H/oGvmvgNcDWz17rfExvbNubJ+9AUrt/3o7Yw+ppVGF0Dvgp7tb0CXUQT4F/AwcD56pt8F6IheCDPGxBJL9IWGALWC7t8LNATeQxdWmYl+MOwBSqC1zOWxtsumMGnWrBk//3zqNa1p06adNu4ebSzRF1plgG7e7TjammENmuRBz/J3oaWbV6Kzdm1s35yZcw6R6D4x+PLLL/0OIUd5GW63qhuD/mdwKRDo1+GAu4EUtKLnZnQZxT8HPedAQQZoIlx8fDx79uzJUxIyoXPOsWfPHuLj43P1vJDO6EWkAzoPvyjwsnNudJb91YBJQALap7eXcy7T23cM+Mp76NfOua65itD4QIA+3u0YuoziPLQBG8B2oCrQGGiLnvG35ORELlPYJCYmkpmZye7du/0OJebFx8ef0lIhFDlW3YhIUWA90A7IBBYBPZ1zq4Me8xrwnnPuXyJyJXC7c+4Wb99B51yZUAOyqpto8C3ahG0+2nv/F/Qc4G30ou5e9Iy/OjbGb0zBOFvVTShDN02BDc65Tc65I0AqcE2Wx9RF/68HrefLut/ElF+hFTv/Bfah5ZtD0XV1QS/sXoIO93QHxgCfoh8IxpiCFkqirwxsC7qfyckavYDl6P/RoPPwzxOR8t79eBFJF5EvRKQb2RCRft5j0u2rX7Qpja6P+zhQxdvWHu1U+Ht0Td2/okM7B739n6DnAzZj15iCEK6LsYOBK0RkKXAFOoh7zNtXzfs6cRMwTkRqZH2yc+5F51yKcy4lISEhTCEZ/yQBf0LXz90AfAN8CJT19v8NHdcvh7ZhHgusPv1ljDFhEUqi387JUzXQxVBPacbsnNvhnOvunGsMPOBt2+/93O793IS2pGt87mGf7tgx6NMHPv00P17dnJtfoZd4AmYAs9CLvRvRmv67gvbPQa8DGGPCIZREvwioJSJJIlIc6IHOoz9BRCqISOC1hqEVOIhIWREpEXgM0IJ8OnXbvBlmz4aWLeHaayEjIz/+igmP84GuwD+AdcAWtKgLtO1yF+Bi9JzgfrTix4Z5jMmrHBO9c+4oMAA9zVoDzHTOrRKRkSISKJVsDawTkfXo6dsob3sdIF1ElqODsqODq3XCqWZNTe4jR8KcOVC3LgwcCN99lx9/zYRXNU5+0SuNVvKMQhu2jUXH+v/h7f8BPVewem1jQhWTTc2++QYeeQReegnKlIHhwzXplywZnhhNQfoBre5piNbuvwbcgNYDXIVe+G2L9egxhd3ZyitjMtEHrF4N998P770HVarAqFFw881QxOYDR7Fv0P48H6K99vd523eiwz0z0A+GRO9WxftZE6vpN7HsXOvoo1bduvDuu5CWBhddBLfeCikpMH9+zs81kepidN3cmcBu4EtgNHCRt38tmuwfAG5Dq3vqcXKoZyy6WtdzwGecLPk0JnbFdKIPaN0aFi6EV16BPXugbVvo2BE++8zvyMy5KYrO57ufk/8pP4x24PwRndA9D/h30P5taP+e/mhtwPloRXDAEmAHdg3AxJJCkehBh2tuvhnWrYMnn4RFi6BFC/0QmDsXImwEy5yzUmhb5iuB64K2P4O2YP4aLR57FB3nD7gRHf+/GK0MehJN/sZEr0KT6APi4+Gvf4WtW2HsWK3UueoqaNYM3n4bjh/3O0KT/wQdu78aXVd3eNC+SWipZ0e09HMoMNHbdxytHn4TvVZgTHSI6Yuxofj5Z5g6FUaPhk2bdFx/2DDo0QPirFu/YTdwGP1g2AL8hpM1/Zegwz8D0CEkY/xTaC/GhqJECfjjH3VI59VXQQRuuQVq14YXXtAPAlOYJXByYnh1tCvn58DfgWS0+ment/+/QG20/PNvwAfYeL+JBIU+0QfExcFNN8GKFTBrFiQkwF13wSWXaFnmtzYj3wC6AtdlwH3oRd2d6ExegHi03n8JWvXTGR3vD3xDXQq8gi7PYJ08TcGxRJ9FkSLQtSt88QV89JEO5YwYoXX4PXvCggV24dYEE7T6B6AZ8DrayO0AsAAYj5Z3gpZ93oJ+GJRBZwP3Bg55+48WSMSm8Cn0Y/ShWLcOnn8eJk+GAwegXj24+27o1QvOP9/v6Ez0+AW9wLs86LYNWIV+YNyC9v1rhC7xkIQOBbX1IVYTbQrtzNhw++knSE2FCRNgyRJtr9Crlyb9KF8k3kSEqeiY/zIgAziCnv0v9/b/AR0qSkKvFySh3xZaFHSgJgJZog8z57QO/7nnNPH//LN2zbzrLu2cmct1e43JxnG0hPMAJ9fqHYqO928BtqJDPZ3RlhCga/+cj34bCNyqYyO0hYMl+ny0Zw9MmQITJ8LGjVC2rJ7l33EHNGzod3Qmdh1DK3p+Qcs8j6O1/yu97QF3ofMAjqEdQQP9fwI3W9A9VliiLwDHj2tPnZdfhjffhCNHoEkTTfg9ethYvilI+9GO4qvR2cGXo98CkrJ57NPowi/foO0jqqBdQmui1wfKY83gooMl+gK2Z4/21XnpJVi1CkqVghtv1KTfvLnW6htT8I6gi8Nt826ZaIuIJuh1gXZoe4hgr6KrgG4AUtHkXxv9ALFvA5HEEr1PnNNmai+/DNOnw48/Qp06mvBvvRUqWAt1E3EOoR8CG9CmcFcDNdCy0euzPLYyen0gGe0auhL9JlADOK+A4jUBlugjwA8/wMyZmvS/+AKKF9cLt3feCZdfbmf5Jhr8xMkPgPVoqeiTaAO40WgfoIBfoUn/bXRRmLXALqCidytTYFEXFpboI8zKlfDii9pj58AB+M1voF8/uO02KFfO7+iMyYuD6IdA4LbRu80BiqFtoZ8LenwZ9HrASrQq6A1gM1AJvS5QGq0gClQ0/AjEAcWxawbZs0QfoX76Sc/yX3hBz/JLlIDrr9ez/BYt7CzfxJIt6NyAnd5tB/oN4SVv/3Vosg9WBW0nDdAB/dAogragLo2WjwZWEZrEyeZzgQvKZSlMHwqW6KPA8uV6lj9tmg7z1KunZ/m33KIlm8bENgd8j34I7EXP4AVdGB70GkGGt/1H9EOiElopBDqbeEWW12yHTkAD+Av6QXCedysD1PceA9qoroS3vaR3K4P2L4oOluijyI8/6iSsF17QSVnx8XDdddC3L1xxhZ3lG5O948C36DeAbd7Pi4Be3v4UdJLZQTThA9yMNpkD/ZYQ6DkUcCfwvPfaZdGkH49+CMQDtwOD0LbVw7zHlPNuZdFvHFXRD7HjnOyJlD8s0UepJUv04u2//61j+TVrQp8+0Ls3VKzod3TGRKujnFwr+ELv53zgB+92GE36ddE+Q0eBIUHbD3m/d0fXJd6FVhplXX/4CXQ282Z0Utv5nPwQKIt+y+iCfkBN9rb1Rr9Z5J4l+ij300/wxhua9BcsgKJFoXNnPcvv1MkWSDEmMvwC7EOHnvah5adV0cVrJnjbgm9DgGvQReoD/Yp+Ri84554l+hiyfj1MmqRtF779Vs/sb7tNk37Nmn5HZ4zJm0No8q+U51ewFaZiSO3auuzhtm26xu2ll8JTT0GtWtChA8yebf3yjYk+JTmXJJ8TS/RRqlgxuOYaePdd+PprGDlSV8fq2FErdl54QYd8jDHGEn0MqFwZHnwQtmzR8sySJbVlcpUqMHw4bN/ud4TGGD9Zoo8hxYtri+T0dL1o27o1PPkkVK+u6+EuWuR3hMYYP1iij0Ei0KqVVups2AB//jO89x40baozbmfM0MVSjDGFgyX6GJeUBM88A5mZMG4cfPON9sevXBnuvVfbKBtjYltIiV5EOojIOhHZICJDs9lfTUTmicgKEflYRBKD9t0mIhne7bZwBm9Cd/75MGiQlmfOng1t2sA//wn168PvfqclmwezzvcwxsSEHBO9iBRFq/07olPFeopI3SwP+zsw1TnXEBiJTglDRMqhzSiaAU2Bh0XEOrf4qGhRaN8eXntNz/L//nfYt0/r8CtW1P46CxdaiaYxsSSUM/qmwAbn3Cbn3BF0mZlrsjwmuI1cWtD+9sBc59xe59w+YC7ahs5EgIsugvvug9Wr4X//0546r74KzZpBo0bw7LP6IWCMiW6hJPrKaJeggExvW7DlaOMHgD8A54lI+RCfi4j0E5F0EUnfvXt3qLGbMBHRi7STJ8OOHfD889pM7Z57dCy/b19YvNjvKI0xeRWui7GDgStEZClwBbow5bFQn+yce9E5l+KcS0lISAhTSCYvLrhA++EvXAhLl2qb5NRUSEnRqp0pU+BQ1iZ/xpiIFkqi34528g9I9Lad4Jzb4Zzr7pxrDDzgbdsfynNN5EpO1hm2O3bA+PHaJ//22/Usf/BgLd00xkS+UBL9IqCWiCSJSHGgB/BO8ANEpIKIBF5rGLrcC+iSMFeJSFnvIuxV3jYTRS64QGvxV6+GtDT4/e91/D7QX2fWLDh61O8ojTFnkmOid84dBQagCXoNMNM5t0pERopIV+9hrYF1IrIeXRV4lPfcvcBj6IfFImCkt81EIRGdbTtzJmzdCo8+Cl99Bd26QY0aOgt3zx6/ozTGZGVtis05OXoU3nkHJkyA+fP1Iu7NN+s3gEaN/I7OmMLD2hSbfBMXB927w7x5enZ/6626IlZysi59+MYbNqxjjN8s0ZuwqV9fL95mZsKYMdo++brr4JJLtIf+d9/5HaExhZMlehN25cqdrMqZNUsXSxk2DBITtSZ/+XK/IzSmcLFEb/JN0aLQtSt89BGsXKmlmampOqzTpo1+CBwLebaFMSavLNGbAlGvHkyceHJYZ9MmrdapXVu7an7/vd8RGhO7LNGbAlW2rA7rbNwIr78OlSrBX/6iwzqDBtkkLGPygyV644u4OLj2WvjkE10Rq1s3PeOvXVuHe+bPtw6axoSLJXrju0svhalTdRLWgw/CF19A27Zahz9pEhw+7HeExkQ3S/QmYlSsqLNtv/5aE7yIVulUrQoPPQQ7d/odoTHRyRK9iTjx8Vqhs2yZDuE0bw6PPw7VqumErCVL/I7QmOhiid5ELJGTZZjr18Of/gRvvaVDPZdfDm++aeWZxoTCEr2JCjVrasfMzEx4+mkd3rn2Wt3+zDOwf7/fERoTuSzRm6hywQVw771ahvnGG1Clii6HWKkS/PGPOtxjjDmVJXoTlQLN1BYs0GUOb7pJ17tt3Bh+9zt45RX4+We/ozQmMliiN1Hvt7+Fl1+G7dth7FhtnnbLLToJa9gw2LLF7wiN8ZclehMzypbVBc3XroUPP4SWLeGpp7R7ZteuMHs2HD/ud5TGFDxL9CbmFCkC7dpphc7mzTB8OHz5JXTsCHXq6ALnv/zid5TGFBxL9CamVa2qNfhff61j+KVKaY1+rVracsFm3ZrCwBK9KRRKlNALtkuWwHvv6Szcu+/WYZ2nn4aDB/2O0Jj8Y4neFCoi0LkzfPaZLn9Yp45206xeXc/8rR7fxCJL9KZQEoErr9Rk/9lncNll2lCtWjV44AHYvdvvCI0JH0v0ptBr3lyHc5Ysgauugiee0DP8e+/Vkk1jop0lemM8jRvDa6/BqlXaXmH8eEhK0hm3GRl+R2dM3lmiNyaLOnW0P35Ghib5adPgN7+BHj1sYXMTnSzRG3MGSUkwYYLOrB0yBD74QBc279IFPv3U7+iMCZ0lemNycPHFMHq0roD12GM6+aplS7jiCpgzx5Y8NJHPEr0xISpbFkaM0DP8ceNg0ybo0AGaNNGe+ZbwTaSyRG9MLpUuDYMGwcaN2kxt/35d3LxxY10MxfrpmEhjid6YPCpeXNe0XbsW/vUv+OknrdZJTtbqHUv4JlJYojfmHMXF6Vq2q1drH/wjR+CGG6BBA0hNteUOjf9CSvQi0kFE1onIBhEZms3+qiKSJiJLRWSFiHTytlcXkUMissy7PR/uAzAmUsTFwc03ax3+9Om6rWdPqF9fG6pZwjd+yTHRi0hRYALQEagL9BSRulkeNgKY6ZxrDPQAngvat9E5l+zd7gpT3MZErKJFteb+q69gxgz9AOjVC+rWhZkz7aKtKXihnNE3BTY45zY5544AqcA1WR7jgPO93y8AdoQvRGOiU5EiOoSzfLmub1u8ONx4o7Zc+N///I7OFCahJPrKwLag+5netmCPAL1EJBP4APhz0L4kb0jnvyLSKrs/ICL9RCRdRNJ3WzcpE2OKFNH1bZctg//7P9i2DVq1gj/8Adat8zs6UxiE62JsT2CKcy4R6ARME5EiwE6gqjekcy/wbxE5P+uTnXMvOudSnHMpCQkJYQrJmMhStCj06QPr12tL5I8+gnr1oH9/2LXL7+hMLAsl0W8HqgTdT/S2BesLzARwzn0OxAMVnHM/O+f2eNsXAxuB2ucatDHRrHRpbYW8cSPceSe88ALUrAmjRmmJpjHhFkqiXwTUEpEkESmOXmx9J8tjvgbaAohIHTTR7xaRBO9iLiJyCVAL2BSu4I2JZhddpL10Vq2Ctm111m3t2jB5slXomPDKMdE7544CA4A5wBq0umaViIwUka7ew+4D/igiy4HpQG/nnAMuB1aIyDLgdeAu59ze/DgQY6LVr3+tC5kvWACJiTq807ix9tExJhzERVitV0pKiktPT/c7DGN84Ry8/joMHaq9dNq1gzFjoFEjvyMzkU5EFjvnUrLbZzNjjYkgInD99bBmjTZOW7xYz+5794bMTL+jM9HKEr0xEah48ZON0wYP1lYKtWrB8OHw/fd+R2eijSV6YyLYhRfCU09pvf211+p6tjVq6EXcX37xOzoTLSzRGxMFqlXThmmLFmnvnAED9Odbb1lLBZMzS/TGRJGUFJg/H957Tydgde+uK11Z/YI5G0v0xkQZEejcGVasgOef12GdJk20VbJdsDXZsURvTJSKi9OZtRkZWo45c6ZOuHroITh40O/oTCSxRG9MlDv/fL1Iu3YtXHONLmBeqxZMmmQzbI2yRG9MjKheXRc8+ewz/b1vX7j0Uh3TN4WbJXpjYkzz5prsU1N14fK2baFrV2uJXJhZojcmBonoIidr18Lo0fDxx1qO+Ze/wL59fkdnCgC5cRcAAA2nSURBVJolemNiWHw83H8/bNigzdKefVbH7597Do4e9Ts6U1As0RtTCFx0kfa9X7oUGjTQxU6Sk2HuXL8jMwXBEr0xhUijRnpx9s034dAhuOoqHb9fv97vyEx+skRvTCEjouvVrl6t4/dpaTp+f999evHWxB5L9MYUUiVK6Ph9RobOqh07Vsfvn3/e6u9jjSV6Ywq5iy+Gl1/Wfjl16sCf/qQlmsuW+R2ZCRdL9MYYAH77W/jvf+Hf/4atW7WB2v3324LlscASvTHmBBHo2VNXuOrdW3vh168PH37od2TmXFiiN8acplw5Hc5JS4NixaB9e+jVC3bt8jsykxeW6I0xZ9S6NSxfDg8+qN0x69SBKVNssZNoY4neGHNW8fEwcqRenK1TB26/XfvnZGT4HZkJlSV6Y0xI6taFBQu0/HLxYp1hO2oUHDnid2QmJ5bojTEhK1JEFztZswauvhpGjNBqnc8/9zsyczaW6I0xuVapErz2GrzzDnz/PbRoAXffDQcO+B2ZyY4lemNMnl19NaxaBQMH6pBOnTraR8cu1kYWS/TGmHNy3nkwbhx8+aV2ybz2WujWDbZt8zsyE2CJ3hgTFk2awKJFOslq7ly9ePuPf1jfnEhgid4YEzbFisGQITqc06KFDun87ndai2/8E1KiF5EOIrJORDaIyNBs9lcVkTQRWSoiK0SkU9C+Yd7z1olI+3AGb4yJTElJ8J//aN+czZt1kfLhw+HwYb8jK5xyTPQiUhSYAHQE6gI9RaRuloeNAGY65xoDPYDnvOfW9e7XAzoAz3mvZ4yJcYG+OWvXwi23wBNP6KpWn37qd2SFTyhn9E2BDc65Tc65I0AqcE2WxzjgfO/3C4Ad3u/XAKnOuZ+dc5uBDd7rGWMKiXLlYPJkmDNHz+hbtdIhnYMH/Y6s8Agl0VcGgq+fZ3rbgj0C9BKRTOAD4M+5eK4xphC46ipYuRIGDIB//tO6YhakcF2M7QlMcc4lAp2AaSIS8muLSD8RSReR9N27d4cpJGNMpClTBsaP11YK8fHaFfP222HfPr8ji22hJOPtQJWg+4netmB9gZkAzrnPgXigQojPxTn3onMuxTmXkpCQEHr0xpio1LKlNkkbNgymTdNSzLfe8juq2BVKol8E1BKRJBEpjl5cfSfLY74G2gKISB000e/2HtdDREqISBJQC1gYruCNMdErPh7+9jdYuFCXM+zeHW64Ab791u/IYk+Oid45dxQYAMwB1qDVNatEZKSIdPUedh/wRxFZDkwHeju1Cj3TXw3MBvo752z6hDHmhN/+VpP9qFEwaxbUqwczZvgdVWwRF2FNKVJSUlx6errfYRhjfBBYwnDhQrj+epgwAWw0NzQistg5l5LdPpsZa4yJGHXqaJ39E0+cPLu3sftzZ4neGBNR4uJg6FBd3CQxUcfue/WCvXv9jix6WaI3xkSk+vW1I+ajj+qYff368P77fkcVnSzRG2MiVrFi8NBDOmZfoQJ06QJ9+tgCJ7llid4YE/EaN4b0dHjgAZg61WbV5pYlemNMVCheHB5/XNenPe88nVV7xx2wf7/fkUU+S/TGmKjSpAksWQL33w9Tpuis2lmz/I4qslmiN8ZEnfh4GD1aL9YmJOjShT16wK5dfkcWmSzRG2Oi1qWX6tj9Y49pvX3duvDqq7Y4eVaW6I0xUa1YMRgxApYuhVq1tOb+6qttcfJgluiNMTGhbl343/9g3DhIS9NZtS+8AMeP+x2Z/yzRG2NiRtGiMGgQfPUVNG0Kd90FV14JGzb4HZm/LNEbY2LOJZfA3Lnw0ks6pNOoEUycWHjH7i3RG2NikojW2a9erQud3H03dOoEO3bk/NxYY4neGBPTKleG2bN1ndr//hcaNIDXXvM7qoJlid4YE/NEoH9/HcapUUNXsurVq/DMqrVEb4wpNH79a+13/8gjkJqqZ/fz5vkdVf6zRG+MKVSKFYOHH9aeOaVLw+9/D/fcA4cO+R1Z/rFEb4wplAI9cwYMgGef1Vm2ixf7HVX+sERvjCm0SpWCf/xDWx4fOACXXQZPPhl7k6ws0RtjCr127WDlSm2ONnQodOgA33zjd1ThY4neGGOAsmVh5kxtm/DJJzrJas4cv6MKD0v0xhjjEYF+/bQjZkKCntn/9a9w5IjfkZ0bS/TGGJNFvXqwaJH2yhkzRmfWbtrkd1R5Z4neGGOyUbKk9sd5/XXIyIDkZJg+3e+o8sYSvTHGnMW118KyZTq56qaboG9f+PFHv6PKHUv0xhiTg2rVtE/OAw/A5MmQkgLLl/sdVegs0RtjTAji4uDxx+Gjj7TmvlkzeO656Gh9bIneGGNy4cor9Wy+bVttlHbddbBvn99RnZ0lemOMyaWEBHj3XXj6af2ZnAyffeZ3VGcWUqIXkQ4isk5ENojI0Gz2jxWRZd5tvYjsD9p3LGjfO+EM3hhj/FKkCNx7r3bDjIuDyy+HJ56IzPYJcTk9QESKAhOAdkAmsEhE3nHOrQ48xjn3l6DH/xloHPQSh5xzyeEL2RhjIkegOdqdd8Lw4TB/PkybBhdf7HdkJ4VyRt8U2OCc2+ScOwKkAtec5fE9gSitNjXGmNy74AKtsX/pJT3Db9RIG6VFilASfWVgW9D9TG/baUSkGpAEzA/aHC8i6SLyhYh0O8Pz+nmPSd+9e3eIoRtjTOQIrFG7aJGO4bdvD8OGwS+/+B1Z+C/G9gBed84dC9pWzTmXAtwEjBORGlmf5Jx70TmX4pxLSUhICHNIxhhTcOrVg4ULtWfO6NE6dr9li78xhZLotwNVgu4netuy04MswzbOue3ez03Ax5w6fm+MMTGnVCntgpmaCqtXQ+PG8Oab/sUTSqJfBNQSkSQRKY4m89OqZ0TkN0BZ4POgbWVFpIT3ewWgBbA663ONMSYW3XijLkhes6a2UhgwAA4fLvg4ckz0zrmjwABgDrAGmOmcWyUiI0Wka9BDewCpzp0yT6wOkC4iy4E0YHRwtY4xxsS6Sy7RC7T33gsTJugqVuvWFWwM4iJs/m5KSopLT0/3OwxjjAm799+H227Ts/qJE+GWW8L32iKy2LseehqbGWuMMQWkc2fthHnppXDrrdC7Nxw8mP9/1xK9McYUoMREmDcPHnoIpk4tmE6YluiNMaaAxcXBo49qwv/++/zvhGmJ3hhjfNKmjQ7ltGmjnTBvuCF/euXk2OvGGGNM/rnoIr1I+8wzsH+/NksLN0v0xhjjsyJFYPDgfHz9/HtpY4wxkcASvTHGxDhL9MYYE+Ms0RtjTIyzRG+MMTHOEr0xxsQ4S/TGGBPjLNEbY0yMi7g2xSKyG9iaZXMF4DsfwslPsXZMsXY8EHvHFGvHA7F3TOdyPNWcc9muxRpxiT47IpJ+pj7L0SrWjinWjgdi75hi7Xgg9o4pv47Hhm6MMSbGWaI3xpgYFy2J/kW/A8gHsXZMsXY8EHvHFGvHA7F3TPlyPFExRm+MMSbvouWM3hhjTB5ZojfGmBgXcYleRCaJyC4RWRm0rZyIzBWRDO9nWT9jzI0zHM8jIrJdRJZ5t05+xphbIlJFRNJEZLWIrBKRQd72qHyfznI8Ufs+iUi8iCwUkeXeMT3qbU8SkS9FZIOIzBCR4n7HGoqzHM8UEdkc9B4l+x1rbohIURFZKiLveffz5f2JuEQPTAE6ZNk2FJjnnKsFzPPuR4spnH48AGOdc8ne7YMCjulcHQXuc87VBS4D+otIXaL3fTrT8UD0vk8/A1c65xoByUAHEbkMeBI9pprAPqCvjzHmxpmOB2BI0Hu0zL8Q82QQsCbofr68PxGX6J1zC4C9WTZfA/zL+/1fQLcCDeocnOF4oppzbqdzbon3+w/of6iVidL36SzHE7WcOujdLebdHHAl8Lq3PZreozMdT9QSkUSgM/Cyd1/Ip/cn4hL9GfzKObfT+/0b4Fd+BhMmA0RkhTe0ExVDHNkRkepAY+BLYuB9ynI8EMXvkzcssAzYBcwFNgL7nXNHvYdkEkUfaFmPxzkXeI9Gee/RWBEp4WOIuTUO+Ctw3Ltfnnx6f6Il0Z/gtB40qj/JgYlADfQr6E7gaX/DyRsRKQO8AdzjnPs+eF80vk/ZHE9Uv0/OuWPOuWQgEWgK/MbnkM5J1uMRkfrAMPS4mgDlgPt9DDFkItIF2OWcW1wQfy9aEv23IlIRwPu5y+d4zolz7lvvP9rjwEvo/4RRRUSKoUnxVefcm97mqH2fsjueWHifAJxz+4E0oDlwoYjEebsSge2+BZZHQcfTwRt2c865n4HJRM971ALoKiJbgFR0yOZZ8un9iZZE/w5wm/f7bcAsH2M5Z4Fk6PkDsPJMj41E3lji/wFrnHPPBO2KyvfpTMcTze+TiCSIyIXe7yWBdui1hzTgOu9h0fQeZXc8a4NOLAQdz46K98g5N8w5l+icqw70AOY7524mn96fiJsZKyLTgdZou85vgYeBt4GZQFW0hfENzrmouMB5huNpjQ4HOGALcGfQ2HbEE5GWwCfAV5wcXxyOjmtH3ft0luPpSZS+TyLSEL2YVxQ9oZvpnBspIpegZ5DlgKVAL+9sOKKd5XjmAwmAAMuAu4Iu2kYFEWkNDHbOdcmv9yfiEr0xxpjwipahG2OMMXlkid4YY2KcJXpjjIlxluiNMSbGWaI3xpgYZ4neGGNinCV6Y4yJcf8PWGIUcmNiHUIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "df=pd.DataFrame(model.learning_process[10:], columns=['epoch', 'train_RMSE', 'test_RMSE'])\n", "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n", "plt.plot('epoch', 'test_RMSE', data=df, color='yellow', linestyle='dashed')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Saving and evaluating recommendations" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "model.estimations()\n", "\n", "top_n=pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n", "\n", "top_n.to_csv('Recommendations generated/ml-100k/Self_SVD_reco.csv', index=False, header=False)\n", "\n", "estimations=pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n", "estimations.to_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', index=False, header=False)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 5237.55it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
00.9138610.7170630.104030.0441090.0533390.070510.0943130.0758140.1076920.0510470.2012730.5187820.4814420.872110.1464653.8814170.972029
\n", "
" ], "text/plain": [ " RMSE MAE precision recall F_1 F_05 \\\n", "0 0.913861 0.717063 0.10403 0.044109 0.053339 0.07051 \n", "\n", " precision_super recall_super NDCG mAP MRR LAUC \\\n", "0 0.094313 0.075814 0.107692 0.051047 0.201273 0.518782 \n", "\n", " HR Reco in test Test coverage Shannon Gini \n", "0 0.481442 0.87211 0.146465 3.881417 0.972029 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import evaluation_measures as ev\n", "\n", "estimations_df=pd.read_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', header=None)\n", "reco=np.loadtxt('Recommendations generated/ml-100k/Self_SVD_reco.csv', delimiter=',')\n", "\n", "ev.evaluate(test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None),\n", " estimations_df=estimations_df, \n", " reco=reco,\n", " super_reactions=[4,5])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 6996.89it/s]\n", "943it [00:00, 5574.45it/s]\n", "943it [00:00, 5909.51it/s]\n", "943it [00:00, 6568.51it/s]\n", "943it [00:00, 5488.25it/s]\n", "943it [00:00, 5363.29it/s]\n", "943it [00:00, 6280.36it/s]\n", "943it [00:00, 5709.71it/s]\n", "943it [00:00, 6279.20it/s]\n", "943it [00:00, 5819.38it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Self_SVD0.9138610.7170630.1040300.0441090.0533390.0705100.0943130.0758140.1076920.0510470.2012730.5187820.4814420.8721100.1464653.8814170.972029
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Self_GlobalAvg1.1257600.9435340.0611880.0259680.0313830.0413430.0405580.0321070.0676950.0274700.1711870.5095460.3849421.0000000.0259742.7117720.992003
0Ready_Random1.5177871.2179530.0478260.0178610.0227110.0310800.0282190.0169820.0511540.0195510.1256930.5054480.3181340.9864260.1868695.0917300.908288
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Self_SVD 0.913861 0.717063 0.104030 0.044109 0.053339 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n", "0 Ready_Random 1.517787 1.217953 0.047826 0.017861 0.022711 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.070510 0.094313 0.075814 0.107692 0.051047 0.201273 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n", "0 0.031080 0.028219 0.016982 0.051154 0.019551 0.125693 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.518782 0.481442 0.872110 0.146465 3.881417 0.972029 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.509546 0.384942 1.000000 0.025974 2.711772 0.992003 \n", "0 0.505448 0.318134 0.986426 0.186869 5.091730 0.908288 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import imp\n", "imp.reload(ev)\n", "\n", "import evaluation_measures as ev\n", "dir_path=\"Recommendations generated/ml-100k/\"\n", "super_reactions=[4,5]\n", "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embeddings" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1, 2],\n", " [3, 4]])" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "array([[0.4472136 , 0.89442719],\n", " [0.6 , 0.8 ]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x=np.array([[1,2],[3,4]])\n", "display(x)\n", "x/np.linalg.norm(x, axis=1)[:,None]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
codescoreitem_ididtitlegenres
010511.00000010521052Dracula: Dead and Loving It (1995)Comedy, Horror
111770.95130311781178Major Payne (1994)Comedy
212900.95048912911291Celtic Pride (1996)Comedy
313750.94986413761376Meet Wally Sparks (1997)Comedy
414890.94737514901490Fausto (1993)Comedy
514950.94736814961496Carpool (1996)Comedy, Crime
614970.94734714981498Farmer & Chase (1995)Comedy
714900.94682914911491Tough and Deadly (1995)Action, Drama, Thriller
813200.94615213211321Open Season (1996)Comedy
914870.94542514881488Germinal (1993)Drama
\n", "
" ], "text/plain": [ " code score item_id id title \\\n", "0 1051 1.000000 1052 1052 Dracula: Dead and Loving It (1995) \n", "1 1177 0.951303 1178 1178 Major Payne (1994) \n", "2 1290 0.950489 1291 1291 Celtic Pride (1996) \n", "3 1375 0.949864 1376 1376 Meet Wally Sparks (1997) \n", "4 1489 0.947375 1490 1490 Fausto (1993) \n", "5 1495 0.947368 1496 1496 Carpool (1996) \n", "6 1497 0.947347 1498 1498 Farmer & Chase (1995) \n", "7 1490 0.946829 1491 1491 Tough and Deadly (1995) \n", "8 1320 0.946152 1321 1321 Open Season (1996) \n", "9 1487 0.945425 1488 1488 Germinal (1993) \n", "\n", " genres \n", "0 Comedy, Horror \n", "1 Comedy \n", "2 Comedy \n", "3 Comedy \n", "4 Comedy \n", "5 Comedy, Crime \n", "6 Comedy \n", "7 Action, Drama, Thriller \n", "8 Comedy \n", "9 Drama " ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item=random.choice(list(set(train_ui.indices)))\n", "\n", "embeddings_norm=model.Qi/np.linalg.norm(model.Qi, axis=1)[:,None] # we do not mean-center here\n", "# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n", "\n", "similarity_scores=np.dot(embeddings_norm,embeddings_norm[item].T)\n", "top_similar_items=pd.DataFrame(enumerate(similarity_scores), columns=['code', 'score'])\\\n", ".sort_values(by=['score'], ascending=[False])[:10]\n", "\n", "top_similar_items['item_id']=top_similar_items['code'].apply(lambda x: item_code_id[x])\n", "\n", "items=pd.read_csv('./Datasets/ml-100k/movies.csv')\n", "\n", "result=pd.merge(top_similar_items, items, left_on='item_id', right_on='id')\n", "\n", "result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# project task 5: implement SVD on top baseline (as it is in Surprise library)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# making changes to our implementation by considering additional parameters in the gradient descent procedure \n", "# seems to be the fastest option\n", "# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n", "# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Ready-made SVD - Surprise implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "import imp\n", "imp.reload(helpers)\n", "\n", "algo = sp.SVD(biased=False) # to use unbiased version\n", "\n", "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVD_reco.csv',\n", " estimations_path='Recommendations generated/ml-100k/Ready_SVD_estimations.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD biased - on top baseline" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "import imp\n", "imp.reload(helpers)\n", "\n", "algo = sp.SVD() # default is biased=True\n", "\n", "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv',\n", " estimations_path='Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 5528.02it/s]\n", "943it [00:00, 6531.06it/s]\n", "943it [00:00, 5593.54it/s]\n", "943it [00:00, 5845.59it/s]\n", "943it [00:00, 5997.91it/s]\n", "943it [00:00, 6080.78it/s]\n", "943it [00:00, 6121.00it/s]\n", "943it [00:00, 5934.94it/s]\n", "943it [00:00, 5026.29it/s]\n", "943it [00:00, 5850.46it/s]\n", "943it [00:00, 5530.89it/s]\n", "943it [00:00, 6004.16it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Ready_SVD0.9498920.7492920.1046660.0486110.0556560.0719000.0928110.0782410.1177300.0574640.2440970.5210540.5015910.9980910.2171724.4580010.951551
0Self_SVD0.9138610.7170630.1040300.0441090.0533390.0705100.0943130.0758140.1076920.0510470.2012730.5187820.4814420.8721100.1464653.8814170.972029
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Ready_SVDBiased0.9429940.7434920.0833510.0340970.0416770.0556730.0732830.0529100.0918660.0425980.1982100.5137050.4231180.9975610.1688314.1952340.963319
0Self_GlobalAvg1.1257600.9435340.0611880.0259680.0313830.0413430.0405580.0321070.0676950.0274700.1711870.5095460.3849421.0000000.0259742.7117720.992003
0Ready_Random1.5177871.2179530.0478260.0178610.0227110.0310800.0282190.0169820.0511540.0195510.1256930.5054480.3181340.9864260.1868695.0917300.908288
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Ready_SVD 0.949892 0.749292 0.104666 0.048611 0.055656 \n", "0 Self_SVD 0.913861 0.717063 0.104030 0.044109 0.053339 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Ready_SVDBiased 0.942994 0.743492 0.083351 0.034097 0.041677 \n", "0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n", "0 Ready_Random 1.517787 1.217953 0.047826 0.017861 0.022711 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.071900 0.092811 0.078241 0.117730 0.057464 0.244097 \n", "0 0.070510 0.094313 0.075814 0.107692 0.051047 0.201273 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.055673 0.073283 0.052910 0.091866 0.042598 0.198210 \n", "0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n", "0 0.031080 0.028219 0.016982 0.051154 0.019551 0.125693 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.521054 0.501591 0.998091 0.217172 4.458001 0.951551 \n", "0 0.518782 0.481442 0.872110 0.146465 3.881417 0.972029 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.513705 0.423118 0.997561 0.168831 4.195234 0.963319 \n", "0 0.509546 0.384942 1.000000 0.025974 2.711772 0.992003 \n", "0 0.505448 0.318134 0.986426 0.186869 5.091730 0.908288 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import imp\n", "imp.reload(ev)\n", "\n", "import evaluation_measures as ev\n", "dir_path=\"Recommendations generated/ml-100k/\"\n", "super_reactions=[4,5]\n", "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }