hf_roberta_base_as_in_ireland
This commit is contained in:
parent
a8101bd385
commit
46d0cd7a38
21034
dev-0/out.tsv
21034
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
@ -11,6 +11,7 @@ with open('test_dataset_A.pickle','rb') as f_p:
|
|||||||
test_dataset = pickle.load(f_p)
|
test_dataset = pickle.load(f_p)
|
||||||
|
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained('./roberta_year_prediction/epoch_best')
|
model = AutoModelForSequenceClassification.from_pretrained('./roberta_year_prediction/epoch_best')
|
||||||
model.eval()
|
model.eval()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -25,7 +26,7 @@ with open('scalers.pickle', 'rb') as f_scaler:
|
|||||||
scalers = pickle.load(f_scaler)
|
scalers = pickle.load(f_scaler)
|
||||||
|
|
||||||
def predict(dataset, out_f):
|
def predict(dataset, out_f):
|
||||||
eval_dataloader = DataLoader(dataset, batch_size=50)
|
eval_dataloader = DataLoader(dataset, batch_size=10)
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
progress_bar = tqdm(range(len(eval_dataloader)))
|
progress_bar = tqdm(range(len(eval_dataloader)))
|
||||||
|
16992
test-A/out.tsv
16992
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user