paranormal-or-skeptic/predict.py

77 lines
2.5 KiB
Python
Raw Normal View History

2020-03-22 10:15:36 +01:00
#!/usr/bin/python3
import pickle
import math
2020-03-22 11:59:07 +01:00
import re
2020-03-29 13:39:47 +02:00
import sys
2020-03-29 23:29:19 +02:00
import nltk
from nltk.corpus import stopwords
2020-03-22 10:15:36 +01:00
2020-03-29 23:29:19 +02:00
def calc_post_class(post, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs, words_logprobs):
2020-03-22 10:15:36 +01:00
text, timestap = post.rstrip('\n').split('\t')
2020-03-29 23:29:19 +02:00
tokens = clear_post(text)
#tokens = text.lower().split(' ')
2020-03-22 13:58:35 +01:00
probs = {}
2020-03-29 13:39:47 +02:00
for class_ in bigrams_logprobs.keys():
product = 0
for index in range(len(tokens)-1):
# we handle bigrams not in models as neutral
bigram = tokens[index] + " " + tokens[index + 1]
#print(bigram)
2020-03-22 10:15:36 +01:00
try:
2020-03-29 13:39:47 +02:00
product += bigrams_logprobs[class_][bigram]
2020-03-22 10:15:36 +01:00
except KeyError:
2020-03-29 19:48:30 +02:00
product += 0
2020-03-29 23:29:19 +02:00
for token in tokens:
try:
product += words_logprobs[class_][token]
except KeyError:
product += 0
2020-03-22 10:15:36 +01:00
if class_ == 'sceptic':
2020-03-29 13:39:47 +02:00
product += sceptic_class_logprob
2020-03-22 10:15:36 +01:00
elif class_ == 'paranormal':
2020-03-29 13:39:47 +02:00
product += paranormal_class_logprob
2020-03-29 19:48:30 +02:00
probs[abs(product)] = class_
2020-03-29 23:29:19 +02:00
2020-03-29 13:39:47 +02:00
#print(probs)
2020-03-29 19:48:30 +02:00
return probs[max(probs.keys())]
2020-03-22 10:15:36 +01:00
2020-03-29 13:39:47 +02:00
def clear_post(post):
post = post.replace('\\n', ' ')
2020-03-29 23:29:19 +02:00
post = post.lower()
2020-03-29 19:48:30 +02:00
post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\-\?\+\%]+(\)|)', ' internetlink ', post)
2020-03-29 14:28:07 +02:00
post = re.sub(r'[\.\,\/\~]+', ' ', post)
post = re.sub(r'(&lt|&gt|\@[a-zA-Z0-9]+)','',post)
2020-03-29 19:48:30 +02:00
post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\\\!\=\^]+', '', post)
2020-03-29 14:28:07 +02:00
post = re.sub(r'( \- |\-\-+)', ' ', post)
2020-03-29 13:39:47 +02:00
post = re.sub(r' +', ' ', post)
post = post.rstrip(' ')
2020-03-29 23:29:19 +02:00
post = post.split(' ')
stop_words = set(stopwords.words('english'))
post_no_stop = [w for w in post if not w in stop_words]
return post_no_stop
2020-03-22 10:15:36 +01:00
def main():
2020-03-29 13:39:47 +02:00
if len(sys.argv) != 4:
print("syntax is ./predict.py in.tsv out.tsv model.pkl")
return
in_file = sys.argv[1]
out_file = sys.argv[2]
model = sys.argv[3]
with open(model, 'rb') as f:
2020-03-22 10:15:36 +01:00
pickle_list = pickle.load(f)
2020-03-29 13:39:47 +02:00
2020-03-22 10:15:36 +01:00
paranormal_class_logprob = pickle_list[0]
sceptic_class_logprob = pickle_list[1]
2020-03-29 13:39:47 +02:00
bigrams_logprobs = pickle_list[2]
2020-03-29 23:29:19 +02:00
words_logprobs = pickle_list[3]
2020-03-29 13:39:47 +02:00
2020-03-22 11:59:07 +01:00
with open(in_file) as in_f, open(out_file, 'w') as out_f:
2020-03-29 13:39:47 +02:00
for line in in_f:
2020-03-29 23:29:19 +02:00
hyp = calc_post_class(line, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs, words_logprobs)
2020-03-22 10:15:36 +01:00
if hyp == 'sceptic':
2020-03-29 13:39:47 +02:00
out_f.write(' S\n')
2020-03-22 10:15:36 +01:00
elif hyp == 'paranormal':
2020-03-29 13:39:47 +02:00
out_f.write(' P\n')
2020-03-22 10:15:36 +01:00
main()