from transformers import BertTokenizer, BertForSequenceClassification import torch model = BertForSequenceClassification.from_pretrained('jagiyahh/simple-polish-tone-recognition') tokenizer = BertTokenizer.from_pretrained('jagiyahh/simple-polish-tone-recognition') labels = ['Młodzieżowy', 'Intrygujący', 'Formalny'] neg_labels = ['Neutralny wiekowo', 'Mało intrygujący', 'Nieformalny'] tmp = [0,1,2] def clear_data(data): data = [i.strip() for i in data] data = [i.lower() for i in data] return data def predict_labels(texts): encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt') input_ids = encodings['input_ids'] attention_mask = encodings['attention_mask'] with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask) logits = outputs.logits probabilities = torch.sigmoid(logits) threshold = 0.5 predictions = (probabilities > threshold).int() predicted_labels = [] for pred in predictions: label_indices = torch.nonzero(pred).flatten().tolist() difference = list(set(tmp) - set(label_indices)) predicted_labels.append([labels[i] for i in label_indices]+[neg_labels[i] for i in difference]) return predicted_labels def tone_prediction(data): prediction = predict_labels(data) return prediction