paranormal-or-skeptic/predict.py

70 lines
2.4 KiB
Python
Raw Normal View History

2020-03-22 10:15:36 +01:00
#!/usr/bin/python3
import pickle
import math
2020-03-22 11:59:07 +01:00
import re
2020-03-22 10:15:36 +01:00
2020-03-22 13:32:09 +01:00
def clear_tokens(tokens, is_text=True):
2020-03-22 11:59:07 +01:00
tokens = tokens.replace('\\n', ' ')
2020-03-22 12:56:42 +01:00
tokens = re.sub(r'\(((http)|(https)).*((\.com)|(\.net)|(\.jpg)|(\.html))\)'," ", tokens)
tokens = re.sub(r'[\n\&\"\?\\\'\*\[\]\,\;\.\=\+\(\)\!\/\:\`\~\%\^\$\#\@\\\\±]+', ' ', tokens)
2020-03-22 11:59:07 +01:00
tokens = re.sub(r'[\.\-][\.\-]+', ' ', tokens)
2020-03-22 13:32:09 +01:00
tokens = re.sub(r'[0-9]+', ' ', tokens)
2020-03-22 12:56:42 +01:00
tokens = re.sub(r'œ|·', '', tokens)
2020-03-22 13:32:09 +01:00
if is_text:
tokens = re.sub(r' +', ' ', tokens)
else:
tokens = re.sub(r' +', '', tokens)
2020-03-22 10:15:36 +01:00
return tokens
def calc_post_prob(post, paranormal_class_logprob, sceptic_class_logprob, word_logprobs):
# dla kazdego tokenu z danego posta
text, timestap = post.rstrip('\n').split('\t')
2020-03-22 13:32:09 +01:00
text = clear_tokens(text, True)
2020-03-22 10:15:36 +01:00
tokens = text.lower().split(' ')
probs = {0.0 : 'sceptic', 0.0 : 'paranormal'}
for class_ in word_logprobs.keys():
product = 1
for token in tokens:
2020-03-22 13:32:09 +01:00
token = clear_tokens(token, False)
2020-03-22 10:15:36 +01:00
try:
2020-03-22 11:59:07 +01:00
product += word_logprobs[class_][token]
2020-03-22 10:15:36 +01:00
except KeyError:
pass
# tu wzoru uzyj
if class_ == 'sceptic':
2020-03-22 11:59:07 +01:00
product += sceptic_class_logprob
2020-03-22 10:15:36 +01:00
elif class_ == 'paranormal':
2020-03-22 11:59:07 +01:00
product += paranormal_class_logprob
2020-03-22 10:15:36 +01:00
probs[abs(product)] = class_
2020-03-22 11:59:07 +01:00
#print(probs)
2020-03-22 12:56:42 +01:00
# mozna jeszcze zrobic aby bralo kluczowe slowa i wtedy decydowalo ze paranormal
2020-03-22 13:32:09 +01:00
if search_for_keywords(text):
return 'paranormal'
2020-03-22 10:15:36 +01:00
return probs[max(probs.keys())]
2020-03-22 13:32:09 +01:00
def search_for_keywords(text):
keywords = ['paranormal', 'ufo', 'aliens', 'conspiracy', 'aliens']
return any(keyword in text for keyword in keywords)
2020-03-22 10:15:36 +01:00
def main():
with open('naive_base_model.pkl', 'rb') as f:
pickle_list = pickle.load(f)
paranormal_class_logprob = pickle_list[0]
sceptic_class_logprob = pickle_list[1]
word_logprobs = pickle_list[2]
2020-03-22 13:32:09 +01:00
in_file = "test-A/in.tsv"
#in_file = "dev-0/in.tsv"
out_file = "test-A/out.tsv"
#out_file = "dev-0/out.tsv"
2020-03-22 12:14:52 +01:00
print (f"in {in_file}")
print (f"out {out_file}")
2020-03-22 11:59:07 +01:00
with open(in_file) as in_f, open(out_file, 'w') as out_f:
2020-03-22 10:15:36 +01:00
for line in in_f:
hyp = calc_post_prob(line, paranormal_class_logprob, sceptic_class_logprob, word_logprobs)
if hyp == 'sceptic':
out_f.write(" S\n")
elif hyp == 'paranormal':
out_f.write(' P\n')
main()