This commit is contained in:
Bartosz Karwacki 2022-04-25 15:44:44 +02:00
parent a7cd8979d3
commit dc7c0f2010
3 changed files with 17943 additions and 17945 deletions

File diff suppressed because it is too large Load Diff

22
run2.py
View File

@ -4,6 +4,7 @@ import regex as re
import kenlm import kenlm
from english_words import english_words_alpha_set from english_words import english_words_alpha_set
from nltk import word_tokenize from nltk import word_tokenize
from math import log10
from pathlib import Path from pathlib import Path
import os import os
import numpy as np import numpy as np
@ -63,31 +64,28 @@ def softmax(x):
return e_x / e_x.sum(axis=0) return e_x / e_x.sum(axis=0)
def predict(model, before, after): def predict(model, before, after):
prob = 0.0 best_scores = []
best = []
for word in english_words_alpha_set: for word in english_words_alpha_set:
text = ' '.join([before, word, after]) text = ' '.join([before, word, after])
text_score = model.score(text, bos=False, eos=False) text_score = model.score(text, bos=False, eos=False)
if len(best) < 12: if len(best_scores) < 12:
best.append((word, text_score)) best_scores.append((word, text_score))
else: else:
worst_score = None worst_score = None
for score in best: for score in best_scores:
if not worst_score: if not worst_score:
worst_score = score worst_score = score
else: else:
if worst_score[1] > score[1]: if worst_score[1] > score[1]:
worst_score = score worst_score = score
if worst_score[1] < text_score: if worst_score[1] < text_score:
best.remove(worst_score) best_scores.remove(worst_score)
best.append((word, text_score)) best_scores.append((word, text_score))
words = [word[0] for word in best] probs = sorted(best_scores, key=lambda tup: tup[1], reverse=True)
probs = [prob[1] for prob in best]
probs = softmax(probs)
bests = sorted(zip(words, probs), key=lambda x:x[1], reverse=True)
pred_str = '' pred_str = ''
for word, prob in bests: for word, prob in probs:
pred_str += f'{word}:{prob} ' pred_str += f'{word}:{prob} '
pred_str += f':{log10(0.99)}'
return pred_str return pred_str
def make_prediction(model, path, result_path): def make_prediction(model, path, result_path):

File diff suppressed because it is too large Load Diff