260 lines
4.5 MiB
Plaintext
260 lines
4.5 MiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import datetime\n",
|
||
|
"import os\n",
|
||
|
"import sys\n",
|
||
|
"import time\n",
|
||
|
"\n",
|
||
|
"from torch.optim import Optimizer\n",
|
||
|
"from torch.utils.data import DataLoader\n",
|
||
|
"\n",
|
||
|
"from inference import infer, evaluate\n",
|
||
|
"from metrics import Metrics\n",
|
||
|
"\n",
|
||
|
"# Hack for module imports\n",
|
||
|
"module_path = os.path.abspath(os.path.join('../data'))\n",
|
||
|
"if module_path not in sys.path:\n",
|
||
|
" sys.path.append(module_path)\n",
|
||
|
"\n",
|
||
|
"from loaders import FlatsDatasetLoader\n",
|
||
|
"\n",
|
||
|
"import numpy as np\n",
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"from tqdm import tqdm\n",
|
||
|
"\n",
|
||
|
"import plots"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Will operate on device cuda\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||
|
"print(f'Will operate on device {device}')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Loading dataset from files...\n",
|
||
|
"Broken file: ..\\data\\images\\train\\Industrial\\646.jpg\n",
|
||
|
"Done. Creating PyTorch datasets...\n",
|
||
|
"Done.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"image_size = 200\n",
|
||
|
"\n",
|
||
|
"data_loader = FlatsDatasetLoader(\n",
|
||
|
" images_dir='../data/images',\n",
|
||
|
" resize_to=image_size,\n",
|
||
|
" device=device,\n",
|
||
|
" batch_size=100\n",
|
||
|
")\n",
|
||
|
"data_loader.load(verbose=True)\n",
|
||
|
"train_loader = data_loader.get_train_loader()\n",
|
||
|
"test_loader = data_loader.get_test_loader()\n",
|
||
|
"classes_count = data_loader.get_classes_count()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def train(\n",
|
||
|
" model: nn.Module,\n",
|
||
|
" train_data: DataLoader,\n",
|
||
|
" test_data: DataLoader,\n",
|
||
|
" optimizer_fn: Optimizer,\n",
|
||
|
" loss_fn,\n",
|
||
|
" epochs: int\n",
|
||
|
") -> [Metrics, Metrics]:\n",
|
||
|
" model.train()\n",
|
||
|
" test_metrics = Metrics()\n",
|
||
|
" train_metrics = Metrics()\n",
|
||
|
" for _ in tqdm(range(epochs), total=epochs):\n",
|
||
|
" train_outs, train_losses = np.array([]), []\n",
|
||
|
" ys = np.array([])\n",
|
||
|
" for data in train_data:\n",
|
||
|
" optimizer_fn.zero_grad()\n",
|
||
|
" output, loss = infer(data, model, loss_fn, device)\n",
|
||
|
" ys = np.concatenate((ys, data[1].data.numpy()))\n",
|
||
|
" train_outs = np.concatenate(\n",
|
||
|
" (train_outs, torch.nn.functional.softmax(output, dim=1).cpu().data.numpy().argmax(axis=1))\n",
|
||
|
" )\n",
|
||
|
" loss.backward()\n",
|
||
|
" train_losses.append(loss.item())\n",
|
||
|
" optimizer_fn.step()\n",
|
||
|
" train_metrics.add_new(train_outs, ys, train_losses)\n",
|
||
|
" test_trues, test_preds, test_losses = evaluate(model, test_data, loss_fn, device)\n",
|
||
|
" test_metrics.add_new(test_preds, test_trues, test_losses)\n",
|
||
|
"\n",
|
||
|
" return test_metrics, train_metrics"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def conduct_experiment(title, model, n_epochs, optimizer):\n",
|
||
|
" criterion = nn.CrossEntropyLoss()\n",
|
||
|
"\n",
|
||
|
" start = time.time()\n",
|
||
|
" test_metrics, train_metrics = train(model, train_loader, test_loader, optimizer, criterion, n_epochs)\n",
|
||
|
" end = time.time()\n",
|
||
|
" model_directory = os.path.join('models', title)\n",
|
||
|
" path = os.path.join(model_directory, f'{datetime.datetime.now().strftime(\"%y-%b-%d-%H-%M\")}.pt')\n",
|
||
|
" try:\n",
|
||
|
" os.mkdir(model_directory)\n",
|
||
|
" except FileExistsError:\n",
|
||
|
" pass\n",
|
||
|
" torch.save(model, path)\n",
|
||
|
" print(f'Model saved in {path}')\n",
|
||
|
" plots.plot_metrics(title, test_metrics, train_metrics, n_epochs, end - start, image_size, device)\n",
|
||
|
" _, preds, _ = evaluate(model, test_loader, criterion, device)\n",
|
||
|
" labels = data_loader.get_label_names()\n",
|
||
|
" plots.show_missclassified(test_loader.dataset, preds, labels)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Using cache found in C:\\Users\\komar/.cache\\torch\\hub\\pytorch_vision_v0.10.0\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"resnet_model = torch.hub.load(\n",
|
||
|
" 'pytorch/vision:v0.10.0',\n",
|
||
|
" 'resnet18',\n",
|
||
|
" num_classes=data_loader.get_classes_count()\n",
|
||
|
").to(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"100%|██████████| 5/5 [02:25<00:00, 29.10s/it]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model saved in models\\ResNet18\\23-Jan-09-20-23.pt\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"C:\\Users\\komar\\Documents\\University\\flats\\experiments\\plots.py:70: UserWarning: You have mixed positional and keyword arguments, some input may be discarded.\n",
|
||
|
" fig.legend(axis, labels=['test', 'train'], loc=\"lower center\")\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<Figure size 800x800 with 5 Axes>",
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxMAAAMZCAYAAAB7wWk2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAxOAAAMTgF/d4wjAADzaElEQVR4nOzdd3hTZfsH8O/J7KZ7yRYb9t4FBEQEZCiKIkMFERy4eQX1dYIKCIKg+CqiiAMQRRyAPxUEyt5D9qZAW7p3s87z+yNtaNq0tGnaJPT7ua5eTXKenHPn9KTn3OdZkhBCgIiIiIiIqJIUrg6AiIiIiIg8E5MJIiIiIiJyCJMJIiIiIiJyCJMJIiIiIiJyCJMJIiIiIiJyCJMJIiIiIiJyCJMJIiIiIiJyiMrVARARebK+ffviypUr5ZbRaDQIDAxEkyZN0L9/fzzwwANQKpU1FKF9Y8eOxe7du6FWq/Hjjz+iadOm5ZY/e/YsBg0aBABYtmwZunTp4rRYkpOTsX79ejz88MMVKv/bb79hypQpePzxxzFlypRyy54+fRpffvkldu7cieTkZGg0GjRp0gQDBw7EqFGjoNVqnfERiIhqLdZMEBE5gVarRWhoaKmfkJAQSJKEa9euYfv27XjrrbcwZswY6PV6V4cMADAajZg2bRqMRqNLtv/rr7/irrvuwu+//16h8idOnMD06dMrVHbdunUYPnw4Vq9ejatXr0KtViM/Px+HDh3CzJkzMWLECKSmplYlfCKiWo/JBBGREwwaNAjbtm0r9bN9+3YcOnQIv/zyC3r37g0A2L9/P2bPnu3agIs5fvw4Pv30U5dse/v27cjNza1Q2cOHD+PRRx9FZmbmDcteuHABL7/8MgwGA2JjY7Fu3TocOHAABw8exPTp0+Hl5YWTJ0/i5ZdfrupHICKq1ZhMEBFVM0mS0LRpUyxcuBAxMTEAgB9//BFZWVkujuy6zz77DMeOHXN1GHbJsoyvv/4ao0aNQnp6eoXes2LFChiNRkREROCTTz7BrbfeCsBSg/TAAw/gpZdeAgBs3boVly9frrbYiYhudkwmiIhqiEajwciRIwEABQUFOHr0qIsjAtq3bw9/f3+YTCZMnToVBoPB1SHZ2L9/P4YNG4b33nsPRqMRffr0QVhY2A3fd+bMGQBA06ZN4e3tXWp5586drY+vXr3qvICJiGoZdsAmIqpBdevWtT62117/0KFDWLZsGfbu3YvU1FT4+Pjgtttuw6BBgzBixAhoNJpS7zEYDFixYgXWr1+P06dPo6CgAHXq1EGzZs0wePBgDBkypMwO35GRkbjvvvvw2muv4dSpU/jkk0/wwgsvOPTZzp49i6+++go7duzAtWvXoNFo0KhRI/Tv3x+jRo2Cn5+ftezq1avxyiuv2HxunU4HADh58qT19fXr1+PUqVPw9fXFCy+8gDFjxuCOO+64YSwBAQEAgGPHjiE/P79UQnH48GEAgEKhQIMGDRz6vERExJoJIqIadfHiRevjknfYP/zwQzzwwAP4/fffkZiYCK1Wi7y8POzduxfvvPMOHnjgASQmJtq8x2AwYPz48Xj33Xexf/9+5OXlwdvbGykpKYiLi8PUqVMxadIkmM3mMmO6//77cfvttwMAvvjiCxw5cqTSn+u7777DkCFDsGrVKly+fBkqlQpGoxFHjhzB3LlzMXToUJw6dcpa3svLC6GhodbRlNRqtbXTenF16tTB+PHj8ffff2Ps2LGQJKlC8fTt2xeAZaSoyZMn4/z58wAs++uXX37BzJkzAQAPPPAAIiIiKv15iYjIgskEEVENyc/Px/LlywEAgYGBaNu2rXXZ0qVL8dlnn0Gr1WLy5MnYvn079u3bh4MHD2Lx4sVo2LAhjh8/jieeeMJm5KXvv/8ee/bsQUBAAD799FMcOXIEe/bswb59+/DUU08BAOLi4vDrr7+WG9v06dMREBAAk8mEV155pVLNnf7880+88847EEJg7Nix+Oeff6ydnb///nu0atUKV65cwcSJE62dp4s6rBcNN9u8eXNrp/XiJk+ejKlTpyI4OLjC8RStf/jw4QAs/SIGDBiAdu3aoW3btnj55ZehVqvxwgsv4M0336zUeomIyBaTCSKiamYwGLBz506MGzcO586dAwC89NJL1rvyGRkZ+OijjwAAM2fOxDPPPIOQkBAAgEqlQq9evbB06VL4+vri+PHjNonBjh07AAD33HMP+vbta23O5Ofnh+eeew6xsbHw9fXFiRMnyo0xIiICr776KgDL3AwLFiyo0GczmUx4//33AQAvvvgi/vvf/yI6OhqApQlRhw4dsHTpUkRFRSEhIQFff/11hdZbVQqFAu+//z6mTZtmrc3Iy8uz1tCYzWYYjUaXDYlLRHSzYJ8JIiInWLduHeLi4kq9bjAYkJ2dDSEEAMtoQi+++CIeeOABa5k///wTeXl5qFOnDgYOHGh3/VFRUejVqxfWr1+Pv/76C/fddx8AwMfHBwCwbds2JCUllWqy87///c9uPwt77r33Xvzf//0f/vnnH3z55Ze488470aZNm3Lfs2vXLmsH5qLO5SX5+fnh7rvvxhdffIG//voLzz77bIXiqQq9Xo9p06Zh3bp16NGjB55//nnodDpkZGRg3bp1mDdvHj7++GMcPnwYn3zySYX3ERER2WIyQUTkBHq9vtyJ6Lp27YoePXpg6NChpS74Dxw4AADIyclBjx49ylxHXl4egOsjFQGWGol169bh7NmzuOOOO9CpUyd069YNsbGxaN68eaUvkt955x0MHjwYmZmZmDZtGtasWVPuLNFFsQPAgAEDyixXUFAAwNJJWwhR4b4Pjpo6dSrWr1+Pnj174vPPP4dCYamIDw8Px6OPPgqdTofx48djy5Yt+PHHHzFq1KhqjYeI6GbFZk5ERE5w77334uTJk9afQ4cOYdGiRdZ5JS5cuICmTZva7eybkpICwNL0JiUlpcyfomSi+PwUt99+O6ZPnw5fX18YjUZs374dc+fOxfDhw9GzZ0+8+eabNqMj3Uh4eDhee+01AMC5c+cwf/78cssXxV70uKyfnJwc62es6CR1jjp58iTWr18PAJgyZYo1kSiuW7du6NevHwBg1apV1RoPEdHNjDUTRETVwMvLC3fccQe6dOmCsWPH4tixY3jyyScxf/5860VsEVmWAQBt27bFypUrK72tBx54AIMGDcLGjRuxadMm7NixA2lpaUhOTsaKFSuwatUqvP322xgxYkSF1jds2DD8+eef+Pvvv7F06VL079/fOtRqSUWxh4WFYevWrZWOvToUDfvq4+ODpk2bllmuU6dO+PPPP3HhwoUaioyI6ObDmgkiomrk5+eHTz75BIGBgTAajZgyZUqpmoKgoCAAQEJCQpW2M3ToUHz44YfYvn07fv31V7z44osICgqC2WzGe++9V6kZt99++20EBgZClmVMmzatzCZcRbGnp6dbmzK5WtEs2UX9VMqiUlnupxkMhhuWJSIi+5hMEBFVs+joaLz77rsALMPDvvDCCzYX561atQIAJCUl2czFUNLjjz+OIUOGYO7cuQAs/RC+//57zJgxw2a0JkmSoNPpMGnSJMyZMweApb9FeesuKTQ0FK+//joASxOtefPm2S1XFLvJZLKOLGXPG2+8gUGDBmHatGkVjsFRRU3J8vPzyx3F6t9//wUA1K9fv9r7cBAR3ayYTBAR1YB+/fpZR2A6e/YsPvzwQ+uyAQMGQK1WAwDmzJljbTpU3ObNm7FlyxacOnUK9evXB2CZ6O2jjz7CN998U2bzqOIXyUW1CBU1ePBg9O/fHwCwZcsWu2V69OhhnQPio48+sluDceLECaxevRpnz55FZGSkzbKi/gzlTapXWT169LB2PF+4cKHdMufOncPatWsBwDrXBRERVR6TCSKiGvLqq68iKioKALBs2TLs2bMHgOVO+oQJEwBYkoann37a2o7fYDBgzZo1ePHFFwEAMTExGDJkCABAqVRah2Ndvnw5Pv74Y2RkZACwXJzv3LkTb7zxBgBLDcKtt95a6ZjfeuutcpMQLy8va2zHjx/HI488gqNHjwKw1Fb8888
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": "<Figure size 2000x3000 with 45 Axes>",
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABjAAAAkwCAYAAAAqL3YXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9d7wuV1n3j79XmZm77L1PTw8JOfQuCIpIk45SREAeQMijIn5BRUUQULoIiBA6JKCgkfKiCPjiUYoKPxBBUUQfHmpCKCknOXWXu83MWtfvj7Wm3LucJBByNsl88jrZd5l7Zs2ata51XZ+rLCUiQocOHTp06NChQ4cOHTp06NChQ4cOHTp06NChwzaCPtEN6NChQ4cOHTp06NChQ4cOHTp06NChQ4cOHTp0WI/OgdGhQ4cOHTp06NChQ4cOHTp06NChQ4cOHTp02HboHBgdOnTo0KFDhw4dOnTo0KFDhw4dOnTo0KFDh22HzoHRoUOHDh06dOjQoUOHDh06dOjQoUOHDh06dNh26BwYHTp06NChQ4cOHTp06NChQ4cOHTp06NChQ4dth86B0aFDhw4dOnTo0KFDhw4dOnTo0KFDhw4dOnTYdugcGB06dOjQoUOHDh06dOjQoUOHDh06dOjQoUOHbYfOgdGhQ4cOHTp06NChQ4cOHTp06NChQ4cOHTp02HboHBgdOnTo0KFDhw4dOnTo0KFDhw4dOnTo0KFDh22HzoHRocMPiE9/+tMopfj0pz99optyrXGf+9yH+9znPie6GR06dNjm+FHIuXe+850opfjOd75znZ2zQ4cOHX7U+FHpfWeffTbnnnvudXrOa4Nzzz2Xs88++4Rdv0OHDtsXnR7YoUOHDhvRcYEnBjc6B8ab3/xmlFL81E/91LX+7eWXX86LXvQivvzlL2/47txzz0UpVf9bWFjgnHPO4dGPfjQf/OAH8d5fB63v8IPgT//0T/nwhz98wq5fKWnVP2stp59+Oueeey6XXXbZj+y6X/3qV3nRi17UKYc3QnRy7saHEynnHv7whzMYDFhdXd3ymCc84Qmkacrhw4d/qGt1cq3DtUUnD298ONF6X4WLL76Ypz71qZxzzjn0ej2Wlpa4xz3uwete9zomk8mJbl6HGzA6uXfjQ6cHduhw9ehk440PJ1on7LjA6xY3OgfGu971Ls4++2z+/d//nYsuuuha/fbyyy/nxS9+8aZCCyDLMi688EIuvPBCzjvvPB7/+MfzrW99i0c/+tHc7373Y2Vl5Tq4gw7XFidaaFV4yUtewoUXXshb3/pWHvKQh/A3f/M33Pve92Y6nf5IrvfVr36VF7/4xZsKrU984hN84hOf+JFct8OJRyfnbnw4kXLuCU94ApPJhA996EObfj8ej/nIRz7Cgx/8YPbs2cOv/MqvMJlMOOuss671tY4n1zp02AydPLzxYTvoff/n//wfbn/72/O+972Phz3sYbzhDW/g5S9/OTe5yU141rOexTOe8YwT2r423va2t/GNb3zjRDejw3WITu7d+NDpgR06XD062Xjjw3bQCaHjAq8r3KgcGJdccgn/+q//ymte8xr27dvHu971rmv0u7IsyfP8ao+z1vLEJz6RJz7xiTzlKU/hT/7kT/jv//5vXv7yl/PpT3+apzzlKT/sLXT4McZDHvIQnvjEJ/Lrv/7rvP3tb+cP/uAPuPjii/m7v/u7670taZqSpun1ft0OP3p0cq7D9Y2HP/zhLC4u8u53v3vT7z/ykY8wGo14whOeAIAxhl6vh1Lq+mxmhxshOnnY4UTgkksu4XGPexxnnXUWX/3qV3nd617HU57yFJ7+9Kfznve8h69+9avc9ra3PdHNrJEkCVmWnehmdLiO0Mm9Dtc3Oj2ww48DOtnY4USi4wKvI8iNCC996Utl165dMpvN5P/7//4/ufnNb77hmEsuuUQAedWrXiXnnXeenHPOOaK1lvPOO0+ADf/e8Y53iIjIk5/8ZBkOh1te+4EPfKAopeQb3/jG3Od///d/Lz/7sz8rg8FAFhYW5KEPfah85Stf2fD7r33ta/KYxzxG9u7dK71eT25xi1vI8573vLljvvSlL8mDH/xgWVxclOFwKD/3cz8nn//853+AnmrwhS98QR70oAfJ0tKS9Pt9ude97iX/8i//Un//1a9+VXq9nvzKr/zK3O8++9nPitZanv3sZ9efnXXWWfLzP//z8vGPf1zueMc7SpZlcutb31o++MEPbrju0aNH5RnPeIacccYZkqap7N+/X17xileIc27uOOecvPa1r5Xb3e52kmWZ7N27Vx70oAfJF7/4RRGRTZ/Zk5/85Pr3l156qfzv//2/5aSTTpI0TeU2t7mN/MVf/MWG9nz/+9+XRzziETIYDGTfvn3yu7/7u/Kxj31MAPnUpz513D58xzveIUDdpgof/ehHBZA//dM/rT+7973vLfe+9703nOPJT36ynHXWWXOfvec975E73/nOsrCwIIuLi3K7291OXvva185dc/2/qq2bXWcymcgLX/hCufnNby5Zlskpp5wiv/iLvygXXXTRce+vw/ZCJ+euPTo5F/DDyLknP/nJYq2VK6+8csN3v/ALvyCLi4syHo9FpJFPl1xyyYZ+++xnPyt3vetdJcsyuelNbyp/9Vd/VR9zdXLtwx/+sDz0oQ+VU089VdI0lXPOOUde8pKXSFmWG9r0xje+UW5605tKr9eTu971rvKZz3xmU7k4nU7lBS94gezfv1/SNJUzzjhDnvWsZ8l0Oj1uf3TYHujk4bVHJw8Dfhh5+Ju/+ZsCyOc+97njHtfup3YbDx8+LM985jPldre7nQyHQ1lcXJQHP/jB8uUvf3nDb1//+tfLbW5zG+n3+7Jz5065y13uIu9617vq71dWVuQZz3iGnHXWWZKmqezbt0/uf//7y3/+53/Wx2ymY15dP3fYvujk3rVHJ/cCOj2w0wNvyOhk47VHJxsDOi5w++BG5cC41a1uJb/2a78mIiKf+cxnBJB///d/nzumElq3uc1t5JxzzpFXvOIVct5558l3vvMdeclLXiKA/MZv/IZceOGFcuGFF8rFF18sIlcvtC688EIB5I1vfGP92V//9V+LUkoe/OAHyxve8AZ55StfKWeffbbs3LlzbkH/7//+b1laWpI9e/bIc5/7XDn//PPl2c9+ttz+9revj/nKV74iw+FQTj31VHnpS18qr3jFK+SmN72pZFkmX/jCF36g/vqnf/onSdNU7n73u8urX/1qOe+88+QOd7iDpGkq//Zv/1Yf96pXvUoA+chHPiIiImtra7J//365zW1uM7e4n3XWWXKLW9xCdu7cKc95znPkNa95jdz+9rcXrbV84hOfqI8bjUZyhzvcQfbs2SPPe97z5K1vfas86UlPEqWUPOMZz5hr47nnniuAPOQhD5HXvva18ud//ufyiEc8Qt7whjfU/Z5lmdzznvesn9m//uu/iojIgQMH5IwzzpAzzzxTXvKSl8hb3vIWefjDHy6AnHfeefU1xuOx3OIWt5BeryfPfvaz5bWvfa3c5S53kTvc4Q4/lNB64xvfKIC85S1vqT+7pkLrE5/4hAByv/vdT970pjfJm970Jvmt3/otecxjHiMiIhdffLH8zu/8jgDyvOc9r773AwcObHqdsizlfve7nwDyuMc9Tt74xjfKy1/+cvm5n/s5+fCHP3zc++uwvdDJuWuHTs4F/LByrpJJVZsqHD58WJIkkSc96Un1Z1sZrre85S3l5JNPluc973nyxje+Ue585zuLUqpW5K9Orj3ykY+Uxz72sfKqV71K3vKWt8hjHvMYAeQP/uAP5tr05je/WQC55z3vKa9//evl93//92X37t2yf//+ObnonJMHPvCBMhgM5Hd/93fl/PPPl9/6rd8Sa6084hGPOG5/dNge6OThtUMnDwN+WHl4+umnyznnnHON+329A+OLX/yi7N+/X57znOfI+eefLy95yUvk9NNPlx07dshll11WH3fBBRcIII9+9KPl/PPPl9e97nXya7/2a/I7v/M79TGPf/zjJU1T+f3f/315+9vfLq985SvlYQ97mPzN3/x
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"conduct_experiment(\"ResNet18\", resnet_model, 5, torch.optim.Adam(resnet_model.parameters()))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 2
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython2",
|
||
|
"version": "2.7.6"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0
|
||
|
}
|