This commit is contained in:
Mikolaj 2023-06-08 21:10:51 +02:00
parent eb10e5db4a
commit c3b03af282
3 changed files with 17964 additions and 17947 deletions

File diff suppressed because it is too large Load Diff

45
hf.py
View File

@ -3,14 +3,15 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForCausalLM
import sys
import regex as re
import sys
import pdb
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2').to('cuda')
a = ['I took part in many conferences and competitions at the \t and international']
for line in sys.stdin:
# for line in a:
input_text = line.split('\t')[-2].rstrip()
right_context = line.split('\t')[-1].rstrip()
input_ids = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
with torch.no_grad():
@ -20,24 +21,40 @@ for line in sys.stdin:
token_logits = outputs.logits[:, -1, : ]
probs = torch.nn.functional.softmax(token_logits, dim=1)[0]
top = torch.topk(probs, 10)
top = torch.topk(probs, 4)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = [tokenizer.decode(x) for x in top_indices]
# pdb.set_trace()
right_encoded = tokenizer.encode(right_context, return_tensors='pt').to('cuda')
first_word = right_encoded[0][0].unsqueeze(0).unsqueeze(0)
second_word = right_encoded[0][1]
string_to_print = ''
sum_probs = 0
for w, p in zip(top_words, top_probs):
if '<unk>' in w:
continue
for p, w_i, w in zip(top_probs, top_indices, top_words):
if re.search(r'\p{L}+', w):
string_to_print += f"{w}:{p} "
sum_probs += p
if string_to_print == '':
print(f"the:0.2 a:0.3 :0.5")
continue
# pdb.set_trace()
buff = torch.tensor([w_i]).unsqueeze(0).to('cuda')
input_ids = torch.cat((buff, first_word), dim=-1)
with torch.no_grad():
outputs = model(input_ids)
token_logits = outputs.logits[:, -1, : ]
probs = torch.nn.functional.softmax(token_logits, dim=1)[0]
new_probs = (p + probs[second_word]) * 0.5
string_to_print += f"{w}:{new_probs} "
sum_probs += new_probs
unknow_prob = 1 - sum_probs
string_to_print += f":{unknow_prob}"
string_to_print = re.sub(' +', ' ', string_to_print)
print(string_to_print.rstrip().strip())

File diff suppressed because it is too large Load Diff