ium_444452/Scripts/train_neural_network.py

112 lines
3.6 KiB
Python
Raw Normal View History

2022-04-24 22:51:20 +02:00
#!/usr/bin/python
2022-05-02 19:21:28 +02:00
import datetime
2022-04-24 22:51:20 +02:00
import os
import pprint
import sys
import pandas as pd
from keras.models import Sequential, load_model
from keras import layers
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)
2022-05-02 22:51:16 +02:00
data_path = ''
num_words = 0
epochs = 0
batch_size = 0
pad_length = 0
2022-04-24 22:51:20 +02:00
2022-05-02 22:51:16 +02:00
def tokenize(x, x_train, x_test):
global pad_length, num_words
tokenizer = Tokenizer(num_words=num_words)
2022-04-24 22:51:20 +02:00
tokenizer.fit_on_texts(x)
train_x = tokenizer.texts_to_sequences(x_train)
test_x = tokenizer.texts_to_sequences(x_test)
vocabulary_length = len(tokenizer.word_index) + 1
2022-05-02 22:51:16 +02:00
train_x = pad_sequences(train_x, padding='post', maxlen=pad_length)
test_x = pad_sequences(test_x, padding='post', maxlen=pad_length)
2022-04-24 22:51:20 +02:00
return train_x, test_x, vocabulary_length
def evaluate_and_save(model, x, y, abs_path):
loss, accuracy = model.evaluate(x, y, verbose=False)
y_predicted = (model.predict(x) >= 0.5).astype(int)
with open(os.path.join(abs_path, 'neural_network_evaluation.txt'), "w") as log_file:
for obj in (
('Accuracy: ', accuracy), ('Loss: ', loss), ('Precision: ', precision_score(y, y_predicted)),
('Recall: ', recall_score(y, y_predicted)), ('F1: ', f1_score(y, y_predicted)),
('Accuracy: ', accuracy_score(y, y_predicted))):
pprint.pprint(obj, log_file)
def load_trained_model(abs_path, model_name):
return load_model(os.path.join(abs_path, model_name))
2022-05-02 19:21:28 +02:00
def save_model(model):
model_name = 'neural_net_' + datetime.datetime.today().strftime('%d-%b-%Y-%H:%M:%S')
model.save(os.path.join(os.getcwd(), 'model', model_name), save_format='h5', overwrite=True)
2022-04-24 22:51:20 +02:00
def train_model(model, x_train, y_train):
2022-05-02 22:51:16 +02:00
global epochs, batch_size
model.fit(x_train, y_train, epochs=epochs, verbose=False, batch_size=batch_size)
2022-04-24 22:51:20 +02:00
2022-05-02 22:51:16 +02:00
def get_model(vocabulary_length):
global pad_length, batch_size
2022-04-24 22:51:20 +02:00
model = Sequential()
model.add(layers.Embedding(input_dim=vocabulary_length,
2022-05-02 22:51:16 +02:00
output_dim=batch_size,
input_length=pad_length))
2022-04-24 22:51:20 +02:00
model.add(layers.Flatten())
model.add(layers.Dense(10, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
return model
def split_data(data):
x = data['tokens']
y = data['fraudulent']
return x, y
def load_data(data_path, filename) -> pd.DataFrame:
return pd.read_csv(os.path.join(data_path, filename))
2022-05-02 22:51:16 +02:00
def read_params():
global data_path, num_words, epochs, batch_size, pad_length
data_path, num_words, epochs, batch_size, pad_length = sys.argv[1].split(',')
2022-05-02 22:56:24 +02:00
num_words = int(num_words)
epochs = int(epochs)
batch_size = int(batch_size)
pad_length = int(pad_length)
2022-05-02 22:51:16 +02:00
2022-04-24 22:51:20 +02:00
def main():
2022-05-02 22:51:16 +02:00
read_params()
global data_path
2022-04-24 22:51:20 +02:00
abs_data_path = os.path.abspath(data_path)
train_data = load_data(abs_data_path, 'train_data.csv')
test_data = load_data(abs_data_path, 'test_data.csv')
x_train, y_train = split_data(train_data)
x_test, y_test = split_data(test_data)
2022-05-02 22:51:16 +02:00
x_train, x_test, vocab_size = tokenize(pd.concat([x_train, x_test]), x_train, x_test)
model = get_model(vocab_size)
2022-04-24 22:51:20 +02:00
train_model(model, x_train, y_train)
2022-05-02 19:21:28 +02:00
save_model(model)
2022-04-24 22:51:20 +02:00
evaluate_and_save(model, x_test, y_test, abs_data_path)
if __name__ == '__main__':
main()