{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "CNN_working_details.ipynb", "provenance": [], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "id": "nJnTbZVIF55X" }, "source": [ "import torch\n", "from torch import nn\n", "from torch.utils.data import TensorDataset, Dataset, DataLoader\n", "from torch.optim import SGD, Adam\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "from torchvision import datasets\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "CW87aBN_F7h5" }, "source": [ "X_train = torch.tensor([[[[1,2,3,4],[2,3,4,5],[5,6,7,8],[1,3,4,5]]],[[[-1,2,3,-4],[2,-3,4,5],[-5,6,-7,8],[-1,-3,-4,-5]]]]).to(device).float()\n", "X_train /= 8\n", "y_train = torch.tensor([0,1]).to(device).float()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "WDr-XS8HF_UG" }, "source": [ "def get_model():\n", " model = nn.Sequential(\n", " nn.Conv2d(1, 1, kernel_size=3),\n", " nn.MaxPool2d(2),\n", " nn.ReLU(),\n", " nn.Flatten(),\n", " nn.Linear(1, 1),\n", " nn.Sigmoid(),\n", " ).to(device)\n", " loss_fn = nn.BCELoss()\n", " optimizer = Adam(model.parameters(), lr=1e-2)\n", " return model, loss_fn, optimizer" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "alInllQdGC13", "outputId": "3ab3e397-93c2-4f14-f325-6fb304bf2d66", "colab": { "base_uri": "https://localhost:8080/", "height": 474 } }, "source": [ "!pip install torch_summary\n", "from torchsummary import summary\n", "model, loss_fn, optimizer = get_model()\n", "summary(model, X_train);" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Collecting torch_summary\n", " Downloading https://files.pythonhosted.org/packages/83/49/f9db57bcad7246591b93519fd8e5166c719548c45945ef7d2fc9fcba46fb/torch_summary-1.4.3-py3-none-any.whl\n", "Installing collected packages: torch-summary\n", "Successfully installed torch-summary-1.4.3\n", "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "├─Conv2d: 1-1 [-1, 1, 2, 2] 10\n", "├─MaxPool2d: 1-2 [-1, 1, 1, 1] --\n", "├─ReLU: 1-3 [-1, 1, 1, 1] --\n", "├─Flatten: 1-4 [-1, 1] --\n", "├─Linear: 1-5 [-1, 1] 2\n", "├─Sigmoid: 1-6 [-1, 1] --\n", "==========================================================================================\n", "Total params: 12\n", "Trainable params: 12\n", "Non-trainable params: 0\n", "Total mult-adds (M): 0.00\n", "==========================================================================================\n", "Input size (MB): 0.00\n", "Forward/backward pass size (MB): 0.00\n", "Params size (MB): 0.00\n", "Estimated Total Size (MB): 0.00\n", "==========================================================================================\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "NqnAmC52GEz0" }, "source": [ "def train_batch(x, y, model, opt, loss_fn):\n", " model.train()\n", " prediction = model(x)\n", " batch_loss = loss_fn(prediction, y)\n", " batch_loss.backward()\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " return batch_loss.item()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "GPFRvgAlGIbp" }, "source": [ "trn_dl = DataLoader(TensorDataset(X_train, y_train))" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "yHALwqudGJzh", "outputId": "1104beb4-326c-4ad7-8952-a5056c73bf9b", "colab": { "base_uri": "https://localhost:8080/", "height": 74 } }, "source": [ "for epoch in range(2000):\n", " for ix, batch in enumerate(iter(trn_dl)):\n", " x, y = batch\n", " batch_loss = train_batch(x, y, model, optimizer, loss_fn)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:529: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])) is deprecated. Please ensure they have the same size.\n", " return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)\n" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "id": "YMEA3dcUGMA2", "outputId": "db7abd8c-c17a-4b1c-9f4d-37795420c05f", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "model(X_train[:1])" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[0.0042]], grad_fn=)" ] }, "metadata": { "tags": [] }, "execution_count": 8 } ] }, { "cell_type": "code", "metadata": { "id": "eHErutP4GNxI" }, "source": [ "" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Zw-Ou9J4Gail", "outputId": "3b16cd39-beef-429a-947d-3d2cc86d0398", "colab": { "base_uri": "https://localhost:8080/", "height": 126 } }, "source": [ "list(model.children())" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)),\n", " MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),\n", " ReLU(),\n", " Flatten(),\n", " Linear(in_features=1, out_features=1, bias=True),\n", " Sigmoid()]" ] }, "metadata": { "tags": [] }, "execution_count": 9 } ] }, { "cell_type": "code", "metadata": { "id": "3ACidFxRGa_0" }, "source": [ "(cnn_w, cnn_b), (lin_w, lin_b) = [(layer.weight.data, layer.bias.data) for layer in list(model.children()) if hasattr(layer, 'weight')]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Lj-SiW6sGcyy" }, "source": [ "h_im, w_im = X_train.shape[2:]\n", "h_conv, w_conv = cnn_w.shape[2:]\n", "sumprod = torch.zeros((h_im - h_conv + 1, w_im - w_conv + 1))" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "RPPgrk51GgL1" }, "source": [ "for i in range(h_im - h_conv + 1):\n", " for j in range(w_im - w_conv + 1):\n", " img_subset = X_train[0, 0, i:(i+3), j:(j+3)]\n", " model_filter = cnn_w.reshape(3,3)\n", " val = torch.sum(img_subset*model_filter) + cnn_b\n", " sumprod[i,j] = val" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "GqaZ2WlpGjTN", "outputId": "b6610f5d-d2eb-499e-e8d0-a0c2db5f9b74", "colab": { "base_uri": "https://localhost:8080/", "height": 54 } }, "source": [ "sumprod.clamp_min_(0)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[0., 0.],\n", " [0., 0.]])" ] }, "metadata": { "tags": [] }, "execution_count": 13 } ] }, { "cell_type": "code", "metadata": { "id": "vpb3vZvuGkyX" }, "source": [ "pooling_layer_output = torch.max(sumprod)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "NfeX-EYuGmJ3" }, "source": [ "intermediate_output_value = pooling_layer_output * lin_w + lin_b" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "rwA6srUgGniP", "outputId": "8893187a-7f02-4e23-e6ea-6eb7f1239594", "colab": { "base_uri": "https://localhost:8080/", "height": 72 } }, "source": [ "from torch.nn import functional as F # torch library for numpy like functions\n", "print(F.sigmoid(intermediate_output_value))" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "tensor([[0.0042]])\n" ], "name": "stdout" }, { "output_type": "stream", "text": [ "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1625: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n", " warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n" ], "name": "stderr" } ] }, { "cell_type": "code", "metadata": { "id": "h6QhLf0xGpIE" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }