This commit is contained in:
Maciej Ścigacz 2023-06-14 00:28:04 +02:00
parent 7c68d0ce7b
commit ddb3326084

View File

@ -4,7 +4,9 @@ import torch
model = BertForSequenceClassification.from_pretrained('jagiyahh/simple-polish-tone-recognition')
tokenizer = BertTokenizer.from_pretrained('jagiyahh/simple-polish-tone-recognition')
labels = ['controversial', 'intriguing', 'formal']
labels = ['Młodzieżowy', 'Intrygujący', 'Formalny']
neg_labels = ['Neutralny wiekowo', 'Mało intrygujący', 'Nieformalny']
tmp = [0,1,2]
def clear_data(data):
data = [i.strip() for i in data]
@ -30,7 +32,9 @@ def predict_labels(texts):
predicted_labels = []
for pred in predictions:
label_indices = torch.nonzero(pred).flatten().tolist()
predicted_labels.append([labels[i] for i in label_indices])
difference = list(set(tmp) - set(label_indices))
predicted_labels.append([labels[i] for i in label_indices]+[neg_labels[i] for i in difference])
return predicted_labels