{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "CNN_working_details.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "nJnTbZVIF55X"
},
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.utils.data import TensorDataset, Dataset, DataLoader\n",
"from torch.optim import SGD, Adam\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"from torchvision import datasets\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CW87aBN_F7h5"
},
"source": [
"X_train = torch.tensor([[[[1,2,3,4],[2,3,4,5],[5,6,7,8],[1,3,4,5]]],[[[-1,2,3,-4],[2,-3,4,5],[-5,6,-7,8],[-1,-3,-4,-5]]]]).to(device).float()\n",
"X_train /= 8\n",
"y_train = torch.tensor([0,1]).to(device).float()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WDr-XS8HF_UG"
},
"source": [
"def get_model():\n",
" model = nn.Sequential(\n",
" nn.Conv2d(1, 1, kernel_size=3),\n",
" nn.MaxPool2d(2),\n",
" nn.ReLU(),\n",
" nn.Flatten(),\n",
" nn.Linear(1, 1),\n",
" nn.Sigmoid(),\n",
" ).to(device)\n",
" loss_fn = nn.BCELoss()\n",
" optimizer = Adam(model.parameters(), lr=1e-2)\n",
" return model, loss_fn, optimizer"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "alInllQdGC13",
"outputId": "3ab3e397-93c2-4f14-f325-6fb304bf2d66",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 474
}
},
"source": [
"!pip install torch_summary\n",
"from torchsummary import summary\n",
"model, loss_fn, optimizer = get_model()\n",
"summary(model, X_train);"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting torch_summary\n",
" Downloading https://files.pythonhosted.org/packages/83/49/f9db57bcad7246591b93519fd8e5166c719548c45945ef7d2fc9fcba46fb/torch_summary-1.4.3-py3-none-any.whl\n",
"Installing collected packages: torch-summary\n",
"Successfully installed torch-summary-1.4.3\n",
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"├─Conv2d: 1-1 [-1, 1, 2, 2] 10\n",
"├─MaxPool2d: 1-2 [-1, 1, 1, 1] --\n",
"├─ReLU: 1-3 [-1, 1, 1, 1] --\n",
"├─Flatten: 1-4 [-1, 1] --\n",
"├─Linear: 1-5 [-1, 1] 2\n",
"├─Sigmoid: 1-6 [-1, 1] --\n",
"==========================================================================================\n",
"Total params: 12\n",
"Trainable params: 12\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 0.00\n",
"==========================================================================================\n",
"Input size (MB): 0.00\n",
"Forward/backward pass size (MB): 0.00\n",
"Params size (MB): 0.00\n",
"Estimated Total Size (MB): 0.00\n",
"==========================================================================================\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "NqnAmC52GEz0"
},
"source": [
"def train_batch(x, y, model, opt, loss_fn):\n",
" model.train()\n",
" prediction = model(x)\n",
" batch_loss = loss_fn(prediction, y)\n",
" batch_loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" return batch_loss.item()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GPFRvgAlGIbp"
},
"source": [
"trn_dl = DataLoader(TensorDataset(X_train, y_train))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "yHALwqudGJzh",
"outputId": "1104beb4-326c-4ad7-8952-a5056c73bf9b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 74
}
},
"source": [
"for epoch in range(2000):\n",
" for ix, batch in enumerate(iter(trn_dl)):\n",
" x, y = batch\n",
" batch_loss = train_batch(x, y, model, optimizer, loss_fn)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:529: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])) is deprecated. Please ensure they have the same size.\n",
" return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "YMEA3dcUGMA2",
"outputId": "db7abd8c-c17a-4b1c-9f4d-37795420c05f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"model(X_train[:1])"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0.0042]], grad_fn=)"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "eHErutP4GNxI"
},
"source": [
""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Zw-Ou9J4Gail",
"outputId": "3b16cd39-beef-429a-947d-3d2cc86d0398",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 126
}
},
"source": [
"list(model.children())"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1)),\n",
" MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),\n",
" ReLU(),\n",
" Flatten(),\n",
" Linear(in_features=1, out_features=1, bias=True),\n",
" Sigmoid()]"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3ACidFxRGa_0"
},
"source": [
"(cnn_w, cnn_b), (lin_w, lin_b) = [(layer.weight.data, layer.bias.data) for layer in list(model.children()) if hasattr(layer, 'weight')]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Lj-SiW6sGcyy"
},
"source": [
"h_im, w_im = X_train.shape[2:]\n",
"h_conv, w_conv = cnn_w.shape[2:]\n",
"sumprod = torch.zeros((h_im - h_conv + 1, w_im - w_conv + 1))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RPPgrk51GgL1"
},
"source": [
"for i in range(h_im - h_conv + 1):\n",
" for j in range(w_im - w_conv + 1):\n",
" img_subset = X_train[0, 0, i:(i+3), j:(j+3)]\n",
" model_filter = cnn_w.reshape(3,3)\n",
" val = torch.sum(img_subset*model_filter) + cnn_b\n",
" sumprod[i,j] = val"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GqaZ2WlpGjTN",
"outputId": "b6610f5d-d2eb-499e-e8d0-a0c2db5f9b74",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 54
}
},
"source": [
"sumprod.clamp_min_(0)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0., 0.],\n",
" [0., 0.]])"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "vpb3vZvuGkyX"
},
"source": [
"pooling_layer_output = torch.max(sumprod)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NfeX-EYuGmJ3"
},
"source": [
"intermediate_output_value = pooling_layer_output * lin_w + lin_b"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rwA6srUgGniP",
"outputId": "8893187a-7f02-4e23-e6ea-6eb7f1239594",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 72
}
},
"source": [
"from torch.nn import functional as F # torch library for numpy like functions\n",
"print(F.sigmoid(intermediate_output_value))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"tensor([[0.0042]])\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:1625: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
" warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "h6QhLf0xGpIE"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}