uczenie-maszynowe/lab/10_Sieci_neuronowe.ipynb

323 lines
104 KiB
Plaintext
Raw Normal View History

2023-05-25 10:22:21 +02:00
{
"cells": [
{
2023-05-25 11:00:24 +02:00
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "raw",
"metadata": {},
"source": []
},
{
2023-05-25 10:22:21 +02:00
"cell_type": "markdown",
"metadata": {},
"source": [
"## Uczenie maszynowe zastosowania\n",
"### Zajęcia laboratoryjne\n",
"# 10. Wprowadzenie do sieci neuronowych"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Poniżej znajduje się implementacja prostej sieci neuronowej dla problemu klasyfikacji binarnej na przykładzie losowo wygenerowanego zestawu danych.\n",
"\n",
"W sieciach jednokierunkowych (ang. *feedforward*) wartości neuronów w $i$-tej warstwie są obliczane na podstawie wartości neuronów warstwy $i-1$. Mając daną $n$-warstwową sieć neuronową oraz jej parametry $\\Theta^{(1)}, \\ldots, \\Theta^{(n)} $ oraz $\\beta^{(1)}, \\ldots, \\beta^{(n)}$ liczymy: \n",
"$$a^{(i)} = g^{(i)}\\left( a^{(i-1)} \\Theta^{(i)} + \\beta^{(i)} \\right) \\; , $$\n",
"gdzie $g^{(i)}$ to tzw. **funkcje aktywacji**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Zadanie 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Część podstawowa (4 punkty)\n",
"\n",
" 1. Zaimplementuj funkcję `accuracy()` liczącą skuteczność klasyfikacji we wskazanym miejscu.\n",
" 2. Za jej pomocą oblicz i wypisz końcową skuteczność klasyfikatora.\n",
" 3. Wypisuj również wartość `accuracy` podczas uczenia (przy okazji wypisywania wartości funkcji kosztu).\n",
" 4. Zbuduj sieci neuronowe dla różnych wielkości warstwy ukrytej (`dim_hid` = 1, 2, 5, 10, 25). Porównaj skuteczność tych modeli."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Część zaawansowana (3 punkty)\n",
"\n",
"Zastosuj poniższą implementację sieci neuronowej do klasyfikacji binarnej zbioru wygenerowanego za pomocą wybranej funkcji [sklearn.datasets](http://scikit-learn.org/stable/modules/classes.html#samples-generator). Ustal rozmiary warstw wejściowej i wyjściowej, dobierz odpowiednie parametry sieci (parametr $\\alpha$, liczba epok, wielkość warstwy ukrytej). Podaj skuteczność klasyfikacji."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn import datasets\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def generate_data():\n",
" # Keep results deterministic\n",
" np.random.seed(1234)\n",
" X, y = datasets.make_moons(200, noise=0.25)\n",
" # X, y = datasets.make_classification(200, 2, 2, 0)\n",
" return X, y\n",
"\n",
"def visualize(X, y, model=None):\n",
" x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5\n",
" y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5\n",
" h = 0.01\n",
" xx, yy = np.meshgrid(\n",
" np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))\n",
" if model:\n",
" Z = predict(model, np.c_[xx.ravel(), yy.ravel()])\n",
" Z = Z.reshape(xx.shape)\n",
" plt.contourf(xx, yy, Z, cmap=plt.cm.viridis)\n",
" plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.viridis)\n",
" plt.show()\n",
"\n",
"def initialize_model(dim_in=2, dim_hid=3, dim_out=2):\n",
" # Keep results deterministic\n",
" np.random.seed(1234)\n",
" W1 = np.random.randn(dim_in, dim_hid) / np.sqrt(dim_in)\n",
" b1 = np.zeros((1, dim_hid))\n",
" W2 = np.random.randn(dim_hid, dim_out) / np.sqrt(dim_hid)\n",
" b2 = np.zeros((1, dim_out))\n",
" return W1, b1, W2, b2\n",
"\n",
"def softmax(X):\n",
" e = np.exp(X)\n",
" return e / np.sum(e, axis=1, keepdims=True)\n",
"\n",
"def predict(model, X):\n",
" W1, b1, W2, b2 = model\n",
" z1 = X.dot(W1) + b1\n",
" a1 = np.tanh(z1)\n",
" z2 = a1.dot(W2) + b2\n",
" probs = softmax(z2)\n",
" return np.argmax(probs, axis=1)\n",
"\n",
"def calculate_cost(model, X, y):\n",
" W1, b1, W2, b2 = model\n",
" z1 = X.dot(W1) + b1\n",
" a1 = np.tanh(z1)\n",
" z2 = a1.dot(W2) + b2\n",
" probs = softmax(z2)\n",
" preds = probs[:, 1]\n",
" return -1. / len(y) * np.sum(\n",
" np.multiply(y, np.log(preds)) + np.multiply(1 - y, np.log(1 - preds)),\n",
" axis=0)\n",
"\n",
"def accuracy(model, X, y):\n",
" # TODO: 1. Napisz funkcję obliczającą skuteczność (accuracy).\n",
" # Skorzystaj z funkcji `predict`.\n",
" pass\n",
"\n",
"def train(model, X, y, alpha=0.01, epochs=10000, debug=False):\n",
" W1, b1, W2, b2 = model\n",
" m = len(X)\n",
"\n",
" for i in range(epochs):\n",
" # Forward propagation\n",
" z1 = X.dot(W1) + b1\n",
" a1 = np.tanh(z1)\n",
" z2 = a1.dot(W2) + b2\n",
" probs = softmax(z2)\n",
"\n",
" # Backpropagation\n",
" delta3 = probs\n",
" delta3[range(m), y] -= 1\n",
" dW2 = (a1.T).dot(delta3)\n",
" db2 = np.sum(delta3, axis=0, keepdims=True)\n",
" delta2 = delta3.dot(W2.T) * (1 - np.power(a1, 2))\n",
" dW1 = np.dot(X.T, delta2)\n",
" db1 = np.sum(delta2, axis=0)\n",
"\n",
" # Parameter update\n",
" W1 -= alpha * dW1\n",
" b1 -= alpha * db1\n",
" W2 -= alpha * dW2\n",
" b2 -= alpha * db2\n",
"\n",
" # Print loss\n",
" if debug and i % 1000 == 0:\n",
" model = (W1, b1, W2, b2)\n",
" print(\"Cost after iteration {}: {:.4f}\".format(i, calculate_cost(\n",
" model, X, y)))\n",
" # TODO: 3. Wypisz skuteczność (accuracy) klasyfikacji w tym miejscu.\n",
" \n",
" # TODO: 2. Wypisz końcową skuteczność (accuracy) klasyfikacji w tym miejscu.\n",
" return W1, b1, W2, b2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOydeZyN1RvAv+e9+53V2Pc1QpFkS7IVIlHWSiElIS3ULylKlmhRCC229lRUosVSRCL7vu/bmDGY5e73Pb8/7hhz3TtjlnvHyPv9fJR5l3Oe19z7vOc8q5BSoqGhoaFxfaBcbQE0NDQ0NAoOTelraGhoXEdoSl9DQ0PjOkJT+hoaGhrXEZrS19DQ0LiO0F9tAbKjWLFislKlSldbDA0NDY1rhg0bNiRKKYtndb5QK/1KlSqxfv36qy2GhoaGxjWDEOJIduc1846GhobGdYSm9DU0NDSuIzSlr6GhoXEdoSl9DQ0NjesITelrXJecOZrA1pU7OZ9w4WqLoqFRoBTq6B0NjVDjsDkZ++B7bFyyBYPJgMvppl3flgye0g9F0dZAGv99NKWvEXb2bzrEt+8s5OSB09RtXosHnu1AXKkiV0WWD4bMYuOSLbgcblwONwC/z11B2RtK0+XZe6+KTBoaBYkozKWVb7vtNqnF6V/brFm4nrEPTsLlcCNVicGkxxJpZvqGiZSokGX+SFhwu9x0iu2NO13ZZ6ZEhWJ8cXh6gcqjoREOhBAbpJS3ZXVe289qhA1VVXlvwIc4bS6k6ltcuJ0eUs/bmDtqXtjnl64NqOcGoCZ2RE1+A2facVSPGvTa1PNpYZdHQ6MwoJl3NMJG4okk0s7bAo6rXpX1v28J69yqfSFcGAE4fAc8B7DwE6UqNeDE/gS/a4UQ1GleO6zyaGgUFrSVvkbYiIi2oKrBV9bRRSPDNq+UHkgeTYbCB8CDIJUh7xgxWU0oOt9HX2/QYYky039ir7DJo6FRmNCUvkbYiIiJoEG7ehiM/htKc4SJrs91DN/E3hMgXcFOcEujTUz5ZxytH25GjQZVuffJNny89R3K1ygbPnk0NAoRmnlHI6y8OGcQrz3wNjv/2YvBqMftdNPxqTa06dMiy3sObDnMZ69/y/5NhyhXowyPjOxG7dtr5HxSJQbwZvwoJexYF8HuTVaKlinGHY+U5MU5g/P+UBoa1zBa9I5GgXDywGkSjp+l8k0ViC4aleV1u9ftY1ir13HZXVz8bJosRl79diiN2t+a4/nUc0+B8y/cLjevPlKZXRuseNwCvdEI6Lil1U3UbV6btn1bElUkfKYmDY2C5krRO5rS18gzpw+f4eT+05S/sSzFyxUNyZjPNnuVHat3BxwvXaUkn+6fmuNxpJqKPD+Eb9/fz2dvFcfpCLRkmixGTFYjU/4ZT5mqpfIlt4ZGYeFKSl8z72jkGpfDxZiek9jwuy+r1e10c8cDjXhh9iD0hvx9pPZvPBj0ePzhM7gcLoxmY47GEUokIm4Wv80bhNNxJug1TrsLl9PN1KdnMW7xy3mWWUPjWkJT+hq55qMXPmPDkq1+Wa2rF6yjbLVSPPpaj3yNHV0sioRjZwOOm6wm9Mbcf1xVNftYBalKNi7dmutxL+fM0QQWz1zOmSMJ1Gt1M8173I7RZMj3uBoaoUaL3tHIFVJKfp21HJfdPzrGaXfx07Tf8j1+z/91xmQ1+R0zWY10GtQuT7Vx7up15xV3B3l5mWRm8x/b6Vf7Ob6Z8ANLPl3B5EEfM7D+i9hS7PkaV0MjHGhKXyNXqF41Y3V/OaFQch2fakvXofdispqwRJkxmg206dOSPm/0zNN4XYfeR/HycVmeN5j0tH74jryKi6qqjO/1Po40Jx6XBwBHmpNTB+P57t2FeR5XQyNcaOYdjVyh0+uoWq8S+zceCjh30x018z2+EII+r/ek5//u58zRRIqWKUJEtDXP45mtJoqVjePEvtNBz1esXZ4n3+6d5/GP7z2FLTnwZedyuPnzm795dFT3PI+toREOtJW+Rq55ZtoTmCNM6PQXs1r1WKMsPDWpT8jmMFtNVLixbL4U/kXOnjof9LjJamLoJ09hjbLkeWyTxYjqDZ51bLLkzOmsoVGQaEpfI9fc2PAGZmx6i/aP30XtpjW4b2BbPt72DpVvqnC1RQvKbW3qojfoAo7rdAoVapbL19glKxan/I1lEYrwO26OMNHxqbb5GltDIxxocfoa/3nOnjrHk3WHkpZsw+PyZeqarCaemtSbDk/cne/xTx2M5/kWI0m7YEeqKqpXpVnXJrw4Z5DWmEWjwNGSszQKBFVVWbtoI6vmr8USZaZd31ZUq1f5aouVwdlT55j31o9sXLKVYuXi6DasE7e2vjlk43u9XjYt287Zk0nUalJdq+WjcdXQlL5G2FFVlVH3v8Xm5dtxpDlQFIHBZKDfmw9z/9Pt8zSmx+1Bp9chhLjyxRoaGhloTVQ0wsrR3ScY0X4c6xZtwJHmK2WsqhKn3cUn//s8143HNyzZQt+az9De/BCdYh9l5ogv8Xq8V76xkOOwOfn3101sXLoVtyt4yOv2Vbt4pukI7ot5lMdveo6V360pYCk1rge0kE2NPLNp+TZG3jcBh80Z9LzOoGPTsu207Nk0R+PtXrePUfdPxGnzJX7ZUxwseH8RaefTGPLBE37XqqrK5j92cHTnccrfWIZ6rW8utPbzVQvWMrH31Axnr6IovLbgBepmatyyfdUuXmo7Bmd60tuRnceZ2GcqaRds3NOv9VWRW+O/iab0NfKElJJ3Hp+epcIHX8y9yZrzsMXP3/guMNPX5uK32X/Qb9xDRMREAL7Whs83H8npQ2fwuL3oDTpKVCjGuytHEx2XdQXPq8GZY4mM7zU54Lle7fgmXx3/MCMk9ZPhX2Qo/Is4bS5mDv+Ctn1bFtoXmsa1h6b0ryJnjiaw6KOlnNh3ijrNa3P3o3diicx7zHhBcj4hmaRT57K9RgjBbW3q5njMo7tOEMzFpDfo2fnPPo7sOIaUsHPNHo7tOZmRAet2ujmx7xRTBs9k2MynMFn8yzh43B5OHognumgkscVjcixPKFj2xV9ZxvGvXrCONr1bAHBo27Gg16Ql20k9n1boXmYa1y6a0r9KbF+1i+H3jMXr9uJ2eVi7aCPfTPyBaesnEFMs+mqLF4CUkpXfrmHhjN9xOdw069Y4qIIGn7K3RJkZs/ClHFfFBKhWrxKnD53h8uACp8PFqM4T0uUgQ9lnxuP28ufXq1kx729qNa7OsFkDKVe9DL9/+ifTn52D1+PF4/ZS/+46vPTZ0xm7hnCTdj4tqLxej5e0C5f6B5esWIxD244GXGcw6kOSoKahcRFtz3gVkFIysc9UHGlO3BfrtdicJJ06x+ejv7vK0gXnvSc/5O1+09jy5w52/bOXua9+jSXSjMHkv27QG/W069eKeac+znVZhl6vdsN4WRarUITvxej04HZ6girQzEhVsnPNXoY0HcHaxRuYPPBjUs+nYU914Ha62bBkC2N6TsqVXPmhQbt6mCNMgSeEoH6mXdCjr3UPMIWZrCa6PN8RnT4wsUxDI69oSv8qcPZkEmdPBppGPG4vq35YexUkyp5je06w9POVONIu2e+dNl8t+tJVS2GyGomIsaYXR2vBszP6B5hYckKVOhV5a9koajWp7sugFT4lnluklLjtbj5+8fMMp/BF3E4PW1fsJOF4YPnmcFCneS1ua3uLn+I3R5i457FWVLjxUiz/Hfc34umpjxNbIga9UY8lykz3F+7jkZFdC0ROjeuHkJh3hBCzgHuBM1LKm4KcF8D7QHvABvSRUm4MxdzXIkazMUtlZrbmXlmGgnPx51nzky8nonHH+sSVKpJxbuuKnYggjkRnmpM6zWrS6evnWPPzBoqUiKHp/Q3z5XQsXr4oDw5/gB+m/sKmpVtRs8kj0Rv1Wa78HTYnSaeD19zRG/UknT5/xW5fUkp+nb2cr8YtIOn0eW64tTJPTHyEWo2r5/h5hBC8Ou95Vi9Yx9LPV6I36GjbtxUN2t0ScG3bPi25+9Hm2FPsmCPN6HTaCl8j9ITKpj8HmAp
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"X, y = generate_data()\n",
"visualize(X, y)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cost after iteration 0: 0.4692\n",
"Cost after iteration 1000: 0.1527\n",
"Cost after iteration 2000: 0.1494\n",
"Cost after iteration 3000: 0.1478\n",
"Cost after iteration 4000: 0.1458\n",
"Cost after iteration 5000: 0.1444\n",
"Cost after iteration 6000: 0.1433\n",
"Cost after iteration 7000: 0.1423\n",
"Cost after iteration 8000: 0.1415\n",
"Cost after iteration 9000: 0.1407\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9eZRs113f+/3tfU7NXT1331FXsmZZRp5wbOHEhjCawXJMgg3IYLIiDBjwswKYB36QBNYTcXBsx2CjBTbIsAzkERkTC4ghC3CQjSyM5UmzdK90x749VlXXdM7ev/fHPqe6hnNq7uru2/uz1pW665w+Q9Wp3977N3x/xMywWCwWy5WP2OsLsFgsFstksAbfYrFYDgnW4FssFsshwRp8i8ViOSRYg2+xWCyHBGevL6AbCZHmtJPf68uwWCz7nBM3bez1JewbvvZlb5WZF6O27WuDn3byuH3p+/b6MiwWyz7nnv95/15fwr7hxafOnonbZl06FovlQPP4z57a60s4MFiDb7FYLIcEa/AtFsuB5v473rfXl3BgsAbfYrEcWCq3Ht/rSzhQWINvsVgOLO+/94N7fQkHCmvwLRaL5ZBgDb7FYjmQ2OycwbEG32KxHEhy12zt9SUcOKzBt1gsB47Krcfxsds+uteXceCwBt9isVgOCdbgWyyWA4fNzhkOa/AtFovlkGANvsViOVC88FOX9voSDizW4FssFsshwRp8i8VyoLhz9sG9voQDizX4FovlwHD+jdfs9SUcaMZi8InoI0S0QkRfidn+WiLaIqIvBv/+n3Gc12KxHC7uu/u9e30JB5pxdbz6XQAfBHBfl30+w8zfNabzWSwWi2VAxjLDZ+a/A7A+jmNZLBZLFDY7Z3Qm6cN/FRE9QkR/TkQvjNuJiO4iooeJ6OG6rkzw8iwWy37GBmtHZ1JNzL8A4BQzl4jodQA+AeD6qB2Z+V4A9wLAdGKZJ3R9FotlH2OVMcfDRGb4zFxg5lLw8wMAXCJamMS5LRaLxWKYiMEnoiNERMHPrwjOuzaJc1ssloOP7Vs7Hsbi0iGijwN4LYAFIjoL4JcAuADAzB8G8L0AfoyIfAAVAG9iZuuusVgslgkyFoPPzG/usf2DMGmbFovFMhA2O2d82Epby6GGtQZ7Hljrvb4USww2O2d8TCpLx2KJJfTuBWGeiZ1TF4rg7W2ACGAGpdMQM9MTvQ5Ld2x2znixBt+yZ7BS0FsFcLUKAKBUCmI6D5Jy18+tt7fB5XJwIWbA4UoFWhDk9PSun99i2QuswbfsCcwMtboGKLXzWrUK5dUhl5bGPstOZ3y86hsvIT/t4ZHPz+HJz1Qahr7lusoVcD5vZ/n7BJudM16swbfsCVytAlF+c83gahWUTo/tXDfeuon/9MHPgwhwXA2tCG+86Vb4OsKo2+QxyxWMDdpa9gZfRRtXZsD3x3YaIRjv/vUvIJtTyGQVEglGKq1x40vK0X8gpZ3d7xOs/378WINv2RscxwRL2yEy2yJg34euVMC1Ovot47j+li0kk6rlNa9O+OGfuwAn0bnCENZ/v2+w7pzxY106lj2BUklAiBYfPgBACFAq1fISM0NvboErFcDUawNCQC7M9wzwCskIh4ZahfChdx/HX/3JLLQCMlMaRT/4CjgOZH4KlEiM5f4so2EbnewOdoZv2TXY88yM3PM6thGRMdhNxp1SKfNa28xfl8vgaqCcyjBuH6Wg1jd6XsMTX5mGDnz1v/b2q/DXfzILryagfIHihmPiCFobN5J15ewb+Jt7f7aWwbEzfMvYYa2h1tYB3wNgctyRSEDOz7UYc5IScm629/G2y0CUB8fzwEp1neUrJXDPz78YP/7vv4yH/nceXi1mjqM11PoG5NLi0D78vagnuFL52G0f3etLuCKxBt8ydnShADRm9YGlrtehC4Xhctzj/PVBwVQvvvgPC3jnnV8PcKH7juFM33UHu7y6B7W11bhnyqQh8nmQsAvoYajcenyvL+GKxT6RlrHCzOBydOOauNd70e7T39lAQJ9FWsVyBl69j5n3gGmZ7CuotbWmAc7cZz/uJks077/Xym7tFtbgWybHkDnuYirXadiJIGZn+nafkBBAP7n9A87u9fZ29H3V65GxC4tlL7EuHctYISJjNKMCtcnkcMcUAnJxAbpcAeo1QDoQ2QwoJn0z9jjJpMn0iWGQASQk1qgTmfjCgAOIxbKbWINvGTtyZtrIJjTPfIkgpvNDH5OEgMxlAWSHPoaQAirO759KQsS5jrpdVyIBrtc7NzAPPCBZrBTybmOfSMvYIdeFXFqE3i6bTB3XBaXS4FoNersMSiRAqeTks1kSCUAQoDoNvpyaGuqQIpuBinDrUCppDf4QWCnk3cU+kZZdgaSEzBsjqmt16NXVHVXKchlwJOT8Akj0Z/RZM9irg0BAwh1qsCAiyPl5kzLapOND+fzQrheS0ribtgrgWg0gAmWzJu7QJxzKSQRB6MOa1mmLrXYfa/AtuwZrDbW5CVRrbRsY8Hzo7VJfM2tdrkBvbe28QAQ5N9tRFdsstxBnNMlxIJcWjYHVGkgkRjaw5DiQ83OR29j3oUvbYM8DuQ5ELtcy89e1GvTG5s4KQQhzb4fQ93/f3e/d60u44rEG37JrqPV1oB6fqcLlCtDD4LPvQ29utr3IUGvrkEeWG8ZabW+DiyVjxIUATeUgs9H+/kZgeZfhumdSNsOVjedBVaqQ8/OghGv6AaxvtLqDlDL3tjx+iWiLxRr8Awj7PnS5DCgFSiZB6fS+Mw7seZGZOi30cc26HKNqCUAXikE2jA9Uqk0bNHirAKWUmVE3FUCxr8DKBznOrjdaUVtbnQFiZqitLTiLC+beYhRDuVoDpQcPIh9UbLB2MliDf8DQ1aqZFQZwtQaUto0GzR5XdupqzVTZKhVUwXbfnzJ95MXH9ZplNu0Ju8ClbajSNiidAk1Pgzc2G352MIMyGdNha7cGy7gBz/OM+0nF3xu0it52hWKDtZPBFl4dIJjZ+HtbXwR832TE7CG6WoXeWDe+ceZ4Qx2STEDEuFyaoVRqZFEzrlShL68aYw+0BI91j0FjJOKumwhEBEomYvexqp2W3WAsBp+IPkJEK0T0lZjtREQfIKKniOhLRPTScZz30OHFNwbpVlA0TowmfbWj4EgXCj1n9CFidhZybq6vmTUlk+Pxt7fLMAdwqccqgRm6WoVaW4e/umZiBX1WDFMm0/V1SqU6tf/JvH6Ygra20cnkGJdL53cBfBDAfTHbvwPA9cG/fwbgQ8H/LYPQzUD2md44LOHqgqvV4DoYcFyjgCmE6WDVD1KCXKd/N0o/sYBR6GG8daFo0khbAq8VE3jtcQ8iPwWt1M57xmwatQfpqqFEtC5tB1r/ZCqIx9je8SDwr17z0F5fwqFhLAafmf+OiK7ussvrAdzHZmr0OSKaIaKjzHxhHOc/NDgSkBHGlagv98go6FLJGC5gx0h6HvTmlpE4liLeJ91MmIXShwxxQ2Z5F/vMdnOdsO93xgmClNJ++u5SkD7KSoH96EAxEUFO5YAB8vavNKz/fnJMyod/HMDzTb+fDV7rgIjuIqKHiejhup6Mm+KgYAzInOkURdSY8VM6DUqlwEFTEP/CRfgXL0EViy3uB2aGLpeNa2JtHbo9P74LHBMj4GoVzAzKDWCwtO5r1s7V6mDGPpWCmJmJXAlRNht0y2p+kSDyRu6BtTZpkr4PrtdNoVe9Hr2qCrJo+oWkhEgmdz0r6CBy7kMze30Jh4pJZelETeUiv8nMfC+AewFgOrG8e1O7Awo5DuTyErhWB7QyMgWOY2bDl1d3gqXM4GIJ2vPNLJPZ5IQHefEMGMOWyUD2o3HTzfAyQ2azUAxwsWj2DQekmOAtax35ULTQz4qhmWoVuloFTeUA3wfXPUBKiKkcRDIJzqSN+8T3QQkXIpsDBEGtre8EdEOIgG7aOlbrfizYRieTZVIG/yyAk02/nwBwfkLnvuIgItMTtgm9XY40rlytGteE53UGfYPURs5le84+KZnccek04ziNdFCZy4KzmYbB19vb4EKx82+YQW7
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"model1 = train(initialize_model(dim_hid=5), X, y, debug=True)\n",
"visualize(X, y, model1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cost after iteration 0: 0.6739\n",
"Cost after iteration 1000: 0.3618\n",
"Cost after iteration 2000: 0.3618\n",
"Cost after iteration 3000: 0.3618\n",
"Cost after iteration 4000: 0.3618\n",
"Cost after iteration 5000: 0.3618\n",
"Cost after iteration 6000: 0.3618\n",
"Cost after iteration 7000: 0.3618\n",
"Cost after iteration 8000: 0.3618\n",
"Cost after iteration 9000: 0.3618\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9eZxkZ13v//4+59TSVb1Pz57JZJIMCUmQRQg3gIIIioiiBAT5EQW5F1G496XAzQUV+XmVK+6iQDDIIuoFUS9clgiKyCKLLCGBrGSbLLP39N5d2znP9/7xnKqu5Zzuqu7qZWbO+/WaZKbqdJ1T1d2f8zzf5fMVVSUlJSUl5dzHbPUFpKSkpKRsDqngp6SkpJwnpIKfkpKScp6QCn5KSkrKeUIq+CkpKSnnCf5WX8BKZM2ADvjDW30ZKSkpKWcNc7VTk6q6M+65bS34A/4wT9n14q2+jJSUlJSzhk8f/fMHk55LQzopKSkp5wmp4KekpKScJ6SCn5KSknKekAp+SkpKynlCKvgpKSkp5wmp4KekpKScJ6SCn5KSknKekAp+SkpKynlCKvgpKSkp5wmp4KekpKScJ6SCn5KSknKekAp+SkpKynlCKvgpKSkp5wmp4KekpKScJ6SCn5KSknKekAp+SkpKynlCXwRfRN4nIqdE5LaE558hIrMickv05zf7cd6UlJSUlO7p18SrDwDvAD64wjFfUtXn9el8KSkpKSk90pcVvqp+EZjqx2ulpKSkpGwMmxnDv0ZEbhWRfxKRK5MOEpFXicg3ReSbVVvaxMtLSUlJObfZrCHmNwMHVXVBRJ4LfAw4HHegqt4I3Agwkt2tm3R9KSkpKec8m7LCV9U5VV2I/n4TkBGRic04d0pKSkqKY1MEX0T2iIhEf786Ou+ZzTh3SkpKSoqjLyEdEfkQ8AxgQkQeAd4CZABU9d3AC4FfEpEAKAEvUdU0XJOSkpKyifRF8FX1Z1d5/h24ss2UlJSUlC0i7bRNOa9Ra9FaDbV2qy8lJWXD2awqnZSUROrRvSjNs2nntHPz6OIiiIAqMjCAGR3Z1OtISdlMUsFP2TI0DLGzc2i5DIDk85iRYcTzNvzcdnERXVqKLsTdcLRUwhrBGxnZ8POnpGwFqeCnbAmqSjh5BsJw+bFymbBWxdu1q++r7IFCwDU/dJLhkRq3fmOce75Uagh9y3UtldDh4XSVn3JOkgp+ypag5TLExc2touUyMjDQt3NddtUMv/2ObyACfsZiQ+Hay68isDGinhaPpZzDpEnblK0hCOPFVRWCoG+nMUZ58x/dTHEwpFAMyWaV/IDlsscvxX+B56Wr+5RzllTwU7YG33fJ0nZE3HMxaBBgSyW0UqXbNo7DV8ySy4Utj9Wqwsv/x3H8bOcOw6Tx+5RzmDSkk7IlSD4HxrTE8AEwBsnnWx5SVezMLFoqgevXBmPwJnasmuA1nlK/NVRKwg1v3s9n/3EMG0JhyDIfRL8Cvo83PIRks315fykp25FU8FM2DK3V0CBAfB/JZFqeExG8iR3xVTptK3+7tISWI+dUjf4ThoRT0/g7V7Zk+t5tI9goVv97r72Qb3xumFrFbWznpw0QrfKDIH7HkZKyzTl27SEAPvj6Pwbg0weTj00FP6XvqLWEZ6YgqAGuxp1sFm/HeIuYi+fhjY+t/nqLSxAXwanV0DBccZUfhoa3velx/PIbvsvXm8S+A2sJp6bxdu1ccwx/K/oJUs5f3vaVj/b8Nangp/QdOzcHtVr0r0ipq1Xs3NzaatyT4vVRw9Rq3PIfE7zuuieBzq18oLVupd+2G1n18qo1wtnZxnuWwgBmeBgxaYospb/cff1BXvD0r3Pd2FfW9PWp4Kf0FVVFl+IH1+hSCdYg+JLPu47YjicEumzSml8qEFbnVz+wx7JMDULCM2davk6XSoRBiD+xo6fXSkmJo3TVft5+Y3+syFLBT9k81ljjboYGCcvl1gSvCGZstOvwiRgDAwNQWmWKWo+re7u4GP++qlW0VuvIXaSkdMOVnzoJsOaVfBKp4Kf0FRFxotkI6TQ9l8ut7TWNwds5gV0qQbUCno8pFpCE8s3E18nlXKVPAr3cQOpozPt0JxOXX0gFP6VLjl17CH3WNH/92Pdv2DlSwU/pO97oiLNNaF75imBGhtf8mmIM3mARKK75NYxnCJPi/vkcpq0ctKvrymbRarXzCdWeb0gp5ydXfupk31fySaQ/kSl9RzIZvF07sYtLrlInk0HyA2ilgl1cQrJZJJ/b/GqWbBaMQNgp+N7Q0Jpe0hQLhDFhHcnnUsFPieXu6w8yeGh2Q1fySaQ/kSkbgnge3rATUVupYicnl10pl5bA9/B2TCCmO9FXq2itiiCQzazpZiEieDt2uJLRJh8fGR5ec+hFPM+Fm2bn0EoFRJBiETM02PVraN1OIkpCp2Wd5x6lq/YzfVm2USu/VaSCn7JhqLWEMzNQrrQ9oVALsIsLXa2s7VIJOzu7/IAI3vhYR1dss91CkmiK7+Pt2ukE1lrIZtctsOL7eDvGY5/TIMAuLEYJXB8zONiy8reVCnZ6ZnmHYIx7b2ns/5zg6A2jW7KSTyIV/JQNI5yagmpCUpOoTHMVwdcgwM7MtD2ohGem8Pbsboh1uLiIzi84ETcGGRrEK8bH+xuJ5Q1Gq7WWkk2t1QhLZbwdO5Bsxs0DmJpuDQeFoXtvu/tvEZ2yORy79tCWr+STSAX/LESDALu0BGGI5HLIwMC2Ewet1WIrdVro4prtUoKrJWDn5qNqmABK5aYnLDo7RxiGbkXd1AClQYiGkd3DBg9aCWdnOxPEqoSzs/g7J9x7S3AM1XIFGeg9iZyyNfSzVn4jSQX/LMOWy25VGKHlCiwsOiOxLe7stOWK67INw6gLduXjpdCF533SrFnV+Gas5kMWFgkXFpGBPDIygk7PNOLsqCKFQqx3T99IuuHVai78FCa/N2wY/1zKtuHu6w/y0Z/6062+jJ5IBf8sQlVdvLf1QQgC7OISXg+Jwn5jy2Xs9PSyyK/WZJXLYhJCLs1IPo+WyusaTKKlMlqtLTduNSWPre/hDW7Q55ZUAiqCiCC5rOsLiDkmde3cnpwtK/kk+iL4IvI+4HnAKVW9KuZ5Ad4OPBdYAl6uqjf349znFbXkwSBaKsEmCL4GAVoLEN9rSSzaublVV/R1zNhY12WZksu5eHtcrXsvtNswR+jCIqwg+KqKViro4hKqigzkMYVCd9deKMTuQqRQcP/P58FfbN0JCEgunyZttxmbWSu/kfRrhf8B4B3ABxOe/zHgcPTnycAN0f9TemElkemyvHGt1HcXWi5H16HgZ5wDpjFuglU3eB6S8bsPo3STC1gPq+wc7Ny8KyNtSbyWXOJ1lfdghoewYbj8mak6C+ioXLVhEb2wGHn9i+sg7uN4x5S1cezaQzz7F752Toh8M30RfFX9oohctMIhzwc+qK5u7msiMioie1X1eD/Of97ge+DFiKtIV+GR9WAXFhq+9Q2RrNWwM7PO4tgzyTHpZupVKF3YEDdsljdwzuxKoRMNgs4VelRS2s3cXYnKRzUMl+cCtCWKRcSF4rYwHJeyzFosh88mNiuGvx94uOnfj0SPdQi+iLwKeBVA3ltb9+O5ihOQ8VbbAlVXpZPPuzK/lgaggqtSiYRVVdFSCbtUcjHkYhGT787fRhfjq2W0XHahjsFBdHYV++E61rpV+ypxai33GLvP5zH5vKvZb+98LRbRpcXWsJMIZtjZPai1LhGsilgLfsZZJsTF4aMqGrpciYvnbXhFUMraOXrDKD9y4d3n3Go+js0S/LilXOxvsqreCNwIMJLdvXFLu7MU8X283bvQShVs6GwKfN+thk9PLle1qKLzC9ha4FaZqq4mPKqLV0CrVbRQwOvG42Yl4VXFKxYJFXR+3h0r4v4kVNmotbE/FC10s2NoplzGlsvI0CAEgUvUeh5maBCTy6GFARc+CQIkm8EUB8EI4Zkpd5NsRgRW8tZJve7PWkpX7WfqNYvbqiFqs9gswX8EOND07wuAY5t07nMOEXEzYZuwi0ux4qrlcpRorXU
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"model2 = train(initialize_model(dim_hid=1), X, y, debug=True)\n",
"visualize(X, y, model2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}