ireland-news-word-gap-predi.../roberta_with_year_ft/04_predict.py

26 lines
819 B
Python
Raw Permalink Normal View History

2021-11-04 11:37:51 +01:00
from tqdm import tqdm
from transformers import pipeline
def get_formatted(text):
answers = unmasker(text, top_k=15)
answers = {x['token_str']:x['score'] for x in answers}
empty = 1 - sum(answers.values())
answers[''] = empty
answers_str =''
for k,v in answers.items():
answers_str += k.strip()+':'+str(v) + ' '
return answers_str.rstrip(' ').lstrip(' ')
def write(f_path_in, f_path_out):
with open(f_path_in) as f_in, open(f_path_out,'w') as f_out:
next(f_in)
for line in tqdm(f_in,total=150_000):
l_in = line.rstrip()
a = get_formatted(l_in)
f_out.write(a + '\n')
2021-11-04 12:09:55 +01:00
model = 'robertamodel'
2021-11-04 11:37:51 +01:00
unmasker = pipeline('fill-mask', model=model)
write('dev-0_in.csv', '../dev-0/out.tsv')
write('test-A_in.csv', '../test-A/out.tsv')