61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
import torch
|
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForCausalLM
|
|
import sys
|
|
import regex as re
|
|
|
|
import pdb
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
model = AutoModelForCausalLM.from_pretrained('gpt2').to('cuda')
|
|
|
|
a = ['I took part in many conferences and competitions at the \t and international']
|
|
for line in sys.stdin:
|
|
# for line in a:
|
|
input_text = line.split('\t')[-2].rstrip()
|
|
right_context = line.split('\t')[-1].rstrip()
|
|
input_ids = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
|
|
|
|
with torch.no_grad():
|
|
outputs = model(input_ids)
|
|
|
|
|
|
token_logits = outputs.logits[:, -1, : ]
|
|
probs = torch.nn.functional.softmax(token_logits, dim=1)[0]
|
|
|
|
top = torch.topk(probs, 4)
|
|
|
|
top_indices = top.indices.tolist()
|
|
top_probs = top.values.tolist()
|
|
top_words = [tokenizer.decode(x) for x in top_indices]
|
|
# pdb.set_trace()
|
|
|
|
|
|
|
|
right_encoded = tokenizer.encode(right_context, return_tensors='pt').to('cuda')
|
|
first_word = right_encoded[0][0].unsqueeze(0).unsqueeze(0)
|
|
second_word = right_encoded[0][1]
|
|
|
|
string_to_print = ''
|
|
sum_probs = 0
|
|
for p, w_i, w in zip(top_probs, top_indices, top_words):
|
|
if re.search(r'\p{L}+', w):
|
|
# pdb.set_trace()
|
|
buff = torch.tensor([w_i]).unsqueeze(0).to('cuda')
|
|
input_ids = torch.cat((buff, first_word), dim=-1)
|
|
with torch.no_grad():
|
|
outputs = model(input_ids)
|
|
|
|
token_logits = outputs.logits[:, -1, : ]
|
|
|
|
probs = torch.nn.functional.softmax(token_logits, dim=1)[0]
|
|
|
|
new_probs = (p + probs[second_word]) * 0.5
|
|
string_to_print += f"{w}:{new_probs} "
|
|
sum_probs += new_probs
|
|
|
|
|
|
unknow_prob = 1 - sum_probs
|
|
string_to_print += f":{unknow_prob}"
|
|
string_to_print = re.sub(' +', ' ', string_to_print)
|
|
print(string_to_print.rstrip().strip())
|
|
|