workshops_recommender_systems/P4. Matrix Factorization.ipynb

3206 lines
192 KiB
Plaintext
Raw Normal View History

2020-06-15 00:15:17 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Self made SVD"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import helpers\n",
"import pandas as pd\n",
"import numpy as np\n",
"import scipy.sparse as sparse\n",
"from collections import defaultdict\n",
"from itertools import chain\n",
"import random\n",
"\n",
"train_read=pd.read_csv('./Datasets/ml-100k/train.csv', sep='\\t', header=None)\n",
"test_read=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n",
"train_ui, test_ui, user_code_id, user_id_code, item_code_id, item_id_code = helpers.data_to_csr(train_read, test_read)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Done similarly to https://github.com/albertauyeung/matrix-factorization-in-python\n",
"from tqdm import tqdm\n",
"\n",
"class SVD():\n",
" \n",
" def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n",
" self.train_ui=train_ui\n",
" self.uir=list(zip(*[train_ui.nonzero()[0],train_ui.nonzero()[1], train_ui.data]))\n",
" \n",
" self.learning_rate=learning_rate\n",
" self.regularization=regularization\n",
" self.iterations=iterations\n",
" self.nb_users, self.nb_items=train_ui.shape\n",
" self.nb_ratings=train_ui.nnz\n",
" self.nb_factors=nb_factors\n",
" \n",
" self.Pu=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_users, self.nb_factors))\n",
" self.Qi=np.random.normal(loc=0, scale=1./self.nb_factors, size=(self.nb_items, self.nb_factors))\n",
"\n",
" def train(self, test_ui=None):\n",
" if test_ui!=None:\n",
" self.test_uir=list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n",
" \n",
" self.learning_process=[]\n",
" pbar = tqdm(range(self.iterations))\n",
" for i in pbar:\n",
" pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i>0 else 0}. Training epoch {i+1}...')\n",
" np.random.shuffle(self.uir)\n",
" self.sgd(self.uir)\n",
" if test_ui==None:\n",
" self.learning_process.append([i+1, self.RMSE_total(self.uir)])\n",
" else:\n",
" self.learning_process.append([i+1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n",
" \n",
" def sgd(self, uir):\n",
" \n",
" for u, i, score in uir:\n",
" # Computer prediction and error\n",
" prediction = self.get_rating(u,i)\n",
" e = (score - prediction)\n",
" \n",
" # Update user and item latent feature matrices\n",
" Pu_update=self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n",
" Qi_update=self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n",
" \n",
" self.Pu[u] += Pu_update\n",
" self.Qi[i] += Qi_update\n",
" \n",
" def get_rating(self, u, i):\n",
" prediction = self.Pu[u].dot(self.Qi[i].T)\n",
" return prediction\n",
" \n",
" def RMSE_total(self, uir):\n",
" RMSE=0\n",
" for u,i, score in uir:\n",
" prediction = self.get_rating(u,i)\n",
" RMSE+=(score - prediction)**2\n",
" return np.sqrt(RMSE/len(uir))\n",
" \n",
" def estimations(self):\n",
" self.estimations=\\\n",
" np.dot(self.Pu,self.Qi.T)\n",
"\n",
" def recommend(self, user_code_id, item_code_id, topK=10):\n",
" \n",
" top_k = defaultdict(list)\n",
" for nb_user, user in enumerate(self.estimations):\n",
" \n",
" user_rated=self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user+1]]\n",
" for item, score in enumerate(user):\n",
" if item not in user_rated and not np.isnan(score):\n",
" top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n",
" result=[]\n",
" # Let's choose k best items in the format: (user, item1, score1, item2, score2, ...)\n",
" for uid, item_scores in top_k.items():\n",
" item_scores.sort(key=lambda x: x[1], reverse=True)\n",
" result.append([uid]+list(chain(*item_scores[:topK])))\n",
" return result\n",
" \n",
" def estimate(self, user_code_id, item_code_id, test_ui):\n",
" result=[]\n",
" for user, item in zip(*test_ui.nonzero()):\n",
" result.append([user_code_id[user], item_code_id[item], \n",
" self.estimations[user,item] if not np.isnan(self.estimations[user,item]) else 1])\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 39 RMSE: 0.7467772350811145. Training epoch 40...: 100%|██████████| 40/40 [00:59<00:00, 1.50s/it]\n"
]
}
],
"source": [
"model=SVD(train_ui, learning_rate=0.005, regularization=0.02, nb_factors=100, iterations=40)\n",
"model.train(test_ui)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fedc32039b0>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAb6UlEQVR4nO3dfXRV9Z3v8feXJBAegoEkWkp4knGWVNQgiaMVOyjLWUCnlapXce5MdTouOhYXerXeAtexY7s6U2TVpz65uK0Pbb1irQ/VLrVXC9apz+FREQRUbALcEgOBIEQJfO8f+xxyCCfJOclJ9snen9dav7X32Xuffb7ZCz57n99+OObuiIhI/zcg7AJERCQ3FOgiIhGhQBcRiQgFuohIRCjQRUQiojCsDy4vL/fx48eH9fEiIv3SqlWrPnL3inTzQgv08ePHU1tbG9bHi4j0S2b2YUfz1OUiIhIRCnQRkYhQoIuIRERofegi0v8cOnSI+vp6Wlpawi4l8oqLi6msrKSoqCjj9yjQRSRj9fX1lJSUMH78eMws7HIiy91pbGykvr6eCRMmZPw+dbmISMZaWlooKytTmPcyM6OsrCzrb0IKdBHJisK8b3RnO/e7QP/zn+GGG+DQobArERHJL/0u0Nesgbvvhh/8IOxKRETyS78L9Isvhksugdtug61bw65GRPpSU1MTP/nJT7J+3+zZs2lqasr6fVdffTUTJkygqqqKM888kz/84Q9H502fPp2xY8eS+iNBc+bMYdiwYQAcOXKEBQsWMHnyZE4//XRqamr44IMPgOBO+dNPP52qqiqqqqpYsGBB1rWl0y+vcvnhD+GFF+Bf/xWefx7UpScSD8lA/8Y3vnHM9NbWVgoLO46zZ555ptufuXTpUi677DJWrlzJvHnz2LJly9F5paWlvPzyy0ybNo2mpiZ27tx5dN4jjzzCjh07WL9+PQMGDKC+vp6hQ4cenb9y5UrKy8u7XVc6/TLQP/tZWLIErr0WfvELuOqqsCsSiZ8bboC1a3O7zqoquOuujucvXLiQ9957j6qqKoqKiiguLmbEiBFs2rSJzZs3M2fOHOrq6mhpaeH6669n3rx5QNuzo/bv38+sWbOYNm0ar7zyCqNHj+a3v/0tgwcP7rK2c889l+3btx8zbe7cuSxfvpxp06bx+OOPc8kll7BhwwYAdu7cyahRoxgwIOgIqays7OZWyVy/63JJmjcPzjsPbrwRdu0KuxoR6Qvf//73mThxImvXrmXp0qWsXr2au+++m82bNwNw3333sWrVKmpra7nnnntobGw8bh1btmxh/vz5bNiwgdLSUh577LGMPvu5555jzpw5x0ybMWMGL730EocPH2b58uVcccUVR+ddfvnlPP3001RVVXHTTTexZs2aY957wQUXHO1yufPOO7PdFGn1yyN0gAEDYNmyYI9+443wq1+FXZFIvHR2JN1Xzj777GNuvLnnnnt44oknAKirq2PLli2UlZUd855knzjA1KlT2bZtW6efcfPNN7N48WLq6+t59dVXj5lXUFDAtGnTWL58OQcPHiT1keCVlZW8++67rFixghUrVjBjxgweffRRZsyYAfROl0u/PUIH+NznYPFieOgh+P3vw65GRPpaap/0iy++yAsvvMCrr77KunXrmDJlStobcwYNGnR0vKCggNbW1k4/Y+nSpWzevJklS5bwta997bj5c+fOZcGCBVx++eVpP2vWrFksXbqUxYsX8+STT2bz52WtXwc6wKJFcOqpwQnSjz8OuxoR6U0lJSU0Nzennbd3715GjBjBkCFD2LRpE6+99lpOP/u6667jyJEj/L7d0eP555/PokWLuPLKK4+Zvnr1anbs2AEEV7ysX7+ecePG5bSm9vp9oA8aFHS9bNsG3/522NWISG8qKyvjvPPOY/Lkydx8883HzJs5cyatra1MmjSJhQsXcs455+T0s82MW265hdtvv/246d/85jeP6z7ZtWsXX/rSl5g8eTJnnHEGhYWFXHfddUfnp/ahf/WrX81NjanXUPal6upqz+UvFn396/Czn8Gbb8JZZ+VstSKSYuPGjUyaNCnsMmIj3fY2s1XuXp1u+X5/hJ60ZAmceCJccw100SUmIhJJXQa6mRWb2Rtmts7MNpjZbWmWudrMGsxsbaJd0zvldqy0NLjhKPloABGRTM2fP/9o90ey3X///WGXlbVMLlv8BLjQ3febWRHwJzN71t3bn3F4xN2vS/P+PnPppTB7Nnz3u8GljLqDVCT33D1yT1z88Y9/HHYJx+lOd3iXR+ge2J94WZRo4XS8d8EM/vZvYe9eOHAg7GpEoqe4uJjGxsZuhY1kLvkDF8XFxVm9L6Mbi8ysAFgF/BXwY3d/Pc1il5rZF4DNwP9w97o065kHzAMYO3ZsVoVmKnkPQWMjpFyiKiI5UFlZSX19PQ0NDWGXEnnJn6DLRkaB7u6HgSozKwWeMLPJ7v52yiJPAw+7+ydm9nXgQeDCNOtZBiyD4CqXrCrN0MiRwbCxEXppnyESW0VFRVn9JJr0rayucnH3JmAlMLPd9EZ3/yTx8mfA1NyUl73UI3QRkTjJ5CqXisSROWY2GLgI2NRumVEpL78MbMxlkdlQoItIXGXS5TIKeDDRjz4A+LW7/87MvgPUuvtTwAIz+zLQCuwGru6tgruiQBeRuOoy0N19PTAlzfRbU8YXAYtyW1r3pPahi4jESWTuFE0aOBBKShToIhI/kQt0CLpdFOgiEjeRDfTdu8OuQkSkb0U20HWELiJxo0AXEYkIBbqISERENtCbmuDw4bArERHpO5EM9JEjwR327Am7EhGRvhPJQNfdoiISRwp0EZGIUKCLiESEAl1EJCIU6CIiERHJQB8+HAoLFegiEi+RDHSz4NJFPc9FROIkkoEOultUROJHgS4iEhEKdBGRiIhsoI8cqUAXkXiJbKDrCF1E4ibSgd7SAgcOhF2JiEjfiHSgg47SRSQ+FOgiIhGhQBcRiYguA93Mis3sDTNbZ2YbzOy2NMsMMrNHzGyrmb1uZuN7o9hsKNBFJG4yOUL/BLjQ3c8EqoCZZnZOu2X+Bdjj7n8F3AksyW2Z2VOgi0jcdBnoHtifeFmUaN5usYuBBxPjvwFmmJnlrMpuSAa6nuciInGRUR+6mRWY2VpgF/C8u7/ebpHRQB2Au7cCe4GyNOuZZ2a1Zlbb0NDQs8q7MHAgDBumI3QRiY+MAt3dD7t7FVAJnG1mk7vzYe6+zN2r3b26oqKiO6vIim4uEpE4yeoqF3dvAlYCM9vN2g6MATCzQuAEIPQoVaCLSJxkcpVLhZmVJsYHAxcBm9ot9hRwVWL8MmCFu7fvZ+9zep6LiMRJYQbLjAIeNLMCgh3Ar939d2b2HaDW3Z8Cfg780sy2AruBub1WcRbKyuDDD8OuQkSkb3QZ6O6+HpiSZvqtKeMtwH/LbWk9py4XEYmTyN4pCkGg79kDhw+HXYmISO+LfKC7Q1NT2JWIiPS+yAc6qNtFROJBgS4iEhEKdBGRiFCgi4hERCwCXQ/oEpE4iHSgn3ACFBToCF1E4iHSgW6m2/9FJD4iHeigQBeR+Ih8oOv2fxGJCwW6iEhEKNBFRCJCgS4iEhGxCPSDB4MmIhJlsQh00FG6iESfAl1EJCIU6CIiERGbQNfzXEQk6mIT6DpCF5Goi3ygjxwZDBXoIhJ1kQ/04mIYMkSBLiLRF/lAB91cJCLxoEAXEYmILgPdzMaY2Uoze8fMNpjZ9WmWmW5me81sbaLd2jvldo8CXUTioDCDZVqBm9x9tZmVAKvM7Hl3f6fdcv/l7n+f+xJ7rqwM6urCrkJEpHd1eYTu7jvdfXVivBnYCIzu7cJySUfoIhIHWfWhm9l4YArweprZ55rZOjN71sxO6+D988ys1sxqGxoasi62u8rKYM8eOHKkzz5SRKTPZRzoZjYMeAy4wd33tZu9Ghjn7mcCPwSeTLcOd1/m7tXuXl1RUdHdmrNWVhaEeVNTn32kiEifyyjQzayIIMwfcvfH2893933uvj8x/gxQZGblOa20B3S3qIjEQSZXuRjwc2C
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"df=pd.DataFrame(model.learning_process).iloc[:,:2]\n",
"df.columns=['epoch', 'train_RMSE']\n",
"plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fedc1043c18>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3deXxU5fXH8c8JYVO0solKZBFRQQgBg2IFARFZiojWIihuWKkVxFZB0VqpVKstWhXXqsUFLah131AUFH+uhFUQ2UECVBGkxQUReX5/nBsIIUASktyZyff9es0rM/fOJOc6eObOuc9zHgshICIiqSst7gBERKRsKdGLiKQ4JXoRkRSnRC8ikuKU6EVEUlx63AEUVKdOndCoUaO4wxARSSrTp0//KoRQt7B9CZfoGzVqRE5OTtxhiIgkFTNbsat9eyzdmNlYM/vSzObuYr+Z2RgzW2xmc8ysTb5955vZouh2fsnCFxGRvVGUGv0jQPfd7O8BNI1ug4D7AMysFjASOA44FhhpZjX3JlgRESm+PSb6EMJUYP1unnIa8FhwHwIHmNnBQDdgUghhfQjha2ASu//AEBGRMlAaNfr6wMp8j3OjbbvavhMzG4R/G6BBgwalEJKIlKcff/yR3NxcNm3aFHcoKa9atWpkZGRQuXLlIr8mIS7GhhAeAB4AyM7OVvMdkSSTm5vLfvvtR6NGjTCzuMNJWSEE1q1bR25uLo0bNy7y60pjHP0q4NB8jzOibbvaLiIpZtOmTdSuXVtJvoyZGbVr1y72N6fSSPQvAudFo2/aAf8NIawBXgdOMbOa0UXYU6JtIpKClOTLR0n+O++xdGNm44FOQB0zy8VH0lQGCCHcD7wK9AQWA98BF0b71pvZn4Fp0a8aFULY3UXdUvANsARoVbZ/RkQkiewx0YcQ+u9hfwAG72LfWGBsyUIrifuB4cDpwPVAVvn9aRGRBJVivW5+DfwJmAy0xhP+rDgDEpFysGHDBu69995iv65nz55s2LCh2K+74IILaNy4MVlZWbRq1Yq33npr275OnTrRoEED8i/q1KdPH2rUqAHA1q1bGTp0KC1atKBly5a0bduWZcuWAd4ZoGXLlmRlZZGVlcXQoUOLHVthEmLUTek5AK8sXQ7cCdwO/BdP/CKSqvIS/aWXXrrD9i1btpCevus09+qrr5b4b44ePZozzzyTKVOmMGjQIBYtWrRt3wEHHMB7771H+/bt2bBhA2vWrNm278knn2T16tXMmTOHtLQ0cnNz2XfffbftnzJlCnXq1ClxXIVJsUSfJ3/CXxdtWwX8DrgWP9sXkbLwu9/BrFL+Ip2VBXfcsev9I0aMYMmSJWRlZVG5cmWqVatGzZo1+eyzz1i4cCF9+vRh5cqVbNq0icsvv5xBgwYB23trffPNN/To0YP27dvz/vvvU79+fV544QWqV6++x9iOP/54Vq3acUBhv379mDBhAu3bt+fZZ5/ljDPOYN68eQCsWbOGgw8+mLQ0L6hkZGSU8L9K0aVY6aagA4Am0f1PgDeBNkAfQI3TRFLFLbfcQpMmTZg1axajR49mxowZ3HnnnSxcuBCAsWPHMn36dHJychgzZgzr1q3b6XcsWrSIwYMHM2/ePA444ACeeeaZIv3tiRMn0qdPnx22denShalTp/LTTz8xYcIEzjrrrG37+vbty0svvURWVhZXXnklM2fO3OG1nTt33la6uf3224v7n6JQKXpGX5juwDJgDF7SeQFvwfM+UCnGuERSy+7OvMvLscceu8OEojFjxvDcc88BsHLlShYtWkTt2rV3eE1ezR3gmGOOYfny5bv9G8OHD+faa68lNzeXDz74YId9lSpVon379kyYMIHvv/+e/K3XMzIyWLBgAZMnT2by5Ml06dKFp59+mi5dugBlU7pJ8TP6gg7AR+MsB+4BTmZ7kh8JvAH8FEtkIlJ68te83377bd58800++OADZs+eTevWrQudcFS1atVt9ytVqsSWLVt2+zdGjx7NwoUL+etf/8rAgQN32t+vXz+GDh1K3759C/1bPXr0YPTo0Vx77bU8//zzxTm8YqtgiT7Pz4BLgZuix+vxxN8NaIx/GCyLJzQRKbb99tuPjRs3Frrvv//9LzVr1mSfffbhs88+48MPPyzVvz1kyBC2bt3K66/vOB+0Q4cOXHPNNfTvv+MI9RkzZrB69WrAR+DMmTOHhg0blmpMBVXQRF9QLfxi7VPA0cCNwGHAs3EGJSJFVLt2bU444QRatGjB8OHDd9jXvXt3tmzZQrNmzRgxYgTt2rUr1b9tZlx33XX87W9/22n7sGHDdirDfPnll5x66qm0aNGCzMxM0tPTGTJkyLb9+Wv05513XunEmH+sZyLIzs4O8a8wtRJ4FJ8HVhMYj0/wHQQcFWNcIolp/vz5NGvWLO4wKozC/nub2fQQQnZhz9cZfaEOBa7DkzzAfOBuoBnQGXgS2BxPaCIixaREXySj8LP8m/ELuf3w9VZEJJUNHjx4Wxkl7/bwww/HHVaxVaDhlXurHjACuApfLCvvP93XwLnAQOBUon5vIpIC7rnnnrhDKBVK9MWWho/OybMImA38EqiL99f5JV7iUdIXkfipdLPXjsWHYr4InAT8C/8gyI32rwZ+iCc0ERGU6EtJOl62mQB8ibdayJuVNwQ4EDgHH675XRwBikgFpkRf6qoDXfI9vhT4Fb64Vl5556oY4hKRikqJvsydDDwE/Ac/0z8fODja9wNwZLTtEWBFDPGJJL+S9qMHuOOOO/juu91/087rE5+ZmUnHjh1ZsWL7/6tmxoABA7Y93rJlC3Xr1qVXr14AfPHFF/Tq1YtWrVrRvHlzevbsCcDy5cupXr36DiN6HnvssRIdw54o0ZebdPxM/17g99G2r4FMfDXGC4FG+Izcl6P9iTWZTSRRlXWiB282NmfOHDp16sSNN964bfu+++7L3Llz+f777wGYNGkS9evX37b/+uuvp2vXrsyePZtPP/2UW265Zdu+vI6bebfSmglbkBJ9rA4Cnga+AObgi6W0wmv6AK/gk7SuxWfmKvFLsuhUyC0vEX+3i/2PRPu/KmTf7uXvRz98+HBGjx5N27ZtyczMZOTIkQB8++23/OIXv6BVq1a0aNGCJ598kjFjxrB69Wo6d+5M586di3RkhfWf79mzJ6+88goA48eP36G/zZo1a3boOZ+ZmVmkv1OalOgTQhrQEhgKPIeP5AHYDzgE+Fu0rWH0nMKbN4lUVPn70Xft2pVFixbx8ccfM2vWLKZPn87UqVOZOHEihxxyCLNnz2bu3Ll0796doUOHcsghhzBlyhSmTJlSpL9VWP/5vIVGNm3axJw5czjuuOO27Rs8eDAXXXQRnTt35qabbtrW0AzY9uGUd3v33XdL5z9IARpHn9A6Am/hq2S9jI/aeQXIa/j9BN6J82SgWhwBiuzC27vZt88e9tfZw/7de+ONN3jjjTdo3dpXkvvmm29YtGgRHTp04Morr+Tqq6+mV69edOjQoVi/t3Pnzqxfv54aNWrw5z//eYd9mZmZLF++nPHjx2+rwefp1q0bS5cuZeLEibz22mu0bt2auXPnAttLN2VNZ/RJoTZ+wfYFYCHb37a/4MM66wJ9gbF4qwaRiiuEwDXXXLOt7r148WIuuugijjjiCGbMmEHLli257rrrGDVqVLF+75QpU1ixYgVZWVnbykH59e7dm2HDhu3UlhigVq1anH322YwbN462bdsyderUEh9fSSjRJ538q2HNBCbiY/TfBS7CWywDbMW/Aey8ZJpIqsnfj75bt26MHTuWb775BoBVq1bx5Zdfsnr1avbZZx8GDBjA8OHDmTFjxk6v3ZP09HTuuOMOHnvsMdavX7/DvoEDBzJy5Ehatmy5w/bJkydvu9i7ceNGlixZQoMGDfbqeItLpZukVgWfhdsNuA+YF20DXyP3l4Dhi6F3iW4d8K/OIqkjfz/6Hj16cPbZZ3P88ccDUKNGDR5//HEWL17M8OHDSUtLo3Llytx3330ADBo0iO7du2+r1e/JwQcfTP/+/bnnnnv44x//uG17RkY
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"df=pd.DataFrame(model.learning_process[10:], columns=['epoch', 'train_RMSE', 'test_RMSE'])\n",
"plt.plot('epoch', 'train_RMSE', data=df, color='blue')\n",
"plt.plot('epoch', 'test_RMSE', data=df, color='yellow', linestyle='dashed')\n",
"plt.legend()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Saving and evaluating recommendations"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model.estimations()\n",
"\n",
"top_n=pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n",
"\n",
"top_n.to_csv('Recommendations generated/ml-100k/Self_SVD_reco.csv', index=False, header=False)\n",
"\n",
"estimations=pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n",
"estimations.to_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', index=False, header=False)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 9025.30it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>HR2</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.915079</td>\n",
" <td>0.71824</td>\n",
" <td>0.104772</td>\n",
" <td>0.045496</td>\n",
" <td>0.054393</td>\n",
" <td>0.071374</td>\n",
" <td>0.094421</td>\n",
" <td>0.076826</td>\n",
" <td>0.109517</td>\n",
" <td>0.052005</td>\n",
" <td>0.206646</td>\n",
" <td>0.519484</td>\n",
" <td>0.487805</td>\n",
" <td>0.264051</td>\n",
" <td>0.874549</td>\n",
" <td>0.142136</td>\n",
" <td>3.890472</td>\n",
" <td>0.972126</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" RMSE MAE precision recall F_1 F_05 \\\n",
"0 0.915079 0.71824 0.104772 0.045496 0.054393 0.071374 \n",
"\n",
" precision_super recall_super NDCG mAP MRR LAUC \\\n",
"0 0.094421 0.076826 0.109517 0.052005 0.206646 0.519484 \n",
"\n",
" HR HR2 Reco in test Test coverage Shannon Gini \n",
"0 0.487805 0.264051 0.874549 0.142136 3.890472 0.972126 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import evaluation_measures as ev\n",
"\n",
"estimations_df=pd.read_csv('Recommendations generated/ml-100k/Self_SVD_estimations.csv', header=None)\n",
"reco=np.loadtxt('Recommendations generated/ml-100k/Self_SVD_reco.csv', delimiter=',')\n",
"\n",
"ev.evaluate(test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None),\n",
" estimations_df=estimations_df, \n",
" reco=reco,\n",
" super_reactions=[4,5])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 8433.36it/s]\n",
"943it [00:00, 8182.71it/s]\n",
"943it [00:00, 9546.13it/s]\n",
"943it [00:00, 8959.29it/s]\n",
"943it [00:00, 9016.78it/s]\n",
"943it [00:00, 8085.81it/s]\n",
"943it [00:00, 8341.37it/s]\n",
"943it [00:00, 9531.98it/s]\n",
"943it [00:00, 9952.14it/s]\n",
"943it [00:00, 9774.37it/s]\n",
"943it [00:00, 9543.76it/s]\n",
"943it [00:00, 9634.07it/s]\n",
"943it [00:00, 9988.71it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>HR2</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>0.492047</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVD</td>\n",
" <td>0.915079</td>\n",
" <td>0.718240</td>\n",
" <td>0.104772</td>\n",
" <td>0.045496</td>\n",
" <td>0.054393</td>\n",
" <td>0.071374</td>\n",
" <td>0.094421</td>\n",
" <td>0.076826</td>\n",
" <td>0.109517</td>\n",
" <td>0.052005</td>\n",
" <td>0.206646</td>\n",
" <td>0.519484</td>\n",
" <td>0.487805</td>\n",
" <td>0.264051</td>\n",
" <td>0.874549</td>\n",
" <td>0.142136</td>\n",
" <td>3.890472</td>\n",
" <td>0.972126</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>0.239661</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_KNNSurprisetask</td>\n",
" <td>0.946255</td>\n",
" <td>0.745209</td>\n",
" <td>0.083457</td>\n",
" <td>0.032848</td>\n",
" <td>0.041227</td>\n",
" <td>0.055493</td>\n",
" <td>0.074785</td>\n",
" <td>0.048890</td>\n",
" <td>0.089577</td>\n",
" <td>0.040902</td>\n",
" <td>0.189057</td>\n",
" <td>0.513076</td>\n",
" <td>0.417815</td>\n",
" <td>0.217391</td>\n",
" <td>0.888547</td>\n",
" <td>0.130592</td>\n",
" <td>3.611806</td>\n",
" <td>0.978659</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_GlobalAvg</td>\n",
" <td>1.125760</td>\n",
" <td>0.943534</td>\n",
" <td>0.061188</td>\n",
" <td>0.025968</td>\n",
" <td>0.031383</td>\n",
" <td>0.041343</td>\n",
" <td>0.040558</td>\n",
" <td>0.032107</td>\n",
" <td>0.067695</td>\n",
" <td>0.027470</td>\n",
" <td>0.171187</td>\n",
" <td>0.509546</td>\n",
" <td>0.384942</td>\n",
" <td>0.142100</td>\n",
" <td>1.000000</td>\n",
" <td>0.025974</td>\n",
" <td>2.711772</td>\n",
" <td>0.992003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Random</td>\n",
" <td>1.522798</td>\n",
" <td>1.222501</td>\n",
" <td>0.049841</td>\n",
" <td>0.020656</td>\n",
" <td>0.025232</td>\n",
" <td>0.033446</td>\n",
" <td>0.030579</td>\n",
" <td>0.022927</td>\n",
" <td>0.051680</td>\n",
" <td>0.019110</td>\n",
" <td>0.123085</td>\n",
" <td>0.506849</td>\n",
" <td>0.331919</td>\n",
" <td>0.119830</td>\n",
" <td>0.985048</td>\n",
" <td>0.183983</td>\n",
" <td>5.097973</td>\n",
" <td>0.907483</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.072110</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.004242</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.000000</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineIU</td>\n",
" <td>0.958136</td>\n",
" <td>0.754051</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.000000</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.000000</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.000000</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.000000</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
"0 Self_SVD 0.915079 0.718240 0.104772 0.045496 0.054393 \n",
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
"0 Self_KNNSurprisetask 0.946255 0.745209 0.083457 0.032848 0.041227 \n",
"0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n",
"0 Ready_Random 1.522798 1.222501 0.049841 0.020656 0.025232 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_BaselineIU 0.958136 0.754051 0.000954 0.000188 0.000298 \n",
"0 Self_TopRated 2.508258 2.217909 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
"0 0.071374 0.094421 0.076826 0.109517 0.052005 0.206646 \n",
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
"0 0.055493 0.074785 0.048890 0.089577 0.040902 0.189057 \n",
"0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n",
"0 0.033446 0.030579 0.022927 0.051680 0.019110 0.123085 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR HR2 Reco in test Test coverage Shannon \\\n",
"0 0.555546 0.765642 0.492047 1.000000 0.038961 3.159079 \n",
"0 0.519484 0.487805 0.264051 0.874549 0.142136 3.890472 \n",
"0 0.515501 0.437964 0.239661 1.000000 0.033911 2.836513 \n",
"0 0.513076 0.417815 0.217391 0.888547 0.130592 3.611806 \n",
"0 0.509546 0.384942 0.142100 1.000000 0.025974 2.711772 \n",
"0 0.506849 0.331919 0.119830 0.985048 0.183983 5.097973 \n",
"0 0.499885 0.154825 0.072110 0.402333 0.434343 5.133650 \n",
"0 0.496724 0.021209 0.004242 0.482821 0.059885 2.232578 \n",
"0 0.496441 0.007423 0.000000 0.602121 0.010823 2.089186 \n",
"0 0.496433 0.009544 0.000000 0.699046 0.005051 1.945910 \n",
"0 0.496433 0.009544 0.000000 0.699046 0.005051 1.945910 \n",
"0 0.496424 0.009544 0.000000 0.600530 0.005051 1.803126 \n",
"0 0.496391 0.003181 0.000000 0.392153 0.115440 4.174741 \n",
"\n",
" Gini \n",
"0 0.987317 \n",
"0 0.972126 \n",
"0 0.991139 \n",
"0 0.978659 \n",
"0 0.992003 \n",
"0 0.907483 \n",
"0 0.877999 \n",
"0 0.994487 \n",
"0 0.995706 \n",
"0 0.995669 \n",
"0 0.995669 \n",
"0 0.996380 \n",
"0 0.965327 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import imp\n",
"imp.reload(ev)\n",
"\n",
"import evaluation_measures as ev\n",
"dir_path=\"Recommendations generated/ml-100k/\"\n",
"super_reactions=[4,5]\n",
"test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embeddings"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1, 2],\n",
" [3, 4]])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"array([[0.4472136 , 0.89442719],\n",
" [0.6 , 0.8 ]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x=np.array([[1,2],[3,4]])\n",
"display(x)\n",
"x/np.linalg.norm(x, axis=1)[:,None]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>code</th>\n",
" <th>score</th>\n",
" <th>item_id</th>\n",
" <th>id</th>\n",
" <th>title</th>\n",
" <th>genres</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1638</td>\n",
" <td>1.000000</td>\n",
" <td>1639</td>\n",
" <td>1639</td>\n",
" <td>Bitter Sugar (Azucar Amargo) (1996)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>802</td>\n",
" <td>0.992833</td>\n",
" <td>803</td>\n",
" <td>803</td>\n",
" <td>Heaven &amp; Earth (1993)</td>\n",
" <td>Action, Drama, War</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1378</td>\n",
" <td>0.992618</td>\n",
" <td>1379</td>\n",
" <td>1379</td>\n",
" <td>Love and Other Catastrophes (1996)</td>\n",
" <td>Romance</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1130</td>\n",
" <td>0.991573</td>\n",
" <td>1131</td>\n",
" <td>1131</td>\n",
" <td>Safe (1995)</td>\n",
" <td>Thriller</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1199</td>\n",
" <td>0.991141</td>\n",
" <td>1200</td>\n",
" <td>1200</td>\n",
" <td>Kim (1950)</td>\n",
" <td>Children's, Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>1195</td>\n",
" <td>0.991040</td>\n",
" <td>1196</td>\n",
" <td>1196</td>\n",
" <td>Savage Nights (Nuits fauves, Les) (1992)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>1622</td>\n",
" <td>0.990832</td>\n",
" <td>1623</td>\n",
" <td>1623</td>\n",
" <td>Cérémonie, La (1995)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>1417</td>\n",
" <td>0.990285</td>\n",
" <td>1418</td>\n",
" <td>1418</td>\n",
" <td>Joy Luck Club, The (1993)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>1067</td>\n",
" <td>0.990192</td>\n",
" <td>1068</td>\n",
" <td>1068</td>\n",
" <td>Star Maker, The (Uomo delle stelle, L') (1995)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>1192</td>\n",
" <td>0.990168</td>\n",
" <td>1193</td>\n",
" <td>1193</td>\n",
" <td>Before the Rain (Pred dozhdot) (1994)</td>\n",
" <td>Drama</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" code score item_id id \\\n",
"0 1638 1.000000 1639 1639 \n",
"1 802 0.992833 803 803 \n",
"2 1378 0.992618 1379 1379 \n",
"3 1130 0.991573 1131 1131 \n",
"4 1199 0.991141 1200 1200 \n",
"5 1195 0.991040 1196 1196 \n",
"6 1622 0.990832 1623 1623 \n",
"7 1417 0.990285 1418 1418 \n",
"8 1067 0.990192 1068 1068 \n",
"9 1192 0.990168 1193 1193 \n",
"\n",
" title genres \n",
"0 Bitter Sugar (Azucar Amargo) (1996) Drama \n",
"1 Heaven & Earth (1993) Action, Drama, War \n",
"2 Love and Other Catastrophes (1996) Romance \n",
"3 Safe (1995) Thriller \n",
"4 Kim (1950) Children's, Drama \n",
"5 Savage Nights (Nuits fauves, Les) (1992) Drama \n",
"6 Cérémonie, La (1995) Drama \n",
"7 Joy Luck Club, The (1993) Drama \n",
"8 Star Maker, The (Uomo delle stelle, L') (1995) Drama \n",
"9 Before the Rain (Pred dozhdot) (1994) Drama "
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"item=random.choice(list(set(train_ui.indices)))\n",
"\n",
"embeddings_norm=model.Qi/np.linalg.norm(model.Qi, axis=1)[:,None] # we do not mean-center here\n",
"# omitting normalization also makes sense, but items with a greater magnitude will be recommended more often\n",
"\n",
"similarity_scores=np.dot(embeddings_norm,embeddings_norm[item].T)\n",
"top_similar_items=pd.DataFrame(enumerate(similarity_scores), columns=['code', 'score'])\\\n",
".sort_values(by=['score'], ascending=[False])[:10]\n",
"\n",
"top_similar_items['item_id']=top_similar_items['code'].apply(lambda x: item_code_id[x])\n",
"\n",
"items=pd.read_csv('./Datasets/ml-100k/movies.csv')\n",
"\n",
"result=pd.merge(top_similar_items, items, left_on='item_id', right_on='id')\n",
"\n",
"result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# project task 5: implement SVD on top baseline (as it is in Surprise library)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# making changes to our implementation by considering additional parameters in the gradient descent procedure \n",
"# seems to be the fastest option\n",
"# please save the output in 'Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv' and\n",
"# 'Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv'"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import tqdm\n",
"\n",
"\n",
"class SVDbaseline():\n",
" def __init__(self, train_ui, learning_rate, regularization, nb_factors, iterations):\n",
" self.train_ui = train_ui\n",
" self.uir = list(zip(*[train_ui.nonzero()[0], train_ui.nonzero()[1], train_ui.data]))\n",
" \n",
" self.learning_rate = learning_rate\n",
" self.regularization = regularization\n",
" self.iterations = iterations\n",
" self.nb_users, self.nb_items = train_ui.shape\n",
" self.nb_ratings = train_ui.nnz\n",
" self.nb_factors = nb_factors\n",
" \n",
" self.Bu = np.random.normal(loc = 0, scale = 1./self.nb_factors, size = (self.nb_users, self.nb_factors))\n",
" self.Bi = np.random.normal(loc = 0, scale = 1./self.nb_factors, size = (self.nb_items, self.nb_factors))\n",
" \n",
" self.Pu = np.random.normal(loc = 0, scale = 1./self.nb_factors, size = (self.nb_users, self.nb_factors))\n",
" self.Qi = np.random.normal(loc = 0, scale = 1./self.nb_factors, size = (self.nb_items, self.nb_factors))\n",
" \n",
" self.bias_i = np.zeros(self.nb_items)\n",
" self.bias_u = np.zeros(self.nb_users)\n",
"\n",
" \n",
" def train(self, test_ui = None):\n",
" if test_ui != None:\n",
" self.test_uir = list(zip(*[test_ui.nonzero()[0],test_ui.nonzero()[1], test_ui.data]))\n",
" \n",
" self.learning_process = []\n",
" pbar = tqdm(range(self.iterations))\n",
" \n",
" for i in pbar:\n",
" pbar.set_description(f'Epoch {i} RMSE: {self.learning_process[-1][1] if i > 0 else 0}. Training epoch {i + 1}...')\n",
" np.random.shuffle(self.uir)\n",
" self.sgd(self.uir)\n",
" \n",
" if test_ui == None:\n",
" self.learning_process.append([i + 1, self.RMSE_total(self.uir)])\n",
" else:\n",
" self.learning_process.append([i + 1, self.RMSE_total(self.uir), self.RMSE_total(self.test_uir)])\n",
" \n",
" \n",
" def sgd(self, uir):\n",
" for u, i, score in uir:\n",
" prediction = self.get_rating(u,i)\n",
" e = (score - prediction)\n",
" \n",
" Pu_update = self.learning_rate * (e * self.Qi[i] - self.regularization * self.Pu[u])\n",
" Qi_update = self.learning_rate * (e * self.Pu[u] - self.regularization * self.Qi[i])\n",
" \n",
" Bu_update = self.learning_rate * (e - self.regularization * self.Bu[u])\n",
" Bi_update = self.learning_rate * (e - self.regularization * self.Bi[i])\n",
" \n",
" self.Bu[u] += Bu_update\n",
" self.Bi[i] += Bi_update\n",
"\n",
" self.Pu[u] += Pu_update\n",
" self.Qi[i] += Qi_update\n",
" \n",
" \n",
" def get_rating(self, u, i):\n",
" prediction = self.Bu[u] + self.Bi[i] + self.Pu[u].dot(self.Qi[i].T)\n",
" return prediction\n",
" \n",
" \n",
" def RMSE_total(self, uir):\n",
" RMSE = 0\n",
" for u,i, score in uir:\n",
" prediction = self.get_rating(u,i)\n",
" RMSE += (score - prediction) ** 2\n",
" return np.sqrt(RMSE / len(uir))\n",
" \n",
" \n",
" def estimations(self):\n",
" self.estimations=\\\n",
" self.bias_u[:, np.newaxis] + self.bias_i[np.newaxis:,] + np.dot(self.Pu, self.Qi.T)\n",
"\n",
" \n",
" def recommend(self, user_code_id, item_code_id, topK = 10):\n",
" top_k = defaultdict(list)\n",
" for nb_user, user in enumerate(self.estimations):\n",
" user_rated = self.train_ui.indices[self.train_ui.indptr[nb_user]:self.train_ui.indptr[nb_user + 1]]\n",
" for item, score in enumerate(user):\n",
" if item not in user_rated and not np.isnan(score):\n",
" top_k[user_code_id[nb_user]].append((item_code_id[item], score))\n",
" result = []\n",
" for uid, item_scores in top_k.items():\n",
" item_scores.sort(key = lambda x: x[1], reverse = True)\n",
" result.append([uid] + list(chain(*item_scores[:topK])))\n",
" return result\n",
" \n",
" \n",
" def estimate(self, user_code_id, item_code_id, test_ui):\n",
" result = []\n",
" for user, item in zip(*test_ui.nonzero()):\n",
" result.append([user_code_id[user], item_code_id[item], \n",
" self.estimations[user, item] if not np.isnan(self.estimations[user, item]) else 1])\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 1 RMSE: [1.56628501 1.56647573 1.56581135 1.56616034 1.56624616 1.56602468/it]\n",
" 1.56609983 1.56608786 1.56617096 1.56609962 1.56627857 1.56624403\n",
" 1.56608856 1.56608238 1.56620356 1.56604183 1.56617325 1.56616721\n",
" 1.56608636 1.56621656 1.56601182 1.56629164 1.56603567 1.566008\n",
" 1.5659204 1.56600686 1.56632857 1.56623671 1.56650435 1.56614388\n",
" 1.56597602 1.56619724 1.56600564 1.56592808 1.5662823 1.56598423\n",
" 1.56630978 1.5661384 1.56637227 1.56600394 1.56599185 1.56598777\n",
" 1.56627173 1.56602758 1.56607052 1.56610967 1.56619676 1.5660723\n",
" 1.5661363 1.56643558 1.56599281 1.56602474 1.56615414 1.56619859\n",
" 1.56630092 1.56593518 1.56608721 1.56602644 1.56614329 1.56602186\n",
" 1.56606511 1.56610499 1.56610819 1.56593067 1.5661533 1.56616749\n",
" 1.56610497 1.5661211 1.56599833 1.56605015 1.56627997 1.56605438\n",
" 1.56626503 1.5660066 1.56630412 1.56620906 1.56620395 1.56631407\n",
" 1.56606347 1.56617128 1.56621069 1.56618785 1.56600689 1.56643735\n",
" 1.56598312 1.56615281 1.5661071 1.56631373 1.56604796 1.56585804\n",
" 1.56627589 1.56626862 1.56616867 1.56633676 1.56601226 1.5661184\n",
"Epoch 1 RMSE: [1.56628501 1.56647573 1.56581135 1.56616034 1.56624616 1.56602468 | 1/40 [00:02<01:40, 2.57s/it]\n",
" 1.56609983 1.56608786 1.56617096 1.56609962 1.56627857 1.56624403\n",
" 1.56608856 1.56608238 1.56620356 1.56604183 1.56617325 1.56616721\n",
" 1.56608636 1.56621656 1.56601182 1.56629164 1.56603567 1.566008\n",
" 1.5659204 1.56600686 1.56632857 1.56623671 1.56650435 1.56614388\n",
" 1.56597602 1.56619724 1.56600564 1.56592808 1.5662823 1.56598423\n",
" 1.56630978 1.5661384 1.56637227 1.56600394 1.56599185 1.56598777\n",
" 1.56627173 1.56602758 1.56607052 1.56610967 1.56619676 1.5660723\n",
" 1.5661363 1.56643558 1.56599281 1.56602474 1.56615414 1.56619859\n",
" 1.56630092 1.56593518 1.56608721 1.56602644 1.56614329 1.56602186\n",
" 1.56606511 1.56610499 1.56610819 1.56593067 1.5661533 1.56616749\n",
" 1.56610497 1.5661211 1.56599833 1.56605015 1.56627997 1.56605438\n",
" 1.56626503 1.5660066 1.56630412 1.56620906 1.56620395 1.56631407\n",
" 1.56606347 1.56617128 1.56621069 1.56618785 1.56600689 1.56643735\n",
" 1.56598312 1.56615281 1.5661071 1.56631373 1.56604796 1.56585804\n",
" 1.56627589 1.56626862 1.56616867 1.56633676 1.56601226 1.5661184\n",
"Epoch 2 RMSE: [1.2417018 1.24188159 1.24150993 1.24174429 1.24176255 1.24170722 | 2/40 [00:05<01:37, 2.58s/it]\n",
" 1.24171074 1.24167268 1.24172545 1.24174763 1.24183238 1.24173671\n",
" 1.24163026 1.24162065 1.24177366 1.24158469 1.24177297 1.24166952\n",
" 1.24171917 1.24183222 1.24162599 1.24178984 1.24168821 1.24166123\n",
" 1.24161223 1.24162659 1.24185753 1.24183293 1.24185593 1.24174315\n",
" 1.24165528 1.2416642 1.241631 1.24156043 1.24184828 1.24160591\n",
" 1.24178785 1.2416347 1.24184533 1.2416611 1.24161786 1.24155744\n",
" 1.24177065 1.24165326 1.24171443 1.24170905 1.24183533 1.24169338\n",
" 1.24169124 1.24192236 1.24165418 1.24165594 1.2417279 1.24175767\n",
" 1.24186364 1.24169617 1.24170661 1.24161604 1.24172478 1.24166061\n",
" 1.24167343 1.24165278 1.24162706 1.2415709 1.24172076 1.24168487\n",
" 1.24171406 1.24165323 1.24165323 1.24166652 1.24179703 1.24169395\n",
" 1.24174949 1.24163659 1.2418293 1.24176263 1.24171404 1.2418294\n",
" 1.24168063 1.24165637 1.24178972 1.24178604 1.24169735 1.241849\n",
" 1.24170329 1.24177262 1.24171872 1.24179172 1.24171928 1.24153613\n",
" 1.24177704 1.24185793 1.24168956 1.2417221 1.24165712 1.24172613\n",
"Epoch 2 RMSE: [1.2417018 1.24188159 1.24150993 1.24174429 1.24176255 1.24170722 | 2/40 [00:05<01:37, 2.58s/it]\n",
" 1.24171074 1.24167268 1.24172545 1.24174763 1.24183238 1.24173671\n",
" 1.24163026 1.24162065 1.24177366 1.24158469 1.24177297 1.24166952\n",
" 1.24171917 1.24183222 1.24162599 1.24178984 1.24168821 1.24166123\n",
" 1.24161223 1.24162659 1.24185753 1.24183293 1.24185593 1.24174315\n",
" 1.24165528 1.2416642 1.241631 1.24156043 1.24184828 1.24160591\n",
" 1.24178785 1.2416347 1.24184533 1.2416611 1.24161786 1.24155744\n",
" 1.24177065 1.24165326 1.24171443 1.24170905 1.24183533 1.24169338\n",
" 1.24169124 1.24192236 1.24165418 1.24165594 1.2417279 1.24175767\n",
" 1.24186364 1.24169617 1.24170661 1.24161604 1.24172478 1.24166061\n",
" 1.24167343 1.24165278 1.24162706 1.2415709 1.24172076 1.24168487\n",
" 1.24171406 1.24165323 1.24165323 1.24166652 1.24179703 1.24169395\n",
" 1.24174949 1.24163659 1.2418293 1.24176263 1.24171404 1.2418294\n",
" 1.24168063 1.24165637 1.24178972 1.24178604 1.24169735 1.241849\n",
" 1.24170329 1.24177262 1.24171872 1.24179172 1.24171928 1.24153613\n",
" 1.24177704 1.24185793 1.24168956 1.2417221 1.24165712 1.24172613\n",
"Epoch 3 RMSE: [1.13258243 1.13273769 1.13249398 1.13265935 1.13264375 1.13263674 | 3/40 [00:07<01:35, 2.58s/it]\n",
" 1.13264279 1.13259496 1.13263666 1.13267439 1.13270287 1.13263871\n",
" 1.13256654 1.13256675 1.13267736 1.13252626 1.13269081 1.13258232\n",
" 1.13262013 1.13274078 1.13258389 1.13267925 1.13263407 1.13259053\n",
" 1.13256681 1.13256461 1.13271984 1.13272441 1.13271252 1.13264455\n",
" 1.13259818 1.13258742 1.13258444 1.1325301 1.13271939 1.13255725\n",
" 1.13265228 1.13257112 1.13272558 1.13259942 1.13256249 1.13251124\n",
" 1.13267438 1.13259717 1.13263932 1.13263856 1.132744 1.13260791\n",
" 1.13260569 1.13276277 1.13258999 1.13260022 1.13263504 1.13267512\n",
" 1.13274647 1.13265384 1.13265274 1.13255779 1.13264575 1.13261409\n",
" 1.13261241 1.13257776 1.13255374 1.13251717 1.13263993 1.13258915\n",
" 1.13262861 1.13256498 1.13260396 1.13259002 1.13269051 1.13262309\n",
" 1.13264451 1.1326 1.13270356 1.13266447 1.13263253 1.13269319\n",
" 1.13262264 1.13256405 1.1326936 1.13267126 1.1326336 1.13270766\n",
" 1.13266901 1.1326705 1.13265684 1.13267642 1.13266019 1.13251334\n",
" 1.13266081 1.13273492 1.13259381 1.13259615 1.13258523 1.13264068\n",
"Epoch 3 RMSE: [1.13258243 1.13273769 1.13249398 1.13265935 1.13264375 1.13263674 | 3/40 [00:07<01:35, 2.58s/it]\n",
" 1.13264279 1.13259496 1.13263666 1.13267439 1.13270287 1.13263871\n",
" 1.13256654 1.13256675 1.13267736 1.13252626 1.13269081 1.13258232\n",
" 1.13262013 1.13274078 1.13258389 1.13267925 1.13263407 1.13259053\n",
" 1.13256681 1.13256461 1.13271984 1.13272441 1.13271252 1.13264455\n",
" 1.13259818 1.13258742 1.13258444 1.1325301 1.13271939 1.13255725\n",
" 1.13265228 1.13257112 1.13272558 1.13259942 1.13256249 1.13251124\n",
" 1.13267438 1.13259717 1.13263932 1.13263856 1.132744 1.13260791\n",
" 1.13260569 1.13276277 1.13258999 1.13260022 1.13263504 1.13267512\n",
" 1.13274647 1.13265384 1.13265274 1.13255779 1.13264575 1.13261409\n",
" 1.13261241 1.13257776 1.13255374 1.13251717 1.13263993 1.13258915\n",
" 1.13262861 1.13256498 1.13260396 1.13259002 1.13269051 1.13262309\n",
" 1.13264451 1.1326 1.13270356 1.13266447 1.13263253 1.13269319\n",
" 1.13262264 1.13256405 1.1326936 1.13267126 1.1326336 1.13270766\n",
" 1.13266901 1.1326705 1.13265684 1.13267642 1.13266019 1.13251334\n",
" 1.13266081 1.13273492 1.13259381 1.13259615 1.13258523 1.13264068\n",
"Epoch 4 RMSE: [1.07470103 1.07483205 1.074653 1.07478245 1.07474992 1.07475779 | 4/40 [00:10<01:30, 2.53s/it]\n",
" 1.07476855 1.07471954 1.07475657 1.07479501 1.07479633 1.07475995\n",
" 1.07471091 1.07471076 1.07478729 1.07466922 1.074806 1.07471156\n",
" 1.07473267 1.07485138 1.07472955 1.07478602 1.07476246 1.07471623\n",
" 1.07470393 1.07470486 1.07480787 1.07483042 1.07480929 1.07475672\n",
" 1.074736 1.07471934 1.07472499 1.07468818 1.07480974 1.0747082\n",
" 1.07475006 1.07471146 1.07482552 1.07473571 1.07469558 1.07466382\n",
" 1.07478714 1.07473432 1.07476276 1.07476954 1.07485164 1.07472884\n",
" 1.07472827 1.07483967 1.0747193 1.07473498 1.07475513 1.07479379\n",
" 1.07484481 1.07478596 1.074786 1.07469995 1.07477074 1.07475607\n",
" 1.07475005 1.07470749 1.07469114 1.0746581 1.07476066 1.0747095\n",
" 1.07474566 1.07469449 1.07474116 1.07472055 1.07479978 1.07475009\n",
" 1.07476008 1.07473934 1.07480159 1.07477503 1.07476399 1.07478594\n",
" 1.07475659 1.07469004 1.0748058 1.07477358 1.07476343 1.07479883\n",
" 1.07480411 1.07477767 1.0747872 1.07478549 1.07478493 1.07467267\n",
" 1.07476843 1.07482625 1.07471241 1.0747061 1.07471244 1.07475596\n",
"Epoch 4 RMSE: [1.07470103 1.07483205 1.074653 1.07478245 1.07474992 1.07475779 | 4/40 [00:10<01:30, 2.53s/it] \n",
" 1.07476855 1.07471954 1.07475657 1.07479501 1.07479633 1.07475995\n",
" 1.07471091 1.07471076 1.07478729 1.07466922 1.074806 1.07471156\n",
" 1.07473267 1.07485138 1.07472955 1.07478602 1.07476246 1.07471623\n",
" 1.07470393 1.07470486 1.07480787 1.07483042 1.07480929 1.07475672\n",
" 1.074736 1.07471934 1.07472499 1.07468818 1.07480974 1.0747082\n",
" 1.07475006 1.07471146 1.07482552 1.07473571 1.07469558 1.07466382\n",
" 1.07478714 1.07473432 1.07476276 1.07476954 1.07485164 1.07472884\n",
" 1.07472827 1.07483967 1.0747193 1.07473498 1.07475513 1.07479379\n",
" 1.07484481 1.07478596 1.074786 1.07469995 1.07477074 1.07475607\n",
" 1.07475005 1.07470749 1.07469114 1.0746581 1.07476066 1.0747095\n",
" 1.07474566 1.07469449 1.07474116 1.07472055 1.07479978 1.07475009\n",
" 1.07476008 1.07473934 1.07480159 1.07477503 1.07476399 1.07478594\n",
" 1.07475659 1.07469004 1.0748058 1.07477358 1.07476343 1.07479883\n",
" 1.07480411 1.07477767 1.0747872 1.07478549 1.07478493 1.07467267\n",
" 1.07476843 1.07482625 1.07471241 1.0747061 1.07471244 1.07475596\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 5 RMSE: [1.03814976 1.03826045 1.03811633 1.03822671 1.03818608 1.03819672 | 5/40 [00:12<01:28, 2.54s/it]\n",
" 1.03821107 1.03816502 1.03819735 1.03823394 1.03822099 1.03820276\n",
" 1.03817009 1.03816776 1.03822115 1.03812697 1.03824053 1.03816213\n",
" 1.0381714 1.03828314 1.03818505 1.03821938 1.03820537 1.03816112\n",
" 1.03815479 1.03816081 1.03823001 1.03826181 1.03823864 1.03819418\n",
" 1.03818894 1.03817048 1.03817742 1.03815259 1.03823377 1.03817128\n",
" 1.03818148 1.03816561 1.03825063 1.03819015 1.03814364 1.03812778\n",
" 1.0382225 1.03818517 1.03820551 1.03821651 1.038281 1.03817264\n",
" 1.03817241 1.03825447 1.0381662 1.03818432 1.03820159 1.03823161\n",
" 1.03826976 1.03823002 1.038232 1.03815758 1.03821445 1.03820907\n",
" 1.03820234 1.03815744 1.03814712 1.03811497 1.03820183 1.03815373\n",
" 1.03818401 1.03814703 1.03819164 1.03817238 1.03823326 1.03819768\n",
" 1.03820049 1.03818693 1.03822919 1.03820819 1.0382148 1.03821205\n",
" 1.0382052 1.03813962 1.03824113 1.03820524 1.0382107 1.03822166\n",
" 1.03824703 1.03821022 1.03823076 1.03822139 1.03822339 1.03813808\n",
" 1.03820354 1.038249 1.03815644 1.03814753 1.03816163 1.03819351\n",
"Epoch 5 RMSE: [1.03814976 1.03826045 1.03811633 1.03822671 1.03818608 1.03819672 | 5/40 [00:12<01:28, 2.54s/it]\n",
" 1.03821107 1.03816502 1.03819735 1.03823394 1.03822099 1.03820276\n",
" 1.03817009 1.03816776 1.03822115 1.03812697 1.03824053 1.03816213\n",
" 1.0381714 1.03828314 1.03818505 1.03821938 1.03820537 1.03816112\n",
" 1.03815479 1.03816081 1.03823001 1.03826181 1.03823864 1.03819418\n",
" 1.03818894 1.03817048 1.03817742 1.03815259 1.03823377 1.03817128\n",
" 1.03818148 1.03816561 1.03825063 1.03819015 1.03814364 1.03812778\n",
" 1.0382225 1.03818517 1.03820551 1.03821651 1.038281 1.03817264\n",
" 1.03817241 1.03825447 1.0381662 1.03818432 1.03820159 1.03823161\n",
" 1.03826976 1.03823002 1.038232 1.03815758 1.03821445 1.03820907\n",
" 1.03820234 1.03815744 1.03814712 1.03811497 1.03820183 1.03815373\n",
" 1.03818401 1.03814703 1.03819164 1.03817238 1.03823326 1.03819768\n",
" 1.03820049 1.03818693 1.03822919 1.03820819 1.0382148 1.03821205\n",
" 1.0382052 1.03813962 1.03824113 1.03820524 1.0382107 1.03822166\n",
" 1.03824703 1.03821022 1.03823076 1.03822139 1.03822339 1.03813808\n",
" 1.03820354 1.038249 1.03815644 1.03814753 1.03816163 1.03819351\n",
"Epoch 6 RMSE: [1.01281717 1.01291107 1.01278938 1.01288647 1.01284443 1.01285417 | 6/40 [00:15<01:28, 2.60s/it]\n",
" 1.01286901 1.01282764 1.01285544 1.01288953 1.0128703 1.01286329\n",
" 1.01284107 1.01283658 1.01287412 1.01279734 1.01289272 1.01282785\n",
" 1.01283052 1.01293233 1.01285139 1.01287313 1.01286354 1.01282315\n",
" 1.01281977 1.01283023 1.01287715 1.01291256 1.0128901 1.01285057\n",
" 1.01285464 1.01283705 1.01284262 1.01282647 1.01288275 1.01284451\n",
" 1.01283583 1.01283238 1.01289636 1.01285878 1.01280834 1.01280304\n",
" 1.01287596 1.01284933 1.01286522 1.01287812 1.0129298 1.01283542\n",
" 1.01283487 1.01289631 1.01282952 1.0128476 1.0128641 1.01288633\n",
" 1.01291558 1.01288799 1.01289213 1.01282762 1.01287417 1.01287394\n",
" 1.01286766 1.01282299 1.01281667 1.01278599 1.01286056 1.01281531\n",
" 1.01284136 1.01281493 1.01285498 1.01284074 1.01288644 1.012862\n",
" 1.01285987 1.01284788 1.01287841 1.01286073 1.01287958 1.01286209\n",
" 1.01286818 1.01280658 1.01289446 1.01285892 1.01287203 1.01286922\n",
" 1.01290259 1.01286311 1.01288884 1.01287717 1.01287654 1.01281316\n",
" 1.01285885 1.01289549 1.01281868 1.01280926 1.01282758 1.01284977\n",
"Epoch 6 RMSE: [1.01281717 1.01291107 1.01278938 1.01288647 1.01284443 1.01285417 | 6/40 [00:15<01:28, 2.60s/it] \n",
" 1.01286901 1.01282764 1.01285544 1.01288953 1.0128703 1.01286329\n",
" 1.01284107 1.01283658 1.01287412 1.01279734 1.01289272 1.01282785\n",
" 1.01283052 1.01293233 1.01285139 1.01287313 1.01286354 1.01282315\n",
" 1.01281977 1.01283023 1.01287715 1.01291256 1.0128901 1.01285057\n",
" 1.01285464 1.01283705 1.01284262 1.01282647 1.01288275 1.01284451\n",
" 1.01283583 1.01283238 1.01289636 1.01285878 1.01280834 1.01280304\n",
" 1.01287596 1.01284933 1.01286522 1.01287812 1.0129298 1.01283542\n",
" 1.01283487 1.01289631 1.01282952 1.0128476 1.0128641 1.01288633\n",
" 1.01291558 1.01288799 1.01289213 1.01282762 1.01287417 1.01287394\n",
" 1.01286766 1.01282299 1.01281667 1.01278599 1.01286056 1.01281531\n",
" 1.01284136 1.01281493 1.01285498 1.01284074 1.01288644 1.012862\n",
" 1.01285987 1.01284788 1.01287841 1.01286073 1.01287958 1.01286209\n",
" 1.01286818 1.01280658 1.01289446 1.01285892 1.01287203 1.01286922\n",
" 1.01290259 1.01286311 1.01288884 1.01287717 1.01287654 1.01281316\n",
" 1.01285885 1.01289549 1.01281868 1.01280926 1.01282758 1.01284977\n",
"Epoch 7 RMSE: [0.99440278 0.99448285 0.99437689 0.99446413 0.99442374 0.99443101 | 7/40 [00:17<01:23, 2.53s/it]\n",
" 0.99444568 0.9944093 0.99443276 0.99446378 0.99444228 0.99444224\n",
" 0.99442689 0.99442063 0.99444827 0.99438454 0.99446507 0.99441125\n",
" 0.99441043 0.99450116 0.99443336 0.99444832 0.9944398 0.99440386\n",
" 0.99440271 0.9944157 0.99444763 0.99448437 0.99446327 0.99442748\n",
" 0.99443694 0.99442101 0.9944241 0.99441448 0.99445494 0.99443242\n",
" 0.99441235 0.99441589 0.99446372 0.99444438 0.99439146 0.99439286\n",
" 0.99444993 0.99443034 0.99444365 0.99445697 0.99449893 0.99441738\n",
" 0.99441652 0.99446308 0.99441141 0.99442816 0.99444559 0.9944599\n",
" 0.99448334 0.99446342 0.99447032 0.99441352 0.99445264 0.99445534\n",
" 0.99444959 0.9944058 0.99440297 0.99437383 0.99443823 0.99439607\n",
" 0.99441945 0.9944003 0.99443535 0.99442618 0.99445976 0.99444387\n",
" 0.99443846 0.99442597 0.99444931 0.99443373 0.99446158 0.99443475\n",
" 0.99444772 0.99439111 0.99446809 0.99443495 0.99445024 0.99444\n",
" 0.99447577 0.99443738 0.99446482 0.99445296 0.99444964 0.99440218\n",
" 0.9944348 0.99446505 0.99440051 0.99439129 0.99441129 0.99442638\n",
"Epoch 7 RMSE: [0.99440278 0.99448285 0.99437689 0.99446413 0.99442374 0.99443101 | 7/40 [00:17<01:23, 2.53s/it] \n",
" 0.99444568 0.9944093 0.99443276 0.99446378 0.99444228 0.99444224\n",
" 0.99442689 0.99442063 0.99444827 0.99438454 0.99446507 0.99441125\n",
" 0.99441043 0.99450116 0.99443336 0.99444832 0.9944398 0.99440386\n",
" 0.99440271 0.9944157 0.99444763 0.99448437 0.99446327 0.99442748\n",
" 0.99443694 0.99442101 0.9944241 0.99441448 0.99445494 0.99443242\n",
" 0.99441235 0.99441589 0.99446372 0.99444438 0.99439146 0.99439286\n",
" 0.99444993 0.99443034 0.99444365 0.99445697 0.99449893 0.99441738\n",
" 0.99441652 0.99446308 0.99441141 0.99442816 0.99444559 0.9944599\n",
" 0.99448334 0.99446342 0.99447032 0.99441352 0.99445264 0.99445534\n",
" 0.99444959 0.9944058 0.99440297 0.99437383 0.99443823 0.99439607\n",
" 0.99441945 0.9944003 0.99443535 0.99442618 0.99445976 0.99444387\n",
" 0.99443846 0.99442597 0.99444931 0.99443373 0.99446158 0.99443475\n",
" 0.99444772 0.99439111 0.99446809 0.99443495 0.99445024 0.99444\n",
" 0.99447577 0.99443738 0.99446482 0.99445296 0.99444964 0.99440218\n",
" 0.9944348 0.99446505 0.99440051 0.99439129 0.99441129 0.99442638\n",
"Epoch 8 RMSE: [0.98042218 0.98049048 0.98039697 0.98047584 0.98043855 0.98044358 | 8/40 [00:20<01:19, 2.48s/it]\n",
" 0.9804572 0.98042485 0.98044487 0.98047328 0.9804515 0.98045602\n",
" 0.98044497 0.98043727 0.98045856 0.98040473 0.98047328 0.98042834\n",
" 0.98042557 0.98050617 0.98044835 0.98045934 0.98045096 0.98041956\n",
" 0.98041958 0.98043437 0.98045558 0.98049181 0.98047231 0.98043977\n",
" 0.98045187 0.98043782 0.98043933 0.98043387 0.98046433 0.98045169\n",
" 0.98042538 0.98043231 0.98046841 0.9804625 0.98040982 0.98041445\n",
" 0.98046011 0.98044478 0.98045634 0.98046963 0.98050411 0.98043312\n",
" 0.98043237 0.98046834 0.98042783 0.98044328 0.98046024 0.98046917\n",
" 0.98048834 0.98047363 0.98048218 0.98043166 0.98046508 0.9804699\n",
" 0.98046497 0.98042322 0.9804222 0.98039509 0.98045161 0.98041184\n",
" 0.98043266 0.98041952 0.98044893 0.98044451 0.98046886 0.98045955\n",
" 0.98045228 0.98043826 0.98045715 0.98044318 0.98047617 0.98044425\n",
" 0.98046183 0.98041042 0.98047757 0.9804468 0.98046262 0.98044893\n",
" 0.98048381 0.98044816 0.98047516 0.98046413 0.98045799 0.98042268\n",
" 0.98044622 0.98047213 0.98041723 0.98040793 0.98042941 0.98043912\n",
"Epoch 8 RMSE: [0.98042218 0.98049048 0.98039697 0.98047584 0.98043855 0.98044358 | 8/40 [00:20<01:19, 2.48s/it]\n",
" 0.9804572 0.98042485 0.98044487 0.98047328 0.9804515 0.98045602\n",
" 0.98044497 0.98043727 0.98045856 0.98040473 0.98047328 0.98042834\n",
" 0.98042557 0.98050617 0.98044835 0.98045934 0.98045096 0.98041956\n",
" 0.98041958 0.98043437 0.98045558 0.98049181 0.98047231 0.98043977\n",
" 0.98045187 0.98043782 0.98043933 0.98043387 0.98046433 0.98045169\n",
" 0.98042538 0.98043231 0.98046841 0.9804625 0.98040982 0.98041445\n",
" 0.98046011 0.98044478 0.98045634 0.98046963 0.98050411 0.98043312\n",
" 0.98043237 0.98046834 0.98042783 0.98044328 0.98046024 0.98046917\n",
" 0.98048834 0.98047363 0.98048218 0.98043166 0.98046508 0.9804699\n",
" 0.98046497 0.98042322 0.9804222 0.98039509 0.98045161 0.98041184\n",
" 0.98043266 0.98041952 0.98044893 0.98044451 0.98046886 0.98045955\n",
" 0.98045228 0.98043826 0.98045715 0.98044318 0.98047617 0.98044425\n",
" 0.98046183 0.98041042 0.98047757 0.9804468 0.98046262 0.98044893\n",
" 0.98048381 0.98044816 0.98047516 0.98046413 0.98045799 0.98042268\n",
" 0.98044622 0.98047213 0.98041723 0.98040793 0.98042941 0.98043912\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 9 RMSE: [0.96956864 0.96962721 0.96954354 0.96961556 0.96958205 0.96958458 | 9/40 [00:22<01:16, 2.47s/it]\n",
" 0.96959707 0.96956865 0.96958616 0.96961134 0.96959051 0.96959811\n",
" 0.96958923 0.96958072 0.96959807 0.96955239 0.96961096 0.96957265\n",
" 0.96956903 0.96964055 0.96959048 0.96959949 0.96959073 0.96956322\n",
" 0.96956373 0.96957968 0.96959395 0.9696286 0.96961054 0.96958063\n",
" 0.96959443 0.96958199 0.96958222 0.96957953 0.96960364 0.96959708\n",
" 0.96956749 0.96957626 0.96960309 0.96960737 0.96955603 0.96956241\n",
" 0.96959943 0.96958725 0.96959703 0.96961002 0.96963852 0.96957742\n",
" 0.96957636 0.9696043 0.96957224 0.96958563 0.96960255 0.96960695\n",
" 0.96962291 0.96961234 0.9696223 0.96957714 0.96960567 0.9696117\n",
" 0.96960751 0.96956774 0.96956802 0.96954356 0.9695932 0.96955556\n",
" 0.96957482 0.96956564 0.96959001 0.96958992 0.96960666 0.96960246\n",
" 0.96959448 0.96957906 0.96959454 0.96958207 0.96961755 0.96958331\n",
" 0.9696031 0.9695574 0.9696157 0.96958767 0.96960284 0.96958787\n",
" 0.96962011 0.969588 0.96961464 0.9696038 0.96959544 0.96956941\n",
" 0.96958616 0.96960883 0.9695618 0.96955266 0.9695746 0.96958055\n",
"Epoch 9 RMSE: [0.96956864 0.96962721 0.96954354 0.96961556 0.96958205 0.96958458 | 9/40 [00:22<01:16, 2.47s/it]\n",
" 0.96959707 0.96956865 0.96958616 0.96961134 0.96959051 0.96959811\n",
" 0.96958923 0.96958072 0.96959807 0.96955239 0.96961096 0.96957265\n",
" 0.96956903 0.96964055 0.96959048 0.96959949 0.96959073 0.96956322\n",
" 0.96956373 0.96957968 0.96959395 0.9696286 0.96961054 0.96958063\n",
" 0.96959443 0.96958199 0.96958222 0.96957953 0.96960364 0.96959708\n",
" 0.96956749 0.96957626 0.96960309 0.96960737 0.96955603 0.96956241\n",
" 0.96959943 0.96958725 0.96959703 0.96961002 0.96963852 0.96957742\n",
" 0.96957636 0.9696043 0.96957224 0.96958563 0.96960255 0.96960695\n",
" 0.96962291 0.96961234 0.9696223 0.96957714 0.96960567 0.9696117\n",
" 0.96960751 0.96956774 0.96956802 0.96954356 0.9695932 0.96955556\n",
" 0.96957482 0.96956564 0.96959001 0.96958992 0.96960666 0.96960246\n",
" 0.96959448 0.96957906 0.96959454 0.96958207 0.96961755 0.96958331\n",
" 0.9696031 0.9695574 0.9696157 0.96958767 0.96960284 0.96958787\n",
" 0.96962011 0.969588 0.96961464 0.9696038 0.96959544 0.96956941\n",
" 0.96958616 0.96960883 0.9695618 0.96955266 0.9695746 0.96958055\n",
"Epoch 10 RMSE: [0.96075673 0.96080761 0.96073196 0.96079776 0.96076804 0.96076873 | 10/40 [00:24<01:13, 2.44s/it]\n",
" 0.96078039 0.96075495 0.96077015 0.96079231 0.96077356 0.96078261\n",
" 0.96077559 0.96076643 0.96078098 0.9607412 0.96079177 0.96075896\n",
" 0.96075491 0.96081856 0.96077505 0.96078216 0.9607736 0.96074991\n",
" 0.96075047 0.96076696 0.96077613 0.9608088 0.9607921 0.96076446\n",
" 0.96077892 0.96076825 0.96076753 0.96076714 0.96078607 0.96078396\n",
" 0.9607528 0.96076217 0.96078209 0.9607937 0.96074464 0.96075185\n",
" 0.96078205 0.96077184 0.96078107 0.96079289 0.96081683 0.96076394\n",
" 0.96076258 0.96078479 0.96075914 0.96077056 0.96078714 0.96078814\n",
" 0.96080169 0.9607937 0.96080479 0.96076411 0.96078901 0.96079532\n",
" 0.96079197 0.9607546 0.96075574 0.96073366 0.96077769 0.96074198\n",
" 0.96075972 0.96075414 0.96077353 0.96077696 0.96078774 0.96078751\n",
" 0.96077925 0.96076313 0.9607757 0.96076468 0.96080126 0.96076611\n",
" 0.96078704 0.96074671 0.9607969 0.96077184 0.9607857 0.96077078\n",
" 0.9607998 0.96077123 0.96079652 0.96078645 0.96077646 0.96075732\n",
" 0.96076928 0.96078951 0.96074835 0.96074048 0.96076206 0.96076493\n",
"Epoch 10 RMSE: [0.96075673 0.96080761 0.96073196 0.96079776 0.96076804 0.96076873 | 10/40 [00:24<01:13, 2.44s/it]\n",
" 0.96078039 0.96075495 0.96077015 0.96079231 0.96077356 0.96078261\n",
" 0.96077559 0.96076643 0.96078098 0.9607412 0.96079177 0.96075896\n",
" 0.96075491 0.96081856 0.96077505 0.96078216 0.9607736 0.96074991\n",
" 0.96075047 0.96076696 0.96077613 0.9608088 0.9607921 0.96076446\n",
" 0.96077892 0.96076825 0.96076753 0.96076714 0.96078607 0.96078396\n",
" 0.9607528 0.96076217 0.96078209 0.9607937 0.96074464 0.96075185\n",
" 0.96078205 0.96077184 0.96078107 0.96079289 0.96081683 0.96076394\n",
" 0.96076258 0.96078479 0.96075914 0.96077056 0.96078714 0.96078814\n",
" 0.96080169 0.9607937 0.96080479 0.96076411 0.96078901 0.96079532\n",
" 0.96079197 0.9607546 0.96075574 0.96073366 0.96077769 0.96074198\n",
" 0.96075972 0.96075414 0.96077353 0.96077696 0.96078774 0.96078751\n",
" 0.96077925 0.96076313 0.9607757 0.96076468 0.96080126 0.96076611\n",
" 0.96078704 0.96074671 0.9607969 0.96077184 0.9607857 0.96077078\n",
" 0.9607998 0.96077123 0.96079652 0.96078645 0.96077646 0.96075732\n",
" 0.96076928 0.96078951 0.96074835 0.96074048 0.96076206 0.96076493\n",
"Epoch 11 RMSE: [0.95361422 0.9536583 0.95358983 0.95365028 0.95362411 0.95362327 | 11/40 [00:27<01:09, 2.41s/it]\n",
" 0.95363393 0.95361104 0.95362427 0.95364368 0.95362728 0.95363712\n",
" 0.95363111 0.95362154 0.95363403 0.95359944 0.95364333 0.95361499\n",
" 0.953611 0.95366754 0.95362947 0.95363561 0.95362648 0.95360615\n",
" 0.95360678 0.95362319 0.95362914 0.95365963 0.95364409 0.95361836\n",
" 0.95363346 0.95362369 0.95362264 0.95362344 0.95363932 0.95363981\n",
" 0.95360837 0.95361782 0.95363242 0.95364907 0.95360271 0.95361032\n",
" 0.95363511 0.95362669 0.95363478 0.95364577 0.95366592 0.95362025\n",
" 0.95361858 0.95363617 0.95361555 0.95362547 0.95364146 0.95363965\n",
" 0.9536515 0.95364549 0.95365732 0.95362043 0.95364252 0.95364887\n",
" 0.95364593 0.9536111 0.95361243 0.95359329 0.95363234 0.95359863\n",
" 0.95361472 0.95361184 0.95362725 0.95363318 0.95363953 0.95364235\n",
" 0.95363404 0.95361743 0.95362787 0.9536177 0.95365451 0.95361986\n",
" 0.95364087 0.9536052 0.95364885 0.9536261 0.9536388 0.95362439\n",
" 0.95365 0.9536251 0.95364881 0.95363917 0.95362836 0.95361419\n",
" 0.95362283 0.95364097 0.95360492 0.95359786 0.95361872 0.95361965\n",
"Epoch 11 RMSE: [0.95361422 0.9536583 0.95358983 0.95365028 0.95362411 0.95362327 | 11/40 [00:27<01:09, 2.41s/it]\n",
" 0.95363393 0.95361104 0.95362427 0.95364368 0.95362728 0.95363712\n",
" 0.95363111 0.95362154 0.95363403 0.95359944 0.95364333 0.95361499\n",
" 0.953611 0.95366754 0.95362947 0.95363561 0.95362648 0.95360615\n",
" 0.95360678 0.95362319 0.95362914 0.95365963 0.95364409 0.95361836\n",
" 0.95363346 0.95362369 0.95362264 0.95362344 0.95363932 0.95363981\n",
" 0.95360837 0.95361782 0.95363242 0.95364907 0.95360271 0.95361032\n",
" 0.95363511 0.95362669 0.95363478 0.95364577 0.95366592 0.95362025\n",
" 0.95361858 0.95363617 0.95361555 0.95362547 0.95364146 0.95363965\n",
" 0.9536515 0.95364549 0.95365732 0.95362043 0.95364252 0.95364887\n",
" 0.95364593 0.9536111 0.95361243 0.95359329 0.95363234 0.95359863\n",
" 0.95361472 0.95361184 0.95362725 0.95363318 0.95363953 0.95364235\n",
" 0.95363404 0.95361743 0.95362787 0.9536177 0.95365451 0.95361986\n",
" 0.95364087 0.9536052 0.95364885 0.9536261 0.9536388 0.95362439\n",
" 0.95365 0.9536251 0.95364881 0.95363917 0.95362836 0.95361419\n",
" 0.95362283 0.95364097 0.95360492 0.95359786 0.95361872 0.95361965\n",
"Epoch 12 RMSE: [0.94765708 0.94769565 0.94763277 0.94768893 0.94766618 0.947664 | 12/40 [00:29<01:06, 2.39s/it]\n",
" 0.94767364 0.94765312 0.94766496 0.94768184 0.94766724 0.94767774\n",
" 0.94767218 0.94766267 0.94767376 0.94764338 0.94768166 0.94765665\n",
" 0.94765283 0.94770359 0.94766981 0.94767537 0.94766609 0.94764863\n",
" 0.94764898 0.94766509 0.9476688 0.94769696 0.94768249 0.94765882\n",
" 0.94767371 0.94766519 0.94766387 0.94766566 0.94767913 0.94768116\n",
" 0.94765028 0.94765911 0.94766987 0.94769018 0.9476467 0.94765397\n",
" 0.94767479 0.94766754 0.94767501 0.94768491 0.9477023 0.94766229\n",
" 0.94766049 0.9476745 0.94765805 0.94766628 0.94768163 0.947678\n",
" 0.94768793 0.94768374 0.94769577 0.94766257 0.94768209 0.94768808\n",
" 0.94768622 0.94765347 0.94765496 0.94763823 0.94767304 0.94764105\n",
" 0.94765615 0.94765504 0.9476667 0.94767507 0.94767778 0.94768277\n",
" 0.9476749 0.94765828 0.94766685 0.94765747 0.94769349 0.9476598\n",
" 0.94768074 0.94764938 0.94768697 0.94766657 0.94767781 0.94766472\n",
" 0.94768706 0.94766521 0.94768765 0.9476782 0.94766698 0.94765685\n",
" 0.94766267 0.9476793 0.94764692 0.94764121 0.94766126 0.94766045\n",
"Epoch 12 RMSE: [0.94765708 0.94769565 0.94763277 0.94768893 0.94766618 0.947664 | 12/40 [00:29<01:06, 2.39s/it] \n",
" 0.94767364 0.94765312 0.94766496 0.94768184 0.94766724 0.94767774\n",
" 0.94767218 0.94766267 0.94767376 0.94764338 0.94768166 0.94765665\n",
" 0.94765283 0.94770359 0.94766981 0.94767537 0.94766609 0.94764863\n",
" 0.94764898 0.94766509 0.9476688 0.94769696 0.94768249 0.94765882\n",
" 0.94767371 0.94766519 0.94766387 0.94766566 0.94767913 0.94768116\n",
" 0.94765028 0.94765911 0.94766987 0.94769018 0.9476467 0.94765397\n",
" 0.94767479 0.94766754 0.94767501 0.94768491 0.9477023 0.94766229\n",
" 0.94766049 0.9476745 0.94765805 0.94766628 0.94768163 0.947678\n",
" 0.94768793 0.94768374 0.94769577 0.94766257 0.94768209 0.94768808\n",
" 0.94768622 0.94765347 0.94765496 0.94763823 0.94767304 0.94764105\n",
" 0.94765615 0.94765504 0.9476667 0.94767507 0.94767778 0.94768277\n",
" 0.9476749 0.94765828 0.94766685 0.94765747 0.94769349 0.9476598\n",
" 0.94768074 0.94764938 0.94768697 0.94766657 0.94767781 0.94766472\n",
" 0.94768706 0.94766521 0.94768765 0.9476782 0.94766698 0.94765685\n",
" 0.94766267 0.9476793 0.94764692 0.94764121 0.94766126 0.94766045\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 13 RMSE: [0.94242167 0.94245501 0.94239796 0.94244964 0.94242964 0.94242677 | 13/40 [00:32<01:04, 2.38s/it]\n",
" 0.94243566 0.94241669 0.94242748 0.94244225 0.94242968 0.94243994\n",
" 0.9424347 0.94242551 0.94243552 0.94240881 0.94244215 0.94242032\n",
" 0.94241652 0.94246188 0.94243226 0.94243719 0.94242775 0.94241276\n",
" 0.9424129 0.94242876 0.94243091 0.94245705 0.94244329 0.94242141\n",
" 0.94243586 0.94242807 0.94242708 0.94242915 0.9424409 0.94244413\n",
" 0.9424143 0.9424223 0.94242981 0.94245265 0.942412 0.94241917\n",
" 0.94243634 0.94243026 0.94243717 0.94244606 0.94246105 0.94242636\n",
" 0.94242434 0.94243546 0.94242238 0.94242902 0.94244364 0.94243855\n",
" 0.94244718 0.94244423 0.94245625 0.94242622 0.94244366 0.94244964\n",
" 0.94244806 0.94241736 0.94241905 0.94240443 0.94243573 0.94240543\n",
" 0.94241937 0.94241989 0.94242829 0.94243845 0.94243873 0.94244488\n",
" 0.94243726 0.94242083 0.94242829 0.94241988 0.94245415 0.94242219\n",
" 0.94244239 0.94241518 0.94244741 0.9424288 0.94243902 0.94242742\n",
" 0.94244647 0.94242764 0.9424486 0.94243932 0.94242807 0.94242085\n",
" 0.94242457 0.94244017 0.94241092 0.94240614 0.94242508 0.94242352\n",
"Epoch 13 RMSE: [0.94242167 0.94245501 0.94239796 0.94244964 0.94242964 0.94242677 | 13/40 [00:32<01:04, 2.38s/it]\n",
" 0.94243566 0.94241669 0.94242748 0.94244225 0.94242968 0.94243994\n",
" 0.9424347 0.94242551 0.94243552 0.94240881 0.94244215 0.94242032\n",
" 0.94241652 0.94246188 0.94243226 0.94243719 0.94242775 0.94241276\n",
" 0.9424129 0.94242876 0.94243091 0.94245705 0.94244329 0.94242141\n",
" 0.94243586 0.94242807 0.94242708 0.94242915 0.9424409 0.94244413\n",
" 0.9424143 0.9424223 0.94242981 0.94245265 0.942412 0.94241917\n",
" 0.94243634 0.94243026 0.94243717 0.94244606 0.94246105 0.94242636\n",
" 0.94242434 0.94243546 0.94242238 0.94242902 0.94244364 0.94243855\n",
" 0.94244718 0.94244423 0.94245625 0.94242622 0.94244366 0.94244964\n",
" 0.94244806 0.94241736 0.94241905 0.94240443 0.94243573 0.94240543\n",
" 0.94241937 0.94241989 0.94242829 0.94243845 0.94243873 0.94244488\n",
" 0.94243726 0.94242083 0.94242829 0.94241988 0.94245415 0.94242219\n",
" 0.94244239 0.94241518 0.94244741 0.9424288 0.94243902 0.94242742\n",
" 0.94244647 0.94242764 0.9424486 0.94243932 0.94242807 0.94242085\n",
" 0.94242457 0.94244017 0.94241092 0.94240614 0.94242508 0.94242352\n",
"Epoch 14 RMSE: [0.9380729 0.93810193 0.93804964 0.93809779 0.93808017 0.93807681 | 14/40 [00:34<01:01, 2.37s/it]\n",
" 0.93808492 0.93806746 0.93807753 0.93808962 0.93807912 0.93808927\n",
" 0.9380845 0.9380754 0.93808462 0.93806081 0.93809038 0.9380707\n",
" 0.93806719 0.93810823 0.93808146 0.93808617 0.93807663 0.93806401\n",
" 0.93806365 0.9380789 0.93807998 0.93810428 0.93809143 0.938071\n",
" 0.93808527 0.93807792 0.93807713 0.93807948 0.93808988 0.93809373\n",
" 0.93806515 0.93807248 0.93807742 0.93810192 0.93806415 0.93807098\n",
" 0.93808517 0.93808003 0.93808658 0.93809419 0.9381074 0.93807715\n",
" 0.93807482 0.93808371 0.93807322 0.93807872 0.93809271 0.93808645\n",
" 0.93809391 0.93809203 0.93810411 0.93807661 0.93809222 0.93809822\n",
" 0.93809697 0.93806833 0.93806988 0.93805712 0.93808548 0.93805662\n",
" 0.93806972 0.93807144 0.93807716 0.93808838 0.93808667 0.93809374\n",
" 0.93808691 0.93807064 0.93807691 0.93806925 0.93810202 0.93807175\n",
" 0.93809133 0.93806764 0.93809515 0.93807829 0.93808742 0.93807713\n",
" 0.93809335 0.93807706 0.9380968 0.93808785 0.93807651 0.93807192\n",
" 0.93807378 0.93808833 0.93806168 0.9380582 0.93807566 0.93807356\n",
"Epoch 14 RMSE: [0.9380729 0.93810193 0.93804964 0.93809779 0.93808017 0.93807681 | 14/40 [00:34<01:01, 2.37s/it] \n",
" 0.93808492 0.93806746 0.93807753 0.93808962 0.93807912 0.93808927\n",
" 0.9380845 0.9380754 0.93808462 0.93806081 0.93809038 0.9380707\n",
" 0.93806719 0.93810823 0.93808146 0.93808617 0.93807663 0.93806401\n",
" 0.93806365 0.9380789 0.93807998 0.93810428 0.93809143 0.938071\n",
" 0.93808527 0.93807792 0.93807713 0.93807948 0.93808988 0.93809373\n",
" 0.93806515 0.93807248 0.93807742 0.93810192 0.93806415 0.93807098\n",
" 0.93808517 0.93808003 0.93808658 0.93809419 0.9381074 0.93807715\n",
" 0.93807482 0.93808371 0.93807322 0.93807872 0.93809271 0.93808645\n",
" 0.93809391 0.93809203 0.93810411 0.93807661 0.93809222 0.93809822\n",
" 0.93809697 0.93806833 0.93806988 0.93805712 0.93808548 0.93805662\n",
" 0.93806972 0.93807144 0.93807716 0.93808838 0.93808667 0.93809374\n",
" 0.93808691 0.93807064 0.93807691 0.93806925 0.93810202 0.93807175\n",
" 0.93809133 0.93806764 0.93809515 0.93807829 0.93808742 0.93807713\n",
" 0.93809335 0.93807706 0.9380968 0.93808785 0.93807651 0.93807192\n",
" 0.93807378 0.93808833 0.93806168 0.9380582 0.93807566 0.93807356\n",
"Epoch 15 RMSE: [0.93405466 0.9340801 0.934032 0.93407668 0.93406119 0.93405746 | 15/40 [00:36<00:59, 2.38s/it]\n",
" 0.93406501 0.93404887 0.93405817 0.93406843 0.93405947 0.93406962\n",
" 0.93406469 0.93405606 0.93406446 0.9340433 0.93406934 0.93405188\n",
" 0.9340484 0.93408552 0.93406174 0.93406584 0.93405665 0.93404574\n",
" 0.93404505 0.93405986 0.93406013 0.93408281 0.93407048 0.93405159\n",
" 0.93406533 0.93405847 0.93405791 0.93406037 0.93406988 0.93407413\n",
" 0.93404685 0.9340534 0.93405653 0.93408173 0.9340467 0.9340531\n",
" 0.93406499 0.93406087 0.9340664 0.93407335 0.93408522 0.93405857\n",
" 0.93405595 0.93406317 0.93405485 0.93405906 0.93407257 0.93406555\n",
" 0.93407195 0.93407072 0.93408251 0.93405763 0.93407174 0.93407741\n",
" 0.93407665 0.93404982 0.93405133 0.93404042 0.93406611 0.93403863\n",
" 0.9340507 0.93405337 0.93405675 0.934069 0.93406573 0.93407359\n",
" 0.93406724 0.93405173 0.93405685 0.93404977 0.93408081 0.93405228\n",
" 0.93407092 0.93405053 0.93407401 0.93405854 0.93406664 0.93405777\n",
" 0.93407149 0.93405726 0.93407599 0.93406703 0.93405594 0.93405333\n",
" 0.93405387 0.93406767 0.93404327 0.93404073 0.93405704 0.93405458\n",
"Epoch 15 RMSE: [0.93405466 0.9340801 0.934032 0.93407668 0.93406119 0.93405746 | 15/40 [00:36<00:59, 2.38s/it]\n",
" 0.93406501 0.93404887 0.93405817 0.93406843 0.93405947 0.93406962\n",
" 0.93406469 0.93405606 0.93406446 0.9340433 0.93406934 0.93405188\n",
" 0.9340484 0.93408552 0.93406174 0.93406584 0.93405665 0.93404574\n",
" 0.93404505 0.93405986 0.93406013 0.93408281 0.93407048 0.93405159\n",
" 0.93406533 0.93405847 0.93405791 0.93406037 0.93406988 0.93407413\n",
" 0.93404685 0.9340534 0.93405653 0.93408173 0.9340467 0.9340531\n",
" 0.93406499 0.93406087 0.9340664 0.93407335 0.93408522 0.93405857\n",
" 0.93405595 0.93406317 0.93405485 0.93405906 0.93407257 0.93406555\n",
" 0.93407195 0.93407072 0.93408251 0.93405763 0.93407174 0.93407741\n",
" 0.93407665 0.93404982 0.93405133 0.93404042 0.93406611 0.93403863\n",
" 0.9340507 0.93405337 0.93405675 0.934069 0.93406573 0.93407359\n",
" 0.93406724 0.93405173 0.93405685 0.93404977 0.93408081 0.93405228\n",
" 0.93407092 0.93405053 0.93407401 0.93405854 0.93406664 0.93405777\n",
" 0.93407149 0.93405726 0.93407599 0.93406703 0.93405594 0.93405333\n",
" 0.93405387 0.93406767 0.93404327 0.93404073 0.93405704 0.93405458\n",
"Epoch 16 RMSE: [0.93049748 0.93051952 0.93047527 0.93051692 0.93050324 0.93049948 | 16/40 [00:39<00:57, 2.38s/it]\n",
" 0.9305064 0.93049127 0.93050022 0.93050846 0.93050115 0.93051083\n",
" 0.93050615 0.93049781 0.93050559 0.93048669 0.93050969 0.93049385\n",
" 0.93049065 0.93052478 0.93050302 0.93050672 0.93049789 0.93048859\n",
" 0.93048731 0.9305018 0.93050168 0.93052245 0.93051101 0.9304933\n",
" 0.93050653 0.93050002 0.93049977 0.93050252 0.93051119 0.93051542\n",
" 0.93048964 0.9304954 0.93049713 0.93052273 0.93049021 0.93049606\n",
" 0.93050613 0.93050248 0.93050784 0.93051398 0.93052447 0.93050086\n",
" 0.93049819 0.93050399 0.93049753 0.93050087 0.9305136 0.93050592\n",
" 0.93051163 0.9305109 0.93052208 0.93049988 0.9305126 0.93051792\n",
" 0.93051748 0.93049249 0.93049378 0.93048443 0.9305079 0.93048177\n",
" 0.93049304 0.93049618 0.93049766 0.93051063 0.93050627 0.93051455\n",
" 0.93050881 0.930494 0.93049817 0.93049168 0.93052071 0.93049398\n",
" 0.93051177 0.93049432 0.9305143 0.93049999 0.93050746 0.93049972\n",
" 0.93051116 0.93049898 0.93051637 0.93050739 0.93049683 0.93049562\n",
" 0.93049531 0.93050831 0.93048542 0.93048411 0.93049954 0.93049665\n",
"Epoch 16 RMSE: [0.93049748 0.93051952 0.93047527 0.93051692 0.93050324 0.93049948 | 16/40 [00:39<00:57, 2.38s/it]\n",
" 0.9305064 0.93049127 0.93050022 0.93050846 0.93050115 0.93051083\n",
" 0.93050615 0.93049781 0.93050559 0.93048669 0.93050969 0.93049385\n",
" 0.93049065 0.93052478 0.93050302 0.93050672 0.93049789 0.93048859\n",
" 0.93048731 0.9305018 0.93050168 0.93052245 0.93051101 0.9304933\n",
" 0.93050653 0.93050002 0.93049977 0.93050252 0.93051119 0.93051542\n",
" 0.93048964 0.9304954 0.93049713 0.93052273 0.93049021 0.93049606\n",
" 0.93050613 0.93050248 0.93050784 0.93051398 0.93052447 0.93050086\n",
" 0.93049819 0.93050399 0.93049753 0.93050087 0.9305136 0.93050592\n",
" 0.93051163 0.9305109 0.93052208 0.93049988 0.9305126 0.93051792\n",
" 0.93051748 0.93049249 0.93049378 0.93048443 0.9305079 0.93048177\n",
" 0.93049304 0.93049618 0.93049766 0.93051063 0.93050627 0.93051455\n",
" 0.93050881 0.930494 0.93049817 0.93049168 0.93052071 0.93049398\n",
" 0.93051177 0.93049432 0.9305143 0.93049999 0.93050746 0.93049972\n",
" 0.93051116 0.93049898 0.93051637 0.93050739 0.93049683 0.93049562\n",
" 0.93049531 0.93050831 0.93048542 0.93048411 0.93049954 0.93049665\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 17 RMSE: [0.92714744 0.92716697 0.92712595 0.92716482 0.9271528 0.92714876 | 17/40 [00:41<00:54, 2.37s/it]\n",
" 0.92715527 0.92714136 0.92714956 0.92715638 0.92715026 0.9271597\n",
" 0.92715493 0.92714704 0.92715426 0.92713721 0.92715782 0.92714357\n",
" 0.92714057 0.92717165 0.92715207 0.92715539 0.92714643 0.92713889\n",
" 0.92713736 0.92715091 0.92715067 0.92716998 0.92715892 0.92714273\n",
" 0.92715531 0.92714926 0.92714926 0.92715176 0.92715998 0.92716418\n",
" 0.92713979 0.92714502 0.92714533 0.92717078 0.92714096 0.92714635\n",
" 0.92715485 0.92715167 0.92715651 0.92716198 0.92717138 0.92715065\n",
" 0.9271479 0.92715235 0.92714731 0.92714985 0.92716222 0.92715401\n",
" 0.92715894 0.92715872 0.92716959 0.92714928 0.92716093 0.92716598\n",
" 0.92716575 0.9271423 0.92714358 0.92713565 0.92715696 0.92713227\n",
" 0.92714273 0.9271466 0.92714622 0.92715961 0.92715452 0.92716285\n",
" 0.92715791 0.92714354 0.92714721 0.92714098 0.92716835 0.92714319\n",
" 0.92716019 0.92714521 0.92716224 0.9271489 0.92715569 0.92714907\n",
" 0.92715862 0.92714818 0.92716453 0.92715568 0.92714544 0.92714559\n",
" 0.9271442 0.92715665 0.92713534 0.92713479 0.92714913 0.92714612\n",
"Epoch 17 RMSE: [0.92714744 0.92716697 0.92712595 0.92716482 0.9271528 0.92714876 | 17/40 [00:41<00:54, 2.37s/it]\n",
" 0.92715527 0.92714136 0.92714956 0.92715638 0.92715026 0.9271597\n",
" 0.92715493 0.92714704 0.92715426 0.92713721 0.92715782 0.92714357\n",
" 0.92714057 0.92717165 0.92715207 0.92715539 0.92714643 0.92713889\n",
" 0.92713736 0.92715091 0.92715067 0.92716998 0.92715892 0.92714273\n",
" 0.92715531 0.92714926 0.92714926 0.92715176 0.92715998 0.92716418\n",
" 0.92713979 0.92714502 0.92714533 0.92717078 0.92714096 0.92714635\n",
" 0.92715485 0.92715167 0.92715651 0.92716198 0.92717138 0.92715065\n",
" 0.9271479 0.92715235 0.92714731 0.92714985 0.92716222 0.92715401\n",
" 0.92715894 0.92715872 0.92716959 0.92714928 0.92716093 0.92716598\n",
" 0.92716575 0.9271423 0.92714358 0.92713565 0.92715696 0.92713227\n",
" 0.92714273 0.9271466 0.92714622 0.92715961 0.92715452 0.92716285\n",
" 0.92715791 0.92714354 0.92714721 0.92714098 0.92716835 0.92714319\n",
" 0.92716019 0.92714521 0.92716224 0.9271489 0.92715569 0.92714907\n",
" 0.92715862 0.92714818 0.92716453 0.92715568 0.92714544 0.92714559\n",
" 0.9271442 0.92715665 0.92713534 0.92713479 0.92714913 0.92714612\n",
"Epoch 18 RMSE: [0.92371638 0.92373374 0.92369539 0.92373195 0.92372119 0.92371734 | 18/40 [00:43<00:52, 2.36s/it]\n",
" 0.92372333 0.92371018 0.92371809 0.9237235 0.92371867 0.92372747\n",
" 0.92372281 0.92371545 0.92372208 0.92370676 0.9237253 0.92371207\n",
" 0.92370931 0.92373784 0.92372002 0.92372321 0.92371436 0.92370792\n",
" 0.92370615 0.9237192 0.92371863 0.92373675 0.92372621 0.92371132\n",
" 0.92372328 0.92371753 0.92371751 0.92372022 0.9237277 0.92373198\n",
" 0.92370893 0.92371344 0.92371275 0.92373809 0.92371035 0.92371553\n",
" 0.9237226 0.92372003 0.92372456 0.9237292 0.92373803 0.92371934\n",
" 0.92371631 0.92372011 0.92371632 0.92371803 0.92372989 0.9237215\n",
" 0.92372566 0.92372605 0.92373631 0.92371775 0.92372841 0.92373317\n",
" 0.92373334 0.92371127 0.92371246 0.92370561 0.92372518 0.92370161\n",
" 0.9237113 0.92371578 0.92371402 0.92372767 0.92372203 0.92373032\n",
" 0.92372593 0.92371218 0.92371532 0.92370948 0.92373509 0.92371158\n",
" 0.92372759 0.92371497 0.92372913 0.92371694 0.92372338 0.92371765\n",
" 0.92372538 0.92371633 0.92373194 0.92372305 0.92371332 0.92371424\n",
" 0.9237122 0.92372428 0.92370422 0.92370451 0.92371783 0.92371467\n",
"Epoch 18 RMSE: [0.92371638 0.92373374 0.92369539 0.92373195 0.92372119 0.92371734 | 18/40 [00:43<00:52, 2.36s/it]\n",
" 0.92372333 0.92371018 0.92371809 0.9237235 0.92371867 0.92372747\n",
" 0.92372281 0.92371545 0.92372208 0.92370676 0.9237253 0.92371207\n",
" 0.92370931 0.92373784 0.92372002 0.92372321 0.92371436 0.92370792\n",
" 0.92370615 0.9237192 0.92371863 0.92373675 0.92372621 0.92371132\n",
" 0.92372328 0.92371753 0.92371751 0.92372022 0.9237277 0.92373198\n",
" 0.92370893 0.92371344 0.92371275 0.92373809 0.92371035 0.92371553\n",
" 0.9237226 0.92372003 0.92372456 0.9237292 0.92373803 0.92371934\n",
" 0.92371631 0.92372011 0.92371632 0.92371803 0.92372989 0.9237215\n",
" 0.92372566 0.92372605 0.92373631 0.92371775 0.92372841 0.92373317\n",
" 0.92373334 0.92371127 0.92371246 0.92370561 0.92372518 0.92370161\n",
" 0.9237113 0.92371578 0.92371402 0.92372767 0.92372203 0.92373032\n",
" 0.92372593 0.92371218 0.92371532 0.92370948 0.92373509 0.92371158\n",
" 0.92372759 0.92371497 0.92372913 0.92371694 0.92372338 0.92371765\n",
" 0.92372538 0.92371633 0.92373194 0.92372305 0.92371332 0.92371424\n",
" 0.9237122 0.92372428 0.92370422 0.92370451 0.92371783 0.92371467\n",
"Epoch 19 RMSE: [0.92036649 0.92038166 0.92034596 0.92038016 0.92037059 0.92036677 | 19/40 [00:46<00:49, 2.37s/it]\n",
" 0.92037241 0.9203601 0.92036774 0.92037168 0.92036813 0.9203763\n",
" 0.92037182 0.92036494 0.92037122 0.92035722 0.92037376 0.92036178\n",
" 0.92035917 0.92038546 0.9203691 0.92037202 0.92036345 0.92035818\n",
" 0.92035624 0.92036857 0.92036789 0.92038479 0.92037488 0.92036083\n",
" 0.92037237 0.92036669 0.92036692 0.92036962 0.92037696 0.92038083\n",
" 0.9203592 0.92036312 0.92036172 0.92038654 0.92036075 0.92036567\n",
" 0.92037159 0.92036923 0.92037387 0.92037782 0.92038576 0.9203693\n",
" 0.92036627 0.920369 0.92036633 0.92036734 0.92037867 0.92037008\n",
" 0.92037368 0.92037421 0.92038414 0.92036738 0.92037709 0.92038137\n",
" 0.92038191 0.92036128 0.92036223 0.92035638 0.92037447 0.92035213\n",
" 0.92036125 0.92036582 0.92036302 0.92037659 0.92037066 0.92037869\n",
" 0.92037531 0.92036215 0.92036481 0.92035902 0.92038317 0.92036109\n",
" 0.92037616 0.92036588 0.92037759 0.92036632 0.92037216 0.92036717\n",
" 0.92037345 0.92036576 0.92038041 0.9203715 0.92036231 0.92036405\n",
" 0.92036156 0.92037306 0.92035393 0.92035518 0.92036757 0.92036435\n",
"Epoch 19 RMSE: [0.92036649 0.92038166 0.92034596 0.92038016 0.92037059 0.92036677 | 19/40 [00:46<00:49, 2.37s/it] \n",
" 0.92037241 0.9203601 0.92036774 0.92037168 0.92036813 0.9203763\n",
" 0.92037182 0.92036494 0.92037122 0.92035722 0.92037376 0.92036178\n",
" 0.92035917 0.92038546 0.9203691 0.92037202 0.92036345 0.92035818\n",
" 0.92035624 0.92036857 0.92036789 0.92038479 0.92037488 0.92036083\n",
" 0.92037237 0.92036669 0.92036692 0.92036962 0.92037696 0.92038083\n",
" 0.9203592 0.92036312 0.92036172 0.92038654 0.92036075 0.92036567\n",
" 0.92037159 0.92036923 0.92037387 0.92037782 0.92038576 0.9203693\n",
" 0.92036627 0.920369 0.92036633 0.92036734 0.92037867 0.92037008\n",
" 0.92037368 0.92037421 0.92038414 0.92036738 0.92037709 0.92038137\n",
" 0.92038191 0.92036128 0.92036223 0.92035638 0.92037447 0.92035213\n",
" 0.92036125 0.92036582 0.92036302 0.92037659 0.92037066 0.92037869\n",
" 0.92037531 0.92036215 0.92036481 0.92035902 0.92038317 0.92036109\n",
" 0.92037616 0.92036588 0.92037759 0.92036632 0.92037216 0.92036717\n",
" 0.92037345 0.92036576 0.92038041 0.9203715 0.92036231 0.92036405\n",
" 0.92036156 0.92037306 0.92035393 0.92035518 0.92036757 0.92036435\n",
"Epoch 20 RMSE: [0.91679736 0.9168108 0.91677763 0.9168097 0.91680124 0.9167976 | 20/40 [00:48<00:47, 2.36s/it]\n",
" 0.91680273 0.91679114 0.91679858 0.91680147 0.91679851 0.91680629\n",
" 0.91680192 0.91679561 0.91680131 0.91678853 0.91680357 0.9167927\n",
" 0.91679007 0.91681454 0.91679948 0.91680212 0.91679378 0.91678955\n",
" 0.91678719 0.91679889 0.9167983 0.91681408 0.91680477 0.91679167\n",
" 0.91680246 0.91679725 0.91679729 0.9168002 0.91680706 0.91681077\n",
" 0.91679052 0.9167938 0.91679187 0.91681619 0.9167924 0.91679686\n",
" 0.91680175 0.91679981 0.916804 0.91680747 0.91681498 0.91680007\n",
" 0.91679704 0.91679899 0.91679735 0.91679784 0.91680865 0.91679986\n",
" 0.91680297 0.91680382 0.91681324 0.91679804 0.91680704 0.91681089\n",
" 0.91681164 0.91679227 0.91679312 0.91678825 0.91680506 0.91678373\n",
" 0.91679208 0.91679687 0.91679319 0.91680679 0.91680066 0.91680845\n",
" 0.91680554 0.91679303 0.91679532 0.9167898 0.91681233 0.91679165\n",
" 0.91680591 0.91679738 0.91680716 0.91679653 0.916802 0.91679794\n",
" 0.91680275 0.91679611 0.9168098 0.91680122 0.91679272 0.9167949\n",
" 0.91679181 0.91680316 0.91678491 0.91678671 0.9167982 0.91679524\n",
"Epoch 20 RMSE: [0.91679736 0.9168108 0.91677763 0.9168097 0.91680124 0.9167976 | 20/40 [00:48<00:47, 2.36s/it] \n",
" 0.91680273 0.91679114 0.91679858 0.91680147 0.91679851 0.91680629\n",
" 0.91680192 0.91679561 0.91680131 0.91678853 0.91680357 0.9167927\n",
" 0.91679007 0.91681454 0.91679948 0.91680212 0.91679378 0.91678955\n",
" 0.91678719 0.91679889 0.9167983 0.91681408 0.91680477 0.91679167\n",
" 0.91680246 0.91679725 0.91679729 0.9168002 0.91680706 0.91681077\n",
" 0.91679052 0.9167938 0.91679187 0.91681619 0.9167924 0.91679686\n",
" 0.91680175 0.91679981 0.916804 0.91680747 0.91681498 0.91680007\n",
" 0.91679704 0.91679899 0.91679735 0.91679784 0.91680865 0.91679986\n",
" 0.91680297 0.91680382 0.91681324 0.91679804 0.91680704 0.91681089\n",
" 0.91681164 0.91679227 0.91679312 0.91678825 0.91680506 0.91678373\n",
" 0.91679208 0.91679687 0.91679319 0.91680679 0.91680066 0.91680845\n",
" 0.91680554 0.91679303 0.91679532 0.9167898 0.91681233 0.91679165\n",
" 0.91680591 0.91679738 0.91680716 0.91679653 0.916802 0.91679794\n",
" 0.91680275 0.91679611 0.9168098 0.91680122 0.91679272 0.9167949\n",
" 0.91679181 0.91680316 0.91678491 0.91678671 0.9167982 0.91679524\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 21 RMSE: [0.91288634 0.91289823 0.91286723 0.91289745 0.91288981 0.91288643 | 21/40 [00:50<00:44, 2.36s/it]\n",
" 0.9128911 0.91288024 0.91288735 0.9128891 0.91288708 0.91289446\n",
" 0.91289013 0.91288426 0.91288966 0.91287784 0.91289153 0.91288155\n",
" 0.91287905 0.91290158 0.91288794 0.91289021 0.91288232 0.91287877\n",
" 0.91287618 0.91288732 0.91288667 0.91290168 0.91289257 0.91288048\n",
" 0.91289076 0.91288575 0.91288588 0.91288878 0.91289542 0.9128987\n",
" 0.9128798 0.9128825 0.91288016 0.91290374 0.91288171 0.91288599\n",
" 0.91288986 0.91288848 0.91289212 0.91289528 0.91290243 0.91288905\n",
" 0.91288569 0.91288723 0.91288623 0.91288636 0.91289681 0.91288791\n",
" 0.9128906 0.91289159 0.91290042 0.91288679 0.91289487 0.91289864\n",
" 0.91289945 0.91288128 0.91288214 0.91287793 0.9128934 0.91287336\n",
" 0.91288113 0.91288609 0.91288154 0.91289499 0.91288885 0.91289632\n",
" 0.91289398 0.91288213 0.91288396 0.91287871 0.91289974 0.91288028\n",
" 0.91289392 0.91288699 0.9128948 0.91288487 0.91289004 0.91288658\n",
" 0.9128903 0.91288471 0.91289764 0.91288905 0.91288116 0.91288375\n",
" 0.91288023 0.91289138 0.91287401 0.91287621 0.9128869 0.91288384\n",
"Epoch 21 RMSE: [0.91288634 0.91289823 0.91286723 0.91289745 0.91288981 0.91288643 | 21/40 [00:50<00:44, 2.36s/it]\n",
" 0.9128911 0.91288024 0.91288735 0.9128891 0.91288708 0.91289446\n",
" 0.91289013 0.91288426 0.91288966 0.91287784 0.91289153 0.91288155\n",
" 0.91287905 0.91290158 0.91288794 0.91289021 0.91288232 0.91287877\n",
" 0.91287618 0.91288732 0.91288667 0.91290168 0.91289257 0.91288048\n",
" 0.91289076 0.91288575 0.91288588 0.91288878 0.91289542 0.9128987\n",
" 0.9128798 0.9128825 0.91288016 0.91290374 0.91288171 0.91288599\n",
" 0.91288986 0.91288848 0.91289212 0.91289528 0.91290243 0.91288905\n",
" 0.91288569 0.91288723 0.91288623 0.91288636 0.91289681 0.91288791\n",
" 0.9128906 0.91289159 0.91290042 0.91288679 0.91289487 0.91289864\n",
" 0.91289945 0.91288128 0.91288214 0.91287793 0.9128934 0.91287336\n",
" 0.91288113 0.91288609 0.91288154 0.91289499 0.91288885 0.91289632\n",
" 0.91289398 0.91288213 0.91288396 0.91287871 0.91289974 0.91288028\n",
" 0.91289392 0.91288699 0.9128948 0.91288487 0.91289004 0.91288658\n",
" 0.9128903 0.91288471 0.91289764 0.91288905 0.91288116 0.91288375\n",
" 0.91288023 0.91289138 0.91287401 0.91287621 0.9128869 0.91288384\n",
"Epoch 22 RMSE: [0.90878321 0.90879365 0.90876443 0.90879308 0.90878607 0.90878297 | 22/40 [00:53<00:42, 2.36s/it]\n",
" 0.90878753 0.90877714 0.90878393 0.90878485 0.90878358 0.90879038\n",
" 0.90878612 0.90878067 0.90878585 0.90877489 0.90878738 0.90877831\n",
" 0.90877571 0.90879683 0.90878424 0.90878624 0.90877857 0.90877572\n",
" 0.90877303 0.90878363 0.90878294 0.90879717 0.90878846 0.90877722\n",
" 0.90878694 0.90878216 0.90878241 0.90878522 0.9087916 0.90879471\n",
" 0.90877681 0.90877909 0.90877643 0.90879933 0.90877883 0.90878288\n",
" 0.90878604 0.90878479 0.90878845 0.9087911 0.90879779 0.90878555\n",
" 0.90878238 0.90878335 0.90878283 0.90878259 0.90879283 0.90878389\n",
" 0.90878621 0.90878722 0.90879575 0.90878326 0.9087908 0.9087942\n",
" 0.90879515 0.90877805 0.90877887 0.90877535 0.90878977 0.90877081\n",
" 0.90877778 0.90878299 0.90877778 0.90879106 0.90878492 0.908792\n",
" 0.9087903 0.90877897 0.90878051 0.90877554 0.90879516 0.90877682\n",
" 0.90878962 0.90878418 0.90879056 0.90878103 0.90878599 0.90878304\n",
" 0.90878581 0.90878106 0.90879333 0.90878476 0.90877748 0.90878039\n",
" 0.90877668 0.90878743 0.90877075 0.90877347 0.90878331 0.90878046\n",
"Epoch 22 RMSE: [0.90878321 0.90879365 0.90876443 0.90879308 0.90878607 0.90878297 | 22/40 [00:53<00:42, 2.36s/it] \n",
" 0.90878753 0.90877714 0.90878393 0.90878485 0.90878358 0.90879038\n",
" 0.90878612 0.90878067 0.90878585 0.90877489 0.90878738 0.90877831\n",
" 0.90877571 0.90879683 0.90878424 0.90878624 0.90877857 0.90877572\n",
" 0.90877303 0.90878363 0.90878294 0.90879717 0.90878846 0.90877722\n",
" 0.90878694 0.90878216 0.90878241 0.90878522 0.9087916 0.90879471\n",
" 0.90877681 0.90877909 0.90877643 0.90879933 0.90877883 0.90878288\n",
" 0.90878604 0.90878479 0.90878845 0.9087911 0.90879779 0.90878555\n",
" 0.90878238 0.90878335 0.90878283 0.90878259 0.90879283 0.90878389\n",
" 0.90878621 0.90878722 0.90879575 0.90878326 0.9087908 0.9087942\n",
" 0.90879515 0.90877805 0.90877887 0.90877535 0.90878977 0.90877081\n",
" 0.90877778 0.90878299 0.90877778 0.90879106 0.90878492 0.908792\n",
" 0.9087903 0.90877897 0.90878051 0.90877554 0.90879516 0.90877682\n",
" 0.90878962 0.90878418 0.90879056 0.90878103 0.90878599 0.90878304\n",
" 0.90878581 0.90878106 0.90879333 0.90878476 0.90877748 0.90878039\n",
" 0.90877668 0.90878743 0.90877075 0.90877347 0.90878331 0.90878046\n",
"Epoch 23 RMSE: [0.90422053 0.90422986 0.90420248 0.90422944 0.904223 0.90422032 | 23/40 [00:55<00:40, 2.36s/it]\n",
" 0.90422445 0.90421455 0.90422124 0.90422135 0.90422064 0.90422695\n",
" 0.90422281 0.90421798 0.90422258 0.90421247 0.90422392 0.90421567\n",
" 0.90421317 0.90423275 0.90422125 0.90422307 0.9042156 0.90421342\n",
" 0.90421034 0.90422029 0.90422001 0.90423299 0.90422515 0.90421446\n",
" 0.90422382 0.90421928 0.90421941 0.90422233 0.90422836 0.90423127\n",
" 0.9042146 0.90421628 0.90421338 0.90423546 0.90421644 0.90422036\n",
" 0.90422263 0.90422177 0.9042252 0.90422756 0.90423391 0.90422269\n",
" 0.90421953 0.90422021 0.90422029 0.90421968 0.90422942 0.9042206\n",
" 0.90422219 0.90422371 0.90423179 0.90422032 0.90422729 0.9042307\n",
" 0.90423174 0.90421557 0.90421616 0.90421333 0.90422676 0.90420883\n",
" 0.90421514 0.9042205 0.90421473 0.90422767 0.90422153 0.90422851\n",
" 0.90422709 0.90421633 0.90421769 0.90421279 0.90423119 0.9042139\n",
" 0.90422596 0.90422192 0.90422702 0.90421791 0.90422274 0.90422038\n",
" 0.90422228 0.90421833 0.90422976 0.90422142 0.90421453 0.90421772\n",
" 0.90421375 0.90422426 0.90420816 0.9042113 0.90422056 0.90421763\n",
"Epoch 23 RMSE: [0.90422053 0.90422986 0.90420248 0.90422944 0.904223 0.90422032 | 23/40 [00:55<00:40, 2.36s/it]\n",
" 0.90422445 0.90421455 0.90422124 0.90422135 0.90422064 0.90422695\n",
" 0.90422281 0.90421798 0.90422258 0.90421247 0.90422392 0.90421567\n",
" 0.90421317 0.90423275 0.90422125 0.90422307 0.9042156 0.90421342\n",
" 0.90421034 0.90422029 0.90422001 0.90423299 0.90422515 0.90421446\n",
" 0.90422382 0.90421928 0.90421941 0.90422233 0.90422836 0.90423127\n",
" 0.9042146 0.90421628 0.90421338 0.90423546 0.90421644 0.90422036\n",
" 0.90422263 0.90422177 0.9042252 0.90422756 0.90423391 0.90422269\n",
" 0.90421953 0.90422021 0.90422029 0.90421968 0.90422942 0.9042206\n",
" 0.90422219 0.90422371 0.90423179 0.90422032 0.90422729 0.9042307\n",
" 0.90423174 0.90421557 0.90421616 0.90421333 0.90422676 0.90420883\n",
" 0.90421514 0.9042205 0.90421473 0.90422767 0.90422153 0.90422851\n",
" 0.90422709 0.90421633 0.90421769 0.90421279 0.90423119 0.9042139\n",
" 0.90422596 0.90422192 0.90422702 0.90421791 0.90422274 0.90422038\n",
" 0.90422228 0.90421833 0.90422976 0.90422142 0.90421453 0.90421772\n",
" 0.90421375 0.90422426 0.90420816 0.9042113 0.90422056 0.90421763\n",
"Epoch 24 RMSE: [0.89944541 0.89945349 0.89942796 0.89945335 0.89944751 0.89944501 | 24/40 [00:58<00:38, 2.39s/it]\n",
" 0.89944878 0.89943965 0.899446 0.89944534 0.89944544 0.89945125\n",
" 0.899447 0.8994426 0.89944699 0.89943749 0.89944794 0.89944042\n",
" 0.89943797 0.89945648 0.8994455 0.89944736 0.89944025 0.89943837\n",
" 0.89943513 0.89944477 0.8994446 0.89945708 0.89944921 0.89943937\n",
" 0.89944796 0.89944377 0.89944392 0.89944665 0.89945284 0.89945524\n",
" 0.89943949 0.89944083 0.89943778 0.89945916 0.89944153 0.89944516\n",
" 0.89944711 0.89944631 0.89944975 0.89945155 0.89945769 0.89944752\n",
" 0.89944431 0.89944458 0.89944509 0.89944413 0.89945366 0.89944493\n",
" 0.89944639 0.89944775 0.89945525 0.899445 0.89945131 0.89945457\n",
" 0.89945558 0.89944056 0.89944123 0.89943846 0.8994512 0.89943387\n",
" 0.89944006 0.89944532 0.8994392 0.8994519 0.89944585 0.89945251\n",
" 0.89945128 0.89944128 0.89944248 0.89943771 0.8994548 0.89943861\n",
" 0.89944998 0.89944716 0.89945072 0.89944257 0.89944696 0.89944494\n",
" 0.89944599 0.89944286 0.89945349 0.8994455 0.89943896 0.89944237\n",
" 0.89943852 0.89944858 0.89943303 0.89943647 0.8994449 0.89944234\n",
"Epoch 24 RMSE: [0.89944541 0.89945349 0.89942796 0.89945335 0.89944751 0.89944501 | 24/40 [00:58<00:38, 2.39s/it]\n",
" 0.89944878 0.89943965 0.899446 0.89944534 0.89944544 0.89945125\n",
" 0.899447 0.8994426 0.89944699 0.89943749 0.89944794 0.89944042\n",
" 0.89943797 0.89945648 0.8994455 0.89944736 0.89944025 0.89943837\n",
" 0.89943513 0.89944477 0.8994446 0.89945708 0.89944921 0.89943937\n",
" 0.89944796 0.89944377 0.89944392 0.89944665 0.89945284 0.89945524\n",
" 0.89943949 0.89944083 0.89943778 0.89945916 0.89944153 0.89944516\n",
" 0.89944711 0.89944631 0.89944975 0.89945155 0.89945769 0.89944752\n",
" 0.89944431 0.89944458 0.89944509 0.89944413 0.89945366 0.89944493\n",
" 0.89944639 0.89944775 0.89945525 0.899445 0.89945131 0.89945457\n",
" 0.89945558 0.89944056 0.89944123 0.89943846 0.8994512 0.89943387\n",
" 0.89944006 0.89944532 0.8994392 0.8994519 0.89944585 0.89945251\n",
" 0.89945128 0.89944128 0.89944248 0.89943771 0.8994548 0.89943861\n",
" 0.89944998 0.89944716 0.89945072 0.89944257 0.89944696 0.89944494\n",
" 0.89944599 0.89944286 0.89945349 0.8994455 0.89943896 0.89944237\n",
" 0.89943852 0.89944858 0.89943303 0.89943647 0.8994449 0.89944234\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 25 RMSE: [0.8939617 0.89396902 0.89394454 0.89396896 0.89396355 0.89396113 | 25/40 [01:00<00:35, 2.38s/it]\n",
" 0.8939648 0.89395604 0.89396229 0.89396108 0.89396134 0.89396702\n",
" 0.89396299 0.89395879 0.89396286 0.89395418 0.89396375 0.89395685\n",
" 0.8939543 0.89397161 0.89396178 0.89396314 0.8939562 0.89395505\n",
" 0.89395179 0.89396073 0.89396054 0.89397242 0.89396513 0.8939557\n",
" 0.89396393 0.89395993 0.89395987 0.89396286 0.89396866 0.89397102\n",
" 0.89395619 0.89395716 0.89395402 0.8939746 0.89395806 0.89396162\n",
" 0.89396294 0.89396246 0.8939657 0.89396737 0.89397308 0.89396356\n",
" 0.8939605 0.89396057 0.89396137 0.89396035 0.89396944 0.89396085\n",
" 0.89396204 0.8939634 0.89397064 0.89396123 0.89396705 0.89397009\n",
" 0.89397127 0.89395682 0.89395731 0.89395543 0.89396713 0.89395092\n",
" 0.89395642 0.89396178 0.89395519 0.89396772 0.89396193 0.89396804\n",
" 0.89396738 0.89395783 0.89395883 0.89395421 0.89397033 0.89395483\n",
" 0.89396561 0.89396384 0.89396653 0.89395857 0.89396291 0.89396119\n",
" 0.89396164 0.89395909 0.89396916 0.89396109 0.8939552 0.89395868\n",
" 0.89395448 0.89396461 0.89394963 0.89395322 0.8939614 0.89395862\n",
"Epoch 25 RMSE: [0.8939617 0.89396902 0.89394454 0.89396896 0.89396355 0.89396113 | 25/40 [01:00<00:35, 2.38s/it] \n",
" 0.8939648 0.89395604 0.89396229 0.89396108 0.89396134 0.89396702\n",
" 0.89396299 0.89395879 0.89396286 0.89395418 0.89396375 0.89395685\n",
" 0.8939543 0.89397161 0.89396178 0.89396314 0.8939562 0.89395505\n",
" 0.89395179 0.89396073 0.89396054 0.89397242 0.89396513 0.8939557\n",
" 0.89396393 0.89395993 0.89395987 0.89396286 0.89396866 0.89397102\n",
" 0.89395619 0.89395716 0.89395402 0.8939746 0.89395806 0.89396162\n",
" 0.89396294 0.89396246 0.8939657 0.89396737 0.89397308 0.89396356\n",
" 0.8939605 0.89396057 0.89396137 0.89396035 0.89396944 0.89396085\n",
" 0.89396204 0.8939634 0.89397064 0.89396123 0.89396705 0.89397009\n",
" 0.89397127 0.89395682 0.89395731 0.89395543 0.89396713 0.89395092\n",
" 0.89395642 0.89396178 0.89395519 0.89396772 0.89396193 0.89396804\n",
" 0.89396738 0.89395783 0.89395883 0.89395421 0.89397033 0.89395483\n",
" 0.89396561 0.89396384 0.89396653 0.89395857 0.89396291 0.89396119\n",
" 0.89396164 0.89395909 0.89396916 0.89396109 0.8939552 0.89395868\n",
" 0.89395448 0.89396461 0.89394963 0.89395322 0.8939614 0.89395862\n",
"Epoch 26 RMSE: [0.88813915 0.88814554 0.88812237 0.88814555 0.88814064 0.88813834 | 26/40 [01:02<00:33, 2.39s/it]\n",
" 0.88814185 0.88813351 0.88813954 0.8881378 0.88813859 0.88814365\n",
" 0.8881398 0.88813593 0.88813996 0.88813158 0.88814047 0.88813423\n",
" 0.88813165 0.88814795 0.8881389 0.8881399 0.88813335 0.88813264\n",
" 0.8881292 0.88813777 0.88813767 0.88814902 0.88814186 0.88813312\n",
" 0.88814095 0.8881371 0.88813693 0.88813993 0.88814556 0.8881475\n",
" 0.88813369 0.88813441 0.8881312 0.88815094 0.8881356 0.88813904\n",
" 0.88813984 0.88813946 0.88814259 0.88814413 0.88814973 0.88814092\n",
" 0.88813763 0.88813751 0.88813871 0.88813734 0.88814611 0.8881377\n",
" 0.88813875 0.8881401 0.88814686 0.88813846 0.88814378 0.88814658\n",
" 0.88814785 0.88813425 0.88813473 0.88813305 0.88814419 0.88812872\n",
" 0.88813375 0.88813917 0.88813233 0.88814456 0.88813878 0.88814465\n",
" 0.88814438 0.88813536 0.88813623 0.88813153 0.88814687 0.88813212\n",
" 0.88814238 0.88814128 0.88814313 0.88813559 0.8881398 0.88813829\n",
" 0.88813821 0.88813624 0.88814575 0.88813772 0.88813232 0.8881359\n",
" 0.88813164 0.88814152 0.88812695 0.88813106 0.8881383 0.88813582\n",
"Epoch 26 RMSE: [0.88813915 0.88814554 0.88812237 0.88814555 0.88814064 0.88813834 | 26/40 [01:02<00:33, 2.39s/it] \n",
" 0.88814185 0.88813351 0.88813954 0.8881378 0.88813859 0.88814365\n",
" 0.8881398 0.88813593 0.88813996 0.88813158 0.88814047 0.88813423\n",
" 0.88813165 0.88814795 0.8881389 0.8881399 0.88813335 0.88813264\n",
" 0.8881292 0.88813777 0.88813767 0.88814902 0.88814186 0.88813312\n",
" 0.88814095 0.8881371 0.88813693 0.88813993 0.88814556 0.8881475\n",
" 0.88813369 0.88813441 0.8881312 0.88815094 0.8881356 0.88813904\n",
" 0.88813984 0.88813946 0.88814259 0.88814413 0.88814973 0.88814092\n",
" 0.88813763 0.88813751 0.88813871 0.88813734 0.88814611 0.8881377\n",
" 0.88813875 0.8881401 0.88814686 0.88813846 0.88814378 0.88814658\n",
" 0.88814785 0.88813425 0.88813473 0.88813305 0.88814419 0.88812872\n",
" 0.88813375 0.88813917 0.88813233 0.88814456 0.88813878 0.88814465\n",
" 0.88814438 0.88813536 0.88813623 0.88813153 0.88814687 0.88813212\n",
" 0.88814238 0.88814128 0.88814313 0.88813559 0.8881398 0.88813829\n",
" 0.88813821 0.88813624 0.88814575 0.88813772 0.88813232 0.8881359\n",
" 0.88813164 0.88814152 0.88812695 0.88813106 0.8881383 0.88813582\n",
"Epoch 27 RMSE: [0.88210562 0.88211126 0.88208923 0.88211131 0.88210686 0.88210478 | 27/40 [01:05<00:31, 2.39s/it]\n",
" 0.88210805 0.8821001 0.88210595 0.88210382 0.88210477 0.88210963\n",
" 0.88210586 0.88210238 0.8821062 0.8820982 0.88210642 0.88210084\n",
" 0.88209819 0.88211345 0.88210509 0.88210605 0.88209968 0.88209927\n",
" 0.88209582 0.88210388 0.8821039 0.88211464 0.88210788 0.88209959\n",
" 0.88210712 0.88210345 0.88210318 0.88210625 0.88211177 0.88211355\n",
" 0.88210046 0.88210083 0.88209761 0.88211654 0.88210219 0.88210552\n",
" 0.88210585 0.88210574 0.88210875 0.8821101 0.88211536 0.88210732\n",
" 0.88210405 0.88210369 0.88210507 0.88210357 0.88211216 0.88210397\n",
" 0.8821046 0.88210594 0.88211244 0.88210475 0.88210974 0.88211234\n",
" 0.88211377 0.88210078 0.88210134 0.88209997 0.88211021 0.88209584\n",
" 0.88210022 0.88210559 0.88209865 0.88211054 0.88210499 0.88211031\n",
" 0.8821106 0.88210202 0.88210277 0.88209825 0.88211251 0.88209844\n",
" 0.88210823 0.88210803 0.88210913 0.88210189 0.88210589 0.8821047\n",
" 0.88210414 0.88210258 0.88211149 0.88210375 0.88209868 0.88210237\n",
" 0.882098 0.8821077 0.88209366 0.88209777 0.88210481 0.88210228\n",
"Epoch 27 RMSE: [0.88210562 0.88211126 0.88208923 0.88211131 0.88210686 0.88210478 | 27/40 [01:05<00:31, 2.39s/it]\n",
" 0.88210805 0.8821001 0.88210595 0.88210382 0.88210477 0.88210963\n",
" 0.88210586 0.88210238 0.8821062 0.8820982 0.88210642 0.88210084\n",
" 0.88209819 0.88211345 0.88210509 0.88210605 0.88209968 0.88209927\n",
" 0.88209582 0.88210388 0.8821039 0.88211464 0.88210788 0.88209959\n",
" 0.88210712 0.88210345 0.88210318 0.88210625 0.88211177 0.88211355\n",
" 0.88210046 0.88210083 0.88209761 0.88211654 0.88210219 0.88210552\n",
" 0.88210585 0.88210574 0.88210875 0.8821101 0.88211536 0.88210732\n",
" 0.88210405 0.88210369 0.88210507 0.88210357 0.88211216 0.88210397\n",
" 0.8821046 0.88210594 0.88211244 0.88210475 0.88210974 0.88211234\n",
" 0.88211377 0.88210078 0.88210134 0.88209997 0.88211021 0.88209584\n",
" 0.88210022 0.88210559 0.88209865 0.88211054 0.88210499 0.88211031\n",
" 0.8821106 0.88210202 0.88210277 0.88209825 0.88211251 0.88209844\n",
" 0.88210823 0.88210803 0.88210913 0.88210189 0.88210589 0.8821047\n",
" 0.88210414 0.88210258 0.88211149 0.88210375 0.88209868 0.88210237\n",
" 0.882098 0.8821077 0.88209366 0.88209777 0.88210481 0.88210228\n",
"Epoch 28 RMSE: [0.87565692 0.87566189 0.87564095 0.87566218 0.87565795 0.87565605 | 28/40 [01:07<00:28, 2.39s/it]\n",
" 0.87565909 0.87565153 0.87565725 0.87565463 0.87565602 0.87566048\n",
" 0.87565678 0.87565358 0.87565725 0.87564969 0.87565727 0.87565217\n",
" 0.8756495 0.875664 0.87565625 0.87565704 0.87565086 0.8756508\n",
" 0.8756472 0.87565495 0.87565502 0.87566519 0.87565886 0.87565115\n",
" 0.87565807 0.87565482 0.87565438 0.87565743 0.87566278 0.87566426\n",
" 0.87565197 0.87565215 0.87564886 0.87566716 0.87565374 0.87565689\n",
" 0.87565681 0.87565688 0.87565987 0.87566099 0.87566612 0.87565845\n",
" 0.87565536 0.87565489 0.87565651 0.87565477 0.87566306 0.87565502\n",
" 0.87565545 0.87565673 0.87566285 0.87565596 0.8756605 0.87566306\n",
" 0.87566459 0.87565221 0.87565271 0.87565164 0.87566135 0.8756476\n",
" 0.87565177 0.87565694 0.87564981 0.87566128 0.87565583 0.87566112\n",
" 0.87566154 0.87565343 0.8756542 0.87564977 0.87566313 0.87564967\n",
" 0.87565908 0.87565954 0.87565998 0.87565301 0.87565698 0.875656\n",
" 0.87565497 0.87565387 0.87566218 0.87565457 0.87564996 0.87565371\n",
" 0.87564916 0.87565878 0.87564516 0.87564938 0.87565599 0.87565357\n",
"Epoch 28 RMSE: [0.87565692 0.87566189 0.87564095 0.87566218 0.87565795 0.87565605 | 28/40 [01:07<00:28, 2.39s/it] \n",
" 0.87565909 0.87565153 0.87565725 0.87565463 0.87565602 0.87566048\n",
" 0.87565678 0.87565358 0.87565725 0.87564969 0.87565727 0.87565217\n",
" 0.8756495 0.875664 0.87565625 0.87565704 0.87565086 0.8756508\n",
" 0.8756472 0.87565495 0.87565502 0.87566519 0.87565886 0.87565115\n",
" 0.87565807 0.87565482 0.87565438 0.87565743 0.87566278 0.87566426\n",
" 0.87565197 0.87565215 0.87564886 0.87566716 0.87565374 0.87565689\n",
" 0.87565681 0.87565688 0.87565987 0.87566099 0.87566612 0.87565845\n",
" 0.87565536 0.87565489 0.87565651 0.87565477 0.87566306 0.87565502\n",
" 0.87565545 0.87565673 0.87566285 0.87565596 0.8756605 0.87566306\n",
" 0.87566459 0.87565221 0.87565271 0.87565164 0.87566135 0.8756476\n",
" 0.87565177 0.87565694 0.87564981 0.87566128 0.87565583 0.87566112\n",
" 0.87566154 0.87565343 0.8756542 0.87564977 0.87566313 0.87564967\n",
" 0.87565908 0.87565954 0.87565998 0.87565301 0.87565698 0.875656\n",
" 0.87565497 0.87565387 0.87566218 0.87565457 0.87564996 0.87565371\n",
" 0.87564916 0.87565878 0.87564516 0.87564938 0.87565599 0.87565357\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 29 RMSE: [0.86889631 0.86890086 0.86888072 0.86890107 0.86889721 0.8688955█▎ | 29/40 [01:10<00:26, 2.40s/it]\n",
" 0.86889828 0.86889117 0.86889675 0.86889368 0.86889532 0.86889949\n",
" 0.86889593 0.86889294 0.86889654 0.86888923 0.86889628 0.86889179\n",
" 0.86888901 0.86890262 0.8688956 0.86889617 0.86889024 0.86889047\n",
" 0.86888686 0.86889417 0.86889428 0.86890411 0.86889792 0.86889061\n",
" 0.86889721 0.868894 0.86889356 0.86889675 0.8689019 0.8689032\n",
" 0.86889165 0.86889161 0.86888831 0.86890585 0.86889323 0.86889634\n",
" 0.86889585 0.8688961 0.86889915 0.86889995 0.86890494 0.86889787\n",
" 0.86889458 0.86889406 0.86889589 0.86889401 0.86890207 0.86889425\n",
" 0.86889451 0.8688958 0.86890159 0.86889536 0.86889957 0.86890191\n",
" 0.86890345 0.86889177 0.86889218 0.86889121 0.86890034 0.86888756\n",
" 0.86889132 0.86889642 0.8688892 0.86890046 0.86889511 0.86889996\n",
" 0.86890076 0.86889303 0.86889375 0.86888928 0.86890196 0.86888924\n",
" 0.86889784 0.86889914 0.86889884 0.86889234 0.86889602 0.86889544\n",
" 0.8688939 0.86889325 0.86890111 0.8688936 0.86888929 0.86889311\n",
" 0.86888863 0.868898 0.86888475 0.86888913 0.86889527 0.86889298\n",
"Epoch 29 RMSE: [0.86889631 0.86890086 0.86888072 0.86890107 0.86889721 0.8688955█▎ | 29/40 [01:10<00:26, 2.40s/it]\n",
" 0.86889828 0.86889117 0.86889675 0.86889368 0.86889532 0.86889949\n",
" 0.86889593 0.86889294 0.86889654 0.86888923 0.86889628 0.86889179\n",
" 0.86888901 0.86890262 0.8688956 0.86889617 0.86889024 0.86889047\n",
" 0.86888686 0.86889417 0.86889428 0.86890411 0.86889792 0.86889061\n",
" 0.86889721 0.868894 0.86889356 0.86889675 0.8689019 0.8689032\n",
" 0.86889165 0.86889161 0.86888831 0.86890585 0.86889323 0.86889634\n",
" 0.86889585 0.8688961 0.86889915 0.86889995 0.86890494 0.86889787\n",
" 0.86889458 0.86889406 0.86889589 0.86889401 0.86890207 0.86889425\n",
" 0.86889451 0.8688958 0.86890159 0.86889536 0.86889957 0.86890191\n",
" 0.86890345 0.86889177 0.86889218 0.86889121 0.86890034 0.86888756\n",
" 0.86889132 0.86889642 0.8688892 0.86890046 0.86889511 0.86889996\n",
" 0.86890076 0.86889303 0.86889375 0.86888928 0.86890196 0.86888924\n",
" 0.86889784 0.86889914 0.86889884 0.86889234 0.86889602 0.86889544\n",
" 0.8688939 0.86889325 0.86890111 0.8688936 0.86888929 0.86889311\n",
" 0.86888863 0.868898 0.86888475 0.86888913 0.86889527 0.86889298\n",
"Epoch 30 RMSE: [0.86192812 0.86193219 0.86191303 0.86193246 0.86192887 0.86192735▌ | 30/40 [01:12<00:23, 2.39s/it]\n",
" 0.86192998 0.86192315 0.86192862 0.86192525 0.86192709 0.86193092\n",
" 0.86192752 0.86192477 0.86192829 0.86192116 0.86192787 0.86192382\n",
" 0.86192098 0.86193383 0.86192743 0.8619278 0.8619221 0.86192256\n",
" 0.86191887 0.86192581 0.86192609 0.86193541 0.86192946 0.86192254\n",
" 0.86192879 0.86192587 0.86192527 0.8619285 0.86193354 0.86193463\n",
" 0.86192368 0.86192349 0.86192022 0.86193713 0.8619252 0.86192839\n",
" 0.86192746 0.86192779 0.86193076 0.86193148 0.86193627 0.86192961\n",
" 0.86192651 0.8619257 0.86192776 0.86192582 0.86193364 0.86192596\n",
" 0.86192595 0.86192726 0.86193289 0.86192717 0.86193094 0.8619332\n",
" 0.86193486 0.86192368 0.86192413 0.86192332 0.86193194 0.86191989\n",
" 0.86192321 0.86192831 0.86192107 0.86193196 0.86192682 0.86193139\n",
" 0.86193237 0.86192516 0.86192579 0.86192134 0.86193331 0.86192106\n",
" 0.86192917 0.86193116 0.86193034 0.86192405 0.86192776 0.86192727\n",
" 0.86192533 0.86192518 0.86193249 0.86192515 0.86192107 0.86192498\n",
" 0.86192046 0.86192969 0.86191688 0.8619213 0.861927 0.86192464\n",
"Epoch 30 RMSE: [0.86192812 0.86193219 0.86191303 0.86193246 0.86192887 0.86192735▌ | 30/40 [01:12<00:23, 2.39s/it]\n",
" 0.86192998 0.86192315 0.86192862 0.86192525 0.86192709 0.86193092\n",
" 0.86192752 0.86192477 0.86192829 0.86192116 0.86192787 0.86192382\n",
" 0.86192098 0.86193383 0.86192743 0.8619278 0.8619221 0.86192256\n",
" 0.86191887 0.86192581 0.86192609 0.86193541 0.86192946 0.86192254\n",
" 0.86192879 0.86192587 0.86192527 0.8619285 0.86193354 0.86193463\n",
" 0.86192368 0.86192349 0.86192022 0.86193713 0.8619252 0.86192839\n",
" 0.86192746 0.86192779 0.86193076 0.86193148 0.86193627 0.86192961\n",
" 0.86192651 0.8619257 0.86192776 0.86192582 0.86193364 0.86192596\n",
" 0.86192595 0.86192726 0.86193289 0.86192717 0.86193094 0.8619332\n",
" 0.86193486 0.86192368 0.86192413 0.86192332 0.86193194 0.86191989\n",
" 0.86192321 0.86192831 0.86192107 0.86193196 0.86192682 0.86193139\n",
" 0.86193237 0.86192516 0.86192579 0.86192134 0.86193331 0.86192106\n",
" 0.86192917 0.86193116 0.86193034 0.86192405 0.86192776 0.86192727\n",
" 0.86192533 0.86192518 0.86193249 0.86192515 0.86192107 0.86192498\n",
" 0.86192046 0.86192969 0.86191688 0.8619213 0.861927 0.86192464\n",
"Epoch 31 RMSE: [0.85460494 0.85460845 0.85459006 0.85460887 0.85460553 0.85460401▊ | 31/40 [01:15<00:22, 2.48s/it]\n",
" 0.85460646 0.85459995 0.85460541 0.85460188 0.85460372 0.85460731\n",
" 0.854604 0.85460138 0.85460479 0.8545981 0.85460427 0.8546007\n",
" 0.85459779 0.8546099 0.85460408 0.85460425 0.85459882 0.8545994\n",
" 0.85459572 0.85460232 0.85460275 0.85461164 0.85460591 0.8545995\n",
" 0.85460535 0.85460262 0.85460197 0.85460507 0.85460999 0.85461095\n",
" 0.8546007 0.85460036 0.85459693 0.8546133 0.85460194 0.85460515\n",
" 0.85460396 0.85460441 0.85460736 0.85460799 0.85461262 0.85460627\n",
" 0.8546031 0.85460228 0.85460449 0.85460242 0.85461011 0.85460262\n",
" 0.8546025 0.85460376 0.85460904 0.8546038 0.85460733 0.85460953\n",
" 0.85461131 0.85460064 0.85460087 0.85460028 0.85460846 0.85459713\n",
" 0.85460006 0.85460517 0.85459777 0.85460837 0.8546034 0.85460774\n",
" 0.85460892 0.85460199 0.85460267 0.85459819 0.8546096 0.8545978\n",
" 0.85460559 0.85460796 0.85460667 0.85460072 0.85460433 0.85460402\n",
" 0.85460173 0.85460189 0.85460865 0.85460155 0.85459776 0.85460177\n",
" 0.85459706 0.85460631 0.85459377 0.85459838 0.85460372 0.85460149\n",
"Epoch 31 RMSE: [0.85460494 0.85460845 0.85459006 0.85460887 0.85460553 0.85460401▊ | 31/40 [01:15<00:22, 2.48s/it] \n",
" 0.85460646 0.85459995 0.85460541 0.85460188 0.85460372 0.85460731\n",
" 0.854604 0.85460138 0.85460479 0.8545981 0.85460427 0.8546007\n",
" 0.85459779 0.8546099 0.85460408 0.85460425 0.85459882 0.8545994\n",
" 0.85459572 0.85460232 0.85460275 0.85461164 0.85460591 0.8545995\n",
" 0.85460535 0.85460262 0.85460197 0.85460507 0.85460999 0.85461095\n",
" 0.8546007 0.85460036 0.85459693 0.8546133 0.85460194 0.85460515\n",
" 0.85460396 0.85460441 0.85460736 0.85460799 0.85461262 0.85460627\n",
" 0.8546031 0.85460228 0.85460449 0.85460242 0.85461011 0.85460262\n",
" 0.8546025 0.85460376 0.85460904 0.8546038 0.85460733 0.85460953\n",
" 0.85461131 0.85460064 0.85460087 0.85460028 0.85460846 0.85459713\n",
" 0.85460006 0.85460517 0.85459777 0.85460837 0.8546034 0.85460774\n",
" 0.85460892 0.85460199 0.85460267 0.85459819 0.8546096 0.8545978\n",
" 0.85460559 0.85460796 0.85460667 0.85460072 0.85460433 0.85460402\n",
" 0.85460173 0.85460189 0.85460865 0.85460155 0.85459776 0.85460177\n",
" 0.85459706 0.85460631 0.85459377 0.85459838 0.85460372 0.85460149\n",
"Epoch 32 RMSE: [0.84699439 0.84699767 0.8469799 0.84699794 0.84699484 0.84699372█ | 32/40 [01:17<00:19, 2.46s/it]\n",
" 0.84699584 0.8469897 0.84699497 0.8469911 0.84699306 0.84699638\n",
" 0.84699326 0.84699088 0.84699428 0.84698774 0.8469936 0.84699039\n",
" 0.84698744 0.84699885 0.84699355 0.84699366 0.8469883 0.84698922\n",
" 0.84698555 0.84699181 0.84699215 0.84700073 0.84699522 0.8469891\n",
" 0.84699478 0.846992 0.84699133 0.8469945 0.84699923 0.84700008\n",
" 0.8469903 0.8469898 0.84698658 0.84700231 0.84699156 0.8469947\n",
" 0.84699318 0.84699383 0.84699687 0.84699715 0.84700179 0.84699568\n",
" 0.84699266 0.84699176 0.84699398 0.84699188 0.84699932 0.84699208\n",
" 0.84699184 0.84699306 0.84699801 0.8469934 0.84699651 0.84699856\n",
" 0.84700039 0.8469902 0.8469905 0.84698998 0.8469977 0.84698715\n",
" 0.84698967 0.84699458 0.84698729 0.84699746 0.84699281 0.84699688\n",
" 0.84699831 0.84699173 0.8469924 0.84698786 0.84699875 0.84698738\n",
" 0.84699483 0.8469976 0.84699594 0.84699025 0.8469936 0.84699353\n",
" 0.84699101 0.84699135 0.84699786 0.84699079 0.84698747 0.84699128\n",
" 0.84698664 0.84699562 0.84698352 0.84698809 0.84699317 0.84699104\n",
"Epoch 32 RMSE: [0.84699439 0.84699767 0.8469799 0.84699794 0.84699484 0.84699372█ | 32/40 [01:17<00:19, 2.46s/it]\n",
" 0.84699584 0.8469897 0.84699497 0.8469911 0.84699306 0.84699638\n",
" 0.84699326 0.84699088 0.84699428 0.84698774 0.8469936 0.84699039\n",
" 0.84698744 0.84699885 0.84699355 0.84699366 0.8469883 0.84698922\n",
" 0.84698555 0.84699181 0.84699215 0.84700073 0.84699522 0.8469891\n",
" 0.84699478 0.846992 0.84699133 0.8469945 0.84699923 0.84700008\n",
" 0.8469903 0.8469898 0.84698658 0.84700231 0.84699156 0.8469947\n",
" 0.84699318 0.84699383 0.84699687 0.84699715 0.84700179 0.84699568\n",
" 0.84699266 0.84699176 0.84699398 0.84699188 0.84699932 0.84699208\n",
" 0.84699184 0.84699306 0.84699801 0.8469934 0.84699651 0.84699856\n",
" 0.84700039 0.8469902 0.8469905 0.84698998 0.8469977 0.84698715\n",
" 0.84698967 0.84699458 0.84698729 0.84699746 0.84699281 0.84699688\n",
" 0.84699831 0.84699173 0.8469924 0.84698786 0.84699875 0.84698738\n",
" 0.84699483 0.8469976 0.84699594 0.84699025 0.8469936 0.84699353\n",
" 0.84699101 0.84699135 0.84699786 0.84699079 0.84698747 0.84699128\n",
" 0.84698664 0.84699562 0.84698352 0.84698809 0.84699317 0.84699104\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 33 RMSE: [0.83875367 0.83875651 0.83873955 0.83875691 0.83875405 0.83875292█▎ | 33/40 [01:19<00:17, 2.43s/it]\n",
" 0.83875482 0.83874898 0.83875422 0.83875015 0.83875224 0.83875543\n",
" 0.8387522 0.83875012 0.8387534 0.83874705 0.83875253 0.83874978\n",
" 0.83874684 0.83875766 0.8387528 0.8387526 0.83874762 0.83874864\n",
" 0.83874501 0.83875084 0.83875135 0.83875956 0.83875426 0.83874856\n",
" 0.8387537 0.83875149 0.83875047 0.83875362 0.83875843 0.83875902\n",
" 0.83874978 0.83874908 0.83874593 0.83876096 0.83875086 0.83875414\n",
" 0.83875214 0.83875291 0.83875595 0.83875619 0.83876049 0.83875497\n",
" 0.83875172 0.83875078 0.83875316 0.83875099 0.83875813 0.83875121\n",
" 0.83875075 0.83875199 0.83875683 0.83875258 0.83875547 0.83875734\n",
" 0.83875926 0.83874945 0.83874978 0.83874942 0.83875676 0.8387468\n",
" 0.83874893 0.8387539 0.8387466 0.83875661 0.83875198 0.83875579\n",
" 0.83875737 0.83875116 0.83875181 0.83874742 0.83875752 0.83874663\n",
" 0.8387537 0.83875691 0.83875493 0.83874943 0.83875288 0.83875274\n",
" 0.83874997 0.83875062 0.83875674 0.83874987 0.83874648 0.83875053\n",
" 0.83874586 0.83875478 0.83874302 0.83874753 0.83875232 0.83875024\n",
"Epoch 33 RMSE: [0.83875367 0.83875651 0.83873955 0.83875691 0.83875405 0.83875292█▎ | 33/40 [01:19<00:17, 2.43s/it]\n",
" 0.83875482 0.83874898 0.83875422 0.83875015 0.83875224 0.83875543\n",
" 0.8387522 0.83875012 0.8387534 0.83874705 0.83875253 0.83874978\n",
" 0.83874684 0.83875766 0.8387528 0.8387526 0.83874762 0.83874864\n",
" 0.83874501 0.83875084 0.83875135 0.83875956 0.83875426 0.83874856\n",
" 0.8387537 0.83875149 0.83875047 0.83875362 0.83875843 0.83875902\n",
" 0.83874978 0.83874908 0.83874593 0.83876096 0.83875086 0.83875414\n",
" 0.83875214 0.83875291 0.83875595 0.83875619 0.83876049 0.83875497\n",
" 0.83875172 0.83875078 0.83875316 0.83875099 0.83875813 0.83875121\n",
" 0.83875075 0.83875199 0.83875683 0.83875258 0.83875547 0.83875734\n",
" 0.83875926 0.83874945 0.83874978 0.83874942 0.83875676 0.8387468\n",
" 0.83874893 0.8387539 0.8387466 0.83875661 0.83875198 0.83875579\n",
" 0.83875737 0.83875116 0.83875181 0.83874742 0.83875752 0.83874663\n",
" 0.8387537 0.83875691 0.83875493 0.83874943 0.83875288 0.83875274\n",
" 0.83874997 0.83875062 0.83875674 0.83874987 0.83874648 0.83875053\n",
" 0.83874586 0.83875478 0.83874302 0.83874753 0.83875232 0.83875024\n",
"Epoch 34 RMSE: [0.83035544 0.83035804 0.83034167 0.83035847 0.83035574 0.83035477█▌ | 34/40 [01:22<00:14, 2.40s/it]\n",
" 0.83035649 0.83035096 0.83035615 0.83035188 0.830354 0.83035688\n",
" 0.830354 0.83035198 0.83035515 0.83034898 0.83035412 0.83035166\n",
" 0.83034873 0.83035899 0.83035465 0.83035423 0.83034937 0.83035063\n",
" 0.83034704 0.83035257 0.83035301 0.83036108 0.83035596 0.83035061\n",
" 0.8303553 0.83035321 0.83035214 0.83035538 0.83035994 0.83036051\n",
" 0.83035182 0.83035092 0.83034791 0.83036238 0.8303527 0.83035591\n",
" 0.83035386 0.83035462 0.83035766 0.83035786 0.83036194 0.83035665\n",
" 0.83035356 0.83035258 0.83035501 0.83035271 0.83035977 0.83035306\n",
" 0.83035243 0.8303536 0.83035819 0.83035435 0.83035698 0.83035884\n",
" 0.83036084 0.83035151 0.83035163 0.83035135 0.83035833 0.83034914\n",
" 0.83035078 0.83035569 0.83034846 0.83035808 0.83035365 0.8303573\n",
" 0.83035897 0.83035315 0.83035383 0.83034943 0.83035898 0.83034858\n",
" 0.8303552 0.8303588 0.8303565 0.83035125 0.83035462 0.8303546\n",
" 0.83035159 0.83035255 0.83035809 0.83035158 0.83034838 0.8303524\n",
" 0.83034775 0.83035644 0.83034512 0.83034972 0.83035411 0.83035205\n",
"Epoch 34 RMSE: [0.83035544 0.83035804 0.83034167 0.83035847 0.83035574 0.83035477█▌ | 34/40 [01:22<00:14, 2.40s/it] \n",
" 0.83035649 0.83035096 0.83035615 0.83035188 0.830354 0.83035688\n",
" 0.830354 0.83035198 0.83035515 0.83034898 0.83035412 0.83035166\n",
" 0.83034873 0.83035899 0.83035465 0.83035423 0.83034937 0.83035063\n",
" 0.83034704 0.83035257 0.83035301 0.83036108 0.83035596 0.83035061\n",
" 0.8303553 0.83035321 0.83035214 0.83035538 0.83035994 0.83036051\n",
" 0.83035182 0.83035092 0.83034791 0.83036238 0.8303527 0.83035591\n",
" 0.83035386 0.83035462 0.83035766 0.83035786 0.83036194 0.83035665\n",
" 0.83035356 0.83035258 0.83035501 0.83035271 0.83035977 0.83035306\n",
" 0.83035243 0.8303536 0.83035819 0.83035435 0.83035698 0.83035884\n",
" 0.83036084 0.83035151 0.83035163 0.83035135 0.83035833 0.83034914\n",
" 0.83035078 0.83035569 0.83034846 0.83035808 0.83035365 0.8303573\n",
" 0.83035897 0.83035315 0.83035383 0.83034943 0.83035898 0.83034858\n",
" 0.8303552 0.8303588 0.8303565 0.83035125 0.83035462 0.8303546\n",
" 0.83035159 0.83035255 0.83035809 0.83035158 0.83034838 0.8303524\n",
" 0.83034775 0.83035644 0.83034512 0.83034972 0.83035411 0.83035205\n",
"Epoch 35 RMSE: [0.8214729 0.82147524 0.82145939 0.82147562 0.82147307 0.82147219█▊ | 35/40 [01:24<00:11, 2.40s/it]\n",
" 0.8214738 0.82146852 0.82147354 0.82146919 0.82147131 0.8214741\n",
" 0.82147121 0.82146941 0.82147254 0.82146649 0.82147127 0.82146921\n",
" 0.82146623 0.82147592 0.82147201 0.82147148 0.8214668 0.82146827\n",
" 0.82146472 0.82146992 0.82147043 0.82147809 0.82147317 0.82146808\n",
" 0.8214725 0.82147075 0.82146953 0.82147277 0.82147715 0.82147767\n",
" 0.82146933 0.82146842 0.82146546 0.82147929 0.82147015 0.82147345\n",
" 0.82147114 0.82147193 0.82147493 0.82147502 0.82147899 0.82147406\n",
" 0.82147082 0.82146983 0.8214724 0.82147011 0.82147688 0.82147045\n",
" 0.8214698 0.82147083 0.8214753 0.82147177 0.82147411 0.82147595\n",
" 0.82147806 0.82146902 0.82146919 0.82146889 0.82147551 0.82146696\n",
" 0.82146833 0.82147311 0.82146583 0.82147531 0.82147094 0.82147445\n",
" 0.82147624 0.82147083 0.82147144 0.82146699 0.82147605 0.82146607\n",
" 0.82147232 0.82147618 0.82147384 0.82146869 0.82147197 0.821472\n",
" 0.82146881 0.82147009 0.82147523 0.82146891 0.82146581 0.82146996\n",
" 0.82146526 0.8214737 0.82146278 0.82146736 0.82147153 0.82146937\n",
"Epoch 35 RMSE: [0.8214729 0.82147524 0.82145939 0.82147562 0.82147307 0.82147219█▊ | 35/40 [01:24<00:11, 2.40s/it]\n",
" 0.8214738 0.82146852 0.82147354 0.82146919 0.82147131 0.8214741\n",
" 0.82147121 0.82146941 0.82147254 0.82146649 0.82147127 0.82146921\n",
" 0.82146623 0.82147592 0.82147201 0.82147148 0.8214668 0.82146827\n",
" 0.82146472 0.82146992 0.82147043 0.82147809 0.82147317 0.82146808\n",
" 0.8214725 0.82147075 0.82146953 0.82147277 0.82147715 0.82147767\n",
" 0.82146933 0.82146842 0.82146546 0.82147929 0.82147015 0.82147345\n",
" 0.82147114 0.82147193 0.82147493 0.82147502 0.82147899 0.82147406\n",
" 0.82147082 0.82146983 0.8214724 0.82147011 0.82147688 0.82147045\n",
" 0.8214698 0.82147083 0.8214753 0.82147177 0.82147411 0.82147595\n",
" 0.82147806 0.82146902 0.82146919 0.82146889 0.82147551 0.82146696\n",
" 0.82146833 0.82147311 0.82146583 0.82147531 0.82147094 0.82147445\n",
" 0.82147624 0.82147083 0.82147144 0.82146699 0.82147605 0.82146607\n",
" 0.82147232 0.82147618 0.82147384 0.82146869 0.82147197 0.821472\n",
" 0.82146881 0.82147009 0.82147523 0.82146891 0.82146581 0.82146996\n",
" 0.82146526 0.8214737 0.82146278 0.82146736 0.82147153 0.82146937\n",
"Epoch 36 RMSE: [0.81217636 0.81217838 0.81216314 0.81217887 0.81217654 0.81217581██ | 36/40 [01:27<00:09, 2.38s/it]\n",
" 0.81217706 0.81217216 0.8121771 0.81217266 0.81217472 0.81217733\n",
" 0.81217459 0.81217295 0.81217602 0.81217003 0.8121746 0.81217288\n",
" 0.81216985 0.8121791 0.81217559 0.81217486 0.81217041 0.8121719\n",
" 0.81216843 0.8121734 0.81217388 0.81218128 0.81217656 0.81217177\n",
" 0.81217588 0.81217428 0.81217299 0.81217616 0.81218048 0.81218104\n",
" 0.81217301 0.81217194 0.81216907 0.81218239 0.81217368 0.812177\n",
" 0.81217443 0.81217538 0.81217836 0.81217839 0.8121822 0.81217747\n",
" 0.81217426 0.81217336 0.81217596 0.81217355 0.81218016 0.812174\n",
" 0.81217316 0.81217415 0.81217845 0.81217527 0.81217745 0.81217914\n",
" 0.81218127 0.81217264 0.81217281 0.8121725 0.8121788 0.81217083\n",
" 0.81217185 0.8121766 0.81216943 0.81217854 0.81217442 0.8121777\n",
" 0.81217972 0.81217442 0.81217514 0.8121707 0.81217926 0.81216959\n",
" 0.81217561 0.81217967 0.81217712 0.81217216 0.81217545 0.8121755\n",
" 0.8121722 0.81217354 0.81217842 0.81217225 0.81216932 0.81217352\n",
" 0.81216881 0.81217714 0.81216646 0.81217107 0.81217497 0.81217293\n",
"Epoch 36 RMSE: [0.81217636 0.81217838 0.81216314 0.81217887 0.81217654 0.81217581██ | 36/40 [01:27<00:09, 2.38s/it] \n",
" 0.81217706 0.81217216 0.8121771 0.81217266 0.81217472 0.81217733\n",
" 0.81217459 0.81217295 0.81217602 0.81217003 0.8121746 0.81217288\n",
" 0.81216985 0.8121791 0.81217559 0.81217486 0.81217041 0.8121719\n",
" 0.81216843 0.8121734 0.81217388 0.81218128 0.81217656 0.81217177\n",
" 0.81217588 0.81217428 0.81217299 0.81217616 0.81218048 0.81218104\n",
" 0.81217301 0.81217194 0.81216907 0.81218239 0.81217368 0.812177\n",
" 0.81217443 0.81217538 0.81217836 0.81217839 0.8121822 0.81217747\n",
" 0.81217426 0.81217336 0.81217596 0.81217355 0.81218016 0.812174\n",
" 0.81217316 0.81217415 0.81217845 0.81217527 0.81217745 0.81217914\n",
" 0.81218127 0.81217264 0.81217281 0.8121725 0.8121788 0.81217083\n",
" 0.81217185 0.8121766 0.81216943 0.81217854 0.81217442 0.8121777\n",
" 0.81217972 0.81217442 0.81217514 0.8121707 0.81217926 0.81216959\n",
" 0.81217561 0.81217967 0.81217712 0.81217216 0.81217545 0.8121755\n",
" 0.8121722 0.81217354 0.81217842 0.81217225 0.81216932 0.81217352\n",
" 0.81216881 0.81217714 0.81216646 0.81217107 0.81217497 0.81217293\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 37 RMSE: [0.80252532 0.80252725 0.80251242 0.80252783 0.80252539 0.80252478██▎| 37/40 [01:29<00:07, 2.38s/it]\n",
" 0.8025259 0.80252122 0.80252609 0.80252151 0.80252362 0.80252608\n",
" 0.80252346 0.80252195 0.80252503 0.80251916 0.80252351 0.80252204\n",
" 0.80251897 0.80252778 0.80252459 0.80252369 0.80251944 0.80252106\n",
" 0.80251763 0.80252237 0.80252287 0.80252996 0.8025255 0.80252094\n",
" 0.80252473 0.80252344 0.80252194 0.80252521 0.80252927 0.80252972\n",
" 0.80252212 0.80252092 0.80251818 0.80253098 0.80252267 0.8025262\n",
" 0.80252327 0.80252429 0.80252731 0.80252717 0.80253087 0.80252644\n",
" 0.80252318 0.80252222 0.80252488 0.80252256 0.80252894 0.8025229\n",
" 0.80252205 0.80252291 0.80252718 0.80252424 0.80252625 0.8025278\n",
" 0.80253003 0.8025218 0.80252183 0.80252168 0.80252762 0.80252022\n",
" 0.80252098 0.80252566 0.8025185 0.80252745 0.8025234 0.80252656\n",
" 0.80252855 0.80252361 0.80252428 0.80252 0.80252802 0.80251868\n",
" 0.8025244 0.80252873 0.80252603 0.80252113 0.80252447 0.80252464\n",
" 0.80252102 0.80252265 0.80252716 0.80252113 0.80251829 0.80252258\n",
" 0.80251785 0.80252601 0.80251562 0.80252025 0.80252401 0.80252183\n",
"Epoch 37 RMSE: [0.80252532 0.80252725 0.80251242 0.80252783 0.80252539 0.80252478██▎| 37/40 [01:29<00:07, 2.38s/it]\n",
" 0.8025259 0.80252122 0.80252609 0.80252151 0.80252362 0.80252608\n",
" 0.80252346 0.80252195 0.80252503 0.80251916 0.80252351 0.80252204\n",
" 0.80251897 0.80252778 0.80252459 0.80252369 0.80251944 0.80252106\n",
" 0.80251763 0.80252237 0.80252287 0.80252996 0.8025255 0.80252094\n",
" 0.80252473 0.80252344 0.80252194 0.80252521 0.80252927 0.80252972\n",
" 0.80252212 0.80252092 0.80251818 0.80253098 0.80252267 0.8025262\n",
" 0.80252327 0.80252429 0.80252731 0.80252717 0.80253087 0.80252644\n",
" 0.80252318 0.80252222 0.80252488 0.80252256 0.80252894 0.8025229\n",
" 0.80252205 0.80252291 0.80252718 0.80252424 0.80252625 0.8025278\n",
" 0.80253003 0.8025218 0.80252183 0.80252168 0.80252762 0.80252022\n",
" 0.80252098 0.80252566 0.8025185 0.80252745 0.8025234 0.80252656\n",
" 0.80252855 0.80252361 0.80252428 0.80252 0.80252802 0.80251868\n",
" 0.8025244 0.80252873 0.80252603 0.80252113 0.80252447 0.80252464\n",
" 0.80252102 0.80252265 0.80252716 0.80252113 0.80251829 0.80252258\n",
" 0.80251785 0.80252601 0.80251562 0.80252025 0.80252401 0.80252183\n",
"Epoch 38 RMSE: [0.79251062 0.79251227 0.79249797 0.79251285 0.79251066 0.79251015██▌| 38/40 [01:31<00:04, 2.38s/it]\n",
" 0.79251103 0.79250659 0.79251131 0.79250683 0.79250881 0.7925112\n",
" 0.79250865 0.7925072 0.79251026 0.79250454 0.79250855 0.79250742\n",
" 0.79250433 0.79251263 0.79250989 0.79250885 0.79250477 0.79250644\n",
" 0.79250308 0.79250752 0.79250805 0.792515 0.79251052 0.79250641\n",
" 0.79250984 0.79250861 0.79250719 0.7925104 0.79251439 0.79251481\n",
" 0.79250758 0.79250625 0.79250359 0.79251587 0.79250798 0.79251144\n",
" 0.79250841 0.79250943 0.7925125 0.79251235 0.79251588 0.7925116\n",
" 0.79250837 0.79250745 0.79251016 0.7925077 0.79251408 0.79250827\n",
" 0.7925073 0.79250815 0.79251217 0.79250951 0.79251136 0.79251279\n",
" 0.79251522 0.79250715 0.7925072 0.79250696 0.79251271 0.79250588\n",
" 0.79250626 0.79251087 0.79250379 0.79251246 0.79250858 0.79251161\n",
" 0.7925137 0.79250899 0.79250973 0.79250533 0.792513 0.79250409\n",
" 0.79250943 0.79251398 0.79251109 0.79250638 0.79250969 0.79250979\n",
" 0.79250619 0.79250803 0.79251216 0.79250638 0.79250363 0.79250783\n",
" 0.79250324 0.79251124 0.79250114 0.79250579 0.79250924 0.79250711\n",
"Epoch 38 RMSE: [0.79251062 0.79251227 0.79249797 0.79251285 0.79251066 0.79251015██▌| 38/40 [01:31<00:04, 2.38s/it]\n",
" 0.79251103 0.79250659 0.79251131 0.79250683 0.79250881 0.7925112\n",
" 0.79250865 0.7925072 0.79251026 0.79250454 0.79250855 0.79250742\n",
" 0.79250433 0.79251263 0.79250989 0.79250885 0.79250477 0.79250644\n",
" 0.79250308 0.79250752 0.79250805 0.792515 0.79251052 0.79250641\n",
" 0.79250984 0.79250861 0.79250719 0.7925104 0.79251439 0.79251481\n",
" 0.79250758 0.79250625 0.79250359 0.79251587 0.79250798 0.79251144\n",
" 0.79250841 0.79250943 0.7925125 0.79251235 0.79251588 0.7925116\n",
" 0.79250837 0.79250745 0.79251016 0.7925077 0.79251408 0.79250827\n",
" 0.7925073 0.79250815 0.79251217 0.79250951 0.79251136 0.79251279\n",
" 0.79251522 0.79250715 0.7925072 0.79250696 0.79251271 0.79250588\n",
" 0.79250626 0.79251087 0.79250379 0.79251246 0.79250858 0.79251161\n",
" 0.7925137 0.79250899 0.79250973 0.79250533 0.792513 0.79250409\n",
" 0.79250943 0.79251398 0.79251109 0.79250638 0.79250969 0.79250979\n",
" 0.79250619 0.79250803 0.79251216 0.79250638 0.79250363 0.79250783\n",
" 0.79250324 0.79251124 0.79250114 0.79250579 0.79250924 0.79250711\n",
"Epoch 39 RMSE: [0.78211049 0.78211201 0.78209813 0.78211264 0.78211053 0.78211019██▊| 39/40 [01:34<00:02, 2.38s/it]\n",
" 0.78211079 0.78210665 0.78211125 0.78210675 0.78210874 0.78211085\n",
" 0.78210849 0.78210722 0.78211018 0.78210456 0.78210833 0.7821075\n",
" 0.78210433 0.78211224 0.78210978 0.78210864 0.78210482 0.78210649\n",
" 0.78210325 0.78210742 0.78210795 0.78211477 0.78211035 0.78210647\n",
" 0.78210964 0.78210863 0.78210705 0.7821103 0.78211415 0.78211461\n",
" 0.78210756 0.7821062 0.78210364 0.78211536 0.78210784 0.78211148\n",
" 0.78210821 0.78210925 0.78211235 0.78211215 0.78211556 0.78211147\n",
" 0.78210825 0.7821074 0.78211008 0.78210755 0.78211375 0.78210818\n",
" 0.78210719 0.78210798 0.78211186 0.78210941 0.78211107 0.78211257\n",
" 0.78211489 0.78210719 0.78210722 0.78210696 0.78211244 0.78210618\n",
" 0.78210627 0.78211081 0.7821038 0.78211226 0.78210851 0.78211143\n",
" 0.78211353 0.78210906 0.78210978 0.78210549 0.78211269 0.78210406\n",
" 0.78210928 0.78211382 0.78211093 0.78210638 0.78210957 0.78210971\n",
" 0.78210602 0.782108 0.7821118 0.78210625 0.7821036 0.78210788\n",
" 0.7821032 0.78211109 0.78210125 0.78210593 0.78210907 0.78210702\n",
"Epoch 39 RMSE: [0.78211049 0.78211201 0.78209813 0.78211264 0.78211053 0.78211019██▊| 39/40 [01:34<00:02, 2.38s/it]\n",
" 0.78211079 0.78210665 0.78211125 0.78210675 0.78210874 0.78211085\n",
" 0.78210849 0.78210722 0.78211018 0.78210456 0.78210833 0.7821075\n",
" 0.78210433 0.78211224 0.78210978 0.78210864 0.78210482 0.78210649\n",
" 0.78210325 0.78210742 0.78210795 0.78211477 0.78211035 0.78210647\n",
" 0.78210964 0.78210863 0.78210705 0.7821103 0.78211415 0.78211461\n",
" 0.78210756 0.7821062 0.78210364 0.78211536 0.78210784 0.78211148\n",
" 0.78210821 0.78210925 0.78211235 0.78211215 0.78211556 0.78211147\n",
" 0.78210825 0.7821074 0.78211008 0.78210755 0.78211375 0.78210818\n",
" 0.78210719 0.78210798 0.78211186 0.78210941 0.78211107 0.78211257\n",
" 0.78211489 0.78210719 0.78210722 0.78210696 0.78211244 0.78210618\n",
" 0.78210627 0.78211081 0.7821038 0.78211226 0.78210851 0.78211143\n",
" 0.78211353 0.78210906 0.78210978 0.78210549 0.78211269 0.78210406\n",
" 0.78210928 0.78211382 0.78211093 0.78210638 0.78210957 0.78210971\n",
" 0.78210602 0.782108 0.7821118 0.78210625 0.7821036 0.78210788\n",
" 0.7821032 0.78211109 0.78210125 0.78210593 0.78210907 0.78210702\n",
"Epoch 39 RMSE: [0.78211049 0.78211201 0.78209813 0.78211264 0.78211053 0.78211019███| 40/40 [01:36<00:00, 2.37s/it]\n",
" 0.78211079 0.78210665 0.78211125 0.78210675 0.78210874 0.78211085\n",
" 0.78210849 0.78210722 0.78211018 0.78210456 0.78210833 0.7821075\n",
" 0.78210433 0.78211224 0.78210978 0.78210864 0.78210482 0.78210649\n",
" 0.78210325 0.78210742 0.78210795 0.78211477 0.78211035 0.78210647\n",
" 0.78210964 0.78210863 0.78210705 0.7821103 0.78211415 0.78211461\n",
" 0.78210756 0.7821062 0.78210364 0.78211536 0.78210784 0.78211148\n",
" 0.78210821 0.78210925 0.78211235 0.78211215 0.78211556 0.78211147\n",
" 0.78210825 0.7821074 0.78211008 0.78210755 0.78211375 0.78210818\n",
" 0.78210719 0.78210798 0.78211186 0.78210941 0.78211107 0.78211257\n",
" 0.78211489 0.78210719 0.78210722 0.78210696 0.78211244 0.78210618\n",
" 0.78210627 0.78211081 0.7821038 0.78211226 0.78210851 0.78211143\n",
" 0.78211353 0.78210906 0.78210978 0.78210549 0.78211269 0.78210406\n",
" 0.78210928 0.78211382 0.78211093 0.78210638 0.78210957 0.78210971\n",
" 0.78210602 0.782108 0.7821118 0.78210625 0.7821036 0.78210788\n",
" 0.7821032 0.78211109 0.78210125 0.78210593 0.78210907 0.78210702\n",
" 0.78210826 0.78211145 0.78211014 0.78210739]. Training epoch 40...: 100%|██████████| 40/40 [01:36<00:00, 2.41s/it]\n"
]
}
],
"source": [
"model = SVDbaseline(train_ui, learning_rate = 0.005, regularization = 0.02, nb_factors = 100, iterations = 40)\n",
"model.train(test_ui)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"model.estimations()\n",
"\n",
"top_n=pd.DataFrame(model.recommend(user_code_id, item_code_id, topK=10))\n",
"\n",
"top_n.to_csv('Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv', index=False, header=False)\n",
"\n",
"estimations=pd.DataFrame(model.estimate(user_code_id, item_code_id, test_ui))\n",
"estimations.to_csv('Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv', index=False, header=False)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 8806.80it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>HR2</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>3.64479</td>\n",
" <td>3.479397</td>\n",
" <td>0.13701</td>\n",
" <td>0.082007</td>\n",
" <td>0.083942</td>\n",
" <td>0.100776</td>\n",
" <td>0.106974</td>\n",
" <td>0.105605</td>\n",
" <td>0.160418</td>\n",
" <td>0.080222</td>\n",
" <td>0.322261</td>\n",
" <td>0.537895</td>\n",
" <td>0.626723</td>\n",
" <td>0.360551</td>\n",
" <td>0.999894</td>\n",
" <td>0.276335</td>\n",
" <td>5.123235</td>\n",
" <td>0.910511</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" RMSE MAE precision recall F_1 F_05 \\\n",
"0 3.64479 3.479397 0.13701 0.082007 0.083942 0.100776 \n",
"\n",
" precision_super recall_super NDCG mAP MRR LAUC \\\n",
"0 0.106974 0.105605 0.160418 0.080222 0.322261 0.537895 \n",
"\n",
" HR HR2 Reco in test Test coverage Shannon Gini \n",
"0 0.626723 0.360551 0.999894 0.276335 5.123235 0.910511 "
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import evaluation_measures as ev\n",
"\n",
"estimations_df=pd.read_csv('Recommendations generated/ml-100k/Self_SVDBaseline_estimations.csv', header=None)\n",
"reco=np.loadtxt('Recommendations generated/ml-100k/Self_SVDBaseline_reco.csv', delimiter=',')\n",
"\n",
"ev.evaluate(test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None),\n",
" estimations_df=estimations_df, \n",
" reco=reco,\n",
" super_reactions=[4,5])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ready-made SVD - Surprise implementation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"import helpers\n",
"import surprise as sp\n",
"import imp\n",
"imp.reload(helpers)\n",
"\n",
"algo = sp.SVD(biased=False) # to use unbiased version\n",
"\n",
"helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVD_reco.csv',\n",
" estimations_path='Recommendations generated/ml-100k/Ready_SVD_estimations.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVD biased - on top baseline"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Generating predictions...\n",
"Generating top N recommendations...\n",
"Generating predictions...\n"
]
}
],
"source": [
"import helpers\n",
"import surprise as sp\n",
"import imp\n",
"imp.reload(helpers)\n",
"\n",
"algo = sp.SVD() # default is biased=True\n",
"\n",
"helpers.ready_made(algo, reco_path='Recommendations generated/ml-100k/Ready_SVDBiased_reco.csv',\n",
" estimations_path='Recommendations generated/ml-100k/Ready_SVDBiased_estimations.csv')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"943it [00:00, 9840.17it/s]\n",
"943it [00:00, 9173.71it/s]\n",
"943it [00:00, 9859.58it/s]\n",
"943it [00:00, 9112.74it/s]\n",
"943it [00:00, 9551.34it/s]\n",
"943it [00:00, 7830.24it/s]\n",
"943it [00:00, 8983.95it/s]\n",
"943it [00:00, 9447.94it/s]\n",
"943it [00:00, 9866.98it/s]\n",
"943it [00:00, 10127.04it/s]\n",
"943it [00:00, 9035.26it/s]\n",
"943it [00:00, 9754.44it/s]\n",
"943it [00:00, 9524.64it/s]\n",
"943it [00:00, 8451.27it/s]\n",
"943it [00:00, 9054.06it/s]\n",
"943it [00:00, 10007.79it/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Model</th>\n",
" <th>RMSE</th>\n",
" <th>MAE</th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>F_1</th>\n",
" <th>F_05</th>\n",
" <th>precision_super</th>\n",
" <th>recall_super</th>\n",
" <th>NDCG</th>\n",
" <th>mAP</th>\n",
" <th>MRR</th>\n",
" <th>LAUC</th>\n",
" <th>HR</th>\n",
" <th>HR2</th>\n",
" <th>Reco in test</th>\n",
" <th>Test coverage</th>\n",
" <th>Shannon</th>\n",
" <th>Gini</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopPop</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.188865</td>\n",
" <td>0.116919</td>\n",
" <td>0.118732</td>\n",
" <td>0.141584</td>\n",
" <td>0.130472</td>\n",
" <td>0.137473</td>\n",
" <td>0.214651</td>\n",
" <td>0.111707</td>\n",
" <td>0.400939</td>\n",
" <td>0.555546</td>\n",
" <td>0.765642</td>\n",
" <td>0.492047</td>\n",
" <td>1.000000</td>\n",
" <td>0.038961</td>\n",
" <td>3.159079</td>\n",
" <td>0.987317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVDBaseline</td>\n",
" <td>3.644790</td>\n",
" <td>3.479397</td>\n",
" <td>0.137010</td>\n",
" <td>0.082007</td>\n",
" <td>0.083942</td>\n",
" <td>0.100776</td>\n",
" <td>0.106974</td>\n",
" <td>0.105605</td>\n",
" <td>0.160418</td>\n",
" <td>0.080222</td>\n",
" <td>0.322261</td>\n",
" <td>0.537895</td>\n",
" <td>0.626723</td>\n",
" <td>0.360551</td>\n",
" <td>0.999894</td>\n",
" <td>0.276335</td>\n",
" <td>5.123235</td>\n",
" <td>0.910511</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVD</td>\n",
" <td>0.950945</td>\n",
" <td>0.749680</td>\n",
" <td>0.098834</td>\n",
" <td>0.049106</td>\n",
" <td>0.054037</td>\n",
" <td>0.068741</td>\n",
" <td>0.087768</td>\n",
" <td>0.073987</td>\n",
" <td>0.113242</td>\n",
" <td>0.054201</td>\n",
" <td>0.243492</td>\n",
" <td>0.521280</td>\n",
" <td>0.493107</td>\n",
" <td>0.248144</td>\n",
" <td>0.998515</td>\n",
" <td>0.214286</td>\n",
" <td>4.413166</td>\n",
" <td>0.953488</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_SVD</td>\n",
" <td>0.915079</td>\n",
" <td>0.718240</td>\n",
" <td>0.104772</td>\n",
" <td>0.045496</td>\n",
" <td>0.054393</td>\n",
" <td>0.071374</td>\n",
" <td>0.094421</td>\n",
" <td>0.076826</td>\n",
" <td>0.109517</td>\n",
" <td>0.052005</td>\n",
" <td>0.206646</td>\n",
" <td>0.519484</td>\n",
" <td>0.487805</td>\n",
" <td>0.264051</td>\n",
" <td>0.874549</td>\n",
" <td>0.142136</td>\n",
" <td>3.890472</td>\n",
" <td>0.972126</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Baseline</td>\n",
" <td>0.949459</td>\n",
" <td>0.752487</td>\n",
" <td>0.091410</td>\n",
" <td>0.037652</td>\n",
" <td>0.046030</td>\n",
" <td>0.061286</td>\n",
" <td>0.079614</td>\n",
" <td>0.056463</td>\n",
" <td>0.095957</td>\n",
" <td>0.043178</td>\n",
" <td>0.198193</td>\n",
" <td>0.515501</td>\n",
" <td>0.437964</td>\n",
" <td>0.239661</td>\n",
" <td>1.000000</td>\n",
" <td>0.033911</td>\n",
" <td>2.836513</td>\n",
" <td>0.991139</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_SVDBiased</td>\n",
" <td>0.938535</td>\n",
" <td>0.738678</td>\n",
" <td>0.085366</td>\n",
" <td>0.036921</td>\n",
" <td>0.044151</td>\n",
" <td>0.057832</td>\n",
" <td>0.074893</td>\n",
" <td>0.056396</td>\n",
" <td>0.095960</td>\n",
" <td>0.044204</td>\n",
" <td>0.212483</td>\n",
" <td>0.515132</td>\n",
" <td>0.446448</td>\n",
" <td>0.217391</td>\n",
" <td>0.997561</td>\n",
" <td>0.168110</td>\n",
" <td>4.191946</td>\n",
" <td>0.963341</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_KNNSurprisetask</td>\n",
" <td>0.946255</td>\n",
" <td>0.745209</td>\n",
" <td>0.083457</td>\n",
" <td>0.032848</td>\n",
" <td>0.041227</td>\n",
" <td>0.055493</td>\n",
" <td>0.074785</td>\n",
" <td>0.048890</td>\n",
" <td>0.089577</td>\n",
" <td>0.040902</td>\n",
" <td>0.189057</td>\n",
" <td>0.513076</td>\n",
" <td>0.417815</td>\n",
" <td>0.217391</td>\n",
" <td>0.888547</td>\n",
" <td>0.130592</td>\n",
" <td>3.611806</td>\n",
" <td>0.978659</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_GlobalAvg</td>\n",
" <td>1.125760</td>\n",
" <td>0.943534</td>\n",
" <td>0.061188</td>\n",
" <td>0.025968</td>\n",
" <td>0.031383</td>\n",
" <td>0.041343</td>\n",
" <td>0.040558</td>\n",
" <td>0.032107</td>\n",
" <td>0.067695</td>\n",
" <td>0.027470</td>\n",
" <td>0.171187</td>\n",
" <td>0.509546</td>\n",
" <td>0.384942</td>\n",
" <td>0.142100</td>\n",
" <td>1.000000</td>\n",
" <td>0.025974</td>\n",
" <td>2.711772</td>\n",
" <td>0.992003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_Random</td>\n",
" <td>1.522798</td>\n",
" <td>1.222501</td>\n",
" <td>0.049841</td>\n",
" <td>0.020656</td>\n",
" <td>0.025232</td>\n",
" <td>0.033446</td>\n",
" <td>0.030579</td>\n",
" <td>0.022927</td>\n",
" <td>0.051680</td>\n",
" <td>0.019110</td>\n",
" <td>0.123085</td>\n",
" <td>0.506849</td>\n",
" <td>0.331919</td>\n",
" <td>0.119830</td>\n",
" <td>0.985048</td>\n",
" <td>0.183983</td>\n",
" <td>5.097973</td>\n",
" <td>0.907483</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNN</td>\n",
" <td>1.030386</td>\n",
" <td>0.813067</td>\n",
" <td>0.026087</td>\n",
" <td>0.006908</td>\n",
" <td>0.010593</td>\n",
" <td>0.016046</td>\n",
" <td>0.021137</td>\n",
" <td>0.009522</td>\n",
" <td>0.024214</td>\n",
" <td>0.008958</td>\n",
" <td>0.048068</td>\n",
" <td>0.499885</td>\n",
" <td>0.154825</td>\n",
" <td>0.072110</td>\n",
" <td>0.402333</td>\n",
" <td>0.434343</td>\n",
" <td>5.133650</td>\n",
" <td>0.877999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_I-KNNBaseline</td>\n",
" <td>0.935327</td>\n",
" <td>0.737424</td>\n",
" <td>0.002545</td>\n",
" <td>0.000755</td>\n",
" <td>0.001105</td>\n",
" <td>0.001602</td>\n",
" <td>0.002253</td>\n",
" <td>0.000930</td>\n",
" <td>0.003444</td>\n",
" <td>0.001362</td>\n",
" <td>0.011760</td>\n",
" <td>0.496724</td>\n",
" <td>0.021209</td>\n",
" <td>0.004242</td>\n",
" <td>0.482821</td>\n",
" <td>0.059885</td>\n",
" <td>2.232578</td>\n",
" <td>0.994487</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Ready_U-KNN</td>\n",
" <td>1.023495</td>\n",
" <td>0.807913</td>\n",
" <td>0.000742</td>\n",
" <td>0.000205</td>\n",
" <td>0.000305</td>\n",
" <td>0.000449</td>\n",
" <td>0.000536</td>\n",
" <td>0.000198</td>\n",
" <td>0.000845</td>\n",
" <td>0.000274</td>\n",
" <td>0.002744</td>\n",
" <td>0.496441</td>\n",
" <td>0.007423</td>\n",
" <td>0.000000</td>\n",
" <td>0.602121</td>\n",
" <td>0.010823</td>\n",
" <td>2.089186</td>\n",
" <td>0.995706</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineIU</td>\n",
" <td>0.958136</td>\n",
" <td>0.754051</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.000000</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_TopRated</td>\n",
" <td>2.508258</td>\n",
" <td>2.217909</td>\n",
" <td>0.000954</td>\n",
" <td>0.000188</td>\n",
" <td>0.000298</td>\n",
" <td>0.000481</td>\n",
" <td>0.000644</td>\n",
" <td>0.000223</td>\n",
" <td>0.001043</td>\n",
" <td>0.000335</td>\n",
" <td>0.003348</td>\n",
" <td>0.496433</td>\n",
" <td>0.009544</td>\n",
" <td>0.000000</td>\n",
" <td>0.699046</td>\n",
" <td>0.005051</td>\n",
" <td>1.945910</td>\n",
" <td>0.995669</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_BaselineUI</td>\n",
" <td>0.967585</td>\n",
" <td>0.762740</td>\n",
" <td>0.000954</td>\n",
" <td>0.000170</td>\n",
" <td>0.000278</td>\n",
" <td>0.000463</td>\n",
" <td>0.000644</td>\n",
" <td>0.000189</td>\n",
" <td>0.000752</td>\n",
" <td>0.000168</td>\n",
" <td>0.001677</td>\n",
" <td>0.496424</td>\n",
" <td>0.009544</td>\n",
" <td>0.000000</td>\n",
" <td>0.600530</td>\n",
" <td>0.005051</td>\n",
" <td>1.803126</td>\n",
" <td>0.996380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Self_IKNN</td>\n",
" <td>1.018363</td>\n",
" <td>0.808793</td>\n",
" <td>0.000318</td>\n",
" <td>0.000108</td>\n",
" <td>0.000140</td>\n",
" <td>0.000189</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000214</td>\n",
" <td>0.000037</td>\n",
" <td>0.000368</td>\n",
" <td>0.496391</td>\n",
" <td>0.003181</td>\n",
" <td>0.000000</td>\n",
" <td>0.392153</td>\n",
" <td>0.115440</td>\n",
" <td>4.174741</td>\n",
" <td>0.965327</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Model RMSE MAE precision recall F_1 \\\n",
"0 Self_TopPop 2.508258 2.217909 0.188865 0.116919 0.118732 \n",
"0 Self_SVDBaseline 3.644790 3.479397 0.137010 0.082007 0.083942 \n",
"0 Ready_SVD 0.950945 0.749680 0.098834 0.049106 0.054037 \n",
"0 Self_SVD 0.915079 0.718240 0.104772 0.045496 0.054393 \n",
"0 Ready_Baseline 0.949459 0.752487 0.091410 0.037652 0.046030 \n",
"0 Ready_SVDBiased 0.938535 0.738678 0.085366 0.036921 0.044151 \n",
"0 Self_KNNSurprisetask 0.946255 0.745209 0.083457 0.032848 0.041227 \n",
"0 Self_GlobalAvg 1.125760 0.943534 0.061188 0.025968 0.031383 \n",
"0 Ready_Random 1.522798 1.222501 0.049841 0.020656 0.025232 \n",
"0 Ready_I-KNN 1.030386 0.813067 0.026087 0.006908 0.010593 \n",
"0 Ready_I-KNNBaseline 0.935327 0.737424 0.002545 0.000755 0.001105 \n",
"0 Ready_U-KNN 1.023495 0.807913 0.000742 0.000205 0.000305 \n",
"0 Self_BaselineIU 0.958136 0.754051 0.000954 0.000188 0.000298 \n",
"0 Self_TopRated 2.508258 2.217909 0.000954 0.000188 0.000298 \n",
"0 Self_BaselineUI 0.967585 0.762740 0.000954 0.000170 0.000278 \n",
"0 Self_IKNN 1.018363 0.808793 0.000318 0.000108 0.000140 \n",
"\n",
" F_05 precision_super recall_super NDCG mAP MRR \\\n",
"0 0.141584 0.130472 0.137473 0.214651 0.111707 0.400939 \n",
"0 0.100776 0.106974 0.105605 0.160418 0.080222 0.322261 \n",
"0 0.068741 0.087768 0.073987 0.113242 0.054201 0.243492 \n",
"0 0.071374 0.094421 0.076826 0.109517 0.052005 0.206646 \n",
"0 0.061286 0.079614 0.056463 0.095957 0.043178 0.198193 \n",
"0 0.057832 0.074893 0.056396 0.095960 0.044204 0.212483 \n",
"0 0.055493 0.074785 0.048890 0.089577 0.040902 0.189057 \n",
"0 0.041343 0.040558 0.032107 0.067695 0.027470 0.171187 \n",
"0 0.033446 0.030579 0.022927 0.051680 0.019110 0.123085 \n",
"0 0.016046 0.021137 0.009522 0.024214 0.008958 0.048068 \n",
"0 0.001602 0.002253 0.000930 0.003444 0.001362 0.011760 \n",
"0 0.000449 0.000536 0.000198 0.000845 0.000274 0.002744 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000481 0.000644 0.000223 0.001043 0.000335 0.003348 \n",
"0 0.000463 0.000644 0.000189 0.000752 0.000168 0.001677 \n",
"0 0.000189 0.000000 0.000000 0.000214 0.000037 0.000368 \n",
"\n",
" LAUC HR HR2 Reco in test Test coverage Shannon \\\n",
"0 0.555546 0.765642 0.492047 1.000000 0.038961 3.159079 \n",
"0 0.537895 0.626723 0.360551 0.999894 0.276335 5.123235 \n",
"0 0.521280 0.493107 0.248144 0.998515 0.214286 4.413166 \n",
"0 0.519484 0.487805 0.264051 0.874549 0.142136 3.890472 \n",
"0 0.515501 0.437964 0.239661 1.000000 0.033911 2.836513 \n",
"0 0.515132 0.446448 0.217391 0.997561 0.168110 4.191946 \n",
"0 0.513076 0.417815 0.217391 0.888547 0.130592 3.611806 \n",
"0 0.509546 0.384942 0.142100 1.000000 0.025974 2.711772 \n",
"0 0.506849 0.331919 0.119830 0.985048 0.183983 5.097973 \n",
"0 0.499885 0.154825 0.072110 0.402333 0.434343 5.133650 \n",
"0 0.496724 0.021209 0.004242 0.482821 0.059885 2.232578 \n",
"0 0.496441 0.007423 0.000000 0.602121 0.010823 2.089186 \n",
"0 0.496433 0.009544 0.000000 0.699046 0.005051 1.945910 \n",
"0 0.496433 0.009544 0.000000 0.699046 0.005051 1.945910 \n",
"0 0.496424 0.009544 0.000000 0.600530 0.005051 1.803126 \n",
"0 0.496391 0.003181 0.000000 0.392153 0.115440 4.174741 \n",
"\n",
" Gini \n",
"0 0.987317 \n",
"0 0.910511 \n",
"0 0.953488 \n",
"0 0.972126 \n",
"0 0.991139 \n",
"0 0.963341 \n",
"0 0.978659 \n",
"0 0.992003 \n",
"0 0.907483 \n",
"0 0.877999 \n",
"0 0.994487 \n",
"0 0.995706 \n",
"0 0.995669 \n",
"0 0.995669 \n",
"0 0.996380 \n",
"0 0.965327 "
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import imp\n",
"imp.reload(ev)\n",
"\n",
"import evaluation_measures as ev\n",
"dir_path=\"Recommendations generated/ml-100k/\"\n",
"super_reactions=[4,5]\n",
"test=pd.read_csv('./Datasets/ml-100k/test.csv', sep='\\t', header=None)\n",
"\n",
"ev.evaluate_all(test, dir_path, super_reactions)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}