warsztaty2/P4. Matrix Factorization.ipynb
2020-06-16 19:40:37 +02:00

1644 lines
83 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Self made SVD"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import helpers\n",
"import pandas as pd\n",
"import numpy as np\n",
"import scipy.sparse as sparse\n",
"from collections import defaultdict\n",
"from itertools import chain\n",
"import random\n",
"\n",
"train_read=pd.read_csv('./Datasets/ml-100k/train.csv', sep='\\t', header=None)\n",
"test_read=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n",
"train_ui, test_ui, user_code_id, user_id_code, item_code_id, item_id_code = helpers.data_to_csr(train_read, test_read)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n",
"from tqdm import tqdm\n",
"\n",
"class SVD():\n",
" \n",
" def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n",
" self.train_ui=train_ui\n",
" self.uir=list(zip(*[train_ui.nonzero()[0],train_ui.nonzero()[1], train_ui.data]))\n",
" \n",
" self.learning_rate=learning_rate\n",
" self.regularization=regularization\n",
" self.iterations=iterations\n",
" self.nb_users, self.nb_items=train_ui.shape\n",
" self.nb_ratings=train_ui.nnz\n",
" self.nb_factors=nb_factors\n",
" \n",
" self.Pu=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_users, self.nb_factors))\n",
" self.Qi=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_items, self.nb_factors))\n",
"\n",
" def train(self, test_ui=None):\n",
" if test_ui!=None:\n",
" self.test_uir=list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n",
" \n",
" self.learning_process=[]\n",
" pbar = tqdm(range(self.iterations))\n",
" for i in pbar:\n",
" pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...')\n",
" np.random.shuffle(self.uir)\n",
" self.sgd(self.uir)\n",
" if test_ui==None:\n",
" self.learning_process.append([i+1, self.RMSE_total(self.uir)])\n",
" else:\n",
" self.learning_process.append([i+1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n",
" \n",
" def sgd(self, uir):\n",
" \n",
" for u, i, score in uir:\n",
" # Computer prediction and error\n",
" prediction = self.get_rating(u,i)\n",
" e = (score - prediction)\n",
" \n",
" # Update user and item latent feature matrices\n",
" Pu_update=self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n",
" Qi_update=self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n",
" \n",
" self.Pu[u] += Pu_update\n",
" self.Qi[i] += Qi_update\n",
" \n",
" def get_rating(self, u, i):\n",
" prediction = self.Pu[u].dot(self.Qi[i].T)\n",
" return prediction\n",
" \n",
" def RMSE_total(self, uir):\n",
" RMSE=0\n",
" for u,i, score in uir:\n",
" prediction = self.get_rating(u,i)\n",
" RMSE+=(score - prediction)**2\n",
" return np.sqrt(RMSE/len(uir))\n",
" \n",
" def estimations(self):\n",
" self.estimations=\\\n",
" np.dot(self.Pu,self.Qi.T)\n",
"\n",
" def recommend(self, user_code_id, item_code_id, topK=10):\n",
" \n",
" top_k = defaultdict(list)\n",
" for nb_user, user in enumerate(self.estimations):\n",
" \n",
" user_rated=self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user+1]]\n",
" for item, score in enumerate(user):\n",
" if item not in user_rated and not np.isnan(score):\n",
" top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n",
" result=[]\n",
" # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n",
" for uid, item_scores in top_k.items():\n",
" item_scores.sort(key=lambda x: x[1], reverse=True)\n",
" result.append([uid]+list(chain(*item_scores[:topK])))\n",
" return result\n",
" \n",
" def estimate(self, user_code_id, item_code_id, test_ui):\n",
" result=[]\n",
" for user, item in zip(*test_ui.nonzero()):\n",
" result.append([user_code_id[user], item_code_id[item], \n",
" self.estimations[user,item] if not np.isnan(self.estimations[user,item]) else 1])\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 39 RMSE: 0.7503496297849689. Training epoch 40...: 100%|██████████| 40/40 [01:12<00:00, 1.81s/it]\n"
]
}
],
"source": [
"model=SVD(train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40)\n",
"model.train(test_ui)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x1491c088>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAb/ElEQVR4nO3dfXRV9Z3v8fc3IRAgEciDmiEg2DIznaIGjEErdlC8dwm3XqnLhbimo3VcwxrFglPRW2m1taNrbBkfhmuVeq9ae29XdTo+VF1WaxUHbVVKMCBPCr1qCTDlMUAEBJLv/WOfQ07CSXJOcpJ9svfntdZvnX323uecr9vw2fv89m/vY+6OiIgMfAVhFyAiIrmhQBcRiQgFuohIRCjQRUQiQoEuIhIRg8L64IqKCh83blxYHy8iMiDV19fvcvfKdMtCC/Rx48axcuXKsD5eRGRAMrNPOlumLhcRkYhQoIuIRIQCXUQkIkLrQxeRgefo0aM0NjZy+PDhsEuJvOLiYqqrqykqKsr4NQp0EclYY2MjpaWljBs3DjMLu5zIcnd2795NY2Mj48ePz/h16nIRkYwdPnyY8vJyhXkfMzPKy8uz/iakQBeRrCjM+0dPtvOAC/SPPoKbboKjR8OuREQkvwy4QH//ffjXf4WlS8OuREQkvwy4QL/0UrjoIvje92Dv3rCrEZH+1NTUxEMPPZT162bOnElTU1PWr/v617/O+PHjqamp4ayzzuK11147vmzatGmMHTuW1B8JmjVrFiUlJQC0trYyf/58Jk6cyBlnnME555zDRx99BARXyp9xxhnU1NRQU1PD/Pnzs64tnQE3ysUM7rsPJk2C738f7r8/7IpEpL8kA/2GG25oN7+lpYXCwsJOX/fSSy/1+DMXL17MFVdcwbJly5g7dy6bNm06vmzkyJH89re/ZerUqTQ1NbF9+/bjy5566im2bdvGmjVrKCgooLGxkeHDhx9fvmzZMioqKnpcVzoDLtABzjoLrrsOHnwQrr8e/vzPw65IJH5uugkaGnL7njU18MADnS//1re+xR/+8AdqamooKiqipKSEqqoqGhoaWL9+PbNmzWLLli0cPnyYBQsWMHfuXKDt3lHNzc3MmDGDqVOn8rvf/Y7Ro0fzy1/+kqFDh3Zb23nnncfWrVvbzZszZw5PPvkkU6dO5ZlnnuHyyy9n3bp1AGzfvp2qqioKCoKOkOrq6h5ulcwNuC6XpH/6JyguhltvDbsSEekv99xzD5/73OdoaGhg8eLFrFixgrvvvpv169cD8Nhjj1FfX8/KlStZsmQJu3fvPuE9Nm3axLx581i3bh0jR47k6aefzuizX375ZWbNmtVu3vTp01m+fDktLS08+eSTXHnllceXzZ49mxdeeIGamhpuvvlm3nvvvXavvfDCC493udyfo66GAXmEDnDqqbBoUdCWLYMLLwy7IpF46epIur/U1dW1u/BmyZIlPPvsswBs2bKFTZs2UV5e3u41yT5xgLPPPpuPP/64y8+45ZZbuPXWW9mxYwfvvPNOu2WFhYVMnTqVp556ikOHDpF6S/Dq6mo++OADXn/9dV5//XWmT5/OL37xC6ZPnw70TZfLgD1CB/jHf4TTTgseW1rCrkZE+ltqn/Qbb7zBb37zG95++21Wr17NpEmT0l6YM2TIkOPThYWFHDt2rMvPWLx4MZs3b+auu+7immuuOWH5nDlz+MY3vsHs2bPTftaMGTNYvHgxixYt4rnnnsvmPy9rAzrQi4vhBz+A1avhJz8JuxoR6WulpaUcOHAg7bJ9+/YxatQohg0bxsaNG084mu6NgoICFixYQGtrK6+88kq7ZRdccAG33XYbV111Vbv5q1atYtu2bUAw4mXNmjWcdtppOaspbZ19+u79YPZs+NKX4Nvfhk7+P4tIRJSXl3P++eczceJEbrnllnbLLrnkEo4dO8aZZ57J7bffzrnnnpvTzzYzvvOd7/DDH/7whPkLFy48oftkx44dXHrppUycOJEzzzyTQYMGceONNx5fntqHfvXVV+emxtQxlP2ptrbWc/WLRStWwJQpQX/63Xfn5C1FJI0NGzbwhS98IewyYiPd9jazenevTbf+gD9CB6irg7/5G7j3Xvik0x9nEhGJtm4D3cyKzWyFma02s3VmdmeadaaZ2T4za0i0O/qm3M798z9DQQHcdlt/f7KIDHTz5s073v2RbI8//njYZWUtk2GLnwEXuXuzmRUBb5nZr9y94xmHN939K7kvMTNjxsDChcH49PnzIcfdZyKS4O6Ru+Pij370o7BLOEFPusO7PUL3QHPiaVGihdPx3o1bb4WqKrj55rArEYmm4uJidu/e3aOwkcwlf+CiuLg4q9dldGGRmRUC9cDngR+5+7tpVjvPzFYD24CF7r4uzfvMBeYCjB07NqtCM1FSEtwS4K674NgxGDRgL5sSyU/V1dU0Njayc+fOsEuJvORP0GUjo8hz9xagxsxGAs+a2UR3X5uyyirgtES3zEzgOWBCmvd5BHgEglEuWVWaoVNOCR737oXKyr74BJH4Kioqyuon0aR/ZTXKxd2bgDeASzrM35/slnH3l4AiM8vtNa0ZKisLHtPcwkFEJNIyGeVSmTgyx8yGAhcDGzusc6olzpKYWV3ifUOJ1ORtG/bsCePTRUTCk0mXSxXwRKIfvQD4N3d/0cz+AcDdlwJXANeb2THgEDDHQzprkgx0HaGLSNx0G+juvgaYlGb+0pTpB4EHc1tazyjQRSSuInGlaCoFuojEVeQCvbQ0GK6oQBeRuIlcoJsFR+kKdBGJm8gFOgRDFxXoIhI3kQx0HaGLSBxFNtA1Dl1E4iayga4jdBGJGwW6iEhERDbQDx+GgwfDrkREpP9ENtBBR+kiEi8KdBGRiIhkoOsWuiISR5EMdB2hi0gcRTrQNRZdROIk0oGuI3QRiZNIBvrgwcEPRivQRSROIhnooIuLRCR+FOgiIhER2UDXLXRFJG4iG+g6QheRuIl0oGvYoojESaQDfe9eaG0NuxIRkf4R6UBvbYWmprArERHpH5EOdFA/uojER7eBbmbFZrbCzFab2TozuzPNOmZmS8xss5mtMbPJfVNu5hToIhI3gzJY5zPgIndvNrMi4C0z+5W7v5OyzgxgQqJNAR5OPIZGd1wUkbjp9gjdA82Jp0WJ5h1Wuwz4aWLdd4CRZlaV21KzoyN0EYmbjPrQzazQzBqAHcCr7v5uh1VGA1tSnjcm5nV8n7lmttLMVu7cubOnNWdEgS4icZNRoLt7i7vXANVAnZlN7LCKpXtZmvd5xN1r3b22srIy+2qzMGIEFBRoLLqIxEdWo1zcvQl4A7ikw6JGYEzK82pgW68q66WCAl3+LyLxkskol0ozG5mYHgpcDGzssNrzwNWJ0S7nAvvcfXvOq82SLv8XkTjJZJRLFfCEmRUS7AD+zd1fNLN/AHD3pcBLwExgM3AQuLaP6s2KAl1E4qTbQHf3NcCkNPOXpkw7MC+3pfVeeTls2dL9eiIiURDZK0VBfegiEi+RDnR1uYhInEQ+0A8ehMOHw65ERKTvRT7QQWPRRSQeYhHo6nYRkThQoIuIRIQCXUQkIiId6LqFrojESaQDXUfoIhInkQ70oUODpkAXkTiIdKCDLi4SkfiIRaBrHLqIxEEsAl1H6CISBwp0EZGIUKCLiERE5AO9rCzoQ/cTfuFURCRaIh/o5eXQ0gL79oVdiYhI34pFoIO6XUQk+mIT6Bq6KCJRF5tA1xG6iESdAl1EJCIU6CIiERH5QB85EswU6CISfZEP9MLCINQV6CISdd0GupmNMbNlZrbBzNaZ2YI060wzs31m1pBod/RNuT2jq0VFJA4GZbDOMeBmd19lZqVAvZm96u7rO6z3prt/Jfcl9p4CXUTioNsjdHff7u6rEtMHgA3A6L4uLJd0C10RiYOs+tDNbBwwCXg3zeLzzGy1mf3KzL6Yg9pyRkfoIhIHmXS5AGBmJcDTwE3uvr/D4lXAae7ebGYzgeeACWneYy4wF2Ds2LE9LjpbCnQRiYOMjtDNrIggzH/m7s90XO7u+929OTH9ElBkZhVp1nvE3WvdvbaysrKXpWeurAwOHIAjR/rtI0VE+l0mo1wMeBTY4O73dbLOqYn1MLO6xPvmzTGx7uciInGQSZfL+cDfAu+bWUNi3iJgLIC7LwWuAK43s2PAIWCOe/7cgTz1atFTTw23FhGRvtJtoLv7W4B1s86DwIO5KirXdPm/iMRB5K8UBQW6iMRDrAJdfegiEmWxCnQdoYtIlMUi0IcNgyFDFOgiEm2xCHSzYCy6Al1EoiwWgQ66WlREok+BLiISEQp0EZGIUKCLiERErAJ9zx7InxsSiIjkVqwC/ehRaG4OuxIRkb4Rm0AvKwse1e0iIlEVm0DX1aIiEnUKdBGRiFCgi4hEhAJdRCQiYhPoOikqIlEXm0AfNAhGjNA90UUkumIT6KA7LopItMUq0HX5v4hEmQJdRCQiFOgiIhGhQBcRiYjYBfq+fXDsWNiViIjkXuwCHWDv3nDrEBHpC90GupmNMbNlZrbBzNaZ2YI065iZLTGzzWa2xswm9025vaOrRUUkygZlsM4x4GZ3X2VmpUC9mb3q7utT1pkBTEi0KcDDice8oqtFRSTKuj1Cd/ft7r4qMX0A2ACM7rDaZcBPPfAOMNLMqnJebS/pCF1EoiyrPnQzGwdMAt7tsGg0sCXleSMnhj5mNtfMVprZyp07d2ZXaQ4o0EUkyjIOdDMrAZ4GbnL3/R0Xp3nJCb/e6e6PuHutu9dWVlZmV2kOKNBFJMoyCnQzKyII85+5+zNpVmkExqQ8rwa29b683CotDW7SpUAXkSjKZJSLAY8CG9z9vk5Wex64OjHa5Vxgn7tvz2GdOWGmi4tEJLoyGeVyPvC3wPtm1pCYtwgYC+DuS4GXgJnAZuAgcG3uS82N8nLdQldEoqnbQHf3t0jfR566jgPzclVUXyorg127wq5CRCT3YnWlKMCECbBuHfgJp2xFRAa22AV6XV1whP7xx2FXIiKSW7EL9CmJ61dXrAi3DhGRXItdoE+cCMXFCnQRiZ7YBXpREUyerEAXkeiJXaBD0I9eX6/7ootItMQ20A8dCka7iIhERWwDHeDdjrcYExEZwGIZ6KefHlxgpH50EYmSWAa6WXCUrkAXkSiJZaBDMB593Tpobg67EhGR3IhtoNfVQWsrrFoVdiUiIrkR20A/55zgUd0uIhIVsQ30ykoYP16BLiLREdtAB50YFZFoiX2gf/IJ/OlPYVciItJ7sQ900FG6iERDrAN98mQoLFSgi0g0xDrQhw2DM85QoItINMQ60KHtxKh+kk5EBjoFeh00NcHmzWFXIiLSOwp0nRgVkYiIfaD/1V/B8OEKdBEZ+GIf6IWFcPbZuje6iAx83Qa6mT1mZjvMbG0ny6eZ2T4za0i0O3JfZt+qq4P33oMjR8KuRESk5zI5Qv8JcEk367zp7jWJ9v3el9W/6uqCMF+zJuxKRER6rttAd/flwJ5+qCU0U6YEj+pHF5GBLFd96OeZ2Woz+5WZfbGzlcxsrpmtNLOVO3fuzNFH996YMXDKKQp0ERnYchHoq4DT3P0s4H8Cz3W2ors/4u617l5bWVmZg4/ODf0knYhEQa8D3d33u3tzYvoloMjMKnpdWT+rq4ONG2HfvrArERHpmV4HupmdamaWmK5LvOfu3r5vf6urCy7/r68PuxIRkZ4Z1N0KZvZzYBpQYWaNwHeBIgB3XwpcAVxvZseAQ8Ac94F3Z5Ta2uDx3XfhoovCrUVEpCe6DXR3v6qb5Q8CD+asopCUlcGECepHF5GBK/ZXiqbSiVERGcgU6CmmTIFt22Dr1rArERHJngI9he68KCIDmQI9xVlnQVERLFsWdiUiItlToKcoLoYrr4RHHoGPPgq7GhGR7CjQO7jnHhg0CBYuDLsSEZHsKNA7GD0aFi2CZ56B118PuxoRkcwp0NP45jfh9NNhwQI4dizsakREMqNAT6O4GO69F9auhaVLw65GRCQzCvROXHYZXHwx3HEH7B5wd6YRkThSoHfCDB54APbvh9tvD7saEZHuKdC78MUvwg03wI9/DKtXh12NiEjXFOjduPNOGDUqOEE68O4hKSJxokDvxqhRcNdd8B//Af/+72FXIyLSOQV6Bv7+74PbAixcCAcPhl2NiEh6CvQMFBbCkiXwxz/C4sVhVyMikp4CPUNf/jLMng0/+AFs2hR2NSIiJ1KgZ2HxYhgyBM49F155JexqRETaU6BnYexY+P3voboaZsyA738fWlvDrkpEJKBAz9LnPw9vvw1f+xp897tw6aWwZ0/YVYmIKNB7ZNgweOIJeOghePVVOPtsWLUq7KpEJO4U6D1kBtdfD2++GdyR8UtfgsceC7sqEYkzBXovTZkSHJ1fcAFcd13Qtm8PuyoRiSMFeg5UVsLLLwc/jPH448HJ0yuvhOXLdbsAEek/3Qa6mT1mZjvMbG0ny83MlpjZZjNbY2aTc19m/isshLvvhg8+gPnz4de/hr/+azjzTHj4YThwIOwKRSTqMjlC/wlwSRfLZwATEm0u8HDvyxq4JkwIfhxj61Z49FEYPDi4Y+Of/RnMmwcNDRrqKCJ9o9tAd/flQFcD8y4DfuqBd4CRZlaVqwIHqmHD4O/+DlauhHfegcsvDwJ+0iQ4+WT46lfh/vuhvl4/cyciuTEoB+8xGtiS8rwxMU+nBglGw0yZErR774UXXgj61pcvh+eeC9YpLYXzzw9uLzBlCvzlX0JVVfBaEZFM5SLQ08VO2lOBZjaXoFuGsWPH5uCjB5aKCrj22qABNDYGwx6TAb9oUdu6paVBsP/FXwSPyenx42H48HDqF5H8Zp7BMAwzGwe86O4T0yz7MfCGu/888fwDYJq7d3mEXltb6ytXruxJzZG1cyesWQMbN7a1Dz6ALVvarzdiRNAnP3p0W0s+P/nkYNRNRUWwXoHGMYlEipnVu3ttumW5OEJ/HrjRzJ4EpgD7ugtzSa+yEqZPD1qq5mb48MMg4P/4x+CE69atsG0bvPZaMO69peXE9yssDIK9oiJ47/Ly4Ac7Ro2CkSPTtxEj4KSToKREXT4iA023gW5mPwemARVm1gh8FygCcPelwEvATGAzcBC4tq+KjauSEpg8OWjptLTAjh1BwO/cGbRdu4KWOr1+PTQ1Be3Qoa4/0ywI9mTAJ1tJSdCGD2+bTp3XVRs2TN8YRPpSt4Hu7ld1s9yBeTmrSLJWWBicRK3KYmzRZ5+1hXtq278f9u0LWnI6+bhrF3zySfCNobkZPv0UjhzJrtahQ9uHfOqOYNiwtseOreMOpLS0/fSwYfpGIZKLLhcZgIYMgVNOCVpvHDkSBHsy4NO1gwfTL0+d19gYfGs4eDBon34Khw9nXodZ+m8NHXcCyR1Bcjr1ecdvHkOHaichA4sCXXpl8OCgjRqV+/dubW0L+eQOoLk5uOq24/SBA+3XSbbdu+Hjj9vWOXAg8wu7zNrvDJLdTsnp1MfuWklJsJ1E+pICXfJWQUFbd0xlZW7e0z048k/dESSnO+4Qks+T6xw4EHQ/ffRR2/T+/XD0aGafPXhw+m8G6XYSHadHjGh/PmOQ/uVKGvqzkFgxC7pShg4NhnjmwmeftQ/9dK3jjiHZmpqCYan797fNy+SGbsOHtw/6ZEuOVEo3L/XxpJN0gjqKFOgivTRkSNAqKnr/Xq2tQRdTMuCT3wJST053PGmdbJ980jbd3SgmCEJ95Mi2oaypraysbTq5E0gd2qrzC/lJgS6SRwoK2vrte+PIkSDwm5raQj45nfrY1AR79wbtww/bprvbIRQVte0MysuDHUBZWefTyR2ELnbrWwp0kQgaPLjtorKe+OyzINg7DmtN3RE0NQW/p7tnT3ANxNq1wUno5ubO37egINgRpAZ9x+BPnZe86rm0VN8IMqFAF5ETDBkCp54atGwdORLsDHbvDsJ+79624E+25PJdu4JvBnv2BDuIzhQVtb/qOTldXt75Yxyvdlagi0hODR7cs2scWlqCUE/uCHbtagv91Cued+0KflcguV5nJ5GHDAnCv7OWDP7kTqCsLNhxDGQKdBHJC4WFbQGbqeROIDX8U3cCqW3TpuCxqy6hk05qC/jU8E/e9K7jdL7d+VSBLiIDVk92AocPB8G+e3f7nUDHHcJ//ie8/36wbmdXLQ8dmj74Tz75xFZZGazflxToIhIrxcUwZkzQMuEeHNV3POJPth072qY3bAgeDx5M/14lJUG433AD3Hxz7v6bkhToIiJdMGu7uvf00zN7zaeftoV9upbNjfSyoUAXEcmx5C0rxo3r38/VEH8RkYhQoIuIRIQCXUQkIhToIiIRoUAXEYkIBbqISEQo0EVEIkKBLiISEeaZ/N5VX3yw2U7gky5WqQB29VM52VJtPaPaeka19UxUazvN3dP+ym5ogd4dM1vp7rVh15GOausZ1dYzqq1n4libulxERCJCgS4iEhH5HOiPhF1AF1Rbz6i2nlFtPRO72vK2D11ERLKTz0foIiKSBQW6iEhE5F2gm9klZvaBmW02s2+FXU8qM/vYzN43swYzWxlyLY+Z2Q4zW5syr8zMXjWzTYnHUXlU2/fMbGti2zWY2cyQahtjZsvMbIOZrTOzBYn5oW+7LmoLfduZWbGZrTCz1Yna7kzMz4ft1lltoW+3lBoLzew9M3sx8bxPtlte9aGbWSHwIfBfgEbg98BV7r4+1MISzOxjoNbdQ79Ywcy+DDQDP3X3iYl5PwT2uPs9iZ3hKHf/H3lS2/eAZnf/l/6up0NtVUCVu68ys1KgHpgFfJ2Qt10Xtc0m5G1nZgYMd/dmMysC3gIWAJcT/nbrrLZLyIO/OQAz+yZQC5zk7l/pq3+r+XaEXgdsdvf/5+5HgCeBy0KuKS+5+3JgT4fZlwFPJKafIAiDftdJbXnB3be7+6rE9AFgAzCaPNh2XdQWOg80J54WJZqTH9uts9rygplVA/8N+N8ps/tku+VboI8GtqQ8byRP/qATHPi1mdWb2dywi0njFHffDkE4ACeHXE9HN5rZmkSXTCjdQanMbBwwCXiXPNt2HWqDPNh2iW6DBmAH8Kq7581266Q2yIPtBjwA3Aq0pszrk+2Wb4FuaeblzZ4WON/dJwMzgHmJrgXJzMPA54AaYDtwb5jFmFkJ8DRwk7vvD7OWjtLUlhfbzt1b3L0GqAbqzGxiGHWk00ltoW83M/sKsMPd6/vj8/It0BuBMSnPq4FtIdVyAnfflnjcATxL0EWUT/6U6IdN9sfuCLme49z9T4l/dK3A/yLEbZfoZ30a+Jm7P5OYnRfbLl1t+bTtEvU0AW8Q9FHnxXZLSq0tT7bb+cB/T5x/exK4yMz+L3203fIt0H8PTDCz8WY2GJgDPB9yTQCY2fDEiSrMbDjwX4G1Xb+q3z0PXJOYvgb4ZYi1tJP84034KiFtu8QJtEeBDe5+X8qi0LddZ7Xlw7Yzs0ozG5mYHgpcDGwkP7Zb2tryYbu5+23uXu3u4wjy7HV3/xp9td3cPa8aMJNgpMsfgG+HXU9KXacDqxNtXdi1AT8n+Bp5lOCbzXVAOfAasCnxWJZHtf0f4H1gTeKPuSqk2qYSdOOtARoSbWY+bLsuagt92wFnAu8lalgL3JGYnw/brbPaQt9uHeqcBrzYl9str4YtiohIz+Vbl4uIiPSQAl1EJCIU6CIiEaFAFxGJCAW6iEhEKNBFRCJCgS4iEhH/H5aRnDSMJbJNAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"df=pd.DataFrame(model.learning_process).iloc[:,:2]\n",
"df.columns=['epoch', 'train_RMSE']\n",
"plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x40e64d8>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3de5zOdf7/8cfLUI5FonXKISpyGJmkRZHKIZUOK9pODmttRG0pVNrdajtoN9nUZqXjbtS3dE4qSmcNhhx/hpTBRtqUSsL798frGi5jhmuYmeswz/vtdt1m5vO5PnO9Pj7mdb2v9+f9fr0thICIiKSuMvEOQEREipcSvYhIilOiFxFJcUr0IiIpToleRCTFlY13APk58sgjQ4MGDeIdhohI0pg7d+7XIYQa+e1LyETfoEEDMjMz4x2GiEjSMLMvCtq3364bM5tsZhvMbFEB+83MxptZtpktNLMTo/Z1M7PlkX0jDyx8ERE5GLH00T8GdNvH/u5Ak8hjEPAQgJmlARMi+5sBfc2s2cEEKyIihbffRB9CmA18s4+nnAc8EdzHQFUzqwW0BbJDCKtCCNuAKZHniohICSqKPvo6wJqon3Mi2/LbfnJBv8TMBuGfCDj66KOLICwRKSm//PILOTk5bN26Nd6hpLzy5ctTt25dypUrF/MxRZHoLZ9tYR/b8xVCmAhMBMjIyFABHpEkkpOTQ5UqVWjQoAFm+f3pS1EIIbBp0yZycnJo2LBhzMcVxTj6HKBe1M91gXX72C4iKWbr1q1Ur15dSb6YmRnVq1cv9Cenokj0LwGXR0bftAM2hxDWA58CTcysoZkdAvSJPFdEUpCSfMk4kH/n/XbdmNnTQCfgSDPLAW4FygGEEP4JvAb0ALKBH4F+kX3bzWwo8AaQBkwOISwudISFshOYC5xUvC8jIpJE9pvoQwh997M/AEMK2Pca/kZQQv4PuBg4G/gr0LLkXlpEJEGlWK2bnsCdwAdAOnApsCquEYlI8fv222958MEHC31cjx49+Pbbbwt93JVXXknDhg1JT0+nVatWvP3227v2derUiaOPPproRZ169epF5cqVAdi5cyfDhg2jefPmtGjRgpNOOonPP/8c8KoALVq0ID09nfT0dIYNG1bo2PKTkCUQDlxFYCTwe+Ae4H5gHrCY/AcBiUgqyE30V1111R7bd+zYQVpaWoHHvfbagXc4jB07losuuohZs2YxaNAgVqxYsWtf1apV+eCDD+jQoQPffvst69ev37Vv6tSprFu3joULF1KmTBlycnKoVKnSrv2zZs3iyCOPPOC48pNiiT5XNbxlfzU++Mfw2wd/i2yrGr/QRFLcNddAVlbR/s70dBg3ruD9I0eOZOXKlaSnp1OuXDkqV65MrVq1yMrKYsmSJfTq1Ys1a9awdetWhg8fzqBBg4DddbW2bNlC9+7d6dChAx9++CF16tThxRdfpEKFCvuN7ZRTTmHt2rV7bOvTpw9TpkyhQ4cOPP/881xwwQUsXuy3KNevX0+tWrUoU8Y7VOrWrXuA/yqxS7Gum7xq4xN0Ad4ExgCN8Nb+j/EKSkSK2F133cUxxxxDVlYWY8eOZc6cOdxxxx0sWbIEgMmTJzN37lwyMzMZP348mzZt2ut3rFixgiFDhrB48WKqVq3Kc889F9NrT58+nV69eu2xrUuXLsyePZsdO3YwZcoULr744l37evfuzcsvv0x6ejrXXXcd8+fP3+PYzp077+q6ue+++wr7T5GvFG3R5+c8YD5wE3AjcB/exTOGlH+/EylB+2p5l5S2bdvuMaFo/PjxTJs2DYA1a9awYsUKqlevvscxuX3uAG3atGH16tX7fI0RI0Zwww03sGHDBj7++OM99qWlpdGhQwemTp3KTz/9RHTZ9bp167J8+XJmzpzJzJkz6dKlC88++yxdunQBiqfrppRluHTgVeBdoHXka+4/wQr2MXFXRJJIdJ/3O++8w1tvvcVHH33EggULaN26db4Tjg499NBd36elpbF9+/Z9vsbYsWPJzs7m9ttv54orrthrf58+fbj66qvp3bt3vq/VvXt3xo4dy+jRo3nhhRcKc3qFVsoSfa5T8VGfb0R+/i9wAnAi8Ajq1hFJLlWqVOH777/Pd9/mzZupVq0aFStWZNmyZXu1vg9GmTJlGD58ODt37uSNN97YY1/Hjh0ZNWoUffvuOUJ93rx5rFvnRQJ27tzJwoULqV+/fpHFlG+cxfrbE94hka+HAQ8A24GBeLWGEfgbgIgkuurVq9O+fXuaN2/OiBEj9tjXrVs3tm/fTsuWLbnlllto165dkb62mXHzzTdzzz337LX9+uuv36sbZsOGDZxzzjk0b96cli1bUrZsWYYOHbprf3Qf/eWXX140MUaP9UwUGRkZIT4rTAVgNp70XwKWAQ3xSb+HA/mu0iVS6i1dupSmTZvGO4xSI79/bzObG0LIyO/5pbxFn5cBpwHPAuvxJA9wHVATaAOMAmYCP8cjQBGRQlOiL9ARUd+PAW4DKgH3Al2AM6L2f4lu5IqkniFDhuzqRsl9PProo/EOq9BK0fDKg9Em8rgZ+B54J2rfVuA4/I2hB9AfaIdm4ookvwkTJsQ7hCKhFn2hVQHOiTzAK2Y+AHTAV0v8NdAKeC8u0YmI5KVEf9AqAgOAqfi6Kg/jVZxzu36W4KX51bUjIvGhRF+kquDL3s7Fx+UD3IWXYWiDvwnkP9ZXRKS4KNEXu38AE4AdwGC8/s5NcY1IREoXJfpidzhwFZAFfARcBOROtf4FX2HxYeCLuEQnkgoOtB49wLhx4/jxx33Phs+tE9+yZUtOO+00vvhi99+rmXHZZZft+nn79u3UqFGDnj17AvDVV1/Rs2dPWrVqRbNmzejRowcAq1evpkKFCnuM6HniiScO6Bz2R4m+xBg+GudRfLgmwBo8+Q8GGgBNgT8Cy+MQn0jyKu5ED15sbOHChXTq1Inbb7991/ZKlSqxaNEifvrpJwDefPNN6tSps2v/mDFjOPPMM1mwYAFLlizhrrvu2rUvt+Jm7qOoZsLmpUQfV42A1cBS4O/A0cCD+GQt8E8B4/A3g5/iEJ/IgeqUzyM3Ef9YwP7HIvu/zmffvkXXox8xYgRjx47lpJNOomXLltx6660A/PDDD5x99tm0atWK5s2bM3XqVMaPH8+6devo3LkznTt3junM8qs/3717d1599VUAnn766T3q26xfv36PmvMtW5b8EqdK9HFnwPHAtXiRtW/woZrglTavxYdsVsGrbw4Eviv5MEUSWHQ9+jPPPJMVK1YwZ84csrKymDt3LrNnz2b69OnUrl2bBQsWsGjRIrp168awYcOoXbs2s2bNYtasWTG9Vn7153MXGtm6dSsLFy7k5JNP3rVvyJAhDBgwgM6dO3PHHXfsKmgG7Hpzyn28917xDMvWhKmEUzHq+5uAK4FMfIhmJl5+oXJk/zDgYyAj8ugINEaTtST+3tnHvor72X/kfvbv24wZM5gxYwatW7cGYMuWLaxYsYKOHTty/fXXc+ONN9KzZ086duxYqN/buXNnvvrqK2rWrLlH1w14K3316tU8/fTTu/rgc3Xt2pVVq1Yxffp0Xn/9dVq3bs2iRYuA3V03xU0t+oRXB1805XZgOr7Yee5la4KXZXgKH8t/LLs/DQD8UHJhiiSIEAKjRo3a1e+dnZ3NgAEDOPbYY5k7dy4tWrRg1KhR/OUvfynU7501axZffPEFJ5xwAmPGjNlr/7nnnsv111+/V1ligCOOOIJLLrmEJ598kpNOOonZs2cf8PkdCCX6pHY1MAv4Fp+YNQHI/U8W8MTfArgGeBl1+Uiqiq5H37VrVyZPnsyWLVsAWLt2LRs2bGDdunVUrFiRSy+9lOuvv5558+btdez+VKhQgXHjxvHEE0/wzTff7LGvf//+jBkzhhYtWuyxfebMmbtu9n7//fesXLmSo48++qDOt7DUdZMSyuAjdqLLlv6Cd+28jQ/fvB9IwxdNH4GXbtiJ/gtIKoiuR9+9e3cuueQSTjnlFAAqV67MU089RXZ2NiNGjKBMmTKUK1eOhx56CIBBgwbRvXt3atWqFVM/fa1atejbty8TJkzglltu2bW9bt26DB8+fK/nz507l6FDh1K2bFl27tzJwIEDOemkk1i9evWuPvpc/fv3Z9iwYQf7z7EX1aMvFX7GR+68BZyJl2Kei1fgPBPoDnTFJ3OJFJ7q0ZeswtajV3OuVDiUvYepVQDOx/v9n41sawU8g3f5iEiqUKIvtZoBk/G+/M/whP8WvowiwFjgQ7zF3wVP/hrNI6nt5JNP5uef91xU6Mknn9yr3z3ZKNGXega0jDxuyLNvHpC7On0doBdekllkbyEEzJK7MfDJJ5/EO4T9OpDudo26kQKMwGftrgD+iU/aih61cwFew+d5fJKXlGbly5dn06ZNB5SEJHYhBDZt2kT58uULdZxa9LIPhk/Aagz8Pmr7L8A24AngocjzWuOzeC8t4RglEdStW5ecnBw2btwY71BSXvny5fcoqRCLmBK9mXVj9/i8SSGEu/Lsr4Z3+B6Dr63XP4SwKLJvNV6EfQewvaC7wpJMygGv4Al/Dj6Ec1bkZ4DPgZ746J5TIw+N6Ell5cqVo2HDhvEOQwqw30RvZmn4TJwzgRzgUzN7KYSwJOppo4GsEML5ZnZ85PldovZ3DiF8XYRxS0IoB7SPPKJnCm7BC7Q9hbf4wdsA/8EXYdmJeg1FSk4sf21tgewQwqoQwjZ8YdTz8jynGd6sI4SwDGhgZkcVaaSSRFoAr+N995l4Zc4W7B7R8wC+oPpQ4EU0Y1ekeMWS6OvghdNz5US2RVuA353DzNoC9dn9Vx2AGWY218wGFfQiZjbIzDLNLFP9fKmiLL6E4rXANHZ33zTAW/iP4iN5jsC7d7aXfIgipUAsiT6/8VJ5b63fBVQzsyy8AMt8dv/Vtg8hnIhPvxxiZqfm9yIhhIkhhIwQQkaNGjVii36P4+G3v4XJk2HHjkIfLiXqXOA1vMU/C7gRH6ef25N4If6h8QG8Vr9GcogcjFgSfQ5QL+rnusC66CeEEL4LIfQLIaQDlwM18DtyhBDWRb5uwJt1bYsg7r1s3gyrVsGAAZCeDq+/7slfElnujN07gElR2+vik7iuxnsFa+NtCRE5ELEk+k+BJmbW0MwOwRc5fSn6CWZWNbIPfGWM2SGE78yskplViTynEnAWsKjowt+talX48EN49ln46Sfo0QPOPBMiBeokqdyPl2NeCfwL6IwvvAKwGS/PPBC/ubs+v18gIlH2m+hDCNvxu2Zv4J+jnwkhLDazwWY2OPK0psBiM1uGd9HklnA7CnjfzBbg4/BeDSFML+qTyGUGF10ES5bA+PGQlQVt2sBll8EXWns7CTVid0IfEtn2P/zG7nPAb/HWfjMgt773d/gHTn2cE8mV0tUrN2+Gu++G++7zbpxhw2D0aG/9S7LbgY8BmIn38/8VL8r2FHAZfoO3BV7aoQXwG0AXXlLXvqpXpnSiz7VmDdxyCzzxBFSr5t//4Q9w6KFF9hKSMLLxoZ0L8X7+RfhKW2vx1v+jeAG33OUXTwQOj0ukIkVpX4m+VMxaqVcPHnvM++vbtIFrr4WmTeGRR+CXX/Z7uCSVxvhN3H/h6+l+h/f114rs34z3It4AnI638k/APyGA1/eJbbUhkWRRKhJ9rvR0mDEDpk+H6tVh4EBo0gQmToRt2+IdnRSPMnhff+4o4WvwAWEb8dtOd+CLrqRF9vfDW/jN8IXZHwYWl1y4IsWgVCX6XF27wpw58Npr8Ktfwe9/D40bw0MPQZ5S1JKyjsQHgY3GZ+7muhn4M/7J4DVgMHBd1P5xeJ0fVfSQ5FEq+uj3JQRv5f/5z/DRR1CnDowc6a39QlYClZQT8GGeP+I3dH/Ab/LmfvxrArQD+uPzAXL/lpK7Jrskp1LfR78vZt7C/+ADePNNaNgQrr4aGjWC++/3MflSWhleqiF3daFK+Gzed/AJXM3w7p/PIvtX4p8UOgCD8Nb/DHxIqEj8lPoWfV4hwDvveAv/3XfhqKNg6FDo3x9qq9Ku7CXg1T7K4X3/dwNLIo9Nkec8h5eCmosv0VgXn2ye+2iOr+ErcuBK/fDKA/Xuu3DHHd7ST0uDc86BQYPgrLP8Z5F924gn/OZAdXxC+XV4VZGtUc/LxIu/PQv8A0/+9fF5iM3wuQDlSixqSU5K9AdpxQqYNAkefRQ2boT69b0PX618OTABv5m7JvLoAlTGW/7jo7bn1gX8Lz7JfCrwCZ78m0YeR5Rk4JLAlOiLyLZt8MILPhzz7be9Vd+zp4/aUStfitYv+OSv5XglT8NHCI0Dom8cNcRvGIMv7ZiDTwyrE/VVM4JLAyX6YpC3lX/00V4584orvMUvUjx2Al/gXUJL8fsAd0b2nUeeeoP4Ai/LIt//Bf800QE4Gf8UIalCib4YbdsGL74IDz/srXyA00+HK6+ECy6ASpXiGp6UOj/hRd3W4WUfygC9I/vOAt7Ck30avqD7ZcCwkg9TipwSfQn5/HN48kl4/HGvjV+5MvzmN97K79gRypT6wawSf5uBj4D3I4/2+OzgbfgN4Qy8xX8su7t+NKEkGSjRl7AQ4P33vb7OM8/Ali0+Pv+KK+Dyy/17kcTyX3wW8PvsHhYKfk9gOH4fYDCe/KMfp7C7jpDEkxJ9HP3wA0yb5q38t9/2N4HTTvOunYsu8la/SOII+E3g1Xj3T1t8dM9nwO8i29aze0RQ7hyBmfjicvXweQK5j4sjX3/Gu5E0TLS4KNEniC+/hKee8pb+ihWe5Hv3hn79oH17n6Urkvh24sND1+Hj/avhy0SPx0f95ODDQ3/AK4WeBDwGDMC7gupHPa7BVx79Hr9vULHkTiPFKNEnmBB82cPJk3d37TRp4gn/8su93o5Icgt4ieiKeCt+PvA8PmIo95H7pvArfETQrXjSbxB5NAT+hM8a/ha/V6D7BQVRok9gW7bAc8950p8922/Ydu3qSf/cc7U4iqSy7Xgr3vAbxDPZ/SbwOd5FtBnv8hmErzFQG38DaAgcD9wU+V3L8U8aNfFPGKVv5IMSfZLIzvZunccfh5wcOOII6NsXLrkETjlFXTtS2uxkd8J+C/gQfwPIfZTFC8kBdMMLzIG/eRyJdxm9HNn2d/yNowL+KaMC/qnhvMj+9/E3nor4QvSH428YyVODSIk+yezYAW+95ZOxXnjBa+TXrw8XXwx9+vgCKkr6IjvYvWDMHDzpb4g8NuKJ+u7I/s6R5/wYdfxpeCVS8OGkK/L8/h7Aq5HvO0SOPTzq0RlfnAZ80poBh0Y9TsDLWO/E34QMf+PKfdTHq6Nuj8RWBl/3+MDeXJTok9h33/mErClTvG7+9u1w3HGe8Pv29e9FJFYBHwH0I56Aj4xsn4/fB/gRv7ewGe8mOjeyfxC7u5JyHz2BByL7D8HLVkQbEtm/DU/8eY3E3yA2RcWxFO+SKjwl+hTx9dfw/POe9N95x2/qpqd70u/TR6UXROJnB57Qf8Yrk/6MdwPVwN9QPo183Ym/2ezEh502ihw3i93lKQ5szLUSfQpatw6efdaT/scf+7bOnWHIEDjvPChbNr7xiUjJ0gpTKah2bRg+3Jc/XLUKbr8dVq70SVgNG/rPX30V7yhFJBEo0aeAhg3hpps84b/wAjRtCrfcAvXq+YidDz7wbh4RKZ2U6FNIWpp328yYAcuXw1VXwWuvQYcO0Lo1/OtfXpJBREoXJfoUdeyxMG4crF3rJZR37vRlEOvUgWuv9TcCESkdlOhTXKVKnuAXLPCZt926wQMPwPHHQ5cufkP3l7yjwkQkpSjRlxJmXhN/yhRYs8YXPV+50ouqHX209+l/+WW8oxSR4qBEXwr96lcwerQn+ldegTZtPPE3bOj1dV5/3WfnikhqiCnRm1k3M1tuZtlmNjKf/dXMbJqZLTSzOWbWPNZjJX7S0uDssz3Zf/45jBwJn3wCPXpA48Zw112+Hq6IJLf9JnozSwMmAN2BZkBfM2uW52mjgawQQkt89YH7C3GsJID69b1Vv2aNd+80aACjRvkQzf79ISsr3hGKyIGKpUXfFsgOIawKIWwDprC75FuuZsDbACGEZUADMzsqxmMlgRxyiBdPmzULFi/2JD91qg/PPO00L6m8ffv+f4+IJI5YEn0dfLmYXDmRbdEW4OuJYWZt8bJsdWM8lshxg8ws08wyN6q/ICE0awYPPuglk++912/WXnQRHHMM3HMPfPNNvCMUkVjEkujzK4ibd57lXUA1M8sCrsZLwW2P8VjfGMLEEEJGCCGjRo0aMYQlJaVaNbjuOq+XP20aNGoEN94IdevC73/vLX8RSVyxJPocfMXfXHXxxSJ3CSF8F0LoF0JIx/voa+ArA+z3WEkeaWnQq5d362RleXmFJ56A5s13j8nfti3eUYpIXrEk+k+BJmbW0MwOAfoAL0U/wcyqRvYBDARmhxC+i+VYSU6tWsGkSX7z9q9/9cXOc8fkjxrldXdEJDHsN9GHELYDQ/ElUpYCz4QQFpvZYDMbHHlaU2CxmS3DR9gM39exRX8aEi9HHumJ/fPPfZhm27bef9+4sc/CnTZNM29F4k316KXIrVkDjzziLf61a6FWLRgwAAYO1OIoIsVF9eilRNWrB3/6E6xe7csgtm69e+bt2WfDG2+obLJISVKil2JTtqyXVHj1Ve/auekmmDfPu3Rat4b//Edj8kVKghK9lIj69eG227yVP3myj8757W+9L3/8eNXJFylOSvRSog49FPr1g0WL4KWXfCz+8OE+WmfMGNiwId4RiqQeJXqJizJl4Jxz4P33/dGxo7f469f3lbFWrox3hCKpQ4le4q59e1/rdulS786ZNMlXyOrd2xc/F5GDo0QvCeP44z3Jr14N11/va9/++tfQrp1X1NR4fJEDo0QvCad2bbj7bi+m9sADXjytb18fnnnXXSqmJlJYSvSSsCpXhiFDYNkyePllb/GPGuU3cP/wB+/qEZH9U6KXhFemDPTsCW+9BQsXeuv+0Ue9jHL37pqAJbI/SvSSVFq08PIKX34Jf/kLzJ/vE7BOOAEmToSffop3hCKJR4leklLNmnDLLfDFF/D441C+vNfGr1fPZ+CuUzFskV2U6CWpHXooXH45zJ0L777r4/HvvNPH4196qW8XKe2U6CUlmMGpp3pZ5Oxsv4n74ouQkeHJ//nnYceOeEcpEh9K9JJyGjWCceN8eObf/+5fL7zQ6+qMGwfffx/vCEVKlhK9pKzDD4drr/UW/vPPe//9tdd6XZ3Ro+G//413hCIlQ4leUl5aGpx/PsyeDZ98Amec4ROv6teH3/0Oli+Pd4QixUuJXkqVtm19EfP/9/981aunnoKmTX3R8w8/jHd0IsVDiV5KpcaN4cEHfXjmzTd7a799e+jQwW/i7twZ7whFio4SvZRqNWv6xKsvv4T77/cbt716+azbxx/XCliSGpToRfC6OsOG+Y3b//zHJ2BdeaXX11HCl2SnRC8SpWxZr6Uzf77XyK9SZXfCf+wxJXxJTkr0Ivkwg/PO88XMX3wRDjvMl0BUwpdkpEQvsg9mcO65XkrhxRd9bH6/fnDccV5BU4uhSDJQoheJQW7Cz8z0Rc2rVoX+/b2F/+ijauFLYlOiFykEM1/UPDPTF0OpVs0TfrNmfhNX9XQkESnRixwAM18M5dNPvUunQgVf2LxVKy+3oIVQJJEo0YschNwunfnzYepU78K58EKvmvnaa0r4khiU6EWKQJky0Ls3LFrk4+7/9z84+2yfbTtzZryjk9JOiV6kCJUt6wuhLF8ODz8Ma9ZAly5w+unwwQfxjk5Kq5gSvZl1M7PlZpZtZiPz2X+4mb1sZgvMbLGZ9Yvat9rMPjOzLDPLLMrgRRJVuXIwaBCsWOGlFZYs8To6PXp4N49ISdpvojezNGAC0B1oBvQ1s2Z5njYEWBJCaAV0Av5mZodE7e8cQkgPIWQUTdgiyaF8eS+tsHKll0b++GM48US4+GKVR5aSE0uLvi2QHUJYFULYBkwBzsvznABUMTMDKgPfABpZLBJRqRLceCOsWuXVMl99FU44AQYO9IJqIsUplkRfB1gT9XNOZFu0B4CmwDrgM2B4CCG30GsAZpjZXDMbVNCLmNkgM8s0s8yNGzfGfAIiyaRqVbjtNk/4V18NTz4JTZrANdfAhg3xjk5SVSyJ3vLZlnfQWFcgC6gNpAMPmNlhkX3tQwgn4l0/Q8zs1PxeJIQwMYSQEULIqFGjRmzRiySpmjXhvvu8D/+yy+Af//C1bm+5Bb79Nt7RSaqJJdHnAPWifq6Lt9yj9QOeDy4b+Bw4HiCEsC7ydQMwDe8KEhF8/dpJk/xm7dlnw+23e8K/5x74+ed4RyepIpZE/ynQxMwaRm6w9gFeyvOcL4EuAGZ2FHAcsMrMKplZlcj2SsBZwKKiCl4kVRx3nE+4mjcP2rXz/vwTTvAyC5p0JQdrv4k+hLAdGAq8ASwFngkhLDazwWY2OPK024Bfm9lnwNvAjSGEr4GjgPfNbAEwB3g1hDC9OE5EJBW0bu0zat94w4donnuuD8nUCB05GBYSsLmQkZERMjM15F5Kt19+gQkT4NZb4ccfYfhwGDPGa+OL5GVmcwsawq6ZsSIJqlw5H42zYgVccQX8/e9w7LFeFlmLl0thKNGLJLiaNf2G7Zw5fqO2f3845RT45JN4RybJQoleJElkZMD778MTT3gNnXbtfD3bdXnHwInkoUQvkkTKlPFx98uX+8ic//wHGjeG0aM1/l4KpkQvkoSqVPHaOcuWwfnnw513erfOvffC1q3xjk4SjRK9SBJr1Aj+/W+viHnyyTBihJdUmDxZ69jKbkr0IikgPR1efx1mzYLatWHAAGjZEl54QROuRIleJKV06uSlkJ97zodgnn++r3I1e3a8I5N4UqIXSTFmcMEFvqzhxInwxRdw2mk+y/bzz+MdncSDEr1IiipbFn73O59wdeed3q1zwgleMO2XX+IdnZQkJXqRFFexIowc6RUyzzrLh2W2aeNdPFI6KNGLlBL16vnN2WnT4H//g1//Gq66SuPvSwMlepFSplcvb90PHw4PPwxNm8Izz2h0TipTohcphapU8R8ea0wAAAyiSURBVBWu5szx4ZgXX+wLn+hmbWpSohcpxdq08eJo48bBe+/pZm2qUqIXKeXKlvVunCVLoGtXv1nburXG3qcSJXoRAfxm7bRpfsN2yxYfe3/55fDVV/GOTA6WEr2I7OG887x1P3o0TJni69k++CDs2BHvyORAKdGLyF4qVoQ77oCFC70O/pAhXjRtzpx4RyYHQoleRAp0/PHw5pvesl+3zhc7GTwYvvkm3pFJYSjRi8g+mfnwy2XLfA3bSZO8O0dr1yYPJXoRiclhh/kC5fPmeaLv3x9OPRU++yzekcn+KNGLSKG0bOlDLydP9lb+iSf6kMwffoh3ZFIQJXoRKbQyZaBfP1+79vLLfZLVCSfAK6/EOzLJjxK9iByw6tXhkUe8hV+pEpxzDlx4IeTkxDsyiaZELyIHrWNHX7f2r3+F117zQmnjxmnd2kShRC8iReKQQ2DUKFi82BP/tddC27bw6afxjkyU6EWkSDVqBK++6qWP//tfn2g1dCh89128Iyu9lOhFpMiZwW9+46Nyhg71EgotW8I778Q7stJJiV5Eis1hh8H48fDhh1CuHHTuDH/8I2zdGu/ISpeYEr2ZdTOz5WaWbWYj89l/uJm9bGYLzGyxmfWL9VgRSX3t2kFWli9deN99Xgd/7tx4R1V67DfRm1kaMAHoDjQD+ppZszxPGwIsCSG0AjoBfzOzQ2I8VkRKgUqVYMIEeOMN2LzZk/9tt2lkTkmIpUXfFsgOIawKIWwDpgDn5XlOAKqYmQGVgW+A7TEeKyKlyFlnedmE3r1hzBho394nXknxiSXR1wHWRP2cE9kW7QGgKbAO+AwYHkLYGeOxAJjZIDPLNLPMjRs3xhi+iCSjatXg3/+GqVMhOxvS0+Ef/1CRtOISS6K3fLblXS++K5AF1AbSgQfM7LAYj/WNIUwMIWSEEDJq1KgRQ1gikux694ZFi+D002HYMG/tr1mz/+OkcGJJ9DlAvaif6+It92j9gOeDywY+B46P8VgRKcVq1fIaORMnwscfQ/PmXlYh5NsklAMRS6L/FGhiZg3N7BCgD/BSnud8CXQBMLOjgOOAVTEeKyKlnBn87ne+otWJJ8LAgdCtG3z5ZbwjSw37TfQhhO3AUOANYCnwTAhhsZkNNrPBkafdBvzazD4D3gZuDCF8XdCxxXEiIpL8GjWCt9/20TkffOAVMf/5T/XdHywLCfj5KCMjI2RmZsY7DBGJo9WrvZX/1ls+0eqRR6Bhw3hHlbjMbG4IISO/fZoZKyIJqUEDmDHD++4zM6FFC3jgAbXuD4QSvYgkrNy++9yKmFdf7a377Ox4R5ZclOhFJOHVq+d17h99FBYs8AJp990HO3bEO7LkoEQvIknBDK680lv3Xbp4cbQuXTTuPhZK9CKSVOrUgZde8tZ9Zia0agX/93/xjiqxKdGLSNLJbd1nZUHjxl77fuBA2LIl3pElJiV6EUlajRv7ePtRo2DyZJ9spfLHe1OiF5GkVq6cL0o+cyb89BOccgrcc4+GYUZToheRlNCpk4/IOfdcuPFGL5C2dm28o0oMSvQikjKOOAKefRYmTYKPPvJhmC+8EO+o4k+JXkRSihkMGADz5vns2vPPh0GD4Lvv4h1Z/CjRi0hKOu44b9XfcIO38Js3h9dfj3dU8aFELyIp65BD4O674cMPoXJl6NEDrrgCvvkm3pGVLCV6EUl57drB/Plw882+hGGzZvD88/GOquQo0YtIqXDooXDbbT6btnZtuPBCn2j11Vfxjqz4KdGLSKmSng6ffOJj7196yVv3Tz2V2ksXKtGLSKlTrpzPps3K8pu2l10GPXtCTk68IyseSvQiUmo1bQrvveclj2fN8tb944+nXuteiV5ESrW0NLjmGli0CFq39mJpl1wCmzfHO7Kio0QvIoIvTD5zpt+wffZZ78v/6KN4R1U0lOhFRCLS0nwI5nvv+c8dO8Lttyf/SlZK9CIieZxyit+o7d0bbrkFTj89uVeyUqIXEcnH4Yf75KrHHvMa961aJe8kKyV6EZECmHnJhPnz4ZhjfJLVoEHw44/xjqxwlOhFRPajSRNfyeqGG+Bf/4I2bbxrJ1ko0YuIxCC3QNqbb/rQy5NPhvvvT44x90r0IiKFcMYZvpJV164+/v7ssxO/Xo4SvYhIIdWoAS++CBMm+Izali1h+vR4R1UwJXoRkQNgBlddBZ9+CjVrQvfucO218PPP8Y5sb0r0IiIHoXlzmDMHrr4axo3zvvulS+Md1Z5iSvRm1s3MlptZtpmNzGf/CDPLijwWmdkOMzsism+1mX0W2ZdZ1CcgIhJvFSrA+PHw8suwdq2Pynn44cS5UbvfRG9macAEoDvQDOhrZs2inxNCGBtCSA8hpAOjgHdDCNGLdXWO7M8owthFRBJKz56wcCF06ACDB/u4+02b4h1VbC36tkB2CGFVCGEbMAU4bx/P7ws8XRTBiYgkm1q1/MbsvffCK6/4jNp3341vTLEk+jpAdJWHnMi2vZhZRaAb8FzU5gDMMLO5ZjaooBcxs0FmlmlmmRs3bowhLBGRxFSmDFx3HXz8MVSq5LVybrstfsXRYkn0ls+2gnqezgE+yNNt0z6EcCLe9TPEzE7N78AQwsQQQkYIIaNGjRoxhCUikthOPNHXqO3bF8aM8bH38RhzH0uizwHqRf1cF1hXwHP7kKfbJoSwLvJ1AzAN7woSESkVqlSBJ5+ESZO8jEJ6ute9L0mxJPpPgSZm1tDMDsGT+Ut5n2RmhwOnAS9GbatkZlVyvwfOAhYVReAiIsnCDAYM8GGYVav67No//ankunL2m+hDCNuBocAbwFLgmRDCYjMbbGaDo556PjAjhPBD1LajgPfNbAEwB3g1hJDA88dERIpPixY+weqyy+DPf4Yzz4T164v/dS0kykDPKBkZGSEzU0PuRSR1PfaYz6ytUsXr3p9xxsH9PjObW9AQds2MFRGJgyuv9Bu1Rx4JZ53lK1lt3148r6VELyISJ82aeVdOv36+Nm2XLrBlS9G/Ttmi/5UiIhKrihXhkUegUyefWFWpUtG/hhK9iEgCuOwyfxQHdd2IiKQ4JXoRkRSnRC8ikuKU6EVEUpwSvYhIilOiFxFJcUr0IiIpToleRCTFJWRRMzPbCHwRtelI4Os4hVNcUu2cUu18IPXOKdXOB1LvnA7mfOqHEPJdtSkhE31eZpaZaguLp9o5pdr5QOqdU6qdD6TeORXX+ajrRkQkxSnRi4ikuGRJ9BPjHUAxSLVzSrXzgdQ7p1Q7H0i9cyqW80mKPnoRETlwydKiFxGRA6RELyKS4hIu0ZvZZDPbYGaLorYdYWZvmtmKyNdq8YyxsAo4pz+Z2Vozy4o8esQzxsIws3pmNsvMlprZYjMbHtmelNdpH+eTzNeovJnNMbMFkXP6c2R7sl6jgs4naa8RgJmlmdl8M3sl8nOxXJ+E66M3s1OBLcATIYTmkW33AN+EEO4ys5FAtRDCjfGMszAKOKc/AVtCCPfGM7YDYWa1gFohhHlmVgWYC/QCriQJr9M+zqc3yXuNDKgUQthiZuWA94HhwAUk5zUq6Hy6kaTXCMDM/ghkAIeFEHoWV65LuBZ9CGE28E2ezecBj0e+fxz/I0waBZxT0gohrA8hzIt8/z2wFKhDkl6nfZxP0goud5npcpFHIHmvUUHnk7TMrC5wNjApanOxXJ+ES/QFOCqEsB78jxKoGed4ispQM1sY6dpJio/QeZlZA6A18AkpcJ3ynA8k8TWKdAtkARuAN0MISX2NCjgfSN5rNA64AdgZta1Yrk+yJPpU9BBwDJAOrAf+Ft9wCs/MKgPPAdeEEL6LdzwHK5/zSeprFELYEUJIB+oCbc2sebxjOhgFnE9SXiMz6wlsCCHMLYnXS5ZE/1WkHzW3P3VDnOM5aCGEryL/cXcC/wLaxjumwoj0kz4H/DuE8Hxkc9Jep/zOJ9mvUa4QwrfAO3h/dtJeo1zR55PE16g9cK6ZrQamAKeb2VMU0/VJlkT/EnBF5PsrgBfjGEuRyL2YEecDiwp6bqKJ3Bh7BFgaQvh71K6kvE4FnU+SX6MaZlY18n0F4AxgGcl7jfI9n2S9RiGEUSGEuiGEBkAfYGYI4VKK6fok4qibp4FOeLnOr4BbgReAZ4CjgS+B34QQkubmZgHn1An/uBmA1cDvc/vmEp2ZdQDeAz5jd//iaLxfO+mu0z7Opy/Je41a4jfz0vAG3TMhhL+YWXWS8xoVdD5PkqTXKJeZdQKuj4y6KZbrk3CJXkREilaydN2IiMgBUqIXEUlxSvQiIilOiV5EJMUp0YuIpDglehGRFKdELyKS4v4/sUJL0UwiXmIAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"df=pd.DataFrame(model.learning_process[10:], columns=['epoch', 'train_RMSE', 'test_RMSE'])\n",
"plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n",
"plt.plot('epoch', 'test_RMSE', data=df, color='yellow', linestyle='dashed')\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Saving and evaluating recommendations"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model.estimations()\n",
"\n",
"top_n=pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n",
"\n",
"top_n.to_csv('Recommendations generated/ml-100k/Self_SVD_reco.csv', index=False, header=False)\n",
"\n",
"estimations=pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n",
"estimations.to_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', index=False, header=False)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 10387.88it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.91633</td>\n",
" <td>0.720153</td>\n",
" <td>0.103393</td>\n",
" <td>0.044455</td>\n",
" <td>0.053177</td>\n",
" <td>0.070073</td>\n",
" <td>0.093884</td>\n",
" <td>0.079366</td>\n",
" <td>0.107792</td>\n",
" <td>0.051281</td>\n",
" <td>0.20021</td>\n",
" <td>0.518957</td>\n",
" <td>0.47508</td>\n",
" <td>0.853022</td>\n",
" <td>0.147186</td>\n",
" <td>3.911356</td>\n",
" <td>0.971196</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" RMSE MAE precision recall F_1 F_05 \\\n",
"0 0.91633 0.720153 0.103393 0.044455 0.053177 0.070073 \n",
"\n",
" precision_super recall_super NDCG mAP MRR LAUC \\\n",
"0 0.093884 0.079366 0.107792 0.051281 0.20021 0.518957 \n",
"\n",
" HR Reco in test Test coverage Shannon Gini \n",
"0 0.47508 0.853022 0.147186 3.911356 0.971196 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import evaluation_measures as ev\n",
"\n",
"estimations_df=pd.read_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', header=None)\n",
"reco=np.loadtxt('Recommendations generated/ml-100k/Self_SVD_reco.csv', delimiter=',')\n",
"\n",
"ev.evaluate(test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None),\n",
" estimations_df=estimations_df, \n",
" reco=reco,\n",
" super_reactions=[4,5])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 10057.03it/s]\n",
"943it [00:00, 10623.82it/s]\n",
"943it [00:00, 9952.71it/s]\n",
"943it [00:00, 10166.90it/s]\n",
"943it [00:00, 10805.10it/s]\n",
"943it [00:00, 10623.82it/s]\n",
"943it [00:00, 10390.37it/s]\n",
"943it [00:00, 10166.95it/s]\n",
"943it [00:00, 11057.98it/s]\n",
"943it [00:00, 10170.01it/s]\n",
"943it [00:00, 10994.41it/s]\n",
"943it [00:00, 10390.45it/s]\n",
"943it [00:00, 10166.77it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVD</td>\n",
" <td>0.916330</td>\n",
" <td>0.720153</td>\n",
" <td>0.103393</td>\n",
" <td>0.044455</td>\n",
" <td>0.053177</td>\n",
" <td>0.070073</td>\n",
" <td>0.093884</td>\n",
" <td>0.079366</td>\n",
" <td>0.107792</td>\n",
" <td>0.051281</td>\n",
" <td>0.200210</td>\n",
" <td>0.518957</td>\n",
" <td>0.475080</td>\n",
" <td>0.853022</td>\n",
" <td>0.147186</td>\n",
" <td>3.911356</td>\n",
" <td>0.971196</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_GlobalAvg</td>\n",
" <td>1.125760</td>\n",
" <td>0.943534</td>\n",
" <td>0.061188</td>\n",
" <td>0.025968</td>\n",
" <td>0.031383</td>\n",
" <td>0.041343</td>\n",
" <td>0.040558</td>\n",
" <td>0.032107</td>\n",
" <td>0.067695</td>\n",
" <td>0.027470</td>\n",
" <td>0.171187</td>\n",
" <td>0.509546</td>\n",
" <td>0.384942</td>\n",
" <td>1.000000</td>\n",
" <td>0.025974</td>\n",
" <td>2.711772</td>\n",
" <td>0.992003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Random</td>\n",
" <td>1.510030</td>\n",
" <td>1.211848</td>\n",
" <td>0.050053</td>\n",
" <td>0.022367</td>\n",
" <td>0.025984</td>\n",
" <td>0.033727</td>\n",
" <td>0.030687</td>\n",
" <td>0.023255</td>\n",
" <td>0.055392</td>\n",
" <td>0.021602</td>\n",
" <td>0.137690</td>\n",
" <td>0.507713</td>\n",
" <td>0.338282</td>\n",
" <td>0.987911</td>\n",
" <td>0.187590</td>\n",
" <td>5.111878</td>\n",
" <td>0.906685</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNWithZScore</td>\n",
" <td>0.957701</td>\n",
" <td>0.752387</td>\n",
" <td>0.003712</td>\n",
" <td>0.001994</td>\n",
" <td>0.002380</td>\n",
" <td>0.002919</td>\n",
" <td>0.003433</td>\n",
" <td>0.002401</td>\n",
" <td>0.005137</td>\n",
" <td>0.002158</td>\n",
" <td>0.016458</td>\n",
" <td>0.497349</td>\n",
" <td>0.027572</td>\n",
" <td>0.389926</td>\n",
" <td>0.067821</td>\n",
" <td>2.475747</td>\n",
" <td>0.992793</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNWithMeans</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
"0 Self_SVD 0.916330 0.720153 0.103393 0.044455 0.053177 \n",
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
"0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n",
"0 Ready_Random 1.510030 1.211848 0.050053 0.022367 0.025984 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNWithZScore 0.957701 0.752387 0.003712 0.001994 0.002380 \n",
"0 Ready_I-KNNWithMeans 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_TopRated 2.508258 2.217909 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
"0 0.070073 0.093884 0.079366 0.107792 0.051281 0.200210 \n",
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
"0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n",
"0 0.033727 0.030687 0.023255 0.055392 0.021602 0.137690 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.002919 0.003433 0.002401 0.005137 0.002158 0.016458 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR Reco in test Test coverage Shannon Gini \n",
"0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n",
"0 0.518957 0.475080 0.853022 0.147186 3.911356 0.971196 \n",
"0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n",
"0 0.509546 0.384942 1.000000 0.025974 2.711772 0.992003 \n",
"0 0.507713 0.338282 0.987911 0.187590 5.111878 0.906685 \n",
"0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n",
"0 0.497349 0.027572 0.389926 0.067821 2.475747 0.992793 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n",
"0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n",
"0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n",
"0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import evaluation_measures as ev\n",
"\n",
"dir_path=\"Recommendations generated/ml-100k/\"\n",
"super_reactions=[4,5]\n",
"test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embeddings"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1, 2],\n",
" [3, 4]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[0.4472136 , 0.89442719],\n",
" [0.6 , 0.8 ]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x=np.array([[1,2],[3,4]])\n",
"display(x)\n",
"x/np.linalg.norm(x, axis=1)[:,None]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>code</th>\n",
" <th>score</th>\n",
" <th>item_id</th>\n",
" <th>id</th>\n",
" <th>title</th>\n",
" <th>genres</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1311</td>\n",
" <td>1.000000</td>\n",
" <td>1312</td>\n",
" <td>1312</td>\n",
" <td>Pompatus of Love, The (1996)</td>\n",
" <td>Comedy, Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1395</td>\n",
" <td>0.992638</td>\n",
" <td>1396</td>\n",
" <td>1396</td>\n",
" <td>Stonewall (1995)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1448</td>\n",
" <td>0.991954</td>\n",
" <td>1449</td>\n",
" <td>1449</td>\n",
" <td>Pather Panchali (1955)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1201</td>\n",
" <td>0.991726</td>\n",
" <td>1202</td>\n",
" <td>1202</td>\n",
" <td>Maybe, Maybe Not (Bewegte Mann, Der) (1994)</td>\n",
" <td>Comedy</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1405</td>\n",
" <td>0.991419</td>\n",
" <td>1406</td>\n",
" <td>1406</td>\n",
" <td>When Night Is Falling (1995)</td>\n",
" <td>Drama, Romance</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1250</td>\n",
" <td>0.991108</td>\n",
" <td>1251</td>\n",
" <td>1251</td>\n",
" <td>A Chef in Love (1996)</td>\n",
" <td>Comedy</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1263</td>\n",
" <td>0.991100</td>\n",
" <td>1264</td>\n",
" <td>1264</td>\n",
" <td>Nothing to Lose (1994)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>1123</td>\n",
" <td>0.991069</td>\n",
" <td>1124</td>\n",
" <td>1124</td>\n",
" <td>Farewell to Arms, A (1932)</td>\n",
" <td>Romance, War</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>1110</td>\n",
" <td>0.991040</td>\n",
" <td>1111</td>\n",
" <td>1111</td>\n",
" <td>Double Happiness (1994)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1124</td>\n",
" <td>0.990889</td>\n",
" <td>1125</td>\n",
" <td>1125</td>\n",
" <td>Innocents, The (1961)</td>\n",
" <td>Thriller</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" code score item_id id title \\\n",
"0 1311 1.000000 1312 1312 Pompatus of Love, The (1996) \n",
"1 1395 0.992638 1396 1396 Stonewall (1995) \n",
"2 1448 0.991954 1449 1449 Pather Panchali (1955) \n",
"3 1201 0.991726 1202 1202 Maybe, Maybe Not (Bewegte Mann, Der) (1994) \n",
"4 1405 0.991419 1406 1406 When Night Is Falling (1995) \n",
"5 1250 0.991108 1251 1251 A Chef in Love (1996) \n",
"6 1263 0.991100 1264 1264 Nothing to Lose (1994) \n",
"7 1123 0.991069 1124 1124 Farewell to Arms, A (1932) \n",
"8 1110 0.991040 1111 1111 Double Happiness (1994) \n",
"9 1124 0.990889 1125 1125 Innocents, The (1961) \n",
"\n",
" genres \n",
"0 Comedy, Drama \n",
"1 Drama \n",
"2 Drama \n",
"3 Comedy \n",
"4 Drama, Romance \n",
"5 Comedy \n",
"6 Drama \n",
"7 Romance, War \n",
"8 Drama \n",
"9 Thriller "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"item=random.choice(list(set(train_ui.indices)))\n",
"\n",
"embeddings_norm=model.Qi/np.linalg.norm(model.Qi, axis=1)[:,None] # we do not mean-center here\n",
"# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n",
"\n",
"similarity_scores=np.dot(embeddings_norm,embeddings_norm[item].T)\n",
"top_similar_items=pd.DataFrame(enumerate(similarity_scores), columns=['code', 'score'])\\\n",
".sort_values(by=['score'], ascending=[False])[:10]\n",
"\n",
"top_similar_items['item_id']=top_similar_items['code'].apply(lambda x: item_code_id[x])\n",
"\n",
"items=pd.read_csv('./Datasets/ml-100k/movies.csv')\n",
"\n",
"result=pd.merge(top_similar_items, items, left_on='item_id', right_on='id')\n",
"\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# project task 5: implement SVD on top baseline (as it is in Surprise library)from tqdm import tqdm\n",
"\n",
"class SVDBaseline():\n",
"\n",
" def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n",
" self.train_ui=train_ui\n",
" self.uir=list(zip(*[train_ui.nonzero()[0],train_ui.nonzero()[1], train_ui.data]))\n",
"\n",
" self.learning_rate=learning_rate\n",
" self.regularization=regularization\n",
" self.iterations=iterations\n",
" self.nb_users, self.nb_items=train_ui.shape\n",
" self.nb_ratings=train_ui.nnz\n",
" self.nb_factors=nb_factors\n",
"\n",
" self.Pu=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_users, self.nb_factors))\n",
" self.Qi=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_items, self.nb_factors))\n",
"\n",
" self.b_u = np.zeros(self.nb_users)\n",
" self.b_i = np.zeros(self.nb_items)\n",
"\n",
" def train(self, test_ui=None):\n",
" if test_ui!=None:\n",
" self.test_uir=list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n",
"\n",
" self.learning_process=[]\n",
" pbar = tqdm(range(self.iterations))\n",
" for i in pbar:\n",
" pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...')\n",
" np.random.shuffle(self.uir)\n",
" self.sgd(self.uir)\n",
" if test_ui==None:\n",
" self.learning_process.append([i+1, self.RMSE_total(self.uir)])\n",
" else:\n",
" self.learning_process.append([i+1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n",
"\n",
" def sgd(self, uir):\n",
"\n",
" for u, i, score in uir:\n",
" prediction = self.get_rating(u,i)\n",
" e = (score - prediction)\n",
" \n",
" \n",
" b_u_update = self.learning_rate * (e - self.regularization * self.b_u[u])\n",
" b_i_update = self.learning_rate * (e - self.regularization * self.b_i[i])\n",
" \n",
" self.b_u[u] += b_u_update\n",
" self.b_i[i] += b_i_update\n",
" Pu_update=self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n",
" Qi_update=self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n",
" \n",
" self.Pu[u] += Pu_update\n",
" self.Qi[i] += Qi_update\n",
" \n",
" def get_rating(self, u, i):\n",
" prediction = self.b_u[u] + self.b_i[i] + self.Pu[u].dot(self.Qi[i].T)\n",
" return prediction\n",
" \n",
" def RMSE_total(self, uir):\n",
" RMSE=0\n",
" for u,i, score in uir:\n",
" prediction = self.get_rating(u,i)\n",
" RMSE+=(score - prediction)**2\n",
" return np.sqrt(RMSE/len(uir))\n",
" \n",
" def estimations(self):\n",
" self.estimations= self.b_u[:,np.newaxis] + self.b_i[np.newaxis:,] + np.dot(self.Pu,self.Qi.T)\n",
"\n",
" def recommend(self, user_code_id, item_code_id, topK=10):\n",
" \n",
" top_k = defaultdict(list)\n",
" for nb_user, user in enumerate(self.estimations):\n",
" \n",
" user_rated=self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user+1]]\n",
" for item, score in enumerate(user):\n",
" if item not in user_rated and not np.isnan(score):\n",
" top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n",
" result=[]\n",
" for uid, item_scores in top_k.items():\n",
" item_scores.sort(key=lambda x: x[1], reverse=True)\n",
" result.append([uid]+list(chain(*item_scores[:topK])))\n",
" return result\n",
" \n",
" def estimate(self, user_code_id, item_code_id, test_ui):\n",
" result=[]\n",
" for user, item in zip(*test_ui.nonzero()):\n",
" result.append([user_code_id[user], item_code_id[item], \n",
" self.estimations[user,item] if not np.isnan(self.estimations[user,item]) else 1])\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# making changes to our implementation by considering additional parameters in the gradient descent procedure \n",
"# seems to be the fastest option\n",
"# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n",
"# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ready-made SVD - Surprise implementation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"import helpers\n",
"import surprise as sp\n",
"import imp\n",
"imp.reload(helpers)\n",
"\n",
"algo = sp.SVD(biased=False) # to use unbiased version\n",
"\n",
"helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVD_reco.csv',\n",
" estimations_path='Recommendations generated/ml-100k/Ready_SVD_estimations.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD biased - on top baseline"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"import helpers\n",
"import surprise as sp\n",
"import imp\n",
"imp.reload(helpers)\n",
"\n",
"algo = sp.SVD() # default is biased=True\n",
"\n",
"helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv',\n",
" estimations_path='Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv')"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 9346.25it/s]\n",
"943it [00:00, 10163.97it/s]\n",
"943it [00:00, 10745.39it/s]\n",
"943it [00:00, 10228.45it/s]\n",
"943it [00:00, 10502.49it/s]\n",
"943it [00:00, 10627.14it/s]\n",
"943it [00:00, 10620.29it/s]\n",
"943it [00:00, 10991.05it/s]\n",
"943it [00:00, 9955.97it/s]\n",
"943it [00:00, 10510.36it/s]\n",
"943it [00:00, 9746.48it/s]\n",
"943it [00:00, 10164.36it/s]\n",
"943it [00:00, 10693.82it/s]\n",
"943it [00:00, 10927.37it/s]\n",
"943it [00:00, 10620.60it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVD</td>\n",
" <td>0.949165</td>\n",
" <td>0.746667</td>\n",
" <td>0.093955</td>\n",
" <td>0.044969</td>\n",
" <td>0.051197</td>\n",
" <td>0.065474</td>\n",
" <td>0.083906</td>\n",
" <td>0.073996</td>\n",
" <td>0.104672</td>\n",
" <td>0.048211</td>\n",
" <td>0.220757</td>\n",
" <td>0.519187</td>\n",
" <td>0.483563</td>\n",
" <td>0.997985</td>\n",
" <td>0.204906</td>\n",
" <td>4.408913</td>\n",
" <td>0.954288</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVD</td>\n",
" <td>0.916330</td>\n",
" <td>0.720153</td>\n",
" <td>0.103393</td>\n",
" <td>0.044455</td>\n",
" <td>0.053177</td>\n",
" <td>0.070073</td>\n",
" <td>0.093884</td>\n",
" <td>0.079366</td>\n",
" <td>0.107792</td>\n",
" <td>0.051281</td>\n",
" <td>0.200210</td>\n",
" <td>0.518957</td>\n",
" <td>0.475080</td>\n",
" <td>0.853022</td>\n",
" <td>0.147186</td>\n",
" <td>3.911356</td>\n",
" <td>0.971196</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVDBiased</td>\n",
" <td>0.938146</td>\n",
" <td>0.739917</td>\n",
" <td>0.086532</td>\n",
" <td>0.037067</td>\n",
" <td>0.044832</td>\n",
" <td>0.058877</td>\n",
" <td>0.078004</td>\n",
" <td>0.057865</td>\n",
" <td>0.094583</td>\n",
" <td>0.043013</td>\n",
" <td>0.202391</td>\n",
" <td>0.515202</td>\n",
" <td>0.433722</td>\n",
" <td>0.996076</td>\n",
" <td>0.166667</td>\n",
" <td>4.168354</td>\n",
" <td>0.964092</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_GlobalAvg</td>\n",
" <td>1.125760</td>\n",
" <td>0.943534</td>\n",
" <td>0.061188</td>\n",
" <td>0.025968</td>\n",
" <td>0.031383</td>\n",
" <td>0.041343</td>\n",
" <td>0.040558</td>\n",
" <td>0.032107</td>\n",
" <td>0.067695</td>\n",
" <td>0.027470</td>\n",
" <td>0.171187</td>\n",
" <td>0.509546</td>\n",
" <td>0.384942</td>\n",
" <td>1.000000</td>\n",
" <td>0.025974</td>\n",
" <td>2.711772</td>\n",
" <td>0.992003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Random</td>\n",
" <td>1.510030</td>\n",
" <td>1.211848</td>\n",
" <td>0.050053</td>\n",
" <td>0.022367</td>\n",
" <td>0.025984</td>\n",
" <td>0.033727</td>\n",
" <td>0.030687</td>\n",
" <td>0.023255</td>\n",
" <td>0.055392</td>\n",
" <td>0.021602</td>\n",
" <td>0.137690</td>\n",
" <td>0.507713</td>\n",
" <td>0.338282</td>\n",
" <td>0.987911</td>\n",
" <td>0.187590</td>\n",
" <td>5.111878</td>\n",
" <td>0.906685</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNWithZScore</td>\n",
" <td>0.957701</td>\n",
" <td>0.752387</td>\n",
" <td>0.003712</td>\n",
" <td>0.001994</td>\n",
" <td>0.002380</td>\n",
" <td>0.002919</td>\n",
" <td>0.003433</td>\n",
" <td>0.002401</td>\n",
" <td>0.005137</td>\n",
" <td>0.002158</td>\n",
" <td>0.016458</td>\n",
" <td>0.497349</td>\n",
" <td>0.027572</td>\n",
" <td>0.389926</td>\n",
" <td>0.067821</td>\n",
" <td>2.475747</td>\n",
" <td>0.992793</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNWithMeans</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
"0 Ready_SVD 0.949165 0.746667 0.093955 0.044969 0.051197 \n",
"0 Self_SVD 0.916330 0.720153 0.103393 0.044455 0.053177 \n",
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
"0 Ready_SVDBiased 0.938146 0.739917 0.086532 0.037067 0.044832 \n",
"0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n",
"0 Ready_Random 1.510030 1.211848 0.050053 0.022367 0.025984 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNWithZScore 0.957701 0.752387 0.003712 0.001994 0.002380 \n",
"0 Ready_I-KNNWithMeans 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_TopRated 2.508258 2.217909 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
"0 0.065474 0.083906 0.073996 0.104672 0.048211 0.220757 \n",
"0 0.070073 0.093884 0.079366 0.107792 0.051281 0.200210 \n",
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
"0 0.058877 0.078004 0.057865 0.094583 0.043013 0.202391 \n",
"0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n",
"0 0.033727 0.030687 0.023255 0.055392 0.021602 0.137690 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.002919 0.003433 0.002401 0.005137 0.002158 0.016458 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR Reco in test Test coverage Shannon Gini \n",
"0 0.555546 0.765642 1.000000 0.038961 3.159079 0.987317 \n",
"0 0.519187 0.483563 0.997985 0.204906 4.408913 0.954288 \n",
"0 0.518957 0.475080 0.853022 0.147186 3.911356 0.971196 \n",
"0 0.515501 0.437964 1.000000 0.033911 2.836513 0.991139 \n",
"0 0.515202 0.433722 0.996076 0.166667 4.168354 0.964092 \n",
"0 0.509546 0.384942 1.000000 0.025974 2.711772 0.992003 \n",
"0 0.507713 0.338282 0.987911 0.187590 5.111878 0.906685 \n",
"0 0.499885 0.154825 0.402333 0.434343 5.133650 0.877999 \n",
"0 0.497349 0.027572 0.389926 0.067821 2.475747 0.992793 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496724 0.021209 0.482821 0.059885 2.232578 0.994487 \n",
"0 0.496441 0.007423 0.602121 0.010823 2.089186 0.995706 \n",
"0 0.496433 0.009544 0.699046 0.005051 1.945910 0.995669 \n",
"0 0.496424 0.009544 0.600530 0.005051 1.803126 0.996380 \n",
"0 0.496391 0.003181 0.392153 0.115440 4.174741 0.965327 "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import imp\n",
"imp.reload(ev)\n",
"\n",
"import evaluation_measures as ev\n",
"dir_path=\"Recommendations generated/ml-100k/\"\n",
"super_reactions=[4,5]\n",
"test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}