444409 GPT2
This commit is contained in:
parent
798d04eb15
commit
9d17f9743d
10519
dev-0/out.tsv
Normal file
10519
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
80
run.py
Normal file
80
run.py
Normal file
@ -0,0 +1,80 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# In[1]:
|
||||
|
||||
|
||||
import torch
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
|
||||
device = torch.device('cuda')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model: GPT2LMHeadModel = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)
|
||||
model.to(device)
|
||||
|
||||
# In[2]:
|
||||
|
||||
|
||||
import lzma
|
||||
|
||||
|
||||
def read_xz_file(fname):
|
||||
with lzma.open(fname, mode='rt', encoding='utf-8') as f:
|
||||
return [line.strip() for line in f.readlines()]
|
||||
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
dev_input_raw = read_xz_file('dev-0/in.tsv.xz')
|
||||
test_input_raw = read_xz_file('test-A/in.tsv.xz')
|
||||
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
def get_contexts(input_text):
|
||||
all_fields = input_text.replace(r'\n', ' ').split('\t')
|
||||
return {'left': all_fields[6], 'right': all_fields[7]}
|
||||
|
||||
|
||||
dev_input_contexts = [get_contexts(input_text) for input_text in dev_input_raw]
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
test_input_contexts = [get_contexts(input_text) for input_text in test_input_raw]
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
tokenizer.truncation_side = 'left'
|
||||
|
||||
|
||||
def predict_words(dataset):
|
||||
preds = []
|
||||
for entry in tqdm(dataset):
|
||||
text = f"{entry['left']}"
|
||||
src = tokenizer.encode(text, return_tensors="pt", truncation=True).to(device)
|
||||
output = model.generate(src, max_length=len(src[0]) + 1, do_sample=True, top_k=0, temperature=0.8,
|
||||
num_return_sequences=1, no_repeat_ngram_size=2)
|
||||
generated_word = tokenizer.decode(output[0], skip_special_tokens=True).split(' ')[-1]
|
||||
preds.append(f'{generated_word.strip()}:0.99 :0.01')
|
||||
return preds
|
||||
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
dev_preds = predict_words(dev_input_contexts)
|
||||
with open('dev-0/out.tsv', 'w') as f:
|
||||
f.writelines(line + '\n' for line in dev_preds)
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
test_preds = predict_words(test_input_contexts)
|
||||
with open('test-A/out.tsv', 'w') as f:
|
||||
f.writelines(line + '\n' for line in test_preds)
|
7414
test-A/out.tsv
Normal file
7414
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user