{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Self made SVD" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import helpers\n", "import pandas as pd\n", "import numpy as np\n", "import scipy.sparse as sparse\n", "from collections import defaultdict\n", "from itertools import chain\n", "import random\n", "\n", "train_read=pd.read_csv('./Datasets/ml-100k/train.csv', sep='\\t', header=None)\n", "test_read=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "train_ui, test_ui, user_code_id, user_id_code, item_code_id, item_id_code = helpers.data_to_csr(train_read, test_read)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n", "from tqdm import tqdm\n", "\n", "class SVD():\n", " \n", " def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n", " self.train_ui=train_ui\n", " self.uir=list(zip(*[train_ui.nonzero()[0],train_ui.nonzero()[1], train_ui.data]))\n", " \n", " self.learning_rate=learning_rate\n", " self.regularization=regularization\n", " self.iterations=iterations\n", " self.nb_users, self.nb_items=train_ui.shape\n", " self.nb_ratings=train_ui.nnz\n", " self.nb_factors=nb_factors\n", " \n", " self.Pu=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_users, self.nb_factors))\n", " self.Qi=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_items, self.nb_factors))\n", "\n", " def train(self, test_ui=None):\n", " if test_ui!=None:\n", " self.test_uir=list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n", " \n", " self.learning_process=[]\n", " pbar = tqdm(range(self.iterations))\n", " for i in pbar:\n", " pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...')\n", " np.random.shuffle(self.uir)\n", " self.sgd(self.uir)\n", " if test_ui==None:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir)])\n", " else:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n", " \n", " def sgd(self, uir):\n", " \n", " for u, i, score in uir:\n", " # Computer prediction and error\n", " prediction = self.get_rating(u,i)\n", " e = (score - prediction)\n", " \n", " # Update user and item latent feature matrices\n", " Pu_update=self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n", " Qi_update=self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n", " \n", " self.Pu[u] += Pu_update\n", " self.Qi[i] += Qi_update\n", " \n", " def get_rating(self, u, i):\n", " prediction = self.Pu[u].dot(self.Qi[i].T)\n", " return prediction\n", " \n", " def RMSE_total(self, uir):\n", " RMSE=0\n", " for u,i, score in uir:\n", " prediction = self.get_rating(u,i)\n", " RMSE+=(score - prediction)**2\n", " return np.sqrt(RMSE/len(uir))\n", " \n", " def estimations(self):\n", " self.estimations=\\\n", " np.dot(self.Pu,self.Qi.T)\n", "\n", " def recommend(self, user_code_id, item_code_id, topK=10):\n", " \n", " top_k = defaultdict(list)\n", " for nb_user, user in enumerate(self.estimations):\n", " \n", " user_rated=self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user+1]]\n", " for item, score in enumerate(user):\n", " if item not in user_rated and not np.isnan(score):\n", " top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n", " result=[]\n", " # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n", " for uid, item_scores in top_k.items():\n", " item_scores.sort(key=lambda x: x[1], reverse=True)\n", " result.append([uid]+list(chain(*item_scores[:topK])))\n", " return result\n", " \n", " def estimate(self, user_code_id, item_code_id, test_ui):\n", " result=[]\n", " for user, item in zip(*test_ui.nonzero()):\n", " result.append([user_code_id[user], item_code_id[item], \n", " self.estimations[user,item] if not np.isnan(self.estimations[user,item]) else 1])\n", " return result" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 39 RMSE: 0.7503496297849689. Training epoch 40...: 100%|██████████| 40/40 [01:12<00:00, 1.81s/it]\n" ] } ], "source": [ "model=SVD(train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40)\n", "model.train(test_ui)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAb/ElEQVR4nO3dfXRV9Z3v8fc3IRAgEciDmiEg2DIznaIGjEErdlC8dwm3XqnLhbimo3VcwxrFglPRW2m1taNrbBkfhmuVeq9ae29XdTo+VF1WaxUHbVVKMCBPCr1qCTDlMUAEBJLv/WOfQ07CSXJOcpJ9svfntdZvnX323uecr9vw2fv89m/vY+6OiIgMfAVhFyAiIrmhQBcRiQgFuohIRCjQRUQiQoEuIhIRg8L64IqKCh83blxYHy8iMiDV19fvcvfKdMtCC/Rx48axcuXKsD5eRGRAMrNPOlumLhcRkYhQoIuIRIQCXUQkIkLrQxeRgefo0aM0NjZy+PDhsEuJvOLiYqqrqykqKsr4NQp0EclYY2MjpaWljBs3DjMLu5zIcnd2795NY2Mj48ePz/h16nIRkYwdPnyY8vJyhXkfMzPKy8uz/iakQBeRrCjM+0dPtvOAC/SPPoKbboKjR8OuREQkvwy4QH//ffjXf4WlS8OuREQkvwy4QL/0UrjoIvje92Dv3rCrEZH+1NTUxEMPPZT162bOnElTU1PWr/v617/O+PHjqamp4ayzzuK11147vmzatGmMHTuW1B8JmjVrFiUlJQC0trYyf/58Jk6cyBlnnME555zDRx99BARXyp9xxhnU1NRQU1PD/Pnzs64tnQE3ysUM7rsPJk2C738f7r8/7IpEpL8kA/2GG25oN7+lpYXCwsJOX/fSSy/1+DMXL17MFVdcwbJly5g7dy6bNm06vmzkyJH89re/ZerUqTQ1NbF9+/bjy5566im2bdvGmjVrKCgooLGxkeHDhx9fvmzZMioqKnpcVzoDLtABzjoLrrsOHnwQrr8e/vzPw65IJH5uugkaGnL7njU18MADnS//1re+xR/+8AdqamooKiqipKSEqqoqGhoaWL9+PbNmzWLLli0cPnyYBQsWMHfuXKDt3lHNzc3MmDGDqVOn8rvf/Y7Ro0fzy1/+kqFDh3Zb23nnncfWrVvbzZszZw5PPvkkU6dO5ZlnnuHyyy9n3bp1AGzfvp2qqioKCoKOkOrq6h5ulcwNuC6XpH/6JyguhltvDbsSEekv99xzD5/73OdoaGhg8eLFrFixgrvvvpv169cD8Nhjj1FfX8/KlStZsmQJu3fvPuE9Nm3axLx581i3bh0jR47k6aefzuizX375ZWbNmtVu3vTp01m+fDktLS08+eSTXHnllceXzZ49mxdeeIGamhpuvvlm3nvvvXavvfDCC493udyfo66GAXmEDnDqqbBoUdCWLYMLLwy7IpF46epIur/U1dW1u/BmyZIlPPvsswBs2bKFTZs2UV5e3u41yT5xgLPPPpuPP/64y8+45ZZbuPXWW9mxYwfvvPNOu2WFhYVMnTqVp556ikOHDpF6S/Dq6mo++OADXn/9dV5//XWmT5/OL37xC6ZPnw70TZfLgD1CB/jHf4TTTgseW1rCrkZE+ltqn/Qbb7zBb37zG95++21Wr17NpEmT0l6YM2TIkOPThYWFHDt2rMvPWLx4MZs3b+auu+7immuuOWH5nDlz+MY3vsHs2bPTftaMGTNYvHgxixYt4rnnnsvmPy9rAzrQi4vhBz+A1avhJz8JuxoR6WulpaUcOHAg7bJ9+/YxatQohg0bxsaNG084mu6NgoICFixYQGtrK6+88kq7ZRdccAG33XYbV111Vbv5q1atYtu2bUAw4mXNmjWcdtppOaspbZ19+u79YPZs+NKX4Nvfhk7+P4tIRJSXl3P++eczceJEbrnllnbLLrnkEo4dO8aZZ57J7bffzrnnnpvTzzYzvvOd7/DDH/7whPkLFy48oftkx44dXHrppUycOJEzzzyTQYMGceONNx5fntqHfvXVV+emxtQxlP2ptrbWc/WLRStWwJQpQX/63Xfn5C1FJI0NGzbwhS98IewyYiPd9jazenevTbf+gD9CB6irg7/5G7j3Xvik0x9nEhGJtm4D3cyKzWyFma02s3VmdmeadaaZ2T4za0i0O/qm3M798z9DQQHcdlt/f7KIDHTz5s073v2RbI8//njYZWUtk2GLnwEXuXuzmRUBb5nZr9y94xmHN939K7kvMTNjxsDChcH49PnzIcfdZyKS4O6Ru+Pij370o7BLOEFPusO7PUL3QHPiaVGihdPx3o1bb4WqKrj55rArEYmm4uJidu/e3aOwkcwlf+CiuLg4q9dldGGRmRUC9cDngR+5+7tpVjvPzFYD24CF7r4uzfvMBeYCjB07NqtCM1FSEtwS4K674NgxGDRgL5sSyU/V1dU0Njayc+fOsEuJvORP0GUjo8hz9xagxsxGAs+a2UR3X5uyyirgtES3zEzgOWBCmvd5BHgEglEuWVWaoVNOCR737oXKyr74BJH4Kioqyuon0aR/ZTXKxd2bgDeASzrM35/slnH3l4AiM8vtNa0ZKisLHtPcwkFEJNIyGeVSmTgyx8yGAhcDGzusc6olzpKYWV3ifUOJ1ORtG/bsCePTRUTCk0mXSxXwRKIfvQD4N3d/0cz+AcDdlwJXANeb2THgEDDHQzprkgx0HaGLSNx0G+juvgaYlGb+0pTpB4EHc1tazyjQRSSuInGlaCoFuojEVeQCvbQ0GK6oQBeRuIlcoJsFR+kKdBGJm8gFOgRDFxXoIhI3kQx0HaGLSBxFNtA1Dl1E4iayga4jdBGJGwW6iEhERDbQDx+GgwfDrkREpP9ENtBBR+kiEi8KdBGRiIhkoOsWuiISR5EMdB2hi0gcRTrQNRZdROIk0oGuI3QRiZNIBvrgwcEPRivQRSROIhnooIuLRCR+FOgiIhER2UDXLXRFJG4iG+g6QheRuIl0oGvYoojESaQDfe9eaG0NuxIRkf4R6UBvbYWmprArERHpH5EOdFA/uojER7eBbmbFZrbCzFab2TozuzPNOmZmS8xss5mtMbPJfVNu5hToIhI3gzJY5zPgIndvNrMi4C0z+5W7v5OyzgxgQqJNAR5OPIZGd1wUkbjp9gjdA82Jp0WJ5h1Wuwz4aWLdd4CRZlaV21KzoyN0EYmbjPrQzazQzBqAHcCr7v5uh1VGA1tSnjcm5nV8n7lmttLMVu7cubOnNWdEgS4icZNRoLt7i7vXANVAnZlN7LCKpXtZmvd5xN1r3b22srIy+2qzMGIEFBRoLLqIxEdWo1zcvQl4A7ikw6JGYEzK82pgW68q66WCAl3+LyLxkskol0ozG5mYHgpcDGzssNrzwNWJ0S7nAvvcfXvOq82SLv8XkTjJZJRLFfCEmRUS7AD+zd1fNLN/AHD3pcBLwExgM3AQuLaP6s2KAl1E4qTbQHf3NcCkNPOXpkw7MC+3pfVeeTls2dL9eiIiURDZK0VBfegiEi+RDnR1uYhInEQ+0A8ehMOHw65ERKTvRT7QQWPRRSQeYhHo6nYRkThQoIuIRIQCXUQkIiId6LqFrojESaQDXUfoIhInkQ70oUODpkAXkTiIdKCDLi4SkfiIRaBrHLqIxEEsAl1H6CISBwp0EZGIUKCLiERE5AO9rCzoQ/cTfuFURCRaIh/o5eXQ0gL79oVdiYhI34pFoIO6XUQk+mIT6Bq6KCJRF5tA1xG6iESdAl1EJCIU6CIiERH5QB85EswU6CISfZEP9MLCINQV6CISdd0GupmNMbNlZrbBzNaZ2YI060wzs31m1pBod/RNuT2jq0VFJA4GZbDOMeBmd19lZqVAvZm96u7rO6z3prt/Jfcl9p4CXUTioNsjdHff7u6rEtMHgA3A6L4uLJd0C10RiYOs+tDNbBwwCXg3zeLzzGy1mf3KzL6Yg9pyRkfoIhIHmXS5AGBmJcDTwE3uvr/D4lXAae7ebGYzgeeACWneYy4wF2Ds2LE9LjpbCnQRiYOMjtDNrIggzH/m7s90XO7u+929OTH9ElBkZhVp1nvE3WvdvbaysrKXpWeurAwOHIAjR/rtI0VE+l0mo1wMeBTY4O73dbLOqYn1MLO6xPvmzTGx7uciInGQSZfL+cDfAu+bWUNi3iJgLIC7LwWuAK43s2PAIWCOe/7cgTz1atFTTw23FhGRvtJtoLv7W4B1s86DwIO5KirXdPm/iMRB5K8UBQW6iMRDrAJdfegiEmWxCnQdoYtIlMUi0IcNgyFDFOgiEm2xCHSzYCy6Al1EoiwWgQ66WlREok+BLiISEQp0EZGIUKCLiERErAJ9zx7InxsSiIjkVqwC/ehRaG4OuxIRkb4Rm0AvKwse1e0iIlEVm0DX1aIiEnUKdBGRiFCgi4hEhAJdRCQiYhPoOikqIlEXm0AfNAhGjNA90UUkumIT6KA7LopItMUq0HX5v4hEmQJdRCQiFOgiIhGhQBcRiYjYBfq+fXDsWNiViIjkXuwCHWDv3nDrEBHpC90GupmNMbNlZrbBzNaZ2YI065iZLTGzzWa2xswm9025vaOrRUUkygZlsM4x4GZ3X2VmpUC9mb3q7utT1pkBTEi0KcDDice8oqtFRSTKuj1Cd/ft7r4qMX0A2ACM7rDaZcBPPfAOMNLMqnJebS/pCF1EoiyrPnQzGwdMAt7tsGg0sCXleSMnhj5mNtfMVprZyp07d2ZXaQ4o0EUkyjIOdDMrAZ4GbnL3/R0Xp3nJCb/e6e6PuHutu9dWVlZmV2kOKNBFJMoyCnQzKyII85+5+zNpVmkExqQ8rwa29b683CotDW7SpUAXkSjKZJSLAY8CG9z9vk5Wex64OjHa5Vxgn7tvz2GdOWGmi4tEJLoyGeVyPvC3wPtm1pCYtwgYC+DuS4GXgJnAZuAgcG3uS82N8nLdQldEoqnbQHf3t0jfR566jgPzclVUXyorg127wq5CRCT3YnWlKMCECbBuHfgJp2xFRAa22AV6XV1whP7xx2FXIiKSW7EL9CmJ61dXrAi3DhGRXItdoE+cCMXFCnQRiZ7YBXpREUyerEAXkeiJXaBD0I9eX6/7ootItMQ20A8dCka7iIhERWwDHeDdjrcYExEZwGIZ6KefHlxgpH50EYmSWAa6WXCUrkAXkSiJZaBDMB593Tpobg67EhGR3IhtoNfVQWsrrFoVdiUiIrkR20A/55zgUd0uIhIVsQ30ykoYP16BLiLREdtAB50YFZFoiX2gf/IJ/OlPYVciItJ7sQ900FG6iERDrAN98mQoLFSgi0g0xDrQhw2DM85QoItINMQ60KHtxKh+kk5EBjoFeh00NcHmzWFXIiLSOwp0nRgVkYiIfaD/1V/B8OEKdBEZ+GIf6IWFcPbZuje6iAx83Qa6mT1mZjvMbG0ny6eZ2T4za0i0O3JfZt+qq4P33oMjR8KuRESk5zI5Qv8JcEk367zp7jWJ9v3el9W/6uqCMF+zJuxKRER6rttAd/flwJ5+qCU0U6YEj+pHF5GBLFd96OeZ2Woz+5WZfbGzlcxsrpmtNLOVO3fuzNFH996YMXDKKQp0ERnYchHoq4DT3P0s4H8Cz3W2ors/4u617l5bWVmZg4/ODf0knYhEQa8D3d33u3tzYvoloMjMKnpdWT+rq4ONG2HfvrArERHpmV4HupmdamaWmK5LvOfu3r5vf6urCy7/r68PuxIRkZ4Z1N0KZvZzYBpQYWaNwHeBIgB3XwpcAVxvZseAQ8Ac94F3Z5Ta2uDx3XfhoovCrUVEpCe6DXR3v6qb5Q8CD+asopCUlcGECepHF5GBK/ZXiqbSiVERGcgU6CmmTIFt22Dr1rArERHJngI9he68KCIDmQI9xVlnQVERLFsWdiUiItlToKcoLoYrr4RHHoGPPgq7GhGR7CjQO7jnHhg0CBYuDLsSEZHsKNA7GD0aFi2CZ56B118PuxoRkcwp0NP45jfh9NNhwQI4dizsakREMqNAT6O4GO69F9auhaVLw65GRCQzCvROXHYZXHwx3HEH7B5wd6YRkThSoHfCDB54APbvh9tvD7saEZHuKdC78MUvwg03wI9/DKtXh12NiEjXFOjduPNOGDUqOEE68O4hKSJxokDvxqhRcNdd8B//Af/+72FXIyLSOQV6Bv7+74PbAixcCAcPhl2NiEh6CvQMFBbCkiXwxz/C4sVhVyMikp4CPUNf/jLMng0/+AFs2hR2NSIiJ1KgZ2HxYhgyBM49F155JexqRETaU6BnYexY+P3voboaZsyA738fWlvDrkpEJKBAz9LnPw9vvw1f+xp897tw6aWwZ0/YVYmIKNB7ZNgweOIJeOghePVVOPtsWLUq7KpEJO4U6D1kBtdfD2++GdyR8UtfgsceC7sqEYkzBXovTZkSHJ1fcAFcd13Qtm8PuyoRiSMFeg5UVsLLLwc/jPH448HJ0yuvhOXLdbsAEek/3Qa6mT1mZjvMbG0ny83MlpjZZjNbY2aTc19m/isshLvvhg8+gPnz4de/hr/+azjzTHj4YThwIOwKRSTqMjlC/wlwSRfLZwATEm0u8HDvyxq4JkwIfhxj61Z49FEYPDi4Y+Of/RnMmwcNDRrqKCJ9o9tAd/flQFcD8y4DfuqBd4CRZlaVqwIHqmHD4O/+DlauhHfegcsvDwJ+0iQ4+WT46lfh/vuhvl4/cyciuTEoB+8xGtiS8rwxMU+nBglGw0yZErR774UXXgj61pcvh+eeC9YpLYXzzw9uLzBlCvzlX0JVVfBaEZFM5SLQ08VO2lOBZjaXoFuGsWPH5uCjB5aKCrj22qABNDYGwx6TAb9oUdu6paVBsP/FXwSPyenx42H48HDqF5H8Zp7BMAwzGwe86O4T0yz7MfCGu/888fwDYJq7d3mEXltb6ytXruxJzZG1cyesWQMbN7a1Dz6ALVvarzdiRNAnP3p0W0s+P/nkYNRNRUWwXoHGMYlEipnVu3ttumW5OEJ/HrjRzJ4EpgD7ugtzSa+yEqZPD1qq5mb48MMg4P/4x+CE69atsG0bvPZaMO69peXE9yssDIK9oiJ47/Ly4Ac7Ro2CkSPTtxEj4KSToKREXT4iA023gW5mPwemARVm1gh8FygCcPelwEvATGAzcBC4tq+KjauSEpg8OWjptLTAjh1BwO/cGbRdu4KWOr1+PTQ1Be3Qoa4/0ywI9mTAJ1tJSdCGD2+bTp3XVRs2TN8YRPpSt4Hu7ld1s9yBeTmrSLJWWBicRK3KYmzRZ5+1hXtq278f9u0LWnI6+bhrF3zySfCNobkZPv0UjhzJrtahQ9uHfOqOYNiwtseOreMOpLS0/fSwYfpGIZKLLhcZgIYMgVNOCVpvHDkSBHsy4NO1gwfTL0+d19gYfGs4eDBon34Khw9nXodZ+m8NHXcCyR1Bcjr1ecdvHkOHaichA4sCXXpl8OCgjRqV+/dubW0L+eQOoLk5uOq24/SBA+3XSbbdu+Hjj9vWOXAg8wu7zNrvDJLdTsnp1MfuWklJsJ1E+pICXfJWQUFbd0xlZW7e0z048k/dESSnO+4Qks+T6xw4EHQ/ffRR2/T+/XD0aGafPXhw+m8G6XYSHadHjGh/PmOQ/uVKGvqzkFgxC7pShg4NhnjmwmeftQ/9dK3jjiHZmpqCYan797fNy+SGbsOHtw/6ZEuOVEo3L/XxpJN0gjqKFOgivTRkSNAqKnr/Xq2tQRdTMuCT3wJST053PGmdbJ980jbd3SgmCEJ95Mi2oaypraysbTq5E0gd2qrzC/lJgS6SRwoK2vrte+PIkSDwm5raQj45nfrY1AR79wbtww/bprvbIRQVte0MysuDHUBZWefTyR2ELnbrWwp0kQgaPLjtorKe+OyzINg7DmtN3RE0NQW/p7tnT3ANxNq1wUno5ubO37egINgRpAZ9x+BPnZe86rm0VN8IMqFAF5ETDBkCp54atGwdORLsDHbvDsJ+79624E+25PJdu4JvBnv2BDuIzhQVtb/qOTldXt75Yxyvdlagi0hODR7cs2scWlqCUE/uCHbtagv91Cued+0KflcguV5nJ5GHDAnCv7OWDP7kTqCsLNhxDGQKdBHJC4WFbQGbqeROIDX8U3cCqW3TpuCxqy6hk05qC/jU8E/e9K7jdL7d+VSBLiIDVk92AocPB8G+e3f7nUDHHcJ//ie8/36wbmdXLQ8dmj74Tz75xFZZGazflxToIhIrxcUwZkzQMuEeHNV3POJPth072qY3bAgeDx5M/14lJUG433AD3Hxz7v6bkhToIiJdMGu7uvf00zN7zaeftoV9upbNjfSyoUAXEcmx5C0rxo3r38/VEH8RkYhQoIuIRIQCXUQkIhToIiIRoUAXEYkIBbqISEQo0EVEIkKBLiISEeaZ/N5VX3yw2U7gky5WqQB29VM52VJtPaPaeka19UxUazvN3dP+ym5ogd4dM1vp7rVh15GOausZ1dYzqq1n4libulxERCJCgS4iEhH5HOiPhF1AF1Rbz6i2nlFtPRO72vK2D11ERLKTz0foIiKSBQW6iEhE5F2gm9klZvaBmW02s2+FXU8qM/vYzN43swYzWxlyLY+Z2Q4zW5syr8zMXjWzTYnHUXlU2/fMbGti2zWY2cyQahtjZsvMbIOZrTOzBYn5oW+7LmoLfduZWbGZrTCz1Yna7kzMz4ft1lltoW+3lBoLzew9M3sx8bxPtlte9aGbWSHwIfBfgEbg98BV7r4+1MISzOxjoNbdQ79Ywcy+DDQDP3X3iYl5PwT2uPs9iZ3hKHf/H3lS2/eAZnf/l/6up0NtVUCVu68ys1KgHpgFfJ2Qt10Xtc0m5G1nZgYMd/dmMysC3gIWAJcT/nbrrLZLyIO/OQAz+yZQC5zk7l/pq3+r+XaEXgdsdvf/5+5HgCeBy0KuKS+5+3JgT4fZlwFPJKafIAiDftdJbXnB3be7+6rE9AFgAzCaPNh2XdQWOg80J54WJZqTH9uts9rygplVA/8N+N8ps/tku+VboI8GtqQ8byRP/qATHPi1mdWb2dywi0njFHffDkE4ACeHXE9HN5rZmkSXTCjdQanMbBwwCXiXPNt2HWqDPNh2iW6DBmAH8Kq7581266Q2yIPtBjwA3Aq0pszrk+2Wb4FuaeblzZ4WON/dJwMzgHmJrgXJzMPA54AaYDtwb5jFmFkJ8DRwk7vvD7OWjtLUlhfbzt1b3L0GqAbqzGxiGHWk00ltoW83M/sKsMPd6/vj8/It0BuBMSnPq4FtIdVyAnfflnjcATxL0EWUT/6U6IdN9sfuCLme49z9T4l/dK3A/yLEbZfoZ30a+Jm7P5OYnRfbLl1t+bTtEvU0AW8Q9FHnxXZLSq0tT7bb+cB/T5x/exK4yMz+L3203fIt0H8PTDCz8WY2GJgDPB9yTQCY2fDEiSrMbDjwX4G1Xb+q3z0PXJOYvgb4ZYi1tJP84034KiFtu8QJtEeBDe5+X8qi0LddZ7Xlw7Yzs0ozG5mYHgpcDGwkP7Zb2tryYbu5+23uXu3u4wjy7HV3/xp9td3cPa8aMJNgpMsfgG+HXU9KXacDqxNtXdi1AT8n+Bp5lOCbzXVAOfAasCnxWJZHtf0f4H1gTeKPuSqk2qYSdOOtARoSbWY+bLsuagt92wFnAu8lalgL3JGYnw/brbPaQt9uHeqcBrzYl9str4YtiohIz+Vbl4uIiPSQAl1EJCIU6CIiEaFAFxGJCAW6iEhEKNBFRCJCgS4iEhH/H5aRnDSMJbJNAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "df=pd.DataFrame(model.learning_process).iloc[:,:2]\n", "df.columns=['epoch', 'train_RMSE']\n", "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3de5zOdf7/8cfLUI5FonXKISpyGJmkRZHKIZUOK9pODmttRG0pVNrdajtoN9nUZqXjbtS3dE4qSmcNhhx/hpTBRtqUSsL798frGi5jhmuYmeswz/vtdt1m5vO5PnO9Pj7mdb2v9+f9fr0thICIiKSuMvEOQEREipcSvYhIilOiFxFJcUr0IiIpToleRCTFlY13APk58sgjQ4MGDeIdhohI0pg7d+7XIYQa+e1LyETfoEEDMjMz4x2GiEjSMLMvCtq3364bM5tsZhvMbFEB+83MxptZtpktNLMTo/Z1M7PlkX0jDyx8ERE5GLH00T8GdNvH/u5Ak8hjEPAQgJmlARMi+5sBfc2s2cEEKyIihbffRB9CmA18s4+nnAc8EdzHQFUzqwW0BbJDCKtCCNuAKZHniohICSqKPvo6wJqon3Mi2/LbfnJBv8TMBuGfCDj66KOLICwRKSm//PILOTk5bN26Nd6hpLzy5ctTt25dypUrF/MxRZHoLZ9tYR/b8xVCmAhMBMjIyFABHpEkkpOTQ5UqVWjQoAFm+f3pS1EIIbBp0yZycnJo2LBhzMcVxTj6HKBe1M91gXX72C4iKWbr1q1Ur15dSb6YmRnVq1cv9Cenokj0LwGXR0bftAM2hxDWA58CTcysoZkdAvSJPFdEUpCSfMk4kH/n/XbdmNnTQCfgSDPLAW4FygGEEP4JvAb0ALKBH4F+kX3bzWwo8AaQBkwOISwudISFshOYC5xUvC8jIpJE9pvoQwh997M/AEMK2Pca/kZQQv4PuBg4G/gr0LLkXlpEJEGlWK2bnsCdwAdAOnApsCquEYlI8fv222958MEHC31cjx49+Pbbbwt93JVXXknDhg1JT0+nVatWvP3227v2derUiaOPPproRZ169epF5cqVAdi5cyfDhg2jefPmtGjRgpNOOonPP/8c8KoALVq0ID09nfT0dIYNG1bo2PKTkCUQDlxFYCTwe+Ae4H5gHrCY/AcBiUgqyE30V1111R7bd+zYQVpaWoHHvfbagXc4jB07losuuohZs2YxaNAgVqxYsWtf1apV+eCDD+jQoQPffvst69ev37Vv6tSprFu3joULF1KmTBlycnKoVKnSrv2zZs3iyCOPPOC48pNiiT5XNbxlfzU++Mfw2wd/i2yrGr/QRFLcNddAVlbR/s70dBg3ruD9I0eOZOXKlaSnp1OuXDkqV65MrVq1yMrKYsmSJfTq1Ys1a9awdetWhg8fzqBBg4DddbW2bNlC9+7d6dChAx9++CF16tThxRdfpEKFCvuN7ZRTTmHt2rV7bOvTpw9TpkyhQ4cOPP/881xwwQUsXuy3KNevX0+tWrUoU8Y7VOrWrXuA/yqxS7Gum7xq4xN0Ad4ExgCN8Nb+j/EKSkSK2F133cUxxxxDVlYWY8eOZc6cOdxxxx0sWbIEgMmTJzN37lwyMzMZP348mzZt2ut3rFixgiFDhrB48WKqVq3Kc889F9NrT58+nV69eu2xrUuXLsyePZsdO3YwZcoULr744l37evfuzcsvv0x6ejrXXXcd8+fP3+PYzp077+q6ue+++wr7T5GvFG3R5+c8YD5wE3AjcB/exTOGlH+/EylB+2p5l5S2bdvuMaFo/PjxTJs2DYA1a9awYsUKqlevvscxuX3uAG3atGH16tX7fI0RI0Zwww03sGHDBj7++OM99qWlpdGhQwemTp3KTz/9RHTZ9bp167J8+XJmzpzJzJkz6dKlC88++yxdunQBiqfrppRluHTgVeBdoHXka+4/wQr2MXFXRJJIdJ/3O++8w1tvvcVHH33EggULaN26db4Tjg499NBd36elpbF9+/Z9vsbYsWPJzs7m9ttv54orrthrf58+fbj66qvp3bt3vq/VvXt3xo4dy+jRo3nhhRcKc3qFVsoSfa5T8VGfb0R+/i9wAnAi8Ajq1hFJLlWqVOH777/Pd9/mzZupVq0aFStWZNmyZXu1vg9GmTJlGD58ODt37uSNN97YY1/Hjh0ZNWoUffvuOUJ93rx5rFvnRQJ27tzJwoULqV+/fpHFlG+cxfrbE94hka+HAQ8A24GBeLWGEfgbgIgkuurVq9O+fXuaN2/OiBEj9tjXrVs3tm/fTsuWLbnlllto165dkb62mXHzzTdzzz337LX9+uuv36sbZsOGDZxzzjk0b96cli1bUrZsWYYOHbprf3Qf/eWXX140MUaP9UwUGRkZIT4rTAVgNp70XwKWAQ3xSb+HA/mu0iVS6i1dupSmTZvGO4xSI79/bzObG0LIyO/5pbxFn5cBpwHPAuvxJA9wHVATaAOMAmYCP8cjQBGRQlOiL9ARUd+PAW4DKgH3Al2AM6L2f4lu5IqkniFDhuzqRsl9PProo/EOq9BK0fDKg9Em8rgZ+B54J2rfVuA4/I2hB9AfaIdm4ookvwkTJsQ7hCKhFn2hVQHOiTzAK2Y+AHTAV0v8NdAKeC8u0YmI5KVEf9AqAgOAqfi6Kg/jVZxzu36W4KX51bUjIvGhRF+kquDL3s7Fx+UD3IWXYWiDvwnkP9ZXRKS4KNEXu38AE4AdwGC8/s5NcY1IREoXJfpidzhwFZAFfARcBOROtf4FX2HxYeCLuEQnkgoOtB49wLhx4/jxx33Phs+tE9+yZUtOO+00vvhi99+rmXHZZZft+nn79u3UqFGDnj17AvDVV1/Rs2dPWrVqRbNmzejRowcAq1evpkKFCnuM6HniiScO6Bz2R4m+xBg+GudRfLgmwBo8+Q8GGgBNgT8Cy+MQn0jyKu5ED15sbOHChXTq1Inbb7991/ZKlSqxaNEifvrpJwDefPNN6tSps2v/mDFjOPPMM1mwYAFLlizhrrvu2rUvt+Jm7qOoZsLmpUQfV42A1cBS4O/A0cCD+GQt8E8B4/A3g5/iEJ/IgeqUzyM3Ef9YwP7HIvu/zmffvkXXox8xYgRjx47lpJNOomXLltx6660A/PDDD5x99tm0atWK5s2bM3XqVMaPH8+6devo3LkznTt3junM8qs/3717d1599VUAnn766T3q26xfv36PmvMtW5b8EqdK9HFnwPHAtXiRtW/woZrglTavxYdsVsGrbw4Eviv5MEUSWHQ9+jPPPJMVK1YwZ84csrKymDt3LrNnz2b69OnUrl2bBQsWsGjRIrp168awYcOoXbs2s2bNYtasWTG9Vn7153MXGtm6dSsLFy7k5JNP3rVvyJAhDBgwgM6dO3PHHXfsKmgG7Hpzyn28917xDMvWhKmEUzHq+5uAK4FMfIhmJl5+oXJk/zDgYyAj8ugINEaTtST+3tnHvor72X/kfvbv24wZM5gxYwatW7cGYMuWLaxYsYKOHTty/fXXc+ONN9KzZ086duxYqN/buXNnvvrqK2rWrLlH1w14K3316tU8/fTTu/rgc3Xt2pVVq1Yxffp0Xn/9dVq3bs2iRYuA3V03xU0t+oRXB1805XZgOr7Yee5la4KXZXgKH8t/LLs/DQD8UHJhiiSIEAKjRo3a1e+dnZ3NgAEDOPbYY5k7dy4tWrRg1KhR/OUvfynU7501axZffPEFJ5xwAmPGjNlr/7nnnsv111+/V1ligCOOOIJLLrmEJ598kpNOOonZs2cf8PkdCCX6pHY1MAv4Fp+YNQHI/U8W8MTfArgGeBl1+Uiqiq5H37VrVyZPnsyWLVsAWLt2LRs2bGDdunVUrFiRSy+9lOuvv5558+btdez+VKhQgXHjxvHEE0/wzTff7LGvf//+jBkzhhYtWuyxfebMmbtu9n7//fesXLmSo48++qDOt7DUdZMSyuAjdqLLlv6Cd+28jQ/fvB9IwxdNH4GXbtiJ/gtIKoiuR9+9e3cuueQSTjnlFAAqV67MU089RXZ2NiNGjKBMmTKUK1eOhx56CIBBgwbRvXt3atWqFVM/fa1atejbty8TJkzglltu2bW9bt26DB8+fK/nz507l6FDh1K2bFl27tzJwIEDOemkk1i9evWuPvpc/fv3Z9iwYQf7z7EX1aMvFX7GR+68BZyJl2Kei1fgPBPoDnTFJ3OJFJ7q0ZeswtajV3OuVDiUvYepVQDOx/v9n41sawU8g3f5iEiqUKIvtZoBk/G+/M/whP8WvowiwFjgQ7zF3wVP/hrNI6nt5JNP5uef91xU6Mknn9yr3z3ZKNGXega0jDxuyLNvHpC7On0doBdekllkbyEEzJK7MfDJJ5/EO4T9OpDudo26kQKMwGftrgD+iU/aih61cwFew+d5fJKXlGbly5dn06ZNB5SEJHYhBDZt2kT58uULdZxa9LIPhk/Aagz8Pmr7L8A24AngocjzWuOzeC8t4RglEdStW5ecnBw2btwY71BSXvny5fcoqRCLmBK9mXVj9/i8SSGEu/Lsr4Z3+B6Dr63XP4SwKLJvNV6EfQewvaC7wpJMygGv4Al/Dj6Ec1bkZ4DPgZ746J5TIw+N6Ell5cqVo2HDhvEOQwqw30RvZmn4TJwzgRzgUzN7KYSwJOppo4GsEML5ZnZ85PldovZ3DiF8XYRxS0IoB7SPPKJnCm7BC7Q9hbf4wdsA/8EXYdmJeg1FSk4sf21tgewQwqoQwjZ8YdTz8jynGd6sI4SwDGhgZkcVaaSSRFoAr+N995l4Zc4W7B7R8wC+oPpQ4EU0Y1ekeMWS6OvghdNz5US2RVuA353DzNoC9dn9Vx2AGWY218wGFfQiZjbIzDLNLFP9fKmiLL6E4rXANHZ33zTAW/iP4iN5jsC7d7aXfIgipUAsiT6/8VJ5b63fBVQzsyy8AMt8dv/Vtg8hnIhPvxxiZqfm9yIhhIkhhIwQQkaNGjVii36P4+G3v4XJk2HHjkIfLiXqXOA1vMU/C7gRH6ef25N4If6h8QG8Vr9GcogcjFgSfQ5QL+rnusC66CeEEL4LIfQLIaQDlwM18DtyhBDWRb5uwJt1bYsg7r1s3gyrVsGAAZCeDq+/7slfElnujN07gElR2+vik7iuxnsFa+NtCRE5ELEk+k+BJmbW0MwOwRc5fSn6CWZWNbIPfGWM2SGE78yskplViTynEnAWsKjowt+talX48EN49ln46Sfo0QPOPBMiBeokqdyPl2NeCfwL6IwvvAKwGS/PPBC/ubs+v18gIlH2m+hDCNvxu2Zv4J+jnwkhLDazwWY2OPK0psBiM1uGd9HklnA7CnjfzBbg4/BeDSFML+qTyGUGF10ES5bA+PGQlQVt2sBll8EXWns7CTVid0IfEtn2P/zG7nPAb/HWfjMgt773d/gHTn2cE8mV0tUrN2+Gu++G++7zbpxhw2D0aG/9S7LbgY8BmIn38/8VL8r2FHAZfoO3BV7aoQXwG0AXXlLXvqpXpnSiz7VmDdxyCzzxBFSr5t//4Q9w6KFF9hKSMLLxoZ0L8X7+RfhKW2vx1v+jeAG33OUXTwQOj0ukIkVpX4m+VMxaqVcPHnvM++vbtIFrr4WmTeGRR+CXX/Z7uCSVxvhN3H/h6+l+h/f114rs34z3It4AnI638k/APyGA1/eJbbUhkWRRKhJ9rvR0mDEDpk+H6tVh4EBo0gQmToRt2+IdnRSPMnhff+4o4WvwAWEb8dtOd+CLrqRF9vfDW/jN8IXZHwYWl1y4IsWgVCX6XF27wpw58Npr8Ktfwe9/D40bw0MPQZ5S1JKyjsQHgY3GZ+7muhn4M/7J4DVgMHBd1P5xeJ0fVfSQ5FEq+uj3JQRv5f/5z/DRR1CnDowc6a39QlYClZQT8GGeP+I3dH/Ab/LmfvxrArQD+uPzAXL/lpK7Jrskp1LfR78vZt7C/+ADePNNaNgQrr4aGjWC++/3MflSWhleqiF3daFK+Gzed/AJXM3w7p/PIvtX4p8UOgCD8Nb/DHxIqEj8lPoWfV4hwDvveAv/3XfhqKNg6FDo3x9qq9Ku7CXg1T7K4X3/dwNLIo9Nkec8h5eCmosv0VgXn2ye+2iOr+ErcuBK/fDKA/Xuu3DHHd7ST0uDc86BQYPgrLP8Z5F924gn/OZAdXxC+XV4VZGtUc/LxIu/PQv8A0/+9fF5iM3wuQDlSixqSU5K9AdpxQqYNAkefRQ2boT69b0PX618OTABv5m7JvLoAlTGW/7jo7bn1gX8Lz7JfCrwCZ78m0YeR5Rk4JLAlOiLyLZt8MILPhzz7be9Vd+zp4/aUStfitYv+OSv5XglT8NHCI0Dom8cNcRvGIMv7ZiDTwyrE/VVM4JLAyX6YpC3lX/00V4584orvMUvUjx2Al/gXUJL8fsAd0b2nUeeeoP4Ai/LIt//Bf800QE4Gf8UIalCib4YbdsGL74IDz/srXyA00+HK6+ECy6ASpXiGp6UOj/hRd3W4WUfygC9I/vOAt7Ck30avqD7ZcCwkg9TipwSfQn5/HN48kl4/HGvjV+5MvzmN97K79gRypT6wawSf5uBj4D3I4/2+OzgbfgN4Qy8xX8su7t+NKEkGSjRl7AQ4P33vb7OM8/Ali0+Pv+KK+Dyy/17kcTyX3wW8PvsHhYKfk9gOH4fYDCe/KMfp7C7jpDEkxJ9HP3wA0yb5q38t9/2N4HTTvOunYsu8la/SOII+E3g1Xj3T1t8dM9nwO8i29aze0RQ7hyBmfjicvXweQK5j4sjX3/Gu5E0TLS4KNEniC+/hKee8pb+ihWe5Hv3hn79oH17n6Urkvh24sND1+Hj/avhy0SPx0f95ODDQ3/AK4WeBDwGDMC7gupHPa7BVx79Hr9vULHkTiPFKNEnmBB82cPJk3d37TRp4gn/8su93o5Icgt4ieiKeCt+PvA8PmIo95H7pvArfETQrXjSbxB5NAT+hM8a/ha/V6D7BQVRok9gW7bAc8950p8922/Ydu3qSf/cc7U4iqSy7Xgr3vAbxDPZ/SbwOd5FtBnv8hmErzFQG38DaAgcD9wU+V3L8U8aNfFPGKVv5IMSfZLIzvZunccfh5wcOOII6NsXLrkETjlFXTtS2uxkd8J+C/gQfwPIfZTFC8kBdMMLzIG/eRyJdxm9HNn2d/yNowL+KaMC/qnhvMj+9/E3nor4QvSH428YyVODSIk+yezYAW+95ZOxXnjBa+TXrw8XXwx9+vgCKkr6IjvYvWDMHDzpb4g8NuKJ+u7I/s6R5/wYdfxpeCVS8OGkK/L8/h7Aq5HvO0SOPTzq0RlfnAZ80poBh0Y9TsDLWO/E34QMf+PKfdTHq6Nuj8RWBl/3+MDeXJTok9h33/mErClTvG7+9u1w3HGe8Pv29e9FJFYBHwH0I56Aj4xsn4/fB/gRv7ewGe8mOjeyfxC7u5JyHz2BByL7D8HLVkQbEtm/DU/8eY3E3yA2RcWxFO+SKjwl+hTx9dfw/POe9N95x2/qpqd70u/TR6UXROJnB57Qf8Yrk/6MdwPVwN9QPo183Ym/2ezEh502ihw3i93lKQ5szLUSfQpatw6efdaT/scf+7bOnWHIEDjvPChbNr7xiUjJ0gpTKah2bRg+3Jc/XLUKbr8dVq70SVgNG/rPX30V7yhFJBEo0aeAhg3hpps84b/wAjRtCrfcAvXq+YidDz7wbh4RKZ2U6FNIWpp328yYAcuXw1VXwWuvQYcO0Lo1/OtfXpJBREoXJfoUdeyxMG4crF3rJZR37vRlEOvUgWuv9TcCESkdlOhTXKVKnuAXLPCZt926wQMPwPHHQ5cufkP3l7yjwkQkpSjRlxJmXhN/yhRYs8YXPV+50ouqHX209+l/+WW8oxSR4qBEXwr96lcwerQn+ldegTZtPPE3bOj1dV5/3WfnikhqiCnRm1k3M1tuZtlmNjKf/dXMbJqZLTSzOWbWPNZjJX7S0uDssz3Zf/45jBwJn3wCPXpA48Zw112+Hq6IJLf9JnozSwMmAN2BZkBfM2uW52mjgawQQkt89YH7C3GsJID69b1Vv2aNd+80aACjRvkQzf79ISsr3hGKyIGKpUXfFsgOIawKIWwDprC75FuuZsDbACGEZUADMzsqxmMlgRxyiBdPmzULFi/2JD91qg/PPO00L6m8ffv+f4+IJI5YEn0dfLmYXDmRbdEW4OuJYWZt8bJsdWM8lshxg8ws08wyN6q/ICE0awYPPuglk++912/WXnQRHHMM3HMPfPNNvCMUkVjEkujzK4ibd57lXUA1M8sCrsZLwW2P8VjfGMLEEEJGCCGjRo0aMYQlJaVaNbjuOq+XP20aNGoEN94IdevC73/vLX8RSVyxJPocfMXfXHXxxSJ3CSF8F0LoF0JIx/voa+ArA+z3WEkeaWnQq5d362RleXmFJ56A5s13j8nfti3eUYpIXrEk+k+BJmbW0MwOAfoAL0U/wcyqRvYBDARmhxC+i+VYSU6tWsGkSX7z9q9/9cXOc8fkjxrldXdEJDHsN9GHELYDQ/ElUpYCz4QQFpvZYDMbHHlaU2CxmS3DR9gM39exRX8aEi9HHumJ/fPPfZhm27bef9+4sc/CnTZNM29F4k316KXIrVkDjzziLf61a6FWLRgwAAYO1OIoIsVF9eilRNWrB3/6E6xe7csgtm69e+bt2WfDG2+obLJISVKil2JTtqyXVHj1Ve/auekmmDfPu3Rat4b//Edj8kVKghK9lIj69eG227yVP3myj8757W+9L3/8eNXJFylOSvRSog49FPr1g0WL4KWXfCz+8OE+WmfMGNiwId4RiqQeJXqJizJl4Jxz4P33/dGxo7f469f3lbFWrox3hCKpQ4le4q59e1/rdulS786ZNMlXyOrd2xc/F5GDo0QvCeP44z3Jr14N11/va9/++tfQrp1X1NR4fJEDo0QvCad2bbj7bi+m9sADXjytb18fnnnXXSqmJlJYSvSSsCpXhiFDYNkyePllb/GPGuU3cP/wB+/qEZH9U6KXhFemDPTsCW+9BQsXeuv+0Ue9jHL37pqAJbI/SvSSVFq08PIKX34Jf/kLzJ/vE7BOOAEmToSffop3hCKJR4leklLNmnDLLfDFF/D441C+vNfGr1fPZ+CuUzFskV2U6CWpHXooXH45zJ0L777r4/HvvNPH4196qW8XKe2U6CUlmMGpp3pZ5Oxsv4n74ouQkeHJ//nnYceOeEcpEh9K9JJyGjWCceN8eObf/+5fL7zQ6+qMGwfffx/vCEVKlhK9pKzDD4drr/UW/vPPe//9tdd6XZ3Ro+G//413hCIlQ4leUl5aGpx/PsyeDZ98Amec4ROv6teH3/0Oli+Pd4QixUuJXkqVtm19EfP/9/981aunnoKmTX3R8w8/jHd0IsVDiV5KpcaN4cEHfXjmzTd7a799e+jQwW/i7twZ7whFio4SvZRqNWv6xKsvv4T77/cbt716+azbxx/XCliSGpToRfC6OsOG+Y3b//zHJ2BdeaXX11HCl2SnRC8SpWxZr6Uzf77XyK9SZXfCf+wxJXxJTkr0Ivkwg/PO88XMX3wRDjvMl0BUwpdkpEQvsg9mcO65XkrhxRd9bH6/fnDccV5BU4uhSDJQoheJQW7Cz8z0Rc2rVoX+/b2F/+ijauFLYlOiFykEM1/UPDPTF0OpVs0TfrNmfhNX9XQkESnRixwAM18M5dNPvUunQgVf2LxVKy+3oIVQJJEo0YschNwunfnzYepU78K58EKvmvnaa0r4khiU6EWKQJky0Ls3LFrk4+7/9z84+2yfbTtzZryjk9JOiV6kCJUt6wuhLF8ODz8Ma9ZAly5w+unwwQfxjk5Kq5gSvZl1M7PlZpZtZiPz2X+4mb1sZgvMbLGZ9Yvat9rMPjOzLDPLLMrgRRJVuXIwaBCsWOGlFZYs8To6PXp4N49ISdpvojezNGAC0B1oBvQ1s2Z5njYEWBJCaAV0Av5mZodE7e8cQkgPIWQUTdgiyaF8eS+tsHKll0b++GM48US4+GKVR5aSE0uLvi2QHUJYFULYBkwBzsvznABUMTMDKgPfABpZLBJRqRLceCOsWuXVMl99FU44AQYO9IJqIsUplkRfB1gT9XNOZFu0B4CmwDrgM2B4CCG30GsAZpjZXDMbVNCLmNkgM8s0s8yNGzfGfAIiyaRqVbjtNk/4V18NTz4JTZrANdfAhg3xjk5SVSyJ3vLZlnfQWFcgC6gNpAMPmNlhkX3tQwgn4l0/Q8zs1PxeJIQwMYSQEULIqFGjRmzRiySpmjXhvvu8D/+yy+Af//C1bm+5Bb79Nt7RSaqJJdHnAPWifq6Lt9yj9QOeDy4b+Bw4HiCEsC7ydQMwDe8KEhF8/dpJk/xm7dlnw+23e8K/5x74+ed4RyepIpZE/ynQxMwaRm6w9gFeyvOcL4EuAGZ2FHAcsMrMKplZlcj2SsBZwKKiCl4kVRx3nE+4mjcP2rXz/vwTTvAyC5p0JQdrv4k+hLAdGAq8ASwFngkhLDazwWY2OPK024Bfm9lnwNvAjSGEr4GjgPfNbAEwB3g1hDC9OE5EJBW0bu0zat94w4donnuuD8nUCB05GBYSsLmQkZERMjM15F5Kt19+gQkT4NZb4ccfYfhwGDPGa+OL5GVmcwsawq6ZsSIJqlw5H42zYgVccQX8/e9w7LFeFlmLl0thKNGLJLiaNf2G7Zw5fqO2f3845RT45JN4RybJQoleJElkZMD778MTT3gNnXbtfD3bdXnHwInkoUQvkkTKlPFx98uX+8ic//wHGjeG0aM1/l4KpkQvkoSqVPHaOcuWwfnnw513erfOvffC1q3xjk4SjRK9SBJr1Aj+/W+viHnyyTBihJdUmDxZ69jKbkr0IikgPR1efx1mzYLatWHAAGjZEl54QROuRIleJKV06uSlkJ97zodgnn++r3I1e3a8I5N4UqIXSTFmcMEFvqzhxInwxRdw2mk+y/bzz+MdncSDEr1IiipbFn73O59wdeed3q1zwgleMO2XX+IdnZQkJXqRFFexIowc6RUyzzrLh2W2aeNdPFI6KNGLlBL16vnN2WnT4H//g1//Gq66SuPvSwMlepFSplcvb90PHw4PPwxNm8Izz2h0TipTohcphapU8R8ea0wAAAyiSURBVBWu5szx4ZgXX+wLn+hmbWpSohcpxdq08eJo48bBe+/pZm2qUqIXKeXKlvVunCVLoGtXv1nburXG3qcSJXoRAfxm7bRpfsN2yxYfe3/55fDVV/GOTA6WEr2I7OG887x1P3o0TJni69k++CDs2BHvyORAKdGLyF4qVoQ77oCFC70O/pAhXjRtzpx4RyYHQoleRAp0/PHw5pvesl+3zhc7GTwYvvkm3pFJYSjRi8g+mfnwy2XLfA3bSZO8O0dr1yYPJXoRiclhh/kC5fPmeaLv3x9OPRU++yzekcn+KNGLSKG0bOlDLydP9lb+iSf6kMwffoh3ZFIQJXoRKbQyZaBfP1+79vLLfZLVCSfAK6/EOzLJjxK9iByw6tXhkUe8hV+pEpxzDlx4IeTkxDsyiaZELyIHrWNHX7f2r3+F117zQmnjxmnd2kShRC8iReKQQ2DUKFi82BP/tddC27bw6afxjkyU6EWkSDVqBK++6qWP//tfn2g1dCh89128Iyu9lOhFpMiZwW9+46Nyhg71EgotW8I778Q7stJJiV5Eis1hh8H48fDhh1CuHHTuDH/8I2zdGu/ISpeYEr2ZdTOz5WaWbWYj89l/uJm9bGYLzGyxmfWL9VgRSX3t2kFWli9deN99Xgd/7tx4R1V67DfRm1kaMAHoDjQD+ppZszxPGwIsCSG0AjoBfzOzQ2I8VkRKgUqVYMIEeOMN2LzZk/9tt2lkTkmIpUXfFsgOIawKIWwDpgDn5XlOAKqYmQGVgW+A7TEeKyKlyFlnedmE3r1hzBho394nXknxiSXR1wHWRP2cE9kW7QGgKbAO+AwYHkLYGeOxAJjZIDPLNLPMjRs3xhi+iCSjatXg3/+GqVMhOxvS0+Ef/1CRtOISS6K3fLblXS++K5AF1AbSgQfM7LAYj/WNIUwMIWSEEDJq1KgRQ1gikux694ZFi+D002HYMG/tr1mz/+OkcGJJ9DlAvaif6+It92j9gOeDywY+B46P8VgRKcVq1fIaORMnwscfQ/PmXlYh5NsklAMRS6L/FGhiZg3N7BCgD/BSnud8CXQBMLOjgOOAVTEeKyKlnBn87ne+otWJJ8LAgdCtG3z5ZbwjSw37TfQhhO3AUOANYCnwTAhhsZkNNrPBkafdBvzazD4D3gZuDCF8XdCxxXEiIpL8GjWCt9/20TkffOAVMf/5T/XdHywLCfj5KCMjI2RmZsY7DBGJo9WrvZX/1ls+0eqRR6Bhw3hHlbjMbG4IISO/fZoZKyIJqUEDmDHD++4zM6FFC3jgAbXuD4QSvYgkrNy++9yKmFdf7a377Ox4R5ZclOhFJOHVq+d17h99FBYs8AJp990HO3bEO7LkoEQvIknBDK680lv3Xbp4cbQuXTTuPhZK9CKSVOrUgZde8tZ9Zia0agX/93/xjiqxKdGLSNLJbd1nZUHjxl77fuBA2LIl3pElJiV6EUlajRv7ePtRo2DyZJ9spfLHe1OiF5GkVq6cL0o+cyb89BOccgrcc4+GYUZToheRlNCpk4/IOfdcuPFGL5C2dm28o0oMSvQikjKOOAKefRYmTYKPPvJhmC+8EO+o4k+JXkRSihkMGADz5vns2vPPh0GD4Lvv4h1Z/CjRi0hKOu44b9XfcIO38Js3h9dfj3dU8aFELyIp65BD4O674cMPoXJl6NEDrrgCvvkm3pGVLCV6EUl57drB/Plw882+hGGzZvD88/GOquQo0YtIqXDooXDbbT6btnZtuPBCn2j11Vfxjqz4KdGLSKmSng6ffOJj7196yVv3Tz2V2ksXKtGLSKlTrpzPps3K8pu2l10GPXtCTk68IyseSvQiUmo1bQrvveclj2fN8tb944+nXuteiV5ESrW0NLjmGli0CFq39mJpl1wCmzfHO7Kio0QvIoIvTD5zpt+wffZZ78v/6KN4R1U0lOhFRCLS0nwI5nvv+c8dO8Lttyf/SlZK9CIieZxyit+o7d0bbrkFTj89uVeyUqIXEcnH4Yf75KrHHvMa961aJe8kKyV6EZECmHnJhPnz4ZhjfJLVoEHw44/xjqxwlOhFRPajSRNfyeqGG+Bf/4I2bbxrJ1ko0YuIxCC3QNqbb/rQy5NPhvvvT44x90r0IiKFcMYZvpJV164+/v7ssxO/Xo4SvYhIIdWoAS++CBMm+Izali1h+vR4R1UwJXoRkQNgBlddBZ9+CjVrQvfucO218PPP8Y5sb0r0IiIHoXlzmDMHrr4axo3zvvulS+Md1Z5iSvRm1s3MlptZtpmNzGf/CDPLijwWmdkOMzsism+1mX0W2ZdZ1CcgIhJvFSrA+PHw8suwdq2Pynn44cS5UbvfRG9macAEoDvQDOhrZs2inxNCGBtCSA8hpAOjgHdDCNGLdXWO7M8owthFRBJKz56wcCF06ACDB/u4+02b4h1VbC36tkB2CGFVCGEbMAU4bx/P7ws8XRTBiYgkm1q1/MbsvffCK6/4jNp3341vTLEk+jpAdJWHnMi2vZhZRaAb8FzU5gDMMLO5ZjaooBcxs0FmlmlmmRs3bowhLBGRxFSmDFx3HXz8MVSq5LVybrstfsXRYkn0ls+2gnqezgE+yNNt0z6EcCLe9TPEzE7N78AQwsQQQkYIIaNGjRoxhCUikthOPNHXqO3bF8aM8bH38RhzH0uizwHqRf1cF1hXwHP7kKfbJoSwLvJ1AzAN7woSESkVqlSBJ5+ESZO8jEJ6ute9L0mxJPpPgSZm1tDMDsGT+Ut5n2RmhwOnAS9GbatkZlVyvwfOAhYVReAiIsnCDAYM8GGYVav67No//ankunL2m+hDCNuBocAbwFLgmRDCYjMbbGaDo556PjAjhPBD1LajgPfNbAEwB3g1hJDA88dERIpPixY+weqyy+DPf4Yzz4T164v/dS0kykDPKBkZGSEzU0PuRSR1PfaYz6ytUsXr3p9xxsH9PjObW9AQds2MFRGJgyuv9Bu1Rx4JZ53lK1lt3148r6VELyISJ82aeVdOv36+Nm2XLrBlS9G/Ttmi/5UiIhKrihXhkUegUyefWFWpUtG/hhK9iEgCuOwyfxQHdd2IiKQ4JXoRkRSnRC8ikuKU6EVEUpwSvYhIilOiFxFJcUr0IiIpToleRCTFJWRRMzPbCHwRtelI4Os4hVNcUu2cUu18IPXOKdXOB1LvnA7mfOqHEPJdtSkhE31eZpaZaguLp9o5pdr5QOqdU6qdD6TeORXX+ajrRkQkxSnRi4ikuGRJ9BPjHUAxSLVzSrXzgdQ7p1Q7H0i9cyqW80mKPnoRETlwydKiFxGRA6RELyKS4hIu0ZvZZDPbYGaLorYdYWZvmtmKyNdq8YyxsAo4pz+Z2Vozy4o8esQzxsIws3pmNsvMlprZYjMbHtmelNdpH+eTzNeovJnNMbMFkXP6c2R7sl6jgs4naa8RgJmlmdl8M3sl8nOxXJ+E66M3s1OBLcATIYTmkW33AN+EEO4ys5FAtRDCjfGMszAKOKc/AVtCCPfGM7YDYWa1gFohhHlmVgWYC/QCriQJr9M+zqc3yXuNDKgUQthiZuWA94HhwAUk5zUq6Hy6kaTXCMDM/ghkAIeFEHoWV65LuBZ9CGE28E2ezecBj0e+fxz/I0waBZxT0gohrA8hzIt8/z2wFKhDkl6nfZxP0goud5npcpFHIHmvUUHnk7TMrC5wNjApanOxXJ+ES/QFOCqEsB78jxKoGed4ispQM1sY6dpJio/QeZlZA6A18AkpcJ3ynA8k8TWKdAtkARuAN0MISX2NCjgfSN5rNA64AdgZta1Yrk+yJPpU9BBwDJAOrAf+Ft9wCs/MKgPPAdeEEL6LdzwHK5/zSeprFELYEUJIB+oCbc2sebxjOhgFnE9SXiMz6wlsCCHMLYnXS5ZE/1WkHzW3P3VDnOM5aCGEryL/cXcC/wLaxjumwoj0kz4H/DuE8Hxkc9Jep/zOJ9mvUa4QwrfAO3h/dtJeo1zR55PE16g9cK6ZrQamAKeb2VMU0/VJlkT/EnBF5PsrgBfjGEuRyL2YEecDiwp6bqKJ3Bh7BFgaQvh71K6kvE4FnU+SX6MaZlY18n0F4AxgGcl7jfI9n2S9RiGEUSGEuiGEBkAfYGYI4VKK6fok4qibp4FOeLnOr4BbgReAZ4CjgS+B34QQkubmZgHn1An/uBmA1cDvc/vmEp2ZdQDeAz5jd//iaLxfO+mu0z7Opy/Je41a4jfz0vAG3TMhhL+YWXWS8xoVdD5PkqTXKJeZdQKuj4y6KZbrk3CJXkREilaydN2IiMgBUqIXEUlxSvQiIilOiV5EJMUp0YuIpDglehGRFKdELyKS4v4/sUJL0UwiXmIAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "df=pd.DataFrame(model.learning_process[10:], columns=['epoch', 'train_RMSE', 'test_RMSE'])\n", "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n", "plt.plot('epoch', 'test_RMSE', data=df, color='yellow', linestyle='dashed')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Saving and evaluating recommendations" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "model.estimations()\n", "\n", "top_n=pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n", "\n", "top_n.to_csv('Recommendations generated/ml-100k/Self_SVD_reco.csv', index=False, header=False)\n", "\n", "estimations=pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n", "estimations.to_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', index=False, header=False)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 10387.88it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
00.916330.7201530.1033930.0444550.0531770.0700730.0938840.0793660.1077920.0512810.200210.5189570.475080.8530220.1471863.9113560.971196
\n", "
" ], "text/plain": [ " RMSE MAE precision recall F_1 F_05 \\\n", "0 0.91633 0.720153 0.103393 0.044455 0.053177 0.070073 \n", "\n", " precision_super recall_super NDCG mAP MRR LAUC \\\n", "0 0.093884 0.079366 0.107792 0.051281 0.20021 0.518957 \n", "\n", " HR Reco in test Test coverage Shannon Gini \n", "0 0.47508 0.853022 0.147186 3.911356 0.971196 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import evaluation_measures as ev\n", "\n", "estimations_df=pd.read_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', header=None)\n", "reco=np.loadtxt('Recommendations generated/ml-100k/Self_SVD_reco.csv', delimiter=',')\n", "\n", "ev.evaluate(test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None),\n", " estimations_df=estimations_df, \n", " reco=reco,\n", " super_reactions=[4,5])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 10057.03it/s]\n", "943it [00:00, 10623.82it/s]\n", "943it [00:00, 9952.71it/s]\n", "943it [00:00, 10166.90it/s]\n", "943it [00:00, 10805.10it/s]\n", "943it [00:00, 10623.82it/s]\n", "943it [00:00, 10390.37it/s]\n", "943it [00:00, 10166.95it/s]\n", "943it [00:00, 11057.98it/s]\n", "943it [00:00, 10170.01it/s]\n", "943it [00:00, 10994.41it/s]\n", "943it [00:00, 10390.45it/s]\n", "943it [00:00, 10166.77it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Self_SVD0.9163300.7201530.1033930.0444550.0531770.0700730.0938840.0793660.1077920.0512810.2002100.5189570.4750800.8530220.1471863.9113560.971196
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Self_GlobalAvg1.1257600.9435340.0611880.0259680.0313830.0413430.0405580.0321070.0676950.0274700.1711870.5095460.3849421.0000000.0259742.7117720.992003
0Ready_Random1.5100301.2118480.0500530.0223670.0259840.0337270.0306870.0232550.0553920.0216020.1376900.5077130.3382820.9879110.1875905.1118780.906685
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNWithZScore0.9577010.7523870.0037120.0019940.0023800.0029190.0034330.0024010.0051370.0021580.0164580.4973490.0275720.3899260.0678212.4757470.992793
0Ready_I-KNNWithMeans0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_TopRated2.5082582.2179090.0009540.0001880.0002980.0004810.0006440.0002230.0010430.0003350.0033480.4964330.0095440.6990460.0050511.9459100.995669
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Self_SVD 0.916330 0.720153 0.103393 0.044455 0.053177 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n", "0 Ready_Random 1.510030 1.211848 0.050053 0.022367 0.025984 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNWithZScore 0.957701 0.752387 0.003712 0.001994 0.002380 \n", "0 Ready_I-KNNWithMeans 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_TopRated 2.508258 2.217909 0.000954 0.000188 0.000298 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.070073 0.093884 0.079366 0.107792 0.051281 0.200210 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n", "0 0.033727 0.030687 0.023255 0.055392 0.021602 0.137690 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.002919 0.003433 0.002401 0.005137 0.002158 0.016458 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.518957 0.475080 0.853022 0.147186 3.911356 0.971196 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.509546 0.384942 1.000000 0.025974 2.711772 0.992003 \n", "0 0.507713 0.338282 0.987911 0.187590 5.111878 0.906685 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.497349 0.027572 0.389926 0.067821 2.475747 0.992793 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import evaluation_measures as ev\n", "\n", "dir_path=\"Recommendations generated/ml-100k/\"\n", "super_reactions=[4,5]\n", "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embeddings" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1, 2],\n", " [3, 4]])" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "array([[0.4472136 , 0.89442719],\n", " [0.6 , 0.8 ]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x=np.array([[1,2],[3,4]])\n", "display(x)\n", "x/np.linalg.norm(x, axis=1)[:,None]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
codescoreitem_ididtitlegenres
013111.00000013121312Pompatus of Love, The (1996)Comedy, Drama
113950.99263813961396Stonewall (1995)Drama
214480.99195414491449Pather Panchali (1955)Drama
312010.99172612021202Maybe, Maybe Not (Bewegte Mann, Der) (1994)Comedy
414050.99141914061406When Night Is Falling (1995)Drama, Romance
512500.99110812511251A Chef in Love (1996)Comedy
612630.99110012641264Nothing to Lose (1994)Drama
711230.99106911241124Farewell to Arms, A (1932)Romance, War
811100.99104011111111Double Happiness (1994)Drama
911240.99088911251125Innocents, The (1961)Thriller
\n", "
" ], "text/plain": [ " code score item_id id title \\\n", "0 1311 1.000000 1312 1312 Pompatus of Love, The (1996) \n", "1 1395 0.992638 1396 1396 Stonewall (1995) \n", "2 1448 0.991954 1449 1449 Pather Panchali (1955) \n", "3 1201 0.991726 1202 1202 Maybe, Maybe Not (Bewegte Mann, Der) (1994) \n", "4 1405 0.991419 1406 1406 When Night Is Falling (1995) \n", "5 1250 0.991108 1251 1251 A Chef in Love (1996) \n", "6 1263 0.991100 1264 1264 Nothing to Lose (1994) \n", "7 1123 0.991069 1124 1124 Farewell to Arms, A (1932) \n", "8 1110 0.991040 1111 1111 Double Happiness (1994) \n", "9 1124 0.990889 1125 1125 Innocents, The (1961) \n", "\n", " genres \n", "0 Comedy, Drama \n", "1 Drama \n", "2 Drama \n", "3 Comedy \n", "4 Drama, Romance \n", "5 Comedy \n", "6 Drama \n", "7 Romance, War \n", "8 Drama \n", "9 Thriller " ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item=random.choice(list(set(train_ui.indices)))\n", "\n", "embeddings_norm=model.Qi/np.linalg.norm(model.Qi, axis=1)[:,None] # we do not mean-center here\n", "# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n", "\n", "similarity_scores=np.dot(embeddings_norm,embeddings_norm[item].T)\n", "top_similar_items=pd.DataFrame(enumerate(similarity_scores), columns=['code', 'score'])\\\n", ".sort_values(by=['score'], ascending=[False])[:10]\n", "\n", "top_similar_items['item_id']=top_similar_items['code'].apply(lambda x: item_code_id[x])\n", "\n", "items=pd.read_csv('./Datasets/ml-100k/movies.csv')\n", "\n", "result=pd.merge(top_similar_items, items, left_on='item_id', right_on='id')\n", "\n", "result" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# project task 5: implement SVD on top baseline (as it is in Surprise library)from tqdm import tqdm\n", "\n", "class SVDBaseline():\n", "\n", " def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n", " self.train_ui=train_ui\n", " self.uir=list(zip(*[train_ui.nonzero()[0],train_ui.nonzero()[1], train_ui.data]))\n", "\n", " self.learning_rate=learning_rate\n", " self.regularization=regularization\n", " self.iterations=iterations\n", " self.nb_users, self.nb_items=train_ui.shape\n", " self.nb_ratings=train_ui.nnz\n", " self.nb_factors=nb_factors\n", "\n", " self.Pu=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_users, self.nb_factors))\n", " self.Qi=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_items, self.nb_factors))\n", "\n", " self.b_u = np.zeros(self.nb_users)\n", " self.b_i = np.zeros(self.nb_items)\n", "\n", " def train(self, test_ui=None):\n", " if test_ui!=None:\n", " self.test_uir=list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n", "\n", " self.learning_process=[]\n", " pbar = tqdm(range(self.iterations))\n", " for i in pbar:\n", " pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...')\n", " np.random.shuffle(self.uir)\n", " self.sgd(self.uir)\n", " if test_ui==None:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir)])\n", " else:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n", "\n", " def sgd(self, uir):\n", "\n", " for u, i, score in uir:\n", " prediction = self.get_rating(u,i)\n", " e = (score - prediction)\n", " \n", " \n", " b_u_update = self.learning_rate * (e - self.regularization * self.b_u[u])\n", " b_i_update = self.learning_rate * (e - self.regularization * self.b_i[i])\n", " \n", " self.b_u[u] += b_u_update\n", " self.b_i[i] += b_i_update\n", " Pu_update=self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n", " Qi_update=self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n", " \n", " self.Pu[u] += Pu_update\n", " self.Qi[i] += Qi_update\n", " \n", " def get_rating(self, u, i):\n", " prediction = self.b_u[u] + self.b_i[i] + self.Pu[u].dot(self.Qi[i].T)\n", " return prediction\n", " \n", " def RMSE_total(self, uir):\n", " RMSE=0\n", " for u,i, score in uir:\n", " prediction = self.get_rating(u,i)\n", " RMSE+=(score - prediction)**2\n", " return np.sqrt(RMSE/len(uir))\n", " \n", " def estimations(self):\n", " self.estimations= self.b_u[:,np.newaxis] + self.b_i[np.newaxis:,] + np.dot(self.Pu,self.Qi.T)\n", "\n", " def recommend(self, user_code_id, item_code_id, topK=10):\n", " \n", " top_k = defaultdict(list)\n", " for nb_user, user in enumerate(self.estimations):\n", " \n", " user_rated=self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user+1]]\n", " for item, score in enumerate(user):\n", " if item not in user_rated and not np.isnan(score):\n", " top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n", " result=[]\n", " for uid, item_scores in top_k.items():\n", " item_scores.sort(key=lambda x: x[1], reverse=True)\n", " result.append([uid]+list(chain(*item_scores[:topK])))\n", " return result\n", " \n", " def estimate(self, user_code_id, item_code_id, test_ui):\n", " result=[]\n", " for user, item in zip(*test_ui.nonzero()):\n", " result.append([user_code_id[user], item_code_id[item], \n", " self.estimations[user,item] if not np.isnan(self.estimations[user,item]) else 1])\n", " return result" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# making changes to our implementation by considering additional parameters in the gradient descent procedure \n", "# seems to be the fastest option\n", "# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n", "# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Ready-made SVD - Surprise implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "import imp\n", "imp.reload(helpers)\n", "\n", "algo = sp.SVD(biased=False) # to use unbiased version\n", "\n", "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVD_reco.csv',\n", " estimations_path='Recommendations generated/ml-100k/Ready_SVD_estimations.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD biased - on top baseline" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "import imp\n", "imp.reload(helpers)\n", "\n", "algo = sp.SVD() # default is biased=True\n", "\n", "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv',\n", " estimations_path='Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 9346.25it/s]\n", "943it [00:00, 10163.97it/s]\n", "943it [00:00, 10745.39it/s]\n", "943it [00:00, 10228.45it/s]\n", "943it [00:00, 10502.49it/s]\n", "943it [00:00, 10627.14it/s]\n", "943it [00:00, 10620.29it/s]\n", "943it [00:00, 10991.05it/s]\n", "943it [00:00, 9955.97it/s]\n", "943it [00:00, 10510.36it/s]\n", "943it [00:00, 9746.48it/s]\n", "943it [00:00, 10164.36it/s]\n", "943it [00:00, 10693.82it/s]\n", "943it [00:00, 10927.37it/s]\n", "943it [00:00, 10620.60it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Ready_SVD0.9491650.7466670.0939550.0449690.0511970.0654740.0839060.0739960.1046720.0482110.2207570.5191870.4835630.9979850.2049064.4089130.954288
0Self_SVD0.9163300.7201530.1033930.0444550.0531770.0700730.0938840.0793660.1077920.0512810.2002100.5189570.4750800.8530220.1471863.9113560.971196
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Ready_SVDBiased0.9381460.7399170.0865320.0370670.0448320.0588770.0780040.0578650.0945830.0430130.2023910.5152020.4337220.9960760.1666674.1683540.964092
0Self_GlobalAvg1.1257600.9435340.0611880.0259680.0313830.0413430.0405580.0321070.0676950.0274700.1711870.5095460.3849421.0000000.0259742.7117720.992003
0Ready_Random1.5100301.2118480.0500530.0223670.0259840.0337270.0306870.0232550.0553920.0216020.1376900.5077130.3382820.9879110.1875905.1118780.906685
0Ready_I-KNN1.0303860.8130670.0260870.0069080.0105930.0160460.0211370.0095220.0242140.0089580.0480680.4998850.1548250.4023330.4343435.1336500.877999
0Ready_I-KNNWithZScore0.9577010.7523870.0037120.0019940.0023800.0029190.0034330.0024010.0051370.0021580.0164580.4973490.0275720.3899260.0678212.4757470.992793
0Ready_I-KNNWithMeans0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_I-KNNBaseline0.9353270.7374240.0025450.0007550.0011050.0016020.0022530.0009300.0034440.0013620.0117600.4967240.0212090.4828210.0598852.2325780.994487
0Ready_U-KNN1.0234950.8079130.0007420.0002050.0003050.0004490.0005360.0001980.0008450.0002740.0027440.4964410.0074230.6021210.0108232.0891860.995706
0Self_TopRated2.5082582.2179090.0009540.0001880.0002980.0004810.0006440.0002230.0010430.0003350.0033480.4964330.0095440.6990460.0050511.9459100.995669
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
0Self_IKNN1.0183630.8087930.0003180.0001080.0001400.0001890.0000000.0000000.0002140.0000370.0003680.4963910.0031810.3921530.1154404.1747410.965327
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall F_1 \\\n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n", "0 Ready_SVD 0.949165 0.746667 0.093955 0.044969 0.051197 \n", "0 Self_SVD 0.916330 0.720153 0.103393 0.044455 0.053177 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n", "0 Ready_SVDBiased 0.938146 0.739917 0.086532 0.037067 0.044832 \n", "0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n", "0 Ready_Random 1.510030 1.211848 0.050053 0.022367 0.025984 \n", "0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n", "0 Ready_I-KNNWithZScore 0.957701 0.752387 0.003712 0.001994 0.002380 \n", "0 Ready_I-KNNWithMeans 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n", "0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n", "0 Self_TopRated 2.508258 2.217909 0.000954 0.000188 0.000298 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n", "0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n", "\n", " F_05 precision_super recall_super NDCG mAP MRR \\\n", "0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n", "0 0.065474 0.083906 0.073996 0.104672 0.048211 0.220757 \n", "0 0.070073 0.093884 0.079366 0.107792 0.051281 0.200210 \n", "0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n", "0 0.058877 0.078004 0.057865 0.094583 0.043013 0.202391 \n", "0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n", "0 0.033727 0.030687 0.023255 0.055392 0.021602 0.137690 \n", "0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n", "0 0.002919 0.003433 0.002401 0.005137 0.002158 0.016458 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n", "0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n", "0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n", "0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n", "0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n", "\n", " LAUC HR Reco in test Test coverage Shannon Gini \n", "0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n", "0 0.519187 0.483563 0.997985 0.204906 4.408913 0.954288 \n", "0 0.518957 0.475080 0.853022 0.147186 3.911356 0.971196 \n", "0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n", "0 0.515202 0.433722 0.996076 0.166667 4.168354 0.964092 \n", "0 0.509546 0.384942 1.000000 0.025974 2.711772 0.992003 \n", "0 0.507713 0.338282 0.987911 0.187590 5.111878 0.906685 \n", "0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n", "0 0.497349 0.027572 0.389926 0.067821 2.475747 0.992793 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n", "0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n", "0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n", "0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n", "0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import imp\n", "imp.reload(ev)\n", "\n", "import evaluation_measures as ev\n", "dir_path=\"Recommendations generated/ml-100k/\"\n", "super_reactions=[4,5]\n", "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }