2023-06-14 00:03:39 +02:00
|
|
|
from transformers import BertTokenizer, BertForSequenceClassification
|
|
|
|
import torch
|
|
|
|
|
|
|
|
model = BertForSequenceClassification.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
|
|
|
|
2023-06-14 00:28:04 +02:00
|
|
|
labels = ['Młodzieżowy', 'Intrygujący', 'Formalny']
|
|
|
|
neg_labels = ['Neutralny wiekowo', 'Mało intrygujący', 'Nieformalny']
|
|
|
|
tmp = [0,1,2]
|
2023-06-14 00:03:39 +02:00
|
|
|
|
|
|
|
def clear_data(data):
|
|
|
|
data = [i.strip() for i in data]
|
|
|
|
data = [i.lower() for i in data]
|
|
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
def predict_labels(texts):
|
|
|
|
encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')
|
|
|
|
input_ids = encodings['input_ids']
|
|
|
|
attention_mask = encodings['attention_mask']
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
outputs = model(input_ids, attention_mask=attention_mask)
|
|
|
|
logits = outputs.logits
|
|
|
|
|
|
|
|
probabilities = torch.sigmoid(logits)
|
|
|
|
|
|
|
|
threshold = 0.5
|
|
|
|
predictions = (probabilities > threshold).int()
|
|
|
|
|
|
|
|
predicted_labels = []
|
|
|
|
for pred in predictions:
|
|
|
|
label_indices = torch.nonzero(pred).flatten().tolist()
|
2023-06-14 00:28:04 +02:00
|
|
|
difference = list(set(tmp) - set(label_indices))
|
|
|
|
|
|
|
|
predicted_labels.append([labels[i] for i in label_indices]+[neg_labels[i] for i in difference])
|
2023-06-14 00:03:39 +02:00
|
|
|
|
|
|
|
return predicted_labels
|
|
|
|
|
|
|
|
def tone_prediction(data):
|
|
|
|
prediction = predict_labels(data)
|
|
|
|
|
|
|
|
return prediction
|