diff --git a/DL-model.py b/DL-model.py index 0bf5930..70d408e 100644 --- a/DL-model.py +++ b/DL-model.py @@ -1,5 +1,5 @@ ## Klasyfikacja jakości diamentu -import os +import argparse import pandas as pd import numpy as np import pickle @@ -10,8 +10,11 @@ from tensorflow.keras.callbacks import History from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder from tensorflow.keras.utils import to_categorical -#Wyświetlenie zbioru danych -epochs = int(os.environ.get('EPOCHS', 10)) +# Wczytanie parametru epochs +parser = argparse.ArgumentParser() +parser.add_argument("--epochs", type=int, default=10, help="Number of epochs") +args = parser.parse_args() +epochs = args.epochs # Wczytanie danych data_train = pd.read_csv('dane/diamonds_train.csv')