diff --git a/bigram.py b/bigram.py new file mode 100644 index 0000000..265c1f6 --- /dev/null +++ b/bigram.py @@ -0,0 +1,36 @@ +import torch +from transformers import AutoTokenizer, AutoModelForMaskedLM +import sys + +tokenizer = AutoTokenizer.from_pretrained("roberta-base") +model = AutoModelForMaskedLM.from_pretrained("roberta-base") + +for line in sys.stdin: + line_splited = line.split("\t") + left_context = line_splited[6].split(" ")[-1] + right_context = line_splited[7].split(" ")[0] + + word = "[MASK]" + + text = f"{left_context} {word} {right_context}" + + input_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt", max_length=512, truncation=True) + + mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1][0] + + with torch.no_grad(): + outputs = model(input_ids) + predictions = outputs[0][0, mask_token_index].softmax(dim=0) + + top_k = 1000 + top_k_tokens = torch.topk(predictions, top_k).indices.tolist() + result = '' + prob_sum = 0 + for token in top_k_tokens: + word = tokenizer.convert_ids_to_tokens([token])[0] + prob = predictions[token].item() + prob_sum += prob + result += f"{word}:{prob} " + diff = 1.0 - prob_sum + result += f":{diff}" + print(result) \ No newline at end of file diff --git a/command.sh b/command.sh index c6f3300..07776a4 100644 --- a/command.sh +++ b/command.sh @@ -1,2 +1,2 @@ #!/bin/bash -xzcat test-A/in.tsv.xz| python3 ../lm0.py > test-A/out.tsv \ No newline at end of file +xzcat test-A/in.tsv.xz| python3 bigram.py > test-A/out.tsv \ No newline at end of file