Add missing code for bigram model
This commit is contained in:
parent
b64c1769e0
commit
436e58c8a3
122
bigrams_predict.py
Normal file
122
bigrams_predict.py
Normal file
@ -0,0 +1,122 @@
|
||||
import pickle
|
||||
import sys
|
||||
from collections import Counter
|
||||
from tqdm import tqdm
|
||||
from math import log
|
||||
|
||||
from itertools import dropwhile
|
||||
|
||||
|
||||
def get_bigram_prob(context, word, prefix: bool, word_stats, bigram_stats):
|
||||
if prefix:
|
||||
bigram_count = bigram_stats.get((context, word))
|
||||
else:
|
||||
bigram_count = bigram_stats.get((word, context))
|
||||
|
||||
context_count = word_stats.get(context)
|
||||
|
||||
if not context_count or not bigram_count:
|
||||
return 0
|
||||
|
||||
prob = log(bigram_count / context_count)
|
||||
return prob
|
||||
|
||||
|
||||
with open('word_stats.pickle', 'rb') as file:
|
||||
word_stats = pickle.load(file)
|
||||
with open('bigram_stats.pickle', 'rb') as file:
|
||||
bigram_stats = pickle.load(file)
|
||||
|
||||
# print("Unpickled")
|
||||
|
||||
for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, word_stats.most_common()):
|
||||
del word_stats[key]
|
||||
|
||||
for key, count in dropwhile(lambda key_count: key_count[1] >= 1000, bigram_stats.most_common()):
|
||||
del bigram_stats[key]
|
||||
|
||||
# print(word_stats.most_common(10))
|
||||
# print(bigram_stats.most_common(10))
|
||||
|
||||
line_num = 1
|
||||
|
||||
for line in tqdm(sys.stdin):
|
||||
# print(f"Line {line_num}")
|
||||
line_num += 1
|
||||
_, _, _, _, _, _, l_context, r_context = line.split("\t")
|
||||
l_context = l_context.replace(r"\n", " ")
|
||||
r_context = r_context.replace(r"\n", " ")
|
||||
prev_word = l_context.split()[-1]
|
||||
next_word = r_context.split()[0]
|
||||
# print(f"Context: {prev_word} <MASK> {next_word}")
|
||||
# print(f"{prev_word in word_stats=}")
|
||||
# print(f"{next_word in word_stats=}")
|
||||
|
||||
l_probs = dict()
|
||||
r_probs = dict()
|
||||
|
||||
for key in bigram_stats.keys():
|
||||
if key[0] == prev_word:
|
||||
l_probs[key[1]] = get_bigram_prob(prev_word, key[1], True, word_stats, bigram_stats)
|
||||
if key[1] == next_word:
|
||||
r_probs[key[0]] = get_bigram_prob(key[0], next_word, False, word_stats, bigram_stats)
|
||||
|
||||
mult_probs = dict()
|
||||
for key in l_probs.keys():
|
||||
prob = float(l_probs.get(key, 0.0)) + float(r_probs.get(key, 0.0))
|
||||
mult_probs[key] = prob
|
||||
# if prob > 0:
|
||||
# print(key)
|
||||
|
||||
sorted_probs = sorted(mult_probs.items(), key=lambda item: item[1], reverse=True)
|
||||
# print(r_probs)
|
||||
#print(mult_probs)
|
||||
# print(len(sorted_probs))
|
||||
# print(sorted_probs[:5])
|
||||
|
||||
k = 10
|
||||
|
||||
top_5 = sorted_probs[:k]
|
||||
|
||||
# sum = 0
|
||||
# for word, prob in top_5:
|
||||
# sum += prob
|
||||
|
||||
result = []
|
||||
for word, prob in top_5:
|
||||
# if sum != 0:
|
||||
result.append(f"{word}:{prob}")
|
||||
# else:
|
||||
# result.append(f"{word}:{1/k}")
|
||||
if not result:
|
||||
top_5 = sorted(l_probs.items(), key=lambda item: item[1], reverse=True)
|
||||
#print(len(top_5))
|
||||
top_5 = top_5[:k]
|
||||
# sum = 0
|
||||
# for word, prob in top_5:
|
||||
# sum += prob
|
||||
|
||||
result = []
|
||||
for word, prob in top_5:
|
||||
# if sum != 0:
|
||||
result.append(f"{word}:{prob}")
|
||||
# else:
|
||||
# result.append(f"{word}:{1/k}")
|
||||
if not result:
|
||||
top_5 = sorted(r_probs.items(), key=lambda item: item[1], reverse=True)
|
||||
#print(len(top_5))
|
||||
top_5 = top_5[:k]
|
||||
# sum = 0
|
||||
# for word, prob in top_5:
|
||||
# sum += prob
|
||||
|
||||
result = []
|
||||
for word, prob in top_5:
|
||||
# if sum != 0:
|
||||
result.append(f"{word}:{prob}")
|
||||
# else:
|
||||
# result.append(f"{word}:{1/k}")
|
||||
if not result:
|
||||
result.append("the:-10.0")
|
||||
sum = 0.01
|
||||
print(" ".join(result) + f" :{-0.01}")
|
40
bigrams_train.py
Normal file
40
bigrams_train.py
Normal file
@ -0,0 +1,40 @@
|
||||
import sys
|
||||
import lzma
|
||||
import regex as re
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
from collections import Counter
|
||||
|
||||
def get_words(text):
|
||||
for m in re.finditer(r'[\p{L}\']+', text):
|
||||
yield m.group(0)
|
||||
|
||||
def get_ngrams(iterable, n):
|
||||
ngram = []
|
||||
for item in iterable:
|
||||
ngram.append(item)
|
||||
if len(ngram) == n:
|
||||
yield tuple(ngram)
|
||||
ngram = ngram[1:]
|
||||
|
||||
|
||||
def get_stats():
|
||||
word_stats = Counter()
|
||||
bigram_stats = Counter()
|
||||
|
||||
with lzma.open("train/in.tsv.xz", mode="rt", encoding="utf-8") as file:
|
||||
for line in tqdm(file):
|
||||
_, _, _, _, _, _, l_context, r_context = line.split("\t")
|
||||
text = f"{l_context.strip()} {r_context.strip()}".replace("\n", " ")
|
||||
word_stats.update(get_words(text))
|
||||
bigram_stats.update(get_ngrams(get_words(text), 2))
|
||||
|
||||
with open("word_stats.pickle", "wb") as file:
|
||||
pickle.dump(word_stats, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
with open("bigram_stats.pickle", "wb") as file:
|
||||
pickle.dump(bigram_stats, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
get_stats()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user