diff --git a/backend/webapp/prototype/filehandler/functions.py b/backend/webapp/prototype/filehandler/functions.py index 6d39e2c..35b4c5e 100644 --- a/backend/webapp/prototype/filehandler/functions.py +++ b/backend/webapp/prototype/filehandler/functions.py @@ -1,7 +1,41 @@ +from pandas import DataFrame, concat +from joblib import load +import string +import re +from sklearn.feature_extraction.text import TfidfVectorizer, TfidfTransformer, CountVectorizer from prototype.filehandler.models import Forum, Discussion, Post, Paragraph +def count_punct(text): + count = sum([1 for char in text if char in string.punctuation]) + return round(count/(len(text) - text.count(" ")), 3)*100 + +def createLabels(data): + id_to_labels = load('prototype/filehandler/labels.pkl') + df = DataFrame(data['messages'], columns = ['body_text']) + + model = load('prototype/filehandler/model.pkl') + + df['body_len'] = df['body_text'].apply(lambda x: len(x) - x.count(" ")) + df['punct%'] = df['body_text'].apply(lambda x: count_punct(x)) + + transformer = TfidfTransformer() + loaded_vec = CountVectorizer(decode_error = "replace", vocabulary = load('prototype/filehandler/vocabulary.pkl')) + transformed = transformer.fit_transform(loaded_vec.fit_transform(df.body_text).toarray()) + + features = concat([df[['body_len', 'punct%']], DataFrame(transformed.toarray())], axis=1) + + pred = model.predict(features) + labels = list(map(id_to_labels.get, pred)) + + for id, label in zip(data['para_id'], labels): + Paragraph.objects.filter(pk = id).update(label = label) + + return(True) + def addToDatabase(data, file_id): - out = [] + out = {} + para_id = [] + messages = [] forum = Forum(forum_id = data['id'], name = data['name'], document_id = file_id) forum.save() for discussion_ in data['discussions']: @@ -13,7 +47,10 @@ def addToDatabase(data, file_id): for paragraph_ in post_['message']: paragraph = Paragraph(message = paragraph_, label = '', post_id = post.pk) paragraph.save() - out.append(paragraph.pk) + para_id.append(paragraph.pk) + messages.append(paragraph_) + out['para_id'] = para_id + out['messages'] = messages return(out) def listDiscussionsFromFile(id): @@ -30,3 +67,28 @@ def listDiscussionsFromFile(id): discussions_.append(obj) out['discussions'] = discussions_ return(out) + +def listParagraphsFromDiscussion(id): + out = {} + posts = Post.objects.filter(discussion_id = id) + posts_ = [] + for elem in posts: + obj = {} + obj['id'] = elem.pk + obj['parent'] = elem.parent + obj['author'] = elem.author + message = [] + para_id = [] + label = [] + paragraphs = Paragraph.objects.filter(post_id = elem.pk) + for paragraph in paragraphs: + message.append(paragraph.message) + para_id.append(paragraph.pk) + label.append(paragraph.label) + obj['message'] = message + obj['para_id'] = para_id + obj['label'] = label + posts_.append(obj) + out['posts'] = posts_ + return(out) + diff --git a/backend/webapp/prototype/filehandler/labels.pkl b/backend/webapp/prototype/filehandler/labels.pkl new file mode 100644 index 0000000..6ca3f02 Binary files /dev/null and b/backend/webapp/prototype/filehandler/labels.pkl differ diff --git a/backend/webapp/prototype/filehandler/views.py b/backend/webapp/prototype/filehandler/views.py index 05a01f4..68cd55c 100644 --- a/backend/webapp/prototype/filehandler/views.py +++ b/backend/webapp/prototype/filehandler/views.py @@ -7,7 +7,7 @@ from django.http import JsonResponse, HttpResponse from prototype.filehandler.models import Document, Forum from prototype.filehandler.forms import DocumentForm from prototype.filehandler.xmlParser import parseData -from prototype.filehandler.functions import addToDatabase, listDiscussionsFromFile +from prototype.filehandler.functions import addToDatabase, listDiscussionsFromFile, listParagraphsFromDiscussion, createLabels def home(request): @@ -21,7 +21,8 @@ def model_form_upload(request): if form.is_valid(): data = parseData(request.FILES['file']) file_id = (form.save()).pk - addToDatabase(data, file_id) + if not (createLabels(addToDatabase(data, file_id))): + return HttpResponse('Błąd przy dodawaniu informacji do bazy danych/tworzeniu etykiet', status=406) output = listDiscussionsFromFile(file_id) return JsonResponse(output, safe=False) else: @@ -31,3 +32,10 @@ def model_form_upload(request): return render(request, 'core/model_form_upload.html', { 'form' : form }) + +def discussions(request, id): + if request.method == 'GET': + output = listParagraphsFromDiscussion(id) + return JsonResponse(output, safe=False) + else: + return HttpResponse('Nieobsługiwana metoda HTTP', status=406) diff --git a/backend/webapp/prototype/urls.py b/backend/webapp/prototype/urls.py index e1c7532..aeb791b 100644 --- a/backend/webapp/prototype/urls.py +++ b/backend/webapp/prototype/urls.py @@ -24,6 +24,8 @@ urlpatterns = [ path('', views.home, name='home'), path('prototype/form/', views.model_form_upload, name='model_form_upload'), path('admin/', admin.site.urls), + path('discussions/', views.discussions) + ] if settings.DEBUG: