roberta_large_no_ft
This commit is contained in:
parent
3030d46c47
commit
df02c30ea0
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
48
roberta_large_no_ft/04_predict.py
Normal file
48
roberta_large_no_ft/04_predict.py
Normal file
@ -0,0 +1,48 @@
|
||||
from tqdm import tqdm
|
||||
from transformers import pipeline
|
||||
|
||||
def get_formatted(text):
|
||||
answers = unmasker(text, top_k=15)
|
||||
answers = {x['token_str']:x['score'] for x in answers}
|
||||
empty = 1 - sum(answers.values())
|
||||
answers[''] = empty
|
||||
answers_str =''
|
||||
for k,v in answers.items():
|
||||
answers_str += k.strip()+':'+str(v) + ' '
|
||||
return answers_str.rstrip(' ').lstrip(' ')
|
||||
|
||||
def write(f_path_in, f_path_out):
|
||||
with open(f_path_in) as f_in, open(f_path_out,'w') as f_out:
|
||||
i = 0
|
||||
for line in tqdm(f_in,total=10_600):
|
||||
char_context = 400
|
||||
i+=1
|
||||
#print(i)
|
||||
is_ok = False
|
||||
while not is_ok:
|
||||
try:
|
||||
left_text = line.rstrip().split('\t')[-2]
|
||||
right_text = line.rstrip().split('\t')[-1]
|
||||
l_in = left_text[-char_context:] + ' <mask> ' + right_text[:char_context]
|
||||
a = get_formatted(l_in)
|
||||
is_ok = True
|
||||
except:
|
||||
print('lowering context')
|
||||
char_context -= 50
|
||||
if char_context < 60:
|
||||
a = ':1'
|
||||
print('lower threshold context exceeded')
|
||||
is_ok = True
|
||||
|
||||
f_out.write(a + '\n')
|
||||
#left_text = line.rstrip().split('\t')[-2]
|
||||
#right_text = line.rstrip().split('\t')[-1]
|
||||
#l_in = left_text[-char_context:] + ' <mask> ' + right_text[:char_context]
|
||||
#a = get_formatted(l_in)
|
||||
|
||||
#f_out.write(a + '\n')
|
||||
|
||||
model = 'roberta-large'
|
||||
unmasker = pipeline('fill-mask', model=model, device=0)
|
||||
write('../dev-0/in.tsv', '../dev-0/out.tsv')
|
||||
write('../test-A/in.tsv', '../test-A/out.tsv')
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user