add option to save the model
This commit is contained in:
parent
069d2c3919
commit
ca54c0550a
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Script creating an ML model for a given corpus.
|
||||
Script creating an bayesian model for a given corpus.
|
||||
Based on:
|
||||
https://towardsdatascience.com/naive-bayes-document-classification-in-python-e33ff50f937e
|
||||
"""
|
||||
@ -13,64 +13,59 @@ import re
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
|
||||
def normalize(text, links=False, numbers=False):
|
||||
# to do: normalize links and maybe numbers
|
||||
punctuation = string.punctuation + "„“”«»‚’-–…"
|
||||
return ''.join([c for c in text.lower() if c not in punctuation])
|
||||
|
||||
def train(xs, ys):
|
||||
# vectorize on default setting
|
||||
cv = CountVectorizer()
|
||||
xs_cv = cv.fit_transform(xs).toarray()
|
||||
mnb = MultinomialNB().fit(xs_cv, ys)
|
||||
return mnb, cv
|
||||
|
||||
|
||||
# arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
"Train model on a given corpus.")
|
||||
parser.add_argument('filename')
|
||||
parser.add_argument('--save-model', dest='save', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
# load and prepare data
|
||||
derp = pd.read_csv(args.filename, sep='\t', header=0)
|
||||
derp['paragraf'] = derp['paragraf'].apply(normalize)
|
||||
df = pd.read_csv(args.filename, sep='\t', header=0)
|
||||
df['paragraf'] = df['paragraf'].apply(normalize)
|
||||
label_dict = {
|
||||
'hipoteza':0,'rzeczowe':1, 'logiczne':2, 'emocjonalne':3, 'inne':4 }
|
||||
derp['tag'] = derp['tag'].apply(lambda x: label_dict[x])
|
||||
df['tag'] = df['tag'].apply(lambda x: label_dict[x])
|
||||
|
||||
# split into sets
|
||||
from sklearn.model_selection import train_test_split
|
||||
par_train, par_test, tag_train, tag_test = train_test_split(derp['paragraf'],
|
||||
derp['tag'], random_state=2)
|
||||
|
||||
# vectorize on default settings
|
||||
# TO DO remove stop-words
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
stop_words = pd.read_csv("stop_words.tsv", sep='\t').iloc[:,0]
|
||||
cv = CountVectorizer()
|
||||
par_train_cv = cv.fit_transform(par_train).toarray()
|
||||
par_test_cv = cv.transform(par_test).toarray()
|
||||
print(len(par_test_cv))
|
||||
if args.save:
|
||||
# save into model, no need for testing
|
||||
model, cv = train(df['paragraf'], df['tag'])
|
||||
pickle.dump(model, open("arg_model.pkl", 'wb'))
|
||||
pickle.dump(cv, open("arg_vector.pkl", 'wb'))
|
||||
|
||||
from sklearn.naive_bayes import MultinomialNB
|
||||
gnb = MultinomialNB().fit(par_train_cv, tag_train)
|
||||
gnb_predictions = gnb.predict(par_test_cv)
|
||||
|
||||
# accuracy on X_test
|
||||
accuracy = gnb.score(par_test_cv, tag_test)
|
||||
print(accuracy)
|
||||
else:
|
||||
# split into sets
|
||||
par_train, par_test, tag_train, tag_test = train_test_split(df['paragraf'],
|
||||
df['tag'], random_state=2)
|
||||
|
||||
# creating a confusion matrix
|
||||
from sklearn.metrics import confusion_matrix
|
||||
cm = confusion_matrix(tag_test, gnb_predictions)
|
||||
print(cm)
|
||||
model, cv = train(par_train, tag_train)
|
||||
par_test_cv = cv.transform(par_test).toarray()
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
plt.clf()
|
||||
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Greens)
|
||||
classNames = ['hipoteza','rzeczowe', 'logiczne', 'emocjonalne', 'inne']
|
||||
plt.title('schematy argumentacji')
|
||||
plt.ylabel('True label')
|
||||
plt.xlabel('Predicted label')
|
||||
tick_marks = np.arange(len(classNames))
|
||||
plt.xticks(tick_marks, classNames, rotation=45)
|
||||
plt.yticks(tick_marks, classNames)
|
||||
for i in range(5):
|
||||
for j in range(5):
|
||||
plt.text(j,i, str(cm[i][j]))
|
||||
plt.show()
|
||||
accuracy = model.score(par_test_cv, tag_test)
|
||||
print(accuracy)
|
||||
|
||||
predictions = model.predict(par_test_cv)
|
||||
cm = confusion_matrix(tag_test, predictions)
|
||||
print(cm)
|
||||
|
Loading…
Reference in New Issue
Block a user