{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "save_and_load_pytorch_model.ipynb", "provenance": [], "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:43.317635Z", "start_time": "2020-09-25T19:46:43.315050Z" }, "id": "D5_lUQ_JzxNQ" }, "source": [ "x = [[1,2],[3,4],[5,6],[7,8]]\n", "y = [[3],[7],[11],[15]]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:43.632085Z", "start_time": "2020-09-25T19:46:43.319154Z" }, "id": "TG0fNwONz6yn" }, "source": [ "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "from torch.utils.data import Dataset, DataLoader\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:43.636696Z", "start_time": "2020-09-25T19:46:43.633237Z" }, "id": "f4-xTYoCz8U9" }, "source": [ "class MyDataset(Dataset):\n", " def __init__(self, x, y):\n", " self.x = torch.tensor(x).float().to(device)\n", " self.y = torch.tensor(y).float().to(device)\n", " def __getitem__(self, ix):\n", " return self.x[ix], self.y[ix]\n", " def __len__(self): \n", " return len(self.x)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:45.210534Z", "start_time": "2020-09-25T19:46:43.638037Z" }, "id": "WeBe83XQz9we" }, "source": [ "ds = MyDataset(x, y)\n", "dl = DataLoader(ds, batch_size=2, shuffle=True)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:45.214494Z", "start_time": "2020-09-25T19:46:45.211517Z" }, "id": "Vcg57P86z_oF" }, "source": [ "model = nn.Sequential(\n", " nn.Linear(2, 8),\n", " nn.ReLU(),\n", " nn.Linear(8, 1)\n", ").to(device)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:46.908328Z", "start_time": "2020-09-25T19:46:45.215657Z" }, "id": "7FGa-UWK0BIX", "outputId": "570c4f77-ef48-46c7-85b9-49b41eec4088", "colab": { "base_uri": "https://localhost:8080/", "height": 85 } }, "source": [ "!pip install torch_summary\n", "from torchsummary import summary" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "Requirement already satisfied: torch_summary in /home/yyr/anaconda3/lib/python3.7/site-packages (1.4.1)\n", "\u001b[33mWARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.\n", "You should consider upgrading via the '/home/yyr/anaconda3/bin/python -m pip install --upgrade pip' command.\u001b[0m\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:46.921020Z", "start_time": "2020-09-25T19:46:46.909862Z" }, "id": "UVZlHyXh0Fyd", "outputId": "1b7c50ea-f954-4a56-8eb0-8095891c943c", "colab": { "base_uri": "https://localhost:8080/", "height": 595 } }, "source": [ "summary(model, torch.zeros(1,2));" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "├─Linear: 1-1 [-1, 8] 24\n", "├─ReLU: 1-2 [-1, 8] --\n", "├─Linear: 1-3 [-1, 1] 9\n", "==========================================================================================\n", "Total params: 33\n", "Trainable params: 33\n", "Non-trainable params: 0\n", "Total mult-adds (M): 0.00\n", "==========================================================================================\n", "Input size (MB): 0.00\n", "Forward/backward pass size (MB): 0.00\n", "Params size (MB): 0.00\n", "Estimated Total Size (MB): 0.00\n", "==========================================================================================\n" ], "name": "stdout" }, { "output_type": "execute_result", "data": { "text/plain": [ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", "├─Linear: 1-1 [-1, 8] 24\n", "├─ReLU: 1-2 [-1, 8] --\n", "├─Linear: 1-3 [-1, 1] 9\n", "==========================================================================================\n", "Total params: 33\n", "Trainable params: 33\n", "Non-trainable params: 0\n", "Total mult-adds (M): 0.00\n", "==========================================================================================\n", "Input size (MB): 0.00\n", "Forward/backward pass size (MB): 0.00\n", "Params size (MB): 0.00\n", "Estimated Total Size (MB): 0.00\n", "==========================================================================================" ] }, "metadata": { "tags": [] }, "execution_count": 7 } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:46.997392Z", "start_time": "2020-09-25T19:46:46.922234Z" }, "id": "NDHfUDbW0Lh_", "outputId": "c0f4620b-4479-4ecc-d3e2-77e7d067b5d8", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "loss_func = nn.MSELoss()\n", "from torch.optim import SGD\n", "opt = SGD(model.parameters(), lr = 0.001)\n", "import time\n", "loss_history = []\n", "start = time.time()\n", "for _ in range(50):\n", " for ix, iy in dl:\n", " opt.zero_grad()\n", " loss_value = loss_func(model(ix),iy)\n", " loss_value.backward()\n", " opt.step()\n", " loss_history.append(loss_value)\n", "end = time.time()\n", "print(end - start)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "0.07127761840820312\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:45:14.978405Z", "start_time": "2020-09-25T19:45:14.976623Z" }, "id": "JrHJXeCl2FHm" }, "source": [ "### Saving" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:47.132743Z", "start_time": "2020-09-25T19:46:46.998346Z" }, "id": "FwNYJ83V2FHp", "outputId": "c2f89080-36c2-4c2a-8f05-960990921fe1" }, "source": [ "save_path = 'mymodel.pth'\n", "torch.save(model.state_dict(), save_path)\n", "!du -hsc {save_path} # size of the model on disk" ], "execution_count": null, "outputs": [ { "output_type": "stream", "text": [ "4.0K\tmymodel.pth\r\n", "4.0K\ttotal\r\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "1ew37R8X2FHr" }, "source": [ "### Loading" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:47.153931Z", "start_time": "2020-09-25T19:46:47.138011Z" }, "id": "93-_e2N62FHr", "outputId": "68c8fa79-e492-4a92-ec33-d4cff2dbf4ba" }, "source": [ "load_path = 'mymodel.pth'\n", "model.load_state_dict(torch.load(load_path))" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": { "tags": [] }, "execution_count": 10 } ] }, { "cell_type": "markdown", "metadata": { "id": "Qiqv1PFH2FHu" }, "source": [ "### Predictions" ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:47.165958Z", "start_time": "2020-09-25T19:46:47.158191Z" }, "id": "-Y-j0JeW0WKz" }, "source": [ "val = [[8,9],[10,11],[1.5,2.5]]\n", "val = torch.tensor(val).float()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:47.184080Z", "start_time": "2020-09-25T19:46:47.170476Z" }, "id": "KdNMIy4u0Xkt", "outputId": "1fc3883d-0692-409d-ecb8-d5dd98583285", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "source": [ "model(val.to(device))" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([[16.5265],\n", " [20.2101],\n", " [ 4.5547]], device='cuda:0', grad_fn=)" ] }, "metadata": { "tags": [] }, "execution_count": 12 } ] }, { "cell_type": "code", "metadata": { "ExecuteTime": { "end_time": "2020-09-25T19:46:47.193074Z", "start_time": "2020-09-25T19:46:47.186259Z" }, "id": "FCagMOUM2FHz", "outputId": "9dcc6908-9188-443b-dc8c-57058e2087cf" }, "source": [ "val.sum(-1)" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([17., 21., 4.])" ] }, "metadata": { "tags": [] }, "execution_count": 13 } ] }, { "cell_type": "code", "metadata": { "id": "bUdMNahe2FH1" }, "source": [ "" ], "execution_count": null, "outputs": [] } ] }