fix tone
This commit is contained in:
parent
7c68d0ce7b
commit
ddb3326084
@ -4,7 +4,9 @@ import torch
|
||||
model = BertForSequenceClassification.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
||||
tokenizer = BertTokenizer.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
||||
|
||||
labels = ['controversial', 'intriguing', 'formal']
|
||||
labels = ['Młodzieżowy', 'Intrygujący', 'Formalny']
|
||||
neg_labels = ['Neutralny wiekowo', 'Mało intrygujący', 'Nieformalny']
|
||||
tmp = [0,1,2]
|
||||
|
||||
def clear_data(data):
|
||||
data = [i.strip() for i in data]
|
||||
@ -30,7 +32,9 @@ def predict_labels(texts):
|
||||
predicted_labels = []
|
||||
for pred in predictions:
|
||||
label_indices = torch.nonzero(pred).flatten().tolist()
|
||||
predicted_labels.append([labels[i] for i in label_indices])
|
||||
difference = list(set(tmp) - set(label_indices))
|
||||
|
||||
predicted_labels.append([labels[i] for i in label_indices]+[neg_labels[i] for i in difference])
|
||||
|
||||
return predicted_labels
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user