ium_464915/lab1.ipynb

263 lines
352 KiB
Plaintext
Raw Normal View History

2024-03-19 19:22:03 +01:00
{
"cells": [
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 1,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from torch.utils.data import DataLoader, random_split\n",
"from torchvision import datasets, transforms, utils\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 2,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Parameters\n",
"\n",
"IMG_SIZE = 224\n",
"BATCH_SIZE = 32\n",
"IMG_SHOW_NUM = 6"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 3,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Data Augmentation: resize, normalize (ImageNet mean and std)\n",
"\n",
"transformer = transforms.Compose([\n",
" transforms.Resize(size = (IMG_SIZE, IMG_SIZE), antialias = True),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),\n",
"])\n",
"\n",
"testTransformer = transforms.Compose([\n",
" transforms.Resize(size = (IMG_SIZE, IMG_SIZE), antialias = True),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),\n",
"])"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 4,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [],
"source": [
"# Dataset from: https://www.kaggle.com/datasets/fanconic/skin-cancer-malignant-vs-benign\n",
"\n",
"trainPath = \"./train\"\n",
"testPath = \"./test\"\n",
"\n",
"trainData = datasets.ImageFolder(root = trainPath, transform = transformer)\n",
"testData = datasets.ImageFolder(root = testPath, transform = testTransformer)"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 5,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [],
"source": [
"trainLoader = DataLoader(trainData, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)\n",
"testLoader = DataLoader(testData, batch_size = BATCH_SIZE, shuffle = False, num_workers = 4)"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 6,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [
{
"data": {
2024-03-20 10:47:15 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlYAAAGsCAYAAAAfROn9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9ebBt21UXjn/GmHOttfe53Wvy0pEQzMMmJRKqCPBVmmASBQrFiEVjGQWqVGwCFpESsaEJVWiBgkWvlBQ/EC3w98NgB4KYFE35hwakrJSUSUGAXzDNey+vu+eevdeac3z/GM2ca59z7r1JngmXt+d7555z9tl7rbnmHHOMz+hJRATHcRzHcRzHcRzHcRzH8SEP/khP4DiO4ziO4ziO4ziO43fLOAKr4ziO4ziO4ziO4ziOZ2gcgdVxHMdxHMdxHMdxHMczNI7A6jiO4ziO4ziO4ziO4xkaR2B1HMdxHMdxHMdxHMfxDI0jsDqO4ziO4ziO4ziO43iGxhFYHcdxHMdxHMdxHMdxPEPjCKyO4ziO4ziO4ziO4zieoXEEVsdxHMdxHMdxHMdxHM/QOAKr4ziO43jGx5d+6ZfiYz7mYz6oz37DN3wDiOiZndBdjg9l3sdxHMdxHMARWB3HcTyrBhHd1ddb3vKWj/RUj+M4juM47slBx16Bx3Ecz57xL/7Fv1j9/kM/9EP4mZ/5GfzwD//w6vU/9sf+GJ73vOd90PeZ5xm1VkzT9AF/dlkWLMuCzWbzQd//gx0fyryP4ziO4ziAI7A6juN4Vo/Xv/71+O7v/m7ciQ2cnp7i5OTkwzSr4ziO4ziOe3ccXYHHcRzHsRqf+ZmfiY/7uI/DW9/6VnzGZ3wGTk5O8Hf+zt8BAPzET/wEPvdzPxcvfOELMU0THn74YXzTN30TSimraxzGKr3zne8EEeEf/aN/hH/2z/4ZHn74YUzThE/6pE/Cf/tv/2312YtirIgIr3/96/GmN70JH/dxH4dpmvAH/+AfxE/91E+dm/9b3vIWvOIVr8Bms8HDDz+Mf/pP/+ldx23dbt7f/d3fjZe+9KU4OTnBH//jfxy/9Vu/BRHBN33TN+FFL3oRttst/tSf+lN47LHHVte82zUDEPfYbrf45E/+ZPz8z/88PvMzPxOf+ZmfuXrfbrfD13/91+NjP/ZjMU0TXvziF+Nv/a2/hd1ut3rfz/zMz+DTPu3TcN999+Hq1av4/b//98deHsdxHMf/nZE/0hM4juM4jt9549FHH8XnfM7n4Iu/+Ivxute9LtyCP/iDP4irV6/iDW94A65evYr/8l/+C77u674OTz75JL71W7/1jtf9l//yX+Kpp57Cl3/5l4OI8C3f8i34/M//fPzar/0ahmG47Wd/4Rd+AT/+4z+Ov/bX/hquXbuG7/iO78Cf+TN/Br/5m7+JBx98EADwy7/8y/jsz/5svOAFL8A3fuM3opSCN77xjXjooYc+pPX4kR/5Eez3e3zFV3wFHnvsMXzLt3wLvvALvxCvetWr8Ja3vAVf8zVfg3e84x34zu/8Tnz1V381fuAHfiA+e7dr9r3f+714/etfj0//9E/HV33VV+Gd73wnXvva1+L+++/Hi170onhfrRWf93mfh1/4hV/AX/7Lfxkve9nL8D//5//Et3/7t+N//+//jTe96U0AgLe97W34E3/iT+DjP/7j8cY3vhHTNOEd73gHfvEXf/FDWovjOI7juMOQ4ziO43jWjr/+1/+6HLKBV77ylQJAvu/7vu/c+09PT8+99uVf/uVycnIiZ2dn8dqXfMmXyEte8pL4/dd//dcFgDz44IPy2GOPxes/8RM/IQDk3/27fxevff3Xf/25OQGQcRzlHe94R7z2K7/yKwJAvvM7vzNe+5N/8k/KycmJvOtd74rX3v72t0vO+dw1LxqXzfuhhx6Sxx9/PF7/2q/9WgEgL3/5y2We53j9z/7ZPyvjOK7W4m7WbLfbyYMPPiif9EmftLreD/7gDwoAeeUrXxmv/fAP/7Aws/z8z//86prf933fJwDkF3/xF0VE5Nu//dsFgLzvfe+743Mfx3EcxzM3jq7A4ziO4zg3pmnCl33Zl517fbvdxs9PPfUUHnnkEXz6p386Tk9P8au/+qt3vO4XfdEX4f7774/fP/3TPx0A8Gu/9mt3/OxrXvMaPPzww/H7x3/8x+P69evx2VIK/vN//s947Wtfixe+8IXxvo/92I/F53zO59zx+rcbX/AFX4AbN27E75/yKZ8CAHjd616HnPPq9f1+j3e9613x2t2s2X//7/8djz76KP7SX/pLq+v9uT/351brBQD/+l//a7zsZS/DH/gDfwCPPPJIfL3qVa8CALz5zW8GANx3330A1BVZa/2Qnv84juM47n4cgdVxHMdxnBsf9VEfhXEcz73+tre9DX/6T/9p3LhxA9evX8dDDz2E173udQCAJ5544o7X/eiP/ujV7w4a3v/+93/An/XP+2ff+9734tatW/jYj/3Yc++76LUPZBze20HWi1/84gtf75/nbtbsN37jNy6cZ875XF2tt7/97Xjb296Ghx56aPX1+37f7wOg6wAoiP3UT/1U/MW/+BfxvOc9D1/8xV+MH/uxHzuCrOM4jv/L4xhjdRzHcRznRm9l8fH444/jla98Ja5fv443vvGNePjhh7HZbPBLv/RL+Jqv+Zq7EtgppQtfl7tITv5QPvuhjsvufac5PRNrdjhqrfhDf+gP4du+7dsu/LuDve12i5/7uZ/Dm9/8ZvyH//Af8FM/9VP40R/9UbzqVa/CT//0T1869+M4juP40MYRWB3HcRzHXY23vOUtePTRR/HjP/7j+IzP+Ix4/dd//dc/grNq47nPfS42mw3e8Y53nPvbRa99OMbdrtlLXvISADrPP/pH/2i8viwL3vnOd+LjP/7j47WHH34Yv/Irv4JXv/rVd8x0ZGa8+tWvxqtf/Wp827d9G775m78Zf/fv/l28+c1vxmte85pn4hGP4ziO42AcXYHHcRzHcVfDLRy9hWi/3+N7vud7PlJTWo2UEl7zmtfgTW96E377t387Xn/HO96Bn/zJn/yIzQm485q94hWvwIMPPojv//7vx7Is8fqP/MiPnHOTfuEXfiHe9a534fu///vP3e/WrVu4efMmAJwr+wAAn/AJnwAA58oyHMdxHMczN44Wq+M4juO4q/FH/sgfwf33348v+ZIvwVd+5VeCiPDDP/zDHxZX3N2Ob/iGb8BP//RP41M/9VPxV//qX0UpBd/1Xd+Fj/u4j8P/+B//48M+n7tds3Ec8Q3f8A34iq/4CrzqVa/CF37hF+Kd73wnfvAHfxAPP/zwyjL15//8n8eP/diP4a/8lb+CN7/5zfjUT/1UlFLwq7/6q/ixH/sx/Kf/9J/wile8Am984xvxcz/3c/jcz/1cvOQlL8F73/tefM/3fA9e9KIX4dM+7dM+3EtxHMfxrBlHYHUcx3EcdzUefPBB/Pt//+/xN//m38Tf+3t/D/fffz9e97rX4dWvfjU+67M+6yM9PQDAJ37iJ+Inf/In8dVf/dX4+3//7+PFL34x3vjGN+J//a//dVdZi8/0+EDW7PWvfz1EBP/4H/9jfPVXfzVe/vKX49/+23+Lr/zKr1y192FmvOlNb8K3f/u344d+6Ifwb/7Nv8HJyQle+tKX4m/8jb8RQeyf93mfh3e+8534gR/4ATzyyCN4znOeg1e+8pX4xm/8xlWG43Ecx3E8s+PY0uY4juM4fteP1772tXjb296Gt7/97R/pqXxAo9aKhx56CJ//+Z9/oevvOI7jOH7njWOM1XEcx3H8rhq3bt1a/f72t78d//E//sdzbWF+p42zs7NzLsIf+qEfwmOPPfY7fu7HcRzH0cbRYnUcx3Ecv6vGC17wAnzpl34pXvrSl+I3fuM38L3f+73Y7Xb45V/+Zfze3/t7P9LTu3S85S1vwVd91VfhC77gC/Dggw/il37pl/DP//k/x8te9jK89a1vvbCu2HEcx3H8zhvHGKvjOI7j+F01PvuzPxv/6l/9K7z73e/GNE34w3/4D+Obv/mbf0eDKgD4mI/5GLz4xS/
2024-03-19 19:22:03 +01:00
"text/plain": [
"<Figure size 700x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Show images\n",
"\n",
"def show_images(imgs, title):\n",
" grid = utils.make_grid(imgs, nrow = 3, padding = 2, normalize=True)\n",
" plt.figure(figsize = (7, 10))\n",
" plt.imshow(np.transpose(grid, (1, 2, 0)))\n",
" plt.title(title)\n",
"\n",
"dataiter = iter(trainLoader)\n",
"images, labels = next(dataiter)\n",
"show_images(images[:IMG_SHOW_NUM], title = \"Training images\")\n"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 7,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train data statistics\n",
"Number of samples: 11879\n",
"Number of classes: 2\n",
"Classes: ['Benign', 'Malignant']\n",
"Shape of the data: torch.Size([3, 224, 224])\n"
]
}
],
"source": [
"# Train data statistics: number of samples, number of classes, classes, shape of the data\n",
"\n",
"print(\"Train data statistics\")\n",
"print(\"Number of samples: \", len(trainData))\n",
"print(\"Number of classes: \", len(trainData.classes))\n",
"print(\"Classes: \", trainData.classes)\n",
"print(\"Shape of the data: \", trainData[0][0].shape)"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 8,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test data statistics\n",
"Number of samples: 2000\n",
"Number of classes: 2\n",
"Classes: ['Benign', 'Malignant']\n",
"Shape of the data: torch.Size([3, 224, 224])\n"
]
}
],
"source": [
"# Test data statistics: number of samples, number of classes, classes, shape of the data\n",
"\n",
"print(\"Test data statistics\")\n",
"print(\"Number of samples: \", len(testData))\n",
"print(\"Number of classes: \", len(testData.classes))\n",
"print(\"Classes: \", testData.classes)\n",
"print(\"Shape of the data: \", testData[0][0].shape)"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 9,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-03-20 10:47:15 +01:00
"{'Benign': 6289, 'Malignant': 5590}\n"
2024-03-19 19:22:03 +01:00
]
}
],
"source": [
"# Number of samples per class in train data\n",
"\n",
2024-03-20 10:47:15 +01:00
"class_dict = dict.fromkeys(testData.class_to_idx, 0)\n",
2024-03-19 19:22:03 +01:00
" \n",
"for i in range(len(trainData)):\n",
" class_dict[trainData.classes[trainData[i][1]]] += 1\n",
"\n",
"print(class_dict)"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 10,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean: tensor([1.0394, 0.4445, 0.5912])\n",
"Std: tensor([0.5279, 0.5983, 0.6442])\n"
]
}
],
"source": [
"# Mean and std of the train data\n",
"\n",
"mean = 0\n",
"std = 0\n",
"for images, _ in trainLoader:\n",
" batch_samples = images.size(0)\n",
" images = images.view(batch_samples, images.size(1), -1)\n",
" mean += images.mean(2).sum(0)\n",
" std += images.std(2).sum(0)\n",
"\n",
"mean /= len(trainLoader.dataset)\n",
"std /= len(trainLoader.dataset)\n",
"\n",
"print(\"Mean: \", mean)\n",
"print(\"Std: \", std)"
]
},
{
"cell_type": "code",
2024-03-20 10:47:15 +01:00
"execution_count": 11,
2024-03-19 19:22:03 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-03-20 10:47:15 +01:00
"Number of samples in train data: 9503\n",
"Number of samples in validation data: 2376\n"
2024-03-19 19:22:03 +01:00
]
}
],
"source": [
"# Split train data into train and validation data\n",
"\n",
"trainData, valData = random_split(trainData, [int(0.8*len(trainData)), len(trainData) - int(0.8*len(trainData))])\n",
"valLoader = DataLoader(valData, batch_size = BATCH_SIZE, shuffle = False, num_workers = 4)\n",
"\n",
"print(\"Number of samples in train data: \", len(trainData))\n",
"print(\"Number of samples in validation data: \", len(valData))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}