import torch from fairseq.models.roberta import RobertaModel from fairseq import hub_utils from fairseq.models.roberta import RobertaModel, RobertaHubInterface import os from tqdm import tqdm roberta = RobertaModel.from_pretrained('checkpoint_final') roberta.eval() roberta.cuda() preds = roberta.fill_mask('I like and apples', topk=3) #import pdb; pdb.set_trace() # raise CUDA RuntimeError from which # the process does not recover BLACKLIST = ['aeeadb08042bbd49dcbefcefa1f13806', '01ba303704bb62bcb59f8cb7cb5663d7', '98bdfa711364f45f1bcffb1359793614', 'a9da7950abcbd531a5207c04c3bdc840', '4cd7f730ee72451406afa89c5c6431d6', ] def predict(f_in_path,f_out_path): f_in = open(f_in_path,'r', newline='\n') f_out = open(f_out_path,'w', newline='\n') for line in tqdm(f_in,total = 88000): id,_, before, after = line.split('\t') before = before.replace('\\n', '\n') after = after.replace('\\n', '\n') before = ' '.join(before.split(' ')[-40:]) # tu można poprawić, żeby śmigał na tokenal spm a nie zakładał że jest jak ze spacjami after = ' '.join(after.split(' ')[:40]) input = before + ' ' + after try: if id in BLACKLIST: f_out.write(':1\n') continue preds = roberta.fill_mask(input, topk=10) hyps = [] probs_sum = 0.0 for pred in preds: if pred[2] == '': continue hyps.append(pred[2].rstrip().lstrip() + ':' + str(pred[1])) probs_sum += pred[1] hyps.append(':' + str(1 - probs_sum)) preds_line = ' '.join(hyps) f_out.write(preds_line + '\n') except RuntimeError: import pdb ; pdb.set_trace() print('RUNTIMEERROR') f_out.write(':1\n') f_out.close() predict('../dev-0/in.tsv', '../dev-0/out.tsv') predict('../test-A/in.tsv', '../test-A/out.tsv')