{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Self made SVD" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import helpers\n", "import pandas as pd\n", "import numpy as np\n", "import scipy.sparse as sparse\n", "from collections import defaultdict\n", "from itertools import chain\n", "import random\n", "\n", "train_read=pd.read_csv('./Datasets/ml-100k/train.csv', sep='\\t', header=None)\n", "test_read=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "train_ui, test_ui, user_code_id, user_id_code, item_code_id, item_id_code = helpers.data_to_csr(train_read, test_read)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n", "from tqdm import tqdm\n", "\n", "class SVD():\n", " \n", " def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n", " self.train_ui=train_ui\n", " self.uir=list(zip(*[train_ui.nonzero()[0],train_ui.nonzero()[1], train_ui.data]))\n", " \n", " self.learning_rate=learning_rate\n", " self.regularization=regularization\n", " self.iterations=iterations\n", " self.nb_users, self.nb_items=train_ui.shape\n", " self.nb_ratings=train_ui.nnz\n", " self.nb_factors=nb_factors\n", " \n", " self.Pu=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_users, self.nb_factors))\n", " self.Qi=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_items, self.nb_factors))\n", "\n", " def train(self, test_ui=None):\n", " if test_ui!=None:\n", " self.test_uir=list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n", " \n", " self.learning_process=[]\n", " pbar = tqdm(range(self.iterations))\n", " for i in pbar:\n", " pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...')\n", " np.random.shuffle(self.uir)\n", " self.sgd(self.uir)\n", " if test_ui==None:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir)])\n", " else:\n", " self.learning_process.append([i+1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n", " \n", " def sgd(self, uir):\n", " \n", " for u, i, score in uir:\n", " # Computer prediction and error\n", " prediction = self.get_rating(u,i)\n", " e = (score - prediction)\n", " \n", " # Update user and item latent feature matrices\n", " Pu_update=self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n", " Qi_update=self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n", " \n", " self.Pu[u] += Pu_update\n", " self.Qi[i] += Qi_update\n", " \n", " def get_rating(self, u, i):\n", " prediction = self.Pu[u].dot(self.Qi[i].T)\n", " return prediction\n", " \n", " def RMSE_total(self, uir):\n", " RMSE=0\n", " for u,i, score in uir:\n", " prediction = self.get_rating(u,i)\n", " RMSE+=(score - prediction)**2\n", " return np.sqrt(RMSE/len(uir))\n", " \n", " def estimations(self):\n", " self.estimations=\\\n", " np.dot(self.Pu,self.Qi.T)\n", "\n", " def recommend(self, user_code_id, item_code_id, topK=10):\n", " \n", " top_k = defaultdict(list)\n", " for nb_user, user in enumerate(self.estimations):\n", " \n", " user_rated=self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user+1]]\n", " for item, score in enumerate(user):\n", " if item not in user_rated and not np.isnan(score):\n", " top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n", " result=[]\n", " # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n", " for uid, item_scores in top_k.items():\n", " item_scores.sort(key=lambda x: x[1], reverse=True)\n", " result.append([uid]+list(chain(*item_scores[:topK])))\n", " return result\n", " \n", " def estimate(self, user_code_id, item_code_id, test_ui):\n", " result=[]\n", " for user, item in zip(*test_ui.nonzero()):\n", " result.append([user_code_id[user], item_code_id[item], \n", " self.estimations[user,item] if not np.isnan(self.estimations[user,item]) else 1])\n", " return result" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 39 RMSE: 0.7481223595239049. Training epoch 40...: 100%|██████████| 40/40 [02:09<00:00, 3.25s/it]\n" ] } ], "source": [ "model=SVD(train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40)\n", "model.train(test_ui)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAcU0lEQVR4nO3de5gU9Z3v8fd3LjADgwxzicsyXDQaL1GcMUh0RVflSRSPUU4Ox0Dy7NF1TziPlwc9URO8rJpksxvlSbwkJj5uQtQ9GozrPWtiVPAx8XoGBARBwMiGQQ4zXAYGkMvA9/xR3UzP2DPTM9M91V31eT1PPV237v5SD/Op6l/9qsrcHRERKXxFYRcgIiLZoUAXEYkIBbqISEQo0EVEIkKBLiISESVhfXFNTY1PmDAhrK8XESlIixcv3uLutemWhRboEyZMoLGxMayvFxEpSGb2n90tU5OLiEhEKNBFRCJCgS4iEhGhtaGLSOE5cOAATU1N7N27N+xSIq+srIy6ujpKS0szfo8CXUQy1tTUxIgRI5gwYQJmFnY5keXubN26laamJo466qiM36cmFxHJ2N69e6murlaY55iZUV1d3edfQgp0EekThfng6M92LrhA37wZrrsO9u8PuxIRkfxScIH+xz/CvffClVeCbuUuItKh4AJ9xgy47TaYPx9+9KOwqxGRwdTa2srPfvazPr/vwgsvpLW1tc/vu/zyyznqqKOor6/nlFNO4ZVXXjm87JxzzmHcuHGkPiRo+vTpVFRUAHDo0CHmzJnDSSedxMknn8xpp53GRx99BARXyp988snU19dTX1/PnDlz+lxbOgXZy+X222H1avj2t+HYY+GSS8KuSEQGQzLQr7rqqk7z29vbKSnpPs5eeOGFfn/nvHnzmDFjBosWLWL27NmsXbv28LLKykpef/11pkyZQmtrK5s2bTq87PHHH+fjjz9m+fLlFBUV0dTUxPDhww8vX7RoETU1Nf2uK52CDPSiInjoIVi/Hr7+dXj9daivD7sqkXi57jpYujS7n1lfD/fc0/3yuXPn8uGHH1JfX09paSllZWWMGjWK1atXs2bNGqZPn86GDRvYu3cv1157LbNnzwY67h21a9cupk2bxpQpU3jjjTcYM2YMzz77LOXl5b3WdsYZZ7Bx48ZO82bOnMmCBQuYMmUKTz31FF/96ldZuXIlAJs2bWL06NEUFQUNIXV1df3cKpkruCaXpPJyeOYZqKqCr3wFUnaMIhJRP/zhD/nsZz/L0qVLmTdvHkuWLOHee+9lzZo1AMyfP5/FixfT2NjIfffdx9atWz/1GWvXruXqq69m5cqVVFZW8uSTT2b03b///e+ZPn16p3lTp07ltdde4+DBgyxYsICvfe1rh5ddeumlPP/889TX13P99dfz7rvvdnrvueeee7jJ5e677+7rpkirII/Qk0aPhuefhylTgmaXV1+FYcPCrkokHno6kh4skydP7nThzX333cfTTz8NwIYNG1i7di3V1dWd3pNsEwf4whe+wPr163v8jhtvvJGbb76ZpqYm3nzzzU7LiouLmTJlCgsWLOCTTz4h9ZbgdXV1fPDBByxcuJCFCxcydepUnnjiCaZOnQrkpsmlYI/Qk+rr4bHHoLERLr8cDh0KuyIRGSypbdKvvvoqL7/8Mm+++SbLli2joaEh7YU5Q4cOPTxeXFxMe3t7j98xb9481qxZw5133skVV1zxqeUzZ85kzpw5XHrppWm/a9q0acybN4+bb76ZZ555pi//vD4r+EAHuPhiuOsueOIJuOOOsKsRkVwZMWIEbW1taZft2LGDUaNGMWzYMFavXs1bb72V1e++5pprOHToEC+++GKn+WeddRY33XQTs2bN6jR/yZIlfPzxx0DQ42X58uWMHz8+qzV1VdBNLqmuvx5WrYLvfx+OOw6+8Y2wKxKRbKuurubMM8/kpJNOory8nCOPPPLwsgsuuIAHHniAE044geOOO47TTz89q99tZtx6663cddddnH/++Z3m33DDDZ9av7m5mW9+85vs27cPCJqHrrnmmsPLzz33XIqLiwGYOHEijzzyyMBr9JCuzpk0aZJn+4lF+/fD1Knw3nuwfTvoCmWR7Fq1ahUnnHBC2GXERrrtbWaL3X1SuvUj0eSSNGRI0ONlxw7YvTvsakREBlevgW5mZWb2jpktM7OVZvbdNOtcbmYtZrY0MfzP3JTbu+RJ4zS9lURE0rr66qsPdyFMDr/61a/CLqvPMmlD3wec5+67zKwU+JOZ/c7du55xeNzdr0nz/kGV7KG0dSvk+PyDSCy5e+TuuHj//feHXcKn9Kc5vNcjdA/sSkyWJoa8vS1WaqCLSHaVlZWxdevWfoWNZC75gIuysrI+vS+jXi5mVgwsBo4B7nf3t9Os9t/M7GxgDfC/3X1Dms+ZDcwGGDduXJ8KzVQy0LdsycnHi8RaXV0dTU1NtLS0hF1K5CUfQdcXGQW6ux8E6s2sEnjazE5y9xUpqzwP/Nrd95nZ/wIeBs5L8zkPAg9C0MulT5VmSEfoIrlTWlrap0eiyeDqUy8Xd28FFgEXdJm/1d33JSZ/AXwhO+X1XVVV8KpAF5G4yaSXS23iyBwzKwe+BKzuss7olMmLgVXZLLIvSkqgslKBLiLxk0mTy2jg4UQ7ehHwG3f/rZl9D2h09+eAOWZ2MdAObAMuz1XBmaiuVqCLSPz0GujuvhxoSDP/tpTxm4Cbslta/1VX66SoiMRPpK4UTdIRuojEkQJdRCQiFOgiIhERyUCvqYG2tuDuiyIicRHJQE9eXLRtW7h1iIgMpkgHunq6iEicRDrQ1Y4uInGiQBcRiQgFuohIRCjQRUQiIpKBPmwYlJcr0EUkXiIZ6KD7uYhI/EQ60HWELiJxokAXEYkIBbqISERENtBrahToIhIvkQ306urgXi6HDoVdiYjI4Ih0oB86BK2tYVciIjI4Ih3ooGYXEYkPBbqISEQo0EVEIiKygV5TE7wq0EUkLiIb6DpCF5G4iWygjxwJxcW6n4uIxEevgW5mZWb2jpktM7OVZvbdNOsMNbPHzWydmb1tZhNyUWxfmEFVlY7QRSQ+MjlC3wec5+6nAPXABWZ2epd1/gHY7u7HAHcDd2a3zP7R5f8iEie9BroHdiUmSxODd1ntEuDhxPi/A1PNzLJWZT8p0EUkTjJqQzezYjNbCjQDL7n7211WGQNsAHD3dmAHUJ3mc2abWaOZNba0tAys8gzofi4iEicZBbq7H3T3eqAOmGxmJ/Xny9z9QXef5O6Tamtr+/MRfaIjdBGJkz71cnH3VmARcEGXRRuBsQBmVgKMBEKP0uRTi7xrA5GISARl0sul1swqE+PlwJeA1V1Wew64LDE+A1joHn6MVlfDvn2wZ0/YlYiI5F5JBuuMBh42s2KCHcBv3P23ZvY9oNHdnwN+Cfybma0DtgEzc1ZxH6ReXDR8eLi1iIjkWq+B7u7LgYY0829LGd8L/PfsljZwqYE+bly4tYiI5FpkrxQF3c9FROIl0oGePELX5f8iEgexCHQdoYtIHEQ60KuqglcFuojEQaQDvbQUjjhCgS4i8RDpQAddLSoi8RH5QNf9XEQkLiIf6MnL/0VEoi4Wga4jdBGJAwW6iEhExCLQd+6EAwfCrkREJLdiEegA27aFW4eISK5FPtB1PxcRiYvIB7ru5yIicRGbQNcRuohEnQJdRCQiFOgiIhER+UAfNgyGDlWgi0j0RT7QzXQ/FxGJh8gHOuh+LiISD7EJdB2hi0jUKdBFRCJCgS4iEhGxCPSamuBeLu5hVyIikju9BrqZjTWzRWb2vpmtNLNr06xzjpntMLOlieG23JTbP9XVcPAg7NgRdiUiIrlTksE67cD17r7EzEYAi83sJXd/v8t6f3T3i7Jf4sCl3s+lsjLcWkREcqXXI3R33+TuSxLjbcAqYEyuC8smXS0qInHQpzZ0M5sANABvp1l8hpktM7Pfmdnnu3n/bDNrNLPGlpaWPhfbXwp0EYmDjAPdzCqAJ4Hr3H1nl8VLgPHufgrwE+CZdJ/h7g+6+yR3n1RbW9vfmvtMgS4icZBRoJtZKUGYP+ruT3Vd7u473X1XYvwFoNTMarJa6QDoIRciEgeZ9HIx4JfAKnf/cTfr/FViPcxscuJz8yY+R46EoiIFuohEWya9XM4E/g54z8yWJubdDIwDcPcHgBnAlWbWDnwCzHTPn17fRUVQVaX7uYhItPUa6O7+J8B6WeenwE+zVVQu6GpREYm6WFwpCgp0EYk+BbqISETEJtD1kAsRibrYBLqO0EUk6mIV6J98Anv2hF2JiEhuxCrQQUfpIhJdCnQRkYhQoIuIRERsAl33cxGRqItNoKc+5EJEJIpiE+hVVcGrjtBFJKpiE+hDhsCIEQp0EYmu2AQ66OIiEYk2BbqISETEKtB1PxcRibJYBXp1tXq5iEh0xS7QdYQuIlEVu0DfsQPa28OuREQk+2IX6ADbtoVbh4hILsQy0NXsIiJRFKtA1/1cRCTKYhXoup+LiERZLANdR+giEkUKdBGRiIhVoA8fHtykS4EuIlHUa6Cb2VgzW2Rm75vZSjO7Ns06Zmb3mdk6M1tuZqfmptyBMdPFRSISXSUZrNMOXO/uS8xsBLDYzF5y9/dT1pkGHJsYvgj8PPGad3Q/FxGJql6P0N19k7svSYy3AauAMV1WuwR4xANvAZVmNjrr1WZBTQ00N4ddhYhI9vWpDd3MJgANwNtdFo0BNqRMN/Hp0MfMZptZo5k1trS09K3SLPnc5+D998E9lK8XEcmZjAPdzCqAJ4Hr3H1nf77M3R9090nuPqm2trY/HzFgDQ3Q2grr14fy9SIiOZNRoJtZKUGYP+ruT6VZZSMwNmW6LjEv7zQ0BK/vvhtuHSIi2ZZJLxcDfgmscvcfd7Pac8D/SPR2OR3Y4e6bslhn1px8MhQXw9KlYVciIpJdmfRyORP4O+A9M0vG4M3AOAB3fwB4AbgQWAfsAf4++6VmR3k5HH+8jtBFJHp6DXR3/xNgvazjwNXZKirX6uvh1VfDrkJEJLtidaVoUkMDbNwIIXW0ERHJidgGOqjZRUSiJZaBXl8fvCrQRSRKYhnoVVUwfrx6uohItMQy0CFodtERuohESWwDvb4e1qyBXbvCrkREJDtiG+gNDcH9XJYvD7sSEZHsiHWgg5pdRCQ6YhvodXXBwy4U6CISFbENdLPgKF09XUQkKmIb6BAE+nvvwYEDYVciIjJwsQ70+nrYvx9WrQq7EhGRgYt1oOvEqIhESawD/XOfg2HDFOgiEg2xDvTiYpg4UYEuItEQ60CHjp4uemi0iBQ6BXoD7NwJH30UdiUiIgMT+0DXrXRFJCpiH+jJh0Yr0EWk0MU+0MvK4IQTFOgiUvhiH+ige6OLSDQo0AkCfdMm2Lw57EpERPpPgU7HiVHdqEtECpkCHfV0EZFo6DXQzWy+mTWb2Ypulp9jZjvMbGliuC37ZebWqFEwYYICXUQKW0kG6zwE/BR4pId1/ujuF2WlopDoxKiIFLpej9Dd/TVg2yDUEqqGBli7Ftrawq5ERKR/stWGfoaZLTOz35nZ57tbycxmm1mjmTW2tLRk6auzI3krXT00WkQKVTYCfQkw3t1PAX4CPNPdiu7+oLtPcvdJtbW1Wfjq7NGJUREpdAMOdHff6e67EuMvAKVmVjPgygbZmDFQU6NAF5HCNeBAN7O/MjNLjE9OfObWgX7uYEs+NFqBLiKFqtdeLmb2a+AcoMbMmoDbgVIAd38AmAFcaWbtwCfATPfCvLt4QwPcfXfwnNEhQ8KuRkSkb3oNdHef1cvynxJ0ayx4DQ1w4AC8/35Hm7qISKHQlaIpkj1ddAsAESlECvQUxxwDFRWwcGHYlYiI9J0CPUVxMVxxBTz6KKxeHXY1IiJ9o0Dv4pZbYNiw4FVEpJAo0Lv4zGfghhvgqafgnXfCrkZEJHMK9DS+9S2orYW5c6EwO2CKSBwp0NMYMQL+8R9h0SL4wx/CrkZEJDMK9G7Mnh3cI33uXDh0KOxqRER6p0DvxtCh8P3vB33SH3887GpERHqnQO/B178OEyfCrbcGtwMQEclnCvQeFBXBv/wL/PnP8ItfhF2NiEjPFOi9mDYNzj4bvvc92LUr7GpERLqnQO+FGdx5J2zeDPfcE3Y1IiLdU6Bn4PTTYfp0uOsu2LIl7GpERNJToGfoBz+A3bvhn/857EpERNJToGfoxBPhssvg/vth3bqwqxER+TQFeh9897tQXg5nnaX7vIhI/lGg98HYsfD660Go/+3fwoIFYVckItJBgd5Hn/88vP02TJoEs2bB7bfr1gAikh8U6P1QWwsvvwyXXx70T581C/bsCbsqEYk7BXo/DR0K8+cHXRmfeCJogvn447CrEpE4U6APgBnceCM88wysWgWTJ8OSJWFXJSJxpUDPgosvhjfeCJ5JeuaZcNVVQcCLiAwmBXqWTJwYdGWcNStoijnxRDj/fPiP/9BJUxEZHL0GupnNN7NmM1vRzXIzs/vMbJ2ZLTezU7NfZmE48sggzDdsgH/6J1ixAi66CI4/Hn7yE2hrC7tCEYmyTI7QHwIu6GH5NODYxDAb+PnAyypstbVwyy2wfj089hhUV8OcOTBmTNAc8+yz0NoadpUiEjW9Brq7vwZs62GVS4BHPPAWUGlmo7NVYCErLQ2aYN58M+i7/pWvwEMPBTf6qq6G006D73wHXnwxuE+MiMhAlGThM8YAG1KmmxLzNnVd0cxmExzFM27cuCx8deGYPBkefRT27QvCfeHCYLj77qDrY2kpfPGL8Dd/EzTRHH88HHccVFWFXbmIFIpsBHrG3P1B4EGASZMm+WB+d74YOjR4YMbZZ8MddwRH5q+/3hHw99zT+XF3tbWdA/7oo2HcuOA2BLW1QddJERHITqBvBMamTNcl5kkGhg+HL385GADa24O299Wr4YMPgtfVq4O+7i0tnd9bVgZ1dR0BP3ZscGK2piYI+5qajmHo0EH/p4nIIMtGoD8HXGNmC4AvAjvc/VPNLZKZkhI45phguOiizsu2bQvCfsMG+MtfOr++8kpwpWp3XSRHjAiCfdSojqGqqvP0qFFwxBHBukcc0TGMGBH0sReR/NZroJvZr4FzgBozawJuB0oB3P0B4AXgQmAdsAf4+1wVG3dVVcFwajcdQ9vbYfv24Eh+y5Zg6Dq+fXswbNwY7CC2b4cDB3r/7mHDoKIi+EUxbFjHkDpdURGEf/I1dUjOGz48GK+oCN5TpCshRLKm10B391m9LHfg6qxVJP1WUhI0tdTWZv4e9+DGYsmgb2uDnTuDoet4Wxt88knQ7r9nTzA0Nwevu3cHQ1tbcOI3U8mdQnJn0fW167yexlPnaUchcTSoJ0Ul/5h1BGddXXY+88CBjh3Arl2dx3ftCoI/OZ6cbmvr2Cns3h38mkid3rUr2PlkKl3Qpw7JXw1d53X3C6OsTCegJf8p0CXrSks7moeyxR327k2/U0jdIaTuOFJ3KLt3B79ANmzovDyT5iYIziGka1LqGv697RiS5ye0g5BcUKBLQTALnhRVXt63JqXe7N/feWeQ7pdFT9NbtnSe3rs3s+8tLu4c8CNHdh4ymXfEEcHOUyRJgS6xNmRIcNVudXV2Pq+9vfMvgK5D6vmI1PMUO3YE5yPWrQvGd+zI7FxEeXnnkK+s/PSOoLKyY37qa2VlsEPR+YboUKCLZFFJSUdYDtS+fUGwJwM/dUg3Lzn85S8d4709ScssONJP1pxuSHZvTTajpXZ51S+E/KJAF8lTQ4fCZz4TDP114EBHuLe2BkPqeLrpjz4KXpO9nnpSUfHpoO861NR0/Aqqrg7maUeQGwp0kQgrLe24Wrg/2ts7wj153UK61+SwalXwunVrzyecR47sCPfka7rx6uqOHcLIkTqR3BsFuoh0q6SkfzsE9+BE87ZtwYnjrVs7D8l5yR3Bhx8G062t3XdPLS7ufKSfvMVF1yF1flnZwLdBIVGgi0jWmXX07+/LjVUPHgyagLrbASTHt2yBtWuDW1Nv2RK8L53hwzt2SF3vb9R1h1BTE/wyKOSTxAp0EckbxcUdTS7HHpvZew4dCo7sW1o6D8ngTx0++KCjq2k6RUXB0X9tbcf5i+SQbl6+NQMp0EWkoBUVdewEjjsus/fs3Rsc7aeGf+p4c3MwvmxZML59e/rPKS3tCPp0gZ86HHlk0M00lxToIhI7ZWXBIyHHjMls/f37O4K+uRk2b+7YASTDv7k5OBfQ3Bxch5BORUUQ7lddBddfn71/T5ICXUSkF0OGwF//dTBkYs+ejpBP7gCS483NMDpHD+lUoIuIZNmwYTB+fDAMpgI+nysiIqkU6CIiEaFAFxGJCAW6iEhEKNBFRCJCgS4iEhEKdBGRiFCgi4hEhHlfHqWezS82awH+s4dVaoAtg1ROX6m2/lFt/aPa+ieqtY1397RP1g0t0HtjZo3uPinsOtJRbf2j2vpHtfVPHGtTk4uISEQo0EVEIiKfA/3BsAvogWrrH9XWP6qtf2JXW962oYuISN/k8xG6iIj0gQJdRCQi8i7QzewCM/vAzNaZ2dyw60llZuvN7D0zW2pmjSHXMt/Mms1sRcq8KjN7yczWJl5H5VFtd5jZxsS2W2pmF4ZU21gzW2Rm75vZSjO7NjE/9G3XQ22hbzszKzOzd8xsWaK27ybmH2Vmbyf+Xh83syF5VNtDZvZRynarH+zaUmosNrN3zey3iencbDd3z5sBKAY+BI4GhgDLgBPDriulvvVATdh1JGo5GzgVWJEy7y5gbmJ8LnBnHtV2B3BDHmy30cCpifERwBrgxHzYdj3UFvq2AwyoSIyXAm8DpwO/AWYm5j8AXJlHtT0EzAj7/1yirm8BjwG/TUznZLvl2xH6ZGCdu//Z3fcDC4BLQq4pL7n7a8C2LrMvAR5OjD8MTB/UohK6qS0vuPsmd1+SGG8DVgFjyINt10NtofNA8tHHpYnBgfOAf0/MD2u7dVdbXjCzOuC/AL9ITBs52m75FuhjgA0p003kyX/oBAf+YGaLzWx22MWkcaS7b0qM/z/gyDCLSeMaM1ueaJIJpTkolZlNABoIjujyatt1qQ3yYNslmg2WAs3ASwS/plvdvT2xSmh/r11rc/fkdvtBYrvdbWZDw6gNuAf4NnAoMV1NjrZbvgV6vpvi7qcC04CrzezssAvqjge/5fLmKAX4OfBZoB7YBPwozGLMrAJ4ErjO3XemLgt726WpLS+2nbsfdPd6oI7g1/TxYdSRTtfazOwk4CaCGk8DqoDvDHZdZnYR0Ozuiwfj+/It0DcCY1Om6xLz8oK7b0y8NgNPE/ynziebzWw0QOK1OeR6DnP3zYk/ukPAvxLitjOzUoLAfNTdn0rMzottl662fNp2iXpagUXAGUClmZUkFoX+95pS2wWJJix3933Arwhnu50JXGxm6wmakM8D7iVH2y3fAv3/AscmzgAPAWYCz4VcEwBmNtzMRiTHgS8DK3p+16B7DrgsMX4Z8GyItXSSDMuE/0pI2y7RfvlLYJW7/zhlUejbrrva8mHbmVmtmVUmxsuBLxG08S8CZiRWC2u7pattdcoO2gjaqAd9u7n7Te5e5+4TCPJsobt/g1xtt7DP/qY5G3whwdn9D4Fbwq4npa6jCXrdLANWhl0b8GuCn98HCNrg/oGgbe4VYC3wMlCVR7X9G/AesJwgPEeHVNsUguaU5cDSxHBhPmy7HmoLfdsBE4F3EzWsAG5LzD8aeAdYBzwBDM2j2hYmttsK4P+Q6AkT1gCcQ0cvl5xsN136LyISEfnW5CIiIv2kQBcRiQgFuohIRCjQRUQiQoEuIhIRCnQRkYhQoIuIRMT/B89L4DlKs5hnAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "df=pd.DataFrame(model.learning_process).iloc[:,:2]\n", "df.columns=['epoch', 'train_RMSE']\n", "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3de5zNdf7A8dd7xj0VMRXGLUlsxqhBaoQkQ5LKipKKWCHdaCkSrV2lLdmf7qnUrst21U3ZkFYqM24NwpAyFJYUuYXP74/3d2aOMczFmfmey/v5eHwfM+f7/Z45728n7/M9n8v7I845jDHGRK4YvwMwxhhTvCzRG2NMhLNEb4wxEc4SvTHGRDhL9MYYE+FK+R1AblWrVnV16tTxOwxjjAkraWlp/3POxeV1LOQSfZ06dUhNTfU7DGOMCSsi8v3xjlnTjTHGRLh8E72ITBGRbSKSfpzjIiKTRCRDRFaIyIUBx24RkXXedkswAzfGGFMwBbmjfwVIOcHxjkB9b+sPPAMgImcAo4EWQHNgtIhUPplgjTHGFF6+bfTOuQUiUucEp1wDTHVaS+FLEakkItWANsAc59xOABGZg35gTDvZoI0xoeX3338nMzOT/fv3+x1KxCtXrhzx8fGULl26wM8JRmdsDWBTwONMb9/x9h9DRPqj3waoVatWEEIyxpSkzMxMTj31VOrUqYOI+B1OxHLOsWPHDjIzM6lbt26BnxcSnbHOueedc0nOuaS4uDxHBxljQtj+/fupUqWKJfliJiJUqVKl0N+cgpHoNwM1Ax7He/uOt98YE4EsyZeMovx3DkainwX09kbfXAz84pz7EfgYuFJEKnudsFd6+4rRPmB18b6EMcaEmYIMr5wGLAIaiEimiPQVkQEiMsA75UNgA5ABvAAMBPA6YR8BFnvb2KyO2eIzEWgM3APsKt6XMsaYMJFvonfO9XTOVXPOlXbOxTvnXnLOPeuce9Y77pxzg5xz9ZxzjZ1zqQHPneKcO9fbXi7OC1G3A32Bp4Dz0M+dw8X/ssYYX+3atYunn3660M/r1KkTu3YV/qbw1ltvpW7duiQmJtKkSRM+/fTT7GNt2rShVq1aBC7q1LVrVypWrAjAkSNHGDJkCBdccAGNGzemWbNmfPfdd4BWBmjcuDGJiYkkJiYyZMiQQseWl5ArgXBy4oDngAHAEHQgz2LgeT+DMsYUs6xEP3DgwKP2Hzp0iFKljp/mPvzwwyK/5oQJE+jWrRvz5s2jf//+rFu3LvtYpUqVWLhwIcnJyezatYsff/wx+9iMGTPYsmULK1asICYmhszMTE455ZTs4/PmzaNq1apFjisvEZboszQFFgAz0Dt7gO3AAbRP2BhTXO6+G5YtC+7fTEyEiROPf3z48OGsX7+exMRESpcuTbly5ahcuTLffvsta9eupWvXrmzatIn9+/dz11130b9/fyCnttaePXvo2LEjycnJfPHFF9SoUYN3332X8uXL5xtby5Yt2bz56HEmPXr0YPr06SQnJ/PWW29x3XXXsXLlSgB+/PFHqlWrRkyMNqjExxd/TgqJ4ZXFQ4AeQFZFhgeBBsA4wCZ1GBNJxo8fT7169Vi2bBkTJkxgyZIlPPXUU6xduxaAKVOmkJaWRmpqKpMmTWLHjh3H/I1169YxaNAgVq5cSaVKlXjzzTcL9NqzZ8+ma9euR+1r164dCxYs4PDhw0yfPp0bbrgh+1j37t157733SExM5L777mPp0qVHPbdt27bZTTdPPvlkYf9T5ClC7+jzMgLYAYwEXgLGAJ0Bq8pgTDCd6M67pDRv3vyoCUWTJk3i7bffBmDTpk2sW7eOKlWqHPWcrDZ3gIsuuoiNGzee8DWGDRvGAw88QGZmJosWLTrqWGxsLMnJyUyfPp19+/YRWHo9Pj6eNWvWMHfuXObOnUu7du3497//Tbt27YDiabqJ4Dv63OoCbwL/ASoAvYHHvGMH0IFFh/wJzRgTVIFt3vPnz+c///kPixYtYvny5TRt2jTPCUdly5bN/j02NpZDh06cDyZMmMDatWt59NFH6dOnzzHHe/TowZAhQ+jevXuer9WxY0cmTJjAAw88wDvvvFOYyyu0KEr0WdoBy4H/oqN08H6/BO3M/SPwIkdXbzDGhLJTTz2V3bt353nsl19+oXLlylSoUIFvv/2WL7/8MqivPXjwYI4cOcLHHx89TahVq1aMGDGCnj17HrV/yZIlbNmyBdAROCtWrKB27dpBjSm3KGq6CRQLXBrwOAntuP3Y297w9qcCFwEHgTIlGaAxphCqVKnCpZdeygUXXED58uU566yzso+lpKTw7LPP0rBhQxo0aMDFF18c1NcWEUaOHMljjz1Ghw4djto/dOjQY87ftm0b/fr148CBA4A2Mw0ePDj7eNu2bYmNjQUgISGBqVOnnnyMgWM9Q0FSUpLzd4UpB6xCm3gGoZ+FdwMLgTvQDt4KvkVnTChavXo1DRs29DuMqJHXf28RSXPOJeV1fhQ23eRHgD8Ad5HzhacJsBedjFUDnXm7xpfojDGmsCzRF8htQDrwGdABmAw8GnDcOnGNiUSDBg3KHuqYtb38cglM8g+yKG2jLwoBLvO2rWi7PWg7/jVAP/SOv2aezzbGhJ/Jkyf7HUJQ2B19kZxFTkKPARKAsUAtoBnwF+BXf0IzxphcLNGftAuBj4B16KzbWGACkDUm9wNgHvC7L9EZY4wl+qCpBzwAfImOwc9K9KOBy9FvAb2AmdjdvjGmJFmiLxanBfz+GfAW0AWYDdyAtuVn+QAtuGaMMcXDEn2xOwW4FngF7cRdAPzZO/YDWm/nTLTgWh+0Do+tuGhMYRS1Hj3AxIkT2bt37wnPyaoTn5CQQOvWrfn++++zj4kIvXr1yn586NAh4uLi6Ny5MwBbt26lc+fONGnShEaNGtGpUycANm7cSPny5Y8a0ROMyVF5sURfomKBVuhMXICzgc+B8Wiin4WWZfjMO56BLqKytWTDNCbMFHeiBy02tmLFCtq0acNf/vKX7P2nnHIK6enp7Nu3D4A5c+ZQo0aN7OMPPfQQ7du3Z/ny5axatYrx48dnH8uquJm19e7du0jXkB8bXumrMkCyt4HOyl0DVPMeL0Rn5d4HpAC3AFcD5Uo2TGMKrU0e+7qjK43uBTrlcfxWb/sf0C3XsfknfLXAevTt27fnzDPPZObMmRw4cIBrr72WMWPG8Ntvv9G9e3cyMzM5fPgwo0aNYuvWrWzZsoW2bdtStWpV5s2bl++VtWzZkkmTJh21r1OnTnzwwQd069aNadOm0bNnTz7//HNA689feeWV2ecmJCTk+xrBZnf0IUWA84HTvce3oOUYhgHL0H8oNdF/KMaYLIH16Nu3b8+6dev4+uuvWbZsGWlpaSxYsIDZs2dTvXp1li9fTnp6OikpKQwZMoTq1aszb968AiV5yLv+fNZCI/v372fFihW0aNEi+9igQYPo27cvbdu2Zdy4cdkFzYDsD6esLevDIdjsjj7kNQT+ho7Nn4tW3syqtXM7Wn75ZnQMvzGhYv4JjlXI53jVfI6f2CeffMInn3xC06ZNAdizZw/r1q2jVatW3Hffffz5z3+mc+fOtGrVqlB/t23btuzcuZOKFSvyyCOPHHUsISGBjRs3Mm3atOw2+CwdOnRgw4YNzJ49m48++oimTZuSnp4O5DTdFDe7ow8bsUB7IKsa3kFgPbqQSh20/PLjWA0eE+2cc4wYMSK73TsjI4O+ffty3nnnsWTJEho3bszIkSMZO3Zsof7uvHnz+P7770lMTGT06NHHHO/SpQtDhw49piwxwBlnnMGNN97Ia6+9RrNmzViwYEGRr68oLNGHrTLoRKz16Fj9zWgTzxzv+Bb0Q2A28IsfARpTYgLr0Xfo0IEpU6awZ88eADZv3sy2bdvYsmULFSpUoFevXgwbNowlS5Yc89z8lCpViokTJzJ16lR27tx51LE+ffowevRoGjdufNT+uXPnZnf27t69m/Xr11OrVsl+A7emm7B3DproR6Ojc7Lq5i9DR/McRtv+E9BO32FA8S5yYExJC6xH37FjR2688UZatmwJQMWKFXn99dfJyMhg2LBhxMTEULp0aZ555hkA+vfvT0pKSnZbfX6qVatGz549mTx5MqNGjcreHx8fz5AhQ445Py0tjcGDB1OqVCmOHDnC7bffTrNmzdi4cWN2G32WPn365Pk3TpbVo49ovwFfoUM4/4sul7ga7dD9EO3ovR5t5zem6KwefckqbD16u6OPaKeg5Rcu9x4fIuct/xiYhN7hNwWuQ5O+/WM1JtJYoo8qgW/3U+jiKm+hi6aP8n4u9Y5/j47kkZIM0BhftWjRInuJvyyvvfbaMe3u4cYSfVQ7Bx3FMxTtzP3R278XaITO3O2Clmloha2ba07EOYdIeN8YfPXVV36HkK+iNLfbqBvjqUFOaQbQO/4GwDPAFejY5hk+xGXCQbly5dixY0eRkpApOOccO3bsoFy5ws2Otzt6k4cK6GSs29EO3U+B99FZu6AduY+gd/pXA42xJp7oFh8fT2ZmJtu3WyXW4lauXDni4+ML9ZwCJXoRSUFv8WKBF51z43Mdrw1MAeKAnUAv51ymd+ww8I136g/OuS6FitD47BS0+SbwbXPosM2R3lYTveufDJTn6E5fEw1Kly5N3bo2eitU5fuvUURi0X/B7YFMYLGIzHLOrQo47XFgqnPuVRG5HJ2zf7N3bJ9zLhETQa7ytp/Qu/v3gC/IKbZ2GzqFvQk6fj8BSCTnG4ExpiQV5LarOZDhnNsAICLT0dWwAxN9I+Be7/d5wDvBDNKEqrPRGvp9cu1vj971r0CHcR4CLiDni90T6J3/heiHQPmSCNaYqFWQztga6Np4WTK9fYGWowOxQVfZOFVEqniPy4lIqoh8KSJdKUY//1ycf90UXG/gdTTR70Fn6QbWCn8eLVd7MXAq2sY/IeC4Vec0JpiCNepmKNBaRJYCrdGxeoe9Y7W92Vo3AhNFpF7uJ4tIf+/DILWonTlbt8I550D//mD9QaGkLNqEE1gpcDU6Tv9tdJ3d2uT877IPqIQ28/RCR/18AxwpoXiNiTwFSfSb0d62LPHkWuvOObfFOXedc64p8KC3b5f3c7P3cwPacNs09ws45553ziU555Li4uKKch2UKwe33govvwz168OTT8LBg0X6U6bYCToZqyswFh3RM9w7dhCdvHU+OtpnINq885R3fBf6v5Hd9RtTUAVJ9IuB+iJSV0TKAD3QNe+yiUhVEcn6WyPQETiISGURKZt1DnApR7ftB83pp2tyX7ECLr4Y7r0XEhLgo4+K49VM8TkdTfTvoBU41wOvosM4QWvyt/XOa4F2Db0JFKz6oDHRKN9E75w7BAxGe9VWAzOdcytFZKyIZI25awOsEZG1wFnAOG9/QyBVRJajnbTjc43WCbqGDTW5v/8+HDkCnTrBVVfB2rXF+aqmeAg6e7c3cK637wrgA7RGTzm0aacbOgIItKTDALRP4HP0G4Ax0S2iq1cePAj/+AeMGQP79sFdd8GoUXr3byLFQbQ+TxI6zWMC8FeOTvC1gbVoCYfVQEWObo00JvydqHplRJdAKFMG7rsP1q2DW26BJ57Q9vsXX4TDh/N/vgkHZdAmnFjv8TB0zt4m9M5/PDoaOKtOz1C0fyAe+CM61HNxCcZrTMmL6Dv63NLS9K5+4UJtv//zn6F7dyhlkzijyDJgAVqbfxE6+qeVtw9gIvoh0JJjRxEbE7qi9o4+t4sugs8/h2nT4Pff4aab4Nxz4amnwFt1zES8RGAIMA3YiHb4Zo3x/x1dqeuPaLI/B53l+2mJR2lMMEVVogcQgR49ID0dZs2CmjXh7ruhVi1tv9+2ze8ITcmqhs7aBSgNbEdX5XoS/VCYRU6N/v+hlT1eBDLQ2b/GhL6oaro5nkWLYMIEeOcdbde/9VZt269fv0TDMCHpCNrhWw79ALga/TAAqI7ODxwD1PfOjbp7JxMirOkmHy1bwltvwbffaqftK69AgwZw/fUQBusQmGIVQ06xthboAuyr0GGdl6GTt373jr+I1v9ph67e9QJa7O13jPGT3dHn4aefdFjm00/Drl3QuLE299xwA9Q7poCDiW5Z/34EnSryOpAOrERr+QP8jJZ1eB34Gm37r+f9PAcr6maC4UR39JboT2D3bpg6VTtvFy7UfUlJmvS7d9f2fWPydgQd0bMGSPH2DUULugXO4j0d/SAQ4Dn0G8M5aAmIxmitIGPyZ4k+CH74AWbOhOnTdZgmQHKyJv1u3eCss/yNz4QLB+xASztsQJN+f+9YV7TzN+vfZBm0T+AN7/EmtF8ga86AMTks0QdZRgbMmKFJPz0dYmKgbVsdrtmjB5S3b+KmyPajwz6/QSdynY5XJxBN8r+idfybedul2CxfA5boi1V6ek7Sz8iAKlW0VPIdd1jTjgmmI8A/0eS/GB3yeQDt9J2IfkB0R2sC1Q/4WRP7BhAdLNGXAOdg/nyYNEnH54vAtdfCkCHaxCO2drYJqt/RTt+KaELPBDqh4/v3BZz3BHAPWln8cbQT+Fxvq43OHTCRwBJ9Cdu4UUfsvPCCjtpJTNSE37On1s03pvg4dLZvBrAOuARd6fO/aKfwbwHnxgLvouv/rkHX/836EDgH6wgOL5boffLbb/DPf+pQzfR0qFo1p1knPt7v6Ez0ceionoyArQ+a1F/m6LV/Y9A7/g/QauNr0L6D89CicNYcFGos0fsssFnn3Xe18/bqq6FfP+jQAWLt34zxnUOrfmZ9E8ja/g84A60BNNY7twzaBHQe8Bq67u8cdO7AKWhzUkVvfxvvObuACuRUETXBZok+hHz3HTzzjM6+3b5dO2z79NGtVi2/ozPmeHagiXwdWtt/LfAD2jEcgw4RfSHXcyqQ01R0E/AvIA4dPVQD/aB40ju+3PtZHaiKziswhWGJPgQdPKidti+8AHPm6L6UFL3L79wZSlsfmQkrvwN7vO037+cBdPgnwCfAl2j/wWZvq4D2HYAuDznf+/00dGnp9uQMLXVY8j8xS/QhbuNGeOklmDIFtmyBs8/Wwmq3324lF0y0WIZOItuM9gcsQfsCZnjHE9BmoQsDtj9gTUE5LNGHiUOHdL3bF1+EDz7QVbAuvxxuuw2uuw4qVPA7QmP8cAQtH5GGzh/IKiHRHy0bsR+4HO1LqBzwsz36jeIgWmPo1FxbWSLpW4Il+jC0ebO240+ZAhs2wGmn6azbPn2geXMbl2+i1RH0zn8JOiroYrSj949oZ/LP3s9fgL8Bw4Hv0JFFuU0C7kT7G64l5wOgvLcNRlcf+w79QCkfsJVDh6vWREcjvQ8c9rZD3s8bvRjXoU1XNQK2swj2yCVL9GHsyBFdFWvKFHjjDdi7Fxo21ITfq5c28xhjcstKtmXRPoMv0G8CgduVwEVozaH7A/bv87YJaK2hBcAVHFtu+l2gC5rkr84jhrlo30PuoaugST4VXdxmLvA2+gEwBO27KDxL9BHi11+1sNrLL8MXX+iwzKuu0qadq66yDlxjitdhtJlon/ezMtpvcACtQRQLlAr4WRodkXQE2EZOJ3TWdg9QBV3bYAT6LeQARe13sEQfgb79Vpt2Xn1V6+fHxUHv3tC3r97xG2PCzW/oB0fR2ApTEej882H8eNi0Cd5/H1q10kXOGzXS2jqvvKIzc40x4aLoST4/lujDXKlS2mzz5puQmQmPPaYTsW67DapXhwEDIDVVZ+caY6KTJfoIctZZMGyYNussWABdu+oKWc2aQdOm8H//Bz//7HeUxpiSZok+AoloU86rr+oErKef1o7bO+/Uu/ybb9alEe0u35joYIk+wlWqpNUy09J0u+02LayWnAxNmmjdnd278/87xpjwZYk+ilx4od7db9kCzz2nd/kDB+pd/h13wIoVfkdojCkOluijUMWKWhd/yRL48ku4/nodpdOkCVx6Kbz+Ouzf73eUxphgKVCiF5EUEVkjIhkiMjyP47VF5FMRWSEi80UkPuDYLSKyzttuCWbw5uSIQIsWmuQ3b4a//11H7Nx8sy6Mcv/9ug6uMSa85ZvoRSQWmAx0RNck6ykijXKd9jgw1TmXgK5O8DfvuVkrFrQAmgOjRaRy8MI3wXLGGXDvvTpiZ84caNMGnngC6teHK6+Et96C33PPADfGhIWC3NE3BzKccxuccweB6cA1uc5phBZsAJgXcLwDMMc5t9M59zO6DE3KyYdtiktMDFxxhdbV+eEHeOQRTf7XX68Lo4waBd9/73eUxpjCKEiirwFsCnic6e0LtBy4zvv9WuBUEalSwOciIv1FJFVEUrdv317Q2E0xq14dRo7UVbHeew8uugjGjYO6dXVxlPff11LKxpjQFqzO2KFAaxFZCrRGK/YUOAU45553ziU555Li4uKCFJIJltjYnMT+3Xfw4IM6VPPqqzXpP/KIjuQxxoSmgiT6zWjR5Szx3r5szrktzrnrnHNN8db+cs7tKshzTXipXVsT+w8/aPPO+efDQw/p/h49tKqmTcQyJrQUJNEvBuqLSF0RKQP0AGYFniAiVUUk62+NAKZ4v38MXCkilb1O2Cu9fSbMlS6t7faffALr1sGQITB7tg7PTErSWbk2RNOY0JBvonfOHUKXWvkYWA3MdM6tFJGxItLFO60NsEZE1qJLp4zznrsTeAT9sFgMjPX2mQhy7rk6NDMzU2fa7t+va97WrKlt/JmZfkdoTHSzevQm6JyDuXPhH/+AWbN0JM911+ld/6WX2jKIxhQHq0dvSpQItGsH77wD69fDPffo2PxWrbQMw8svw759fkdpTPSwRG+KVd26MGGCzrx9/nk4dEjXu61ZE0aM0IVTjDHFyxK9KREVKkC/flo4bd48uOwyXSSlbl3o1k3r54dYK6IxEcMSvSlRIlpe4a23YMMGuO8+bc9v3VoXR3npJWvWMSbYLNEb39SuDY8+qqNyXngBjhyB22/XgmrDh+tYfWPMybNEb3xXoYIm+OXLYf58veOfMEGbda67Tpt6rFnHmKKzRG9Chog24bz5pjbrDBsGn30Gl18OCQm6WMpvv/kdpTHhxxK9CUm1a8P48dqsM2WKzsQdMABq1NByyuvX+x2hMeHDEr0JaeXL6zq3aWnw3/9CSopOxKpfXwutffyxtu0bY47PEr0JCyI6q3b6dK2HP2oUpKZq4m/YUNfCtWYdY/Jmid6EnerVYcwYTfivvw6VKsGgQToJ64EHrGSyMblZojdhq2xZuOkmXeB84ULttH30UahTB3r3hmXL/I7QmNBgid6EPRG45BKtj79uHQwcCG+/rROwLr9cF0yxdnwTzSzRm4hyzjkwcaLW0JkwATIydCWshg3h2Wdh716/IzSm5FmiNxGpUiUYOlSHYU6bBqedBnfckTMbd88evyM0puRYojcRrXRpXeLw66+1cFpSkpZXqFMH/vY32L3b7wiNKX6W6E1UENF6+B99pJ23zZvrCJ06dWDcOPj1V78jNKb4WKI3UadFC/jwQ/jqK2jZUpc7rFNHFz3/5Re/ozMm+CzRm6jVvLmOyFm8GJKT4aGHNOGPGQO7dvkdnTHBY4neRL2kJF3bNi1Ni6o9/LAm/NGj4eef/Y7OmJNnid4Yz4UX6jq3S5fq+PuxYzXhjxoFO3f6HZ0xRWeJ3phcEhN1Bazly6F9e/jLXzThjxxpCd+EJ0v0xhxHQoLOtl2+HDp00NE5derAgw/Cjh1+R2dMwVmiNyYfCQnw73/DN99Ax446/r5OHRgxAv73P7+jMyZ/luiNKaALLoAZMzThX3VVTgG14cMt4ZvQZonemEL6wx+0Ln56utbReewxXd/2gQesSceEJkv0xhRRo0ZaR+ebb6BTJ136sG5dHaVjwzJNKLFEb8xJ+sMftElnxQrttM0apfPwwzbxyoQGS/TGBMkFF2in7fLl0K6dzrCtW1fH41tpBeOnAiV6EUkRkTUikiEiw/M4XktE5onIUhFZISKdvP11RGSfiCzztmeDfQHGhJqEBB2Hv3SpzrQdPVoT/rhxVi3T+CPfRC8iscBkoCPQCOgpIo1ynTYSmOmcawr0AJ4OOLbeOZfobQOCFLcxIS8xUWfapqVpLZ2RI6FePV0YZf9+v6Mz0aQgd/TNgQzn3Abn3EFgOnBNrnMccJr3++mALc9sjOfCC7WWzldfQZMmcM89cN558NJLcOiQ39GZaFCQRF8D2BTwONPbF+hhoJeIZAIfAncGHKvrNel8JiKtTiZYY8JZ8+YwZw58+ilUqwa3357Trm9r2priFKzO2J7AK865eKAT8JqIxAA/ArW8Jp17gX+JyGm5nywi/UUkVURSt2/fHqSQjAlNl1+ui5+88w6UKgXdu2sFzdmzwTm/ozORqCCJfjNQM+BxvLcvUF9gJoBzbhFQDqjqnDvgnNvh7U8D1gPn5X4B59zzzrkk51xSXFxc4a/CmDAjAtdcoyN0XntNh2F27KidtwsX+h2diTQFSfSLgfoiUldEyqCdrbNynfMD0A5ARBqiiX67iMR5nbmIyDlAfWBDsII3JtzFxkKvXvDttzB5Mqxbpx23nTvruHxjgiHfRO+cOwQMBj4GVqOja1aKyFgR6eKddh/QT0SWA9OAW51zDrgMWCEiy4A3gAHOOSv0akwuZcrAwIGwfr3OsF24UEft3HwzfPed39GZcCcuxBoFk5KSXGpqqt9hGOOrn3/WomlPPQWHD8Mdd2h55DPP9DsyE6pEJM05l5TXMZsZa0wIqlxZ7+wzMuC227RZp149Latgk65MYVmiNyaE1agBzz0HK1dCSoqWVahXDyZNggMH/I7OhAtL9MaEgQYNdLz9119D48Zw111w/vk6YsfG4Jv8WKI3Jow0awb/+Q988gmccQb07g0tWui4fGOOxxK9MWFGRBctX7wYXn8dtmyBli3h1lvhp5/8js6EIkv0xoSpmBi46SZYs0aXM5w2TWvoPP44HDzod3QmlFiiNybMVayoC5anp+vM2mHDtB1/9my/IzOhwhK9MRGifn147z344AOtmdOxI3TpokM0TXSzRG9MhOnUSe/uH30U5s3TpQ4ffBD27PE7MuMXS/TGRKAyZeD++7X9/oYb4K9/1eGYM2ZYhcxoZInemIWfGBgAAA2aSURBVAhWvTpMnaq1c+LioEcPuOIKWLXK78hMSbJEb0wUuOQSSE3VUgpLluhKV0OHWjmFaGGJ3pgoERurFTLXrtUx9088oTNu//Uva86JdJbojYkycXHwwgs6m7ZGDR2L36YNfPON35GZ4mKJ3pgo1by5Jvvnn9eiaU2bwt13wy+/+B2ZCTZL9MZEsdhY6NdPR+f066dVMRs0gOnTrTknkliiN8ZQpQo884x22NaqBT176mSrzEy/IzPBYIneGJPtwgth0SLtqP30U2jUSD8ArBRyeLNEb4w5Smws3HOPzq5t0UJH6rRpo807JjxZojfG5Omcc7Tu/ZQpOiKnSROdYfv7735HZgrLEr0x5rhEdM3a1avh6qu1Zk6zZpCW5ndkpjAs0Rtj8nX22bqU4dtvw7ZtOjTz/vth716/IzMFYYneGFNgXbtqnZy+fWHCBKt7Hy4s0RtjCqVSJZ1kNXculC6tde+7d9clDU1oskRvjCmStm1h+XJ45BGYNUvLIP/jH3D4sN+Rmdws0RtjiqxsWRg5UoditmwJQ4bokMzUVL8jM4Es0RtjTtq552pb/YwZ2oTTvDnceafVzQkVluiNMUEhom31q1fD4MFa+/78861uTiiwRG+MCarTT9fiaF9/rWWQe/aEDh1g/Xq/I4teluiNMcUiKQm++ko7aL/8UodiPvmkddb6oUCJXkRSRGSNiGSIyPA8jtcSkXkislREVohIp4BjI7znrRGRDsEM3hgT2mJjtRln1Spo1w7uvVeXNUxP9zuy6JJvoheRWGAy0BFoBPQUkUa5ThsJzHTONQV6AE97z23kPf4DkAI87f09Y0wUiY/XIZj/+hds2KBVMh9+GA4e9Duy6FCQO/rmQIZzboNz7iAwHbgm1zkOOM37/XQga+rENcB059wB59x3QIb394wxUUZE2+tXrYI//hHGjNGE/9VXfkcW+QqS6GsAmwIeZ3r7Aj0M9BKRTOBD4M5CPBcR6S8iqSKSun379gKGbowJR3Fx8M9/wnvvwa5dOv7+3nvht9/8jixyBasztifwinMuHugEvCYiBf7bzrnnnXNJzrmkuLi4IIVkjAllnTvr3f2f/qSdtI0b62InJvgKkow3AzUDHsd7+wL1BWYCOOcWAeWAqgV8rjEmSp12mq5gNX8+lCoFV1yha9fu3u13ZJGlIIl+MVBfROqKSBm0c3VWrnN+ANoBiEhDNNFv987rISJlRaQuUB/4OljBG2MiQ+vWWjdn2DB46SVITNQlDU1w5JvonXOHgMHAx8BqdHTNShEZKyJdvNPuA/qJyHJgGnCrUyvRO/1VwGxgkHPORtEaY45Rvjw89hh89pmOtU9OhoceshWtgkFciM1NTkpKcqlWEcmYqPbLL1ogbepUXdHqtdegQQO/owptIpLmnEvK65jNjDXGhJzTT4dXX9VVrTIyoGlTbcsPsfvSsGGJ3hgTsrp104XJk5Nh4EBdt/ann/yOKvxYojfGhLQaNbQE8lNP6fDLxo3h3Xf9jiq8WKI3xoS8mBhts09L03IKXbvC7bfbMMyCskRvjAkbjRppyYThw2HKFEhI0DH45sQs0RtjwkqZMvC3v8GCBVods21bvdvfu9fvyEKXJXpjTFhKTtZJVoMHa837Jk3giy/8jio0WaI3xoStU07RJD93rk6sSk7W2bX79/sdWWixRG+MCXtt2+owzH794PHHtfzx4sV+RxU6LNEbYyLCqafCc8/pUMxff9XyxyNH2uImYIneGBNhOnTQpQp79YJx47SEwrJlfkflL0v0xpiIU6kSvPKKTqzauhWaN4dHH43ehckt0RtjIlaXLrByJVxzjY69b9NG16yNNpbojTERrUoVmDlTK2CuWKHDMF96KboKpFmiN8ZEPBFts//mG0hK0vIJ114L27b5HVnJsERvjIkatWppYbS//x0++kgLpM3KvV5eBLJEb4yJKjExcO+9WiCtWjVtv4/0dWot0RtjotIFF+QUSMtap3bhQr+jKh6W6I0xUats2ZwCaUeOwGWXwcMPR94wTEv0xpiol5ysI3JuugnGjIF27WDzZr+jCh5L9MYYg5ZQmDpV16pNTdWmnA8/9Duq4LBEb4wxAXr31o7a6tXhqqtg6NDwr5djid4YY3Jp0EA7agcO1KGYycnhPaPWEr0xxuShXDmYPBneeAPWroWmTXWGbTiyRG+MMSdw/fVa/bJhQ7jhBhgwAPbt8zuqwrFEb4wx+ahTBz7/HO6/X2veN28Oq1b5HVXBWaI3xpgCKF1aSx1/9JGWPk5KghdeCI/iaJbojTGmEFJSdFHySy+F/v2he3f4+We/ozoxS/TGGFNI1arBxx/rHf4774R++YQCJXoRSRGRNSKSISLD8zj+pIgs87a1IrIr4NjhgGNRUCfOGBMNYmK0zX7hQm3WuewyGDs2NMsnlMrvBBGJBSYD7YFMYLGIzHLOZXdFOOfuCTj/TqBpwJ/Y55xLDF7IxhgTOpo3hyVLYNAgGD1ayyC//jrUrOl3ZDkKckffHMhwzm1wzh0EpgPXnOD8nsC0YARnjDHh4LTTdAWrqVM16TdpAm+95XdUOQqS6GsAmwIeZ3r7jiEitYG6wNyA3eVEJFVEvhSRrkWO1BhjQtzNN8PSpVCvno6/HzAA9u71O6rgd8b2AN5wzgW2UtV2ziUBNwITRaRe7ieJSH/vwyB1+/btQQ7JGGNKzrnnart91pj7Zs0gPd3fmAqS6DcDga1N8d6+vPQgV7ONc26z93MDMJ+j2++zznneOZfknEuKi4srQEjGGBO6ypTRETmffAI7dmg7/pQp/o25L0iiXwzUF5G6IlIGTebHjJ4RkfOBysCigH2VRaSs93tV4FIgjOaTGWNM0bVvr+UTLrkE+vbVpp09e0o+jnwTvXPuEDAY+BhYDcx0zq0UkbEi0iXg1B7AdOeO+sxqCKSKyHJgHjA+cLSOMcZEurPP1jH3Y8fCtGlw0UW6yElJEhdi83eTkpJcamqq32EYY0zQffYZ9OwJO3fCpEm6KLlIcP62iKR5/aHHsJmxxhhTQlq31qac1q3hT3+CG2+EX38t/te1RG+MMSXozDO1MNpf/wr//rc25SxdWryvaYneGGNKWEwMjBgB8+drbfuLL9ZFToqrJd0SvTHG+CQ5WZtyrrgCBg/WhU2OHAn+6+Rb68YYY0zxqVoV3ntP16b95Re92w82S/TGGOOzmBgYNqwY/37x/WljjDGhwBK9McZEOEv0xhgT4SzRG2NMhLNEb4wxEc4SvTHGRDhL9MYYE+Es0RtjTIQLuTLFIrId+D7X7qrA/3wIpzhF2jVF2vVA5F1TpF0PRN41ncz11HbO5blEX8gl+ryISOrx6iyHq0i7pki7Hoi8a4q064HIu6biuh5rujHGmAhnid4YYyJcuCT65/0OoBhE2jVF2vVA5F1TpF0PRN41Fcv1hEUbvTHGmKILlzt6Y4wxRWSJ3hhjIlzIJXoRmSIi20QkPWDfGSIyR0TWeT8r+xljYRzneh4Wkc0isszbOvkZY2GJSE0RmSciq0RkpYjc5e0Py/fpBNcTtu+TiJQTka9FZLl3TWO8/XVF5CsRyRCRGSJSxu9YC+IE1/OKiHwX8B4l+h1rYYhIrIgsFZH3vcfF8v6EXKIHXgFScu0bDnzqnKsPfOo9DhevcOz1ADzpnEv0tg9LOKaTdQi4zznXCLgYGCQijQjf9+l41wPh+z4dAC53zjUBEoEUEbkYeBS9pnOBn4G+PsZYGMe7HoBhAe/RMv9CLJK7gNUBj4vl/Qm5RO+cWwDszLX7GuBV7/dXga4lGtRJOM71hDXn3I/OuSXe77vR/1FrEKbv0wmuJ2w5tcd7WNrbHHA58Ia3P5zeo+NdT9gSkXjgKuBF77FQTO9PyCX64zjLOfej9/tPwFl+BhMkg0Vkhde0ExZNHHkRkTpAU+ArIuB9ynU9EMbvk9cssAzYBswB1gO7nHOHvFMyCaMPtNzX45zLeo/Gee/RkyJS1scQC2sicD9wxHtchWJ6f8Il0WdzOh40rD/JgWeAeuhX0B+Bv/sbTtGISEXgTeBu59yvgcfC8X3K43rC+n1yzh12ziUC8UBz4HyfQzopua9HRC4ARqDX1Qw4A/izjyEWmIh0BrY559JK4vXCJdFvFZFqAN7PbT7Hc1Kcc1u9/2mPAC+g/wjDioiURpPiP51zb3m7w/Z9yut6IuF9AnDO7QLmAS2BSiJSyjsUD2z2LbAiCrieFK/ZzTnnDgAvEz7v0aVAFxHZCExHm2yeopjen3BJ9LOAW7zfbwHe9TGWk5aVDD3XAunHOzcUeW2JLwGrnXNPBBwKy/fpeNcTzu+TiMSJSCXv9/JAe7TvYR7QzTstnN6jvK7n24AbC0Hbs8PiPXLOjXDOxTvn6gA9gLnOuZsopvcn5GbGisg0oA1arnMrMBp4B5gJ1EJLGHd3zoVFB+dxrqcN2hzggI3AnwLatkOeiCQDnwPfkNO++ADarh1279MJrqcnYfo+iUgC2pkXi97QzXTOjRWRc9A7yDOApUAv7244pJ3geuYCcYAAy4ABAZ22YUFE2gBDnXOdi+v9CblEb4wxJrjCpenGGGNMEVmiN8aYCGeJ3hhjIpwlemOMiXCW6I0xJsJZojfGmAhnid4YYyLc/wMcH02NeVcblAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "df=pd.DataFrame(model.learning_process[10:], columns=['epoch', 'train_RMSE', 'test_RMSE'])\n", "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n", "plt.plot('epoch', 'test_RMSE', data=df, color='yellow', linestyle='dashed')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Saving and evaluating recommendations" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model.estimations()\n", "\n", "top_n=pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n", "\n", "top_n.to_csv('Recommendations generated/ml-100k/Self_SVD_reco.csv', index=False, header=False)\n", "\n", "estimations=pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n", "estimations.to_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', index=False, header=False)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 4506.01it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
00.9153040.7190160.1008480.0422280.0511910.0678850.0922750.070730.1043660.0496060.1929990.5178310.4655360.8678690.1500723.8477960.972676
\n", "
" ], "text/plain": [ " RMSE MAE precision recall F_1 F_05 \\\n", "0 0.915304 0.719016 0.100848 0.042228 0.051191 0.067885 \n", "\n", " precision_super recall_super NDCG mAP MRR LAUC \\\n", "0 0.092275 0.07073 0.104366 0.049606 0.192999 0.517831 \n", "\n", " HR Reco in test Test coverage Shannon Gini \n", "0 0.465536 0.867869 0.150072 3.847796 0.972676 " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import evaluation_measures as ev\n", "\n", "estimations_df=pd.read_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', header=None)\n", "reco=np.loadtxt('Recommendations generated/ml-100k/Self_SVD_reco.csv', delimiter=',')\n", "\n", "ev.evaluate(test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None),\n", " estimations_df=estimations_df, \n", " reco=reco,\n", " super_reactions=[4,5])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 4962.65it/s]\n", "943it [00:00, 4020.02it/s]\n", "943it [00:00, 3974.62it/s]\n", "943it [00:00, 4763.58it/s]\n", "943it [00:00, 4203.99it/s]\n", "943it [00:00, 4230.96it/s]\n", "943it [00:00, 3909.05it/s]\n", "943it [00:00, 4215.47it/s]\n", "943it [00:00, 4293.30it/s]\n", "943it [00:00, 4389.83it/s]\n", "943it [00:00, 4410.76it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Ready_LightFMpureMF7.9531927.4620080.3344640.2199970.2172250.2549810.2337980.2669520.3987780.2630580.6291290.6077090.9130431.0000000.2756135.0858180.913665
0Ready_LightFM162.707436160.8554830.3408270.2176820.2179900.2580100.2438840.2606630.4038500.2682660.6375900.6065680.8981971.0000000.3513715.3662910.885046
0Self_P33.7024463.5272730.2821850.1920920.1867490.2169800.2041850.2400960.3391140.2049050.5721570.5935440.8759281.0000000.0772013.8758920.974947
0Ready_ImplicitALS3.2661013.0658240.2550370.1886530.1768520.2011890.1666310.2149250.3059080.1725460.5238710.5917090.8897141.0000000.5028865.7229570.827507
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Ready_LightFMcontent182.471340180.4052100.1603390.1012240.1021980.1210740.1026820.1124550.1800790.0874290.3378250.5475720.7041360.9749730.2647914.9098930.926201
0Self_SVD0.9153040.7190160.1008480.0422280.0511910.0678850.0922750.0707300.1043660.0496060.1929990.5178310.4655360.8678690.1500723.8477960.972676
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Self_GlobalAvg1.1257600.9435340.0611880.0259680.0313830.0413430.0405580.0321070.0676950.0274700.1711870.5095460.3849421.0000000.0259742.7117720.992003
0Ready_Random1.5143551.2163830.0497350.0223000.0257820.0335980.0282190.0217510.0543830.0211190.1339780.5076800.3393430.9869570.1774895.0886700.907676
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall \\\n", "0 Ready_LightFMpureMF 7.953192 7.462008 0.334464 0.219997 \n", "0 Ready_LightFM 162.707436 160.855483 0.340827 0.217682 \n", "0 Self_P3 3.702446 3.527273 0.282185 0.192092 \n", "0 Ready_ImplicitALS 3.266101 3.065824 0.255037 0.188653 \n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 \n", "0 Ready_LightFMcontent 182.471340 180.405210 0.160339 0.101224 \n", "0 Self_SVD 0.915304 0.719016 0.100848 0.042228 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 \n", "0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 \n", "0 Ready_Random 1.514355 1.216383 0.049735 0.022300 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 \n", "\n", " F_1 F_05 precision_super recall_super NDCG mAP \\\n", "0 0.217225 0.254981 0.233798 0.266952 0.398778 0.263058 \n", "0 0.217990 0.258010 0.243884 0.260663 0.403850 0.268266 \n", "0 0.186749 0.216980 0.204185 0.240096 0.339114 0.204905 \n", "0 0.176852 0.201189 0.166631 0.214925 0.305908 0.172546 \n", "0 0.118732 0.141584 0.130472 0.137473 0.214651 0.111707 \n", "0 0.102198 0.121074 0.102682 0.112455 0.180079 0.087429 \n", "0 0.051191 0.067885 0.092275 0.070730 0.104366 0.049606 \n", "0 0.046030 0.061286 0.079614 0.056463 0.095957 0.043178 \n", "0 0.031383 0.041343 0.040558 0.032107 0.067695 0.027470 \n", "0 0.025782 0.033598 0.028219 0.021751 0.054383 0.021119 \n", "0 0.000278 0.000463 0.000644 0.000189 0.000752 0.000168 \n", "\n", " MRR LAUC HR Reco in test Test coverage Shannon \\\n", "0 0.629129 0.607709 0.913043 1.000000 0.275613 5.085818 \n", "0 0.637590 0.606568 0.898197 1.000000 0.351371 5.366291 \n", "0 0.572157 0.593544 0.875928 1.000000 0.077201 3.875892 \n", "0 0.523871 0.591709 0.889714 1.000000 0.502886 5.722957 \n", "0 0.400939 0.555546 0.765642 1.000000 0.038961 3.159079 \n", "0 0.337825 0.547572 0.704136 0.974973 0.264791 4.909893 \n", "0 0.192999 0.517831 0.465536 0.867869 0.150072 3.847796 \n", "0 0.198193 0.515501 0.437964 1.000000 0.033911 2.836513 \n", "0 0.171187 0.509546 0.384942 1.000000 0.025974 2.711772 \n", "0 0.133978 0.507680 0.339343 0.986957 0.177489 5.088670 \n", "0 0.001677 0.496424 0.009544 0.600530 0.005051 1.803126 \n", "\n", " Gini \n", "0 0.913665 \n", "0 0.885046 \n", "0 0.974947 \n", "0 0.827507 \n", "0 0.987317 \n", "0 0.926201 \n", "0 0.972676 \n", "0 0.991139 \n", "0 0.992003 \n", "0 0.907676 \n", "0 0.996380 " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import evaluation_measures as ev\n", "\n", "dir_path=\"Recommendations generated/ml-100k/\"\n", "super_reactions=[4,5]\n", "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embeddings" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1, 2],\n", " [3, 4]])" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "array([[0.4472136 , 0.89442719],\n", " [0.6 , 0.8 ]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x=np.array([[1,2],[3,4]])\n", "display(x)\n", "x/np.linalg.norm(x, axis=1)[:,None]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
codescoreitem_ididtitlegenres
0441.0000004545Eat Drink Man Woman (1994)Comedy, Drama
18550.966812856856Night on Earth (1991)Comedy, Drama
214030.96657114041404Withnail and I (1987)Comedy
31120.966115113113Horseman on the Roof, The (Hussard sur le toit...Drama
49550.965365956956Nobody's Fool (1994)Drama
512220.96523212231223King of the Hill (1993)Drama
6600.9644816161Three Colors: White (1994)Drama
75350.963322536536Ponette (1996)Drama
811020.96259711031103Trust (1990)Comedy, Drama
97130.962459714714Carrington (1995)Drama, Romance
\n", "
" ], "text/plain": [ " code score item_id id \\\n", "0 44 1.000000 45 45 \n", "1 855 0.966812 856 856 \n", "2 1403 0.966571 1404 1404 \n", "3 112 0.966115 113 113 \n", "4 955 0.965365 956 956 \n", "5 1222 0.965232 1223 1223 \n", "6 60 0.964481 61 61 \n", "7 535 0.963322 536 536 \n", "8 1102 0.962597 1103 1103 \n", "9 713 0.962459 714 714 \n", "\n", " title genres \n", "0 Eat Drink Man Woman (1994) Comedy, Drama \n", "1 Night on Earth (1991) Comedy, Drama \n", "2 Withnail and I (1987) Comedy \n", "3 Horseman on the Roof, The (Hussard sur le toit... Drama \n", "4 Nobody's Fool (1994) Drama \n", "5 King of the Hill (1993) Drama \n", "6 Three Colors: White (1994) Drama \n", "7 Ponette (1996) Drama \n", "8 Trust (1990) Comedy, Drama \n", "9 Carrington (1995) Drama, Romance " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item=random.choice(list(set(train_ui.indices)))\n", "\n", "embeddings_norm=model.Qi/np.linalg.norm(model.Qi, axis=1)[:,None] # we do not mean-center here\n", "# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n", "\n", "similarity_scores=np.dot(embeddings_norm,embeddings_norm[item].T)\n", "top_similar_items=pd.DataFrame(enumerate(similarity_scores), columns=['code', 'score'])\\\n", ".sort_values(by=['score'], ascending=[False])[:10]\n", "\n", "top_similar_items['item_id']=top_similar_items['code'].apply(lambda x: item_code_id[x])\n", "\n", "items=pd.read_csv('./Datasets/ml-100k/movies.csv')\n", "\n", "result=pd.merge(top_similar_items, items, left_on='item_id', right_on='id')\n", "\n", "result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# project task 5: implement SVD on top baseline (as it is in Surprise library)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# making changes to our implementation by considering additional parameters in the gradient descent procedure \n", "# seems to be the fastest option\n", "# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n", "# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Ready-made SVD - Surprise implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "import imp\n", "imp.reload(helpers)\n", "\n", "algo = sp.SVD(biased=False) # to use unbiased version\n", "\n", "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVD_reco.csv',\n", " estimations_path='Recommendations generated/ml-100k/Ready_SVD_estimations.csv')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SVD biased - on top baseline" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Generating predictions...\n", "Generating top N recommendations...\n", "Generating predictions...\n" ] } ], "source": [ "import helpers\n", "import surprise as sp\n", "import imp\n", "imp.reload(helpers)\n", "\n", "algo = sp.SVD() # default is biased=True\n", "\n", "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv',\n", " estimations_path='Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "943it [00:00, 4110.82it/s]\n", "943it [00:00, 4014.43it/s]\n", "943it [00:00, 3946.85it/s]\n", "943it [00:00, 4832.12it/s]\n", "943it [00:00, 4090.40it/s]\n", "943it [00:00, 4152.16it/s]\n", "943it [00:00, 4456.22it/s]\n", "943it [00:00, 3943.66it/s]\n", "943it [00:00, 4298.65it/s]\n", "943it [00:00, 4243.05it/s]\n", "943it [00:00, 4391.40it/s]\n", "943it [00:00, 4528.23it/s]\n", "943it [00:00, 4419.43it/s]\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ModelRMSEMAEprecisionrecallF_1F_05precision_superrecall_superNDCGmAPMRRLAUCHRReco in testTest coverageShannonGini
0Ready_LightFMpureMF7.9531927.4620080.3344640.2199970.2172250.2549810.2337980.2669520.3987780.2630580.6291290.6077090.9130431.0000000.2756135.0858180.913665
0Ready_LightFM162.707436160.8554830.3408270.2176820.2179900.2580100.2438840.2606630.4038500.2682660.6375900.6065680.8981971.0000000.3513715.3662910.885046
0Self_P33.7024463.5272730.2821850.1920920.1867490.2169800.2041850.2400960.3391140.2049050.5721570.5935440.8759281.0000000.0772013.8758920.974947
0Ready_ImplicitALS3.2661013.0658240.2550370.1886530.1768520.2011890.1666310.2149250.3059080.1725460.5238710.5917090.8897141.0000000.5028865.7229570.827507
0Self_TopPop2.5082582.2179090.1888650.1169190.1187320.1415840.1304720.1374730.2146510.1117070.4009390.5555460.7656421.0000000.0389613.1590790.987317
0Ready_LightFMcontent182.471340180.4052100.1603390.1012240.1021980.1210740.1026820.1124550.1800790.0874290.3378250.5475720.7041360.9749730.2647914.9098930.926201
0Ready_SVD0.9514750.7502250.0994700.0514070.0560040.0702290.0881970.0831660.1154220.0535150.2533290.5224340.5228000.9967130.2164504.4245050.952962
0Self_SVD0.9153040.7190160.1008480.0422280.0511910.0678850.0922750.0707300.1043660.0496060.1929990.5178310.4655360.8678690.1500723.8477960.972676
0Ready_Baseline0.9494590.7524870.0914100.0376520.0460300.0612860.0796140.0564630.0959570.0431780.1981930.5155010.4379641.0000000.0339112.8365130.991139
0Ready_SVDBiased0.9378410.7399060.0794270.0325700.0398040.0530220.0710300.0506390.0884900.0393080.2015650.5129290.4252390.9970310.1709964.1670510.963929
0Self_GlobalAvg1.1257600.9435340.0611880.0259680.0313830.0413430.0405580.0321070.0676950.0274700.1711870.5095460.3849421.0000000.0259742.7117720.992003
0Ready_Random1.5143551.2163830.0497350.0223000.0257820.0335980.0282190.0217510.0543830.0211190.1339780.5076800.3393430.9869570.1774895.0886700.907676
0Self_BaselineUI0.9675850.7627400.0009540.0001700.0002780.0004630.0006440.0001890.0007520.0001680.0016770.4964240.0095440.6005300.0050511.8031260.996380
\n", "
" ], "text/plain": [ " Model RMSE MAE precision recall \\\n", "0 Ready_LightFMpureMF 7.953192 7.462008 0.334464 0.219997 \n", "0 Ready_LightFM 162.707436 160.855483 0.340827 0.217682 \n", "0 Self_P3 3.702446 3.527273 0.282185 0.192092 \n", "0 Ready_ImplicitALS 3.266101 3.065824 0.255037 0.188653 \n", "0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 \n", "0 Ready_LightFMcontent 182.471340 180.405210 0.160339 0.101224 \n", "0 Ready_SVD 0.951475 0.750225 0.099470 0.051407 \n", "0 Self_SVD 0.915304 0.719016 0.100848 0.042228 \n", "0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 \n", "0 Ready_SVDBiased 0.937841 0.739906 0.079427 0.032570 \n", "0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 \n", "0 Ready_Random 1.514355 1.216383 0.049735 0.022300 \n", "0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 \n", "\n", " F_1 F_05 precision_super recall_super NDCG mAP \\\n", "0 0.217225 0.254981 0.233798 0.266952 0.398778 0.263058 \n", "0 0.217990 0.258010 0.243884 0.260663 0.403850 0.268266 \n", "0 0.186749 0.216980 0.204185 0.240096 0.339114 0.204905 \n", "0 0.176852 0.201189 0.166631 0.214925 0.305908 0.172546 \n", "0 0.118732 0.141584 0.130472 0.137473 0.214651 0.111707 \n", "0 0.102198 0.121074 0.102682 0.112455 0.180079 0.087429 \n", "0 0.056004 0.070229 0.088197 0.083166 0.115422 0.053515 \n", "0 0.051191 0.067885 0.092275 0.070730 0.104366 0.049606 \n", "0 0.046030 0.061286 0.079614 0.056463 0.095957 0.043178 \n", "0 0.039804 0.053022 0.071030 0.050639 0.088490 0.039308 \n", "0 0.031383 0.041343 0.040558 0.032107 0.067695 0.027470 \n", "0 0.025782 0.033598 0.028219 0.021751 0.054383 0.021119 \n", "0 0.000278 0.000463 0.000644 0.000189 0.000752 0.000168 \n", "\n", " MRR LAUC HR Reco in test Test coverage Shannon \\\n", "0 0.629129 0.607709 0.913043 1.000000 0.275613 5.085818 \n", "0 0.637590 0.606568 0.898197 1.000000 0.351371 5.366291 \n", "0 0.572157 0.593544 0.875928 1.000000 0.077201 3.875892 \n", "0 0.523871 0.591709 0.889714 1.000000 0.502886 5.722957 \n", "0 0.400939 0.555546 0.765642 1.000000 0.038961 3.159079 \n", "0 0.337825 0.547572 0.704136 0.974973 0.264791 4.909893 \n", "0 0.253329 0.522434 0.522800 0.996713 0.216450 4.424505 \n", "0 0.192999 0.517831 0.465536 0.867869 0.150072 3.847796 \n", "0 0.198193 0.515501 0.437964 1.000000 0.033911 2.836513 \n", "0 0.201565 0.512929 0.425239 0.997031 0.170996 4.167051 \n", "0 0.171187 0.509546 0.384942 1.000000 0.025974 2.711772 \n", "0 0.133978 0.507680 0.339343 0.986957 0.177489 5.088670 \n", "0 0.001677 0.496424 0.009544 0.600530 0.005051 1.803126 \n", "\n", " Gini \n", "0 0.913665 \n", "0 0.885046 \n", "0 0.974947 \n", "0 0.827507 \n", "0 0.987317 \n", "0 0.926201 \n", "0 0.952962 \n", "0 0.972676 \n", "0 0.991139 \n", "0 0.963929 \n", "0 0.992003 \n", "0 0.907676 \n", "0 0.996380 " ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import imp\n", "imp.reload(ev)\n", "\n", "import evaluation_measures as ev\n", "dir_path=\"Recommendations generated/ml-100k/\"\n", "super_reactions=[4,5]\n", "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n", "\n", "ev.evaluate_all(test, dir_path, super_reactions)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }