diff --git a/DL-model.py b/DL-model.py index 9339ad1..0bf5930 100644 --- a/DL-model.py +++ b/DL-model.py @@ -1,5 +1,5 @@ ## Klasyfikacja jakości diamentu - +import os import pandas as pd import numpy as np import pickle @@ -10,6 +10,9 @@ from tensorflow.keras.callbacks import History from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder from tensorflow.keras.utils import to_categorical +#Wyświetlenie zbioru danych +epochs = int(os.environ.get('EPOCHS', 10)) + # Wczytanie danych data_train = pd.read_csv('dane/diamonds_train.csv') data_test = pd.read_csv('dane/diamonds_test.csv') @@ -74,7 +77,7 @@ model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['ac # Trenowanie history = History() -model.fit(X_train_scaled, y_train_encoded, epochs=10, batch_size=32, validation_data=(X_val_scaled, y_val_encoded), callbacks=[history]) +model.fit(X_train_scaled, y_train_encoded, epochs=epochs, batch_size=32, validation_data=(X_val_scaled, y_val_encoded), callbacks=[history]) # Zapisywanie modelu do pliku saved_model = [model,