Add non complete grams

This commit is contained in:
Jan Nowak 2022-04-03 22:59:35 +02:00
parent e806e44383
commit 81f09b68d1

23
run.py
View File

@ -46,6 +46,8 @@ def load_train():
def predict(search_for_words): def predict(search_for_words):
trigrams = {} trigrams = {}
bigrams = {} bigrams = {}
trigrams_nc = {}
bigrams_nc = {}
index = 0 index = 0
expected = open('train/expected.tsv', 'r') expected = open('train/expected.tsv', 'r')
with lzma.open('train/in.tsv.xz', mode='rt') as file: with lzma.open('train/in.tsv.xz', mode='rt') as file:
@ -58,6 +60,9 @@ def predict(search_for_words):
if search_for_word[0] == words[0+mv] and search_for_word[1] == words[1+mv]: if search_for_word[0] == words[0+mv] and search_for_word[1] == words[1+mv]:
set_bigram_count(words[0+mv], words[1+mv], bigrams) set_bigram_count(words[0+mv], words[1+mv], bigrams)
set_trigram_count(expected_word, words[0+mv], words[1+mv], trigrams) set_trigram_count(expected_word, words[0+mv], words[1+mv], trigrams)
elif search_for_word[0] == words[0+mv]:
set_bigram_count(words[0+mv], words[1+mv], bigrams_nc)
set_trigram_count(expected_word, words[0+mv], words[1+mv], trigrams_nc)
if index == 100000: if index == 100000:
break break
@ -66,6 +71,8 @@ def predict(search_for_words):
print(len(search_for_words)) print(len(search_for_words))
print(len(bigrams)) print(len(bigrams))
print(len(trigrams)) print(len(trigrams))
print(len(bigrams_nc))
print(len(trigrams_nc))
left_context_search_for_word = {} left_context_search_for_word = {}
for bigram in bigrams: for bigram in bigrams:
@ -75,6 +82,15 @@ def predict(search_for_words):
max_count = trigrams[trigram] max_count = trigrams[trigram]
left_context = trigram.split("_")[0] left_context = trigram.split("_")[0]
left_context_search_for_word[bigram] = left_context left_context_search_for_word[bigram] = left_context
left_context_search_for_word_nc = {}
for bigram in bigrams_nc:
max_count = 0
for trigram in trigrams_nc:
if bigram == '_'.join(trigram.split("_")[1:3]) and trigrams_nc[trigram] > max_count:
max_count = trigrams_nc[trigram]
left_context = trigram.split("_")[0]
left_context_search_for_word_nc[bigram] = left_context
for index, search_for_word in enumerate(search_for_words): for index, search_for_word in enumerate(search_for_words):
hash_search_for_word = '_'.join(search_for_word) hash_search_for_word = '_'.join(search_for_word)
@ -82,7 +98,12 @@ def predict(search_for_words):
left_context = left_context_search_for_word[hash_search_for_word] left_context = left_context_search_for_word[hash_search_for_word]
print(f"{index+1}: {left_context} {' '.join(search_for_word)} {trigrams['_'.join([left_context]+search_for_word)]/bigrams[hash_search_for_word]}") print(f"{index+1}: {left_context} {' '.join(search_for_word)} {trigrams['_'.join([left_context]+search_for_word)]/bigrams[hash_search_for_word]}")
else: else:
print(f"{index+1}: ??? {' '.join(search_for_word)}") for lfc in left_context_search_for_word_nc:
if search_for_word[0] == lfc.split("_")[0]:
left_context = left_context_search_for_word[lfc]
print(f"{index+1}: {left_context} {' '.join(search_for_word)} {trigrams_nc['_'.join([left_context]+lfc)]/bigrams_nc[lfc]}")
else:
print(f"{index+1}: ??? {' '.join(search_for_word)}")
def load_dev(): def load_dev():
search_for_words = [] search_for_words = []