english roberta large no finetune
This commit is contained in:
parent
bc0be042bf
commit
b3e482c2ed
71190
dev-0/out.tsv
71190
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
62
roberta_large_no_finetune/predict.py
Normal file
62
roberta_large_no_finetune/predict.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from fairseq.models.roberta import RobertaModel
|
||||
from fairseq import hub_utils
|
||||
from fairseq.models.roberta import RobertaModel, RobertaHubInterface
|
||||
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
|
||||
|
||||
roberta.eval()
|
||||
roberta.cuda()
|
||||
|
||||
|
||||
preds = roberta.fill_mask('I like <mask> and apples', topk=3)
|
||||
#import pdb; pdb.set_trace()
|
||||
|
||||
# raise CUDA RuntimeError from which
|
||||
# the process does not recover
|
||||
BLACKLIST = ['aeeadb08042bbd49dcbefcefa1f13806',
|
||||
'01ba303704bb62bcb59f8cb7cb5663d7',
|
||||
'98bdfa711364f45f1bcffb1359793614',
|
||||
'a9da7950abcbd531a5207c04c3bdc840',
|
||||
'4cd7f730ee72451406afa89c5c6431d6',
|
||||
]
|
||||
|
||||
def predict(f_in_path,f_out_path):
|
||||
f_in = open(f_in_path,'r', newline='\n')
|
||||
f_out = open(f_out_path,'w', newline='\n')
|
||||
|
||||
for line in tqdm(f_in,total = 88000):
|
||||
id,_, before, after = line.split('\t')
|
||||
before = before.replace('\\n', '\n')
|
||||
after = after.replace('\\n', '\n')
|
||||
before = ' '.join(before.split(' ')[-40:]) # tu można poprawić, żeby śmigał na tokenal spm a nie zakładał że jest jak ze spacjami
|
||||
after = ' '.join(after.split(' ')[:40])
|
||||
input = before + ' <mask> ' + after
|
||||
try:
|
||||
if id in BLACKLIST:
|
||||
f_out.write(':1\n')
|
||||
continue
|
||||
preds = roberta.fill_mask(input, topk=10)
|
||||
hyps = []
|
||||
probs_sum = 0.0
|
||||
for pred in preds:
|
||||
if pred[2] == '<unk>':
|
||||
continue
|
||||
hyps.append(pred[2].rstrip().lstrip() + ':' + str(pred[1]))
|
||||
probs_sum += pred[1]
|
||||
hyps.append(':' + str(1 - probs_sum))
|
||||
preds_line = ' '.join(hyps)
|
||||
f_out.write(preds_line + '\n')
|
||||
except RuntimeError:
|
||||
import pdb ; pdb.set_trace()
|
||||
print('RUNTIMEERROR')
|
||||
f_out.write(':1\n')
|
||||
|
||||
f_out.close()
|
||||
|
||||
predict('../dev-0/in.tsv', '../dev-0/out.tsv')
|
||||
predict('../test-A/in.tsv', '../test-A/out.tsv')
|
52497
test-A/out.tsv
52497
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user