ium_434695/train.py
s434695 f40f237a02
Some checks failed
s434695-training/pipeline/head There was a failure building this commit
fix2
2021-05-16 22:00:03 +02:00

75 lines
2.6 KiB
Python
Executable File

import sys
import pandas as pd
import numpy as np
from sklearn import preprocessing
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Activation,Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import Sequential
from sacred import Experiment
from datetime import datetime
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
import pymongo
ex = Experiment("434695-mongo", interactive=False, save_git_info=False)
ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@172.17.0.1:27017', db_name='sacred'))
ex.observers.append(FileStorageObserver('my_runs'))
@ex.config
def my_config():
batch_param = int(sys.argv[1])
epoch_param = int(sys.argv[2])
@ex.capture
def prepare_model(epoch_param, batch_param, _run):
_run.info["prepare_model_ts"] = str(datetime.now())
vgsales_train = pd.read_csv('train.csv')
vgsales_test = pd.read_csv('test.csv')
vgsales_dev = pd.read_csv('dev.csv')
vgsales_train['Nintendo'] = vgsales_train['Publisher'].apply(lambda x: 1 if x=='Nintendo' else 0)
vgsales_test['Nintendo'] = vgsales_test['Publisher'].apply(lambda x: 1 if x=='Nintendo' else 0)
vgsales_dev['Nintendo'] = vgsales_dev['Publisher'].apply(lambda x: 1 if x=='Nintendo' else 0)
X_train = vgsales_train.drop(['Rank','Name','Platform','Year','Genre','Publisher'],axis = 1)
y_train = vgsales_train[['Nintendo']]
X_test = vgsales_test.drop(['Rank','Name','Platform','Year','Genre','Publisher'],axis = 1)
y_test = vgsales_test[['Nintendo']]
print(X_train.shape[1])
model = Sequential()
model.add(Dense(9, input_dim = X_train.shape[1], kernel_initializer='normal', activation='relu'))
model.add(Dense(1,kernel_initializer='normal', activation='sigmoid'))
early_stop = EarlyStopping(monitor="val_loss", mode="min", verbose=1, patience=10)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
epochs = int(sys.argv[1])
batch_size = int(sys.argv[2])
model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_test, y_test))
prediction = model.predict(X_test)
rmse = mean_squared_error(y_test, prediction)
_run.log_scalar("rmse", rmse)
model.save('vgsales_model.h5')
return rmse
@ex.main
def my_main(epoch_param, batch_param):
print(prepare_model())
r = ex.run()
ex.add_artifact("vgsales_model.h5")