Merge commit 'a21a1866550375f0f8c4c57dc707680d350b245b'
This commit is contained in:
commit
909c1e8a4c
36
bigram.py
36
bigram.py
@ -1,36 +0,0 @@
|
|||||||
import torch
|
|
||||||
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
||||||
import sys
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
|
|
||||||
model = AutoModelForMaskedLM.from_pretrained("roberta-base")
|
|
||||||
|
|
||||||
for line in sys.stdin:
|
|
||||||
line_splited = line.split("\t")
|
|
||||||
left_context = line_splited[6].split(" ")[-1]
|
|
||||||
right_context = line_splited[7].split(" ")[0]
|
|
||||||
|
|
||||||
word = "[MASK]"
|
|
||||||
|
|
||||||
text = f"{left_context} {word} {right_context}"
|
|
||||||
|
|
||||||
input_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt", max_length=512, truncation=True)
|
|
||||||
|
|
||||||
mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1][0]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(input_ids)
|
|
||||||
predictions = outputs[0][0, mask_token_index].softmax(dim=0)
|
|
||||||
|
|
||||||
top_k = 1000
|
|
||||||
top_k_tokens = torch.topk(predictions, top_k).indices.tolist()
|
|
||||||
result = ''
|
|
||||||
prob_sum = 0
|
|
||||||
for token in top_k_tokens:
|
|
||||||
word = tokenizer.convert_ids_to_tokens([token])[0]
|
|
||||||
prob = predictions[token].item()
|
|
||||||
prob_sum += prob
|
|
||||||
result += f"{word}:{prob} "
|
|
||||||
diff = 1.0 - prob_sum
|
|
||||||
result += f":{diff}"
|
|
||||||
print(result)
|
|
Loading…
Reference in New Issue
Block a user