import pickle import torch from transformers import AutoTokenizer, RobertaModel, RobertaTokenizer from regressor_head import RegressorHead from classification_head import YearClassificationHead from torch.utils.data import DataLoader from tqdm.auto import tqdm from config import * with open('eval_dataset_full.pickle','rb') as f_p: eval_dataset_full = pickle.load(f_p) with open('test_dataset_A.pickle','rb') as f_p: test_dataset = pickle.load(f_p) device = 'cuda' with open('./roberta_year_prediction/epoch_best', 'rb') as f: model = pickle.load(f) model.eval() model.to(device) lrelu = torch.nn.LeakyReLU(0.0) def hard_clip(t): t = lrelu(t) t = -lrelu(-t + 1 ) + 1 return t with open('scalers.pickle', 'rb') as f_scaler: scalers = pickle.load(f_scaler) def transform_batch(batch): batch['input_ids'] = torch.stack(batch['input_ids']).permute(1,0).to(device) batch['attention_mask'] = torch.stack(batch['attention_mask']).permute(1,0).to(device) labels = batch['year'].to(device) batch['input_ids'].to(device) batch['attention_mask'].to(device) for c in set(batch.keys()) - {'input_ids', 'attention_mask'}: del batch[c] return batch, labels def predict(dataset, out_f): eval_dataloader = DataLoader(dataset, batch_size=20) outputs = [] progress_bar = tqdm(range(len(eval_dataloader))) for batch in eval_dataloader: batch, labels = transform_batch(batch) o = model(**batch)[0] o = model.regressor_head(o) o = torch.argmax(o,1) outputs.extend(o.tolist()) progress_bar.update(1) outputs = [a + MIN_YEAR for a in outputs] with open(out_f,'w') as f_out: for o in outputs: f_out.write(str(o) + '\n') predict(eval_dataset_full, '../dev-0/out.tsv') predict(test_dataset, '../test-A/out.tsv')