UMA-projekt/run_models.py

78 lines
3.1 KiB
Python
Raw Normal View History

2022-06-19 13:16:05 +02:00
import pandas as pd
import numpy as np
from train_naive_bayes import naive_bayes
from train_lstm import lstm
from train_bert import bert
import matplotlib.pyplot as plt
from termcolor import colored
def run():
# Załadowanie zbioru danych
df_train = pd.read_csv('train.txt', header=None, sep=';', names=['Input', 'Sentiment'], encoding='utf-8')
df_test = pd.read_csv('test.txt', header=None, sep=';', names=['Input', 'Sentiment'], encoding='utf-8')
# Wyświetlenie kilku przykładów ze zbiorów
print('TRAIN SET:')
print(df_train[3:6])
print()
print('TEST SET:')
print(df_test[3:6])
# Przykładowa wizualizacja - etykiety zbioru treningowego
category_counts = {}
for value in df_train['Sentiment']:
category_counts[value] = category_counts.get(value, 0) + 1
lists = sorted(category_counts.items())
x, y = zip(*lists)
labels = []
sizes = []
for x, y in lists:
labels.append(x)
sizes.append(y)
plt.close('all')
plt.pie(sizes, labels=labels, autopct=lambda x: '{:.0f}'.format(x * sum(sizes) / 100))
plt.title(f'Sentiment analysis model training set with a total of {len(df_train)} examples')
plt.axis('equal')
print(colored('### Displaying training set data, close the display to continue ###', 'green'))
plt.show()
# Zamiana etykiet emocji z tekstu na liczby
category_mapping = {'anger': 0, 'fear': 1, 'joy': 2, 'love': 3, 'sadness': 4, 'surprise': 5}
df_train['Sentiment'] = df_train['Sentiment'].map(category_mapping)
df_test['Sentiment'] = df_test['Sentiment'].map(category_mapping)
# Interaktywne uruchomienie po kolei wszystkich modeli i zapisanie wyników
print(colored('### Press any key to run the naive bayes model ###', 'green'))
input()
bayes_results = naive_bayes(df_train, df_test)
print(colored('### Press any key to run the LSTM model ###', 'green'))
input()
lstm_results = lstm(df_train, df_test)
print(colored('### Press any key to run the BERT model ###', 'green'))
input()
bert_results = bert(df_train, df_test)
# Przykładowa wizualizacja - wyniki na podstawie metryki F1 dla wszystkich modeli
plt.close()
y_bayes = [bayes_results['accuracy'], bayes_results['macro avg']['f1-score'], bayes_results['weighted avg']['f1-score']]
y_lstm = [lstm_results['accuracy'], lstm_results['macro avg']['f1-score'], lstm_results['weighted avg']['f1-score']]
y_bert = [bert_results['accuracy'], bert_results['macro avg']['f1-score'], bert_results['weighted avg']['f1-score']]
x = ['Accuracy', 'Macro avg', 'Weighted avg']
x_axis = np.arange(len(x))
plt.xticks(x_axis, x)
plt.ylim(0, 1)
plt.bar(x_axis - 0.2, y_bayes, 0.2, label='Naive bayes')
plt.bar(x_axis, y_lstm, 0.2, label='LSTM')
plt.bar(x_axis + 0.2, y_bert, 0.2, label='BERT')
plt.xlabel('Metric')
plt.ylabel('Score')
plt.title('F1-scores per model', y=1.05)
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3, fancybox=True, framealpha=1)
print(colored('### Displaying F1-scores for all models, close the display to finish ###', 'green'))
plt.show()
if __name__ == '__main__':
run()