hf_roberta_base_as_in_ireland

This commit is contained in:
kubapok 2021-12-24 15:15:39 +01:00
parent a8101bd385
commit 46d0cd7a38
3 changed files with 19015 additions and 19014 deletions

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@ with open('test_dataset_A.pickle','rb') as f_p:
test_dataset = pickle.load(f_p) test_dataset = pickle.load(f_p)
device = 'cuda' device = 'cuda'
model = AutoModelForSequenceClassification.from_pretrained('./roberta_year_prediction/epoch_best') model = AutoModelForSequenceClassification.from_pretrained('./roberta_year_prediction/epoch_best')
model.eval() model.eval()
model.to(device) model.to(device)
@ -25,7 +26,7 @@ with open('scalers.pickle', 'rb') as f_scaler:
scalers = pickle.load(f_scaler) scalers = pickle.load(f_scaler)
def predict(dataset, out_f): def predict(dataset, out_f):
eval_dataloader = DataLoader(dataset, batch_size=50) eval_dataloader = DataLoader(dataset, batch_size=10)
outputs = [] outputs = []
progress_bar = tqdm(range(len(eval_dataloader))) progress_bar = tqdm(range(len(eval_dataloader)))

File diff suppressed because it is too large Load Diff