paranormal-or-skeptic/predict.py
2020-04-06 10:41:14 +02:00

56 lines
1.6 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python3
import pickle, re, sys
from nltk.corpus import stopwords
def clear_post(post):
post = post.replace('\\n', ' ')
post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\?\+\-\%]+(\)|)', ' internetlink ', post)
post = re.sub(r'[\.\,\/\~]+', ' ', post)
post = re.sub(r'(&lt|&gt|\@[a-zA-Z0-9]+)','',post)
post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\\\!\=\^]+', '', post)
post = re.sub(r'( \- |\-\-+)', ' ', post)
post = re.sub(r' +', ' ', post)
post = post.rstrip(' ')
post = post.split(' ')
stop_words = set(stopwords.words('english'))
post_no_stop = [w for w in post if not w in stop_words]
return post_no_stop
def calc_prob(posts, weights, word_to_index_mapping):
for post in posts:
d = post.split(' ')
y_hat = weights[0]
for token in d:
try:
y_hat += weights[word_to_index_mapping[token]] * post.count(token)
except KeyError:
y_hat += 0
if y_hat > 0.5:
print("1")
else:
print("0")
def main():
if len(sys.argv) != 2:
print("Expected model")
return
model = str(sys.argv[1])
posts = []
for line in sys.stdin:
text, timestap = line.rstrip('\n').split('\t')
post = clear_post(text)
posts.append(" ".join(post))
with open(model, 'rb') as f:
pickle_list = pickle.load(f)
weights = pickle_list[0]
lowest_loss_weights = pickle_list[1]
word_to_index_mapping = pickle_list[2]
calc_prob(posts, weights, word_to_index_mapping)
main()