70 lines
2.7 KiB
Python
70 lines
2.7 KiB
Python
import torch
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
import lzma
|
|
|
|
# import os
|
|
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
top = 50
|
|
model_name = "gpt2"
|
|
device = torch.device('cuda')
|
|
model = GPT2LMHeadModel.from_pretrained(model_name)
|
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
|
tokenizer.truncation_side = 'left'
|
|
model.to(torch.device(device))
|
|
|
|
for folder_name in ['dev-0', 'test-A']:
|
|
linecount = 10519 if folder_name == 'dev-0' else 7414
|
|
processed_lines = 0
|
|
f = lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8')
|
|
with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as file:
|
|
for line in f:
|
|
separated = line.split('\t')
|
|
prefix = separated[6].replace(r'\n', ' ')
|
|
suffix = separated[7].replace(r'\n', ' ')
|
|
|
|
first_next_word = suffix.split()[0]
|
|
#prompt = f'{prefix} [TOKEN] {suffix}\n[TOKEN] = '
|
|
|
|
inputs = tokenizer.encode(prefix, return_tensors="pt", truncation=True).to(device)
|
|
output = model(inputs)
|
|
probs = torch.softmax(output[0][0][-1], dim=0)
|
|
|
|
result = ''
|
|
total = 0
|
|
values, indices = probs.topk(top)
|
|
for val, idx in zip(values, indices):
|
|
final_val = val.item()
|
|
token = tokenizer.decode([idx])
|
|
token = token.strip()
|
|
if token in ",<>.?:;\'\"/\\{[]}|_-+=)(&%^*#@!$":
|
|
continue
|
|
if token in ['ia', 'ix', 'io', 'ik', 'ing']:
|
|
continue
|
|
|
|
# Biore pierwsze slowo z prawego kontekstu i sprawdzam czy jest jednym z tokenów przewidzianych
|
|
# przez prompt złożony z lewego kontekstu i kandydata na słowo w dziurze
|
|
# jesli tak to zwiększam prawdopodobieństwo tego slowa
|
|
new_prompt = f'{prefix} {token} '
|
|
new_inputs = tokenizer.encode(new_prompt, return_tensors="pt", truncation=True).to(device)
|
|
new_output = model(new_inputs)
|
|
new_probs = torch.softmax(output[0][0][-1], dim=0)
|
|
new_values, new_indices = new_probs.topk(top)
|
|
for new_val, new_idx in zip(new_values, new_indices):
|
|
if tokenizer.decode([new_idx]) == first_next_word:
|
|
final_val += new_val.item()
|
|
break
|
|
|
|
total += val
|
|
result += f'{token}:{final_val} '
|
|
result += f':{1 - total}'
|
|
|
|
file.write(result + '\n')
|
|
print(f'\r{folder_name} : {(processed_lines/linecount)*100:.2f}%', end='')
|
|
processed_lines += 1
|
|
#print(processed_lines)
|
|
f.close()
|
|
print()
|