paranormal-or-skeptic/predict.py
2020-03-29 19:48:30 +02:00

64 lines
2.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python3
import pickle
import math
import re
import sys
def calc_post_class(post, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs):
text, timestap = post.rstrip('\n').split('\t')
text = clear_post(text)
tokens = text.lower().split(' ')
probs = {}
for class_ in bigrams_logprobs.keys():
product = 0
for index in range(len(tokens)-1):
# we handle bigrams not in models as neutral
bigram = tokens[index] + " " + tokens[index + 1]
#print(bigram)
try:
product += bigrams_logprobs[class_][bigram]
except KeyError:
product += 0
if class_ == 'sceptic':
product += sceptic_class_logprob
elif class_ == 'paranormal':
product += paranormal_class_logprob
probs[abs(product)] = class_
#print(probs)
return probs[max(probs.keys())]
def clear_post(post):
post = post.replace('\\n', ' ')
post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\-\?\+\%]+(\)|)', ' internetlink ', post)
post = re.sub(r'[\.\,\/\~]+', ' ', post)
post = re.sub(r'(&lt|&gt|\@[a-zA-Z0-9]+)','',post)
post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\\\!\=\^]+', '', post)
post = re.sub(r'( \- |\-\-+)', ' ', post)
post = re.sub(r' +', ' ', post)
post = post.rstrip(' ')
return post
def main():
if len(sys.argv) != 4:
print("syntax is ./predict.py in.tsv out.tsv model.pkl")
return
in_file = sys.argv[1]
out_file = sys.argv[2]
model = sys.argv[3]
with open(model, 'rb') as f:
pickle_list = pickle.load(f)
paranormal_class_logprob = pickle_list[0]
sceptic_class_logprob = pickle_list[1]
bigrams_logprobs = pickle_list[2]
with open(in_file) as in_f, open(out_file, 'w') as out_f:
for line in in_f:
hyp = calc_post_class(line, paranormal_class_logprob, sceptic_class_logprob, bigrams_logprobs)
if hyp == 'sceptic':
out_f.write(' S\n')
elif hyp == 'paranormal':
out_f.write(' P\n')
main()