44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
import torch
|
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForCausalLM
|
|
import sys
|
|
import regex as re
|
|
|
|
import sys
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
model = AutoModelForCausalLM.from_pretrained('gpt2').to('cuda')
|
|
|
|
|
|
for line in sys.stdin:
|
|
input_text = line.split('\t')[-2].rstrip()
|
|
input_ids = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
|
|
|
|
with torch.no_grad():
|
|
outputs = model(input_ids)
|
|
|
|
|
|
token_logits = outputs.logits[:, -1, : ]
|
|
probs = torch.nn.functional.softmax(token_logits, dim=1)[0]
|
|
|
|
top = torch.topk(probs, 10)
|
|
top_indices = top.indices.tolist()
|
|
top_probs = top.values.tolist()
|
|
top_words = [tokenizer.decode(x) for x in top_indices]
|
|
|
|
string_to_print = ''
|
|
sum_probs = 0
|
|
|
|
for w, p in zip(top_words, top_probs):
|
|
if '<unk>' in w:
|
|
continue
|
|
if re.search(r'\p{L}+', w):
|
|
string_to_print += f"{w}:{p} "
|
|
sum_probs += p
|
|
if string_to_print == '':
|
|
print(f"the:0.2 a:0.3 :0.5")
|
|
continue
|
|
unknow_prob = 1 - sum_probs
|
|
string_to_print += f":{unknow_prob}"
|
|
string_to_print = re.sub(' +', ' ', string_to_print)
|
|
print(string_to_print.rstrip().strip())
|