From a21a1866550375f0f8c4c57dc707680d350b245b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20Parafi=C5=84ski?= Date: Wed, 5 Apr 2023 00:56:21 +0200 Subject: [PATCH] =?UTF-8?q?Usu=C5=84=20'bigram.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bigram.py | 36 ------------------------------------ 1 file changed, 36 deletions(-) delete mode 100644 bigram.py diff --git a/bigram.py b/bigram.py deleted file mode 100644 index 265c1f6..0000000 --- a/bigram.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from transformers import AutoTokenizer, AutoModelForMaskedLM -import sys - -tokenizer = AutoTokenizer.from_pretrained("roberta-base") -model = AutoModelForMaskedLM.from_pretrained("roberta-base") - -for line in sys.stdin: - line_splited = line.split("\t") - left_context = line_splited[6].split(" ")[-1] - right_context = line_splited[7].split(" ")[0] - - word = "[MASK]" - - text = f"{left_context} {word} {right_context}" - - input_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt", max_length=512, truncation=True) - - mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1][0] - - with torch.no_grad(): - outputs = model(input_ids) - predictions = outputs[0][0, mask_token_index].softmax(dim=0) - - top_k = 1000 - top_k_tokens = torch.topk(predictions, top_k).indices.tolist() - result = '' - prob_sum = 0 - for token in top_k_tokens: - word = tokenizer.convert_ids_to_tokens([token])[0] - prob = predictions[token].item() - prob_sum += prob - result += f"{word}:{prob} " - diff = 1.0 - prob_sum - result += f":{diff}" - print(result) \ No newline at end of file