left context GPT2

This commit is contained in:
Adrian Charkiewicz 2023-06-08 23:21:49 +02:00
parent 1f742b4802
commit 13e970d70f
4 changed files with 17974 additions and 9 deletions

10519
dev-0/out.tsv Normal file

File diff suppressed because one or more lines are too long

View File

@ -1,10 +1,3 @@
description: trigram model
description: GPT2
tags:
- neural-network
- trigram
params:
epochs: 1
learning-rate: 0.0001
vocab_size: 40000
embed_size: 300
hidden_size: 256
- decoder-only

39
main.py Normal file
View File

@ -0,0 +1,39 @@
import sys
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
model_name = 'gpt2-xl'
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
device = 'cuda:0'
model.to(device)
#with open("out-dev.tsv", "w") as file:
with open("out-test.tsv", "w") as file:
for line in sys.stdin:
line = line.strip("\n")
fields = line.split("\t")
left_context = fields[6]
left_context = left_context.replace("\\n", " ")
inputs = tokenizer.encode(left_context, return_tensors="pt").to(device)
outputs = model(inputs)
z_dist = outputs[0][0][-1]
prob_dist = torch.softmax(z_dist, dim=0)
topk_values, topk_indices = prob_dist.topk(300)
unk_bonus = 1 - sum(topk_values)
# print(f"{topk_values=}")
# print(f"{topk_indices=}")
result =r""
for v, idx in zip(topk_values, topk_indices):
token = tokenizer.decode([idx])
token =str(token).strip(" ")
if token.isalnum():
# print(f"{v} {idx} {token}")
result = result + token + ":"+str(v.item())+" "
else:
unk_bonus+=v.item()
#result = result.replace("\n",r"\n").replace("\t",r"\t")
result+=f":{unk_bonus}"
file.write(result+"\n")

7414
test-A/out.tsv Normal file

File diff suppressed because one or more lines are too long