Compare commits

...

2 Commits

Author SHA1 Message Date
2987fbdbe9 gonito.yaml 2023-06-08 03:00:10 +02:00
4a69a01029 gpt 2023-06-08 02:58:45 +02:00
6 changed files with 17979 additions and 17945 deletions

File diff suppressed because one or more lines are too long

10519
dev-0/out-top=50.tsv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,18 +1,8 @@
description: nn with bagging
description: gpt2-small
tags:
- neural-network
- ngram
- gpt2
params:
epochs: 1
learning-rate: 0.0001
vocab_size: 25000
embed_size: 300
bagging_left_ctx: 25
ngram_left_ctx: 7
bagging_right_ctx: 25
ngram_right_ctx: 3
hidden_size: 150
batch_size: 4000
unwanted-params:
- model-file
- vocab-file

File diff suppressed because one or more lines are too long

7414
test-A/out-top=50.tsv Normal file

File diff suppressed because it is too large Load Diff

44
zad121.py Normal file
View File

@ -0,0 +1,44 @@
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import lzma
# import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
torch.cuda.empty_cache()
top = 50
model_name = "gpt2"
device = torch.device('cuda')
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model.to(torch.device(device))
for folder_name in ['dev-0', 'test-A']:
linecount = 10519 if folder_name == 'dev-0' else 7414
processed_lines = 0
f = lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8')
with open(f'{folder_name}/out-top={top}.tsv', 'w', encoding='utf-8') as file:
for line in f:
separated = line.split('\t')
prefix = separated[6].replace(r'\n', ' ')
inputs = tokenizer.encode(prefix, return_tensors="pt").to(device)
output = model(inputs)
probs = torch.softmax(output[0][0][-1], dim=0)
result = ''
total = 0
values, indices = probs.topk(top)
for val, idx in zip(values, indices):
token = tokenizer.decode([idx])
total += val
result += f'{token.strip()}:{val} '
result += f':{1 - total}'
file.write(result + '\n')
print(f'\r{folder_name} : {(processed_lines/linecount)*100:.2f}%', end='')
processed_lines += 1
#print(processed_lines)
f.close()
print()