challenging-america-word-ga.../run_nc.py
2022-04-04 17:54:10 +02:00

130 lines
4.9 KiB
Python

from encodings import search_function
import lzma
from re import L
import regex as re
import string
import queue
# text = lzma.open('train/in.tsv.xz').read()
def read_file(file):
for line in file:
text = line.split("\t")
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n"," ").replace("\n","").lower())).split(" ")
def get_words(file):
for words in read_file(file):
yield from words
def set_bigram_count(first_word, second_word, bigrams):
if f"{first_word}_{second_word}" not in bigrams:
bigrams[f"{first_word}_{second_word}"] = 1
else:
bigrams[f"{first_word}_{second_word}"] += 1
def set_trigram_count(first_word, second_word, third_word, trigrams):
if f"{first_word}_{second_word}_{third_word}" not in trigrams:
trigrams[f"{first_word}_{second_word}_{third_word}"] = 1
else:
trigrams[f"{first_word}_{second_word}_{third_word}"] += 1
def load_train():
index = 0
with lzma.open('train/in.tsv.xz', mode='rt') as file:
for words in read_file(file):
expected_word = re.sub(r"[^\w\d'\s]+", '', expected.readline().replace("\n", "").lower())
mv = 0
if not words[0]:
mv = 1
set_bigram_count(words[0+mv], words[1+mv], bigrams)
set_trigram_count(expected_word, words[0+mv], words[1+mv], trigrams)
print(bigrams)
print(trigrams)
def predict(search_for_words):
trigrams = {}
bigrams = {}
trigrams_nc = {}
bigrams_nc = {}
index = 0
expected = open('train/expected.tsv', 'r')
with lzma.open('train/in.tsv.xz', mode='rt') as file:
for words in read_file(file):
expected_word = re.sub(r"[^\w\d'\s]+", '', expected.readline().replace("\n", "").lower())
mv = 0
if not words[0]:
mv = 1
for search_for_word in search_for_words:
if search_for_word[0] == words[0+mv] and search_for_word[1] == words[1+mv]:
set_bigram_count(words[0+mv], words[1+mv], bigrams)
set_trigram_count(expected_word, words[0+mv], words[1+mv], trigrams)
elif search_for_word[0] == words[0+mv]:
set_bigram_count(words[0+mv], words[1+mv], bigrams_nc)
set_trigram_count(expected_word, words[0+mv], words[1+mv], trigrams_nc)
if index == 100000:
break
index += 1
print(len(search_for_words))
print(len(bigrams))
print(len(trigrams))
print(len(bigrams_nc))
print(len(trigrams_nc))
left_context_search_for_word = {}
for bigram in bigrams:
max_count = 0
for trigram in trigrams:
if bigram == '_'.join(trigram.split("_")[1:3]) and trigrams[trigram] > max_count:
max_count = trigrams[trigram]
left_context = trigram.split("_")[0]
left_context_search_for_word[bigram] = left_context
left_context_search_for_word_nc = {}
for bigram in bigrams_nc:
max_count = 0
for trigram in trigrams_nc:
if bigram == '_'.join(trigram.split("_")[1:3]) and trigrams_nc[trigram] > max_count:
max_count = trigrams_nc[trigram]
left_context = trigram.split("_")[0]
left_context_search_for_word_nc[bigram] = left_context
for index, search_for_word in enumerate(search_for_words):
hash_search_for_word = '_'.join(search_for_word)
if hash_search_for_word in left_context_search_for_word:
left_context = left_context_search_for_word[hash_search_for_word]
print(f"{index+1}: {left_context} {' '.join(search_for_word)} {trigrams['_'.join([left_context]+search_for_word)]/bigrams[hash_search_for_word]}")
else:
for lfc in left_context_search_for_word_nc:
if search_for_word[0] == lfc.split("_")[0]:
left_context = left_context_search_for_word[lfc]
print(f"{index+1}: {left_context} {' '.join(search_for_word)} {trigrams_nc['_'.join([left_context]+lfc)]/bigrams_nc[lfc]}")
else:
print(f"{index+1}: ??? {' '.join(search_for_word)}")
def load_dev():
search_for_words = []
with lzma.open('dev-0/in.tsv.xz', mode='rt') as file:
index = 0
for words in read_file(file):
if words[0]:
search_for_words.append([words[0], words[1]])
else:
search_for_words.append([words[1], words[2]])
if index == 100:
break
index += 1
print(search_for_words)
return search_for_words
if __name__ == "__main__":
# load_train()
# load_dev()
predict(load_dev())
# with lzma.open('train/in.tsv.xz', mode='rt') as file:
# index = 0
# for _ in get_words(file):
# index += 1
# print(index) # 141820215