Add missing code for bigram model

This commit is contained in:
s444380 2023-04-22 18:58:52 +02:00
parent b64c1769e0
commit 436e58c8a3
2 changed files with 162 additions and 0 deletions

122
bigrams_predict.py Normal file
View File

@ -0,0 +1,122 @@
import pickle
import sys
from collections import Counter
from tqdm import tqdm
from math import log
from itertools import dropwhile
def get_bigram_prob(context, word, prefix: bool, word_stats, bigram_stats):
if prefix:
bigram_count = bigram_stats.get((context, word))
else:
bigram_count = bigram_stats.get((word, context))
context_count = word_stats.get(context)
if not context_count or not bigram_count:
return 0
prob = log(bigram_count / context_count)
return prob
with open('word_stats.pickle', 'rb') as file:
word_stats = pickle.load(file)
with open('bigram_stats.pickle', 'rb') as file:
bigram_stats = pickle.load(file)
# print("Unpickled")
for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, word_stats.most_common()):
del word_stats[key]
for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, bigram_stats.most_common()):
del bigram_stats[key]
# print(word_stats.most_common(10))
# print(bigram_stats.most_common(10))
line_num = 1
for line in tqdm(sys.stdin):
# print(f"Line {line_num}")
line_num += 1
_, _, _, _, _, _, l_context, r_context = line.split("\t")
l_context = l_context.replace(r"\n", " ")
r_context = r_context.replace(r"\n", " ")
prev_word = l_context.split()[-1]
next_word = r_context.split()[0]
# print(f"Context: {prev_word} <MASK> {next_word}")
# print(f"{prev_word in word_stats=}")
# print(f"{next_word in word_stats=}")
l_probs = dict()
r_probs = dict()
for key in bigram_stats.keys():
if key[0] == prev_word:
l_probs[key[1]] = get_bigram_prob(prev_word, key[1], True, word_stats, bigram_stats)
if key[1] == next_word:
r_probs[key[0]] = get_bigram_prob(key[0], next_word, False, word_stats, bigram_stats)
mult_probs = dict()
for key in l_probs.keys():
prob = float(l_probs.get(key, 0.0)) + float(r_probs.get(key, 0.0))
mult_probs[key] = prob
# if prob > 0:
# print(key)
sorted_probs = sorted(mult_probs.items(), key=lambda item: item[1], reverse=True)
# print(r_probs)
#print(mult_probs)
# print(len(sorted_probs))
# print(sorted_probs[:5])
k = 10
top_5 = sorted_probs[:k]
# sum = 0
# for word, prob in top_5:
# sum += prob
result = []
for word, prob in top_5:
# if sum != 0:
result.append(f"{word}:{prob}")
# else:
# result.append(f"{word}:{1/k}")
if not result:
top_5 = sorted(l_probs.items(), key=lambda item: item[1], reverse=True)
#print(len(top_5))
top_5 = top_5[:k]
# sum = 0
# for word, prob in top_5:
# sum += prob
result = []
for word, prob in top_5:
# if sum != 0:
result.append(f"{word}:{prob}")
# else:
# result.append(f"{word}:{1/k}")
if not result:
top_5 = sorted(r_probs.items(), key=lambda item: item[1], reverse=True)
#print(len(top_5))
top_5 = top_5[:k]
# sum = 0
# for word, prob in top_5:
# sum += prob
result = []
for word, prob in top_5:
# if sum != 0:
result.append(f"{word}:{prob}")
# else:
# result.append(f"{word}:{1/k}")
if not result:
result.append("the:-10.0")
sum = 0.01
print(" ".join(result) + f" :{-0.01}")

40
bigrams_train.py Normal file
View File

@ -0,0 +1,40 @@
import sys
import lzma
import regex as re
import pickle
from tqdm import tqdm
from collections import Counter
def get_words(text):
for m in re.finditer(r'[\p{L}\']+', text):
yield m.group(0)
def get_ngrams(iterable, n):
ngram = []
for item in iterable:
ngram.append(item)
if len(ngram) == n:
yield tuple(ngram)
ngram = ngram[1:]
def get_stats():
word_stats = Counter()
bigram_stats = Counter()
with lzma.open("train/in.tsv.xz", mode="rt", encoding="utf-8") as file:
for line in tqdm(file):
_, _, _, _, _, _, l_context, r_context = line.split("\t")
text = f"{l_context.strip()} {r_context.strip()}".replace("\n", " ")
word_stats.update(get_words(text))
bigram_stats.update(get_ngrams(get_words(text), 2))
with open("word_stats.pickle", "wb") as file:
pickle.dump(word_stats, file, protocol=pickle.HIGHEST_PROTOCOL)
with open("bigram_stats.pickle", "wb") as file:
pickle.dump(bigram_stats, file, protocol=pickle.HIGHEST_PROTOCOL)
get_stats()