{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Self made SVD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import helpers\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import scipy.sparse as sparse\n",
    "from collections import defaultdict\n",
    "from itertools import chain\n",
    "import random\n",
    "\n",
    "train_read=pd.read_csv('./Datasets/ml-100k/train.csv', sep='\\t', header=None)\n",
    "test_read=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n",
    "train_ui, test_ui, user_code_id, user_id_code, item_code_id, item_id_code = helpers.data_to_csr(train_read, test_read)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n",
    "from tqdm import tqdm\n",
    "\n",
    "class SVD():\n",
    "    \n",
    "    def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n",
    "        self.train_ui=train_ui\n",
    "        self.uir=list(zip(*[train_ui.nonzero()[0],train_ui.nonzero()[1], train_ui.data]))\n",
    "        \n",
    "        self.learning_rate=learning_rate\n",
    "        self.regularization=regularization\n",
    "        self.iterations=iterations\n",
    "        self.nb_users, self.nb_items=train_ui.shape\n",
    "        self.nb_ratings=train_ui.nnz\n",
    "        self.nb_factors=nb_factors\n",
    "        \n",
    "        self.Pu=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_users, self.nb_factors))\n",
    "        self.Qi=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_items, self.nb_factors))\n",
    "\n",
    "    def train(self, test_ui=None):\n",
    "        if test_ui!=None:\n",
    "            self.test_uir=list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n",
    "            \n",
    "        self.learning_process=[]\n",
    "        pbar = tqdm(range(self.iterations))\n",
    "        for i in pbar:\n",
    "            pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...')\n",
    "            np.random.shuffle(self.uir)\n",
    "            self.sgd(self.uir)\n",
    "            if test_ui==None:\n",
    "                self.learning_process.append([i+1, self.RMSE_total(self.uir)])\n",
    "            else:\n",
    "                self.learning_process.append([i+1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n",
    "    \n",
    "    def sgd(self, uir):\n",
    "        \n",
    "        for u, i, score in uir:\n",
    "            # Computer prediction and error\n",
    "            prediction = self.get_rating(u,i)\n",
    "            e = (score - prediction)\n",
    "            \n",
    "            # Update user and item latent feature matrices\n",
    "            Pu_update=self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n",
    "            Qi_update=self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n",
    "            \n",
    "            self.Pu[u] += Pu_update\n",
    "            self.Qi[i] += Qi_update\n",
    "        \n",
    "    def get_rating(self, u, i):\n",
    "        prediction = self.Pu[u].dot(self.Qi[i].T)\n",
    "        return prediction\n",
    "    \n",
    "    def RMSE_total(self, uir):\n",
    "        RMSE=0\n",
    "        for u,i, score in uir:\n",
    "            prediction = self.get_rating(u,i)\n",
    "            RMSE+=(score - prediction)**2\n",
    "        return np.sqrt(RMSE/len(uir))\n",
    "    \n",
    "    def estimations(self):\n",
    "        self.estimations=\\\n",
    "        np.dot(self.Pu,self.Qi.T)\n",
    "\n",
    "    def recommend(self, user_code_id, item_code_id, topK=10):\n",
    "        \n",
    "        top_k = defaultdict(list)\n",
    "        for nb_user, user in enumerate(self.estimations):\n",
    "            \n",
    "            user_rated=self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user+1]]\n",
    "            for item, score in enumerate(user):\n",
    "                if item not in user_rated and not np.isnan(score):\n",
    "                    top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n",
    "        result=[]\n",
    "        # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n",
    "        for uid, item_scores in top_k.items():\n",
    "            item_scores.sort(key=lambda x: x[1], reverse=True)\n",
    "            result.append([uid]+list(chain(*item_scores[:topK])))\n",
    "        return result\n",
    "    \n",
    "    def estimate(self, user_code_id, item_code_id, test_ui):\n",
    "        result=[]\n",
    "        for user, item in zip(*test_ui.nonzero()):\n",
    "            result.append([user_code_id[user], item_code_id[item], \n",
    "                           self.estimations[user,item] if not np.isnan(self.estimations[user,item]) else 1])\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39 RMSE: 0.7467772350811145. Training epoch 40...: 100%|██████████| 40/40 [00:59<00:00,  1.50s/it]\n"
     ]
    }
   ],
   "source": [
    "model=SVD(train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40)\n",
    "model.train(test_ui)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7fedc32039b0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAb6UlEQVR4nO3dfXRV9Z3v8feXJBAegoEkWkp4knGWVNQgiaMVOyjLWUCnlapXce5MdTouOhYXerXeAtexY7s6U2TVpz65uK0Pbb1irQ/VLrVXC9apz+FREQRUbALcEgOBIEQJfO8f+xxyCCfJOclJ9snen9dav7X32Xuffb7ZCz57n99+OObuiIhI/zcg7AJERCQ3FOgiIhGhQBcRiQgFuohIRCjQRUQiojCsDy4vL/fx48eH9fEiIv3SqlWrPnL3inTzQgv08ePHU1tbG9bHi4j0S2b2YUfz1OUiIhIRCnQRkYhQoIuIRERofegi0v8cOnSI+vp6Wlpawi4l8oqLi6msrKSoqCjj9yjQRSRj9fX1lJSUMH78eMws7HIiy91pbGykvr6eCRMmZPw+dbmISMZaWlooKytTmPcyM6OsrCzrb0IKdBHJisK8b3RnO/e7QP/zn+GGG+DQobArERHJL/0u0Nesgbvvhh/8IOxKRETyS78L9Isvhksugdtug61bw65GRPpSU1MTP/nJT7J+3+zZs2lqasr6fVdffTUTJkygqqqKM888kz/84Q9H502fPp2xY8eS+iNBc+bMYdiwYQAcOXKEBQsWMHnyZE4//XRqamr44IMPgOBO+dNPP52qqiqqqqpYsGBB1rWl0y+vcvnhD+GFF+Bf/xWefx7UpScSD8lA/8Y3vnHM9NbWVgoLO46zZ555ptufuXTpUi677DJWrlzJvHnz2LJly9F5paWlvPzyy0ybNo2mpiZ27tx5dN4jjzzCjh07WL9+PQMGDKC+vp6hQ4cenb9y5UrKy8u7XVc6/TLQP/tZWLIErr0WfvELuOqqsCsSiZ8bboC1a3O7zqoquOuujucvXLiQ9957j6qqKoqKiiguLmbEiBFs2rSJzZs3M2fOHOrq6mhpaeH6669n3rx5QNuzo/bv38+sWbOYNm0ar7zyCqNHj+a3v/0tgwcP7rK2c889l+3btx8zbe7cuSxfvpxp06bx+OOPc8kll7BhwwYAdu7cyahRoxgwIOgIqays7OZWyVy/63JJmjcPzjsPbrwRdu0KuxoR6Qvf//73mThxImvXrmXp0qWsXr2au+++m82bNwNw3333sWrVKmpra7nnnntobGw8bh1btmxh/vz5bNiwgdLSUh577LGMPvu5555jzpw5x0ybMWMGL730EocPH2b58uVcccUVR+ddfvnlPP3001RVVXHTTTexZs2aY957wQUXHO1yufPOO7PdFGn1yyN0gAEDYNmyYI9+443wq1+FXZFIvHR2JN1Xzj777GNuvLnnnnt44oknAKirq2PLli2UlZUd855knzjA1KlT2bZtW6efcfPNN7N48WLq6+t59dVXj5lXUFDAtGnTWL58OQcPHiT1keCVlZW8++67rFixghUrVjBjxgweffRRZsyYAfROl0u/PUIH+NznYPFieOgh+P3vw65GRPpaap/0iy++yAsvvMCrr77KunXrmDJlStobcwYNGnR0vKCggNbW1k4/Y+nSpWzevJklS5bwta997bj5c+fOZcGCBVx++eVpP2vWrFksXbqUxYsX8+STT2bz52WtXwc6wKJFcOqpwQnSjz8OuxoR6U0lJSU0Nzennbd3715GjBjBkCFD2LRpE6+99lpOP/u6667jyJEj/L7d0eP555/PokWLuPLKK4+Zvnr1anbs2AEEV7ysX7+ecePG5bSm9vp9oA8aFHS9bNsG3/522NWISG8qKyvjvPPOY/Lkydx8883HzJs5cyatra1MmjSJhQsXcs455+T0s82MW265hdtvv/246d/85jeP6z7ZtWsXX/rSl5g8eTJnnHEGhYWFXHfddUfnp/ahf/WrX81NjanXUPal6upqz+UvFn396/Czn8Gbb8JZZ+VstSKSYuPGjUyaNCnsMmIj3fY2s1XuXp1u+X5/hJ60ZAmceCJccw100SUmIhJJXQa6mRWb2Rtmts7MNpjZbWmWudrMGsxsbaJd0zvldqy0NLjhKPloABGRTM2fP/9o90ey3X///WGXlbVMLlv8BLjQ3febWRHwJzN71t3bn3F4xN2vS/P+PnPppTB7Nnz3u8GljLqDVCT33D1yT1z88Y9/HHYJx+lOd3iXR+ge2J94WZRo4XS8d8EM/vZvYe9eOHAg7GpEoqe4uJjGxsZuhY1kLvkDF8XFxVm9L6Mbi8ysAFgF/BXwY3d/Pc1il5rZF4DNwP9w97o065kHzAMYO3ZsVoVmKnkPQWMjpFyiKiI5UFlZSX19PQ0NDWGXEnnJn6DLRkaB7u6HgSozKwWeMLPJ7v52yiJPAw+7+ydm9nXgQeDCNOtZBiyD4CqXrCrN0MiRwbCxEXppnyESW0VFRVn9JJr0rayucnH3JmAlMLPd9EZ3/yTx8mfA1NyUl73UI3QRkTjJ5CqXisSROWY2GLgI2NRumVEpL78MbMxlkdlQoItIXGXS5TIKeDDRjz4A+LW7/87MvgPUuvtTwAIz+zLQCuwGru6tgruiQBeRuOoy0N19PTAlzfRbU8YXAYtyW1r3pPahi4jESWTuFE0aOBBKShToIhI/kQt0CLpdFOgiEjeRDfTdu8OuQkSkb0U20HWELiJxo0AXEYkIBbqISERENtCbmuDw4bArERHpO5EM9JEjwR327Am7EhGRvhPJQNfdoiISRwp0EZGIUKCLiESEAl1EJCIU6CIiERHJQB8+HAoLFegiEi+RDHSz4NJFPc9FROIkkoEOultUROJHgS4iEhEKdBGRiIhsoI8cqUAXkXiJbKDrCF1E4ibSgd7SAgcOhF2JiEjfiHSgg47SRSQ+FOgiIhGhQBcRiYguA93Mis3sDTNbZ2YbzOy2NMsMMrNHzGyrmb1uZuN7o9hsKNBFJG4yOUL/BLjQ3c8EqoCZZnZOu2X+Bdjj7n8F3AksyW2Z2VOgi0jcdBnoHtifeFmUaN5usYuBBxPjvwFmmJnlrMpuSAa6nuciInGRUR+6mRWY2VpgF/C8u7/ebpHRQB2Au7cCe4GyNOuZZ2a1Zlbb0NDQs8q7MHAgDBumI3QRiY+MAt3dD7t7FVAJnG1mk7vzYe6+zN2r3b26oqKiO6vIim4uEpE4yeoqF3dvAlYCM9vN2g6MATCzQuAEIPQoVaCLSJxkcpVLhZmVJsYHAxcBm9ot9hRwVWL8MmCFu7fvZ+9zep6LiMRJYQbLjAIeNLMCgh3Ar939d2b2HaDW3Z8Cfg780sy2AruBub1WcRbKyuDDD8OuQkSkb3QZ6O6+HpiSZvqtKeMtwH/LbWk9py4XEYmTyN4pCkGg79kDhw+HXYmISO+LfKC7Q1NT2JWIiPS+yAc6qNtFROJBgS4iEhEKdBGRiFCgi4hERCwCXQ/oEpE4iHSgn3ACFBToCF1E4iHSgW6m2/9FJD4iHeigQBeR+Ih8oOv2fxGJCwW6iEhEKNBFRCJCgS4iEhGxCPSDB4MmIhJlsQh00FG6iESfAl1EJCIU6CIiERGbQNfzXEQk6mIT6DpCF5Goi3ygjxwZDBXoIhJ1kQ/04mIYMkSBLiLRF/lAB91cJCLxoEAXEYmILgPdzMaY2Uoze8fMNpjZ9WmWmW5me81sbaLd2jvldo8CXUTioDCDZVqBm9x9tZmVAKvM7Hl3f6fdcv/l7n+f+xJ7rqwM6urCrkJEpHd1eYTu7jvdfXVivBnYCIzu7cJySUfoIhIHWfWhm9l4YArweprZ55rZOjN71sxO6+D988ys1sxqGxoasi62u8rKYM8eOHKkzz5SRKTPZRzoZjYMeAy4wd33tZu9Ghjn7mcCPwSeTLcOd1/m7tXuXl1RUdHdmrNWVhaEeVNTn32kiEifyyjQzayIIMwfcvfH2893933uvj8x/gxQZGblOa20B3S3qIjEQSZXuRjwc2Cju9/RwTKfSSyHmZ2dWG/exKee5yIicZDJVS7nAf8EvGVmaxPTFgNjAdz9XuAy4FozawUOAnPd3Xuh3m7REbqIxEGXge7ufwKsi2V+BPwoV0Xlmp7nIiJxEJs7RUGBLiLRFotALy2FAQMU6CISbbEI9AEDYMQIBbqIRFssAh10t6iIRJ8CXUQkIhToIiIRoUAXEYkIBbqISETEKtAPHICWlrArERHpHbEKdNDzXEQkumIX6Op2EZGoik2g63kuIhJ1sQl0HaGLSNQp0EVEIkKBLiISEbEJ9MGDg6ZAF5Goik2gg24uEpFoU6CLiESEAl1EJCIU6CIiEaFAFxGJiNgF+u7dcORI2JWIiORe7AL9yBHYty/sSkREci9Wga7nuYhIlMUq0HW3qIhEWZeBbmZjzGylmb1jZhvM7Po0y5iZ3WNmW81svZmd1Tvl9owCXUSirDCDZVqBm9x9tZmVAKvM7Hl3fydlmVnAKYn2N8BPE8O8okAXkSjr8gjd3Xe6++rEeDOwERjdbrGLgV944DWg1MxG5bzaHlKgi0iUZdWHbmbjgSnA6+1mjQbqUl7Xc3zoY2bzzKzWzGobGhqyqzQHSkvBTIEuItGUcaCb2TDgMeAGd+/WhX/uvszdq929uqKiojur6JGCAhgxQoEuItGUUaCbWRFBmD/k7o+nWWQ7MCbldWViWt7R3aIiElWZXOViwM+Bje5+RweLPQV8NXG1yznAXnffmcM6c0aBLiJRlclVLucB/wS8ZWZrE9MWA2MB3P1e4BlgNrAVOAD8c+5LzY0xY2DVqrCrEBHJvS4D3d3/BFgXyzgwP1dF9aapU+HRR4Oj9ORVLyIiURCrO0UBamqCoY7SRSRqYhfoU6cGwzffDLcOEZFci12gn3AC/PVfK9BFJHpiF+gQdLvU1oZdhYhIbsU20Ldvh515eWGliEj3xDLQq6uDobpdRCRKYhnoU6YEjwFQt4uIREksA33IEDjtNB2hi0i0xDLQIeh2efNNcA+7EhGR3IhtoNfUBHeLbtsWdiUiIrkR60AH9aOLSHTENtBPPx0GDlQ/uohER2wDfeBAOPNMBbqIREdsAx2CbpdVq+DIkbArERHpudgHenMzbN4cdiUiIj0X60DXHaMiEiWxDvRJk2DoUF3pIiLREOtALyiAs87SEbqIREOsAx2Cbpc1a+DQobArERHpmdgHek0NtLTAO++EXYmISM8o0BN3jKrbRUT6u9gH+sSJUFqqQBeR/i/2gW7W9uRFEZH+LPaBDkG3y1tvBX3pIiL9VZeBbmb3mdkuM3u7g/nTzWyvma1NtFtzX2bvqq6G1lZYty7sSkREui+TI/QHgJldLPNf7l6VaN/peVl9SydGRSQKugx0d38J2N0HtYSmshJOOkl3jIpI/5arPvRzzWydmT1rZqd1tJCZzTOzWjOrbWhoyNFH95xOjIpIFOQi0FcD49z9TOCHwJMdLejuy9y92t2rKyoqcvDRuVNTAxs3Bk9fFBHpj3oc6O6+z933J8afAYrMrLzHlfWxmprgB6PXrAm7EhGR7ulxoJvZZ8zMEuNnJ9bZ2NP19jU9SldE+rvCrhYws4eB6UC5mdUD3waKANz9XuAy4FozawUOAnPd3Xut4l5y4okwdqwCXUT6ry4D3d2v7GL+j4Af5ayiENXUKNBFpP/SnaIpamrg/fdhd6Qv0hSRqFKgp0j2o+t6dBHpjxToKaZODYavvx5uHSIi3aFAT1FaCp//PNx7L3z8cdjViIhkR4HeztKlsGMH3H572JWIiGRHgd7O5z8PV1wRBHtdXdjViIhkToGexpIlcOQILFoUdiUiIplToKcxbhzcdBM89BC88UbY1YiIZEaB3oGFC+Ezn4Ebbgie8SIiku8U6B0oKYHvfQ9efRUeeSTsakREuqZA78RVV0FVFXzrW3DwYNjViIh0ToHeiYICuPNO+POf4Y47wq5GRKRzCvQuTJ8OX/kK/Od/ws6dYVcjItIxBXoGli6FTz+FW24JuxIRkY4p0DMwcSJcfz3cfz+sXh12NSIi6SnQM3TLLVBWBjfeqMsYRSQ/KdAzdMIJ8N3vwh//CP/xHwp1Eck/CvQsXHNN8JyXW26BuXP1REYRyS8K9CwUFsLDDwdPYvzNb+Ccc2Dr1rCrEhEJKNCzZAY33wzPPRc8ZremBp59NuyqREQU6N120UXBT9WNGwdf/GLwmAD1q4tImBToPTBhArzyCvzDPwT96pdeCs3NYVclInGlQO+hIUPgl78MHhHw1FNBF8zDDwc3IomI9CUFeg6YBY/Zff55OHw4OGIfOxb+7d+gvj7s6kQkLroMdDO7z8x2mdnbHcw3M7vHzLaa2XozOyv3ZfYPF1wA774bnCStqQn61cePD7piVqxQH7uI9K5MjtAfAGZ2Mn8WcEqizQN+2vOy+q8BA2DmTHj6aXjvveCXj/74R5gxA047LXhq41tvBT9xJyKSS10Guru/BOzuZJGLgV944DWg1MxG5arA/mzChOD3Sevq4IEHYOjQIODPOAMqKoKnON51F6xdq4AXkZ4rzME6RgN1Ka/rE9OOe9ismc0jOIpn7NixOfjo/mHw4ODHMq66Cj74IDhiT7YnnwyWKS2F888PblaaNAlOPTV4KNjAgeHWLiL9Ry4CPWPuvgxYBlBdXR3LHuUJE4J29dXB67q6INhffDEYPv1027IFBUGon3pqWzv5ZKishNGjobg4jL9ARPJVLgJ9OzAm5XVlYppkYMwY+Md/DBoE17G/+y5s2nRse/ZZOHTo2PeWlwfvr6xsayeeGExPbSNGBDsHEYm2XAT6U8B1ZrYc+Btgr7vrt326qaQEqquDlqq1Neiu+fDD4FLIZKurC6a9/DLs7uBMhxmMHNkW7qWlHQ+HD29rJSXBcNgw7RBE+oMuA93MHgamA+VmVg98GygCcPd7gWeA2cBW4ADwz71VbJwVFsIppwStIwcOwEcfpW+NjcFwz55guHVrMN7UFFw735Vhw9oCvqQk/fiwYcGNVkOHtrXU18l1lJQE081yt31EJINAd/cru5jvwPycVSTdNmRIcENTNueb3WH//rZwb26GffuObanTkuPNzcE3htRp7buEOmPWFvCpw9SWOi25Q2g/bD+ubxISZ316UlTyj1nbUXNPLzw6dCh4RvzHHwffFpLjybZ/f9Cam48fNjcHy+zc2bZcsmVzSefgwW1/T+o3go7Gu1quqKhn20SkLynQJWeKioJ++NLS3K3THVpa2gI/uWNIHX788fE7h9S2axe8//6xy2R61+6gQcd3MXXU7ZSuJecnu6PUzSS9SYEuec0sOOoePDh363QPvkGk7gDa7ww62jk0N8Nf/hKcg0h2NWX6y1UDBqT/NjB8ePATh8lh6nhymFwuuYPQNwdJR4EusWPWdqL2pJN6vr7Dh9u+JSTPKaS2ffs632G8/z7s3dt2niKTLqbBg9sCPrkTKC3NfDh8uM43RJECXaSHCgrawnX06J6tyz3YOezd29ZSdxTtT06nLrdzZzBsasrsW0NJSRDuI0e2tREjjn+dnJYcHz5cXUf5SoEukkeSV/8MG9aznUNra1u4JwM/OZ463LMnaLt3w8aNwXhjY+fP8y8oaLt3IdlSX6eOjxwJZWVtbfBg7Qx6kwJdJIIKC9tCNFvucPBgEOzJwE8N/tTx5E5h27a28c4uXx006NiATz36b/+NIHkzXHm5TihnSoEuIscwCwJ0yJDg0RLZSJ5wbmoKAn/37mDH0Nh47Hiybd7ctnNoael4vcXFxz/Sory8bceQbnzo0PjtBBToIpIzqSecs+0yOnjw2G8DqXc4t2/btrV9g+jIoEHBY6rLyzsett8x9PenmyrQRSQvJC9P/exnM39Pa2tbv38y/FPHGxqC9tFHwdVEH30UnD/oyPDhbSFfUZG+nXhi23DIkJ7/3bmkQBeRfquwsC1oM/Xpp50f/Tc2BjuBHTtg3bpg/JNP0q9r6NBjAz7Z2od/svX2I68V6CISKwMHBt8CMv0m4B5cJpo82m9oCO4+Th3ftQu2b4c1a4Lxjk4Ml5QEwT5/Ptx4Y+7+piQFuohIJ8za7jOYOLHr5d2DewSSod8+/BsacnNDWzoKdBGRHDJru3u3s8dd94YufyRaRET6BwW6iEhEKNBFRCJCgS4iEhEKdBGRiFCgi4hEhAJdRCQiFOgiIhFhnumv5eb6g80agA87WaQc+KiPysmWause1dY9qq17olrbOHdP+/Sa0AK9K2ZW6+7VYdeRjmrrHtXWPaqte+JYm7pcREQiQoEuIhIR+Rzoy8IuoBOqrXtUW/eotu6JXW1524cuIiLZyecjdBERyYICXUQkIvIu0M1sppm9a2ZbzWxh2PWkMrNtZvaWma01s9qQa7nPzHaZ2dsp00aa2fNmtiUxHJFHtf27mW1PbLu1ZjY7pNrGmNlKM3vHzDaY2fWJ6aFvu05qC33bmVmxmb1hZusStd2WmD7BzF5P/H99xMwG5lFtD5jZBynbraqva0upscDM1pjZ7xKve2e7uXveNKAAeA84GRgIrAM+F3ZdKfVtA8rDriNRyxeAs4C3U6bdDixMjC8EluRRbf8OfDMPttso4KzEeAmwGfhcPmy7TmoLfdsBBgxLjBcBrwPnAL8G5iam3wtcm0e1PQBcFva/uURdNwL/B/hd4nWvbLd8O0I/G9jq7u+7+6fAcuDikGvKS+7+ErC73eSLgQcT4w8Cc/q0qIQOassL7r7T3VcnxpuBjcBo8mDbdVJb6DywP/GyKNEcuBD4TWJ6WNuto9rygplVAl8EfpZ4bfTSdsu3QB8N1KW8ridP/kEnOPB/zWyVmc0Lu5g0TnL3nYnx/wf00k/Rdtt1ZrY+0SUTSndQKjMbD0whOKLLq23XrjbIg22X6DZYC+wCnif4Nt3k7q2JRUL7/9q+NndPbrfvJbbbnWY2KIzagLuA/wkcSbwuo5e2W74Fer6b5u5nAbOA+Wb2hbAL6ogH3+Xy5igF+CkwEagCdgI/CLMYMxsGPAbc4O77UueFve3S1JYX287dD7t7FVBJ8G361DDqSKd9bWY2GVhEUGMNMBL4Vl/XZWZ/D+xy91V98Xn5FujbgTEprysT0/KCu29PDHcBTxD8o84nfzGzUQCJ4a6Q6znK3f+S+E93BPjfhLjtzKyIIDAfcvfHE5PzYtulqy2ftl2iniZgJXAuUGpmhYlZof9/TaltZqILy939E+B+wtlu5wFfNrNtBF3IFwJ300vbLd8C/U3glMQZ4IHAXOCpkGsCwMyGmllJchz4O+Dtzt/V554CrkqMXwX8NsRajpEMy4SvENK2S/Rf/hzY6O53pMwKfdt1VFs+bDszqzCz0sT4YOAigj7+lcBlicXC2m7patuUsoM2gj7qPt9u7r7I3SvdfTxBnq1w9/9Ob223sM/+pjkbPJvg7P57wP8Ku56Uuk4muOpmHbAh7NqAhwm+fh8i6IP7F4K+uT8AW4AXgJF5VNsvgbeA9QThOSqk2qYRdKesB9Ym2ux82Had1Bb6tgPOANYkangbuDUx/WTgDWAr8CgwKI9qW5HYbm8DvyJxJUxYDZhO21UuvbLddOu/iEhE5FuXi4iIdJMCXUQkIhToIiIRoUAXEYkIBbqISEQo0EVEIkKBLiISEf8fAkPtxu0qi3AAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "df=pd.DataFrame(model.learning_process).iloc[:,:2]\n",
    "df.columns=['epoch', 'train_RMSE']\n",
    "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7fedc1043c18>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deXxU5fXH8c8JYVO0solKZBFRQQgBg2IFARFZiojWIihuWKkVxFZB0VqpVKstWhXXqsUFLah131AUFH+uhFUQ2UECVBGkxQUReX5/nBsIIUASktyZyff9es0rM/fOJOc6eObOuc9zHgshICIiqSst7gBERKRsKdGLiKQ4JXoRkRSnRC8ikuKU6EVEUlx63AEUVKdOndCoUaO4wxARSSrTp0//KoRQt7B9CZfoGzVqRE5OTtxhiIgkFTNbsat9eyzdmNlYM/vSzObuYr+Z2RgzW2xmc8ysTb5955vZouh2fsnCFxGRvVGUGv0jQPfd7O8BNI1ug4D7AMysFjASOA44FhhpZjX3JlgRESm+PSb6EMJUYP1unnIa8FhwHwIHmNnBQDdgUghhfQjha2ASu//AEBGRMlAaNfr6wMp8j3OjbbvavhMzG4R/G6BBgwalEJKIlKcff/yR3NxcNm3aFHcoKa9atWpkZGRQuXLlIr8mIS7GhhAeAB4AyM7OVvMdkSSTm5vLfvvtR6NGjTCzuMNJWSEE1q1bR25uLo0bNy7y60pjHP0q4NB8jzOibbvaLiIpZtOmTdSuXVtJvoyZGbVr1y72N6fSSPQvAudFo2/aAf8NIawBXgdOMbOa0UXYU6JtIpKClOTLR0n+O++xdGNm44FOQB0zy8VH0lQGCCHcD7wK9AQWA98BF0b71pvZn4Fp0a8aFULY3UXdUvANsARoVbZ/RkQkiewx0YcQ+u9hfwAG72LfWGBsyUIrifuB4cDpwPVAVvn9aRGRBJVivW5+DfwJmAy0xhP+rDgDEpFysGHDBu69995iv65nz55s2LCh2K+74IILaNy4MVlZWbRq1Yq33npr275OnTrRoEED8i/q1KdPH2rUqAHA1q1bGTp0KC1atKBly5a0bduWZcuWAd4ZoGXLlmRlZZGVlcXQoUOLHVthEmLUTek5AK8sXQ7cCdwO/BdP/CKSqvIS/aWXXrrD9i1btpCevus09+qrr5b4b44ePZozzzyTKVOmMGjQIBYtWrRt3wEHHMB7771H+/bt2bBhA2vWrNm278knn2T16tXMmTOHtLQ0cnNz2XfffbftnzJlCnXq1ClxXIVJsUSfJ3/CXxdtWwX8DrgWP9sXkbLwu9/BrFL+Ip2VBXfcsev9I0aMYMmSJWRlZVG5cmWqVatGzZo1+eyzz1i4cCF9+vRh5cqVbNq0icsvv5xBgwYB23trffPNN/To0YP27dvz/vvvU79+fV544QWqV6++x9iOP/54Vq3acUBhv379mDBhAu3bt+fZZ5/ljDPOYN68eQCsWbOGgw8+mLQ0L6hkZGSU8L9K0aVY6aagA4Am0f1PgDeBNkAfQI3TRFLFLbfcQpMmTZg1axajR49mxowZ3HnnnSxcuBCAsWPHMn36dHJychgzZgzr1q3b6XcsWrSIwYMHM2/ePA444ACeeeaZIv3tiRMn0qdPnx22denShalTp/LTTz8xYcIEzjrrrG37+vbty0svvURWVhZXXnklM2fO3OG1nTt33la6uf3224v7n6JQKXpGX5juwDJgDF7SeQFvwfM+UCnGuERSy+7OvMvLscceu8OEojFjxvDcc88BsHLlShYtWkTt2rV3eE1ezR3gmGOOYfny5bv9G8OHD+faa68lNzeXDz74YId9lSpVon379kyYMIHvv/+e/K3XMzIyWLBgAZMnT2by5Ml06dKFp59+mi5dugBlU7pJ8TP6gg7AR+MsB+4BTmZ7kh8JvAH8FEtkIlJ68te83377bd58800++OADZs+eTevWrQudcFS1atVt9ytVqsSWLVt2+zdGjx7NwoUL+etf/8rAgQN32t+vXz+GDh1K3759C/1bPXr0YPTo0Vx77bU8//zzxTm8YqtgiT7Pz4BLgZuix+vxxN8NaIx/GCyLJzQRKbb99tuPjRs3Frrvv//9LzVr1mSfffbhs88+48MPPyzVvz1kyBC2bt3K66/vOB+0Q4cOXHPNNfTvv+MI9RkzZrB69WrAR+DMmTOHhg0blmpMBVXQRF9QLfxi7VPA0cCNwGHAs3EGJSJFVLt2bU444QRatGjB8OHDd9jXvXt3tmzZQrNmzRgxYgTt2rUr1b9tZlx33XX87W9/22n7sGHDdirDfPnll5x66qm0aNGCzMxM0tPTGTJkyLb9+Wv05513XunEmH+sZyLIzs4O8a8wtRJ4FJ8HVhMYj0/wHQQcFWNcIolp/vz5NGvWLO4wKozC/nub2fQQQnZhz9cZfaEOBa7DkzzAfOBuoBnQGXgS2BxPaCIixaREXySj8LP8m/ELuf3w9VZEJJUNHjx4Wxkl7/bwww/HHVaxVaDhlXurHjACuApfLCvvP93XwLnAQOBUon5vIpIC7rnnnrhDKBVK9MWWho/OybMImA38EqiL99f5JV7iUdIXkfipdLPXjsWHYr4InAT8C/8gyI32rwZ+iCc0ERGU6EtJOl62mQB8ibdayJuVNwQ4EDgHH675XRwBikgFpkRf6qoDXfI9vhT4Fb64Vl5556oY4hKRikqJvsydDDwE/Ac/0z8fODja9wNwZLTtEWBFDPGJJL+S9qMHuOOOO/juu91/087rE5+ZmUnHjh1ZsWL7/6tmxoABA7Y93rJlC3Xr1qVXr14AfPHFF/Tq1YtWrVrRvHlzevbsCcDy5cupXr36DiN6HnvssRIdw54o0ZebdPxM/17g99G2r4FMfDXGC4FG+Izcl6P9iTWZTSRRlXWiB282NmfOHDp16sSNN964bfu+++7L3Llz+f777wGYNGkS9evX37b/+uuvp2vXrsyePZtPP/2UW265Zdu+vI6bebfSmglbkBJ9rA4Cnga+AObgi6W0wmv6AK/gk7SuxWfmKvFLsuhUyC0vEX+3i/2PRPu/KmTf7uXvRz98+HBGjx5N27ZtyczMZOTIkQB8++23/OIXv6BVq1a0aNGCJ598kjFjxrB69Wo6d+5M586di3RkhfWf79mzJ6+88goA48eP36G/zZo1a3boOZ+ZmVmkv1OalOgTQhrQEhgKPIeP5AHYDzgE+Fu0rWH0nMKbN4lUVPn70Xft2pVFixbx8ccfM2vWLKZPn87UqVOZOHEihxxyCLNnz2bu3Ll0796doUOHcsghhzBlyhSmTJlSpL9VWP/5vIVGNm3axJw5czjuuOO27Rs8eDAXXXQRnTt35qabbtrW0AzY9uGUd3v33XdL5z9IARpHn9A6Am/hq2S9jI/aeQXIa/j9BN6J82SgWhwBiuzC27vZt88e9tfZw/7de+ONN3jjjTdo3dpXkvvmm29YtGgRHTp04Morr+Tqq6+mV69edOjQoVi/t3Pnzqxfv54aNWrw5z//eYd9mZmZLF++nPHjx2+rwefp1q0bS5cuZeLEibz22mu0bt2auXPnAttLN2VNZ/RJoTZ+wfYFYCHb37a/4MM66wJ9gbF4qwaRiiuEwDXXXLOt7r148WIuuugijjjiCGbMmEHLli257rrrGDVqVLF+75QpU1ixYgVZWVnbykH59e7dm2HDhu3UlhigVq1anH322YwbN462bdsyderUEh9fSSjRJ538q2HNBCbiY/TfBS7CWywDbMW/Aey8ZJpIqsnfj75bt26MHTuWb775BoBVq1bx5Zdfsnr1avbZZx8GDBjA8OHDmTFjxk6v3ZP09HTuuOMOHnvsMdavX7/DvoEDBzJy5Ehatmy5w/bJkydvu9i7ceNGlixZQoMGDfbqeItLpZukVgWfhdsNuA+YF20DXyP3l4Dhi6F3iW4d8K/OIqkjfz/6Hj16cPbZZ3P88ccDUKNGDR5//HEWL17M8OHDSUtLo3Llytx3330ADBo0iO7du2+r1e/JwQcfTP/+/bnnnnv44x//uG17RkYGQ4cO3en506dPZ8iQIaSnp7N161Z+/etf07ZtW5YvX76tRp9n4MCBhf6OvaV+9ClrCz5S5y18/P77wI94rf8XeGuGjfg4fpG9o3705Uv96CWSDhyP99V/Gx+zPxG/wAvwIL6IyhHAFcBk1GNfJDWpdFNh7MuOXTcH4hdxX8bHN9+Oj+tfif+z+BF135SK5rjjjuOHH3ZsQjhu3Lid6u7JRom+wjoU78NzKfANXuJZwfZ/Eu2j+72iWwu83i9SuBACZsn9b+Sjjz6KO4Q9Kkm5XaUbAWrgK2blXQQKQE9gEz4rNxNvz/CPOIKTJFCtWjXWrVtXoiQkRRdCYN26dVSrVrx5M0U6ozez7vj8/ErAQyGEWwrsb4gP4q4LrAcGhBByo30/4UNAAD4PIfQuVoQSAwNGRrfVeC+el/HOnETbfouf6f8Cn70rFVlGRga5ubmsXbs27lBSXrVq1XZoqVAUexx1Y2aV8Fk6XfHVNKYB/UMIn+Z7ztPAyyGER83sJODCEMK50b5vQgg1ihqQRt0kg/fwsft5HfyOwf95DGV7Z04RKU97O+rmWGBxCGFpCGEzvrpGwZWxm+PDNgCmFLJfUsoJ+Kpan+ALplcFbgV+ivbfC/wcr/8/AHyEFlwRiU9REn19dpxXnxtty282cEZ0/3RgPzOrHT2uZmY5ZvahmfVBUoThF2hH4Gf437D9n8X++MStfwG/AdoBtdg+fHMm6r0vUn5K62LsMKCjmc3EB2qvYvvpXcPo68TZwB1m1qTgi81sUPRhkLM3Nb4CM5KlXFVl+6icAWwfu78M78h5K9tn7f4Ov7jbDO/NPxH4vvxCFalgipLoV+Fj8fJkRNu2CSGsDiGcEUJoDfwh2rYh+rkq+rkU/7+/dcE/EEJ4IISQHULIrlu3bkmOgy++gMaN4cILoUCraImN4Qm9D752bp5/4OP2G+KtG3qw/Qsh+IeDRm+IlJaiJPppQFMza2xmVYB+wIv5n2Bmdcws73ddg4/AwcxqmlnVvOfgxd1PKQPVqsHFF8O//gVHHAEjR0LU00gSzlH4Wf1EfJDWa2xfR3cd0AT/EPg1MB5fmEVESmqPiT6EsAU/HXsdmA88FUKYZ2ajzCxvqGQnYIGZLQTqATdF25sBOWY2G79Ie0v+0Tql6Wc/g1tvhfnzoVcvGDXKE/4//wk//bTn10tc9gG6A3mr+1QG7geygX/jFb+DgMej/d8CG8o5RpHklrJNzT74AK64Aj78EDIz/UOga9dSCFDK0U/ADHxAV1+gMTAOuAAf0nlSdGuPOnJKRVchm5odfzy8/z48+SRs3AinnAI9esC8eXFHJkVXCWgLXI0nefAz/evwi7+34f17fsb25RWfw79QPoGPBlqF9+YXqbhSNtEDmEHfvl7OufVWP8vPzITf/Ab+85+4o5OSaQbcgC+08jVe378OX18XvMJ4HT7ypz0+duDAfK+fADyM9+5XTU8qhpQt3RRm3Tqv3d97L1SuDBddBFdeCY0alcmfk9h8B3wOLMfH63+Lt2IGv5z0TnS/Bl4C6oF/axBJXrsr3VSoRJ9n0SK4+WZ4/HHYuhX69YOrrvKzfUl1W/GOHh/jA8o+Bg7HSz3gpaF6+ITw1kAroAHq3CmJrkLW6HenaVMYOxaWLoXLL4cXXoBWraBnT5g6FRLss09KVRo+vPM84C68PUPeiJ4f8U6dK/Dy0Gn4PIAr8+0fi18g3lRuEYvsrQqZ6PNkZMBtt8Hnn8ONN0JODnTsCD//OTz/vJ/tS0WQd7ZeGU/kc4H/AR/gE7ryJnMtwBdgPwYv+7TAm7t9GO3XGYIkpgqd6PPUrAl/+AOsWAH33OMXak8/HY4+2s/8Cyw4IxVCDbxHzyX4RV3w3n0LgafxeYGN8YvCeb033sQv/vbCLwj/G1iCRv1I3CpkjX5PtmyBp5+Gv/4VZs+Ggw+GIUPgkkugVq1YQ5OEFPBvBdPwctBMfG5h3qie6UAb/Mz/Y+Cw6NYIjf+X0qIafTGlp0P//jBzJrz+OrRs6Wf8hx4Kgwf7xVyR7fJKP22Bx/D2zd8AOcBDwNHR/meAy4FTo2374v37v432vwk8AkzFF3dJrJMwSV46oy+iTz6B22+HJ56AH3+E3r19aGb79j5eX2TPAvAVsDS6LcMT+t3R/v74OP88P8NHAb0ZPZ6Lzxc4FJ2jSUEaXlmK/vMfr+Pfe6+3RW7b1lstnHmmfxMQKbnN+NIPS/FrAZ/iI30eiPafALyPfxNoFt064heIiZ5buRzjlUSiRF8GvvsOHn3Uz/IXLYIGDbysM2AAHKIlVKVMfISv8fNpdJuP1/5fiPY3wi/8No1uR+ArfbUr70AlBkr0ZWjrVnj5ZR+mOXUqpKVBly5w7rk+cqdGkVfLFSmJLUA6Xhb6Cz4EdCGwCB8NdClwT/S8ZviooEb5bu2AI8s3ZCkTSvTlZMECr+E//jgsWwb77ANnnOFn+SefDJUqxR2hVCzr8ARfD2/tPARvC7Gc7Rd7b8aXg8wFTmT7B0DeyKAO7LjukCQqJfpyFgK89x6MGwdPPQUbNsBBB8HZZ/uZfqtWuoArcfsB7we0H97vfwVwLf4hsAxYEz3vcXxSWA5wMb4ozGH5fh6LXzSWuCnRx2jTJnjlFU/6r77qI3ZatIBBg+D882H//eOOUKQw3+EJ/2B8YfccYCQ+AWwZ2xd6fwf/JjARXx6yCd47qEl0OxJdIC4fSvQJYt0674//yCMwbZrX788/3y/iNmsWd3QiRbUV7/O/BG8HsR/wPHBjtC3/CmBL8DP/J/GZwgdGt3rRz1PxD4LN0U991S0pJfoENG0a3H03TJgAmzf7BdwhQ3wZRA3TlOS2Hk/wS4Az8YvFt+PDRL9ke8sI2J7gh0T76+KloHRgf+D/oufdgK80lo4vSJOOt6l4Ev9wWAXUBqqV3WElOM2MTUBt2/rwzNxc+MtfYOFCH6XTpAnccgt89VXcEYqUVC18lnA/PCED/B4fDroOT+6rgTlsL+v0wLuEdsdnDR+OjxDKk4Yn9M34amJf4fMG8r4BDMDnFxyFf7jcgC9CI6Az+oSxZQu89JKf5U+eDFWrehuGwYMhu9DPaBHZ7mW819An0W0JcAp+7QB8Ytn3eJlp/+hnJ2BgtP9BfHnK/YE60e2Q6HFyUOkmyXz6qc++ffRR+PZbT/SXXOILpOy7b9zRiSSD7/ClJutHjy8AvsC/DfwvuvUB7sCHmVZm56Ulh+BN6jbji9DkfQDUjX72wGcrb8TLTmls/+aRhnc9zcKvWTwLVIlulaOfmfjQ1W/xbztV8NbXJSu0KNEnqf/+18fk33efL2q+//5w3nme9I8+es+vF5GiCMBaPGH/Fy8LfYWPGjoO/1AYGG1bm2//n4A/4m0rGhTye28HfoevT9yikP0P4e0rPmL77OXvKel1BiX6JJc3Lv/++7198ubN0KGDJ/xf/tLLPCJSnrbi3wAqR/e/wT8wtka3gLeg3ge/lrA6+rk5360RPvLoa/yi82bgdHRGL3z1FTz8MPzjH7BkCdSpAwMH+rj8Jk3ijk5E4qJRNymkTh0YPtxH6bzxBpx4ovfZOfxwb7MwfrxP0hIRyaNEn6TS0qBrV3jmGV8C8YYb/Az/7LO9e+Zll8GsWXFHKSKJQIk+BdSvD9df74l+0iTo1g0efBBat4Y2bXwEz9dfxx2liMRFiT6FpKVtL9+sXg1jxngb5SFD/Cz/nHN8jP5WrVUtUqEo0aeoWrW8fDNzJuTk+AXbV17xVgtHHAF33OHDN0Uk9SnRpzgzOOYYL9+sWePj8uvVg9//3ks+gwfDZ5/FHaWIlKUiJXoz625mC8xssZmNKGR/QzN7y8zmmNnbZpaRb9/5ZrYoup1fmsFL8VSv7uWb997zs/wzz4SHHvLOmd26+UpZKuuIpJ49Jnozq4SvRdYDaA70N7PmBZ52K/BYCCETGIUvW4OZ1cKbWB+Hr1Aw0sxqll74UlLHHOPtkleuhBtvhLlz4dRTVdYRSUVFOaM/FlgcQlgaQtgMTABOK/Cc5ngPUYAp+fZ3AyaFENaHEL4GJuHt6SRBHHgg/OEPsHy5t0w+6KAdyzrz58cdoYjsraIk+vp4M4c8uWzvFJRnNnBGdP90YD8zq13E12Jmg8wsx8xy1q5dW9TYpRRVrgxnnQX/938wfTr86lfwz39C8+ZwyineWfOngj2fRCQplNbF2GFARzObifcDXcXOreB2KYTwQAghO4SQXbdu3VIKSUqqTRtvs7ByJdx0k3fT7N3byzp//7uvgSsiyaMoiX4VOy4DnxFt2yaEsDqEcEYIoTXwh2jbhqK8VhJX3bpw7bWwbJkvcl6/Plx5pf+85BLvqCkiia8oiX4a0NTMGptZFXzZmBfzP8HM6phZ3u+6Bhgb3X8dOMXMakYXYU9By74kncqVvZQzdaqPy+/Xz3vlt2jh4/JfeEFlHZFEtsdEH0LYgnfgfx3vjv9UCGGemY0ys97R0zoBC8xsIb7q703Ra9cDf8Y/LKYBo6JtkqSysrx2v3Il3HwzLFoEffp4U7XRo2G93l2RhKM2xbJXtmzxM/q77oJ33tk+Vv+yyyAzM+7oRCoOtSmWMpOe7oufvP02zJ4NAwbAE09Aq1bQsSP8+9/+YSAi8VGil1KTmQkPPAC5uV7G+fxzr+03bgx/+Qto5KxIPJTopdTVqgXDhsHixV7WOeoon5R16KFwwQV+5i8i5UeJXspMpUo+/n7SJB+KOXCgl3Kysryd8sSJvh6uiJQtJXopF82bw733+midW27x1go9ekDLljB2LPzwQ9wRiqQuJXopVzVrwtVX+ySsxx7zs/6LLoJGjXwW7rp1cUcoknqU6CUWVarAuef6uraTJnk557rroEEDXxFryZK4IxRJHUr0Eiszr9e/9hp88ok3VnvwQWjaFM44A959V3V8kb2lRC8Jo0ULr9evWOE9dt55B048Edq29ZWxNm+OO0KR5KRELwnnoIN8MZSVK+H+++Hbb73MkzceX3V8keJRopeEtc8+8Jvf+NDM117zM/688fiXXKJFUUSKSoleEl5aGnTvDq+/7ksennOOL4PYvLkP0XzjDdXxRXZHiV6SytFH+8XalSth1Chvm9ytm2/PK/OIyI6U6CUp1a0Lf/yjX7h99FEv8/z2t5CR4e0Xli+PO0KRxKFEL0mtalU47zyYNg3ee8/Xt73jDmjSBE4/HaZMUVlHRIleUoIZ/Pzn8OSTfjY/YoSPwT/pJG+Z/NBD8N13cUcpEg8lekk5GRneTmHlSh+Xn5YGF1/so3WuvVbtkqXiUaKXlFW9Olx4oV+wfecd6NTJG6o1agTDh8MXX8QdoUj5UKKXlGfmM2yfecbH5J9+Ovz9757wf/c7WL067ghFypYSvVQozZp5O4XPPoN+/eDuu+Gww2DwYF8RSyQVKdFLhdS0KTz8MCxc6KN2HngADj/cZ+IuWxZ3dCKlS4leKrTDDvMkv3gx/PrXPuO2aVNfDWvBgrijEykdSvQiQMOGvgLW0qVexhk/3ss8v/wlfPxx3NGJ7B0lepF86teHO+/0sfjXXAOTJ8Nxx0Hnzt5YTZOvJBkp0YsUol49H4v/+edw222waBH07AmtW8O//gVbtsQdoUjRKdGL7MZ++8EVV3hJ5+GHffGTc87xOv7dd2u2rSQHJXqRIqhSBS64wNskv/ACHHIIXHaZr3F7442wcWPcEYrsmhK9SDGkpUHv3t5A7d134fjjvYtm48Ze4vn++7gjFNmZEr1ICbVvDy+95KNyjjnG2yMffjjcd5/Wt5XEokQvspfatvXVr955x8flX3opHHmk98nXRVtJBEVK9GbW3cwWmNliMxtRyP4GZjbFzGaa2Rwz6xltb2Rm35vZrOh2f2kfgEiiOPFEmDrVh2HWru01/ZYt4amnYOvWuKOTimyPid7MKgH3AD2A5kB/M2te4GnXAU+FEFoD/YB78+1bEkLIim6XlFLcIgnJzNe3nTYNnn0WKlWCs87y0s7LL2scvsSjKGf0xwKLQwhLQwibgQnAaQWeE4D9o/s/A9QPUCo0M++SOXu2N1HbuBFOPRVOOMFXvRIpT0VJ9PWBlfke50bb8vsTMMDMcoFXgcvy7WsclXTeMbMOhf0BMxtkZjlmlrNWq0JICqlUycfdz5/vPXVWrvRVr7p2VWsFKT+ldTG2P/BICCED6AmMM7M0YA3QICrpXAH8y8z2L/jiEMIDIYTsEEJ23bp1SykkkcRRubKvcrVoEdx+u5/pH3cc9OkDn3wSd3SS6oqS6FcBh+Z7nBFty+8i4CmAEMIHQDWgTgjhhxDCumj7dGAJcMTeBi2SrKpV88VOli71iVZvv+1r2p5zjnfQFCkLRUn004CmZtbYzKrgF1tfLPCcz4EuAGbWDE/0a82sbnQxFzM7DGgKLC2t4EWSVY0a8Ic/eMK/+mp4/nk46igYNMjLOyKlaY+JPoSwBRgCvA7Mx0fXzDOzUWbWO3ralcDFZjYbGA9cEEIIwInAHDObBfwbuCSEsL4sDkQkGdWqBTffDEuW+Pj7vH74w4bB//4Xd3SSKiwk2Hiv7OzskJOTE3cYIrFYsQJuuMETfr16MHq0l3XM4o5MEp2ZTQ8hZBe2TzNjRRJIw4Ywdix89BEceiicey507Ahz5sQdmSQzJXqRBNS2LXz4ITz4IHz6KbRpA5dfDhs2xB2ZJCMlepEElZbm69guXOgXae+6a3sPHbVUkOJQohdJcLVq+Xq2OTneNO2CC6BDB5g5M+7IJFko0YskiTZtvA/+2LE+8So72xcyX69xbLIHSvQiSSQtDS68EBYs8CR///1wxBHeA/+nn+KOThKVEr1IEqpZE8aM8fJNy5Y+Br9NG++JL1KQEr1IEsvMhMmT4emnfUROp07Qt6+PxxfJo0QvkuTM4Mwz4bPPfLLVyy97O4U//Qm++y7u6CQRKNGLpIjq1eH66z3hn3aaJ/1mzXyFqwSbAC/lTIleJMU0aAATJviyhrVq+QpXnTt7a2SpmJToRVJUhw4+9v4f/4B586B1a7joIlizJu7IpLwp0Y0ThIQAAA13SURBVIuksEqVfFbtwoVwxRUwbpx3x/zzn1W/r0iU6EUqgJo14dZbfUnDHj28ln/EEfDYY2qnUBEo0YtUIE2a+FDMd9+FQw6B88/3Bmoaf5/alOhFKqD27b075hNPwNq1Pv7+9NO9tYKkHiV6kQoqLQ3OPtvbKfzlL/Dmm9C8ua9pq3bIqUWJXqSCq14drrnGFycfONDbITdr5iUejb9PDUr0IgL40oX/+AdMm+b1+759oXdv+PzzuCOTvaVELyI7aNPGlzK87Tbvo9O8Odx5p7pjJjMlehHZSXq6j7ufNw9OPNHr9u3awaxZcUcmJaFELyK71KgRvPKKt1T4/HNf7OSqqzTZKtko0YvIbpl5v5z5833Rk9GjoUULeOONuCOTolKiF5EiqVULHnwQ3n4bqlSBbt18eOaqVXFHJnuiRC8ixdKxo9fqr78enn0WjjwSbrkFfvgh7shkV5ToRaTYqlXzfveffgonn+zj8Fu08Hq+JB4lehEpscMOg+efh4kTvVNmr17wi194t0xJHEr0IrLXunWDOXO8Q+a77/rZ/YgRsHFj3JEJKNGLSCmpUgWuvNLP5s85B/76V6/fP/64WinErUiJ3sy6m9kCM1tsZiMK2d/AzKaY2Uwzm2NmPfPtuyZ63QIz61aawYtI4jnoIHj4YfjgA6hfH84911e7mjcv7sgqrj0mejOrBNwD9ACaA/3NrHmBp10HPBVCaA30A+6NXts8enw00B24N/p9IpLi2rXzVgr//Kd3yGzTxs/yt2yJO7KKpyhn9McCi0MIS0MIm4EJwGkFnhOA/aP7PwNWR/dPAyaEEH4IISwDFke/T0QqgLQ074g5bx6ceqrX7du3h88+izuyiqUoib4+sDLf49xoW35/AgaYWS7wKnBZMV6LmQ0ysxwzy1m7dm0RQxeRZHHggd72ePx4X9wkK8ubpqlRWvkorYux/YFHQggZQE9gnJkV+XeHEB4IIWSHELLr1q1bSiGJSCIxg379/Oy+e3cYNswbpmlVq7JXlGS8Cjg03+OMaFt+FwFPAYQQPgCqAXWK+FoRqUAOOgieew7GjfMJV61aeRtkLVJedoqS6KcBTc2ssZlVwS+uvljgOZ8DXQDMrBme6NdGz+tnZlXNrDHQFPi4tIIXkeRkBgMG+Nn9SSd5G+TOnWHJkrgjS017TPQhhC3AEOB1YD4+umaemY0ys97R064ELjaz2cB44ILg5uFn+p8CE4HBIQRV5UQE8JWsXnrJh2POmgWZmXD33Tq7L20WEmwmQ3Z2dsjJyYk7DBEpZ7m5cPHF3k6hfXt46CGfcCVFY2bTQwjZhe3TzFgRSQgZGfDqq/Doo17SadXKu2L++GPckSU/JXoRSRhmcN55vsjJqad6V8xjj4WZM+OOLLkp0YtIwqlXz8fdP/ss/Oc/0LatJ/3vv487suSkRC8iCev0030I5vnnexknK8u7Y0rxKNGLSEKrWdP75UyaBJs3+ySrwYPVArk4lOhFJCmcfDLMnetj7u+7D44+Gt58M+6okoMSvYgkjX33hdtvh/fe8/tdu8LQofDdd3FHltiU6EUk6Rx/PMyY4Un+rrvgmGNA0292TYleRJJS9ereI2fSJK/XH388jBqlfveFUaIXkaR28snwySfQty+MHOmzarU4+Y6U6EUk6dWsCU88ARMmeJLPyvILtgnW4SU2SvQikjLOOsvP7jt0gEsvhZ49YfXqPb8u1SnRi0hKqV/fG6PdfTe88w60bAlPPhl3VPFSoheRlGPmk6pmzoTDD/eVrX71K/jyy7gji4cSvYikrCOP9DH3N98ML74IzZv72X1Fq90r0YtISktPhxEj/Oz+sMP87P7MM+GLL+KOrPwo0YtIhdC8Obz/vjdHe/llb6EwYULFOLtXoheRCiM9Ha6+2s/umzSB/v0rxtm9Er2IVDjNm3vt/m9/g1de8cfjx6fu2b0SvYhUSOnpMHy4n903bQpnnw1nnJGaI3OU6EWkQmvWzM/uR4+G117zcfevvRZ3VKVLiV5EKrxKlWDYMO+AWa+ez6gdOjR1li5UohcRibRoAR9/7Iub3HWXL0z+ySdxR7X3lOhFRPKpVs0XN5k4Eb76yhcmv/NO2Lo17shKToleRKQQ3brBnDlwyil+ht+zJ6xZE3dUJaNELyKyC3XrwgsveMvjqVMhM9NbKSQbJXoRkd0wg0sugenTISMDTjsNfvvb5FqnVoleRKQImjWDDz/0sff33+/r1M6ZE3dURaNELyJSRFWr+mzaN9+EDRt8VM499yT+jFolehGRYurSBWbPhpNOgiFDfEbt+vVxR7VrRUr0ZtbdzBaY2WIzG1HI/tvNbFZ0W2hmG/Lt+ynfviS8jCEisrMDD/QumLfd5v1ysrLg3Xfjjqpwe0z0ZlYJuAfoATQH+ptZ8/zPCSH8PoSQFULIAu4Cns23+/u8fSGE3qUYu4hIrNLS4Ior4IMPvKzTqROMGgU//RR3ZDsqyhn9scDiEMLSEMJmYAJw2m6e3x8YXxrBiYgkg2OOgRkzvDHayJFe2snNjTuq7YqS6OsDK/M9zo227cTMGgKNgcn5Nlczsxwz+9DM+uzidYOi5+SsXbu2iKGLiCSO/faDcePg0Ue9Z06rVokz5r60L8b2A/4dQsj/xaVhCCEbOBu4w8yaFHxRCOGBEEJ2CCG7bt26pRySiEj5Oe88P7tv2NDH3A8dCps2xRtTURL9KuDQfI8zom2F6UeBsk0IYVX0cynwNtC62FGKiCSRI47wun1ec7TjjoN58+KLpyiJfhrQ1Mwam1kVPJnv9IXEzI4CagIf5NtW08yqRvfrACcAn5ZG4CIiiaxqVW+O9sor3iMnO9tbKcQx5n6PiT6EsAUYArwOzAeeCiHMM7NRZpZ/FE0/YEIIOxxGMyDHzGYDU4BbQghK9CJSYfTs6TNoO3aESy+FPn28K2Z5spBgU7qys7NDTk5O3GGIiJSqrVthzBhfnLx2bXjsMTj55NL7/WY2PboeuhPNjBURKQdpaV6z/+gjOOAA6NrV++Zs3lwOf7vs/4SIiOTJyvLhl5dcArfeCu3awYIFZfs3lehFRMrZPvv4hdnnnoMVK6BNG3joobK7UKtELyISkz59/EJtu3Zw8cXQt2/ZLFmYXvq/UkREiqp+fZg0ycs4//uf1/JLmxK9iEjM0tLgqqvK8PeX3a8WEZFEoEQvIpLilOhFRFKcEr2ISIpTohcRSXFK9CIiKU6JXkQkxSnRi4ikuIRrU2xma4EVBTbXAcq5g3OZS7VjSrXjgdQ7plQ7Hki9Y9qb42kYQih0LdaES/SFMbOcXfVZTlapdkypdjyQeseUascDqXdMZXU8Kt2IiKQ4JXoRkRSXLIn+gbgDKAOpdkypdjyQeseUascDqXdMZXI8SVGjFxGRkkuWM3oRESkhJXoRkRSXcInezMaa2ZdmNjfftlpmNsnMFkU/a8YZY3Hs4nj+ZGarzGxWdOsZZ4zFZWaHmtkUM/vUzOaZ2eXR9qR8n3ZzPEn7PplZNTP72MxmR8d0Q7S9sZl9ZGaLzexJM6sSd6xFsZvjecTMluV7j7LijrU4zKySmc00s5ejx2Xy/iRcogceAboX2DYCeCuE0BR4K3qcLB5h5+MBuD2EkBXdXi3nmPbWFuDKEEJzoB0w2Myak7zv066OB5L3ffoBOCmE0ArIArqbWTvgr/gxHQ58DVwUY4zFsavjARie7z2aFV+IJXI5MD/f4zJ5fxIu0YcQpgLrC2w+DXg0uv8o0Kdcg9oLuziepBZCWBNCmBHd34j/Q61Pkr5PuzmepBXcN9HDytEtACcB/462J9N7tKvjSVpmlgH8AngoemyU0fuTcIl+F+qFENZE9/8D1IszmFIyxMzmRKWdpChxFMbMGgGtgY9IgfepwPFAEr9PUVlgFvAlMAlYAmwIIWyJnpJLEn2gFTyeEELee3RT9B7dbmZVYwyxuO4ArgK2Ro9rU0bvT7Ik+m2CjwdN6k9y4D6gCf4VdA1wW7zhlIyZ1QCeAX4XQvhf/n3J+D4VcjxJ/T6FEH4KIWQBGcCxwFExh7RXCh6PmbUArsGPqy1QC7g6xhCLzMx6AV+GEKaXx99LlkT/hZkdDBD9/DLmePZKCOGL6B/tVuBB/H/CpGJmlfGk+EQI4dloc9K+T4UdTyq8TwAhhA3AFOB44AAzS492ZQCrYgushPIdT/eo7BZCCD8AD5M879EJQG8zWw5MwEs2d1JG70+yJPoXgfOj++cDL8QYy17LS4aR04G5u3puIopqif8E5ocQ/p5vV1K+T7s6nmR+n8ysrpkdEN2vDnTFrz1MAc6MnpZM71Fhx/NZvhMLw+vZSfEehRCuCSFkhBAaAf2AySGEcyij9yfhZsaa2XigE96u8wtgJPA88BTQAG9h3DeEkBQXOHdxPJ3wckAAlgO/yVfbTnhm1h54F/iE7fXFa/G6dtK9T7s5nv4k6ftkZpn4xbxK+AndUyGEUWZ2GH4GWQuYCQyIzoYT2m6OZzJQFzBgFnBJvou2ScHMOgHDQgi9yur9SbhELyIipStZSjciIlJCSvQiIilOiV5EJMUp0YuIpDglehGRFKdELyKS4pToRURS3P8DLIEl5WXqJ5UAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "df=pd.DataFrame(model.learning_process[10:], columns=['epoch', 'train_RMSE', 'test_RMSE'])\n",
    "plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n",
    "plt.plot('epoch', 'test_RMSE', data=df, color='yellow', linestyle='dashed')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Saving and evaluating recommendations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.estimations()\n",
    "\n",
    "top_n=pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n",
    "\n",
    "top_n.to_csv('Recommendations generated/ml-100k/Self_SVD_reco.csv', index=False, header=False)\n",
    "\n",
    "estimations=pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n",
    "estimations.to_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', index=False, header=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "943it [00:00, 9025.30it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>RMSE</th>\n",
       "      <th>MAE</th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>F_1</th>\n",
       "      <th>F_05</th>\n",
       "      <th>precision_super</th>\n",
       "      <th>recall_super</th>\n",
       "      <th>NDCG</th>\n",
       "      <th>mAP</th>\n",
       "      <th>MRR</th>\n",
       "      <th>LAUC</th>\n",
       "      <th>HR</th>\n",
       "      <th>HR2</th>\n",
       "      <th>Reco in test</th>\n",
       "      <th>Test coverage</th>\n",
       "      <th>Shannon</th>\n",
       "      <th>Gini</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.915079</td>\n",
       "      <td>0.71824</td>\n",
       "      <td>0.104772</td>\n",
       "      <td>0.045496</td>\n",
       "      <td>0.054393</td>\n",
       "      <td>0.071374</td>\n",
       "      <td>0.094421</td>\n",
       "      <td>0.076826</td>\n",
       "      <td>0.109517</td>\n",
       "      <td>0.052005</td>\n",
       "      <td>0.206646</td>\n",
       "      <td>0.519484</td>\n",
       "      <td>0.487805</td>\n",
       "      <td>0.264051</td>\n",
       "      <td>0.874549</td>\n",
       "      <td>0.142136</td>\n",
       "      <td>3.890472</td>\n",
       "      <td>0.972126</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       RMSE      MAE  precision    recall       F_1      F_05  \\\n",
       "0  0.915079  0.71824   0.104772  0.045496  0.054393  0.071374   \n",
       "\n",
       "   precision_super  recall_super      NDCG       mAP       MRR      LAUC  \\\n",
       "0         0.094421      0.076826  0.109517  0.052005  0.206646  0.519484   \n",
       "\n",
       "         HR       HR2  Reco in test  Test coverage   Shannon      Gini  \n",
       "0  0.487805  0.264051      0.874549       0.142136  3.890472  0.972126  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import evaluation_measures as ev\n",
    "\n",
    "estimations_df=pd.read_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', header=None)\n",
    "reco=np.loadtxt('Recommendations generated/ml-100k/Self_SVD_reco.csv', delimiter=',')\n",
    "\n",
    "ev.evaluate(test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None),\n",
    "            estimations_df=estimations_df, \n",
    "            reco=reco,\n",
    "            super_reactions=[4,5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "943it [00:00, 8433.36it/s]\n",
      "943it [00:00, 8182.71it/s]\n",
      "943it [00:00, 9546.13it/s]\n",
      "943it [00:00, 8959.29it/s]\n",
      "943it [00:00, 9016.78it/s]\n",
      "943it [00:00, 8085.81it/s]\n",
      "943it [00:00, 8341.37it/s]\n",
      "943it [00:00, 9531.98it/s]\n",
      "943it [00:00, 9952.14it/s]\n",
      "943it [00:00, 9774.37it/s]\n",
      "943it [00:00, 9543.76it/s]\n",
      "943it [00:00, 9634.07it/s]\n",
      "943it [00:00, 9988.71it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>RMSE</th>\n",
       "      <th>MAE</th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>F_1</th>\n",
       "      <th>F_05</th>\n",
       "      <th>precision_super</th>\n",
       "      <th>recall_super</th>\n",
       "      <th>NDCG</th>\n",
       "      <th>mAP</th>\n",
       "      <th>MRR</th>\n",
       "      <th>LAUC</th>\n",
       "      <th>HR</th>\n",
       "      <th>HR2</th>\n",
       "      <th>Reco in test</th>\n",
       "      <th>Test coverage</th>\n",
       "      <th>Shannon</th>\n",
       "      <th>Gini</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_TopPop</td>\n",
       "      <td>2.508258</td>\n",
       "      <td>2.217909</td>\n",
       "      <td>0.188865</td>\n",
       "      <td>0.116919</td>\n",
       "      <td>0.118732</td>\n",
       "      <td>0.141584</td>\n",
       "      <td>0.130472</td>\n",
       "      <td>0.137473</td>\n",
       "      <td>0.214651</td>\n",
       "      <td>0.111707</td>\n",
       "      <td>0.400939</td>\n",
       "      <td>0.555546</td>\n",
       "      <td>0.765642</td>\n",
       "      <td>0.492047</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.038961</td>\n",
       "      <td>3.159079</td>\n",
       "      <td>0.987317</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_SVD</td>\n",
       "      <td>0.915079</td>\n",
       "      <td>0.718240</td>\n",
       "      <td>0.104772</td>\n",
       "      <td>0.045496</td>\n",
       "      <td>0.054393</td>\n",
       "      <td>0.071374</td>\n",
       "      <td>0.094421</td>\n",
       "      <td>0.076826</td>\n",
       "      <td>0.109517</td>\n",
       "      <td>0.052005</td>\n",
       "      <td>0.206646</td>\n",
       "      <td>0.519484</td>\n",
       "      <td>0.487805</td>\n",
       "      <td>0.264051</td>\n",
       "      <td>0.874549</td>\n",
       "      <td>0.142136</td>\n",
       "      <td>3.890472</td>\n",
       "      <td>0.972126</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_Baseline</td>\n",
       "      <td>0.949459</td>\n",
       "      <td>0.752487</td>\n",
       "      <td>0.091410</td>\n",
       "      <td>0.037652</td>\n",
       "      <td>0.046030</td>\n",
       "      <td>0.061286</td>\n",
       "      <td>0.079614</td>\n",
       "      <td>0.056463</td>\n",
       "      <td>0.095957</td>\n",
       "      <td>0.043178</td>\n",
       "      <td>0.198193</td>\n",
       "      <td>0.515501</td>\n",
       "      <td>0.437964</td>\n",
       "      <td>0.239661</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.033911</td>\n",
       "      <td>2.836513</td>\n",
       "      <td>0.991139</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_KNNSurprisetask</td>\n",
       "      <td>0.946255</td>\n",
       "      <td>0.745209</td>\n",
       "      <td>0.083457</td>\n",
       "      <td>0.032848</td>\n",
       "      <td>0.041227</td>\n",
       "      <td>0.055493</td>\n",
       "      <td>0.074785</td>\n",
       "      <td>0.048890</td>\n",
       "      <td>0.089577</td>\n",
       "      <td>0.040902</td>\n",
       "      <td>0.189057</td>\n",
       "      <td>0.513076</td>\n",
       "      <td>0.417815</td>\n",
       "      <td>0.217391</td>\n",
       "      <td>0.888547</td>\n",
       "      <td>0.130592</td>\n",
       "      <td>3.611806</td>\n",
       "      <td>0.978659</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_GlobalAvg</td>\n",
       "      <td>1.125760</td>\n",
       "      <td>0.943534</td>\n",
       "      <td>0.061188</td>\n",
       "      <td>0.025968</td>\n",
       "      <td>0.031383</td>\n",
       "      <td>0.041343</td>\n",
       "      <td>0.040558</td>\n",
       "      <td>0.032107</td>\n",
       "      <td>0.067695</td>\n",
       "      <td>0.027470</td>\n",
       "      <td>0.171187</td>\n",
       "      <td>0.509546</td>\n",
       "      <td>0.384942</td>\n",
       "      <td>0.142100</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.025974</td>\n",
       "      <td>2.711772</td>\n",
       "      <td>0.992003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_Random</td>\n",
       "      <td>1.522798</td>\n",
       "      <td>1.222501</td>\n",
       "      <td>0.049841</td>\n",
       "      <td>0.020656</td>\n",
       "      <td>0.025232</td>\n",
       "      <td>0.033446</td>\n",
       "      <td>0.030579</td>\n",
       "      <td>0.022927</td>\n",
       "      <td>0.051680</td>\n",
       "      <td>0.019110</td>\n",
       "      <td>0.123085</td>\n",
       "      <td>0.506849</td>\n",
       "      <td>0.331919</td>\n",
       "      <td>0.119830</td>\n",
       "      <td>0.985048</td>\n",
       "      <td>0.183983</td>\n",
       "      <td>5.097973</td>\n",
       "      <td>0.907483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_I-KNN</td>\n",
       "      <td>1.030386</td>\n",
       "      <td>0.813067</td>\n",
       "      <td>0.026087</td>\n",
       "      <td>0.006908</td>\n",
       "      <td>0.010593</td>\n",
       "      <td>0.016046</td>\n",
       "      <td>0.021137</td>\n",
       "      <td>0.009522</td>\n",
       "      <td>0.024214</td>\n",
       "      <td>0.008958</td>\n",
       "      <td>0.048068</td>\n",
       "      <td>0.499885</td>\n",
       "      <td>0.154825</td>\n",
       "      <td>0.072110</td>\n",
       "      <td>0.402333</td>\n",
       "      <td>0.434343</td>\n",
       "      <td>5.133650</td>\n",
       "      <td>0.877999</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_I-KNNBaseline</td>\n",
       "      <td>0.935327</td>\n",
       "      <td>0.737424</td>\n",
       "      <td>0.002545</td>\n",
       "      <td>0.000755</td>\n",
       "      <td>0.001105</td>\n",
       "      <td>0.001602</td>\n",
       "      <td>0.002253</td>\n",
       "      <td>0.000930</td>\n",
       "      <td>0.003444</td>\n",
       "      <td>0.001362</td>\n",
       "      <td>0.011760</td>\n",
       "      <td>0.496724</td>\n",
       "      <td>0.021209</td>\n",
       "      <td>0.004242</td>\n",
       "      <td>0.482821</td>\n",
       "      <td>0.059885</td>\n",
       "      <td>2.232578</td>\n",
       "      <td>0.994487</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_U-KNN</td>\n",
       "      <td>1.023495</td>\n",
       "      <td>0.807913</td>\n",
       "      <td>0.000742</td>\n",
       "      <td>0.000205</td>\n",
       "      <td>0.000305</td>\n",
       "      <td>0.000449</td>\n",
       "      <td>0.000536</td>\n",
       "      <td>0.000198</td>\n",
       "      <td>0.000845</td>\n",
       "      <td>0.000274</td>\n",
       "      <td>0.002744</td>\n",
       "      <td>0.496441</td>\n",
       "      <td>0.007423</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.602121</td>\n",
       "      <td>0.010823</td>\n",
       "      <td>2.089186</td>\n",
       "      <td>0.995706</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_BaselineIU</td>\n",
       "      <td>0.958136</td>\n",
       "      <td>0.754051</td>\n",
       "      <td>0.000954</td>\n",
       "      <td>0.000188</td>\n",
       "      <td>0.000298</td>\n",
       "      <td>0.000481</td>\n",
       "      <td>0.000644</td>\n",
       "      <td>0.000223</td>\n",
       "      <td>0.001043</td>\n",
       "      <td>0.000335</td>\n",
       "      <td>0.003348</td>\n",
       "      <td>0.496433</td>\n",
       "      <td>0.009544</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.699046</td>\n",
       "      <td>0.005051</td>\n",
       "      <td>1.945910</td>\n",
       "      <td>0.995669</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_TopRated</td>\n",
       "      <td>2.508258</td>\n",
       "      <td>2.217909</td>\n",
       "      <td>0.000954</td>\n",
       "      <td>0.000188</td>\n",
       "      <td>0.000298</td>\n",
       "      <td>0.000481</td>\n",
       "      <td>0.000644</td>\n",
       "      <td>0.000223</td>\n",
       "      <td>0.001043</td>\n",
       "      <td>0.000335</td>\n",
       "      <td>0.003348</td>\n",
       "      <td>0.496433</td>\n",
       "      <td>0.009544</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.699046</td>\n",
       "      <td>0.005051</td>\n",
       "      <td>1.945910</td>\n",
       "      <td>0.995669</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_BaselineUI</td>\n",
       "      <td>0.967585</td>\n",
       "      <td>0.762740</td>\n",
       "      <td>0.000954</td>\n",
       "      <td>0.000170</td>\n",
       "      <td>0.000278</td>\n",
       "      <td>0.000463</td>\n",
       "      <td>0.000644</td>\n",
       "      <td>0.000189</td>\n",
       "      <td>0.000752</td>\n",
       "      <td>0.000168</td>\n",
       "      <td>0.001677</td>\n",
       "      <td>0.496424</td>\n",
       "      <td>0.009544</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.600530</td>\n",
       "      <td>0.005051</td>\n",
       "      <td>1.803126</td>\n",
       "      <td>0.996380</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_IKNN</td>\n",
       "      <td>1.018363</td>\n",
       "      <td>0.808793</td>\n",
       "      <td>0.000318</td>\n",
       "      <td>0.000108</td>\n",
       "      <td>0.000140</td>\n",
       "      <td>0.000189</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000214</td>\n",
       "      <td>0.000037</td>\n",
       "      <td>0.000368</td>\n",
       "      <td>0.496391</td>\n",
       "      <td>0.003181</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.392153</td>\n",
       "      <td>0.115440</td>\n",
       "      <td>4.174741</td>\n",
       "      <td>0.965327</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Model      RMSE       MAE  precision    recall       F_1  \\\n",
       "0           Self_TopPop  2.508258  2.217909   0.188865  0.116919  0.118732   \n",
       "0              Self_SVD  0.915079  0.718240   0.104772  0.045496  0.054393   \n",
       "0        Ready_Baseline  0.949459  0.752487   0.091410  0.037652  0.046030   \n",
       "0  Self_KNNSurprisetask  0.946255  0.745209   0.083457  0.032848  0.041227   \n",
       "0        Self_GlobalAvg  1.125760  0.943534   0.061188  0.025968  0.031383   \n",
       "0          Ready_Random  1.522798  1.222501   0.049841  0.020656  0.025232   \n",
       "0           Ready_I-KNN  1.030386  0.813067   0.026087  0.006908  0.010593   \n",
       "0   Ready_I-KNNBaseline  0.935327  0.737424   0.002545  0.000755  0.001105   \n",
       "0           Ready_U-KNN  1.023495  0.807913   0.000742  0.000205  0.000305   \n",
       "0       Self_BaselineIU  0.958136  0.754051   0.000954  0.000188  0.000298   \n",
       "0         Self_TopRated  2.508258  2.217909   0.000954  0.000188  0.000298   \n",
       "0       Self_BaselineUI  0.967585  0.762740   0.000954  0.000170  0.000278   \n",
       "0             Self_IKNN  1.018363  0.808793   0.000318  0.000108  0.000140   \n",
       "\n",
       "       F_05  precision_super  recall_super      NDCG       mAP       MRR  \\\n",
       "0  0.141584         0.130472      0.137473  0.214651  0.111707  0.400939   \n",
       "0  0.071374         0.094421      0.076826  0.109517  0.052005  0.206646   \n",
       "0  0.061286         0.079614      0.056463  0.095957  0.043178  0.198193   \n",
       "0  0.055493         0.074785      0.048890  0.089577  0.040902  0.189057   \n",
       "0  0.041343         0.040558      0.032107  0.067695  0.027470  0.171187   \n",
       "0  0.033446         0.030579      0.022927  0.051680  0.019110  0.123085   \n",
       "0  0.016046         0.021137      0.009522  0.024214  0.008958  0.048068   \n",
       "0  0.001602         0.002253      0.000930  0.003444  0.001362  0.011760   \n",
       "0  0.000449         0.000536      0.000198  0.000845  0.000274  0.002744   \n",
       "0  0.000481         0.000644      0.000223  0.001043  0.000335  0.003348   \n",
       "0  0.000481         0.000644      0.000223  0.001043  0.000335  0.003348   \n",
       "0  0.000463         0.000644      0.000189  0.000752  0.000168  0.001677   \n",
       "0  0.000189         0.000000      0.000000  0.000214  0.000037  0.000368   \n",
       "\n",
       "       LAUC        HR       HR2  Reco in test  Test coverage   Shannon  \\\n",
       "0  0.555546  0.765642  0.492047      1.000000       0.038961  3.159079   \n",
       "0  0.519484  0.487805  0.264051      0.874549       0.142136  3.890472   \n",
       "0  0.515501  0.437964  0.239661      1.000000       0.033911  2.836513   \n",
       "0  0.513076  0.417815  0.217391      0.888547       0.130592  3.611806   \n",
       "0  0.509546  0.384942  0.142100      1.000000       0.025974  2.711772   \n",
       "0  0.506849  0.331919  0.119830      0.985048       0.183983  5.097973   \n",
       "0  0.499885  0.154825  0.072110      0.402333       0.434343  5.133650   \n",
       "0  0.496724  0.021209  0.004242      0.482821       0.059885  2.232578   \n",
       "0  0.496441  0.007423  0.000000      0.602121       0.010823  2.089186   \n",
       "0  0.496433  0.009544  0.000000      0.699046       0.005051  1.945910   \n",
       "0  0.496433  0.009544  0.000000      0.699046       0.005051  1.945910   \n",
       "0  0.496424  0.009544  0.000000      0.600530       0.005051  1.803126   \n",
       "0  0.496391  0.003181  0.000000      0.392153       0.115440  4.174741   \n",
       "\n",
       "       Gini  \n",
       "0  0.987317  \n",
       "0  0.972126  \n",
       "0  0.991139  \n",
       "0  0.978659  \n",
       "0  0.992003  \n",
       "0  0.907483  \n",
       "0  0.877999  \n",
       "0  0.994487  \n",
       "0  0.995706  \n",
       "0  0.995669  \n",
       "0  0.995669  \n",
       "0  0.996380  \n",
       "0  0.965327  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import imp\n",
    "imp.reload(ev)\n",
    "\n",
    "import evaluation_measures as ev\n",
    "dir_path=\"Recommendations generated/ml-100k/\"\n",
    "super_reactions=[4,5]\n",
    "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n",
    "\n",
    "ev.evaluate_all(test, dir_path, super_reactions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1, 2],\n",
       "       [3, 4]])"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "array([[0.4472136 , 0.89442719],\n",
       "       [0.6       , 0.8       ]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x=np.array([[1,2],[3,4]])\n",
    "display(x)\n",
    "x/np.linalg.norm(x, axis=1)[:,None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>code</th>\n",
       "      <th>score</th>\n",
       "      <th>item_id</th>\n",
       "      <th>id</th>\n",
       "      <th>title</th>\n",
       "      <th>genres</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1638</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1639</td>\n",
       "      <td>1639</td>\n",
       "      <td>Bitter Sugar (Azucar Amargo) (1996)</td>\n",
       "      <td>Drama</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>802</td>\n",
       "      <td>0.992833</td>\n",
       "      <td>803</td>\n",
       "      <td>803</td>\n",
       "      <td>Heaven &amp; Earth (1993)</td>\n",
       "      <td>Action, Drama, War</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1378</td>\n",
       "      <td>0.992618</td>\n",
       "      <td>1379</td>\n",
       "      <td>1379</td>\n",
       "      <td>Love and Other Catastrophes (1996)</td>\n",
       "      <td>Romance</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1130</td>\n",
       "      <td>0.991573</td>\n",
       "      <td>1131</td>\n",
       "      <td>1131</td>\n",
       "      <td>Safe (1995)</td>\n",
       "      <td>Thriller</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1199</td>\n",
       "      <td>0.991141</td>\n",
       "      <td>1200</td>\n",
       "      <td>1200</td>\n",
       "      <td>Kim (1950)</td>\n",
       "      <td>Children's, Drama</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1195</td>\n",
       "      <td>0.991040</td>\n",
       "      <td>1196</td>\n",
       "      <td>1196</td>\n",
       "      <td>Savage Nights (Nuits fauves, Les) (1992)</td>\n",
       "      <td>Drama</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1622</td>\n",
       "      <td>0.990832</td>\n",
       "      <td>1623</td>\n",
       "      <td>1623</td>\n",
       "      <td>Cérémonie, La (1995)</td>\n",
       "      <td>Drama</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1417</td>\n",
       "      <td>0.990285</td>\n",
       "      <td>1418</td>\n",
       "      <td>1418</td>\n",
       "      <td>Joy Luck Club, The (1993)</td>\n",
       "      <td>Drama</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1067</td>\n",
       "      <td>0.990192</td>\n",
       "      <td>1068</td>\n",
       "      <td>1068</td>\n",
       "      <td>Star Maker, The (Uomo delle stelle, L') (1995)</td>\n",
       "      <td>Drama</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1192</td>\n",
       "      <td>0.990168</td>\n",
       "      <td>1193</td>\n",
       "      <td>1193</td>\n",
       "      <td>Before the Rain (Pred dozhdot) (1994)</td>\n",
       "      <td>Drama</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   code     score  item_id    id  \\\n",
       "0  1638  1.000000     1639  1639   \n",
       "1   802  0.992833      803   803   \n",
       "2  1378  0.992618     1379  1379   \n",
       "3  1130  0.991573     1131  1131   \n",
       "4  1199  0.991141     1200  1200   \n",
       "5  1195  0.991040     1196  1196   \n",
       "6  1622  0.990832     1623  1623   \n",
       "7  1417  0.990285     1418  1418   \n",
       "8  1067  0.990192     1068  1068   \n",
       "9  1192  0.990168     1193  1193   \n",
       "\n",
       "                                            title              genres  \n",
       "0             Bitter Sugar (Azucar Amargo) (1996)               Drama  \n",
       "1                           Heaven & Earth (1993)  Action, Drama, War  \n",
       "2              Love and Other Catastrophes (1996)             Romance  \n",
       "3                                     Safe (1995)            Thriller  \n",
       "4                                      Kim (1950)   Children's, Drama  \n",
       "5        Savage Nights (Nuits fauves, Les) (1992)               Drama  \n",
       "6                            Cérémonie, La (1995)               Drama  \n",
       "7                       Joy Luck Club, The (1993)               Drama  \n",
       "8  Star Maker, The (Uomo delle stelle, L') (1995)               Drama  \n",
       "9           Before the Rain (Pred dozhdot) (1994)               Drama  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "item=random.choice(list(set(train_ui.indices)))\n",
    "\n",
    "embeddings_norm=model.Qi/np.linalg.norm(model.Qi, axis=1)[:,None] # we do not mean-center here\n",
    "# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n",
    "\n",
    "similarity_scores=np.dot(embeddings_norm,embeddings_norm[item].T)\n",
    "top_similar_items=pd.DataFrame(enumerate(similarity_scores), columns=['code', 'score'])\\\n",
    ".sort_values(by=['score'], ascending=[False])[:10]\n",
    "\n",
    "top_similar_items['item_id']=top_similar_items['code'].apply(lambda x: item_code_id[x])\n",
    "\n",
    "items=pd.read_csv('./Datasets/ml-100k/movies.csv')\n",
    "\n",
    "result=pd.merge(top_similar_items, items, left_on='item_id', right_on='id')\n",
    "\n",
    "result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# project task 5:  implement SVD on top baseline (as it is in Surprise library)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# making changes to our implementation by considering additional parameters in the gradient descent procedure \n",
    "# seems to be the fastest option\n",
    "# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n",
    "# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "class SVDbaseline():\n",
    "    def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n",
    "        self.train_ui = train_ui\n",
    "        self.uir = list(zip(*[train_ui.nonzero()[0], train_ui.nonzero()[1], train_ui.data]))\n",
    "        \n",
    "        self.learning_rate = learning_rate\n",
    "        self.regularization = regularization\n",
    "        self.iterations = iterations\n",
    "        self.nb_users, self.nb_items = train_ui.shape\n",
    "        self.nb_ratings = train_ui.nnz\n",
    "        self.nb_factors = nb_factors\n",
    "        \n",
    "        self.Bu = np.random.normal(loc = 0, scale = 1./self.nb_factors, size = (self.nb_users, self.nb_factors))\n",
    "        self.Bi = np.random.normal(loc = 0, scale = 1./self.nb_factors, size = (self.nb_items, self.nb_factors))\n",
    "        \n",
    "        self.Pu = np.random.normal(loc = 0, scale = 1./self.nb_factors, size = (self.nb_users, self.nb_factors))\n",
    "        self.Qi = np.random.normal(loc = 0, scale = 1./self.nb_factors, size = (self.nb_items, self.nb_factors))\n",
    "        \n",
    "        self.bias_i = np.zeros(self.nb_items)\n",
    "        self.bias_u = np.zeros(self.nb_users)\n",
    "\n",
    "        \n",
    "    def train(self, test_ui = None):\n",
    "        if test_ui != None:\n",
    "            self.test_uir = list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n",
    "            \n",
    "        self.learning_process = []\n",
    "        pbar = tqdm(range(self.iterations))\n",
    "        \n",
    "        for i in pbar:\n",
    "            pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i > 0 else 0}. Training epoch {i + 1}...')\n",
    "            np.random.shuffle(self.uir)\n",
    "            self.sgd(self.uir)\n",
    "            \n",
    "            if test_ui == None:\n",
    "                self.learning_process.append([i + 1, self.RMSE_total(self.uir)])\n",
    "            else:\n",
    "                self.learning_process.append([i + 1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n",
    "    \n",
    "    \n",
    "    def sgd(self, uir):\n",
    "        for u, i, score in uir:\n",
    "            prediction = self.get_rating(u,i)\n",
    "            e = (score - prediction)\n",
    "            \n",
    "            Pu_update = self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n",
    "            Qi_update = self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n",
    "            \n",
    "            Bu_update = self.learning_rate * (e - self.regularization * self.Bu[u])\n",
    "            Bi_update = self.learning_rate * (e - self.regularization * self.Bi[i])\n",
    "            \n",
    "            self.Bu[u] += Bu_update\n",
    "            self.Bi[i] += Bi_update\n",
    "\n",
    "            self.Pu[u] += Pu_update\n",
    "            self.Qi[i] += Qi_update\n",
    "    \n",
    "    \n",
    "    def get_rating(self, u, i):\n",
    "        prediction = self.Bu[u] + self.Bi[i] + self.Pu[u].dot(self.Qi[i].T)\n",
    "        return prediction\n",
    "    \n",
    "    \n",
    "    def RMSE_total(self, uir):\n",
    "        RMSE = 0\n",
    "        for u,i, score in uir:\n",
    "            prediction = self.get_rating(u,i)\n",
    "            RMSE += (score - prediction) ** 2\n",
    "        return np.sqrt(RMSE / len(uir))\n",
    "    \n",
    "    \n",
    "    def estimations(self):\n",
    "        self.estimations=\\\n",
    "        self.bias_u[:, np.newaxis] + self.bias_i[np.newaxis:,] + np.dot(self.Pu, self.Qi.T)\n",
    "\n",
    "        \n",
    "    def recommend(self, user_code_id, item_code_id, topK = 10):\n",
    "        top_k = defaultdict(list)\n",
    "        for nb_user, user in enumerate(self.estimations):\n",
    "            user_rated = self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user + 1]]\n",
    "            for item, score in enumerate(user):\n",
    "                if item not in user_rated and not np.isnan(score):\n",
    "                    top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n",
    "        result = []\n",
    "        for uid, item_scores in top_k.items():\n",
    "            item_scores.sort(key = lambda x: x[1], reverse = True)\n",
    "            result.append([uid] + list(chain(*item_scores[:topK])))\n",
    "        return result\n",
    "    \n",
    "    \n",
    "    def estimate(self, user_code_id, item_code_id, test_ui):\n",
    "        result = []\n",
    "        for user, item in zip(*test_ui.nonzero()):\n",
    "            result.append([user_code_id[user], item_code_id[item], \n",
    "                           self.estimations[user, item] if not np.isnan(self.estimations[user, item]) else 1])\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1 RMSE: [1.56628501 1.56647573 1.56581135 1.56616034 1.56624616 1.56602468/it]\n",
      " 1.56609983 1.56608786 1.56617096 1.56609962 1.56627857 1.56624403\n",
      " 1.56608856 1.56608238 1.56620356 1.56604183 1.56617325 1.56616721\n",
      " 1.56608636 1.56621656 1.56601182 1.56629164 1.56603567 1.566008\n",
      " 1.5659204  1.56600686 1.56632857 1.56623671 1.56650435 1.56614388\n",
      " 1.56597602 1.56619724 1.56600564 1.56592808 1.5662823  1.56598423\n",
      " 1.56630978 1.5661384  1.56637227 1.56600394 1.56599185 1.56598777\n",
      " 1.56627173 1.56602758 1.56607052 1.56610967 1.56619676 1.5660723\n",
      " 1.5661363  1.56643558 1.56599281 1.56602474 1.56615414 1.56619859\n",
      " 1.56630092 1.56593518 1.56608721 1.56602644 1.56614329 1.56602186\n",
      " 1.56606511 1.56610499 1.56610819 1.56593067 1.5661533  1.56616749\n",
      " 1.56610497 1.5661211  1.56599833 1.56605015 1.56627997 1.56605438\n",
      " 1.56626503 1.5660066  1.56630412 1.56620906 1.56620395 1.56631407\n",
      " 1.56606347 1.56617128 1.56621069 1.56618785 1.56600689 1.56643735\n",
      " 1.56598312 1.56615281 1.5661071  1.56631373 1.56604796 1.56585804\n",
      " 1.56627589 1.56626862 1.56616867 1.56633676 1.56601226 1.5661184\n",
      "Epoch 1 RMSE: [1.56628501 1.56647573 1.56581135 1.56616034 1.56624616 1.56602468   | 1/40 [00:02<01:40,  2.57s/it]\n",
      " 1.56609983 1.56608786 1.56617096 1.56609962 1.56627857 1.56624403\n",
      " 1.56608856 1.56608238 1.56620356 1.56604183 1.56617325 1.56616721\n",
      " 1.56608636 1.56621656 1.56601182 1.56629164 1.56603567 1.566008\n",
      " 1.5659204  1.56600686 1.56632857 1.56623671 1.56650435 1.56614388\n",
      " 1.56597602 1.56619724 1.56600564 1.56592808 1.5662823  1.56598423\n",
      " 1.56630978 1.5661384  1.56637227 1.56600394 1.56599185 1.56598777\n",
      " 1.56627173 1.56602758 1.56607052 1.56610967 1.56619676 1.5660723\n",
      " 1.5661363  1.56643558 1.56599281 1.56602474 1.56615414 1.56619859\n",
      " 1.56630092 1.56593518 1.56608721 1.56602644 1.56614329 1.56602186\n",
      " 1.56606511 1.56610499 1.56610819 1.56593067 1.5661533  1.56616749\n",
      " 1.56610497 1.5661211  1.56599833 1.56605015 1.56627997 1.56605438\n",
      " 1.56626503 1.5660066  1.56630412 1.56620906 1.56620395 1.56631407\n",
      " 1.56606347 1.56617128 1.56621069 1.56618785 1.56600689 1.56643735\n",
      " 1.56598312 1.56615281 1.5661071  1.56631373 1.56604796 1.56585804\n",
      " 1.56627589 1.56626862 1.56616867 1.56633676 1.56601226 1.5661184\n",
      "Epoch 2 RMSE: [1.2417018  1.24188159 1.24150993 1.24174429 1.24176255 1.24170722   | 2/40 [00:05<01:37,  2.58s/it]\n",
      " 1.24171074 1.24167268 1.24172545 1.24174763 1.24183238 1.24173671\n",
      " 1.24163026 1.24162065 1.24177366 1.24158469 1.24177297 1.24166952\n",
      " 1.24171917 1.24183222 1.24162599 1.24178984 1.24168821 1.24166123\n",
      " 1.24161223 1.24162659 1.24185753 1.24183293 1.24185593 1.24174315\n",
      " 1.24165528 1.2416642  1.241631   1.24156043 1.24184828 1.24160591\n",
      " 1.24178785 1.2416347  1.24184533 1.2416611  1.24161786 1.24155744\n",
      " 1.24177065 1.24165326 1.24171443 1.24170905 1.24183533 1.24169338\n",
      " 1.24169124 1.24192236 1.24165418 1.24165594 1.2417279  1.24175767\n",
      " 1.24186364 1.24169617 1.24170661 1.24161604 1.24172478 1.24166061\n",
      " 1.24167343 1.24165278 1.24162706 1.2415709  1.24172076 1.24168487\n",
      " 1.24171406 1.24165323 1.24165323 1.24166652 1.24179703 1.24169395\n",
      " 1.24174949 1.24163659 1.2418293  1.24176263 1.24171404 1.2418294\n",
      " 1.24168063 1.24165637 1.24178972 1.24178604 1.24169735 1.241849\n",
      " 1.24170329 1.24177262 1.24171872 1.24179172 1.24171928 1.24153613\n",
      " 1.24177704 1.24185793 1.24168956 1.2417221  1.24165712 1.24172613\n",
      "Epoch 2 RMSE: [1.2417018  1.24188159 1.24150993 1.24174429 1.24176255 1.24170722   | 2/40 [00:05<01:37,  2.58s/it]\n",
      " 1.24171074 1.24167268 1.24172545 1.24174763 1.24183238 1.24173671\n",
      " 1.24163026 1.24162065 1.24177366 1.24158469 1.24177297 1.24166952\n",
      " 1.24171917 1.24183222 1.24162599 1.24178984 1.24168821 1.24166123\n",
      " 1.24161223 1.24162659 1.24185753 1.24183293 1.24185593 1.24174315\n",
      " 1.24165528 1.2416642  1.241631   1.24156043 1.24184828 1.24160591\n",
      " 1.24178785 1.2416347  1.24184533 1.2416611  1.24161786 1.24155744\n",
      " 1.24177065 1.24165326 1.24171443 1.24170905 1.24183533 1.24169338\n",
      " 1.24169124 1.24192236 1.24165418 1.24165594 1.2417279  1.24175767\n",
      " 1.24186364 1.24169617 1.24170661 1.24161604 1.24172478 1.24166061\n",
      " 1.24167343 1.24165278 1.24162706 1.2415709  1.24172076 1.24168487\n",
      " 1.24171406 1.24165323 1.24165323 1.24166652 1.24179703 1.24169395\n",
      " 1.24174949 1.24163659 1.2418293  1.24176263 1.24171404 1.2418294\n",
      " 1.24168063 1.24165637 1.24178972 1.24178604 1.24169735 1.241849\n",
      " 1.24170329 1.24177262 1.24171872 1.24179172 1.24171928 1.24153613\n",
      " 1.24177704 1.24185793 1.24168956 1.2417221  1.24165712 1.24172613\n",
      "Epoch 3 RMSE: [1.13258243 1.13273769 1.13249398 1.13265935 1.13264375 1.13263674   | 3/40 [00:07<01:35,  2.58s/it]\n",
      " 1.13264279 1.13259496 1.13263666 1.13267439 1.13270287 1.13263871\n",
      " 1.13256654 1.13256675 1.13267736 1.13252626 1.13269081 1.13258232\n",
      " 1.13262013 1.13274078 1.13258389 1.13267925 1.13263407 1.13259053\n",
      " 1.13256681 1.13256461 1.13271984 1.13272441 1.13271252 1.13264455\n",
      " 1.13259818 1.13258742 1.13258444 1.1325301  1.13271939 1.13255725\n",
      " 1.13265228 1.13257112 1.13272558 1.13259942 1.13256249 1.13251124\n",
      " 1.13267438 1.13259717 1.13263932 1.13263856 1.132744   1.13260791\n",
      " 1.13260569 1.13276277 1.13258999 1.13260022 1.13263504 1.13267512\n",
      " 1.13274647 1.13265384 1.13265274 1.13255779 1.13264575 1.13261409\n",
      " 1.13261241 1.13257776 1.13255374 1.13251717 1.13263993 1.13258915\n",
      " 1.13262861 1.13256498 1.13260396 1.13259002 1.13269051 1.13262309\n",
      " 1.13264451 1.1326     1.13270356 1.13266447 1.13263253 1.13269319\n",
      " 1.13262264 1.13256405 1.1326936  1.13267126 1.1326336  1.13270766\n",
      " 1.13266901 1.1326705  1.13265684 1.13267642 1.13266019 1.13251334\n",
      " 1.13266081 1.13273492 1.13259381 1.13259615 1.13258523 1.13264068\n",
      "Epoch 3 RMSE: [1.13258243 1.13273769 1.13249398 1.13265935 1.13264375 1.13263674   | 3/40 [00:07<01:35,  2.58s/it]\n",
      " 1.13264279 1.13259496 1.13263666 1.13267439 1.13270287 1.13263871\n",
      " 1.13256654 1.13256675 1.13267736 1.13252626 1.13269081 1.13258232\n",
      " 1.13262013 1.13274078 1.13258389 1.13267925 1.13263407 1.13259053\n",
      " 1.13256681 1.13256461 1.13271984 1.13272441 1.13271252 1.13264455\n",
      " 1.13259818 1.13258742 1.13258444 1.1325301  1.13271939 1.13255725\n",
      " 1.13265228 1.13257112 1.13272558 1.13259942 1.13256249 1.13251124\n",
      " 1.13267438 1.13259717 1.13263932 1.13263856 1.132744   1.13260791\n",
      " 1.13260569 1.13276277 1.13258999 1.13260022 1.13263504 1.13267512\n",
      " 1.13274647 1.13265384 1.13265274 1.13255779 1.13264575 1.13261409\n",
      " 1.13261241 1.13257776 1.13255374 1.13251717 1.13263993 1.13258915\n",
      " 1.13262861 1.13256498 1.13260396 1.13259002 1.13269051 1.13262309\n",
      " 1.13264451 1.1326     1.13270356 1.13266447 1.13263253 1.13269319\n",
      " 1.13262264 1.13256405 1.1326936  1.13267126 1.1326336  1.13270766\n",
      " 1.13266901 1.1326705  1.13265684 1.13267642 1.13266019 1.13251334\n",
      " 1.13266081 1.13273492 1.13259381 1.13259615 1.13258523 1.13264068\n",
      "Epoch 4 RMSE: [1.07470103 1.07483205 1.074653   1.07478245 1.07474992 1.07475779   | 4/40 [00:10<01:30,  2.53s/it]\n",
      " 1.07476855 1.07471954 1.07475657 1.07479501 1.07479633 1.07475995\n",
      " 1.07471091 1.07471076 1.07478729 1.07466922 1.074806   1.07471156\n",
      " 1.07473267 1.07485138 1.07472955 1.07478602 1.07476246 1.07471623\n",
      " 1.07470393 1.07470486 1.07480787 1.07483042 1.07480929 1.07475672\n",
      " 1.074736   1.07471934 1.07472499 1.07468818 1.07480974 1.0747082\n",
      " 1.07475006 1.07471146 1.07482552 1.07473571 1.07469558 1.07466382\n",
      " 1.07478714 1.07473432 1.07476276 1.07476954 1.07485164 1.07472884\n",
      " 1.07472827 1.07483967 1.0747193  1.07473498 1.07475513 1.07479379\n",
      " 1.07484481 1.07478596 1.074786   1.07469995 1.07477074 1.07475607\n",
      " 1.07475005 1.07470749 1.07469114 1.0746581  1.07476066 1.0747095\n",
      " 1.07474566 1.07469449 1.07474116 1.07472055 1.07479978 1.07475009\n",
      " 1.07476008 1.07473934 1.07480159 1.07477503 1.07476399 1.07478594\n",
      " 1.07475659 1.07469004 1.0748058  1.07477358 1.07476343 1.07479883\n",
      " 1.07480411 1.07477767 1.0747872  1.07478549 1.07478493 1.07467267\n",
      " 1.07476843 1.07482625 1.07471241 1.0747061  1.07471244 1.07475596\n",
      "Epoch 4 RMSE: [1.07470103 1.07483205 1.074653   1.07478245 1.07474992 1.07475779   | 4/40 [00:10<01:30,  2.53s/it]  \n",
      " 1.07476855 1.07471954 1.07475657 1.07479501 1.07479633 1.07475995\n",
      " 1.07471091 1.07471076 1.07478729 1.07466922 1.074806   1.07471156\n",
      " 1.07473267 1.07485138 1.07472955 1.07478602 1.07476246 1.07471623\n",
      " 1.07470393 1.07470486 1.07480787 1.07483042 1.07480929 1.07475672\n",
      " 1.074736   1.07471934 1.07472499 1.07468818 1.07480974 1.0747082\n",
      " 1.07475006 1.07471146 1.07482552 1.07473571 1.07469558 1.07466382\n",
      " 1.07478714 1.07473432 1.07476276 1.07476954 1.07485164 1.07472884\n",
      " 1.07472827 1.07483967 1.0747193  1.07473498 1.07475513 1.07479379\n",
      " 1.07484481 1.07478596 1.074786   1.07469995 1.07477074 1.07475607\n",
      " 1.07475005 1.07470749 1.07469114 1.0746581  1.07476066 1.0747095\n",
      " 1.07474566 1.07469449 1.07474116 1.07472055 1.07479978 1.07475009\n",
      " 1.07476008 1.07473934 1.07480159 1.07477503 1.07476399 1.07478594\n",
      " 1.07475659 1.07469004 1.0748058  1.07477358 1.07476343 1.07479883\n",
      " 1.07480411 1.07477767 1.0747872  1.07478549 1.07478493 1.07467267\n",
      " 1.07476843 1.07482625 1.07471241 1.0747061  1.07471244 1.07475596\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5 RMSE: [1.03814976 1.03826045 1.03811633 1.03822671 1.03818608 1.03819672   | 5/40 [00:12<01:28,  2.54s/it]\n",
      " 1.03821107 1.03816502 1.03819735 1.03823394 1.03822099 1.03820276\n",
      " 1.03817009 1.03816776 1.03822115 1.03812697 1.03824053 1.03816213\n",
      " 1.0381714  1.03828314 1.03818505 1.03821938 1.03820537 1.03816112\n",
      " 1.03815479 1.03816081 1.03823001 1.03826181 1.03823864 1.03819418\n",
      " 1.03818894 1.03817048 1.03817742 1.03815259 1.03823377 1.03817128\n",
      " 1.03818148 1.03816561 1.03825063 1.03819015 1.03814364 1.03812778\n",
      " 1.0382225  1.03818517 1.03820551 1.03821651 1.038281   1.03817264\n",
      " 1.03817241 1.03825447 1.0381662  1.03818432 1.03820159 1.03823161\n",
      " 1.03826976 1.03823002 1.038232   1.03815758 1.03821445 1.03820907\n",
      " 1.03820234 1.03815744 1.03814712 1.03811497 1.03820183 1.03815373\n",
      " 1.03818401 1.03814703 1.03819164 1.03817238 1.03823326 1.03819768\n",
      " 1.03820049 1.03818693 1.03822919 1.03820819 1.0382148  1.03821205\n",
      " 1.0382052  1.03813962 1.03824113 1.03820524 1.0382107  1.03822166\n",
      " 1.03824703 1.03821022 1.03823076 1.03822139 1.03822339 1.03813808\n",
      " 1.03820354 1.038249   1.03815644 1.03814753 1.03816163 1.03819351\n",
      "Epoch 5 RMSE: [1.03814976 1.03826045 1.03811633 1.03822671 1.03818608 1.03819672   | 5/40 [00:12<01:28,  2.54s/it]\n",
      " 1.03821107 1.03816502 1.03819735 1.03823394 1.03822099 1.03820276\n",
      " 1.03817009 1.03816776 1.03822115 1.03812697 1.03824053 1.03816213\n",
      " 1.0381714  1.03828314 1.03818505 1.03821938 1.03820537 1.03816112\n",
      " 1.03815479 1.03816081 1.03823001 1.03826181 1.03823864 1.03819418\n",
      " 1.03818894 1.03817048 1.03817742 1.03815259 1.03823377 1.03817128\n",
      " 1.03818148 1.03816561 1.03825063 1.03819015 1.03814364 1.03812778\n",
      " 1.0382225  1.03818517 1.03820551 1.03821651 1.038281   1.03817264\n",
      " 1.03817241 1.03825447 1.0381662  1.03818432 1.03820159 1.03823161\n",
      " 1.03826976 1.03823002 1.038232   1.03815758 1.03821445 1.03820907\n",
      " 1.03820234 1.03815744 1.03814712 1.03811497 1.03820183 1.03815373\n",
      " 1.03818401 1.03814703 1.03819164 1.03817238 1.03823326 1.03819768\n",
      " 1.03820049 1.03818693 1.03822919 1.03820819 1.0382148  1.03821205\n",
      " 1.0382052  1.03813962 1.03824113 1.03820524 1.0382107  1.03822166\n",
      " 1.03824703 1.03821022 1.03823076 1.03822139 1.03822339 1.03813808\n",
      " 1.03820354 1.038249   1.03815644 1.03814753 1.03816163 1.03819351\n",
      "Epoch 6 RMSE: [1.01281717 1.01291107 1.01278938 1.01288647 1.01284443 1.01285417   | 6/40 [00:15<01:28,  2.60s/it]\n",
      " 1.01286901 1.01282764 1.01285544 1.01288953 1.0128703  1.01286329\n",
      " 1.01284107 1.01283658 1.01287412 1.01279734 1.01289272 1.01282785\n",
      " 1.01283052 1.01293233 1.01285139 1.01287313 1.01286354 1.01282315\n",
      " 1.01281977 1.01283023 1.01287715 1.01291256 1.0128901  1.01285057\n",
      " 1.01285464 1.01283705 1.01284262 1.01282647 1.01288275 1.01284451\n",
      " 1.01283583 1.01283238 1.01289636 1.01285878 1.01280834 1.01280304\n",
      " 1.01287596 1.01284933 1.01286522 1.01287812 1.0129298  1.01283542\n",
      " 1.01283487 1.01289631 1.01282952 1.0128476  1.0128641  1.01288633\n",
      " 1.01291558 1.01288799 1.01289213 1.01282762 1.01287417 1.01287394\n",
      " 1.01286766 1.01282299 1.01281667 1.01278599 1.01286056 1.01281531\n",
      " 1.01284136 1.01281493 1.01285498 1.01284074 1.01288644 1.012862\n",
      " 1.01285987 1.01284788 1.01287841 1.01286073 1.01287958 1.01286209\n",
      " 1.01286818 1.01280658 1.01289446 1.01285892 1.01287203 1.01286922\n",
      " 1.01290259 1.01286311 1.01288884 1.01287717 1.01287654 1.01281316\n",
      " 1.01285885 1.01289549 1.01281868 1.01280926 1.01282758 1.01284977\n",
      "Epoch 6 RMSE: [1.01281717 1.01291107 1.01278938 1.01288647 1.01284443 1.01285417   | 6/40 [00:15<01:28,  2.60s/it]  \n",
      " 1.01286901 1.01282764 1.01285544 1.01288953 1.0128703  1.01286329\n",
      " 1.01284107 1.01283658 1.01287412 1.01279734 1.01289272 1.01282785\n",
      " 1.01283052 1.01293233 1.01285139 1.01287313 1.01286354 1.01282315\n",
      " 1.01281977 1.01283023 1.01287715 1.01291256 1.0128901  1.01285057\n",
      " 1.01285464 1.01283705 1.01284262 1.01282647 1.01288275 1.01284451\n",
      " 1.01283583 1.01283238 1.01289636 1.01285878 1.01280834 1.01280304\n",
      " 1.01287596 1.01284933 1.01286522 1.01287812 1.0129298  1.01283542\n",
      " 1.01283487 1.01289631 1.01282952 1.0128476  1.0128641  1.01288633\n",
      " 1.01291558 1.01288799 1.01289213 1.01282762 1.01287417 1.01287394\n",
      " 1.01286766 1.01282299 1.01281667 1.01278599 1.01286056 1.01281531\n",
      " 1.01284136 1.01281493 1.01285498 1.01284074 1.01288644 1.012862\n",
      " 1.01285987 1.01284788 1.01287841 1.01286073 1.01287958 1.01286209\n",
      " 1.01286818 1.01280658 1.01289446 1.01285892 1.01287203 1.01286922\n",
      " 1.01290259 1.01286311 1.01288884 1.01287717 1.01287654 1.01281316\n",
      " 1.01285885 1.01289549 1.01281868 1.01280926 1.01282758 1.01284977\n",
      "Epoch 7 RMSE: [0.99440278 0.99448285 0.99437689 0.99446413 0.99442374 0.99443101   | 7/40 [00:17<01:23,  2.53s/it]\n",
      " 0.99444568 0.9944093  0.99443276 0.99446378 0.99444228 0.99444224\n",
      " 0.99442689 0.99442063 0.99444827 0.99438454 0.99446507 0.99441125\n",
      " 0.99441043 0.99450116 0.99443336 0.99444832 0.9944398  0.99440386\n",
      " 0.99440271 0.9944157  0.99444763 0.99448437 0.99446327 0.99442748\n",
      " 0.99443694 0.99442101 0.9944241  0.99441448 0.99445494 0.99443242\n",
      " 0.99441235 0.99441589 0.99446372 0.99444438 0.99439146 0.99439286\n",
      " 0.99444993 0.99443034 0.99444365 0.99445697 0.99449893 0.99441738\n",
      " 0.99441652 0.99446308 0.99441141 0.99442816 0.99444559 0.9944599\n",
      " 0.99448334 0.99446342 0.99447032 0.99441352 0.99445264 0.99445534\n",
      " 0.99444959 0.9944058  0.99440297 0.99437383 0.99443823 0.99439607\n",
      " 0.99441945 0.9944003  0.99443535 0.99442618 0.99445976 0.99444387\n",
      " 0.99443846 0.99442597 0.99444931 0.99443373 0.99446158 0.99443475\n",
      " 0.99444772 0.99439111 0.99446809 0.99443495 0.99445024 0.99444\n",
      " 0.99447577 0.99443738 0.99446482 0.99445296 0.99444964 0.99440218\n",
      " 0.9944348  0.99446505 0.99440051 0.99439129 0.99441129 0.99442638\n",
      "Epoch 7 RMSE: [0.99440278 0.99448285 0.99437689 0.99446413 0.99442374 0.99443101   | 7/40 [00:17<01:23,  2.53s/it]  \n",
      " 0.99444568 0.9944093  0.99443276 0.99446378 0.99444228 0.99444224\n",
      " 0.99442689 0.99442063 0.99444827 0.99438454 0.99446507 0.99441125\n",
      " 0.99441043 0.99450116 0.99443336 0.99444832 0.9944398  0.99440386\n",
      " 0.99440271 0.9944157  0.99444763 0.99448437 0.99446327 0.99442748\n",
      " 0.99443694 0.99442101 0.9944241  0.99441448 0.99445494 0.99443242\n",
      " 0.99441235 0.99441589 0.99446372 0.99444438 0.99439146 0.99439286\n",
      " 0.99444993 0.99443034 0.99444365 0.99445697 0.99449893 0.99441738\n",
      " 0.99441652 0.99446308 0.99441141 0.99442816 0.99444559 0.9944599\n",
      " 0.99448334 0.99446342 0.99447032 0.99441352 0.99445264 0.99445534\n",
      " 0.99444959 0.9944058  0.99440297 0.99437383 0.99443823 0.99439607\n",
      " 0.99441945 0.9944003  0.99443535 0.99442618 0.99445976 0.99444387\n",
      " 0.99443846 0.99442597 0.99444931 0.99443373 0.99446158 0.99443475\n",
      " 0.99444772 0.99439111 0.99446809 0.99443495 0.99445024 0.99444\n",
      " 0.99447577 0.99443738 0.99446482 0.99445296 0.99444964 0.99440218\n",
      " 0.9944348  0.99446505 0.99440051 0.99439129 0.99441129 0.99442638\n",
      "Epoch 8 RMSE: [0.98042218 0.98049048 0.98039697 0.98047584 0.98043855 0.98044358   | 8/40 [00:20<01:19,  2.48s/it]\n",
      " 0.9804572  0.98042485 0.98044487 0.98047328 0.9804515  0.98045602\n",
      " 0.98044497 0.98043727 0.98045856 0.98040473 0.98047328 0.98042834\n",
      " 0.98042557 0.98050617 0.98044835 0.98045934 0.98045096 0.98041956\n",
      " 0.98041958 0.98043437 0.98045558 0.98049181 0.98047231 0.98043977\n",
      " 0.98045187 0.98043782 0.98043933 0.98043387 0.98046433 0.98045169\n",
      " 0.98042538 0.98043231 0.98046841 0.9804625  0.98040982 0.98041445\n",
      " 0.98046011 0.98044478 0.98045634 0.98046963 0.98050411 0.98043312\n",
      " 0.98043237 0.98046834 0.98042783 0.98044328 0.98046024 0.98046917\n",
      " 0.98048834 0.98047363 0.98048218 0.98043166 0.98046508 0.9804699\n",
      " 0.98046497 0.98042322 0.9804222  0.98039509 0.98045161 0.98041184\n",
      " 0.98043266 0.98041952 0.98044893 0.98044451 0.98046886 0.98045955\n",
      " 0.98045228 0.98043826 0.98045715 0.98044318 0.98047617 0.98044425\n",
      " 0.98046183 0.98041042 0.98047757 0.9804468  0.98046262 0.98044893\n",
      " 0.98048381 0.98044816 0.98047516 0.98046413 0.98045799 0.98042268\n",
      " 0.98044622 0.98047213 0.98041723 0.98040793 0.98042941 0.98043912\n",
      "Epoch 8 RMSE: [0.98042218 0.98049048 0.98039697 0.98047584 0.98043855 0.98044358   | 8/40 [00:20<01:19,  2.48s/it]\n",
      " 0.9804572  0.98042485 0.98044487 0.98047328 0.9804515  0.98045602\n",
      " 0.98044497 0.98043727 0.98045856 0.98040473 0.98047328 0.98042834\n",
      " 0.98042557 0.98050617 0.98044835 0.98045934 0.98045096 0.98041956\n",
      " 0.98041958 0.98043437 0.98045558 0.98049181 0.98047231 0.98043977\n",
      " 0.98045187 0.98043782 0.98043933 0.98043387 0.98046433 0.98045169\n",
      " 0.98042538 0.98043231 0.98046841 0.9804625  0.98040982 0.98041445\n",
      " 0.98046011 0.98044478 0.98045634 0.98046963 0.98050411 0.98043312\n",
      " 0.98043237 0.98046834 0.98042783 0.98044328 0.98046024 0.98046917\n",
      " 0.98048834 0.98047363 0.98048218 0.98043166 0.98046508 0.9804699\n",
      " 0.98046497 0.98042322 0.9804222  0.98039509 0.98045161 0.98041184\n",
      " 0.98043266 0.98041952 0.98044893 0.98044451 0.98046886 0.98045955\n",
      " 0.98045228 0.98043826 0.98045715 0.98044318 0.98047617 0.98044425\n",
      " 0.98046183 0.98041042 0.98047757 0.9804468  0.98046262 0.98044893\n",
      " 0.98048381 0.98044816 0.98047516 0.98046413 0.98045799 0.98042268\n",
      " 0.98044622 0.98047213 0.98041723 0.98040793 0.98042941 0.98043912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9 RMSE: [0.96956864 0.96962721 0.96954354 0.96961556 0.96958205 0.96958458   | 9/40 [00:22<01:16,  2.47s/it]\n",
      " 0.96959707 0.96956865 0.96958616 0.96961134 0.96959051 0.96959811\n",
      " 0.96958923 0.96958072 0.96959807 0.96955239 0.96961096 0.96957265\n",
      " 0.96956903 0.96964055 0.96959048 0.96959949 0.96959073 0.96956322\n",
      " 0.96956373 0.96957968 0.96959395 0.9696286  0.96961054 0.96958063\n",
      " 0.96959443 0.96958199 0.96958222 0.96957953 0.96960364 0.96959708\n",
      " 0.96956749 0.96957626 0.96960309 0.96960737 0.96955603 0.96956241\n",
      " 0.96959943 0.96958725 0.96959703 0.96961002 0.96963852 0.96957742\n",
      " 0.96957636 0.9696043  0.96957224 0.96958563 0.96960255 0.96960695\n",
      " 0.96962291 0.96961234 0.9696223  0.96957714 0.96960567 0.9696117\n",
      " 0.96960751 0.96956774 0.96956802 0.96954356 0.9695932  0.96955556\n",
      " 0.96957482 0.96956564 0.96959001 0.96958992 0.96960666 0.96960246\n",
      " 0.96959448 0.96957906 0.96959454 0.96958207 0.96961755 0.96958331\n",
      " 0.9696031  0.9695574  0.9696157  0.96958767 0.96960284 0.96958787\n",
      " 0.96962011 0.969588   0.96961464 0.9696038  0.96959544 0.96956941\n",
      " 0.96958616 0.96960883 0.9695618  0.96955266 0.9695746  0.96958055\n",
      "Epoch 9 RMSE: [0.96956864 0.96962721 0.96954354 0.96961556 0.96958205 0.96958458    | 9/40 [00:22<01:16,  2.47s/it]\n",
      " 0.96959707 0.96956865 0.96958616 0.96961134 0.96959051 0.96959811\n",
      " 0.96958923 0.96958072 0.96959807 0.96955239 0.96961096 0.96957265\n",
      " 0.96956903 0.96964055 0.96959048 0.96959949 0.96959073 0.96956322\n",
      " 0.96956373 0.96957968 0.96959395 0.9696286  0.96961054 0.96958063\n",
      " 0.96959443 0.96958199 0.96958222 0.96957953 0.96960364 0.96959708\n",
      " 0.96956749 0.96957626 0.96960309 0.96960737 0.96955603 0.96956241\n",
      " 0.96959943 0.96958725 0.96959703 0.96961002 0.96963852 0.96957742\n",
      " 0.96957636 0.9696043  0.96957224 0.96958563 0.96960255 0.96960695\n",
      " 0.96962291 0.96961234 0.9696223  0.96957714 0.96960567 0.9696117\n",
      " 0.96960751 0.96956774 0.96956802 0.96954356 0.9695932  0.96955556\n",
      " 0.96957482 0.96956564 0.96959001 0.96958992 0.96960666 0.96960246\n",
      " 0.96959448 0.96957906 0.96959454 0.96958207 0.96961755 0.96958331\n",
      " 0.9696031  0.9695574  0.9696157  0.96958767 0.96960284 0.96958787\n",
      " 0.96962011 0.969588   0.96961464 0.9696038  0.96959544 0.96956941\n",
      " 0.96958616 0.96960883 0.9695618  0.96955266 0.9695746  0.96958055\n",
      "Epoch 10 RMSE: [0.96075673 0.96080761 0.96073196 0.96079776 0.96076804 0.96076873   | 10/40 [00:24<01:13,  2.44s/it]\n",
      " 0.96078039 0.96075495 0.96077015 0.96079231 0.96077356 0.96078261\n",
      " 0.96077559 0.96076643 0.96078098 0.9607412  0.96079177 0.96075896\n",
      " 0.96075491 0.96081856 0.96077505 0.96078216 0.9607736  0.96074991\n",
      " 0.96075047 0.96076696 0.96077613 0.9608088  0.9607921  0.96076446\n",
      " 0.96077892 0.96076825 0.96076753 0.96076714 0.96078607 0.96078396\n",
      " 0.9607528  0.96076217 0.96078209 0.9607937  0.96074464 0.96075185\n",
      " 0.96078205 0.96077184 0.96078107 0.96079289 0.96081683 0.96076394\n",
      " 0.96076258 0.96078479 0.96075914 0.96077056 0.96078714 0.96078814\n",
      " 0.96080169 0.9607937  0.96080479 0.96076411 0.96078901 0.96079532\n",
      " 0.96079197 0.9607546  0.96075574 0.96073366 0.96077769 0.96074198\n",
      " 0.96075972 0.96075414 0.96077353 0.96077696 0.96078774 0.96078751\n",
      " 0.96077925 0.96076313 0.9607757  0.96076468 0.96080126 0.96076611\n",
      " 0.96078704 0.96074671 0.9607969  0.96077184 0.9607857  0.96077078\n",
      " 0.9607998  0.96077123 0.96079652 0.96078645 0.96077646 0.96075732\n",
      " 0.96076928 0.96078951 0.96074835 0.96074048 0.96076206 0.96076493\n",
      "Epoch 10 RMSE: [0.96075673 0.96080761 0.96073196 0.96079776 0.96076804 0.96076873   | 10/40 [00:24<01:13,  2.44s/it]\n",
      " 0.96078039 0.96075495 0.96077015 0.96079231 0.96077356 0.96078261\n",
      " 0.96077559 0.96076643 0.96078098 0.9607412  0.96079177 0.96075896\n",
      " 0.96075491 0.96081856 0.96077505 0.96078216 0.9607736  0.96074991\n",
      " 0.96075047 0.96076696 0.96077613 0.9608088  0.9607921  0.96076446\n",
      " 0.96077892 0.96076825 0.96076753 0.96076714 0.96078607 0.96078396\n",
      " 0.9607528  0.96076217 0.96078209 0.9607937  0.96074464 0.96075185\n",
      " 0.96078205 0.96077184 0.96078107 0.96079289 0.96081683 0.96076394\n",
      " 0.96076258 0.96078479 0.96075914 0.96077056 0.96078714 0.96078814\n",
      " 0.96080169 0.9607937  0.96080479 0.96076411 0.96078901 0.96079532\n",
      " 0.96079197 0.9607546  0.96075574 0.96073366 0.96077769 0.96074198\n",
      " 0.96075972 0.96075414 0.96077353 0.96077696 0.96078774 0.96078751\n",
      " 0.96077925 0.96076313 0.9607757  0.96076468 0.96080126 0.96076611\n",
      " 0.96078704 0.96074671 0.9607969  0.96077184 0.9607857  0.96077078\n",
      " 0.9607998  0.96077123 0.96079652 0.96078645 0.96077646 0.96075732\n",
      " 0.96076928 0.96078951 0.96074835 0.96074048 0.96076206 0.96076493\n",
      "Epoch 11 RMSE: [0.95361422 0.9536583  0.95358983 0.95365028 0.95362411 0.95362327   | 11/40 [00:27<01:09,  2.41s/it]\n",
      " 0.95363393 0.95361104 0.95362427 0.95364368 0.95362728 0.95363712\n",
      " 0.95363111 0.95362154 0.95363403 0.95359944 0.95364333 0.95361499\n",
      " 0.953611   0.95366754 0.95362947 0.95363561 0.95362648 0.95360615\n",
      " 0.95360678 0.95362319 0.95362914 0.95365963 0.95364409 0.95361836\n",
      " 0.95363346 0.95362369 0.95362264 0.95362344 0.95363932 0.95363981\n",
      " 0.95360837 0.95361782 0.95363242 0.95364907 0.95360271 0.95361032\n",
      " 0.95363511 0.95362669 0.95363478 0.95364577 0.95366592 0.95362025\n",
      " 0.95361858 0.95363617 0.95361555 0.95362547 0.95364146 0.95363965\n",
      " 0.9536515  0.95364549 0.95365732 0.95362043 0.95364252 0.95364887\n",
      " 0.95364593 0.9536111  0.95361243 0.95359329 0.95363234 0.95359863\n",
      " 0.95361472 0.95361184 0.95362725 0.95363318 0.95363953 0.95364235\n",
      " 0.95363404 0.95361743 0.95362787 0.9536177  0.95365451 0.95361986\n",
      " 0.95364087 0.9536052  0.95364885 0.9536261  0.9536388  0.95362439\n",
      " 0.95365    0.9536251  0.95364881 0.95363917 0.95362836 0.95361419\n",
      " 0.95362283 0.95364097 0.95360492 0.95359786 0.95361872 0.95361965\n",
      "Epoch 11 RMSE: [0.95361422 0.9536583  0.95358983 0.95365028 0.95362411 0.95362327   | 11/40 [00:27<01:09,  2.41s/it]\n",
      " 0.95363393 0.95361104 0.95362427 0.95364368 0.95362728 0.95363712\n",
      " 0.95363111 0.95362154 0.95363403 0.95359944 0.95364333 0.95361499\n",
      " 0.953611   0.95366754 0.95362947 0.95363561 0.95362648 0.95360615\n",
      " 0.95360678 0.95362319 0.95362914 0.95365963 0.95364409 0.95361836\n",
      " 0.95363346 0.95362369 0.95362264 0.95362344 0.95363932 0.95363981\n",
      " 0.95360837 0.95361782 0.95363242 0.95364907 0.95360271 0.95361032\n",
      " 0.95363511 0.95362669 0.95363478 0.95364577 0.95366592 0.95362025\n",
      " 0.95361858 0.95363617 0.95361555 0.95362547 0.95364146 0.95363965\n",
      " 0.9536515  0.95364549 0.95365732 0.95362043 0.95364252 0.95364887\n",
      " 0.95364593 0.9536111  0.95361243 0.95359329 0.95363234 0.95359863\n",
      " 0.95361472 0.95361184 0.95362725 0.95363318 0.95363953 0.95364235\n",
      " 0.95363404 0.95361743 0.95362787 0.9536177  0.95365451 0.95361986\n",
      " 0.95364087 0.9536052  0.95364885 0.9536261  0.9536388  0.95362439\n",
      " 0.95365    0.9536251  0.95364881 0.95363917 0.95362836 0.95361419\n",
      " 0.95362283 0.95364097 0.95360492 0.95359786 0.95361872 0.95361965\n",
      "Epoch 12 RMSE: [0.94765708 0.94769565 0.94763277 0.94768893 0.94766618 0.947664     | 12/40 [00:29<01:06,  2.39s/it]\n",
      " 0.94767364 0.94765312 0.94766496 0.94768184 0.94766724 0.94767774\n",
      " 0.94767218 0.94766267 0.94767376 0.94764338 0.94768166 0.94765665\n",
      " 0.94765283 0.94770359 0.94766981 0.94767537 0.94766609 0.94764863\n",
      " 0.94764898 0.94766509 0.9476688  0.94769696 0.94768249 0.94765882\n",
      " 0.94767371 0.94766519 0.94766387 0.94766566 0.94767913 0.94768116\n",
      " 0.94765028 0.94765911 0.94766987 0.94769018 0.9476467  0.94765397\n",
      " 0.94767479 0.94766754 0.94767501 0.94768491 0.9477023  0.94766229\n",
      " 0.94766049 0.9476745  0.94765805 0.94766628 0.94768163 0.947678\n",
      " 0.94768793 0.94768374 0.94769577 0.94766257 0.94768209 0.94768808\n",
      " 0.94768622 0.94765347 0.94765496 0.94763823 0.94767304 0.94764105\n",
      " 0.94765615 0.94765504 0.9476667  0.94767507 0.94767778 0.94768277\n",
      " 0.9476749  0.94765828 0.94766685 0.94765747 0.94769349 0.9476598\n",
      " 0.94768074 0.94764938 0.94768697 0.94766657 0.94767781 0.94766472\n",
      " 0.94768706 0.94766521 0.94768765 0.9476782  0.94766698 0.94765685\n",
      " 0.94766267 0.9476793  0.94764692 0.94764121 0.94766126 0.94766045\n",
      "Epoch 12 RMSE: [0.94765708 0.94769565 0.94763277 0.94768893 0.94766618 0.947664     | 12/40 [00:29<01:06,  2.39s/it]     \n",
      " 0.94767364 0.94765312 0.94766496 0.94768184 0.94766724 0.94767774\n",
      " 0.94767218 0.94766267 0.94767376 0.94764338 0.94768166 0.94765665\n",
      " 0.94765283 0.94770359 0.94766981 0.94767537 0.94766609 0.94764863\n",
      " 0.94764898 0.94766509 0.9476688  0.94769696 0.94768249 0.94765882\n",
      " 0.94767371 0.94766519 0.94766387 0.94766566 0.94767913 0.94768116\n",
      " 0.94765028 0.94765911 0.94766987 0.94769018 0.9476467  0.94765397\n",
      " 0.94767479 0.94766754 0.94767501 0.94768491 0.9477023  0.94766229\n",
      " 0.94766049 0.9476745  0.94765805 0.94766628 0.94768163 0.947678\n",
      " 0.94768793 0.94768374 0.94769577 0.94766257 0.94768209 0.94768808\n",
      " 0.94768622 0.94765347 0.94765496 0.94763823 0.94767304 0.94764105\n",
      " 0.94765615 0.94765504 0.9476667  0.94767507 0.94767778 0.94768277\n",
      " 0.9476749  0.94765828 0.94766685 0.94765747 0.94769349 0.9476598\n",
      " 0.94768074 0.94764938 0.94768697 0.94766657 0.94767781 0.94766472\n",
      " 0.94768706 0.94766521 0.94768765 0.9476782  0.94766698 0.94765685\n",
      " 0.94766267 0.9476793  0.94764692 0.94764121 0.94766126 0.94766045\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13 RMSE: [0.94242167 0.94245501 0.94239796 0.94244964 0.94242964 0.94242677   | 13/40 [00:32<01:04,  2.38s/it]\n",
      " 0.94243566 0.94241669 0.94242748 0.94244225 0.94242968 0.94243994\n",
      " 0.9424347  0.94242551 0.94243552 0.94240881 0.94244215 0.94242032\n",
      " 0.94241652 0.94246188 0.94243226 0.94243719 0.94242775 0.94241276\n",
      " 0.9424129  0.94242876 0.94243091 0.94245705 0.94244329 0.94242141\n",
      " 0.94243586 0.94242807 0.94242708 0.94242915 0.9424409  0.94244413\n",
      " 0.9424143  0.9424223  0.94242981 0.94245265 0.942412   0.94241917\n",
      " 0.94243634 0.94243026 0.94243717 0.94244606 0.94246105 0.94242636\n",
      " 0.94242434 0.94243546 0.94242238 0.94242902 0.94244364 0.94243855\n",
      " 0.94244718 0.94244423 0.94245625 0.94242622 0.94244366 0.94244964\n",
      " 0.94244806 0.94241736 0.94241905 0.94240443 0.94243573 0.94240543\n",
      " 0.94241937 0.94241989 0.94242829 0.94243845 0.94243873 0.94244488\n",
      " 0.94243726 0.94242083 0.94242829 0.94241988 0.94245415 0.94242219\n",
      " 0.94244239 0.94241518 0.94244741 0.9424288  0.94243902 0.94242742\n",
      " 0.94244647 0.94242764 0.9424486  0.94243932 0.94242807 0.94242085\n",
      " 0.94242457 0.94244017 0.94241092 0.94240614 0.94242508 0.94242352\n",
      "Epoch 13 RMSE: [0.94242167 0.94245501 0.94239796 0.94244964 0.94242964 0.94242677   | 13/40 [00:32<01:04,  2.38s/it]\n",
      " 0.94243566 0.94241669 0.94242748 0.94244225 0.94242968 0.94243994\n",
      " 0.9424347  0.94242551 0.94243552 0.94240881 0.94244215 0.94242032\n",
      " 0.94241652 0.94246188 0.94243226 0.94243719 0.94242775 0.94241276\n",
      " 0.9424129  0.94242876 0.94243091 0.94245705 0.94244329 0.94242141\n",
      " 0.94243586 0.94242807 0.94242708 0.94242915 0.9424409  0.94244413\n",
      " 0.9424143  0.9424223  0.94242981 0.94245265 0.942412   0.94241917\n",
      " 0.94243634 0.94243026 0.94243717 0.94244606 0.94246105 0.94242636\n",
      " 0.94242434 0.94243546 0.94242238 0.94242902 0.94244364 0.94243855\n",
      " 0.94244718 0.94244423 0.94245625 0.94242622 0.94244366 0.94244964\n",
      " 0.94244806 0.94241736 0.94241905 0.94240443 0.94243573 0.94240543\n",
      " 0.94241937 0.94241989 0.94242829 0.94243845 0.94243873 0.94244488\n",
      " 0.94243726 0.94242083 0.94242829 0.94241988 0.94245415 0.94242219\n",
      " 0.94244239 0.94241518 0.94244741 0.9424288  0.94243902 0.94242742\n",
      " 0.94244647 0.94242764 0.9424486  0.94243932 0.94242807 0.94242085\n",
      " 0.94242457 0.94244017 0.94241092 0.94240614 0.94242508 0.94242352\n",
      "Epoch 14 RMSE: [0.9380729  0.93810193 0.93804964 0.93809779 0.93808017 0.93807681   | 14/40 [00:34<01:01,  2.37s/it]\n",
      " 0.93808492 0.93806746 0.93807753 0.93808962 0.93807912 0.93808927\n",
      " 0.9380845  0.9380754  0.93808462 0.93806081 0.93809038 0.9380707\n",
      " 0.93806719 0.93810823 0.93808146 0.93808617 0.93807663 0.93806401\n",
      " 0.93806365 0.9380789  0.93807998 0.93810428 0.93809143 0.938071\n",
      " 0.93808527 0.93807792 0.93807713 0.93807948 0.93808988 0.93809373\n",
      " 0.93806515 0.93807248 0.93807742 0.93810192 0.93806415 0.93807098\n",
      " 0.93808517 0.93808003 0.93808658 0.93809419 0.9381074  0.93807715\n",
      " 0.93807482 0.93808371 0.93807322 0.93807872 0.93809271 0.93808645\n",
      " 0.93809391 0.93809203 0.93810411 0.93807661 0.93809222 0.93809822\n",
      " 0.93809697 0.93806833 0.93806988 0.93805712 0.93808548 0.93805662\n",
      " 0.93806972 0.93807144 0.93807716 0.93808838 0.93808667 0.93809374\n",
      " 0.93808691 0.93807064 0.93807691 0.93806925 0.93810202 0.93807175\n",
      " 0.93809133 0.93806764 0.93809515 0.93807829 0.93808742 0.93807713\n",
      " 0.93809335 0.93807706 0.9380968  0.93808785 0.93807651 0.93807192\n",
      " 0.93807378 0.93808833 0.93806168 0.9380582  0.93807566 0.93807356\n",
      "Epoch 14 RMSE: [0.9380729  0.93810193 0.93804964 0.93809779 0.93808017 0.93807681   | 14/40 [00:34<01:01,  2.37s/it]   \n",
      " 0.93808492 0.93806746 0.93807753 0.93808962 0.93807912 0.93808927\n",
      " 0.9380845  0.9380754  0.93808462 0.93806081 0.93809038 0.9380707\n",
      " 0.93806719 0.93810823 0.93808146 0.93808617 0.93807663 0.93806401\n",
      " 0.93806365 0.9380789  0.93807998 0.93810428 0.93809143 0.938071\n",
      " 0.93808527 0.93807792 0.93807713 0.93807948 0.93808988 0.93809373\n",
      " 0.93806515 0.93807248 0.93807742 0.93810192 0.93806415 0.93807098\n",
      " 0.93808517 0.93808003 0.93808658 0.93809419 0.9381074  0.93807715\n",
      " 0.93807482 0.93808371 0.93807322 0.93807872 0.93809271 0.93808645\n",
      " 0.93809391 0.93809203 0.93810411 0.93807661 0.93809222 0.93809822\n",
      " 0.93809697 0.93806833 0.93806988 0.93805712 0.93808548 0.93805662\n",
      " 0.93806972 0.93807144 0.93807716 0.93808838 0.93808667 0.93809374\n",
      " 0.93808691 0.93807064 0.93807691 0.93806925 0.93810202 0.93807175\n",
      " 0.93809133 0.93806764 0.93809515 0.93807829 0.93808742 0.93807713\n",
      " 0.93809335 0.93807706 0.9380968  0.93808785 0.93807651 0.93807192\n",
      " 0.93807378 0.93808833 0.93806168 0.9380582  0.93807566 0.93807356\n",
      "Epoch 15 RMSE: [0.93405466 0.9340801  0.934032   0.93407668 0.93406119 0.93405746   | 15/40 [00:36<00:59,  2.38s/it]\n",
      " 0.93406501 0.93404887 0.93405817 0.93406843 0.93405947 0.93406962\n",
      " 0.93406469 0.93405606 0.93406446 0.9340433  0.93406934 0.93405188\n",
      " 0.9340484  0.93408552 0.93406174 0.93406584 0.93405665 0.93404574\n",
      " 0.93404505 0.93405986 0.93406013 0.93408281 0.93407048 0.93405159\n",
      " 0.93406533 0.93405847 0.93405791 0.93406037 0.93406988 0.93407413\n",
      " 0.93404685 0.9340534  0.93405653 0.93408173 0.9340467  0.9340531\n",
      " 0.93406499 0.93406087 0.9340664  0.93407335 0.93408522 0.93405857\n",
      " 0.93405595 0.93406317 0.93405485 0.93405906 0.93407257 0.93406555\n",
      " 0.93407195 0.93407072 0.93408251 0.93405763 0.93407174 0.93407741\n",
      " 0.93407665 0.93404982 0.93405133 0.93404042 0.93406611 0.93403863\n",
      " 0.9340507  0.93405337 0.93405675 0.934069   0.93406573 0.93407359\n",
      " 0.93406724 0.93405173 0.93405685 0.93404977 0.93408081 0.93405228\n",
      " 0.93407092 0.93405053 0.93407401 0.93405854 0.93406664 0.93405777\n",
      " 0.93407149 0.93405726 0.93407599 0.93406703 0.93405594 0.93405333\n",
      " 0.93405387 0.93406767 0.93404327 0.93404073 0.93405704 0.93405458\n",
      "Epoch 15 RMSE: [0.93405466 0.9340801  0.934032   0.93407668 0.93406119 0.93405746   | 15/40 [00:36<00:59,  2.38s/it]\n",
      " 0.93406501 0.93404887 0.93405817 0.93406843 0.93405947 0.93406962\n",
      " 0.93406469 0.93405606 0.93406446 0.9340433  0.93406934 0.93405188\n",
      " 0.9340484  0.93408552 0.93406174 0.93406584 0.93405665 0.93404574\n",
      " 0.93404505 0.93405986 0.93406013 0.93408281 0.93407048 0.93405159\n",
      " 0.93406533 0.93405847 0.93405791 0.93406037 0.93406988 0.93407413\n",
      " 0.93404685 0.9340534  0.93405653 0.93408173 0.9340467  0.9340531\n",
      " 0.93406499 0.93406087 0.9340664  0.93407335 0.93408522 0.93405857\n",
      " 0.93405595 0.93406317 0.93405485 0.93405906 0.93407257 0.93406555\n",
      " 0.93407195 0.93407072 0.93408251 0.93405763 0.93407174 0.93407741\n",
      " 0.93407665 0.93404982 0.93405133 0.93404042 0.93406611 0.93403863\n",
      " 0.9340507  0.93405337 0.93405675 0.934069   0.93406573 0.93407359\n",
      " 0.93406724 0.93405173 0.93405685 0.93404977 0.93408081 0.93405228\n",
      " 0.93407092 0.93405053 0.93407401 0.93405854 0.93406664 0.93405777\n",
      " 0.93407149 0.93405726 0.93407599 0.93406703 0.93405594 0.93405333\n",
      " 0.93405387 0.93406767 0.93404327 0.93404073 0.93405704 0.93405458\n",
      "Epoch 16 RMSE: [0.93049748 0.93051952 0.93047527 0.93051692 0.93050324 0.93049948   | 16/40 [00:39<00:57,  2.38s/it]\n",
      " 0.9305064  0.93049127 0.93050022 0.93050846 0.93050115 0.93051083\n",
      " 0.93050615 0.93049781 0.93050559 0.93048669 0.93050969 0.93049385\n",
      " 0.93049065 0.93052478 0.93050302 0.93050672 0.93049789 0.93048859\n",
      " 0.93048731 0.9305018  0.93050168 0.93052245 0.93051101 0.9304933\n",
      " 0.93050653 0.93050002 0.93049977 0.93050252 0.93051119 0.93051542\n",
      " 0.93048964 0.9304954  0.93049713 0.93052273 0.93049021 0.93049606\n",
      " 0.93050613 0.93050248 0.93050784 0.93051398 0.93052447 0.93050086\n",
      " 0.93049819 0.93050399 0.93049753 0.93050087 0.9305136  0.93050592\n",
      " 0.93051163 0.9305109  0.93052208 0.93049988 0.9305126  0.93051792\n",
      " 0.93051748 0.93049249 0.93049378 0.93048443 0.9305079  0.93048177\n",
      " 0.93049304 0.93049618 0.93049766 0.93051063 0.93050627 0.93051455\n",
      " 0.93050881 0.930494   0.93049817 0.93049168 0.93052071 0.93049398\n",
      " 0.93051177 0.93049432 0.9305143  0.93049999 0.93050746 0.93049972\n",
      " 0.93051116 0.93049898 0.93051637 0.93050739 0.93049683 0.93049562\n",
      " 0.93049531 0.93050831 0.93048542 0.93048411 0.93049954 0.93049665\n",
      "Epoch 16 RMSE: [0.93049748 0.93051952 0.93047527 0.93051692 0.93050324 0.93049948   | 16/40 [00:39<00:57,  2.38s/it]\n",
      " 0.9305064  0.93049127 0.93050022 0.93050846 0.93050115 0.93051083\n",
      " 0.93050615 0.93049781 0.93050559 0.93048669 0.93050969 0.93049385\n",
      " 0.93049065 0.93052478 0.93050302 0.93050672 0.93049789 0.93048859\n",
      " 0.93048731 0.9305018  0.93050168 0.93052245 0.93051101 0.9304933\n",
      " 0.93050653 0.93050002 0.93049977 0.93050252 0.93051119 0.93051542\n",
      " 0.93048964 0.9304954  0.93049713 0.93052273 0.93049021 0.93049606\n",
      " 0.93050613 0.93050248 0.93050784 0.93051398 0.93052447 0.93050086\n",
      " 0.93049819 0.93050399 0.93049753 0.93050087 0.9305136  0.93050592\n",
      " 0.93051163 0.9305109  0.93052208 0.93049988 0.9305126  0.93051792\n",
      " 0.93051748 0.93049249 0.93049378 0.93048443 0.9305079  0.93048177\n",
      " 0.93049304 0.93049618 0.93049766 0.93051063 0.93050627 0.93051455\n",
      " 0.93050881 0.930494   0.93049817 0.93049168 0.93052071 0.93049398\n",
      " 0.93051177 0.93049432 0.9305143  0.93049999 0.93050746 0.93049972\n",
      " 0.93051116 0.93049898 0.93051637 0.93050739 0.93049683 0.93049562\n",
      " 0.93049531 0.93050831 0.93048542 0.93048411 0.93049954 0.93049665\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17 RMSE: [0.92714744 0.92716697 0.92712595 0.92716482 0.9271528  0.92714876   | 17/40 [00:41<00:54,  2.37s/it]\n",
      " 0.92715527 0.92714136 0.92714956 0.92715638 0.92715026 0.9271597\n",
      " 0.92715493 0.92714704 0.92715426 0.92713721 0.92715782 0.92714357\n",
      " 0.92714057 0.92717165 0.92715207 0.92715539 0.92714643 0.92713889\n",
      " 0.92713736 0.92715091 0.92715067 0.92716998 0.92715892 0.92714273\n",
      " 0.92715531 0.92714926 0.92714926 0.92715176 0.92715998 0.92716418\n",
      " 0.92713979 0.92714502 0.92714533 0.92717078 0.92714096 0.92714635\n",
      " 0.92715485 0.92715167 0.92715651 0.92716198 0.92717138 0.92715065\n",
      " 0.9271479  0.92715235 0.92714731 0.92714985 0.92716222 0.92715401\n",
      " 0.92715894 0.92715872 0.92716959 0.92714928 0.92716093 0.92716598\n",
      " 0.92716575 0.9271423  0.92714358 0.92713565 0.92715696 0.92713227\n",
      " 0.92714273 0.9271466  0.92714622 0.92715961 0.92715452 0.92716285\n",
      " 0.92715791 0.92714354 0.92714721 0.92714098 0.92716835 0.92714319\n",
      " 0.92716019 0.92714521 0.92716224 0.9271489  0.92715569 0.92714907\n",
      " 0.92715862 0.92714818 0.92716453 0.92715568 0.92714544 0.92714559\n",
      " 0.9271442  0.92715665 0.92713534 0.92713479 0.92714913 0.92714612\n",
      "Epoch 17 RMSE: [0.92714744 0.92716697 0.92712595 0.92716482 0.9271528  0.92714876   | 17/40 [00:41<00:54,  2.37s/it]\n",
      " 0.92715527 0.92714136 0.92714956 0.92715638 0.92715026 0.9271597\n",
      " 0.92715493 0.92714704 0.92715426 0.92713721 0.92715782 0.92714357\n",
      " 0.92714057 0.92717165 0.92715207 0.92715539 0.92714643 0.92713889\n",
      " 0.92713736 0.92715091 0.92715067 0.92716998 0.92715892 0.92714273\n",
      " 0.92715531 0.92714926 0.92714926 0.92715176 0.92715998 0.92716418\n",
      " 0.92713979 0.92714502 0.92714533 0.92717078 0.92714096 0.92714635\n",
      " 0.92715485 0.92715167 0.92715651 0.92716198 0.92717138 0.92715065\n",
      " 0.9271479  0.92715235 0.92714731 0.92714985 0.92716222 0.92715401\n",
      " 0.92715894 0.92715872 0.92716959 0.92714928 0.92716093 0.92716598\n",
      " 0.92716575 0.9271423  0.92714358 0.92713565 0.92715696 0.92713227\n",
      " 0.92714273 0.9271466  0.92714622 0.92715961 0.92715452 0.92716285\n",
      " 0.92715791 0.92714354 0.92714721 0.92714098 0.92716835 0.92714319\n",
      " 0.92716019 0.92714521 0.92716224 0.9271489  0.92715569 0.92714907\n",
      " 0.92715862 0.92714818 0.92716453 0.92715568 0.92714544 0.92714559\n",
      " 0.9271442  0.92715665 0.92713534 0.92713479 0.92714913 0.92714612\n",
      "Epoch 18 RMSE: [0.92371638 0.92373374 0.92369539 0.92373195 0.92372119 0.92371734   | 18/40 [00:43<00:52,  2.36s/it]\n",
      " 0.92372333 0.92371018 0.92371809 0.9237235  0.92371867 0.92372747\n",
      " 0.92372281 0.92371545 0.92372208 0.92370676 0.9237253  0.92371207\n",
      " 0.92370931 0.92373784 0.92372002 0.92372321 0.92371436 0.92370792\n",
      " 0.92370615 0.9237192  0.92371863 0.92373675 0.92372621 0.92371132\n",
      " 0.92372328 0.92371753 0.92371751 0.92372022 0.9237277  0.92373198\n",
      " 0.92370893 0.92371344 0.92371275 0.92373809 0.92371035 0.92371553\n",
      " 0.9237226  0.92372003 0.92372456 0.9237292  0.92373803 0.92371934\n",
      " 0.92371631 0.92372011 0.92371632 0.92371803 0.92372989 0.9237215\n",
      " 0.92372566 0.92372605 0.92373631 0.92371775 0.92372841 0.92373317\n",
      " 0.92373334 0.92371127 0.92371246 0.92370561 0.92372518 0.92370161\n",
      " 0.9237113  0.92371578 0.92371402 0.92372767 0.92372203 0.92373032\n",
      " 0.92372593 0.92371218 0.92371532 0.92370948 0.92373509 0.92371158\n",
      " 0.92372759 0.92371497 0.92372913 0.92371694 0.92372338 0.92371765\n",
      " 0.92372538 0.92371633 0.92373194 0.92372305 0.92371332 0.92371424\n",
      " 0.9237122  0.92372428 0.92370422 0.92370451 0.92371783 0.92371467\n",
      "Epoch 18 RMSE: [0.92371638 0.92373374 0.92369539 0.92373195 0.92372119 0.92371734   | 18/40 [00:43<00:52,  2.36s/it]\n",
      " 0.92372333 0.92371018 0.92371809 0.9237235  0.92371867 0.92372747\n",
      " 0.92372281 0.92371545 0.92372208 0.92370676 0.9237253  0.92371207\n",
      " 0.92370931 0.92373784 0.92372002 0.92372321 0.92371436 0.92370792\n",
      " 0.92370615 0.9237192  0.92371863 0.92373675 0.92372621 0.92371132\n",
      " 0.92372328 0.92371753 0.92371751 0.92372022 0.9237277  0.92373198\n",
      " 0.92370893 0.92371344 0.92371275 0.92373809 0.92371035 0.92371553\n",
      " 0.9237226  0.92372003 0.92372456 0.9237292  0.92373803 0.92371934\n",
      " 0.92371631 0.92372011 0.92371632 0.92371803 0.92372989 0.9237215\n",
      " 0.92372566 0.92372605 0.92373631 0.92371775 0.92372841 0.92373317\n",
      " 0.92373334 0.92371127 0.92371246 0.92370561 0.92372518 0.92370161\n",
      " 0.9237113  0.92371578 0.92371402 0.92372767 0.92372203 0.92373032\n",
      " 0.92372593 0.92371218 0.92371532 0.92370948 0.92373509 0.92371158\n",
      " 0.92372759 0.92371497 0.92372913 0.92371694 0.92372338 0.92371765\n",
      " 0.92372538 0.92371633 0.92373194 0.92372305 0.92371332 0.92371424\n",
      " 0.9237122  0.92372428 0.92370422 0.92370451 0.92371783 0.92371467\n",
      "Epoch 19 RMSE: [0.92036649 0.92038166 0.92034596 0.92038016 0.92037059 0.92036677   | 19/40 [00:46<00:49,  2.37s/it]\n",
      " 0.92037241 0.9203601  0.92036774 0.92037168 0.92036813 0.9203763\n",
      " 0.92037182 0.92036494 0.92037122 0.92035722 0.92037376 0.92036178\n",
      " 0.92035917 0.92038546 0.9203691  0.92037202 0.92036345 0.92035818\n",
      " 0.92035624 0.92036857 0.92036789 0.92038479 0.92037488 0.92036083\n",
      " 0.92037237 0.92036669 0.92036692 0.92036962 0.92037696 0.92038083\n",
      " 0.9203592  0.92036312 0.92036172 0.92038654 0.92036075 0.92036567\n",
      " 0.92037159 0.92036923 0.92037387 0.92037782 0.92038576 0.9203693\n",
      " 0.92036627 0.920369   0.92036633 0.92036734 0.92037867 0.92037008\n",
      " 0.92037368 0.92037421 0.92038414 0.92036738 0.92037709 0.92038137\n",
      " 0.92038191 0.92036128 0.92036223 0.92035638 0.92037447 0.92035213\n",
      " 0.92036125 0.92036582 0.92036302 0.92037659 0.92037066 0.92037869\n",
      " 0.92037531 0.92036215 0.92036481 0.92035902 0.92038317 0.92036109\n",
      " 0.92037616 0.92036588 0.92037759 0.92036632 0.92037216 0.92036717\n",
      " 0.92037345 0.92036576 0.92038041 0.9203715  0.92036231 0.92036405\n",
      " 0.92036156 0.92037306 0.92035393 0.92035518 0.92036757 0.92036435\n",
      "Epoch 19 RMSE: [0.92036649 0.92038166 0.92034596 0.92038016 0.92037059 0.92036677   | 19/40 [00:46<00:49,  2.37s/it] \n",
      " 0.92037241 0.9203601  0.92036774 0.92037168 0.92036813 0.9203763\n",
      " 0.92037182 0.92036494 0.92037122 0.92035722 0.92037376 0.92036178\n",
      " 0.92035917 0.92038546 0.9203691  0.92037202 0.92036345 0.92035818\n",
      " 0.92035624 0.92036857 0.92036789 0.92038479 0.92037488 0.92036083\n",
      " 0.92037237 0.92036669 0.92036692 0.92036962 0.92037696 0.92038083\n",
      " 0.9203592  0.92036312 0.92036172 0.92038654 0.92036075 0.92036567\n",
      " 0.92037159 0.92036923 0.92037387 0.92037782 0.92038576 0.9203693\n",
      " 0.92036627 0.920369   0.92036633 0.92036734 0.92037867 0.92037008\n",
      " 0.92037368 0.92037421 0.92038414 0.92036738 0.92037709 0.92038137\n",
      " 0.92038191 0.92036128 0.92036223 0.92035638 0.92037447 0.92035213\n",
      " 0.92036125 0.92036582 0.92036302 0.92037659 0.92037066 0.92037869\n",
      " 0.92037531 0.92036215 0.92036481 0.92035902 0.92038317 0.92036109\n",
      " 0.92037616 0.92036588 0.92037759 0.92036632 0.92037216 0.92036717\n",
      " 0.92037345 0.92036576 0.92038041 0.9203715  0.92036231 0.92036405\n",
      " 0.92036156 0.92037306 0.92035393 0.92035518 0.92036757 0.92036435\n",
      "Epoch 20 RMSE: [0.91679736 0.9168108  0.91677763 0.9168097  0.91680124 0.9167976    | 20/40 [00:48<00:47,  2.36s/it]\n",
      " 0.91680273 0.91679114 0.91679858 0.91680147 0.91679851 0.91680629\n",
      " 0.91680192 0.91679561 0.91680131 0.91678853 0.91680357 0.9167927\n",
      " 0.91679007 0.91681454 0.91679948 0.91680212 0.91679378 0.91678955\n",
      " 0.91678719 0.91679889 0.9167983  0.91681408 0.91680477 0.91679167\n",
      " 0.91680246 0.91679725 0.91679729 0.9168002  0.91680706 0.91681077\n",
      " 0.91679052 0.9167938  0.91679187 0.91681619 0.9167924  0.91679686\n",
      " 0.91680175 0.91679981 0.916804   0.91680747 0.91681498 0.91680007\n",
      " 0.91679704 0.91679899 0.91679735 0.91679784 0.91680865 0.91679986\n",
      " 0.91680297 0.91680382 0.91681324 0.91679804 0.91680704 0.91681089\n",
      " 0.91681164 0.91679227 0.91679312 0.91678825 0.91680506 0.91678373\n",
      " 0.91679208 0.91679687 0.91679319 0.91680679 0.91680066 0.91680845\n",
      " 0.91680554 0.91679303 0.91679532 0.9167898  0.91681233 0.91679165\n",
      " 0.91680591 0.91679738 0.91680716 0.91679653 0.916802   0.91679794\n",
      " 0.91680275 0.91679611 0.9168098  0.91680122 0.91679272 0.9167949\n",
      " 0.91679181 0.91680316 0.91678491 0.91678671 0.9167982  0.91679524\n",
      "Epoch 20 RMSE: [0.91679736 0.9168108  0.91677763 0.9168097  0.91680124 0.9167976    | 20/40 [00:48<00:47,  2.36s/it] \n",
      " 0.91680273 0.91679114 0.91679858 0.91680147 0.91679851 0.91680629\n",
      " 0.91680192 0.91679561 0.91680131 0.91678853 0.91680357 0.9167927\n",
      " 0.91679007 0.91681454 0.91679948 0.91680212 0.91679378 0.91678955\n",
      " 0.91678719 0.91679889 0.9167983  0.91681408 0.91680477 0.91679167\n",
      " 0.91680246 0.91679725 0.91679729 0.9168002  0.91680706 0.91681077\n",
      " 0.91679052 0.9167938  0.91679187 0.91681619 0.9167924  0.91679686\n",
      " 0.91680175 0.91679981 0.916804   0.91680747 0.91681498 0.91680007\n",
      " 0.91679704 0.91679899 0.91679735 0.91679784 0.91680865 0.91679986\n",
      " 0.91680297 0.91680382 0.91681324 0.91679804 0.91680704 0.91681089\n",
      " 0.91681164 0.91679227 0.91679312 0.91678825 0.91680506 0.91678373\n",
      " 0.91679208 0.91679687 0.91679319 0.91680679 0.91680066 0.91680845\n",
      " 0.91680554 0.91679303 0.91679532 0.9167898  0.91681233 0.91679165\n",
      " 0.91680591 0.91679738 0.91680716 0.91679653 0.916802   0.91679794\n",
      " 0.91680275 0.91679611 0.9168098  0.91680122 0.91679272 0.9167949\n",
      " 0.91679181 0.91680316 0.91678491 0.91678671 0.9167982  0.91679524\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21 RMSE: [0.91288634 0.91289823 0.91286723 0.91289745 0.91288981 0.91288643   | 21/40 [00:50<00:44,  2.36s/it]\n",
      " 0.9128911  0.91288024 0.91288735 0.9128891  0.91288708 0.91289446\n",
      " 0.91289013 0.91288426 0.91288966 0.91287784 0.91289153 0.91288155\n",
      " 0.91287905 0.91290158 0.91288794 0.91289021 0.91288232 0.91287877\n",
      " 0.91287618 0.91288732 0.91288667 0.91290168 0.91289257 0.91288048\n",
      " 0.91289076 0.91288575 0.91288588 0.91288878 0.91289542 0.9128987\n",
      " 0.9128798  0.9128825  0.91288016 0.91290374 0.91288171 0.91288599\n",
      " 0.91288986 0.91288848 0.91289212 0.91289528 0.91290243 0.91288905\n",
      " 0.91288569 0.91288723 0.91288623 0.91288636 0.91289681 0.91288791\n",
      " 0.9128906  0.91289159 0.91290042 0.91288679 0.91289487 0.91289864\n",
      " 0.91289945 0.91288128 0.91288214 0.91287793 0.9128934  0.91287336\n",
      " 0.91288113 0.91288609 0.91288154 0.91289499 0.91288885 0.91289632\n",
      " 0.91289398 0.91288213 0.91288396 0.91287871 0.91289974 0.91288028\n",
      " 0.91289392 0.91288699 0.9128948  0.91288487 0.91289004 0.91288658\n",
      " 0.9128903  0.91288471 0.91289764 0.91288905 0.91288116 0.91288375\n",
      " 0.91288023 0.91289138 0.91287401 0.91287621 0.9128869  0.91288384\n",
      "Epoch 21 RMSE: [0.91288634 0.91289823 0.91286723 0.91289745 0.91288981 0.91288643   | 21/40 [00:50<00:44,  2.36s/it]\n",
      " 0.9128911  0.91288024 0.91288735 0.9128891  0.91288708 0.91289446\n",
      " 0.91289013 0.91288426 0.91288966 0.91287784 0.91289153 0.91288155\n",
      " 0.91287905 0.91290158 0.91288794 0.91289021 0.91288232 0.91287877\n",
      " 0.91287618 0.91288732 0.91288667 0.91290168 0.91289257 0.91288048\n",
      " 0.91289076 0.91288575 0.91288588 0.91288878 0.91289542 0.9128987\n",
      " 0.9128798  0.9128825  0.91288016 0.91290374 0.91288171 0.91288599\n",
      " 0.91288986 0.91288848 0.91289212 0.91289528 0.91290243 0.91288905\n",
      " 0.91288569 0.91288723 0.91288623 0.91288636 0.91289681 0.91288791\n",
      " 0.9128906  0.91289159 0.91290042 0.91288679 0.91289487 0.91289864\n",
      " 0.91289945 0.91288128 0.91288214 0.91287793 0.9128934  0.91287336\n",
      " 0.91288113 0.91288609 0.91288154 0.91289499 0.91288885 0.91289632\n",
      " 0.91289398 0.91288213 0.91288396 0.91287871 0.91289974 0.91288028\n",
      " 0.91289392 0.91288699 0.9128948  0.91288487 0.91289004 0.91288658\n",
      " 0.9128903  0.91288471 0.91289764 0.91288905 0.91288116 0.91288375\n",
      " 0.91288023 0.91289138 0.91287401 0.91287621 0.9128869  0.91288384\n",
      "Epoch 22 RMSE: [0.90878321 0.90879365 0.90876443 0.90879308 0.90878607 0.90878297   | 22/40 [00:53<00:42,  2.36s/it]\n",
      " 0.90878753 0.90877714 0.90878393 0.90878485 0.90878358 0.90879038\n",
      " 0.90878612 0.90878067 0.90878585 0.90877489 0.90878738 0.90877831\n",
      " 0.90877571 0.90879683 0.90878424 0.90878624 0.90877857 0.90877572\n",
      " 0.90877303 0.90878363 0.90878294 0.90879717 0.90878846 0.90877722\n",
      " 0.90878694 0.90878216 0.90878241 0.90878522 0.9087916  0.90879471\n",
      " 0.90877681 0.90877909 0.90877643 0.90879933 0.90877883 0.90878288\n",
      " 0.90878604 0.90878479 0.90878845 0.9087911  0.90879779 0.90878555\n",
      " 0.90878238 0.90878335 0.90878283 0.90878259 0.90879283 0.90878389\n",
      " 0.90878621 0.90878722 0.90879575 0.90878326 0.9087908  0.9087942\n",
      " 0.90879515 0.90877805 0.90877887 0.90877535 0.90878977 0.90877081\n",
      " 0.90877778 0.90878299 0.90877778 0.90879106 0.90878492 0.908792\n",
      " 0.9087903  0.90877897 0.90878051 0.90877554 0.90879516 0.90877682\n",
      " 0.90878962 0.90878418 0.90879056 0.90878103 0.90878599 0.90878304\n",
      " 0.90878581 0.90878106 0.90879333 0.90878476 0.90877748 0.90878039\n",
      " 0.90877668 0.90878743 0.90877075 0.90877347 0.90878331 0.90878046\n",
      "Epoch 22 RMSE: [0.90878321 0.90879365 0.90876443 0.90879308 0.90878607 0.90878297   | 22/40 [00:53<00:42,  2.36s/it]  \n",
      " 0.90878753 0.90877714 0.90878393 0.90878485 0.90878358 0.90879038\n",
      " 0.90878612 0.90878067 0.90878585 0.90877489 0.90878738 0.90877831\n",
      " 0.90877571 0.90879683 0.90878424 0.90878624 0.90877857 0.90877572\n",
      " 0.90877303 0.90878363 0.90878294 0.90879717 0.90878846 0.90877722\n",
      " 0.90878694 0.90878216 0.90878241 0.90878522 0.9087916  0.90879471\n",
      " 0.90877681 0.90877909 0.90877643 0.90879933 0.90877883 0.90878288\n",
      " 0.90878604 0.90878479 0.90878845 0.9087911  0.90879779 0.90878555\n",
      " 0.90878238 0.90878335 0.90878283 0.90878259 0.90879283 0.90878389\n",
      " 0.90878621 0.90878722 0.90879575 0.90878326 0.9087908  0.9087942\n",
      " 0.90879515 0.90877805 0.90877887 0.90877535 0.90878977 0.90877081\n",
      " 0.90877778 0.90878299 0.90877778 0.90879106 0.90878492 0.908792\n",
      " 0.9087903  0.90877897 0.90878051 0.90877554 0.90879516 0.90877682\n",
      " 0.90878962 0.90878418 0.90879056 0.90878103 0.90878599 0.90878304\n",
      " 0.90878581 0.90878106 0.90879333 0.90878476 0.90877748 0.90878039\n",
      " 0.90877668 0.90878743 0.90877075 0.90877347 0.90878331 0.90878046\n",
      "Epoch 23 RMSE: [0.90422053 0.90422986 0.90420248 0.90422944 0.904223   0.90422032   | 23/40 [00:55<00:40,  2.36s/it]\n",
      " 0.90422445 0.90421455 0.90422124 0.90422135 0.90422064 0.90422695\n",
      " 0.90422281 0.90421798 0.90422258 0.90421247 0.90422392 0.90421567\n",
      " 0.90421317 0.90423275 0.90422125 0.90422307 0.9042156  0.90421342\n",
      " 0.90421034 0.90422029 0.90422001 0.90423299 0.90422515 0.90421446\n",
      " 0.90422382 0.90421928 0.90421941 0.90422233 0.90422836 0.90423127\n",
      " 0.9042146  0.90421628 0.90421338 0.90423546 0.90421644 0.90422036\n",
      " 0.90422263 0.90422177 0.9042252  0.90422756 0.90423391 0.90422269\n",
      " 0.90421953 0.90422021 0.90422029 0.90421968 0.90422942 0.9042206\n",
      " 0.90422219 0.90422371 0.90423179 0.90422032 0.90422729 0.9042307\n",
      " 0.90423174 0.90421557 0.90421616 0.90421333 0.90422676 0.90420883\n",
      " 0.90421514 0.9042205  0.90421473 0.90422767 0.90422153 0.90422851\n",
      " 0.90422709 0.90421633 0.90421769 0.90421279 0.90423119 0.9042139\n",
      " 0.90422596 0.90422192 0.90422702 0.90421791 0.90422274 0.90422038\n",
      " 0.90422228 0.90421833 0.90422976 0.90422142 0.90421453 0.90421772\n",
      " 0.90421375 0.90422426 0.90420816 0.9042113  0.90422056 0.90421763\n",
      "Epoch 23 RMSE: [0.90422053 0.90422986 0.90420248 0.90422944 0.904223   0.90422032   | 23/40 [00:55<00:40,  2.36s/it]\n",
      " 0.90422445 0.90421455 0.90422124 0.90422135 0.90422064 0.90422695\n",
      " 0.90422281 0.90421798 0.90422258 0.90421247 0.90422392 0.90421567\n",
      " 0.90421317 0.90423275 0.90422125 0.90422307 0.9042156  0.90421342\n",
      " 0.90421034 0.90422029 0.90422001 0.90423299 0.90422515 0.90421446\n",
      " 0.90422382 0.90421928 0.90421941 0.90422233 0.90422836 0.90423127\n",
      " 0.9042146  0.90421628 0.90421338 0.90423546 0.90421644 0.90422036\n",
      " 0.90422263 0.90422177 0.9042252  0.90422756 0.90423391 0.90422269\n",
      " 0.90421953 0.90422021 0.90422029 0.90421968 0.90422942 0.9042206\n",
      " 0.90422219 0.90422371 0.90423179 0.90422032 0.90422729 0.9042307\n",
      " 0.90423174 0.90421557 0.90421616 0.90421333 0.90422676 0.90420883\n",
      " 0.90421514 0.9042205  0.90421473 0.90422767 0.90422153 0.90422851\n",
      " 0.90422709 0.90421633 0.90421769 0.90421279 0.90423119 0.9042139\n",
      " 0.90422596 0.90422192 0.90422702 0.90421791 0.90422274 0.90422038\n",
      " 0.90422228 0.90421833 0.90422976 0.90422142 0.90421453 0.90421772\n",
      " 0.90421375 0.90422426 0.90420816 0.9042113  0.90422056 0.90421763\n",
      "Epoch 24 RMSE: [0.89944541 0.89945349 0.89942796 0.89945335 0.89944751 0.89944501   | 24/40 [00:58<00:38,  2.39s/it]\n",
      " 0.89944878 0.89943965 0.899446   0.89944534 0.89944544 0.89945125\n",
      " 0.899447   0.8994426  0.89944699 0.89943749 0.89944794 0.89944042\n",
      " 0.89943797 0.89945648 0.8994455  0.89944736 0.89944025 0.89943837\n",
      " 0.89943513 0.89944477 0.8994446  0.89945708 0.89944921 0.89943937\n",
      " 0.89944796 0.89944377 0.89944392 0.89944665 0.89945284 0.89945524\n",
      " 0.89943949 0.89944083 0.89943778 0.89945916 0.89944153 0.89944516\n",
      " 0.89944711 0.89944631 0.89944975 0.89945155 0.89945769 0.89944752\n",
      " 0.89944431 0.89944458 0.89944509 0.89944413 0.89945366 0.89944493\n",
      " 0.89944639 0.89944775 0.89945525 0.899445   0.89945131 0.89945457\n",
      " 0.89945558 0.89944056 0.89944123 0.89943846 0.8994512  0.89943387\n",
      " 0.89944006 0.89944532 0.8994392  0.8994519  0.89944585 0.89945251\n",
      " 0.89945128 0.89944128 0.89944248 0.89943771 0.8994548  0.89943861\n",
      " 0.89944998 0.89944716 0.89945072 0.89944257 0.89944696 0.89944494\n",
      " 0.89944599 0.89944286 0.89945349 0.8994455  0.89943896 0.89944237\n",
      " 0.89943852 0.89944858 0.89943303 0.89943647 0.8994449  0.89944234\n",
      "Epoch 24 RMSE: [0.89944541 0.89945349 0.89942796 0.89945335 0.89944751 0.89944501   | 24/40 [00:58<00:38,  2.39s/it]\n",
      " 0.89944878 0.89943965 0.899446   0.89944534 0.89944544 0.89945125\n",
      " 0.899447   0.8994426  0.89944699 0.89943749 0.89944794 0.89944042\n",
      " 0.89943797 0.89945648 0.8994455  0.89944736 0.89944025 0.89943837\n",
      " 0.89943513 0.89944477 0.8994446  0.89945708 0.89944921 0.89943937\n",
      " 0.89944796 0.89944377 0.89944392 0.89944665 0.89945284 0.89945524\n",
      " 0.89943949 0.89944083 0.89943778 0.89945916 0.89944153 0.89944516\n",
      " 0.89944711 0.89944631 0.89944975 0.89945155 0.89945769 0.89944752\n",
      " 0.89944431 0.89944458 0.89944509 0.89944413 0.89945366 0.89944493\n",
      " 0.89944639 0.89944775 0.89945525 0.899445   0.89945131 0.89945457\n",
      " 0.89945558 0.89944056 0.89944123 0.89943846 0.8994512  0.89943387\n",
      " 0.89944006 0.89944532 0.8994392  0.8994519  0.89944585 0.89945251\n",
      " 0.89945128 0.89944128 0.89944248 0.89943771 0.8994548  0.89943861\n",
      " 0.89944998 0.89944716 0.89945072 0.89944257 0.89944696 0.89944494\n",
      " 0.89944599 0.89944286 0.89945349 0.8994455  0.89943896 0.89944237\n",
      " 0.89943852 0.89944858 0.89943303 0.89943647 0.8994449  0.89944234\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25 RMSE: [0.8939617  0.89396902 0.89394454 0.89396896 0.89396355 0.89396113   | 25/40 [01:00<00:35,  2.38s/it]\n",
      " 0.8939648  0.89395604 0.89396229 0.89396108 0.89396134 0.89396702\n",
      " 0.89396299 0.89395879 0.89396286 0.89395418 0.89396375 0.89395685\n",
      " 0.8939543  0.89397161 0.89396178 0.89396314 0.8939562  0.89395505\n",
      " 0.89395179 0.89396073 0.89396054 0.89397242 0.89396513 0.8939557\n",
      " 0.89396393 0.89395993 0.89395987 0.89396286 0.89396866 0.89397102\n",
      " 0.89395619 0.89395716 0.89395402 0.8939746  0.89395806 0.89396162\n",
      " 0.89396294 0.89396246 0.8939657  0.89396737 0.89397308 0.89396356\n",
      " 0.8939605  0.89396057 0.89396137 0.89396035 0.89396944 0.89396085\n",
      " 0.89396204 0.8939634  0.89397064 0.89396123 0.89396705 0.89397009\n",
      " 0.89397127 0.89395682 0.89395731 0.89395543 0.89396713 0.89395092\n",
      " 0.89395642 0.89396178 0.89395519 0.89396772 0.89396193 0.89396804\n",
      " 0.89396738 0.89395783 0.89395883 0.89395421 0.89397033 0.89395483\n",
      " 0.89396561 0.89396384 0.89396653 0.89395857 0.89396291 0.89396119\n",
      " 0.89396164 0.89395909 0.89396916 0.89396109 0.8939552  0.89395868\n",
      " 0.89395448 0.89396461 0.89394963 0.89395322 0.8939614  0.89395862\n",
      "Epoch 25 RMSE: [0.8939617  0.89396902 0.89394454 0.89396896 0.89396355 0.89396113   | 25/40 [01:00<00:35,  2.38s/it] \n",
      " 0.8939648  0.89395604 0.89396229 0.89396108 0.89396134 0.89396702\n",
      " 0.89396299 0.89395879 0.89396286 0.89395418 0.89396375 0.89395685\n",
      " 0.8939543  0.89397161 0.89396178 0.89396314 0.8939562  0.89395505\n",
      " 0.89395179 0.89396073 0.89396054 0.89397242 0.89396513 0.8939557\n",
      " 0.89396393 0.89395993 0.89395987 0.89396286 0.89396866 0.89397102\n",
      " 0.89395619 0.89395716 0.89395402 0.8939746  0.89395806 0.89396162\n",
      " 0.89396294 0.89396246 0.8939657  0.89396737 0.89397308 0.89396356\n",
      " 0.8939605  0.89396057 0.89396137 0.89396035 0.89396944 0.89396085\n",
      " 0.89396204 0.8939634  0.89397064 0.89396123 0.89396705 0.89397009\n",
      " 0.89397127 0.89395682 0.89395731 0.89395543 0.89396713 0.89395092\n",
      " 0.89395642 0.89396178 0.89395519 0.89396772 0.89396193 0.89396804\n",
      " 0.89396738 0.89395783 0.89395883 0.89395421 0.89397033 0.89395483\n",
      " 0.89396561 0.89396384 0.89396653 0.89395857 0.89396291 0.89396119\n",
      " 0.89396164 0.89395909 0.89396916 0.89396109 0.8939552  0.89395868\n",
      " 0.89395448 0.89396461 0.89394963 0.89395322 0.8939614  0.89395862\n",
      "Epoch 26 RMSE: [0.88813915 0.88814554 0.88812237 0.88814555 0.88814064 0.88813834   | 26/40 [01:02<00:33,  2.39s/it]\n",
      " 0.88814185 0.88813351 0.88813954 0.8881378  0.88813859 0.88814365\n",
      " 0.8881398  0.88813593 0.88813996 0.88813158 0.88814047 0.88813423\n",
      " 0.88813165 0.88814795 0.8881389  0.8881399  0.88813335 0.88813264\n",
      " 0.8881292  0.88813777 0.88813767 0.88814902 0.88814186 0.88813312\n",
      " 0.88814095 0.8881371  0.88813693 0.88813993 0.88814556 0.8881475\n",
      " 0.88813369 0.88813441 0.8881312  0.88815094 0.8881356  0.88813904\n",
      " 0.88813984 0.88813946 0.88814259 0.88814413 0.88814973 0.88814092\n",
      " 0.88813763 0.88813751 0.88813871 0.88813734 0.88814611 0.8881377\n",
      " 0.88813875 0.8881401  0.88814686 0.88813846 0.88814378 0.88814658\n",
      " 0.88814785 0.88813425 0.88813473 0.88813305 0.88814419 0.88812872\n",
      " 0.88813375 0.88813917 0.88813233 0.88814456 0.88813878 0.88814465\n",
      " 0.88814438 0.88813536 0.88813623 0.88813153 0.88814687 0.88813212\n",
      " 0.88814238 0.88814128 0.88814313 0.88813559 0.8881398  0.88813829\n",
      " 0.88813821 0.88813624 0.88814575 0.88813772 0.88813232 0.8881359\n",
      " 0.88813164 0.88814152 0.88812695 0.88813106 0.8881383  0.88813582\n",
      "Epoch 26 RMSE: [0.88813915 0.88814554 0.88812237 0.88814555 0.88814064 0.88813834   | 26/40 [01:02<00:33,  2.39s/it]  \n",
      " 0.88814185 0.88813351 0.88813954 0.8881378  0.88813859 0.88814365\n",
      " 0.8881398  0.88813593 0.88813996 0.88813158 0.88814047 0.88813423\n",
      " 0.88813165 0.88814795 0.8881389  0.8881399  0.88813335 0.88813264\n",
      " 0.8881292  0.88813777 0.88813767 0.88814902 0.88814186 0.88813312\n",
      " 0.88814095 0.8881371  0.88813693 0.88813993 0.88814556 0.8881475\n",
      " 0.88813369 0.88813441 0.8881312  0.88815094 0.8881356  0.88813904\n",
      " 0.88813984 0.88813946 0.88814259 0.88814413 0.88814973 0.88814092\n",
      " 0.88813763 0.88813751 0.88813871 0.88813734 0.88814611 0.8881377\n",
      " 0.88813875 0.8881401  0.88814686 0.88813846 0.88814378 0.88814658\n",
      " 0.88814785 0.88813425 0.88813473 0.88813305 0.88814419 0.88812872\n",
      " 0.88813375 0.88813917 0.88813233 0.88814456 0.88813878 0.88814465\n",
      " 0.88814438 0.88813536 0.88813623 0.88813153 0.88814687 0.88813212\n",
      " 0.88814238 0.88814128 0.88814313 0.88813559 0.8881398  0.88813829\n",
      " 0.88813821 0.88813624 0.88814575 0.88813772 0.88813232 0.8881359\n",
      " 0.88813164 0.88814152 0.88812695 0.88813106 0.8881383  0.88813582\n",
      "Epoch 27 RMSE: [0.88210562 0.88211126 0.88208923 0.88211131 0.88210686 0.88210478   | 27/40 [01:05<00:31,  2.39s/it]\n",
      " 0.88210805 0.8821001  0.88210595 0.88210382 0.88210477 0.88210963\n",
      " 0.88210586 0.88210238 0.8821062  0.8820982  0.88210642 0.88210084\n",
      " 0.88209819 0.88211345 0.88210509 0.88210605 0.88209968 0.88209927\n",
      " 0.88209582 0.88210388 0.8821039  0.88211464 0.88210788 0.88209959\n",
      " 0.88210712 0.88210345 0.88210318 0.88210625 0.88211177 0.88211355\n",
      " 0.88210046 0.88210083 0.88209761 0.88211654 0.88210219 0.88210552\n",
      " 0.88210585 0.88210574 0.88210875 0.8821101  0.88211536 0.88210732\n",
      " 0.88210405 0.88210369 0.88210507 0.88210357 0.88211216 0.88210397\n",
      " 0.8821046  0.88210594 0.88211244 0.88210475 0.88210974 0.88211234\n",
      " 0.88211377 0.88210078 0.88210134 0.88209997 0.88211021 0.88209584\n",
      " 0.88210022 0.88210559 0.88209865 0.88211054 0.88210499 0.88211031\n",
      " 0.8821106  0.88210202 0.88210277 0.88209825 0.88211251 0.88209844\n",
      " 0.88210823 0.88210803 0.88210913 0.88210189 0.88210589 0.8821047\n",
      " 0.88210414 0.88210258 0.88211149 0.88210375 0.88209868 0.88210237\n",
      " 0.882098   0.8821077  0.88209366 0.88209777 0.88210481 0.88210228\n",
      "Epoch 27 RMSE: [0.88210562 0.88211126 0.88208923 0.88211131 0.88210686 0.88210478   | 27/40 [01:05<00:31,  2.39s/it]\n",
      " 0.88210805 0.8821001  0.88210595 0.88210382 0.88210477 0.88210963\n",
      " 0.88210586 0.88210238 0.8821062  0.8820982  0.88210642 0.88210084\n",
      " 0.88209819 0.88211345 0.88210509 0.88210605 0.88209968 0.88209927\n",
      " 0.88209582 0.88210388 0.8821039  0.88211464 0.88210788 0.88209959\n",
      " 0.88210712 0.88210345 0.88210318 0.88210625 0.88211177 0.88211355\n",
      " 0.88210046 0.88210083 0.88209761 0.88211654 0.88210219 0.88210552\n",
      " 0.88210585 0.88210574 0.88210875 0.8821101  0.88211536 0.88210732\n",
      " 0.88210405 0.88210369 0.88210507 0.88210357 0.88211216 0.88210397\n",
      " 0.8821046  0.88210594 0.88211244 0.88210475 0.88210974 0.88211234\n",
      " 0.88211377 0.88210078 0.88210134 0.88209997 0.88211021 0.88209584\n",
      " 0.88210022 0.88210559 0.88209865 0.88211054 0.88210499 0.88211031\n",
      " 0.8821106  0.88210202 0.88210277 0.88209825 0.88211251 0.88209844\n",
      " 0.88210823 0.88210803 0.88210913 0.88210189 0.88210589 0.8821047\n",
      " 0.88210414 0.88210258 0.88211149 0.88210375 0.88209868 0.88210237\n",
      " 0.882098   0.8821077  0.88209366 0.88209777 0.88210481 0.88210228\n",
      "Epoch 28 RMSE: [0.87565692 0.87566189 0.87564095 0.87566218 0.87565795 0.87565605   | 28/40 [01:07<00:28,  2.39s/it]\n",
      " 0.87565909 0.87565153 0.87565725 0.87565463 0.87565602 0.87566048\n",
      " 0.87565678 0.87565358 0.87565725 0.87564969 0.87565727 0.87565217\n",
      " 0.8756495  0.875664   0.87565625 0.87565704 0.87565086 0.8756508\n",
      " 0.8756472  0.87565495 0.87565502 0.87566519 0.87565886 0.87565115\n",
      " 0.87565807 0.87565482 0.87565438 0.87565743 0.87566278 0.87566426\n",
      " 0.87565197 0.87565215 0.87564886 0.87566716 0.87565374 0.87565689\n",
      " 0.87565681 0.87565688 0.87565987 0.87566099 0.87566612 0.87565845\n",
      " 0.87565536 0.87565489 0.87565651 0.87565477 0.87566306 0.87565502\n",
      " 0.87565545 0.87565673 0.87566285 0.87565596 0.8756605  0.87566306\n",
      " 0.87566459 0.87565221 0.87565271 0.87565164 0.87566135 0.8756476\n",
      " 0.87565177 0.87565694 0.87564981 0.87566128 0.87565583 0.87566112\n",
      " 0.87566154 0.87565343 0.8756542  0.87564977 0.87566313 0.87564967\n",
      " 0.87565908 0.87565954 0.87565998 0.87565301 0.87565698 0.875656\n",
      " 0.87565497 0.87565387 0.87566218 0.87565457 0.87564996 0.87565371\n",
      " 0.87564916 0.87565878 0.87564516 0.87564938 0.87565599 0.87565357\n",
      "Epoch 28 RMSE: [0.87565692 0.87566189 0.87564095 0.87566218 0.87565795 0.87565605   | 28/40 [01:07<00:28,  2.39s/it]   \n",
      " 0.87565909 0.87565153 0.87565725 0.87565463 0.87565602 0.87566048\n",
      " 0.87565678 0.87565358 0.87565725 0.87564969 0.87565727 0.87565217\n",
      " 0.8756495  0.875664   0.87565625 0.87565704 0.87565086 0.8756508\n",
      " 0.8756472  0.87565495 0.87565502 0.87566519 0.87565886 0.87565115\n",
      " 0.87565807 0.87565482 0.87565438 0.87565743 0.87566278 0.87566426\n",
      " 0.87565197 0.87565215 0.87564886 0.87566716 0.87565374 0.87565689\n",
      " 0.87565681 0.87565688 0.87565987 0.87566099 0.87566612 0.87565845\n",
      " 0.87565536 0.87565489 0.87565651 0.87565477 0.87566306 0.87565502\n",
      " 0.87565545 0.87565673 0.87566285 0.87565596 0.8756605  0.87566306\n",
      " 0.87566459 0.87565221 0.87565271 0.87565164 0.87566135 0.8756476\n",
      " 0.87565177 0.87565694 0.87564981 0.87566128 0.87565583 0.87566112\n",
      " 0.87566154 0.87565343 0.8756542  0.87564977 0.87566313 0.87564967\n",
      " 0.87565908 0.87565954 0.87565998 0.87565301 0.87565698 0.875656\n",
      " 0.87565497 0.87565387 0.87566218 0.87565457 0.87564996 0.87565371\n",
      " 0.87564916 0.87565878 0.87564516 0.87564938 0.87565599 0.87565357\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29 RMSE: [0.86889631 0.86890086 0.86888072 0.86890107 0.86889721 0.8688955█▎  | 29/40 [01:10<00:26,  2.40s/it]\n",
      " 0.86889828 0.86889117 0.86889675 0.86889368 0.86889532 0.86889949\n",
      " 0.86889593 0.86889294 0.86889654 0.86888923 0.86889628 0.86889179\n",
      " 0.86888901 0.86890262 0.8688956  0.86889617 0.86889024 0.86889047\n",
      " 0.86888686 0.86889417 0.86889428 0.86890411 0.86889792 0.86889061\n",
      " 0.86889721 0.868894   0.86889356 0.86889675 0.8689019  0.8689032\n",
      " 0.86889165 0.86889161 0.86888831 0.86890585 0.86889323 0.86889634\n",
      " 0.86889585 0.8688961  0.86889915 0.86889995 0.86890494 0.86889787\n",
      " 0.86889458 0.86889406 0.86889589 0.86889401 0.86890207 0.86889425\n",
      " 0.86889451 0.8688958  0.86890159 0.86889536 0.86889957 0.86890191\n",
      " 0.86890345 0.86889177 0.86889218 0.86889121 0.86890034 0.86888756\n",
      " 0.86889132 0.86889642 0.8688892  0.86890046 0.86889511 0.86889996\n",
      " 0.86890076 0.86889303 0.86889375 0.86888928 0.86890196 0.86888924\n",
      " 0.86889784 0.86889914 0.86889884 0.86889234 0.86889602 0.86889544\n",
      " 0.8688939  0.86889325 0.86890111 0.8688936  0.86888929 0.86889311\n",
      " 0.86888863 0.868898   0.86888475 0.86888913 0.86889527 0.86889298\n",
      "Epoch 29 RMSE: [0.86889631 0.86890086 0.86888072 0.86890107 0.86889721 0.8688955█▎  | 29/40 [01:10<00:26,  2.40s/it]\n",
      " 0.86889828 0.86889117 0.86889675 0.86889368 0.86889532 0.86889949\n",
      " 0.86889593 0.86889294 0.86889654 0.86888923 0.86889628 0.86889179\n",
      " 0.86888901 0.86890262 0.8688956  0.86889617 0.86889024 0.86889047\n",
      " 0.86888686 0.86889417 0.86889428 0.86890411 0.86889792 0.86889061\n",
      " 0.86889721 0.868894   0.86889356 0.86889675 0.8689019  0.8689032\n",
      " 0.86889165 0.86889161 0.86888831 0.86890585 0.86889323 0.86889634\n",
      " 0.86889585 0.8688961  0.86889915 0.86889995 0.86890494 0.86889787\n",
      " 0.86889458 0.86889406 0.86889589 0.86889401 0.86890207 0.86889425\n",
      " 0.86889451 0.8688958  0.86890159 0.86889536 0.86889957 0.86890191\n",
      " 0.86890345 0.86889177 0.86889218 0.86889121 0.86890034 0.86888756\n",
      " 0.86889132 0.86889642 0.8688892  0.86890046 0.86889511 0.86889996\n",
      " 0.86890076 0.86889303 0.86889375 0.86888928 0.86890196 0.86888924\n",
      " 0.86889784 0.86889914 0.86889884 0.86889234 0.86889602 0.86889544\n",
      " 0.8688939  0.86889325 0.86890111 0.8688936  0.86888929 0.86889311\n",
      " 0.86888863 0.868898   0.86888475 0.86888913 0.86889527 0.86889298\n",
      "Epoch 30 RMSE: [0.86192812 0.86193219 0.86191303 0.86193246 0.86192887 0.86192735▌  | 30/40 [01:12<00:23,  2.39s/it]\n",
      " 0.86192998 0.86192315 0.86192862 0.86192525 0.86192709 0.86193092\n",
      " 0.86192752 0.86192477 0.86192829 0.86192116 0.86192787 0.86192382\n",
      " 0.86192098 0.86193383 0.86192743 0.8619278  0.8619221  0.86192256\n",
      " 0.86191887 0.86192581 0.86192609 0.86193541 0.86192946 0.86192254\n",
      " 0.86192879 0.86192587 0.86192527 0.8619285  0.86193354 0.86193463\n",
      " 0.86192368 0.86192349 0.86192022 0.86193713 0.8619252  0.86192839\n",
      " 0.86192746 0.86192779 0.86193076 0.86193148 0.86193627 0.86192961\n",
      " 0.86192651 0.8619257  0.86192776 0.86192582 0.86193364 0.86192596\n",
      " 0.86192595 0.86192726 0.86193289 0.86192717 0.86193094 0.8619332\n",
      " 0.86193486 0.86192368 0.86192413 0.86192332 0.86193194 0.86191989\n",
      " 0.86192321 0.86192831 0.86192107 0.86193196 0.86192682 0.86193139\n",
      " 0.86193237 0.86192516 0.86192579 0.86192134 0.86193331 0.86192106\n",
      " 0.86192917 0.86193116 0.86193034 0.86192405 0.86192776 0.86192727\n",
      " 0.86192533 0.86192518 0.86193249 0.86192515 0.86192107 0.86192498\n",
      " 0.86192046 0.86192969 0.86191688 0.8619213  0.861927   0.86192464\n",
      "Epoch 30 RMSE: [0.86192812 0.86193219 0.86191303 0.86193246 0.86192887 0.86192735▌  | 30/40 [01:12<00:23,  2.39s/it]\n",
      " 0.86192998 0.86192315 0.86192862 0.86192525 0.86192709 0.86193092\n",
      " 0.86192752 0.86192477 0.86192829 0.86192116 0.86192787 0.86192382\n",
      " 0.86192098 0.86193383 0.86192743 0.8619278  0.8619221  0.86192256\n",
      " 0.86191887 0.86192581 0.86192609 0.86193541 0.86192946 0.86192254\n",
      " 0.86192879 0.86192587 0.86192527 0.8619285  0.86193354 0.86193463\n",
      " 0.86192368 0.86192349 0.86192022 0.86193713 0.8619252  0.86192839\n",
      " 0.86192746 0.86192779 0.86193076 0.86193148 0.86193627 0.86192961\n",
      " 0.86192651 0.8619257  0.86192776 0.86192582 0.86193364 0.86192596\n",
      " 0.86192595 0.86192726 0.86193289 0.86192717 0.86193094 0.8619332\n",
      " 0.86193486 0.86192368 0.86192413 0.86192332 0.86193194 0.86191989\n",
      " 0.86192321 0.86192831 0.86192107 0.86193196 0.86192682 0.86193139\n",
      " 0.86193237 0.86192516 0.86192579 0.86192134 0.86193331 0.86192106\n",
      " 0.86192917 0.86193116 0.86193034 0.86192405 0.86192776 0.86192727\n",
      " 0.86192533 0.86192518 0.86193249 0.86192515 0.86192107 0.86192498\n",
      " 0.86192046 0.86192969 0.86191688 0.8619213  0.861927   0.86192464\n",
      "Epoch 31 RMSE: [0.85460494 0.85460845 0.85459006 0.85460887 0.85460553 0.85460401▊  | 31/40 [01:15<00:22,  2.48s/it]\n",
      " 0.85460646 0.85459995 0.85460541 0.85460188 0.85460372 0.85460731\n",
      " 0.854604   0.85460138 0.85460479 0.8545981  0.85460427 0.8546007\n",
      " 0.85459779 0.8546099  0.85460408 0.85460425 0.85459882 0.8545994\n",
      " 0.85459572 0.85460232 0.85460275 0.85461164 0.85460591 0.8545995\n",
      " 0.85460535 0.85460262 0.85460197 0.85460507 0.85460999 0.85461095\n",
      " 0.8546007  0.85460036 0.85459693 0.8546133  0.85460194 0.85460515\n",
      " 0.85460396 0.85460441 0.85460736 0.85460799 0.85461262 0.85460627\n",
      " 0.8546031  0.85460228 0.85460449 0.85460242 0.85461011 0.85460262\n",
      " 0.8546025  0.85460376 0.85460904 0.8546038  0.85460733 0.85460953\n",
      " 0.85461131 0.85460064 0.85460087 0.85460028 0.85460846 0.85459713\n",
      " 0.85460006 0.85460517 0.85459777 0.85460837 0.8546034  0.85460774\n",
      " 0.85460892 0.85460199 0.85460267 0.85459819 0.8546096  0.8545978\n",
      " 0.85460559 0.85460796 0.85460667 0.85460072 0.85460433 0.85460402\n",
      " 0.85460173 0.85460189 0.85460865 0.85460155 0.85459776 0.85460177\n",
      " 0.85459706 0.85460631 0.85459377 0.85459838 0.85460372 0.85460149\n",
      "Epoch 31 RMSE: [0.85460494 0.85460845 0.85459006 0.85460887 0.85460553 0.85460401▊  | 31/40 [01:15<00:22,  2.48s/it]   \n",
      " 0.85460646 0.85459995 0.85460541 0.85460188 0.85460372 0.85460731\n",
      " 0.854604   0.85460138 0.85460479 0.8545981  0.85460427 0.8546007\n",
      " 0.85459779 0.8546099  0.85460408 0.85460425 0.85459882 0.8545994\n",
      " 0.85459572 0.85460232 0.85460275 0.85461164 0.85460591 0.8545995\n",
      " 0.85460535 0.85460262 0.85460197 0.85460507 0.85460999 0.85461095\n",
      " 0.8546007  0.85460036 0.85459693 0.8546133  0.85460194 0.85460515\n",
      " 0.85460396 0.85460441 0.85460736 0.85460799 0.85461262 0.85460627\n",
      " 0.8546031  0.85460228 0.85460449 0.85460242 0.85461011 0.85460262\n",
      " 0.8546025  0.85460376 0.85460904 0.8546038  0.85460733 0.85460953\n",
      " 0.85461131 0.85460064 0.85460087 0.85460028 0.85460846 0.85459713\n",
      " 0.85460006 0.85460517 0.85459777 0.85460837 0.8546034  0.85460774\n",
      " 0.85460892 0.85460199 0.85460267 0.85459819 0.8546096  0.8545978\n",
      " 0.85460559 0.85460796 0.85460667 0.85460072 0.85460433 0.85460402\n",
      " 0.85460173 0.85460189 0.85460865 0.85460155 0.85459776 0.85460177\n",
      " 0.85459706 0.85460631 0.85459377 0.85459838 0.85460372 0.85460149\n",
      "Epoch 32 RMSE: [0.84699439 0.84699767 0.8469799  0.84699794 0.84699484 0.84699372█  | 32/40 [01:17<00:19,  2.46s/it]\n",
      " 0.84699584 0.8469897  0.84699497 0.8469911  0.84699306 0.84699638\n",
      " 0.84699326 0.84699088 0.84699428 0.84698774 0.8469936  0.84699039\n",
      " 0.84698744 0.84699885 0.84699355 0.84699366 0.8469883  0.84698922\n",
      " 0.84698555 0.84699181 0.84699215 0.84700073 0.84699522 0.8469891\n",
      " 0.84699478 0.846992   0.84699133 0.8469945  0.84699923 0.84700008\n",
      " 0.8469903  0.8469898  0.84698658 0.84700231 0.84699156 0.8469947\n",
      " 0.84699318 0.84699383 0.84699687 0.84699715 0.84700179 0.84699568\n",
      " 0.84699266 0.84699176 0.84699398 0.84699188 0.84699932 0.84699208\n",
      " 0.84699184 0.84699306 0.84699801 0.8469934  0.84699651 0.84699856\n",
      " 0.84700039 0.8469902  0.8469905  0.84698998 0.8469977  0.84698715\n",
      " 0.84698967 0.84699458 0.84698729 0.84699746 0.84699281 0.84699688\n",
      " 0.84699831 0.84699173 0.8469924  0.84698786 0.84699875 0.84698738\n",
      " 0.84699483 0.8469976  0.84699594 0.84699025 0.8469936  0.84699353\n",
      " 0.84699101 0.84699135 0.84699786 0.84699079 0.84698747 0.84699128\n",
      " 0.84698664 0.84699562 0.84698352 0.84698809 0.84699317 0.84699104\n",
      "Epoch 32 RMSE: [0.84699439 0.84699767 0.8469799  0.84699794 0.84699484 0.84699372█  | 32/40 [01:17<00:19,  2.46s/it]\n",
      " 0.84699584 0.8469897  0.84699497 0.8469911  0.84699306 0.84699638\n",
      " 0.84699326 0.84699088 0.84699428 0.84698774 0.8469936  0.84699039\n",
      " 0.84698744 0.84699885 0.84699355 0.84699366 0.8469883  0.84698922\n",
      " 0.84698555 0.84699181 0.84699215 0.84700073 0.84699522 0.8469891\n",
      " 0.84699478 0.846992   0.84699133 0.8469945  0.84699923 0.84700008\n",
      " 0.8469903  0.8469898  0.84698658 0.84700231 0.84699156 0.8469947\n",
      " 0.84699318 0.84699383 0.84699687 0.84699715 0.84700179 0.84699568\n",
      " 0.84699266 0.84699176 0.84699398 0.84699188 0.84699932 0.84699208\n",
      " 0.84699184 0.84699306 0.84699801 0.8469934  0.84699651 0.84699856\n",
      " 0.84700039 0.8469902  0.8469905  0.84698998 0.8469977  0.84698715\n",
      " 0.84698967 0.84699458 0.84698729 0.84699746 0.84699281 0.84699688\n",
      " 0.84699831 0.84699173 0.8469924  0.84698786 0.84699875 0.84698738\n",
      " 0.84699483 0.8469976  0.84699594 0.84699025 0.8469936  0.84699353\n",
      " 0.84699101 0.84699135 0.84699786 0.84699079 0.84698747 0.84699128\n",
      " 0.84698664 0.84699562 0.84698352 0.84698809 0.84699317 0.84699104\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33 RMSE: [0.83875367 0.83875651 0.83873955 0.83875691 0.83875405 0.83875292█▎ | 33/40 [01:19<00:17,  2.43s/it]\n",
      " 0.83875482 0.83874898 0.83875422 0.83875015 0.83875224 0.83875543\n",
      " 0.8387522  0.83875012 0.8387534  0.83874705 0.83875253 0.83874978\n",
      " 0.83874684 0.83875766 0.8387528  0.8387526  0.83874762 0.83874864\n",
      " 0.83874501 0.83875084 0.83875135 0.83875956 0.83875426 0.83874856\n",
      " 0.8387537  0.83875149 0.83875047 0.83875362 0.83875843 0.83875902\n",
      " 0.83874978 0.83874908 0.83874593 0.83876096 0.83875086 0.83875414\n",
      " 0.83875214 0.83875291 0.83875595 0.83875619 0.83876049 0.83875497\n",
      " 0.83875172 0.83875078 0.83875316 0.83875099 0.83875813 0.83875121\n",
      " 0.83875075 0.83875199 0.83875683 0.83875258 0.83875547 0.83875734\n",
      " 0.83875926 0.83874945 0.83874978 0.83874942 0.83875676 0.8387468\n",
      " 0.83874893 0.8387539  0.8387466  0.83875661 0.83875198 0.83875579\n",
      " 0.83875737 0.83875116 0.83875181 0.83874742 0.83875752 0.83874663\n",
      " 0.8387537  0.83875691 0.83875493 0.83874943 0.83875288 0.83875274\n",
      " 0.83874997 0.83875062 0.83875674 0.83874987 0.83874648 0.83875053\n",
      " 0.83874586 0.83875478 0.83874302 0.83874753 0.83875232 0.83875024\n",
      "Epoch 33 RMSE: [0.83875367 0.83875651 0.83873955 0.83875691 0.83875405 0.83875292█▎ | 33/40 [01:19<00:17,  2.43s/it]\n",
      " 0.83875482 0.83874898 0.83875422 0.83875015 0.83875224 0.83875543\n",
      " 0.8387522  0.83875012 0.8387534  0.83874705 0.83875253 0.83874978\n",
      " 0.83874684 0.83875766 0.8387528  0.8387526  0.83874762 0.83874864\n",
      " 0.83874501 0.83875084 0.83875135 0.83875956 0.83875426 0.83874856\n",
      " 0.8387537  0.83875149 0.83875047 0.83875362 0.83875843 0.83875902\n",
      " 0.83874978 0.83874908 0.83874593 0.83876096 0.83875086 0.83875414\n",
      " 0.83875214 0.83875291 0.83875595 0.83875619 0.83876049 0.83875497\n",
      " 0.83875172 0.83875078 0.83875316 0.83875099 0.83875813 0.83875121\n",
      " 0.83875075 0.83875199 0.83875683 0.83875258 0.83875547 0.83875734\n",
      " 0.83875926 0.83874945 0.83874978 0.83874942 0.83875676 0.8387468\n",
      " 0.83874893 0.8387539  0.8387466  0.83875661 0.83875198 0.83875579\n",
      " 0.83875737 0.83875116 0.83875181 0.83874742 0.83875752 0.83874663\n",
      " 0.8387537  0.83875691 0.83875493 0.83874943 0.83875288 0.83875274\n",
      " 0.83874997 0.83875062 0.83875674 0.83874987 0.83874648 0.83875053\n",
      " 0.83874586 0.83875478 0.83874302 0.83874753 0.83875232 0.83875024\n",
      "Epoch 34 RMSE: [0.83035544 0.83035804 0.83034167 0.83035847 0.83035574 0.83035477█▌ | 34/40 [01:22<00:14,  2.40s/it]\n",
      " 0.83035649 0.83035096 0.83035615 0.83035188 0.830354   0.83035688\n",
      " 0.830354   0.83035198 0.83035515 0.83034898 0.83035412 0.83035166\n",
      " 0.83034873 0.83035899 0.83035465 0.83035423 0.83034937 0.83035063\n",
      " 0.83034704 0.83035257 0.83035301 0.83036108 0.83035596 0.83035061\n",
      " 0.8303553  0.83035321 0.83035214 0.83035538 0.83035994 0.83036051\n",
      " 0.83035182 0.83035092 0.83034791 0.83036238 0.8303527  0.83035591\n",
      " 0.83035386 0.83035462 0.83035766 0.83035786 0.83036194 0.83035665\n",
      " 0.83035356 0.83035258 0.83035501 0.83035271 0.83035977 0.83035306\n",
      " 0.83035243 0.8303536  0.83035819 0.83035435 0.83035698 0.83035884\n",
      " 0.83036084 0.83035151 0.83035163 0.83035135 0.83035833 0.83034914\n",
      " 0.83035078 0.83035569 0.83034846 0.83035808 0.83035365 0.8303573\n",
      " 0.83035897 0.83035315 0.83035383 0.83034943 0.83035898 0.83034858\n",
      " 0.8303552  0.8303588  0.8303565  0.83035125 0.83035462 0.8303546\n",
      " 0.83035159 0.83035255 0.83035809 0.83035158 0.83034838 0.8303524\n",
      " 0.83034775 0.83035644 0.83034512 0.83034972 0.83035411 0.83035205\n",
      "Epoch 34 RMSE: [0.83035544 0.83035804 0.83034167 0.83035847 0.83035574 0.83035477█▌ | 34/40 [01:22<00:14,  2.40s/it]  \n",
      " 0.83035649 0.83035096 0.83035615 0.83035188 0.830354   0.83035688\n",
      " 0.830354   0.83035198 0.83035515 0.83034898 0.83035412 0.83035166\n",
      " 0.83034873 0.83035899 0.83035465 0.83035423 0.83034937 0.83035063\n",
      " 0.83034704 0.83035257 0.83035301 0.83036108 0.83035596 0.83035061\n",
      " 0.8303553  0.83035321 0.83035214 0.83035538 0.83035994 0.83036051\n",
      " 0.83035182 0.83035092 0.83034791 0.83036238 0.8303527  0.83035591\n",
      " 0.83035386 0.83035462 0.83035766 0.83035786 0.83036194 0.83035665\n",
      " 0.83035356 0.83035258 0.83035501 0.83035271 0.83035977 0.83035306\n",
      " 0.83035243 0.8303536  0.83035819 0.83035435 0.83035698 0.83035884\n",
      " 0.83036084 0.83035151 0.83035163 0.83035135 0.83035833 0.83034914\n",
      " 0.83035078 0.83035569 0.83034846 0.83035808 0.83035365 0.8303573\n",
      " 0.83035897 0.83035315 0.83035383 0.83034943 0.83035898 0.83034858\n",
      " 0.8303552  0.8303588  0.8303565  0.83035125 0.83035462 0.8303546\n",
      " 0.83035159 0.83035255 0.83035809 0.83035158 0.83034838 0.8303524\n",
      " 0.83034775 0.83035644 0.83034512 0.83034972 0.83035411 0.83035205\n",
      "Epoch 35 RMSE: [0.8214729  0.82147524 0.82145939 0.82147562 0.82147307 0.82147219█▊ | 35/40 [01:24<00:11,  2.40s/it]\n",
      " 0.8214738  0.82146852 0.82147354 0.82146919 0.82147131 0.8214741\n",
      " 0.82147121 0.82146941 0.82147254 0.82146649 0.82147127 0.82146921\n",
      " 0.82146623 0.82147592 0.82147201 0.82147148 0.8214668  0.82146827\n",
      " 0.82146472 0.82146992 0.82147043 0.82147809 0.82147317 0.82146808\n",
      " 0.8214725  0.82147075 0.82146953 0.82147277 0.82147715 0.82147767\n",
      " 0.82146933 0.82146842 0.82146546 0.82147929 0.82147015 0.82147345\n",
      " 0.82147114 0.82147193 0.82147493 0.82147502 0.82147899 0.82147406\n",
      " 0.82147082 0.82146983 0.8214724  0.82147011 0.82147688 0.82147045\n",
      " 0.8214698  0.82147083 0.8214753  0.82147177 0.82147411 0.82147595\n",
      " 0.82147806 0.82146902 0.82146919 0.82146889 0.82147551 0.82146696\n",
      " 0.82146833 0.82147311 0.82146583 0.82147531 0.82147094 0.82147445\n",
      " 0.82147624 0.82147083 0.82147144 0.82146699 0.82147605 0.82146607\n",
      " 0.82147232 0.82147618 0.82147384 0.82146869 0.82147197 0.821472\n",
      " 0.82146881 0.82147009 0.82147523 0.82146891 0.82146581 0.82146996\n",
      " 0.82146526 0.8214737  0.82146278 0.82146736 0.82147153 0.82146937\n",
      "Epoch 35 RMSE: [0.8214729  0.82147524 0.82145939 0.82147562 0.82147307 0.82147219█▊ | 35/40 [01:24<00:11,  2.40s/it]\n",
      " 0.8214738  0.82146852 0.82147354 0.82146919 0.82147131 0.8214741\n",
      " 0.82147121 0.82146941 0.82147254 0.82146649 0.82147127 0.82146921\n",
      " 0.82146623 0.82147592 0.82147201 0.82147148 0.8214668  0.82146827\n",
      " 0.82146472 0.82146992 0.82147043 0.82147809 0.82147317 0.82146808\n",
      " 0.8214725  0.82147075 0.82146953 0.82147277 0.82147715 0.82147767\n",
      " 0.82146933 0.82146842 0.82146546 0.82147929 0.82147015 0.82147345\n",
      " 0.82147114 0.82147193 0.82147493 0.82147502 0.82147899 0.82147406\n",
      " 0.82147082 0.82146983 0.8214724  0.82147011 0.82147688 0.82147045\n",
      " 0.8214698  0.82147083 0.8214753  0.82147177 0.82147411 0.82147595\n",
      " 0.82147806 0.82146902 0.82146919 0.82146889 0.82147551 0.82146696\n",
      " 0.82146833 0.82147311 0.82146583 0.82147531 0.82147094 0.82147445\n",
      " 0.82147624 0.82147083 0.82147144 0.82146699 0.82147605 0.82146607\n",
      " 0.82147232 0.82147618 0.82147384 0.82146869 0.82147197 0.821472\n",
      " 0.82146881 0.82147009 0.82147523 0.82146891 0.82146581 0.82146996\n",
      " 0.82146526 0.8214737  0.82146278 0.82146736 0.82147153 0.82146937\n",
      "Epoch 36 RMSE: [0.81217636 0.81217838 0.81216314 0.81217887 0.81217654 0.81217581██ | 36/40 [01:27<00:09,  2.38s/it]\n",
      " 0.81217706 0.81217216 0.8121771  0.81217266 0.81217472 0.81217733\n",
      " 0.81217459 0.81217295 0.81217602 0.81217003 0.8121746  0.81217288\n",
      " 0.81216985 0.8121791  0.81217559 0.81217486 0.81217041 0.8121719\n",
      " 0.81216843 0.8121734  0.81217388 0.81218128 0.81217656 0.81217177\n",
      " 0.81217588 0.81217428 0.81217299 0.81217616 0.81218048 0.81218104\n",
      " 0.81217301 0.81217194 0.81216907 0.81218239 0.81217368 0.812177\n",
      " 0.81217443 0.81217538 0.81217836 0.81217839 0.8121822  0.81217747\n",
      " 0.81217426 0.81217336 0.81217596 0.81217355 0.81218016 0.812174\n",
      " 0.81217316 0.81217415 0.81217845 0.81217527 0.81217745 0.81217914\n",
      " 0.81218127 0.81217264 0.81217281 0.8121725  0.8121788  0.81217083\n",
      " 0.81217185 0.8121766  0.81216943 0.81217854 0.81217442 0.8121777\n",
      " 0.81217972 0.81217442 0.81217514 0.8121707  0.81217926 0.81216959\n",
      " 0.81217561 0.81217967 0.81217712 0.81217216 0.81217545 0.8121755\n",
      " 0.8121722  0.81217354 0.81217842 0.81217225 0.81216932 0.81217352\n",
      " 0.81216881 0.81217714 0.81216646 0.81217107 0.81217497 0.81217293\n",
      "Epoch 36 RMSE: [0.81217636 0.81217838 0.81216314 0.81217887 0.81217654 0.81217581██ | 36/40 [01:27<00:09,  2.38s/it]    \n",
      " 0.81217706 0.81217216 0.8121771  0.81217266 0.81217472 0.81217733\n",
      " 0.81217459 0.81217295 0.81217602 0.81217003 0.8121746  0.81217288\n",
      " 0.81216985 0.8121791  0.81217559 0.81217486 0.81217041 0.8121719\n",
      " 0.81216843 0.8121734  0.81217388 0.81218128 0.81217656 0.81217177\n",
      " 0.81217588 0.81217428 0.81217299 0.81217616 0.81218048 0.81218104\n",
      " 0.81217301 0.81217194 0.81216907 0.81218239 0.81217368 0.812177\n",
      " 0.81217443 0.81217538 0.81217836 0.81217839 0.8121822  0.81217747\n",
      " 0.81217426 0.81217336 0.81217596 0.81217355 0.81218016 0.812174\n",
      " 0.81217316 0.81217415 0.81217845 0.81217527 0.81217745 0.81217914\n",
      " 0.81218127 0.81217264 0.81217281 0.8121725  0.8121788  0.81217083\n",
      " 0.81217185 0.8121766  0.81216943 0.81217854 0.81217442 0.8121777\n",
      " 0.81217972 0.81217442 0.81217514 0.8121707  0.81217926 0.81216959\n",
      " 0.81217561 0.81217967 0.81217712 0.81217216 0.81217545 0.8121755\n",
      " 0.8121722  0.81217354 0.81217842 0.81217225 0.81216932 0.81217352\n",
      " 0.81216881 0.81217714 0.81216646 0.81217107 0.81217497 0.81217293\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37 RMSE: [0.80252532 0.80252725 0.80251242 0.80252783 0.80252539 0.80252478██▎| 37/40 [01:29<00:07,  2.38s/it]\n",
      " 0.8025259  0.80252122 0.80252609 0.80252151 0.80252362 0.80252608\n",
      " 0.80252346 0.80252195 0.80252503 0.80251916 0.80252351 0.80252204\n",
      " 0.80251897 0.80252778 0.80252459 0.80252369 0.80251944 0.80252106\n",
      " 0.80251763 0.80252237 0.80252287 0.80252996 0.8025255  0.80252094\n",
      " 0.80252473 0.80252344 0.80252194 0.80252521 0.80252927 0.80252972\n",
      " 0.80252212 0.80252092 0.80251818 0.80253098 0.80252267 0.8025262\n",
      " 0.80252327 0.80252429 0.80252731 0.80252717 0.80253087 0.80252644\n",
      " 0.80252318 0.80252222 0.80252488 0.80252256 0.80252894 0.8025229\n",
      " 0.80252205 0.80252291 0.80252718 0.80252424 0.80252625 0.8025278\n",
      " 0.80253003 0.8025218  0.80252183 0.80252168 0.80252762 0.80252022\n",
      " 0.80252098 0.80252566 0.8025185  0.80252745 0.8025234  0.80252656\n",
      " 0.80252855 0.80252361 0.80252428 0.80252    0.80252802 0.80251868\n",
      " 0.8025244  0.80252873 0.80252603 0.80252113 0.80252447 0.80252464\n",
      " 0.80252102 0.80252265 0.80252716 0.80252113 0.80251829 0.80252258\n",
      " 0.80251785 0.80252601 0.80251562 0.80252025 0.80252401 0.80252183\n",
      "Epoch 37 RMSE: [0.80252532 0.80252725 0.80251242 0.80252783 0.80252539 0.80252478██▎| 37/40 [01:29<00:07,  2.38s/it]\n",
      " 0.8025259  0.80252122 0.80252609 0.80252151 0.80252362 0.80252608\n",
      " 0.80252346 0.80252195 0.80252503 0.80251916 0.80252351 0.80252204\n",
      " 0.80251897 0.80252778 0.80252459 0.80252369 0.80251944 0.80252106\n",
      " 0.80251763 0.80252237 0.80252287 0.80252996 0.8025255  0.80252094\n",
      " 0.80252473 0.80252344 0.80252194 0.80252521 0.80252927 0.80252972\n",
      " 0.80252212 0.80252092 0.80251818 0.80253098 0.80252267 0.8025262\n",
      " 0.80252327 0.80252429 0.80252731 0.80252717 0.80253087 0.80252644\n",
      " 0.80252318 0.80252222 0.80252488 0.80252256 0.80252894 0.8025229\n",
      " 0.80252205 0.80252291 0.80252718 0.80252424 0.80252625 0.8025278\n",
      " 0.80253003 0.8025218  0.80252183 0.80252168 0.80252762 0.80252022\n",
      " 0.80252098 0.80252566 0.8025185  0.80252745 0.8025234  0.80252656\n",
      " 0.80252855 0.80252361 0.80252428 0.80252    0.80252802 0.80251868\n",
      " 0.8025244  0.80252873 0.80252603 0.80252113 0.80252447 0.80252464\n",
      " 0.80252102 0.80252265 0.80252716 0.80252113 0.80251829 0.80252258\n",
      " 0.80251785 0.80252601 0.80251562 0.80252025 0.80252401 0.80252183\n",
      "Epoch 38 RMSE: [0.79251062 0.79251227 0.79249797 0.79251285 0.79251066 0.79251015██▌| 38/40 [01:31<00:04,  2.38s/it]\n",
      " 0.79251103 0.79250659 0.79251131 0.79250683 0.79250881 0.7925112\n",
      " 0.79250865 0.7925072  0.79251026 0.79250454 0.79250855 0.79250742\n",
      " 0.79250433 0.79251263 0.79250989 0.79250885 0.79250477 0.79250644\n",
      " 0.79250308 0.79250752 0.79250805 0.792515   0.79251052 0.79250641\n",
      " 0.79250984 0.79250861 0.79250719 0.7925104  0.79251439 0.79251481\n",
      " 0.79250758 0.79250625 0.79250359 0.79251587 0.79250798 0.79251144\n",
      " 0.79250841 0.79250943 0.7925125  0.79251235 0.79251588 0.7925116\n",
      " 0.79250837 0.79250745 0.79251016 0.7925077  0.79251408 0.79250827\n",
      " 0.7925073  0.79250815 0.79251217 0.79250951 0.79251136 0.79251279\n",
      " 0.79251522 0.79250715 0.7925072  0.79250696 0.79251271 0.79250588\n",
      " 0.79250626 0.79251087 0.79250379 0.79251246 0.79250858 0.79251161\n",
      " 0.7925137  0.79250899 0.79250973 0.79250533 0.792513   0.79250409\n",
      " 0.79250943 0.79251398 0.79251109 0.79250638 0.79250969 0.79250979\n",
      " 0.79250619 0.79250803 0.79251216 0.79250638 0.79250363 0.79250783\n",
      " 0.79250324 0.79251124 0.79250114 0.79250579 0.79250924 0.79250711\n",
      "Epoch 38 RMSE: [0.79251062 0.79251227 0.79249797 0.79251285 0.79251066 0.79251015██▌| 38/40 [01:31<00:04,  2.38s/it]\n",
      " 0.79251103 0.79250659 0.79251131 0.79250683 0.79250881 0.7925112\n",
      " 0.79250865 0.7925072  0.79251026 0.79250454 0.79250855 0.79250742\n",
      " 0.79250433 0.79251263 0.79250989 0.79250885 0.79250477 0.79250644\n",
      " 0.79250308 0.79250752 0.79250805 0.792515   0.79251052 0.79250641\n",
      " 0.79250984 0.79250861 0.79250719 0.7925104  0.79251439 0.79251481\n",
      " 0.79250758 0.79250625 0.79250359 0.79251587 0.79250798 0.79251144\n",
      " 0.79250841 0.79250943 0.7925125  0.79251235 0.79251588 0.7925116\n",
      " 0.79250837 0.79250745 0.79251016 0.7925077  0.79251408 0.79250827\n",
      " 0.7925073  0.79250815 0.79251217 0.79250951 0.79251136 0.79251279\n",
      " 0.79251522 0.79250715 0.7925072  0.79250696 0.79251271 0.79250588\n",
      " 0.79250626 0.79251087 0.79250379 0.79251246 0.79250858 0.79251161\n",
      " 0.7925137  0.79250899 0.79250973 0.79250533 0.792513   0.79250409\n",
      " 0.79250943 0.79251398 0.79251109 0.79250638 0.79250969 0.79250979\n",
      " 0.79250619 0.79250803 0.79251216 0.79250638 0.79250363 0.79250783\n",
      " 0.79250324 0.79251124 0.79250114 0.79250579 0.79250924 0.79250711\n",
      "Epoch 39 RMSE: [0.78211049 0.78211201 0.78209813 0.78211264 0.78211053 0.78211019██▊| 39/40 [01:34<00:02,  2.38s/it]\n",
      " 0.78211079 0.78210665 0.78211125 0.78210675 0.78210874 0.78211085\n",
      " 0.78210849 0.78210722 0.78211018 0.78210456 0.78210833 0.7821075\n",
      " 0.78210433 0.78211224 0.78210978 0.78210864 0.78210482 0.78210649\n",
      " 0.78210325 0.78210742 0.78210795 0.78211477 0.78211035 0.78210647\n",
      " 0.78210964 0.78210863 0.78210705 0.7821103  0.78211415 0.78211461\n",
      " 0.78210756 0.7821062  0.78210364 0.78211536 0.78210784 0.78211148\n",
      " 0.78210821 0.78210925 0.78211235 0.78211215 0.78211556 0.78211147\n",
      " 0.78210825 0.7821074  0.78211008 0.78210755 0.78211375 0.78210818\n",
      " 0.78210719 0.78210798 0.78211186 0.78210941 0.78211107 0.78211257\n",
      " 0.78211489 0.78210719 0.78210722 0.78210696 0.78211244 0.78210618\n",
      " 0.78210627 0.78211081 0.7821038  0.78211226 0.78210851 0.78211143\n",
      " 0.78211353 0.78210906 0.78210978 0.78210549 0.78211269 0.78210406\n",
      " 0.78210928 0.78211382 0.78211093 0.78210638 0.78210957 0.78210971\n",
      " 0.78210602 0.782108   0.7821118  0.78210625 0.7821036  0.78210788\n",
      " 0.7821032  0.78211109 0.78210125 0.78210593 0.78210907 0.78210702\n",
      "Epoch 39 RMSE: [0.78211049 0.78211201 0.78209813 0.78211264 0.78211053 0.78211019██▊| 39/40 [01:34<00:02,  2.38s/it]\n",
      " 0.78211079 0.78210665 0.78211125 0.78210675 0.78210874 0.78211085\n",
      " 0.78210849 0.78210722 0.78211018 0.78210456 0.78210833 0.7821075\n",
      " 0.78210433 0.78211224 0.78210978 0.78210864 0.78210482 0.78210649\n",
      " 0.78210325 0.78210742 0.78210795 0.78211477 0.78211035 0.78210647\n",
      " 0.78210964 0.78210863 0.78210705 0.7821103  0.78211415 0.78211461\n",
      " 0.78210756 0.7821062  0.78210364 0.78211536 0.78210784 0.78211148\n",
      " 0.78210821 0.78210925 0.78211235 0.78211215 0.78211556 0.78211147\n",
      " 0.78210825 0.7821074  0.78211008 0.78210755 0.78211375 0.78210818\n",
      " 0.78210719 0.78210798 0.78211186 0.78210941 0.78211107 0.78211257\n",
      " 0.78211489 0.78210719 0.78210722 0.78210696 0.78211244 0.78210618\n",
      " 0.78210627 0.78211081 0.7821038  0.78211226 0.78210851 0.78211143\n",
      " 0.78211353 0.78210906 0.78210978 0.78210549 0.78211269 0.78210406\n",
      " 0.78210928 0.78211382 0.78211093 0.78210638 0.78210957 0.78210971\n",
      " 0.78210602 0.782108   0.7821118  0.78210625 0.7821036  0.78210788\n",
      " 0.7821032  0.78211109 0.78210125 0.78210593 0.78210907 0.78210702\n",
      "Epoch 39 RMSE: [0.78211049 0.78211201 0.78209813 0.78211264 0.78211053 0.78211019███| 40/40 [01:36<00:00,  2.37s/it]\n",
      " 0.78211079 0.78210665 0.78211125 0.78210675 0.78210874 0.78211085\n",
      " 0.78210849 0.78210722 0.78211018 0.78210456 0.78210833 0.7821075\n",
      " 0.78210433 0.78211224 0.78210978 0.78210864 0.78210482 0.78210649\n",
      " 0.78210325 0.78210742 0.78210795 0.78211477 0.78211035 0.78210647\n",
      " 0.78210964 0.78210863 0.78210705 0.7821103  0.78211415 0.78211461\n",
      " 0.78210756 0.7821062  0.78210364 0.78211536 0.78210784 0.78211148\n",
      " 0.78210821 0.78210925 0.78211235 0.78211215 0.78211556 0.78211147\n",
      " 0.78210825 0.7821074  0.78211008 0.78210755 0.78211375 0.78210818\n",
      " 0.78210719 0.78210798 0.78211186 0.78210941 0.78211107 0.78211257\n",
      " 0.78211489 0.78210719 0.78210722 0.78210696 0.78211244 0.78210618\n",
      " 0.78210627 0.78211081 0.7821038  0.78211226 0.78210851 0.78211143\n",
      " 0.78211353 0.78210906 0.78210978 0.78210549 0.78211269 0.78210406\n",
      " 0.78210928 0.78211382 0.78211093 0.78210638 0.78210957 0.78210971\n",
      " 0.78210602 0.782108   0.7821118  0.78210625 0.7821036  0.78210788\n",
      " 0.7821032  0.78211109 0.78210125 0.78210593 0.78210907 0.78210702\n",
      " 0.78210826 0.78211145 0.78211014 0.78210739]. Training epoch 40...: 100%|██████████| 40/40 [01:36<00:00,  2.41s/it]\n"
     ]
    }
   ],
   "source": [
    "model = SVDbaseline(train_ui, learning_rate = 0.005, regularization = 0.02, nb_factors = 100, iterations = 40)\n",
    "model.train(test_ui)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.estimations()\n",
    "\n",
    "top_n=pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n",
    "\n",
    "top_n.to_csv('Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv', index=False, header=False)\n",
    "\n",
    "estimations=pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n",
    "estimations.to_csv('Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv', index=False, header=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "943it [00:00, 8806.80it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>RMSE</th>\n",
       "      <th>MAE</th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>F_1</th>\n",
       "      <th>F_05</th>\n",
       "      <th>precision_super</th>\n",
       "      <th>recall_super</th>\n",
       "      <th>NDCG</th>\n",
       "      <th>mAP</th>\n",
       "      <th>MRR</th>\n",
       "      <th>LAUC</th>\n",
       "      <th>HR</th>\n",
       "      <th>HR2</th>\n",
       "      <th>Reco in test</th>\n",
       "      <th>Test coverage</th>\n",
       "      <th>Shannon</th>\n",
       "      <th>Gini</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.64479</td>\n",
       "      <td>3.479397</td>\n",
       "      <td>0.13701</td>\n",
       "      <td>0.082007</td>\n",
       "      <td>0.083942</td>\n",
       "      <td>0.100776</td>\n",
       "      <td>0.106974</td>\n",
       "      <td>0.105605</td>\n",
       "      <td>0.160418</td>\n",
       "      <td>0.080222</td>\n",
       "      <td>0.322261</td>\n",
       "      <td>0.537895</td>\n",
       "      <td>0.626723</td>\n",
       "      <td>0.360551</td>\n",
       "      <td>0.999894</td>\n",
       "      <td>0.276335</td>\n",
       "      <td>5.123235</td>\n",
       "      <td>0.910511</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      RMSE       MAE  precision    recall       F_1      F_05  \\\n",
       "0  3.64479  3.479397    0.13701  0.082007  0.083942  0.100776   \n",
       "\n",
       "   precision_super  recall_super      NDCG       mAP       MRR      LAUC  \\\n",
       "0         0.106974      0.105605  0.160418  0.080222  0.322261  0.537895   \n",
       "\n",
       "         HR       HR2  Reco in test  Test coverage   Shannon      Gini  \n",
       "0  0.626723  0.360551      0.999894       0.276335  5.123235  0.910511  "
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import evaluation_measures as ev\n",
    "\n",
    "estimations_df=pd.read_csv('Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv', header=None)\n",
    "reco=np.loadtxt('Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv', delimiter=',')\n",
    "\n",
    "ev.evaluate(test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None),\n",
    "            estimations_df=estimations_df, \n",
    "            reco=reco,\n",
    "            super_reactions=[4,5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ready-made SVD - Surprise implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SVD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating predictions...\n",
      "Generating top N recommendations...\n",
      "Generating predictions...\n"
     ]
    }
   ],
   "source": [
    "import helpers\n",
    "import surprise as sp\n",
    "import imp\n",
    "imp.reload(helpers)\n",
    "\n",
    "algo = sp.SVD(biased=False) # to use unbiased version\n",
    "\n",
    "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVD_reco.csv',\n",
    "          estimations_path='Recommendations generated/ml-100k/Ready_SVD_estimations.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SVD biased - on top baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating predictions...\n",
      "Generating top N recommendations...\n",
      "Generating predictions...\n"
     ]
    }
   ],
   "source": [
    "import helpers\n",
    "import surprise as sp\n",
    "import imp\n",
    "imp.reload(helpers)\n",
    "\n",
    "algo = sp.SVD() # default is biased=True\n",
    "\n",
    "helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv',\n",
    "          estimations_path='Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "943it [00:00, 9840.17it/s]\n",
      "943it [00:00, 9173.71it/s]\n",
      "943it [00:00, 9859.58it/s]\n",
      "943it [00:00, 9112.74it/s]\n",
      "943it [00:00, 9551.34it/s]\n",
      "943it [00:00, 7830.24it/s]\n",
      "943it [00:00, 8983.95it/s]\n",
      "943it [00:00, 9447.94it/s]\n",
      "943it [00:00, 9866.98it/s]\n",
      "943it [00:00, 10127.04it/s]\n",
      "943it [00:00, 9035.26it/s]\n",
      "943it [00:00, 9754.44it/s]\n",
      "943it [00:00, 9524.64it/s]\n",
      "943it [00:00, 8451.27it/s]\n",
      "943it [00:00, 9054.06it/s]\n",
      "943it [00:00, 10007.79it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Model</th>\n",
       "      <th>RMSE</th>\n",
       "      <th>MAE</th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>F_1</th>\n",
       "      <th>F_05</th>\n",
       "      <th>precision_super</th>\n",
       "      <th>recall_super</th>\n",
       "      <th>NDCG</th>\n",
       "      <th>mAP</th>\n",
       "      <th>MRR</th>\n",
       "      <th>LAUC</th>\n",
       "      <th>HR</th>\n",
       "      <th>HR2</th>\n",
       "      <th>Reco in test</th>\n",
       "      <th>Test coverage</th>\n",
       "      <th>Shannon</th>\n",
       "      <th>Gini</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_TopPop</td>\n",
       "      <td>2.508258</td>\n",
       "      <td>2.217909</td>\n",
       "      <td>0.188865</td>\n",
       "      <td>0.116919</td>\n",
       "      <td>0.118732</td>\n",
       "      <td>0.141584</td>\n",
       "      <td>0.130472</td>\n",
       "      <td>0.137473</td>\n",
       "      <td>0.214651</td>\n",
       "      <td>0.111707</td>\n",
       "      <td>0.400939</td>\n",
       "      <td>0.555546</td>\n",
       "      <td>0.765642</td>\n",
       "      <td>0.492047</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.038961</td>\n",
       "      <td>3.159079</td>\n",
       "      <td>0.987317</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_SVDBaseline</td>\n",
       "      <td>3.644790</td>\n",
       "      <td>3.479397</td>\n",
       "      <td>0.137010</td>\n",
       "      <td>0.082007</td>\n",
       "      <td>0.083942</td>\n",
       "      <td>0.100776</td>\n",
       "      <td>0.106974</td>\n",
       "      <td>0.105605</td>\n",
       "      <td>0.160418</td>\n",
       "      <td>0.080222</td>\n",
       "      <td>0.322261</td>\n",
       "      <td>0.537895</td>\n",
       "      <td>0.626723</td>\n",
       "      <td>0.360551</td>\n",
       "      <td>0.999894</td>\n",
       "      <td>0.276335</td>\n",
       "      <td>5.123235</td>\n",
       "      <td>0.910511</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_SVD</td>\n",
       "      <td>0.950945</td>\n",
       "      <td>0.749680</td>\n",
       "      <td>0.098834</td>\n",
       "      <td>0.049106</td>\n",
       "      <td>0.054037</td>\n",
       "      <td>0.068741</td>\n",
       "      <td>0.087768</td>\n",
       "      <td>0.073987</td>\n",
       "      <td>0.113242</td>\n",
       "      <td>0.054201</td>\n",
       "      <td>0.243492</td>\n",
       "      <td>0.521280</td>\n",
       "      <td>0.493107</td>\n",
       "      <td>0.248144</td>\n",
       "      <td>0.998515</td>\n",
       "      <td>0.214286</td>\n",
       "      <td>4.413166</td>\n",
       "      <td>0.953488</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_SVD</td>\n",
       "      <td>0.915079</td>\n",
       "      <td>0.718240</td>\n",
       "      <td>0.104772</td>\n",
       "      <td>0.045496</td>\n",
       "      <td>0.054393</td>\n",
       "      <td>0.071374</td>\n",
       "      <td>0.094421</td>\n",
       "      <td>0.076826</td>\n",
       "      <td>0.109517</td>\n",
       "      <td>0.052005</td>\n",
       "      <td>0.206646</td>\n",
       "      <td>0.519484</td>\n",
       "      <td>0.487805</td>\n",
       "      <td>0.264051</td>\n",
       "      <td>0.874549</td>\n",
       "      <td>0.142136</td>\n",
       "      <td>3.890472</td>\n",
       "      <td>0.972126</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_Baseline</td>\n",
       "      <td>0.949459</td>\n",
       "      <td>0.752487</td>\n",
       "      <td>0.091410</td>\n",
       "      <td>0.037652</td>\n",
       "      <td>0.046030</td>\n",
       "      <td>0.061286</td>\n",
       "      <td>0.079614</td>\n",
       "      <td>0.056463</td>\n",
       "      <td>0.095957</td>\n",
       "      <td>0.043178</td>\n",
       "      <td>0.198193</td>\n",
       "      <td>0.515501</td>\n",
       "      <td>0.437964</td>\n",
       "      <td>0.239661</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.033911</td>\n",
       "      <td>2.836513</td>\n",
       "      <td>0.991139</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_SVDBiased</td>\n",
       "      <td>0.938535</td>\n",
       "      <td>0.738678</td>\n",
       "      <td>0.085366</td>\n",
       "      <td>0.036921</td>\n",
       "      <td>0.044151</td>\n",
       "      <td>0.057832</td>\n",
       "      <td>0.074893</td>\n",
       "      <td>0.056396</td>\n",
       "      <td>0.095960</td>\n",
       "      <td>0.044204</td>\n",
       "      <td>0.212483</td>\n",
       "      <td>0.515132</td>\n",
       "      <td>0.446448</td>\n",
       "      <td>0.217391</td>\n",
       "      <td>0.997561</td>\n",
       "      <td>0.168110</td>\n",
       "      <td>4.191946</td>\n",
       "      <td>0.963341</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_KNNSurprisetask</td>\n",
       "      <td>0.946255</td>\n",
       "      <td>0.745209</td>\n",
       "      <td>0.083457</td>\n",
       "      <td>0.032848</td>\n",
       "      <td>0.041227</td>\n",
       "      <td>0.055493</td>\n",
       "      <td>0.074785</td>\n",
       "      <td>0.048890</td>\n",
       "      <td>0.089577</td>\n",
       "      <td>0.040902</td>\n",
       "      <td>0.189057</td>\n",
       "      <td>0.513076</td>\n",
       "      <td>0.417815</td>\n",
       "      <td>0.217391</td>\n",
       "      <td>0.888547</td>\n",
       "      <td>0.130592</td>\n",
       "      <td>3.611806</td>\n",
       "      <td>0.978659</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_GlobalAvg</td>\n",
       "      <td>1.125760</td>\n",
       "      <td>0.943534</td>\n",
       "      <td>0.061188</td>\n",
       "      <td>0.025968</td>\n",
       "      <td>0.031383</td>\n",
       "      <td>0.041343</td>\n",
       "      <td>0.040558</td>\n",
       "      <td>0.032107</td>\n",
       "      <td>0.067695</td>\n",
       "      <td>0.027470</td>\n",
       "      <td>0.171187</td>\n",
       "      <td>0.509546</td>\n",
       "      <td>0.384942</td>\n",
       "      <td>0.142100</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.025974</td>\n",
       "      <td>2.711772</td>\n",
       "      <td>0.992003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_Random</td>\n",
       "      <td>1.522798</td>\n",
       "      <td>1.222501</td>\n",
       "      <td>0.049841</td>\n",
       "      <td>0.020656</td>\n",
       "      <td>0.025232</td>\n",
       "      <td>0.033446</td>\n",
       "      <td>0.030579</td>\n",
       "      <td>0.022927</td>\n",
       "      <td>0.051680</td>\n",
       "      <td>0.019110</td>\n",
       "      <td>0.123085</td>\n",
       "      <td>0.506849</td>\n",
       "      <td>0.331919</td>\n",
       "      <td>0.119830</td>\n",
       "      <td>0.985048</td>\n",
       "      <td>0.183983</td>\n",
       "      <td>5.097973</td>\n",
       "      <td>0.907483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_I-KNN</td>\n",
       "      <td>1.030386</td>\n",
       "      <td>0.813067</td>\n",
       "      <td>0.026087</td>\n",
       "      <td>0.006908</td>\n",
       "      <td>0.010593</td>\n",
       "      <td>0.016046</td>\n",
       "      <td>0.021137</td>\n",
       "      <td>0.009522</td>\n",
       "      <td>0.024214</td>\n",
       "      <td>0.008958</td>\n",
       "      <td>0.048068</td>\n",
       "      <td>0.499885</td>\n",
       "      <td>0.154825</td>\n",
       "      <td>0.072110</td>\n",
       "      <td>0.402333</td>\n",
       "      <td>0.434343</td>\n",
       "      <td>5.133650</td>\n",
       "      <td>0.877999</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_I-KNNBaseline</td>\n",
       "      <td>0.935327</td>\n",
       "      <td>0.737424</td>\n",
       "      <td>0.002545</td>\n",
       "      <td>0.000755</td>\n",
       "      <td>0.001105</td>\n",
       "      <td>0.001602</td>\n",
       "      <td>0.002253</td>\n",
       "      <td>0.000930</td>\n",
       "      <td>0.003444</td>\n",
       "      <td>0.001362</td>\n",
       "      <td>0.011760</td>\n",
       "      <td>0.496724</td>\n",
       "      <td>0.021209</td>\n",
       "      <td>0.004242</td>\n",
       "      <td>0.482821</td>\n",
       "      <td>0.059885</td>\n",
       "      <td>2.232578</td>\n",
       "      <td>0.994487</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Ready_U-KNN</td>\n",
       "      <td>1.023495</td>\n",
       "      <td>0.807913</td>\n",
       "      <td>0.000742</td>\n",
       "      <td>0.000205</td>\n",
       "      <td>0.000305</td>\n",
       "      <td>0.000449</td>\n",
       "      <td>0.000536</td>\n",
       "      <td>0.000198</td>\n",
       "      <td>0.000845</td>\n",
       "      <td>0.000274</td>\n",
       "      <td>0.002744</td>\n",
       "      <td>0.496441</td>\n",
       "      <td>0.007423</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.602121</td>\n",
       "      <td>0.010823</td>\n",
       "      <td>2.089186</td>\n",
       "      <td>0.995706</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_BaselineIU</td>\n",
       "      <td>0.958136</td>\n",
       "      <td>0.754051</td>\n",
       "      <td>0.000954</td>\n",
       "      <td>0.000188</td>\n",
       "      <td>0.000298</td>\n",
       "      <td>0.000481</td>\n",
       "      <td>0.000644</td>\n",
       "      <td>0.000223</td>\n",
       "      <td>0.001043</td>\n",
       "      <td>0.000335</td>\n",
       "      <td>0.003348</td>\n",
       "      <td>0.496433</td>\n",
       "      <td>0.009544</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.699046</td>\n",
       "      <td>0.005051</td>\n",
       "      <td>1.945910</td>\n",
       "      <td>0.995669</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_TopRated</td>\n",
       "      <td>2.508258</td>\n",
       "      <td>2.217909</td>\n",
       "      <td>0.000954</td>\n",
       "      <td>0.000188</td>\n",
       "      <td>0.000298</td>\n",
       "      <td>0.000481</td>\n",
       "      <td>0.000644</td>\n",
       "      <td>0.000223</td>\n",
       "      <td>0.001043</td>\n",
       "      <td>0.000335</td>\n",
       "      <td>0.003348</td>\n",
       "      <td>0.496433</td>\n",
       "      <td>0.009544</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.699046</td>\n",
       "      <td>0.005051</td>\n",
       "      <td>1.945910</td>\n",
       "      <td>0.995669</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_BaselineUI</td>\n",
       "      <td>0.967585</td>\n",
       "      <td>0.762740</td>\n",
       "      <td>0.000954</td>\n",
       "      <td>0.000170</td>\n",
       "      <td>0.000278</td>\n",
       "      <td>0.000463</td>\n",
       "      <td>0.000644</td>\n",
       "      <td>0.000189</td>\n",
       "      <td>0.000752</td>\n",
       "      <td>0.000168</td>\n",
       "      <td>0.001677</td>\n",
       "      <td>0.496424</td>\n",
       "      <td>0.009544</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.600530</td>\n",
       "      <td>0.005051</td>\n",
       "      <td>1.803126</td>\n",
       "      <td>0.996380</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Self_IKNN</td>\n",
       "      <td>1.018363</td>\n",
       "      <td>0.808793</td>\n",
       "      <td>0.000318</td>\n",
       "      <td>0.000108</td>\n",
       "      <td>0.000140</td>\n",
       "      <td>0.000189</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000214</td>\n",
       "      <td>0.000037</td>\n",
       "      <td>0.000368</td>\n",
       "      <td>0.496391</td>\n",
       "      <td>0.003181</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.392153</td>\n",
       "      <td>0.115440</td>\n",
       "      <td>4.174741</td>\n",
       "      <td>0.965327</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  Model      RMSE       MAE  precision    recall       F_1  \\\n",
       "0           Self_TopPop  2.508258  2.217909   0.188865  0.116919  0.118732   \n",
       "0      Self_SVDBaseline  3.644790  3.479397   0.137010  0.082007  0.083942   \n",
       "0             Ready_SVD  0.950945  0.749680   0.098834  0.049106  0.054037   \n",
       "0              Self_SVD  0.915079  0.718240   0.104772  0.045496  0.054393   \n",
       "0        Ready_Baseline  0.949459  0.752487   0.091410  0.037652  0.046030   \n",
       "0       Ready_SVDBiased  0.938535  0.738678   0.085366  0.036921  0.044151   \n",
       "0  Self_KNNSurprisetask  0.946255  0.745209   0.083457  0.032848  0.041227   \n",
       "0        Self_GlobalAvg  1.125760  0.943534   0.061188  0.025968  0.031383   \n",
       "0          Ready_Random  1.522798  1.222501   0.049841  0.020656  0.025232   \n",
       "0           Ready_I-KNN  1.030386  0.813067   0.026087  0.006908  0.010593   \n",
       "0   Ready_I-KNNBaseline  0.935327  0.737424   0.002545  0.000755  0.001105   \n",
       "0           Ready_U-KNN  1.023495  0.807913   0.000742  0.000205  0.000305   \n",
       "0       Self_BaselineIU  0.958136  0.754051   0.000954  0.000188  0.000298   \n",
       "0         Self_TopRated  2.508258  2.217909   0.000954  0.000188  0.000298   \n",
       "0       Self_BaselineUI  0.967585  0.762740   0.000954  0.000170  0.000278   \n",
       "0             Self_IKNN  1.018363  0.808793   0.000318  0.000108  0.000140   \n",
       "\n",
       "       F_05  precision_super  recall_super      NDCG       mAP       MRR  \\\n",
       "0  0.141584         0.130472      0.137473  0.214651  0.111707  0.400939   \n",
       "0  0.100776         0.106974      0.105605  0.160418  0.080222  0.322261   \n",
       "0  0.068741         0.087768      0.073987  0.113242  0.054201  0.243492   \n",
       "0  0.071374         0.094421      0.076826  0.109517  0.052005  0.206646   \n",
       "0  0.061286         0.079614      0.056463  0.095957  0.043178  0.198193   \n",
       "0  0.057832         0.074893      0.056396  0.095960  0.044204  0.212483   \n",
       "0  0.055493         0.074785      0.048890  0.089577  0.040902  0.189057   \n",
       "0  0.041343         0.040558      0.032107  0.067695  0.027470  0.171187   \n",
       "0  0.033446         0.030579      0.022927  0.051680  0.019110  0.123085   \n",
       "0  0.016046         0.021137      0.009522  0.024214  0.008958  0.048068   \n",
       "0  0.001602         0.002253      0.000930  0.003444  0.001362  0.011760   \n",
       "0  0.000449         0.000536      0.000198  0.000845  0.000274  0.002744   \n",
       "0  0.000481         0.000644      0.000223  0.001043  0.000335  0.003348   \n",
       "0  0.000481         0.000644      0.000223  0.001043  0.000335  0.003348   \n",
       "0  0.000463         0.000644      0.000189  0.000752  0.000168  0.001677   \n",
       "0  0.000189         0.000000      0.000000  0.000214  0.000037  0.000368   \n",
       "\n",
       "       LAUC        HR       HR2  Reco in test  Test coverage   Shannon  \\\n",
       "0  0.555546  0.765642  0.492047      1.000000       0.038961  3.159079   \n",
       "0  0.537895  0.626723  0.360551      0.999894       0.276335  5.123235   \n",
       "0  0.521280  0.493107  0.248144      0.998515       0.214286  4.413166   \n",
       "0  0.519484  0.487805  0.264051      0.874549       0.142136  3.890472   \n",
       "0  0.515501  0.437964  0.239661      1.000000       0.033911  2.836513   \n",
       "0  0.515132  0.446448  0.217391      0.997561       0.168110  4.191946   \n",
       "0  0.513076  0.417815  0.217391      0.888547       0.130592  3.611806   \n",
       "0  0.509546  0.384942  0.142100      1.000000       0.025974  2.711772   \n",
       "0  0.506849  0.331919  0.119830      0.985048       0.183983  5.097973   \n",
       "0  0.499885  0.154825  0.072110      0.402333       0.434343  5.133650   \n",
       "0  0.496724  0.021209  0.004242      0.482821       0.059885  2.232578   \n",
       "0  0.496441  0.007423  0.000000      0.602121       0.010823  2.089186   \n",
       "0  0.496433  0.009544  0.000000      0.699046       0.005051  1.945910   \n",
       "0  0.496433  0.009544  0.000000      0.699046       0.005051  1.945910   \n",
       "0  0.496424  0.009544  0.000000      0.600530       0.005051  1.803126   \n",
       "0  0.496391  0.003181  0.000000      0.392153       0.115440  4.174741   \n",
       "\n",
       "       Gini  \n",
       "0  0.987317  \n",
       "0  0.910511  \n",
       "0  0.953488  \n",
       "0  0.972126  \n",
       "0  0.991139  \n",
       "0  0.963341  \n",
       "0  0.978659  \n",
       "0  0.992003  \n",
       "0  0.907483  \n",
       "0  0.877999  \n",
       "0  0.994487  \n",
       "0  0.995706  \n",
       "0  0.995669  \n",
       "0  0.995669  \n",
       "0  0.996380  \n",
       "0  0.965327  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import imp\n",
    "imp.reload(ev)\n",
    "\n",
    "import evaluation_measures as ev\n",
    "dir_path=\"Recommendations generated/ml-100k/\"\n",
    "super_reactions=[4,5]\n",
    "test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n",
    "\n",
    "ev.evaluate_all(test, dir_path, super_reactions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}