add tone and style services
This commit is contained in:
parent
c511181eb5
commit
7c68d0ce7b
@ -8,9 +8,13 @@ def create_app():
|
||||
from application.services.sentiment_service import sentiment_service
|
||||
from application.services.errors_service import errors_service
|
||||
from application.services.irony_service import irony_service
|
||||
from application.services.style_services import style_service
|
||||
from application.services.tone_services import tone_service
|
||||
|
||||
application.register_blueprint(sentiment_service)
|
||||
application.register_blueprint(errors_service)
|
||||
application.register_blueprint(irony_service)
|
||||
application.register_blueprint(style_service)
|
||||
application.register_blueprint(tone_service)
|
||||
|
||||
return application
|
33
application/functions/style.py
Normal file
33
application/functions/style.py
Normal file
@ -0,0 +1,33 @@
|
||||
from transformers import pipeline
|
||||
import re
|
||||
|
||||
pipe = pipeline('text-classification', model="jagiyahh/simple-polish-stylistic-errors", tokenizer = 'dkleczek/bert-base-polish-uncased-v1')
|
||||
|
||||
def style_prediction(data):
|
||||
result = pipe(data)
|
||||
|
||||
return result
|
||||
|
||||
def clear_data(data):
|
||||
data = [re.sub(r"[^A-Za-zżźćńółęąśŻŹĆĄŚĘŁÓŃ ']+", r"", i) for i in data['sentences']]
|
||||
data = [x for x in data if x != '']
|
||||
data = [i.strip() for i in data]
|
||||
data = [i.lower() for i in data]
|
||||
|
||||
return data
|
||||
|
||||
def count_predictions(predictions):
|
||||
l0 = 0
|
||||
l1 = 0
|
||||
all = {}
|
||||
|
||||
for i in predictions:
|
||||
if i['label'] == 'LABEL_0':
|
||||
l0 += 1
|
||||
if i['label'] == 'LABEL_1':
|
||||
l1 += 1
|
||||
|
||||
all['stylistically_positive'] = l0
|
||||
all['stylistically_negative'] = l1
|
||||
|
||||
return all
|
40
application/functions/tone.py
Normal file
40
application/functions/tone.py
Normal file
@ -0,0 +1,40 @@
|
||||
from transformers import BertTokenizer, BertForSequenceClassification
|
||||
import torch
|
||||
|
||||
model = BertForSequenceClassification.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
||||
tokenizer = BertTokenizer.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
||||
|
||||
labels = ['controversial', 'intriguing', 'formal']
|
||||
|
||||
def clear_data(data):
|
||||
data = [i.strip() for i in data]
|
||||
data = [i.lower() for i in data]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def predict_labels(texts):
|
||||
encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')
|
||||
input_ids = encodings['input_ids']
|
||||
attention_mask = encodings['attention_mask']
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids, attention_mask=attention_mask)
|
||||
logits = outputs.logits
|
||||
|
||||
probabilities = torch.sigmoid(logits)
|
||||
|
||||
threshold = 0.5
|
||||
predictions = (probabilities > threshold).int()
|
||||
|
||||
predicted_labels = []
|
||||
for pred in predictions:
|
||||
label_indices = torch.nonzero(pred).flatten().tolist()
|
||||
predicted_labels.append([labels[i] for i in label_indices])
|
||||
|
||||
return predicted_labels
|
||||
|
||||
def tone_prediction(data):
|
||||
prediction = predict_labels(data)
|
||||
|
||||
return prediction
|
16
application/services/style_services.py
Normal file
16
application/services/style_services.py
Normal file
@ -0,0 +1,16 @@
|
||||
from flask import(
|
||||
request,
|
||||
jsonify,
|
||||
Blueprint,
|
||||
)
|
||||
from application.functions.style import style_prediction, clear_data
|
||||
|
||||
style_service = Blueprint("style_service", __name__)
|
||||
|
||||
@style_service.route("/get_style_data", methods=['POST'])
|
||||
def get_data():
|
||||
data = request.get_json()
|
||||
data_clear = clear_data(data)
|
||||
predicitons = style_prediction(data_clear)
|
||||
|
||||
return jsonify({"predictions": predicitons})
|
16
application/services/tone_services.py
Normal file
16
application/services/tone_services.py
Normal file
@ -0,0 +1,16 @@
|
||||
from flask import(
|
||||
request,
|
||||
jsonify,
|
||||
Blueprint,
|
||||
)
|
||||
from application.functions.tone import tone_prediction, clear_data
|
||||
|
||||
tone_service = Blueprint("tone_service", __name__)
|
||||
|
||||
@tone_service.route("/get_tone_data", methods=['POST'])
|
||||
def get_data():
|
||||
data = request.get_json()
|
||||
data_clear = clear_data(data['sentences'])
|
||||
predicitons = tone_prediction(data_clear)
|
||||
|
||||
return jsonify({"predictions": predicitons})
|
Loading…
Reference in New Issue
Block a user