81 lines
1.8 KiB
Python
81 lines
1.8 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
# In[1]:
|
|
|
|
|
|
import torch
|
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
|
|
|
device = torch.device('cuda')
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
model: GPT2LMHeadModel = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id)
|
|
model.to(device)
|
|
|
|
# In[2]:
|
|
|
|
|
|
import lzma
|
|
|
|
|
|
def read_xz_file(fname):
|
|
with lzma.open(fname, mode='rt', encoding='utf-8') as f:
|
|
return [line.strip() for line in f.readlines()]
|
|
|
|
|
|
# In[3]:
|
|
|
|
|
|
dev_input_raw = read_xz_file('dev-0/in.tsv.xz')
|
|
test_input_raw = read_xz_file('test-A/in.tsv.xz')
|
|
|
|
|
|
# In[4]:
|
|
|
|
|
|
def get_contexts(input_text):
|
|
all_fields = input_text.replace(r'\n', ' ').split('\t')
|
|
return {'left': all_fields[6], 'right': all_fields[7]}
|
|
|
|
|
|
dev_input_contexts = [get_contexts(input_text) for input_text in dev_input_raw]
|
|
|
|
# In[5]:
|
|
|
|
|
|
test_input_contexts = [get_contexts(input_text) for input_text in test_input_raw]
|
|
|
|
# In[6]:
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
tokenizer.truncation_side = 'left'
|
|
|
|
|
|
def predict_words(dataset):
|
|
preds = []
|
|
for entry in tqdm(dataset):
|
|
text = f"{entry['left']}"
|
|
src = tokenizer.encode(text, return_tensors="pt", truncation=True).to(device)
|
|
output = model.generate(src, max_length=len(src[0]) + 1, do_sample=True, top_k=0, temperature=0.8,
|
|
num_return_sequences=1, no_repeat_ngram_size=2)
|
|
generated_word = tokenizer.decode(output[0], skip_special_tokens=True).split(' ')[-1]
|
|
preds.append(f'{generated_word.strip()}:0.99 :0.01')
|
|
return preds
|
|
|
|
|
|
# In[7]:
|
|
|
|
|
|
dev_preds = predict_words(dev_input_contexts)
|
|
with open('dev-0/out.tsv', 'w') as f:
|
|
f.writelines(line + '\n' for line in dev_preds)
|
|
|
|
# In[8]:
|
|
|
|
|
|
test_preds = predict_words(test_input_contexts)
|
|
with open('test-A/out.tsv', 'w') as f:
|
|
f.writelines(line + '\n' for line in test_preds)
|