import pickle import sys from collections import Counter from tqdm import tqdm from math import log from itertools import dropwhile def get_bigram_prob(context, word, prefix: bool, word_stats, bigram_stats): if prefix: bigram_count = bigram_stats.get((context, word)) else: bigram_count = bigram_stats.get((word, context)) context_count = word_stats.get(context) if not context_count or not bigram_count: return 0 prob = log(bigram_count / context_count) return prob with open('word_stats.pickle', 'rb') as file: word_stats = pickle.load(file) with open('bigram_stats.pickle', 'rb') as file: bigram_stats = pickle.load(file) # print("Unpickled") for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, word_stats.most_common()): del word_stats[key] for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, bigram_stats.most_common()): del bigram_stats[key] # print(word_stats.most_common(10)) # print(bigram_stats.most_common(10)) line_num = 1 for line in tqdm(sys.stdin): # print(f"Line {line_num}") line_num += 1 _, _, _, _, _, _, l_context, r_context = line.split("\t") l_context = l_context.replace(r"\n", " ") r_context = r_context.replace(r"\n", " ") prev_word = l_context.split()[-1] next_word = r_context.split()[0] # print(f"Context: {prev_word} {next_word}") # print(f"{prev_word in word_stats=}") # print(f"{next_word in word_stats=}") l_probs = dict() r_probs = dict() for key in bigram_stats.keys(): if key[0] == prev_word: l_probs[key[1]] = get_bigram_prob(prev_word, key[1], True, word_stats, bigram_stats) if key[1] == next_word: r_probs[key[0]] = get_bigram_prob(key[0], next_word, False, word_stats, bigram_stats) mult_probs = dict() for key in l_probs.keys(): prob = float(l_probs.get(key, 0.0)) + float(r_probs.get(key, 0.0)) mult_probs[key] = prob # if prob > 0: # print(key) sorted_probs = sorted(mult_probs.items(), key=lambda item: item[1], reverse=True) # print(r_probs) #print(mult_probs) # print(len(sorted_probs)) # print(sorted_probs[:5]) k = 10 top_5 = sorted_probs[:k] # sum = 0 # for word, prob in top_5: # sum += prob result = [] for word, prob in top_5: # if sum != 0: result.append(f"{word}:{prob}") # else: # result.append(f"{word}:{1/k}") if not result: top_5 = sorted(l_probs.items(), key=lambda item: item[1], reverse=True) #print(len(top_5)) top_5 = top_5[:k] # sum = 0 # for word, prob in top_5: # sum += prob result = [] for word, prob in top_5: # if sum != 0: result.append(f"{word}:{prob}") # else: # result.append(f"{word}:{1/k}") if not result: top_5 = sorted(r_probs.items(), key=lambda item: item[1], reverse=True) #print(len(top_5)) top_5 = top_5[:k] # sum = 0 # for word, prob in top_5: # sum += prob result = [] for word, prob in top_5: # if sum != 0: result.append(f"{word}:{prob}") # else: # result.append(f"{word}:{1/k}") if not result: result.append("the:-10.0") sum = 0.01 print(" ".join(result) + f" :{-0.01}")