From 5902c72f2ed0c37f2d3948aca376d3b5909587ad Mon Sep 17 00:00:00 2001 From: Marcin Armacki Date: Sun, 14 Jun 2020 15:21:13 +0200 Subject: [PATCH] Implemented model in backend --- .../webapp/prototype/filehandler/functions.py | 41 +++++++++++++++++- .../webapp/prototype/filehandler/labels.pkl | Bin 0 -> 81 bytes backend/webapp/prototype/filehandler/views.py | 5 ++- 3 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 backend/webapp/prototype/filehandler/labels.pkl diff --git a/backend/webapp/prototype/filehandler/functions.py b/backend/webapp/prototype/filehandler/functions.py index 6d39e2c..24300a9 100644 --- a/backend/webapp/prototype/filehandler/functions.py +++ b/backend/webapp/prototype/filehandler/functions.py @@ -1,7 +1,41 @@ +from pandas import DataFrame, concat +from joblib import load +import string +import re +from sklearn.feature_extraction.text import TfidfVectorizer, TfidfTransformer, CountVectorizer from prototype.filehandler.models import Forum, Discussion, Post, Paragraph +def count_punct(text): + count = sum([1 for char in text if char in string.punctuation]) + return round(count/(len(text) - text.count(" ")), 3)*100 + +def createLabels(data): + id_to_labels = load('prototype/filehandler/labels.pkl') + df = DataFrame(data['messages'], columns = ['body_text']) + + model = load('prototype/filehandler/model.pkl') + + df['body_len'] = df['body_text'].apply(lambda x: len(x) - x.count(" ")) + df['punct%'] = df['body_text'].apply(lambda x: count_punct(x)) + + transformer = TfidfTransformer() + loaded_vec = CountVectorizer(decode_error = "replace", vocabulary = load('prototype/filehandler/vocabulary.pkl')) + transformed = transformer.fit_transform(loaded_vec.fit_transform(df.body_text).toarray()) + + features = concat([df[['body_len', 'punct%']], DataFrame(transformed.toarray())], axis=1) + + pred = model.predict(features) + labels = list(map(id_to_labels.get, pred)) + + for id, label in zip(data['para_id'], labels): + Paragraph.objects.filter(pk = id).update(label = label) + + return(True) + def addToDatabase(data, file_id): - out = [] + out = {} + para_id = [] + messages = [] forum = Forum(forum_id = data['id'], name = data['name'], document_id = file_id) forum.save() for discussion_ in data['discussions']: @@ -13,7 +47,10 @@ def addToDatabase(data, file_id): for paragraph_ in post_['message']: paragraph = Paragraph(message = paragraph_, label = '', post_id = post.pk) paragraph.save() - out.append(paragraph.pk) + para_id.append(paragraph.pk) + messages.append(paragraph_) + out['para_id'] = para_id + out['messages'] = messages return(out) def listDiscussionsFromFile(id): diff --git a/backend/webapp/prototype/filehandler/labels.pkl b/backend/webapp/prototype/filehandler/labels.pkl new file mode 100644 index 0000000000000000000000000000000000000000..6ca3f02faf0947e683b22a551281cdc023dd42d5 GIT binary patch literal 81 zcmZo*nd-&>0ku;!ycv49@-kC1i%arL@)D