paranormal-or-skeptic/predict_bigram.py
2020-04-04 22:07:48 +02:00

93 lines
3.5 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python3
import pickle
import math
import re
import sys
import nltk
from nltk.corpus import stopwords
def calc_post_class(post, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs, words_logprobs, total_sceptic_bigram, total_paranormal_bigram, total_sceptic_word, total_paranormal_word):
text, timestap = post.rstrip('\n').split('\t')
tokens = clear_post(text)
#tokens = text.lower().split(' ')
probs = {}
for class_ in bigrams_logprobs.keys():
product = 0
for index in range(len(tokens)-1):
# we handle bigrams not in models as neutral
bigram = tokens[index] + " " + tokens[index + 1]
#print(bigram)
try:
product += bigrams_logprobs[class_][bigram] * 4
except KeyError:
product += 0
# if class_ == 'sceptic':
# product += math.log(1/total_sceptic_bigram)
# elif class_ == 'paranormal':
# product += math.log(1/total_paranormal_bigram)
for token in tokens:
try:
product += words_logprobs[class_][token]/7
except KeyError:
product +=0
#if class_ == 'sceptic':
# product += math.log(1/total_sceptic_word)
#elif class_ == 'paranormal':
# product += math.log(1/total_paranormal_word)
if class_ == 'sceptic':
product += sceptic_class_logprob
elif class_ == 'paranormal':
product += paranormal_class_logprob
probs[abs(product)] = class_
#print(probs)
return probs[max(probs.keys())]
def clear_post(post):
post = post.replace('\\n', ' ')
post = post.lower()
post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\-\?\+\%]+(\)|)', ' internetlink ', post)
post = re.sub(r'[\.\,\/\~]+', ' ', post)
post = re.sub(r'(&lt|&gt|\@[a-zA-Z0-9]+)','',post)
post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\\\!\=\^]+', '', post)
post = re.sub(r'( \- |\-\-+)', ' ', post)
post = re.sub(r' +', ' ', post)
post = post.rstrip(' ')
post = post.split(' ')
stop_words = set(stopwords.words('english'))
post_no_stop = [w for w in post if not w in stop_words]
return post_no_stop
def main():
if len(sys.argv) != 4:
print("syntax is ./predict.py in.tsv out.tsv model.pkl")
return
in_file = sys.argv[1]
out_file = sys.argv[2]
model = sys.argv[3]
with open(model, 'rb') as f:
pickle_list = pickle.load(f)
paranormal_class_logprob = pickle_list[0]
sceptic_class_logprob = pickle_list[1]
bigrams_logprobs = pickle_list[2]
words_logprobs = pickle_list[3]
total_sceptic_bigram = pickle_list[4]
total_paranormal_bigram = pickle_list[5]
total_sceptic_word = pickle_list[6]
total_paranormal_word = pickle_list[7]
print(math.log(1/total_sceptic_bigram))
print(math.log(1/total_paranormal_bigram))
print(math.log(1/total_sceptic_word))
print(math.log(1/total_paranormal_word))
with open(in_file) as in_f, open(out_file, 'w') as out_f:
for line in in_f:
hyp = calc_post_class(line, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs, words_logprobs, total_sceptic_bigram, total_paranormal_bigram, total_sceptic_word, total_paranormal_word)
if hyp == 'sceptic':
out_f.write(' S\n')
elif hyp == 'paranormal':
out_f.write(' P\n')
main()