Kenlm
This commit is contained in:
parent
8359ba19e6
commit
43036240f0
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
50
run.py
50
run.py
@ -6,7 +6,8 @@ from nltk import trigrams
|
||||
import regex as re
|
||||
import lzma
|
||||
import kenlm
|
||||
|
||||
from math import log10
|
||||
from english_words import english_words_set
|
||||
|
||||
class WordPred:
|
||||
|
||||
@ -31,15 +32,17 @@ class WordPred:
|
||||
with open(output_file, 'w') as out:
|
||||
with lzma.open(file_path, mode='rt') as file:
|
||||
for text in self.read_file(file):
|
||||
for word in text.split(" "):
|
||||
if word not in self.words:
|
||||
out.write(word + "\n")
|
||||
self.words.add(word)
|
||||
for mword in text.split(" "):
|
||||
if mword not in self.words:
|
||||
out.write(mword + "\n")
|
||||
self.words.add(mword)
|
||||
|
||||
def read_words(self, file_path):
|
||||
with open(file_path, 'r') as fin:
|
||||
for word in fin.readline():
|
||||
self.words.add(word.replace("\n",""))
|
||||
for word in fin.readlines():
|
||||
word = word.replace("\n", "")
|
||||
if word:
|
||||
self.words.add(word)
|
||||
|
||||
|
||||
def create_train_file(self, file_path, output_path, rows=10000):
|
||||
@ -63,22 +66,26 @@ class WordPred:
|
||||
outputf.write(prediction + '\n')
|
||||
|
||||
def predict_probs(self, word1, word2):
|
||||
preds = []
|
||||
for word in english_words_set:
|
||||
sentence = word1 + ' ' + word + ' ' + word2
|
||||
words_score = self.model.score(sentence, bos=False, eos=False)
|
||||
|
||||
|
||||
total_prob = 0.0
|
||||
if len(preds) < 12:
|
||||
preds.append((word, words_score))
|
||||
else:
|
||||
min_score = preds[0]
|
||||
for score in preds:
|
||||
if min_score[1] > score[1]:
|
||||
min_score = score
|
||||
if min_score[1] < words_score:
|
||||
preds.remove(min_score)
|
||||
preds.append((word, words_score))
|
||||
probs = sorted(preds, key=lambda sc: sc[1], reverse=True)
|
||||
str_prediction = ''
|
||||
|
||||
for word, prob in most_common.items():
|
||||
total_prob += prob
|
||||
for word, prob in probs:
|
||||
str_prediction += f'{word}:{prob} '
|
||||
|
||||
if total_prob == 0.0:
|
||||
return 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
|
||||
|
||||
if 1 - total_prob >= 0.01:
|
||||
str_prediction += f":{1 - total_prob}"
|
||||
else:
|
||||
str_prediction += f":0.01"
|
||||
str_prediction += f':{log10(0.99)}'
|
||||
|
||||
return str_prediction
|
||||
|
||||
@ -86,3 +93,6 @@ if __name__ == "__main__":
|
||||
wp = WordPred()
|
||||
# wp.create_train_file("train/in.tsv.xz", "train/in.txt")
|
||||
# wp.fill_words("train/in.tsv.xz", "words.txt")
|
||||
# wp.read_words("words.txt")
|
||||
wp.generate_outputs("dev-0/in.tsv.xz", "dev-0/out3.tsv")
|
||||
wp.generate_outputs("test-A/in.tsv.xz", "test-A/out3.tsv")
|
||||
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user