This commit is contained in:
Bartosz Karwacki 2022-04-25 15:44:44 +02:00
parent a7cd8979d3
commit dc7c0f2010
3 changed files with 17943 additions and 17945 deletions

File diff suppressed because it is too large Load Diff

22
run2.py
View File

@ -4,6 +4,7 @@ import regex as re
import kenlm
from english_words import english_words_alpha_set
from nltk import word_tokenize
from math import log10
from pathlib import Path
import os
import numpy as np
@ -63,31 +64,28 @@ def softmax(x):
return e_x / e_x.sum(axis=0)
def predict(model, before, after):
prob = 0.0
best = []
best_scores = []
for word in english_words_alpha_set:
text = ' '.join([before, word, after])
text_score = model.score(text, bos=False, eos=False)
if len(best) < 12:
best.append((word, text_score))
if len(best_scores) < 12:
best_scores.append((word, text_score))
else:
worst_score = None
for score in best:
for score in best_scores:
if not worst_score:
worst_score = score
else:
if worst_score[1] > score[1]:
worst_score = score
if worst_score[1] < text_score:
best.remove(worst_score)
best.append((word, text_score))
words = [word[0] for word in best]
probs = [prob[1] for prob in best]
probs = softmax(probs)
bests = sorted(zip(words, probs), key=lambda x:x[1], reverse=True)
best_scores.remove(worst_score)
best_scores.append((word, text_score))
probs = sorted(best_scores, key=lambda tup: tup[1], reverse=True)
pred_str = ''
for word, prob in bests:
for word, prob in probs:
pred_str += f'{word}:{prob} '
pred_str += f':{log10(0.99)}'
return pred_str
def make_prediction(model, path, result_path):

File diff suppressed because it is too large Load Diff