retro-gap/tau_bigram.py

84 lines
2.6 KiB
Python

import pickle
import sys
from math import log
import regex as re
def smoothing(count, total, num_class):
probability = (count + 1.0) / (total + num_class)
if probability > 1.0:
return 1.0
else:
return probability
def prediction(file_in, file_out):
ngrams = pickle.load(open('ngrams.pkl', 'rb'))
dictionary_size = len(ngrams[1])
in_file = open(file_in, encoding = 'utf-8')
out_file = open(file_out, 'w', encoding='utf-8')
for line in in_file:
words = re.findall(r'.*\t.*\t.* (.*?) (.*?)\t(.*?) (.*?) ', line.lower())[0]
left_words = [str(words[0])]
right_words = [str(words[1])]
probabilities = []
for word in ngrams[1].keys():
word = str(word[0])
pre_ngram = tuple(left_words + [word])
post_ngram = tuple([word] + right_words)
pre_prob = smoothing(ngrams[2].get(pre_ngram, 0), ngrams[2].get(tuple(left_words), 0), dictionary_size)
post_prob = smoothing(ngrams[2].get(post_ngram, 0), ngrams[2].get(post_ngram[0:1], 0), dictionary_size)
probabilities.append((word, pre_prob * post_prob))
log_prob_0 = False
probabilities = sorted(probabilities, key=lambda t: t[1], reverse=True)[:50]
probability = 1.0
text = ''
counter = 0
for probab in probabilities:
word = probab[0]
prob = probab[1]
if counter == 0 and (probability - prob <= 0.0):
text = word + ':' + str(log(0.95)) + ' :' + str(log(0.05))
log_prob_0 = True
break
if counter > 0 and (probability - prob <= 0.0):
text += ':' + str(log(probability))
log_prob_0 = True
break
text += word + ':' + str(log(prob)) + ' '
probability -= prob
counter += 1
if not log_prob_0:
text += ':' + str(log(0.0001))
out_file.write(text)
out_file.write('\n')
if __name__ == '__main__':
in_dev_0 = 'C:/Users/eryk6/PycharmProjects/retro-gap/dev-0/in.tsv'
in_dev_1 = 'C:/Users/eryk6/PycharmProjects/retro-gap/dev-1/in.tsv'
in_test_a = 'C:/Users/eryk6/PycharmProjects/retro-gap/test-A/in.tsv'
out_dev_0 = 'C:/Users/eryk6/PycharmProjects/retro-gap/dev-0/out.tsv'
out_dev_1 = 'C:/Users/eryk6/PycharmProjects/retro-gap/dev-1/out.tsv'
out_test_a = 'C:/Users/eryk6/PycharmProjects/retro-gap/test-A/out.tsv'
#prediction(in_dev_0, out_dev_0)
#prediction(in_dev_1, out_dev_1)
#prediction(in_test_a, out_test_a)