challenging-america-word-ga.../hf.py
2023-06-08 13:01:08 +02:00

44 lines
1.2 KiB
Python

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForCausalLM
import sys
import regex as re
import sys
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2').to('cuda')
for line in sys.stdin:
input_text = line.split('\t')[-2].rstrip()
input_ids = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
with torch.no_grad():
outputs = model(input_ids)
token_logits = outputs.logits[:, -1, : ]
probs = torch.nn.functional.softmax(token_logits, dim=1)[0]
top = torch.topk(probs, 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = [tokenizer.decode(x) for x in top_indices]
string_to_print = ''
sum_probs = 0
for w, p in zip(top_words, top_probs):
if '<unk>' in w:
continue
if re.search(r'\p{L}+', w):
string_to_print += f"{w}:{p} "
sum_probs += p
if string_to_print == '':
print(f"the:0.2 a:0.3 :0.5")
continue
unknow_prob = 1 - sum_probs
string_to_print += f":{unknow_prob}"
string_to_print = re.sub(' +', ' ', string_to_print)
print(string_to_print.rstrip().strip())