kenlm
This commit is contained in:
parent
a7cd8979d3
commit
dc7c0f2010
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
22
run2.py
22
run2.py
@ -4,6 +4,7 @@ import regex as re
|
||||
import kenlm
|
||||
from english_words import english_words_alpha_set
|
||||
from nltk import word_tokenize
|
||||
from math import log10
|
||||
from pathlib import Path
|
||||
import os
|
||||
import numpy as np
|
||||
@ -63,31 +64,28 @@ def softmax(x):
|
||||
return e_x / e_x.sum(axis=0)
|
||||
|
||||
def predict(model, before, after):
|
||||
prob = 0.0
|
||||
best = []
|
||||
best_scores = []
|
||||
for word in english_words_alpha_set:
|
||||
text = ' '.join([before, word, after])
|
||||
text_score = model.score(text, bos=False, eos=False)
|
||||
if len(best) < 12:
|
||||
best.append((word, text_score))
|
||||
if len(best_scores) < 12:
|
||||
best_scores.append((word, text_score))
|
||||
else:
|
||||
worst_score = None
|
||||
for score in best:
|
||||
for score in best_scores:
|
||||
if not worst_score:
|
||||
worst_score = score
|
||||
else:
|
||||
if worst_score[1] > score[1]:
|
||||
worst_score = score
|
||||
if worst_score[1] < text_score:
|
||||
best.remove(worst_score)
|
||||
best.append((word, text_score))
|
||||
words = [word[0] for word in best]
|
||||
probs = [prob[1] for prob in best]
|
||||
probs = softmax(probs)
|
||||
bests = sorted(zip(words, probs), key=lambda x:x[1], reverse=True)
|
||||
best_scores.remove(worst_score)
|
||||
best_scores.append((word, text_score))
|
||||
probs = sorted(best_scores, key=lambda tup: tup[1], reverse=True)
|
||||
pred_str = ''
|
||||
for word, prob in bests:
|
||||
for word, prob in probs:
|
||||
pred_str += f'{word}:{prob} '
|
||||
pred_str += f':{log10(0.99)}'
|
||||
return pred_str
|
||||
|
||||
def make_prediction(model, path, result_path):
|
||||
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user