fix tone
This commit is contained in:
parent
7c68d0ce7b
commit
ddb3326084
@ -4,7 +4,9 @@ import torch
|
|||||||
model = BertForSequenceClassification.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
model = BertForSequenceClassification.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
||||||
tokenizer = BertTokenizer.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
tokenizer = BertTokenizer.from_pretrained('jagiyahh/simple-polish-tone-recognition')
|
||||||
|
|
||||||
labels = ['controversial', 'intriguing', 'formal']
|
labels = ['Młodzieżowy', 'Intrygujący', 'Formalny']
|
||||||
|
neg_labels = ['Neutralny wiekowo', 'Mało intrygujący', 'Nieformalny']
|
||||||
|
tmp = [0,1,2]
|
||||||
|
|
||||||
def clear_data(data):
|
def clear_data(data):
|
||||||
data = [i.strip() for i in data]
|
data = [i.strip() for i in data]
|
||||||
@ -30,7 +32,9 @@ def predict_labels(texts):
|
|||||||
predicted_labels = []
|
predicted_labels = []
|
||||||
for pred in predictions:
|
for pred in predictions:
|
||||||
label_indices = torch.nonzero(pred).flatten().tolist()
|
label_indices = torch.nonzero(pred).flatten().tolist()
|
||||||
predicted_labels.append([labels[i] for i in label_indices])
|
difference = list(set(tmp) - set(label_indices))
|
||||||
|
|
||||||
|
predicted_labels.append([labels[i] for i in label_indices]+[neg_labels[i] for i in difference])
|
||||||
|
|
||||||
return predicted_labels
|
return predicted_labels
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user