ium_434695/vgsales-mlflow.py
s434695 c01ba17840
Some checks failed
s434695-training/pipeline/head There was a failure building this commit
elo
2021-05-23 19:31:11 +02:00

69 lines
2.1 KiB
Python

import sys
from tensorflow.keras.backend import batch_dot, mean
import pandas as pd
import numpy as np
from six import int2byte
from sklearn import preprocessing
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Activation,Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import Sequential
import mlflow
def my_main(epochs, batch_size):
vgsales=pd.read_csv('vgsales.csv')
vgsales['Nintendo'] = vgsales['Publisher'].apply(lambda x: 1 if x=='Nintendo' else 0)
Y = vgsales['Nintendo']
X = vgsales.drop(['Rank','Name','Platform','Year','Genre','Publisher','Nintendo'],axis = 1)
X_train, X_test, y_train, y_test = train_test_split(X,Y , test_size=0.2,train_size=0.8, random_state=21)
model = Sequential()
model.add(Dense(9, input_dim = X_train.shape[1], kernel_initializer='normal', activation='relu'))
model.add(Dense(1,kernel_initializer='normal', activation='sigmoid'))
early_stop = EarlyStopping(monitor="val_loss", mode="min", verbose=1, patience=10)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_test, y_test))
prediction = model.predict(X_test)
rmse = mean_squared_error(y_test, prediction)
model.save('vgsales_model.h5')
return rmse, model, X_train, y_train
epochs = int(sys.argv[1]) if len(sys.argv) > 1 else 15
batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else 16
with mlflow.start_run():
rmse, model, x_train, y_train = my_main(epochs, batch_size)
mlflow.log_param("epochs", epochs)
mlflow.log_param("batch_size", batch_size)
mlflow.log_metric("rmse", rmse)
#mlflow.keras.log_model(model, 'vgsales_model.h5')
mlflow.keras.save_model(model, "my_model", signature=mlflow.models.signature.infer_signature(x_train, y_train), input_example=x_train)