ium_z487175/DL-model.py

127 lines
4.7 KiB
Python

## Klasyfikacja jakości diamentu
import argparse
import pandas as pd
import numpy as np
import pickle
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import History
from sklearn.preprocessing import LabelEncoder, StandardScaler, OneHotEncoder
from tensorflow.keras.utils import to_categorical
from sacred import Experiment
from sacred.observers import MongoObserver
from sacred.observers import FileStorageObserver
# Init sacred
ex = Experiment('z-s487175-training', interactive=True, save_git_info=False)
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', db_name='sacred'))
@ex.config
def config():
# Wczytanie parametru epochs
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
args = parser.parse_args()
epochs = args.epochs
@ex.main
def main(epochs, _run):
# Wczytanie danych
data_train = pd.read_csv('dane/diamonds_train.csv')
data_test = pd.read_csv('dane/diamonds_test.csv')
data_val = pd.read_csv('dane/diamonds_dev.csv')
# Podział na cechy (X) i etykiety (y)
X_train = data_train.drop('cut', axis=1)
y_train = data_train['cut']
X_test = data_test.drop('cut', axis=1)
y_test = data_test['cut']
X_val = data_val.drop('cut', axis=1)
y_val = data_val['cut']
# Konwersja danych kategorycznych na kodowanie one-hot
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(y_train)
y_test = label_encoder.transform(y_test)
y_val = label_encoder.transform(y_val)
y_train_encoded = to_categorical(y_train)
y_test_encoded = to_categorical(y_test)
y_val_encoded = to_categorical(y_val)
# Kodowanie kategorii tylko dla zbioru treningowego
categorical_cols = ['color', 'clarity']
encoder = OneHotEncoder(sparse=False, handle_unknown='ignore')
X_train_encoded = pd.DataFrame(encoder.fit_transform(X_train[categorical_cols]))
# Zakodowanie atrybutów dla zbiorów testowego i walidacyjnego
X_test_encoded = pd.DataFrame(encoder.transform(X_test[categorical_cols]))
X_val_encoded = pd.DataFrame(encoder.transform(X_val[categorical_cols]))
# Złączenie zakodowanych atrybutów z danymi numerycznymi
X_train_processed = pd.concat([X_train.drop(categorical_cols, axis=1), X_train_encoded], axis=1)
X_test_processed = pd.concat([X_test.drop(categorical_cols, axis=1), X_test_encoded], axis=1)
X_val_processed = pd.concat([X_val.drop(categorical_cols, axis=1), X_val_encoded], axis=1)
# Konwersja nazw kolumn na ciągi znaków
X_train_processed.columns = X_train_processed.columns.astype(str)
X_test_processed.columns = X_test_processed.columns.astype(str)
X_val_processed.columns = X_val_processed.columns.astype(str)
# Skalowanie cech
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_processed)
X_test_scaled = scaler.transform(X_test_processed)
X_val_scaled = scaler.transform(X_val_processed)
# Inicjalizacja modelu
model = Sequential()
model.add(Dense(128, activation='relu', input_dim=X_train_scaled.shape[1]))
model.add(Dropout(0.2))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(5, activation='softmax'))
# Kompilacja
optimizer = Adam(learning_rate=0.0001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
# Trenowanie
history = History()
model.fit(X_train_scaled, y_train_encoded, epochs=epochs, batch_size=32, validation_data=(X_val_scaled, y_val_encoded), callbacks=[history])
# Zapisywanie modelu do pliku
saved_model = [model,
X_train_scaled, y_train_encoded,
X_test_scaled, y_test_encoded,
X_val_scaled, y_val_encoded,
history]
with open('model_with_data.pickle', 'wb') as file:
pickle.dump(saved_model, file)
# Zapisanie - parametry, z którymi wywołany był trening
_run.info["epochs"] = epochs
# Zapisanie - powstały plik z modelem
ex.add_artifact('model_with_data.pickle', content_type='application/octet-stream')
# Zapisanie - kod źródłowy użyty do przeprowadzenia treningu
ex.add_artifact('DL-model.py')
# Zapisanie - pliki wejściowe
ex.add_artifact('dane/diamonds_train.csv')
ex.add_artifact('dane/diamonds_test.csv')
ex.add_artifact('dane/diamonds_dev.csv')
# Zapisanie - metryki
ex.log_scalar('accuracy', history.history['accuracy'][-1])
ex.log_scalar('val_accuracy', history.history['val_accuracy'][-1])
ex.run()