transformer

This commit is contained in:
Adrian Charkiewicz 2023-04-05 07:23:48 +02:00
parent 406c59f600
commit ccb5d5aab6
5 changed files with 10599 additions and 10525 deletions

View File

@ -1,8 +1,37 @@
#!/usr/bin/python3 import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import sys import sys
for line in sys.stdin:
if "United" in line: tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print('States:0.9 :0.1') model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
else:
print('the:0.6 a:0.3 :0.1') for line in sys.stdin:
line_splitted = line.split("\t")
left_context = line_splitted[6].split(" ")[-1]
right_context = line_splitted[7].split(" ")[0]
word = "[MASK]"
text = f"{left_context} {word} {right_context}"
input_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt", max_length=512, truncation=True)
mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1][0]
with torch.inference_mode():
outputs = model(input_ids)
predictions = outputs[0][0, mask_token_index].softmax(dim=0)
top_k = 500
top_k_tokens = torch.topk(predictions, top_k).indices.tolist()
result = ''
prob_sum = 0
for token in top_k_tokens:
word = tokenizer.convert_ids_to_tokens([token])[0]
prob = predictions[token].item()
prob_sum += prob
result += f"{word}:{prob} "
diff = 1.0 - prob_sum
result += f":{diff}"
print(result)

File diff suppressed because one or more lines are too long

37
lm0.py Normal file
View File

@ -0,0 +1,37 @@
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import sys
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
for line in sys.stdin:
line_splitted = line.split("\t")
left_context = line_splitted[6].split(" ")[-1]
right_context = line_splitted[7].split(" ")[0]
word = "[MASK]"
text = f"{left_context} {word} {right_context}"
input_ids = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt", max_length=512, truncation=True)
mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1][0]
with torch.inference_mode():
outputs = model(input_ids)
predictions = outputs[0][0, mask_token_index].softmax(dim=0)
top_k = 500
top_k_tokens = torch.topk(predictions, top_k).indices.tolist()
result = ''
prob_sum = 0
for token in top_k_tokens:
word = tokenizer.convert_ids_to_tokens([token])[0]
prob = predictions[token].item()
prob_sum += prob
result += f"{word}:{prob} "
diff = 1.0 - prob_sum
result += f":{diff}"
print(result)

8
lm2.py Normal file
View File

@ -0,0 +1,8 @@
#!/usr/bin/python3
import sys
for line in sys.stdin:
if "United" in line:
print('States:0.9 :0.1')
else:
print('the:0.6 a:0.3 :0.1')