challenging-america-word-ga.../bigrams_predict.py

122 lines
3.4 KiB
Python

import pickle
import sys
from collections import Counter
from tqdm import tqdm
from math import log
from itertools import dropwhile
def get_bigram_prob(context, word, prefix: bool, word_stats, bigram_stats):
if prefix:
bigram_count = bigram_stats.get((context, word))
else:
bigram_count = bigram_stats.get((word, context))
context_count = word_stats.get(context)
if not context_count or not bigram_count:
return 0
prob = log(bigram_count / context_count)
return prob
with open('word_stats.pickle', 'rb') as file:
word_stats = pickle.load(file)
with open('bigram_stats.pickle', 'rb') as file:
bigram_stats = pickle.load(file)
# print("Unpickled")
for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, word_stats.most_common()):
del word_stats[key]
for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, bigram_stats.most_common()):
del bigram_stats[key]
# print(word_stats.most_common(10))
# print(bigram_stats.most_common(10))
line_num = 1
for line in tqdm(sys.stdin):
# print(f"Line {line_num}")
line_num += 1
_, _, _, _, _, _, l_context, r_context = line.split("\t")
l_context = l_context.replace(r"\n", " ")
r_context = r_context.replace(r"\n", " ")
prev_word = l_context.split()[-1]
next_word = r_context.split()[0]
# print(f"Context: {prev_word} <MASK> {next_word}")
# print(f"{prev_word in word_stats=}")
# print(f"{next_word in word_stats=}")
l_probs = dict()
r_probs = dict()
for key in bigram_stats.keys():
if key[0] == prev_word:
l_probs[key[1]] = get_bigram_prob(prev_word, key[1], True, word_stats, bigram_stats)
if key[1] == next_word:
r_probs[key[0]] = get_bigram_prob(key[0], next_word, False, word_stats, bigram_stats)
mult_probs = dict()
for key in l_probs.keys():
prob = float(l_probs.get(key, 0.0)) + float(r_probs.get(key, 0.0))
mult_probs[key] = prob
# if prob > 0:
# print(key)
sorted_probs = sorted(mult_probs.items(), key=lambda item: item[1], reverse=True)
# print(r_probs)
#print(mult_probs)
# print(len(sorted_probs))
# print(sorted_probs[:5])
k = 10
top_5 = sorted_probs[:k]
# sum = 0
# for word, prob in top_5:
# sum += prob
result = []
for word, prob in top_5:
# if sum != 0:
result.append(f"{word}:{prob}")
# else:
# result.append(f"{word}:{1/k}")
if not result:
top_5 = sorted(l_probs.items(), key=lambda item: item[1], reverse=True)
#print(len(top_5))
top_5 = top_5[:k]
# sum = 0
# for word, prob in top_5:
# sum += prob
result = []
for word, prob in top_5:
# if sum != 0:
result.append(f"{word}:{prob}")
# else:
# result.append(f"{word}:{1/k}")
if not result:
top_5 = sorted(r_probs.items(), key=lambda item: item[1], reverse=True)
#print(len(top_5))
top_5 = top_5[:k]
# sum = 0
# for word, prob in top_5:
# sum += prob
result = []
for word, prob in top_5:
# if sum != 0:
result.append(f"{word}:{prob}")
# else:
# result.append(f"{word}:{1/k}")
if not result:
result.append("the:-10.0")
sum = 0.01
print(" ".join(result) + f" :{-0.01}")