{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "save_and_load_pytorch_model.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:43.317635Z",
"start_time": "2020-09-25T19:46:43.315050Z"
},
"id": "D5_lUQ_JzxNQ"
},
"source": [
"x = [[1,2],[3,4],[5,6],[7,8]]\n",
"y = [[3],[7],[11],[15]]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:43.632085Z",
"start_time": "2020-09-25T19:46:43.319154Z"
},
"id": "TG0fNwONz6yn"
},
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"from torch.utils.data import Dataset, DataLoader\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:43.636696Z",
"start_time": "2020-09-25T19:46:43.633237Z"
},
"id": "f4-xTYoCz8U9"
},
"source": [
"class MyDataset(Dataset):\n",
" def __init__(self, x, y):\n",
" self.x = torch.tensor(x).float().to(device)\n",
" self.y = torch.tensor(y).float().to(device)\n",
" def __getitem__(self, ix):\n",
" return self.x[ix], self.y[ix]\n",
" def __len__(self): \n",
" return len(self.x)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:45.210534Z",
"start_time": "2020-09-25T19:46:43.638037Z"
},
"id": "WeBe83XQz9we"
},
"source": [
"ds = MyDataset(x, y)\n",
"dl = DataLoader(ds, batch_size=2, shuffle=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:45.214494Z",
"start_time": "2020-09-25T19:46:45.211517Z"
},
"id": "Vcg57P86z_oF"
},
"source": [
"model = nn.Sequential(\n",
" nn.Linear(2, 8),\n",
" nn.ReLU(),\n",
" nn.Linear(8, 1)\n",
").to(device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:46.908328Z",
"start_time": "2020-09-25T19:46:45.215657Z"
},
"id": "7FGa-UWK0BIX",
"outputId": "570c4f77-ef48-46c7-85b9-49b41eec4088",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
}
},
"source": [
"!pip install torch_summary\n",
"from torchsummary import summary"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: torch_summary in /home/yyr/anaconda3/lib/python3.7/site-packages (1.4.1)\n",
"\u001b[33mWARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.\n",
"You should consider upgrading via the '/home/yyr/anaconda3/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:46.921020Z",
"start_time": "2020-09-25T19:46:46.909862Z"
},
"id": "UVZlHyXh0Fyd",
"outputId": "1b7c50ea-f954-4a56-8eb0-8095891c943c",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 595
}
},
"source": [
"summary(model, torch.zeros(1,2));"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"├─Linear: 1-1 [-1, 8] 24\n",
"├─ReLU: 1-2 [-1, 8] --\n",
"├─Linear: 1-3 [-1, 1] 9\n",
"==========================================================================================\n",
"Total params: 33\n",
"Trainable params: 33\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 0.00\n",
"==========================================================================================\n",
"Input size (MB): 0.00\n",
"Forward/backward pass size (MB): 0.00\n",
"Params size (MB): 0.00\n",
"Estimated Total Size (MB): 0.00\n",
"==========================================================================================\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"├─Linear: 1-1 [-1, 8] 24\n",
"├─ReLU: 1-2 [-1, 8] --\n",
"├─Linear: 1-3 [-1, 1] 9\n",
"==========================================================================================\n",
"Total params: 33\n",
"Trainable params: 33\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 0.00\n",
"==========================================================================================\n",
"Input size (MB): 0.00\n",
"Forward/backward pass size (MB): 0.00\n",
"Params size (MB): 0.00\n",
"Estimated Total Size (MB): 0.00\n",
"=========================================================================================="
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:46.997392Z",
"start_time": "2020-09-25T19:46:46.922234Z"
},
"id": "NDHfUDbW0Lh_",
"outputId": "c0f4620b-4479-4ecc-d3e2-77e7d067b5d8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"loss_func = nn.MSELoss()\n",
"from torch.optim import SGD\n",
"opt = SGD(model.parameters(), lr = 0.001)\n",
"import time\n",
"loss_history = []\n",
"start = time.time()\n",
"for _ in range(50):\n",
" for ix, iy in dl:\n",
" opt.zero_grad()\n",
" loss_value = loss_func(model(ix),iy)\n",
" loss_value.backward()\n",
" opt.step()\n",
" loss_history.append(loss_value)\n",
"end = time.time()\n",
"print(end - start)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"0.07127761840820312\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:45:14.978405Z",
"start_time": "2020-09-25T19:45:14.976623Z"
},
"id": "JrHJXeCl2FHm"
},
"source": [
"### Saving"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:47.132743Z",
"start_time": "2020-09-25T19:46:46.998346Z"
},
"id": "FwNYJ83V2FHp",
"outputId": "c2f89080-36c2-4c2a-8f05-960990921fe1"
},
"source": [
"save_path = 'mymodel.pth'\n",
"torch.save(model.state_dict(), save_path)\n",
"!du -hsc {save_path} # size of the model on disk"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"4.0K\tmymodel.pth\r\n",
"4.0K\ttotal\r\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1ew37R8X2FHr"
},
"source": [
"### Loading"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:47.153931Z",
"start_time": "2020-09-25T19:46:47.138011Z"
},
"id": "93-_e2N62FHr",
"outputId": "68c8fa79-e492-4a92-ec33-d4cff2dbf4ba"
},
"source": [
"load_path = 'mymodel.pth'\n",
"model.load_state_dict(torch.load(load_path))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qiqv1PFH2FHu"
},
"source": [
"### Predictions"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:47.165958Z",
"start_time": "2020-09-25T19:46:47.158191Z"
},
"id": "-Y-j0JeW0WKz"
},
"source": [
"val = [[8,9],[10,11],[1.5,2.5]]\n",
"val = torch.tensor(val).float()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:47.184080Z",
"start_time": "2020-09-25T19:46:47.170476Z"
},
"id": "KdNMIy4u0Xkt",
"outputId": "1fc3883d-0692-409d-ecb8-d5dd98583285",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"model(val.to(device))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[16.5265],\n",
" [20.2101],\n",
" [ 4.5547]], device='cuda:0', grad_fn=)"
]
},
"metadata": {
"tags": []
},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-25T19:46:47.193074Z",
"start_time": "2020-09-25T19:46:47.186259Z"
},
"id": "FCagMOUM2FHz",
"outputId": "9dcc6908-9188-443b-dc8c-57058e2087cf"
},
"source": [
"val.sum(-1)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([17., 21., 4.])"
]
},
"metadata": {
"tags": []
},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "bUdMNahe2FH1"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}