retro-gap/predict.py

67 lines
2.1 KiB
Python
Raw Normal View History

2020-12-08 12:01:14 +01:00
import pickle
import sys
from math import log
import regex as re
2020-12-08 14:14:52 +01:00
def count_prob(bigrams, unigrams):
prob = (bigrams + 1.0) / (unigrams + 1)
2020-12-08 12:01:14 +01:00
if prob > 1.0:
return 1.0
else:
return prob
def main():
2020-12-08 14:14:52 +01:00
ngrams = pickle.load(open('ngrams_2.pkl', 'rb'))
2020-12-08 12:01:14 +01:00
vocabulary_size = len(ngrams[1])
2020-12-08 14:14:52 +01:00
# a = ngrams[1]
# print(a)
# lookfor1 = str(".")
# #lookfor = tuple(lookfor1)
# # print(lookfor)
# b = a.get((',',),0)
2020-12-08 12:01:14 +01:00
for line in sys.stdin:
2020-12-08 14:14:52 +01:00
words = re.findall(r'.*\t.*\t.* (.*?)\t(.*?) ', line.lower())[0]
#print(words)
left_word = [str(words[0])]
right_word = [str(words[1])]
2020-12-08 12:01:14 +01:00
probabilities = []
for word in ngrams[1].keys():
word = str(word[0])
2020-12-08 14:14:52 +01:00
pre_ngram = tuple(left_word + [word])
post_ngram = tuple([word] + right_word)
#print(pre_ngram)
#print("bigram:", ngrams[2].get(pre_ngram, 0), "unigram", ngrams[1].get(word[0],0))
pre_ngram_prob = count_prob(ngrams[2].get(pre_ngram, 0), ngrams[1].get((word[0],),0) + vocabulary_size)
#if pre_ngram_prob>0:
post_ngram_prob = count_prob(ngrams[2].get(post_ngram, 0), ngrams[1].get((word[0],),0) + vocabulary_size)
2020-12-08 12:01:14 +01:00
probabilities.append((word, pre_ngram_prob * post_ngram_prob))
probabilities = sorted(probabilities, key=lambda t: t[1], reverse=True)[:50]
probability = 1.0
text = ''
counter = 0
has_log_prob0 = False
for p in probabilities:
word = p[0]
prob = p[1]
if counter == 0 and (probability - prob <= 0.0):
text = word + ':' + str(log(0.95)) + ' :' + str(log(0.05))
has_log_prob0 = True
break
if counter > 0 and (probability - prob <= 0.0):
text += ':' + str(log(probability))
has_log_prob0 = True
break
text += word + ':' + str(log(prob)) + ' '
probability -= prob
counter += 1
if not has_log_prob0:
text += ':' + str(log(0.0001))
print(text)
if __name__ == '__main__':
main()