exact_data2023/application/functions/tone.py

44 lines
1.4 KiB
Python
Raw Permalink Normal View History

2023-06-14 00:03:39 +02:00
from transformers import BertTokenizer, BertForSequenceClassification
import torch
model = BertForSequenceClassification.from_pretrained('jagiyahh/simple-polish-tone-recognition')
tokenizer = BertTokenizer.from_pretrained('jagiyahh/simple-polish-tone-recognition')
2023-06-14 00:28:04 +02:00
labels = ['Młodzieżowy', 'Intrygujący', 'Formalny']
neg_labels = ['Neutralny wiekowo', 'Mało intrygujący', 'Nieformalny']
tmp = [0,1,2]
2023-06-14 00:03:39 +02:00
def clear_data(data):
data = [i.strip() for i in data]
data = [i.lower() for i in data]
return data
def predict_labels(texts):
encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')
input_ids = encodings['input_ids']
attention_mask = encodings['attention_mask']
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
probabilities = torch.sigmoid(logits)
threshold = 0.5
predictions = (probabilities > threshold).int()
predicted_labels = []
for pred in predictions:
label_indices = torch.nonzero(pred).flatten().tolist()
2023-06-14 00:28:04 +02:00
difference = list(set(tmp) - set(label_indices))
predicted_labels.append([labels[i] for i in label_indices]+[neg_labels[i] for i in difference])
2023-06-14 00:03:39 +02:00
return predicted_labels
def tone_prediction(data):
prediction = predict_labels(data)
return prediction