71 lines
1.8 KiB
Python
71 lines
1.8 KiB
Python
import pickle
|
|
import torch
|
|
from transformers import AutoTokenizer, RobertaModel, RobertaTokenizer
|
|
from regressor_head import RegressorHead
|
|
from classification_head import YearClassificationHead
|
|
from torch.utils.data import DataLoader
|
|
from tqdm.auto import tqdm
|
|
from config import *
|
|
|
|
with open('eval_dataset_full.pickle','rb') as f_p:
|
|
eval_dataset_full = pickle.load(f_p)
|
|
|
|
with open('test_dataset_A.pickle','rb') as f_p:
|
|
test_dataset = pickle.load(f_p)
|
|
|
|
device = 'cuda'
|
|
with open('./roberta_year_prediction/epoch_best', 'rb') as f:
|
|
model = pickle.load(f)
|
|
|
|
model.eval()
|
|
model.to(device)
|
|
|
|
lrelu = torch.nn.LeakyReLU(0.0)
|
|
def hard_clip(t):
|
|
t = lrelu(t)
|
|
t = -lrelu(-t + 1 ) + 1
|
|
return t
|
|
|
|
with open('scalers.pickle', 'rb') as f_scaler:
|
|
scalers = pickle.load(f_scaler)
|
|
|
|
|
|
def transform_batch(batch):
|
|
batch['input_ids'] = torch.stack(batch['input_ids']).permute(1,0).to(device)
|
|
batch['attention_mask'] = torch.stack(batch['attention_mask']).permute(1,0).to(device)
|
|
labels = batch['year'].to(device)
|
|
|
|
batch['input_ids'].to(device)
|
|
batch['attention_mask'].to(device)
|
|
|
|
for c in set(batch.keys()) - {'input_ids', 'attention_mask'}:
|
|
del batch[c]
|
|
|
|
return batch, labels
|
|
|
|
|
|
def predict(dataset, out_f):
|
|
eval_dataloader = DataLoader(dataset, batch_size=10)
|
|
outputs = []
|
|
|
|
progress_bar = tqdm(range(len(eval_dataloader)))
|
|
|
|
for batch in eval_dataloader:
|
|
batch, labels = transform_batch(batch)
|
|
|
|
o = model(**batch)[0]
|
|
o = model.regressor_head(o)
|
|
o = torch.argmax(o,1)
|
|
|
|
outputs.extend(o.tolist())
|
|
progress_bar.update(1)
|
|
outputs = [a + MIN_YEAR for a in outputs]
|
|
|
|
with open(out_f,'w') as f_out:
|
|
|
|
for o in outputs:
|
|
f_out.write(str(o) + '\n')
|
|
|
|
predict(eval_dataset_full, '../dev-0/out.tsv')
|
|
predict(test_dataset, '../test-A/out.tsv')
|