Dodanie obsługi parametru epochs do lab 06 zad 1
This commit is contained in:
parent
606f47bb4e
commit
6462a5fce5
@ -1,5 +1,5 @@
|
|||||||
## Klasyfikacja jakości diamentu
|
## Klasyfikacja jakości diamentu
|
||||||
|
import os
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pickle
|
import pickle
|
||||||
@ -10,6 +10,9 @@ from tensorflow.keras.callbacks import History
|
|||||||
from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder
|
from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder
|
||||||
from tensorflow.keras.utils import to_categorical
|
from tensorflow.keras.utils import to_categorical
|
||||||
|
|
||||||
|
#Wyświetlenie zbioru danych
|
||||||
|
epochs = int(os.environ.get('EPOCHS', 10))
|
||||||
|
|
||||||
# Wczytanie danych
|
# Wczytanie danych
|
||||||
data_train = pd.read_csv('dane/diamonds_train.csv')
|
data_train = pd.read_csv('dane/diamonds_train.csv')
|
||||||
data_test = pd.read_csv('dane/diamonds_test.csv')
|
data_test = pd.read_csv('dane/diamonds_test.csv')
|
||||||
@ -74,7 +77,7 @@ model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['ac
|
|||||||
|
|
||||||
# Trenowanie
|
# Trenowanie
|
||||||
history = History()
|
history = History()
|
||||||
model.fit(X_train_scaled, y_train_encoded, epochs=10, batch_size=32, validation_data=(X_val_scaled, y_val_encoded), callbacks=[history])
|
model.fit(X_train_scaled, y_train_encoded, epochs=epochs, batch_size=32, validation_data=(X_val_scaled, y_val_encoded), callbacks=[history])
|
||||||
|
|
||||||
# Zapisywanie modelu do pliku
|
# Zapisywanie modelu do pliku
|
||||||
saved_model = [model,
|
saved_model = [model,
|
||||||
|
Loading…
Reference in New Issue
Block a user