challenging-america-word-ga.../challam-roberta-without-date/04_predict.py
2021-12-11 15:03:30 +01:00

49 lines
1.7 KiB
Python

from tqdm import tqdm
from transformers import pipeline
def get_formatted(text):
answers = unmasker(text, top_k=15)
answers = {x['token_str']:x['score'] for x in answers}
empty = 1 - sum(answers.values())
answers[''] = empty
answers_str =''
for k,v in answers.items():
answers_str += k.strip()+':'+str(v) + ' '
return answers_str.rstrip(' ').lstrip(' ')
def write(f_path_in, f_path_out):
with open(f_path_in) as f_in, open(f_path_out,'w') as f_out:
i = 0
for line in tqdm(f_in,total=10_600):
char_context = 400
i+=1
#print(i)
is_ok = False
while not is_ok:
try:
left_text = line.rstrip().split('\t')[-2]
right_text = line.rstrip().split('\t')[-1]
l_in = left_text[-char_context:] + ' <mask> ' + right_text[:char_context]
a = get_formatted(l_in)
is_ok = True
except:
print('lowering context')
char_context -= 50
if char_context < 60:
a = ':1'
print('lower threshold context exceeded')
is_ok = True
f_out.write(a + '\n')
#left_text = line.rstrip().split('\t')[-2]
#right_text = line.rstrip().split('\t')[-1]
#l_in = left_text[-char_context:] + ' <mask> ' + right_text[:char_context]
#a = get_formatted(l_in)
#f_out.write(a + '\n')
model = 'without_date/checkpoint-395000'
unmasker = pipeline('fill-mask', model=model, device=0)
write('../dev-0/in.tsv', '../dev-0/out.tsv')
write('../test-A/in.tsv', '../test-A/out.tsv')