AdamOsiowy123
18e26fed6b
Some checks failed
s444452-training/pipeline/head There was a failure building this commit
121 lines
3.9 KiB
Python
121 lines
3.9 KiB
Python
#!/usr/bin/python
|
|
import os
|
|
import sys
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from keras.models import Sequential
|
|
from keras import layers
|
|
from keras.preprocessing.text import Tokenizer
|
|
from keras.preprocessing.sequence import pad_sequences
|
|
from sacred.observers import MongoObserver
|
|
from sacred.observers import FileStorageObserver
|
|
from sacred import Experiment
|
|
from mlflow.models.signature import infer_signature
|
|
import mlflow
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.WARN)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
mlflow.set_tracking_uri("http://172.17.0.1:5000")
|
|
mlflow.set_experiment("s444452")
|
|
|
|
ex = Experiment(name='s444452_fake_job_classification_training', save_git_info=False)
|
|
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017',
|
|
db_name='sacred'))
|
|
ex.observers.append(FileStorageObserver('my_runs'))
|
|
|
|
data_path = sys.argv[1]
|
|
epochs = int(sys.argv[2])
|
|
num_words = int(sys.argv[3])
|
|
batch_size = int(sys.argv[4])
|
|
pad_length = int(sys.argv[5])
|
|
|
|
|
|
@ex.config
|
|
def config():
|
|
data_path = data_path
|
|
epochs = epochs
|
|
num_words = num_words
|
|
batch_size = batch_size
|
|
pad_length = pad_length
|
|
|
|
|
|
@ex.capture
|
|
def tokenize(x, x_train, pad_length, num_words):
|
|
tokenizer = Tokenizer(num_words=num_words)
|
|
tokenizer.fit_on_texts(x)
|
|
train_x = tokenizer.texts_to_sequences(x_train)
|
|
vocabulary_length = len(tokenizer.word_index) + 1
|
|
train_x = pad_sequences(train_x, padding='post', maxlen=pad_length)
|
|
return train_x, vocabulary_length
|
|
|
|
|
|
def save_model(model):
|
|
# model_name = 'neural_net_' + datetime.datetime.today().strftime('%d-%b-%Y-%H:%M:%S')
|
|
model_name = 'neural_net'
|
|
model.save(os.path.join(os.getcwd(), 'model', model_name), save_format='h5', overwrite=True)
|
|
ex.add_artifact(os.path.join(os.getcwd(), 'model', model_name))
|
|
|
|
|
|
@ex.capture
|
|
def train_model(model, x_train, y_train, epochs, batch_size):
|
|
model.fit(x_train, y_train, epochs=epochs, verbose=False, batch_size=batch_size)
|
|
|
|
|
|
@ex.capture
|
|
def get_model(vocabulary_length, batch_size, pad_length):
|
|
model = Sequential()
|
|
model.add(layers.Embedding(input_dim=vocabulary_length,
|
|
output_dim=batch_size,
|
|
input_length=pad_length))
|
|
model.add(layers.Flatten())
|
|
model.add(layers.Dense(10, activation='relu'))
|
|
model.add(layers.Dense(1, activation='sigmoid'))
|
|
model.compile(optimizer='adam',
|
|
loss='binary_crossentropy',
|
|
metrics=['accuracy'])
|
|
return model
|
|
|
|
|
|
def split_data(data):
|
|
x = data['tokens']
|
|
y = data['fraudulent']
|
|
return x, y
|
|
|
|
|
|
def load_data(data_path, filename) -> pd.DataFrame:
|
|
return pd.read_csv(os.path.join(data_path, filename))
|
|
|
|
|
|
@ex.main
|
|
def main(data_path, num_words, epochs, batch_size, pad_length, _run):
|
|
with mlflow.start_run() as mlflow_run:
|
|
print("MLflow run experiment_id: {0}".format(mlflow_run.info.experiment_id))
|
|
print("MLflow run artifact_uri: {0}".format(mlflow_run.info.artifact_uri))
|
|
mlflow.log_param("data_path", data_path)
|
|
mlflow.log_param("num_words", num_words)
|
|
mlflow.log_param("epochs", epochs)
|
|
mlflow.log_param("batch_size", batch_size)
|
|
mlflow.log_param("pad_length", pad_length)
|
|
|
|
abs_data_path = os.path.abspath(data_path)
|
|
train_data = load_data(abs_data_path, 'train_data.csv')
|
|
test_data = load_data(abs_data_path, 'test_data.csv')
|
|
x_train, y_train = split_data(train_data)
|
|
x_test, _ = split_data(test_data)
|
|
x_train, vocab_size = tokenize(pd.concat([x_train, x_test]), x_train)
|
|
model = get_model(vocab_size)
|
|
train_model(model, x_train, y_train)
|
|
save_model(model)
|
|
|
|
signature = infer_signature(x_train, y_train)
|
|
input_example = np.array(x_test[:20])
|
|
mlflow.keras.log_model(model, "model", signature=signature, input_example=input_example)
|
|
|
|
|
|
warnings.filterwarnings("ignore")
|
|
ex.run()
|