ium_444354/pytorch/pytorch.ipynb

532 lines
16 KiB
Plaintext
Raw Normal View History

2022-04-09 20:40:50 +02:00
{
"cells": [
{
"cell_type": "code",
2022-04-17 19:33:40 +02:00
"execution_count": 18,
2022-04-09 20:40:50 +02:00
"id": "e1c5e25d",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import jovian\n",
"import torchvision\n",
"import matplotlib\n",
"import torch.nn as nn\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import torch.nn.functional as F\n",
"from torchvision.datasets.utils import download_url\n",
"from torch.utils.data import DataLoader, TensorDataset, random_split\n",
2022-04-17 19:33:40 +02:00
"import random\n",
"import os\n",
"import sys"
2022-04-09 20:40:50 +02:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c77ff6aa",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>fixed acidity</th>\n",
" <th>volatile acidity</th>\n",
" <th>citric acid</th>\n",
" <th>residual sugar</th>\n",
" <th>chlorides</th>\n",
" <th>free sulfur dioxide</th>\n",
" <th>total sulfur dioxide</th>\n",
" <th>density</th>\n",
" <th>pH</th>\n",
" <th>sulphates</th>\n",
" <th>alcohol</th>\n",
" <th>quality</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>7.8</td>\n",
" <td>0.88</td>\n",
" <td>0.00</td>\n",
" <td>2.6</td>\n",
" <td>0.098</td>\n",
" <td>25.0</td>\n",
" <td>67.0</td>\n",
" <td>0.9968</td>\n",
" <td>3.20</td>\n",
" <td>0.68</td>\n",
" <td>9.8</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>7.8</td>\n",
" <td>0.76</td>\n",
" <td>0.04</td>\n",
" <td>2.3</td>\n",
" <td>0.092</td>\n",
" <td>15.0</td>\n",
" <td>54.0</td>\n",
" <td>0.9970</td>\n",
" <td>3.26</td>\n",
" <td>0.65</td>\n",
" <td>9.8</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>11.2</td>\n",
" <td>0.28</td>\n",
" <td>0.56</td>\n",
" <td>1.9</td>\n",
" <td>0.075</td>\n",
" <td>17.0</td>\n",
" <td>60.0</td>\n",
" <td>0.9980</td>\n",
" <td>3.16</td>\n",
" <td>0.58</td>\n",
" <td>9.8</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>7.4</td>\n",
" <td>0.70</td>\n",
" <td>0.00</td>\n",
" <td>1.9</td>\n",
" <td>0.076</td>\n",
" <td>11.0</td>\n",
" <td>34.0</td>\n",
" <td>0.9978</td>\n",
" <td>3.51</td>\n",
" <td>0.56</td>\n",
" <td>9.4</td>\n",
" <td>5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
"0 7.4 0.70 0.00 1.9 0.076 \n",
"1 7.8 0.88 0.00 2.6 0.098 \n",
"2 7.8 0.76 0.04 2.3 0.092 \n",
"3 11.2 0.28 0.56 1.9 0.075 \n",
"4 7.4 0.70 0.00 1.9 0.076 \n",
"\n",
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
"0 11.0 34.0 0.9978 3.51 0.56 \n",
"1 25.0 67.0 0.9968 3.20 0.68 \n",
"2 15.0 54.0 0.9970 3.26 0.65 \n",
"3 17.0 60.0 0.9980 3.16 0.58 \n",
"4 11.0 34.0 0.9978 3.51 0.56 \n",
"\n",
" alcohol quality \n",
"0 9.4 5 \n",
"1 9.8 5 \n",
"2 9.8 5 \n",
"3 9.8 6 \n",
"4 9.4 5 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataframe_raw = pd.read_csv(\"winequality-red.csv\")\n",
"dataframe_raw.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "99f42861",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(['fixed acidity',\n",
" 'volatile acidity',\n",
" 'citric acid',\n",
" 'residual sugar',\n",
" 'chlorides',\n",
" 'free sulfur dioxide',\n",
" 'total sulfur dioxide',\n",
" 'density',\n",
" 'pH',\n",
" 'sulphates',\n",
" 'alcohol'],\n",
" ['quality'])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_cols=list(dataframe_raw.columns)[:-1]\n",
"output_cols = ['quality']\n",
"input_cols,output_cols"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "87011c12",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[ 7.4 , 0.7 , 0. , ..., 3.51 , 0.56 , 9.4 ],\n",
" [ 7.8 , 0.88 , 0. , ..., 3.2 , 0.68 , 9.8 ],\n",
" [ 7.8 , 0.76 , 0.04 , ..., 3.26 , 0.65 , 9.8 ],\n",
" ...,\n",
" [ 6.3 , 0.51 , 0.13 , ..., 3.42 , 0.75 , 11. ],\n",
" [ 5.9 , 0.645, 0.12 , ..., 3.57 , 0.71 , 10.2 ],\n",
" [ 6. , 0.31 , 0.47 , ..., 3.39 , 0.66 , 11. ]]),\n",
" array([[5],\n",
" [5],\n",
" [5],\n",
" ...,\n",
" [6],\n",
" [5],\n",
" [6]], dtype=int64))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def dataframe_to_arrays(dataframe):\n",
" dataframe1 = dataframe_raw.copy(deep=True)\n",
" inputs_array = dataframe1[input_cols].to_numpy()\n",
" targets_array = dataframe1[output_cols].to_numpy()\n",
" return inputs_array, targets_array\n",
"\n",
"inputs_array, targets_array = dataframe_to_arrays(dataframe_raw)\n",
"inputs_array, targets_array"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "705fb5b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[ 7.4000, 0.7000, 0.0000, ..., 3.5100, 0.5600, 9.4000],\n",
" [ 7.8000, 0.8800, 0.0000, ..., 3.2000, 0.6800, 9.8000],\n",
" [ 7.8000, 0.7600, 0.0400, ..., 3.2600, 0.6500, 9.8000],\n",
" ...,\n",
" [ 6.3000, 0.5100, 0.1300, ..., 3.4200, 0.7500, 11.0000],\n",
" [ 5.9000, 0.6450, 0.1200, ..., 3.5700, 0.7100, 10.2000],\n",
" [ 6.0000, 0.3100, 0.4700, ..., 3.3900, 0.6600, 11.0000]]),\n",
" tensor([[5.],\n",
" [5.],\n",
" [5.],\n",
" ...,\n",
" [6.],\n",
" [5.],\n",
" [6.]]))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = torch.from_numpy(inputs_array).type(torch.float)\n",
"targets = torch.from_numpy(targets_array).type(torch.float)\n",
"inputs,targets"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "71f14b4a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2022-04-17 19:33:40 +02:00
"<torch.utils.data.dataset.TensorDataset at 0x1f334183760>"
2022-04-09 20:40:50 +02:00
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset = TensorDataset(inputs, targets)\n",
"dataset"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c4f8cd40",
"metadata": {},
"outputs": [],
"source": [
"train_ds, val_ds = random_split(dataset, [1300, 299])\n",
"batch_size=50\n",
"train_loader = DataLoader(train_ds, batch_size, shuffle=True)\n",
"val_loader = DataLoader(val_ds, batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "56f75067",
"metadata": {},
"outputs": [],
"source": [
"class WineQuality(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.linear = nn.Linear(input_size,output_size) \n",
" \n",
" def forward(self, xb): \n",
" out = self.linear(xb)\n",
" return out\n",
" \n",
" def training_step(self, batch):\n",
" inputs, targets = batch \n",
" # Generate predictions\n",
" out = self(inputs) \n",
" # Calcuate loss\n",
" loss = F.l1_loss(out,targets) \n",
" return loss\n",
" \n",
" def validation_step(self, batch):\n",
" inputs, targets = batch\n",
" # Generate predictions\n",
" out = self(inputs)\n",
" # Calculate loss\n",
" loss = F.l1_loss(out,targets) \n",
" return {'val_loss': loss.detach()}\n",
" \n",
" def validation_epoch_end(self, outputs):\n",
" batch_losses = [x['val_loss'] for x in outputs]\n",
" epoch_loss = torch.stack(batch_losses).mean() \n",
" return {'val_loss': epoch_loss.item()}\n",
" \n",
" def epoch_end(self, epoch, result, num_epochs):\n",
" # Print result every 100th epoch\n",
" if (epoch+1) % 100 == 0 or epoch == num_epochs-1:\n",
" print(\"Epoch [{}], val_loss: {:.4f}\".format(epoch+1, result['val_loss']))"
]
},
{
"cell_type": "code",
2022-04-17 19:33:40 +02:00
"execution_count": 9,
2022-04-09 20:40:50 +02:00
"id": "57f354ce",
"metadata": {},
"outputs": [],
"source": [
"input_size = len(input_cols)\n",
"output_size = len(output_cols)"
]
},
{
"cell_type": "code",
2022-04-17 19:33:40 +02:00
"execution_count": 10,
2022-04-09 20:40:50 +02:00
"id": "4a926cfa",
"metadata": {},
"outputs": [],
"source": [
"model=WineQuality()"
]
},
{
"cell_type": "code",
2022-04-17 19:33:40 +02:00
"execution_count": 11,
2022-04-09 20:40:50 +02:00
"id": "3df1733d",
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model, val_loader):\n",
" outputs = [model.validation_step(batch) for batch in val_loader]\n",
" return model.validation_epoch_end(outputs)\n",
"\n",
"def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):\n",
" history = []\n",
" optimizer = opt_func(model.parameters(), lr)\n",
" for epoch in range(epochs):\n",
" for batch in train_loader:\n",
" loss = model.training_step(batch)\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" result = evaluate(model, val_loader)\n",
" model.epoch_end(epoch, result, epochs)\n",
" history.append(result)\n",
" return history"
]
},
{
"cell_type": "code",
2022-04-17 19:33:40 +02:00
"execution_count": 12,
2022-04-09 20:40:50 +02:00
"id": "3ed5f872",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-04-17 19:33:40 +02:00
"Epoch [100], val_loss: 4.1732\n",
"Epoch [200], val_loss: 1.6444\n",
"Epoch [300], val_loss: 1.4860\n",
"Epoch [400], val_loss: 1.4119\n",
"Epoch [500], val_loss: 1.3407\n",
"Epoch [600], val_loss: 1.2709\n",
"Epoch [700], val_loss: 1.2045\n",
"Epoch [800], val_loss: 1.1401\n",
"Epoch [900], val_loss: 1.0783\n",
"Epoch [1000], val_loss: 1.0213\n",
"Epoch [1100], val_loss: 0.9678\n",
"Epoch [1200], val_loss: 0.9186\n",
"Epoch [1300], val_loss: 0.8729\n",
"Epoch [1400], val_loss: 0.8320\n",
"Epoch [1500], val_loss: 0.7959\n"
2022-04-09 20:40:50 +02:00
]
}
],
"source": [
"epochs = 1500\n",
"lr = 1e-6\n",
"history5 = fit(epochs, lr, model, train_loader, val_loader)"
]
},
{
"cell_type": "code",
2022-04-17 19:33:40 +02:00
"execution_count": 27,
2022-04-09 20:40:50 +02:00
"id": "413ab394",
"metadata": {},
"outputs": [],
"source": [
"def predict_single(input, target, model):\n",
" inputs = input.unsqueeze(0)\n",
" predictions = model(inputs)\n",
" prediction = predictions[0].detach()\n",
2022-04-17 19:33:40 +02:00
"\n",
" return \"Target: \"+str(target)+\"----- Prediction: \"+str(prediction)+\"\\n\""
2022-04-09 20:40:50 +02:00
]
},
{
"cell_type": "code",
2022-04-17 19:33:40 +02:00
"execution_count": 32,
2022-04-09 20:40:50 +02:00
"id": "b1ab4522",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-04-17 19:33:40 +02:00
"Target: tensor([5.])----- Prediction: tensor([4.9765])\n",
"Target: tensor([5.])----- Prediction: tensor([6.6649])\n",
"Target: tensor([5.])----- Prediction: tensor([5.2627])\n",
"Target: tensor([7.])----- Prediction: tensor([5.7054])\n",
"Target: tensor([5.])----- Prediction: tensor([5.1168])\n",
"Target: tensor([7.])----- Prediction: tensor([5.3928])\n",
"Target: tensor([5.])----- Prediction: tensor([4.8501])\n",
"Target: tensor([4.])----- Prediction: tensor([5.4210])\n",
"Target: tensor([5.])----- Prediction: tensor([4.6719])\n",
"Target: tensor([5.])----- Prediction: tensor([7.8635])\n"
2022-04-09 20:40:50 +02:00
]
}
],
"source": [
"#wylosuj 10 próbek predykcji\n",
"for i in random.sample(range(0, len(val_ds)), 10):\n",
" input_, target = val_ds[i]\n",
2022-04-17 19:33:40 +02:00
" print(predict_single(input_, target, model),end=\"\")\n",
2022-04-09 20:40:50 +02:00
" "
]
},
{
"cell_type": "code",
2022-04-17 19:33:40 +02:00
"execution_count": 36,
"id": "a754aaff",
2022-04-09 20:40:50 +02:00
"metadata": {},
"outputs": [],
2022-04-17 19:33:40 +02:00
"source": [
"with open(\"result.txt\", \"w+\") as file:\n",
" for i in range(0, len(val_ds), 1):\n",
" input_, target = val_ds[i]\n",
" file.write(str(predict_single(input_, target, model)))"
]
2022-04-09 20:40:50 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}