527 lines
15 KiB
Plaintext
527 lines
15 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "e1c5e25d",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import jovian\n",
|
||
|
"import torchvision\n",
|
||
|
"import matplotlib\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import pandas as pd\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import seaborn as sns\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"from torchvision.datasets.utils import download_url\n",
|
||
|
"from torch.utils.data import DataLoader, TensorDataset, random_split\n",
|
||
|
"import random"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "c77ff6aa",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<div>\n",
|
||
|
"<style scoped>\n",
|
||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||
|
" vertical-align: middle;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe tbody tr th {\n",
|
||
|
" vertical-align: top;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe thead th {\n",
|
||
|
" text-align: right;\n",
|
||
|
" }\n",
|
||
|
"</style>\n",
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: right;\">\n",
|
||
|
" <th></th>\n",
|
||
|
" <th>fixed acidity</th>\n",
|
||
|
" <th>volatile acidity</th>\n",
|
||
|
" <th>citric acid</th>\n",
|
||
|
" <th>residual sugar</th>\n",
|
||
|
" <th>chlorides</th>\n",
|
||
|
" <th>free sulfur dioxide</th>\n",
|
||
|
" <th>total sulfur dioxide</th>\n",
|
||
|
" <th>density</th>\n",
|
||
|
" <th>pH</th>\n",
|
||
|
" <th>sulphates</th>\n",
|
||
|
" <th>alcohol</th>\n",
|
||
|
" <th>quality</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>0</th>\n",
|
||
|
" <td>7.4</td>\n",
|
||
|
" <td>0.70</td>\n",
|
||
|
" <td>0.00</td>\n",
|
||
|
" <td>1.9</td>\n",
|
||
|
" <td>0.076</td>\n",
|
||
|
" <td>11.0</td>\n",
|
||
|
" <td>34.0</td>\n",
|
||
|
" <td>0.9978</td>\n",
|
||
|
" <td>3.51</td>\n",
|
||
|
" <td>0.56</td>\n",
|
||
|
" <td>9.4</td>\n",
|
||
|
" <td>5</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>1</th>\n",
|
||
|
" <td>7.8</td>\n",
|
||
|
" <td>0.88</td>\n",
|
||
|
" <td>0.00</td>\n",
|
||
|
" <td>2.6</td>\n",
|
||
|
" <td>0.098</td>\n",
|
||
|
" <td>25.0</td>\n",
|
||
|
" <td>67.0</td>\n",
|
||
|
" <td>0.9968</td>\n",
|
||
|
" <td>3.20</td>\n",
|
||
|
" <td>0.68</td>\n",
|
||
|
" <td>9.8</td>\n",
|
||
|
" <td>5</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>2</th>\n",
|
||
|
" <td>7.8</td>\n",
|
||
|
" <td>0.76</td>\n",
|
||
|
" <td>0.04</td>\n",
|
||
|
" <td>2.3</td>\n",
|
||
|
" <td>0.092</td>\n",
|
||
|
" <td>15.0</td>\n",
|
||
|
" <td>54.0</td>\n",
|
||
|
" <td>0.9970</td>\n",
|
||
|
" <td>3.26</td>\n",
|
||
|
" <td>0.65</td>\n",
|
||
|
" <td>9.8</td>\n",
|
||
|
" <td>5</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>3</th>\n",
|
||
|
" <td>11.2</td>\n",
|
||
|
" <td>0.28</td>\n",
|
||
|
" <td>0.56</td>\n",
|
||
|
" <td>1.9</td>\n",
|
||
|
" <td>0.075</td>\n",
|
||
|
" <td>17.0</td>\n",
|
||
|
" <td>60.0</td>\n",
|
||
|
" <td>0.9980</td>\n",
|
||
|
" <td>3.16</td>\n",
|
||
|
" <td>0.58</td>\n",
|
||
|
" <td>9.8</td>\n",
|
||
|
" <td>6</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>4</th>\n",
|
||
|
" <td>7.4</td>\n",
|
||
|
" <td>0.70</td>\n",
|
||
|
" <td>0.00</td>\n",
|
||
|
" <td>1.9</td>\n",
|
||
|
" <td>0.076</td>\n",
|
||
|
" <td>11.0</td>\n",
|
||
|
" <td>34.0</td>\n",
|
||
|
" <td>0.9978</td>\n",
|
||
|
" <td>3.51</td>\n",
|
||
|
" <td>0.56</td>\n",
|
||
|
" <td>9.4</td>\n",
|
||
|
" <td>5</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>\n",
|
||
|
"</div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
" fixed acidity volatile acidity citric acid residual sugar chlorides \\\n",
|
||
|
"0 7.4 0.70 0.00 1.9 0.076 \n",
|
||
|
"1 7.8 0.88 0.00 2.6 0.098 \n",
|
||
|
"2 7.8 0.76 0.04 2.3 0.092 \n",
|
||
|
"3 11.2 0.28 0.56 1.9 0.075 \n",
|
||
|
"4 7.4 0.70 0.00 1.9 0.076 \n",
|
||
|
"\n",
|
||
|
" free sulfur dioxide total sulfur dioxide density pH sulphates \\\n",
|
||
|
"0 11.0 34.0 0.9978 3.51 0.56 \n",
|
||
|
"1 25.0 67.0 0.9968 3.20 0.68 \n",
|
||
|
"2 15.0 54.0 0.9970 3.26 0.65 \n",
|
||
|
"3 17.0 60.0 0.9980 3.16 0.58 \n",
|
||
|
"4 11.0 34.0 0.9978 3.51 0.56 \n",
|
||
|
"\n",
|
||
|
" alcohol quality \n",
|
||
|
"0 9.4 5 \n",
|
||
|
"1 9.8 5 \n",
|
||
|
"2 9.8 5 \n",
|
||
|
"3 9.8 6 \n",
|
||
|
"4 9.4 5 "
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"dataframe_raw = pd.read_csv(\"winequality-red.csv\")\n",
|
||
|
"dataframe_raw.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "99f42861",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(['fixed acidity',\n",
|
||
|
" 'volatile acidity',\n",
|
||
|
" 'citric acid',\n",
|
||
|
" 'residual sugar',\n",
|
||
|
" 'chlorides',\n",
|
||
|
" 'free sulfur dioxide',\n",
|
||
|
" 'total sulfur dioxide',\n",
|
||
|
" 'density',\n",
|
||
|
" 'pH',\n",
|
||
|
" 'sulphates',\n",
|
||
|
" 'alcohol'],\n",
|
||
|
" ['quality'])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"input_cols=list(dataframe_raw.columns)[:-1]\n",
|
||
|
"output_cols = ['quality']\n",
|
||
|
"input_cols,output_cols"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "87011c12",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(array([[ 7.4 , 0.7 , 0. , ..., 3.51 , 0.56 , 9.4 ],\n",
|
||
|
" [ 7.8 , 0.88 , 0. , ..., 3.2 , 0.68 , 9.8 ],\n",
|
||
|
" [ 7.8 , 0.76 , 0.04 , ..., 3.26 , 0.65 , 9.8 ],\n",
|
||
|
" ...,\n",
|
||
|
" [ 6.3 , 0.51 , 0.13 , ..., 3.42 , 0.75 , 11. ],\n",
|
||
|
" [ 5.9 , 0.645, 0.12 , ..., 3.57 , 0.71 , 10.2 ],\n",
|
||
|
" [ 6. , 0.31 , 0.47 , ..., 3.39 , 0.66 , 11. ]]),\n",
|
||
|
" array([[5],\n",
|
||
|
" [5],\n",
|
||
|
" [5],\n",
|
||
|
" ...,\n",
|
||
|
" [6],\n",
|
||
|
" [5],\n",
|
||
|
" [6]], dtype=int64))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def dataframe_to_arrays(dataframe):\n",
|
||
|
" dataframe1 = dataframe_raw.copy(deep=True)\n",
|
||
|
" inputs_array = dataframe1[input_cols].to_numpy()\n",
|
||
|
" targets_array = dataframe1[output_cols].to_numpy()\n",
|
||
|
" return inputs_array, targets_array\n",
|
||
|
"\n",
|
||
|
"inputs_array, targets_array = dataframe_to_arrays(dataframe_raw)\n",
|
||
|
"inputs_array, targets_array"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "705fb5b7",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(tensor([[ 7.4000, 0.7000, 0.0000, ..., 3.5100, 0.5600, 9.4000],\n",
|
||
|
" [ 7.8000, 0.8800, 0.0000, ..., 3.2000, 0.6800, 9.8000],\n",
|
||
|
" [ 7.8000, 0.7600, 0.0400, ..., 3.2600, 0.6500, 9.8000],\n",
|
||
|
" ...,\n",
|
||
|
" [ 6.3000, 0.5100, 0.1300, ..., 3.4200, 0.7500, 11.0000],\n",
|
||
|
" [ 5.9000, 0.6450, 0.1200, ..., 3.5700, 0.7100, 10.2000],\n",
|
||
|
" [ 6.0000, 0.3100, 0.4700, ..., 3.3900, 0.6600, 11.0000]]),\n",
|
||
|
" tensor([[5.],\n",
|
||
|
" [5.],\n",
|
||
|
" [5.],\n",
|
||
|
" ...,\n",
|
||
|
" [6.],\n",
|
||
|
" [5.],\n",
|
||
|
" [6.]]))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"inputs = torch.from_numpy(inputs_array).type(torch.float)\n",
|
||
|
"targets = torch.from_numpy(targets_array).type(torch.float)\n",
|
||
|
"inputs,targets"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "71f14b4a",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<torch.utils.data.dataset.TensorDataset at 0x16db5c32af0>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"dataset = TensorDataset(inputs, targets)\n",
|
||
|
"dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "c4f8cd40",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_ds, val_ds = random_split(dataset, [1300, 299])\n",
|
||
|
"batch_size=50\n",
|
||
|
"train_loader = DataLoader(train_ds, batch_size, shuffle=True)\n",
|
||
|
"val_loader = DataLoader(val_ds, batch_size)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "56f75067",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class WineQuality(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.linear = nn.Linear(input_size,output_size) \n",
|
||
|
" \n",
|
||
|
" def forward(self, xb): \n",
|
||
|
" out = self.linear(xb)\n",
|
||
|
" return out\n",
|
||
|
" \n",
|
||
|
" def training_step(self, batch):\n",
|
||
|
" inputs, targets = batch \n",
|
||
|
" # Generate predictions\n",
|
||
|
" out = self(inputs) \n",
|
||
|
" # Calcuate loss\n",
|
||
|
" loss = F.l1_loss(out,targets) \n",
|
||
|
" return loss\n",
|
||
|
" \n",
|
||
|
" def validation_step(self, batch):\n",
|
||
|
" inputs, targets = batch\n",
|
||
|
" # Generate predictions\n",
|
||
|
" out = self(inputs)\n",
|
||
|
" # Calculate loss\n",
|
||
|
" loss = F.l1_loss(out,targets) \n",
|
||
|
" return {'val_loss': loss.detach()}\n",
|
||
|
" \n",
|
||
|
" def validation_epoch_end(self, outputs):\n",
|
||
|
" batch_losses = [x['val_loss'] for x in outputs]\n",
|
||
|
" epoch_loss = torch.stack(batch_losses).mean() \n",
|
||
|
" return {'val_loss': epoch_loss.item()}\n",
|
||
|
" \n",
|
||
|
" def epoch_end(self, epoch, result, num_epochs):\n",
|
||
|
" # Print result every 100th epoch\n",
|
||
|
" if (epoch+1) % 100 == 0 or epoch == num_epochs-1:\n",
|
||
|
" print(\"Epoch [{}], val_loss: {:.4f}\".format(epoch+1, result['val_loss']))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "57f354ce",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"input_size = len(input_cols)\n",
|
||
|
"output_size = len(output_cols)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "4a926cfa",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"model=WineQuality()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "3df1733d",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def evaluate(model, val_loader):\n",
|
||
|
" outputs = [model.validation_step(batch) for batch in val_loader]\n",
|
||
|
" return model.validation_epoch_end(outputs)\n",
|
||
|
"\n",
|
||
|
"def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):\n",
|
||
|
" history = []\n",
|
||
|
" optimizer = opt_func(model.parameters(), lr)\n",
|
||
|
" for epoch in range(epochs):\n",
|
||
|
" for batch in train_loader:\n",
|
||
|
" loss = model.training_step(batch)\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" result = evaluate(model, val_loader)\n",
|
||
|
" model.epoch_end(epoch, result, epochs)\n",
|
||
|
" history.append(result)\n",
|
||
|
" return history"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "3ed5f872",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch [100], val_loss: 5.1117\n",
|
||
|
"Epoch [200], val_loss: 1.7651\n",
|
||
|
"Epoch [300], val_loss: 1.4800\n",
|
||
|
"Epoch [400], val_loss: 1.3942\n",
|
||
|
"Epoch [500], val_loss: 1.3119\n",
|
||
|
"Epoch [600], val_loss: 1.2326\n",
|
||
|
"Epoch [700], val_loss: 1.1571\n",
|
||
|
"Epoch [800], val_loss: 1.0863\n",
|
||
|
"Epoch [900], val_loss: 1.0224\n",
|
||
|
"Epoch [1000], val_loss: 0.9642\n",
|
||
|
"Epoch [1100], val_loss: 0.9100\n",
|
||
|
"Epoch [1200], val_loss: 0.8617\n",
|
||
|
"Epoch [1300], val_loss: 0.8200\n",
|
||
|
"Epoch [1400], val_loss: 0.7816\n",
|
||
|
"Epoch [1500], val_loss: 0.7484\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"epochs = 1500\n",
|
||
|
"lr = 1e-6\n",
|
||
|
"history5 = fit(epochs, lr, model, train_loader, val_loader)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"id": "413ab394",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def predict_single(input, target, model):\n",
|
||
|
" inputs = input.unsqueeze(0)\n",
|
||
|
" predictions = model(inputs)\n",
|
||
|
" prediction = predictions[0].detach()\n",
|
||
|
" #print(\"Input:\", input)\n",
|
||
|
" #print(\"Target:\", target)\n",
|
||
|
" #print(\"Prediction:\", prediction)\n",
|
||
|
" print(\"Target: \", target, \"----- Prediction: \", prediction)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"id": "b1ab4522",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Target: tensor([6.]) ----- Prediction: tensor([5.1011])\n",
|
||
|
"Target: tensor([6.]) ----- Prediction: tensor([7.1398])\n",
|
||
|
"Target: tensor([5.]) ----- Prediction: tensor([5.1009])\n",
|
||
|
"Target: tensor([6.]) ----- Prediction: tensor([5.2282])\n",
|
||
|
"Target: tensor([5.]) ----- Prediction: tensor([4.8219])\n",
|
||
|
"Target: tensor([6.]) ----- Prediction: tensor([4.8082])\n",
|
||
|
"Target: tensor([7.]) ----- Prediction: tensor([5.0764])\n",
|
||
|
"Target: tensor([5.]) ----- Prediction: tensor([6.3668])\n",
|
||
|
"Target: tensor([6.]) ----- Prediction: tensor([5.0642])\n",
|
||
|
"Target: tensor([5.]) ----- Prediction: tensor([5.4656])\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"#wylosuj 10 próbek predykcji\n",
|
||
|
"for i in random.sample(range(0, len(val_ds)), 10):\n",
|
||
|
" input_, target = val_ds[i]\n",
|
||
|
" predict_single(input_, target, model)\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "0237aad2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.7"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|