diff --git a/bigrams_predict.py b/bigrams_predict.py new file mode 100644 index 0000000..d675b9c --- /dev/null +++ b/bigrams_predict.py @@ -0,0 +1,122 @@ +import pickle +import sys +from collections import Counter +from tqdm import tqdm +from math import log + +from itertools import dropwhile + + +def get_bigram_prob(context, word, prefix: bool, word_stats, bigram_stats): + if prefix: + bigram_count = bigram_stats.get((context, word)) + else: + bigram_count = bigram_stats.get((word, context)) + + context_count = word_stats.get(context) + + if not context_count or not bigram_count: + return 0 + + prob = log(bigram_count / context_count) + return prob + + +with open('word_stats.pickle', 'rb') as file: + word_stats = pickle.load(file) +with open('bigram_stats.pickle', 'rb') as file: + bigram_stats = pickle.load(file) + +# print("Unpickled") + +for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, word_stats.most_common()): + del word_stats[key] + +for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, bigram_stats.most_common()): + del bigram_stats[key] + +# print(word_stats.most_common(10)) +# print(bigram_stats.most_common(10)) + +line_num = 1 + +for line in tqdm(sys.stdin): + # print(f"Line {line_num}") + line_num += 1 + _, _, _, _, _, _, l_context, r_context = line.split("\t") + l_context = l_context.replace(r"\n", " ") + r_context = r_context.replace(r"\n", " ") + prev_word = l_context.split()[-1] + next_word = r_context.split()[0] + # print(f"Context: {prev_word} {next_word}") + # print(f"{prev_word in word_stats=}") + # print(f"{next_word in word_stats=}") + + l_probs = dict() + r_probs = dict() + + for key in bigram_stats.keys(): + if key[0] == prev_word: + l_probs[key[1]] = get_bigram_prob(prev_word, key[1], True, word_stats, bigram_stats) + if key[1] == next_word: + r_probs[key[0]] = get_bigram_prob(key[0], next_word, False, word_stats, bigram_stats) + + mult_probs = dict() + for key in l_probs.keys(): + prob = float(l_probs.get(key, 0.0)) + float(r_probs.get(key, 0.0)) + mult_probs[key] = prob + # if prob > 0: + # print(key) + + sorted_probs = sorted(mult_probs.items(), key=lambda item: item[1], reverse=True) + # print(r_probs) + #print(mult_probs) + # print(len(sorted_probs)) + # print(sorted_probs[:5]) + + k = 10 + + top_5 = sorted_probs[:k] + + # sum = 0 + # for word, prob in top_5: + # sum += prob + + result = [] + for word, prob in top_5: + # if sum != 0: + result.append(f"{word}:{prob}") + # else: + # result.append(f"{word}:{1/k}") + if not result: + top_5 = sorted(l_probs.items(), key=lambda item: item[1], reverse=True) + #print(len(top_5)) + top_5 = top_5[:k] + # sum = 0 + # for word, prob in top_5: + # sum += prob + + result = [] + for word, prob in top_5: + # if sum != 0: + result.append(f"{word}:{prob}") + # else: + # result.append(f"{word}:{1/k}") + if not result: + top_5 = sorted(r_probs.items(), key=lambda item: item[1], reverse=True) + #print(len(top_5)) + top_5 = top_5[:k] + # sum = 0 + # for word, prob in top_5: + # sum += prob + + result = [] + for word, prob in top_5: + # if sum != 0: + result.append(f"{word}:{prob}") + # else: + # result.append(f"{word}:{1/k}") + if not result: + result.append("the:-10.0") + sum = 0.01 + print(" ".join(result) + f" :{-0.01}") \ No newline at end of file diff --git a/bigrams_train.py b/bigrams_train.py new file mode 100644 index 0000000..50794fa --- /dev/null +++ b/bigrams_train.py @@ -0,0 +1,40 @@ +import sys +import lzma +import regex as re +import pickle +from tqdm import tqdm +from collections import Counter + +def get_words(text): + for m in re.finditer(r'[\p{L}\']+', text): + yield m.group(0) + +def get_ngrams(iterable, n): + ngram = [] + for item in iterable: + ngram.append(item) + if len(ngram) == n: + yield tuple(ngram) + ngram = ngram[1:] + + +def get_stats(): + word_stats = Counter() + bigram_stats = Counter() + + with lzma.open("train/in.tsv.xz", mode="rt", encoding="utf-8") as file: + for line in tqdm(file): + _, _, _, _, _, _, l_context, r_context = line.split("\t") + text = f"{l_context.strip()} {r_context.strip()}".replace("\n", " ") + word_stats.update(get_words(text)) + bigram_stats.update(get_ngrams(get_words(text), 2)) + + with open("word_stats.pickle", "wb") as file: + pickle.dump(word_stats, file, protocol=pickle.HIGHEST_PROTOCOL) + with open("bigram_stats.pickle", "wb") as file: + pickle.dump(bigram_stats, file, protocol=pickle.HIGHEST_PROTOCOL) + + +get_stats() + + \ No newline at end of file