from tqdm import tqdm from transformers import pipeline def get_formatted(text): answers = unmasker(text, top_k=15) answers = {x['token_str']:x['score'] for x in answers} empty = 1 - sum(answers.values()) answers[''] = empty answers_str ='' for k,v in answers.items(): answers_str += k.strip()+':'+str(v) + ' ' return answers_str.rstrip(' ').lstrip(' ') def write(f_path_in, f_path_out): with open(f_path_in) as f_in, open(f_path_out,'w') as f_out: i = 0 for line in tqdm(f_in,total=10_600): char_context = 400 i+=1 #print(i) is_ok = False while not is_ok: try: date, left_text, right_text = line.rstrip().split('\t') l_in = date + left_text[-char_context:] + ' ' + right_text[:char_context] a = get_formatted(l_in) is_ok = True except: print('lowering context') char_context -= 50 if char_context < 60: a = ':1' print('lower threshold context exceeded') is_ok = True f_out.write(a + '\n') #left_text = line.rstrip().split('\t')[-2] #right_text = line.rstrip().split('\t')[-1] #l_in = left_text[-char_context:] + ' ' + right_text[:char_context] #a = get_formatted(l_in) #f_out.write(a + '\n') model = 'with_date/checkpoint-396000' unmasker = pipeline('fill-mask', model=model, device=0) write('./dev-0-date.tsv', '../dev-0/out.tsv') write('./test-A-date.tsv', '../test-A/out.tsv')