66 KiB
66 KiB
import pandas as pd
import sklearn.model_selection
import requests
url = "https://huggingface.co/datasets/mstz/wine/raw/main/Wine_Quality_Data.csv"
save_path = "Wine_Quality_Data.csv"
response = requests.get(url)
response.raise_for_status()
with open(save_path, "wb") as f:
f.write(response.content)
wine_dataset = pd.read_csv("Wine_Quality_Data.csv")
wine_dataset.head()# podgląd danych
fixed_acidity | volatile_acidity | citric_acid | residual_sugar | chlorides | free_sulfur_dioxide | total_sulfur_dioxide | density | pH | sulphates | alcohol | quality | color | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 7.4 | 0.70 | 0.00 | 1.9 | 0.076 | 11.0 | 34.0 | 0.9978 | 3.51 | 0.56 | 9.4 | 5 | red |
1 | 7.8 | 0.88 | 0.00 | 2.6 | 0.098 | 25.0 | 67.0 | 0.9968 | 3.20 | 0.68 | 9.8 | 5 | red |
2 | 7.8 | 0.76 | 0.04 | 2.3 | 0.092 | 15.0 | 54.0 | 0.9970 | 3.26 | 0.65 | 9.8 | 5 | red |
3 | 11.2 | 0.28 | 0.56 | 1.9 | 0.075 | 17.0 | 60.0 | 0.9980 | 3.16 | 0.58 | 9.8 | 6 | red |
4 | 7.4 | 0.70 | 0.00 | 1.9 | 0.076 | 11.0 | 34.0 | 0.9978 | 3.51 | 0.56 | 9.4 | 5 | red |
wine_dataset['color'] = wine_dataset['color'].replace({'red': 1, 'white': 0})
wine_dataset.describe(include='all')
fixed_acidity | volatile_acidity | citric_acid | residual_sugar | chlorides | free_sulfur_dioxide | total_sulfur_dioxide | density | pH | sulphates | alcohol | quality | color | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 |
mean | 7.215307 | 0.339666 | 0.318633 | 5.443235 | 0.056034 | 30.525319 | 115.744574 | 0.994697 | 3.218501 | 0.531268 | 10.491801 | 5.818378 | 0.246114 |
std | 1.296434 | 0.164636 | 0.145318 | 4.757804 | 0.035034 | 17.749400 | 56.521855 | 0.002999 | 0.160787 | 0.148806 | 1.192712 | 0.873255 | 0.430779 |
min | 3.800000 | 0.080000 | 0.000000 | 0.600000 | 0.009000 | 1.000000 | 6.000000 | 0.987110 | 2.720000 | 0.220000 | 8.000000 | 3.000000 | 0.000000 |
25% | 6.400000 | 0.230000 | 0.250000 | 1.800000 | 0.038000 | 17.000000 | 77.000000 | 0.992340 | 3.110000 | 0.430000 | 9.500000 | 5.000000 | 0.000000 |
50% | 7.000000 | 0.290000 | 0.310000 | 3.000000 | 0.047000 | 29.000000 | 118.000000 | 0.994890 | 3.210000 | 0.510000 | 10.300000 | 6.000000 | 0.000000 |
75% | 7.700000 | 0.400000 | 0.390000 | 8.100000 | 0.065000 | 41.000000 | 156.000000 | 0.996990 | 3.320000 | 0.600000 | 11.300000 | 6.000000 | 0.000000 |
max | 15.900000 | 1.580000 | 1.660000 | 65.800000 | 0.611000 | 289.000000 | 440.000000 | 1.038980 | 4.010000 | 2.000000 | 14.900000 | 9.000000 | 1.000000 |
wine_dataset["color"].value_counts().plot(kind="bar")
<Axes: >
wine_dataset["fixed_acidity"].std()
1.2964337577998153
import numpy as np
np.where(pd.isnull(wine_dataset))## sprawdzanie czy istnieją puste wartości
(array([], dtype=int64), array([], dtype=int64))
for column in wine_dataset.columns:
wine_dataset[column] = wine_dataset[column] / wine_dataset[column].abs().max() # normalizacja
wine_dataset.describe(include='all') # sprawdzanie wartości po znormalizowaniu
fixed_acidity | volatile_acidity | citric_acid | residual_sugar | chlorides | free_sulfur_dioxide | total_sulfur_dioxide | density | pH | sulphates | alcohol | quality | color | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 | 6497.000000 |
mean | 0.453793 | 0.214978 | 0.191948 | 0.082724 | 0.091708 | 0.105624 | 0.263056 | 0.957378 | 0.802619 | 0.265634 | 0.704148 | 0.646486 | 0.246114 |
std | 0.081537 | 0.104200 | 0.087541 | 0.072307 | 0.057338 | 0.061417 | 0.128459 | 0.002886 | 0.040097 | 0.074403 | 0.080048 | 0.097028 | 0.430779 |
min | 0.238994 | 0.050633 | 0.000000 | 0.009119 | 0.014730 | 0.003460 | 0.013636 | 0.950076 | 0.678304 | 0.110000 | 0.536913 | 0.333333 | 0.000000 |
25% | 0.402516 | 0.145570 | 0.150602 | 0.027356 | 0.062193 | 0.058824 | 0.175000 | 0.955110 | 0.775561 | 0.215000 | 0.637584 | 0.555556 | 0.000000 |
50% | 0.440252 | 0.183544 | 0.186747 | 0.045593 | 0.076923 | 0.100346 | 0.268182 | 0.957564 | 0.800499 | 0.255000 | 0.691275 | 0.666667 | 0.000000 |
75% | 0.484277 | 0.253165 | 0.234940 | 0.123100 | 0.106383 | 0.141869 | 0.354545 | 0.959585 | 0.827930 | 0.300000 | 0.758389 | 0.666667 | 0.000000 |
max | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 |
wine_dataset["fixed_acidity"].nlargest(10) #sprawdza czy najwyższe wartości mają sens
652 1.000000 442 0.981132 557 0.981132 554 0.974843 555 0.974843 243 0.943396 244 0.943396 544 0.899371 3125 0.893082 374 0.880503 Name: fixed_acidity, dtype: float64
from sklearn.model_selection import train_test_split
wine_train, wine_test = sklearn.model_selection.train_test_split(wine_dataset, test_size=0.1, random_state=1, stratify=wine_dataset["color"])
wine_train["color"].value_counts()
# podzielenie na train i test
0.0 4408 1.0 1439 Name: color, dtype: int64
wine_test["color"].value_counts()
0.0 490 1.0 160 Name: color, dtype: int64
wine_test, wine_val = sklearn.model_selection.train_test_split(wine_test, test_size=0.5, random_state=1, stratify=wine_test["color"]) # podzielenie na test i validation
wine_test["color"].value_counts()
0.0 245 1.0 80 Name: color, dtype: int64
wine_val["color"].value_counts()
0.0 245 1.0 80 Name: color, dtype: int64
import seaborn as sns
sns.set_theme()
len(wine_dataset.columns)
13
#sns.pairplot(data=wine_dataset, hue="color")
wine_test.describe()
fixed_acidity | volatile_acidity | citric_acid | residual_sugar | chlorides | free_sulfur_dioxide | total_sulfur_dioxide | density | pH | sulphates | alcohol | quality | color | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 |
mean | 0.460126 | 0.209883 | 0.197294 | 0.083839 | 0.096352 | 0.105307 | 0.272028 | 0.957685 | 0.799770 | 0.266477 | 0.691389 | 0.636239 | 0.246154 |
std | 0.087321 | 0.100971 | 0.086532 | 0.072172 | 0.066017 | 0.061895 | 0.131981 | 0.002780 | 0.038640 | 0.082243 | 0.073293 | 0.088732 | 0.431433 |
min | 0.308176 | 0.066456 | 0.000000 | 0.010638 | 0.026187 | 0.003460 | 0.020455 | 0.952030 | 0.698254 | 0.115000 | 0.577181 | 0.333333 | 0.000000 |
25% | 0.408805 | 0.139241 | 0.156627 | 0.027356 | 0.062193 | 0.058824 | 0.188636 | 0.955322 | 0.773067 | 0.215000 | 0.630872 | 0.555556 | 0.000000 |
50% | 0.440252 | 0.189873 | 0.186747 | 0.048632 | 0.078560 | 0.100346 | 0.275000 | 0.957978 | 0.795511 | 0.250000 | 0.671141 | 0.666667 | 0.000000 |
75% | 0.484277 | 0.240506 | 0.246988 | 0.121581 | 0.116203 | 0.145329 | 0.356818 | 0.959787 | 0.822943 | 0.305000 | 0.738255 | 0.666667 | 0.000000 |
max | 0.943396 | 0.715190 | 0.469880 | 0.303191 | 0.764321 | 0.479239 | 0.781818 | 0.966034 | 0.895262 | 0.975000 | 0.906040 | 0.888889 | 1.000000 |
wine_train.describe()
fixed_acidity | volatile_acidity | citric_acid | residual_sugar | chlorides | free_sulfur_dioxide | total_sulfur_dioxide | density | pH | sulphates | alcohol | quality | color | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 | 5847.000000 |
mean | 0.453724 | 0.215128 | 0.192091 | 0.082877 | 0.091656 | 0.105899 | 0.262834 | 0.957374 | 0.802637 | 0.265601 | 0.704572 | 0.646846 | 0.246109 |
std | 0.081597 | 0.104319 | 0.087166 | 0.072487 | 0.057502 | 0.061908 | 0.128388 | 0.002899 | 0.040030 | 0.074400 | 0.080399 | 0.097212 | 0.430780 |
min | 0.238994 | 0.050633 | 0.000000 | 0.009119 | 0.014730 | 0.003460 | 0.013636 | 0.950076 | 0.678304 | 0.110000 | 0.536913 | 0.333333 | 0.000000 |
25% | 0.402516 | 0.145570 | 0.150602 | 0.027356 | 0.062193 | 0.058824 | 0.175000 | 0.955110 | 0.775561 | 0.215000 | 0.637584 | 0.555556 | 0.000000 |
50% | 0.440252 | 0.183544 | 0.186747 | 0.045593 | 0.076923 | 0.100346 | 0.268182 | 0.957555 | 0.800499 | 0.255000 | 0.691275 | 0.666667 | 0.000000 |
75% | 0.484277 | 0.259494 | 0.234940 | 0.123100 | 0.106383 | 0.141869 | 0.354545 | 0.959585 | 0.827930 | 0.300000 | 0.758389 | 0.666667 | 0.000000 |
max | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.000000 |
wine_val.describe()
fixed_acidity | volatile_acidity | citric_acid | residual_sugar | chlorides | free_sulfur_dioxide | total_sulfur_dioxide | density | pH | sulphates | alcohol | quality | color | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 | 325.000000 |
mean | 0.448708 | 0.217381 | 0.184022 | 0.078864 | 0.088017 | 0.100985 | 0.258073 | 0.957147 | 0.805141 | 0.265385 | 0.709269 | 0.650256 | 0.246154 |
std | 0.073960 | 0.105388 | 0.094736 | 0.069232 | 0.043159 | 0.051174 | 0.126120 | 0.002746 | 0.042584 | 0.065946 | 0.079198 | 0.101225 | 0.431433 |
min | 0.301887 | 0.075949 | 0.000000 | 0.012158 | 0.026187 | 0.006920 | 0.018182 | 0.950731 | 0.683292 | 0.150000 | 0.570470 | 0.333333 | 0.000000 |
25% | 0.402516 | 0.145570 | 0.138554 | 0.028875 | 0.062193 | 0.058824 | 0.179545 | 0.954879 | 0.775561 | 0.215000 | 0.637584 | 0.555556 | 0.000000 |
50% | 0.433962 | 0.177215 | 0.186747 | 0.042553 | 0.076923 | 0.100346 | 0.256818 | 0.957189 | 0.805486 | 0.260000 | 0.697987 | 0.666667 | 0.000000 |
75% | 0.484277 | 0.253165 | 0.234940 | 0.117021 | 0.101473 | 0.138408 | 0.356818 | 0.959306 | 0.830424 | 0.305000 | 0.758389 | 0.666667 | 0.000000 |
max | 0.798742 | 0.696203 | 0.602410 | 0.303191 | 0.436989 | 0.280277 | 0.575000 | 0.962935 | 0.935162 | 0.490000 | 0.953020 | 0.888889 | 1.000000 |
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
class TabularDataset(Dataset):
def __init__(self, data):
self.data = data.values.astype('float32')
def __getitem__(self, index):
x = torch.tensor(self.data[index, :-1])
y = torch.tensor(self.data[index, -1])
return x, y
def __len__(self):
return len(self.data)
batch_size = 64
train_dataset = TabularDataset(wine_train)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = TabularDataset(wine_test)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
class TabularModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(TabularModel, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.softmax(out)
return out
input_dim = wine_train.shape[1] - 1
hidden_dim = 32
output_dim = 2
model = TabularModel(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
model = TabularModel(input_dim=len(wine_train.columns)-1, hidden_dim=32, output_dim=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
num_epochs = 10
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(train_dataloader, 0):
inputs, labels = data
labels = labels.type(torch.LongTensor)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# Print the loss every 1000 mini-batches
if (epoch%2) == 0:
print(f'Epoch {epoch + 1}, loss: {running_loss / len(train_dataloader):.4f}')
print('Finished Training')
Epoch 1, loss: 0.4864 Epoch 3, loss: 0.3413 Epoch 5, loss: 0.3345 Epoch 7, loss: 0.3337 Epoch 9, loss: 0.3331 Finished Training
correct = 0
total = 0
with torch.no_grad():
for data in test_dataloader:
inputs, labels = data
outputs = model(inputs.float())
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy on test set: %d %%' % (100 * correct / total))
Accuracy on test set: 98 %