Dodanie obsługi parametru epochs do lab 06 zad 1

This commit is contained in:
Norbert Walkowiak 2023-06-07 21:04:15 +02:00
parent 606f47bb4e
commit 6462a5fce5

View File

@ -1,5 +1,5 @@
## Klasyfikacja jakości diamentu
import os
import pandas as pd
import numpy as np
import pickle
@ -10,6 +10,9 @@ from tensorflow.keras.callbacks import History
from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder
from tensorflow.keras.utils import to_categorical
#Wyświetlenie zbioru danych
epochs = int(os.environ.get('EPOCHS', 10))
# Wczytanie danych
data_train = pd.read_csv('dane/diamonds_train.csv')
data_test = pd.read_csv('dane/diamonds_test.csv')
@ -74,7 +77,7 @@ model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['ac
# Trenowanie
history = History()
model.fit(X_train_scaled, y_train_encoded, epochs=10, batch_size=32, validation_data=(X_val_scaled, y_val_encoded), callbacks=[history])
model.fit(X_train_scaled, y_train_encoded, epochs=epochs, batch_size=32, validation_data=(X_val_scaled, y_val_encoded), callbacks=[history])
# Zapisywanie modelu do pliku
saved_model = [model,