left context GPT2
This commit is contained in:
parent
1f742b4802
commit
13e970d70f
10519
dev-0/out.tsv
Normal file
10519
dev-0/out.tsv
Normal file
File diff suppressed because one or more lines are too long
11
gonito.yaml
11
gonito.yaml
@ -1,10 +1,3 @@
|
|||||||
description: trigram model
|
description: GPT2
|
||||||
tags:
|
tags:
|
||||||
- neural-network
|
- decoder-only
|
||||||
- trigram
|
|
||||||
params:
|
|
||||||
epochs: 1
|
|
||||||
learning-rate: 0.0001
|
|
||||||
vocab_size: 40000
|
|
||||||
embed_size: 300
|
|
||||||
hidden_size: 256
|
|
||||||
|
39
main.py
Normal file
39
main.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||||
|
|
||||||
|
model_name = 'gpt2-xl'
|
||||||
|
model = GPT2LMHeadModel.from_pretrained(model_name)
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
||||||
|
device = 'cuda:0'
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
#with open("out-dev.tsv", "w") as file:
|
||||||
|
with open("out-test.tsv", "w") as file:
|
||||||
|
for line in sys.stdin:
|
||||||
|
line = line.strip("\n")
|
||||||
|
fields = line.split("\t")
|
||||||
|
left_context = fields[6]
|
||||||
|
left_context = left_context.replace("\\n", " ")
|
||||||
|
inputs = tokenizer.encode(left_context, return_tensors="pt").to(device)
|
||||||
|
outputs = model(inputs)
|
||||||
|
z_dist = outputs[0][0][-1]
|
||||||
|
prob_dist = torch.softmax(z_dist, dim=0)
|
||||||
|
topk_values, topk_indices = prob_dist.topk(300)
|
||||||
|
unk_bonus = 1 - sum(topk_values)
|
||||||
|
# print(f"{topk_values=}")
|
||||||
|
# print(f"{topk_indices=}")
|
||||||
|
result =r""
|
||||||
|
for v, idx in zip(topk_values, topk_indices):
|
||||||
|
token = tokenizer.decode([idx])
|
||||||
|
token =str(token).strip(" ")
|
||||||
|
if token.isalnum():
|
||||||
|
# print(f"{v} {idx} {token}")
|
||||||
|
result = result + token + ":"+str(v.item())+" "
|
||||||
|
else:
|
||||||
|
unk_bonus+=v.item()
|
||||||
|
#result = result.replace("\n",r"\n").replace("\t",r"\t")
|
||||||
|
result+=f":{unk_bonus}"
|
||||||
|
file.write(result+"\n")
|
||||||
|
|
||||||
|
|
7414
test-A/out.tsv
Normal file
7414
test-A/out.tsv
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user