122 lines
3.4 KiB
Python
122 lines
3.4 KiB
Python
|
import pickle
|
||
|
import sys
|
||
|
from collections import Counter
|
||
|
from tqdm import tqdm
|
||
|
from math import log
|
||
|
|
||
|
from itertools import dropwhile
|
||
|
|
||
|
|
||
|
def get_bigram_prob(context, word, prefix: bool, word_stats, bigram_stats):
|
||
|
if prefix:
|
||
|
bigram_count = bigram_stats.get((context, word))
|
||
|
else:
|
||
|
bigram_count = bigram_stats.get((word, context))
|
||
|
|
||
|
context_count = word_stats.get(context)
|
||
|
|
||
|
if not context_count or not bigram_count:
|
||
|
return 0
|
||
|
|
||
|
prob = log(bigram_count / context_count)
|
||
|
return prob
|
||
|
|
||
|
|
||
|
with open('word_stats.pickle', 'rb') as file:
|
||
|
word_stats = pickle.load(file)
|
||
|
with open('bigram_stats.pickle', 'rb') as file:
|
||
|
bigram_stats = pickle.load(file)
|
||
|
|
||
|
# print("Unpickled")
|
||
|
|
||
|
for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, word_stats.most_common()):
|
||
|
del word_stats[key]
|
||
|
|
||
|
for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, bigram_stats.most_common()):
|
||
|
del bigram_stats[key]
|
||
|
|
||
|
# print(word_stats.most_common(10))
|
||
|
# print(bigram_stats.most_common(10))
|
||
|
|
||
|
line_num = 1
|
||
|
|
||
|
for line in tqdm(sys.stdin):
|
||
|
# print(f"Line {line_num}")
|
||
|
line_num += 1
|
||
|
_, _, _, _, _, _, l_context, r_context = line.split("\t")
|
||
|
l_context = l_context.replace(r"\n", " ")
|
||
|
r_context = r_context.replace(r"\n", " ")
|
||
|
prev_word = l_context.split()[-1]
|
||
|
next_word = r_context.split()[0]
|
||
|
# print(f"Context: {prev_word} <MASK> {next_word}")
|
||
|
# print(f"{prev_word in word_stats=}")
|
||
|
# print(f"{next_word in word_stats=}")
|
||
|
|
||
|
l_probs = dict()
|
||
|
r_probs = dict()
|
||
|
|
||
|
for key in bigram_stats.keys():
|
||
|
if key[0] == prev_word:
|
||
|
l_probs[key[1]] = get_bigram_prob(prev_word, key[1], True, word_stats, bigram_stats)
|
||
|
if key[1] == next_word:
|
||
|
r_probs[key[0]] = get_bigram_prob(key[0], next_word, False, word_stats, bigram_stats)
|
||
|
|
||
|
mult_probs = dict()
|
||
|
for key in l_probs.keys():
|
||
|
prob = float(l_probs.get(key, 0.0)) + float(r_probs.get(key, 0.0))
|
||
|
mult_probs[key] = prob
|
||
|
# if prob > 0:
|
||
|
# print(key)
|
||
|
|
||
|
sorted_probs = sorted(mult_probs.items(), key=lambda item: item[1], reverse=True)
|
||
|
# print(r_probs)
|
||
|
#print(mult_probs)
|
||
|
# print(len(sorted_probs))
|
||
|
# print(sorted_probs[:5])
|
||
|
|
||
|
k = 10
|
||
|
|
||
|
top_5 = sorted_probs[:k]
|
||
|
|
||
|
# sum = 0
|
||
|
# for word, prob in top_5:
|
||
|
# sum += prob
|
||
|
|
||
|
result = []
|
||
|
for word, prob in top_5:
|
||
|
# if sum != 0:
|
||
|
result.append(f"{word}:{prob}")
|
||
|
# else:
|
||
|
# result.append(f"{word}:{1/k}")
|
||
|
if not result:
|
||
|
top_5 = sorted(l_probs.items(), key=lambda item: item[1], reverse=True)
|
||
|
#print(len(top_5))
|
||
|
top_5 = top_5[:k]
|
||
|
# sum = 0
|
||
|
# for word, prob in top_5:
|
||
|
# sum += prob
|
||
|
|
||
|
result = []
|
||
|
for word, prob in top_5:
|
||
|
# if sum != 0:
|
||
|
result.append(f"{word}:{prob}")
|
||
|
# else:
|
||
|
# result.append(f"{word}:{1/k}")
|
||
|
if not result:
|
||
|
top_5 = sorted(r_probs.items(), key=lambda item: item[1], reverse=True)
|
||
|
#print(len(top_5))
|
||
|
top_5 = top_5[:k]
|
||
|
# sum = 0
|
||
|
# for word, prob in top_5:
|
||
|
# sum += prob
|
||
|
|
||
|
result = []
|
||
|
for word, prob in top_5:
|
||
|
# if sum != 0:
|
||
|
result.append(f"{word}:{prob}")
|
||
|
# else:
|
||
|
# result.append(f"{word}:{1/k}")
|
||
|
if not result:
|
||
|
result.append("the:-10.0")
|
||
|
sum = 0.01
|
||
|
print(" ".join(result) + f" :{-0.01}")
|