49 lines
1.5 KiB
Python
Executable File
49 lines
1.5 KiB
Python
Executable File
#!/usr/bin/python3
|
|
|
|
import pickle
|
|
import math
|
|
|
|
def clear_tokens(tokens):
|
|
tokens = tokens.replace('\n', ' ')
|
|
|
|
return tokens
|
|
|
|
def calc_post_prob(post, paranormal_class_logprob, sceptic_class_logprob, word_logprobs):
|
|
# dla kazdego tokenu z danego posta
|
|
text, timestap = post.rstrip('\n').split('\t')
|
|
text = clear_tokens(text)
|
|
tokens = text.lower().split(' ')
|
|
probs = {0.0 : 'sceptic', 0.0 : 'paranormal'}
|
|
for class_ in word_logprobs.keys():
|
|
product = 1
|
|
for token in tokens:
|
|
try:
|
|
product *= word_logprobs[class_][token]
|
|
except KeyError:
|
|
pass
|
|
# tu wzoru uzyj
|
|
if class_ == 'sceptic':
|
|
product *= sceptic_class_logprob
|
|
elif class_ == 'paranormal':
|
|
product *= paranormal_class_logprob
|
|
probs[abs(product)] = class_
|
|
print(probs)
|
|
|
|
return probs[max(probs.keys())]
|
|
|
|
|
|
def main():
|
|
with open('naive_base_model.pkl', 'rb') as f:
|
|
pickle_list = pickle.load(f)
|
|
paranormal_class_logprob = pickle_list[0]
|
|
sceptic_class_logprob = pickle_list[1]
|
|
word_logprobs = pickle_list[2]
|
|
with open('test-A/in.tsv') as in_f, open('test-A/out.tsv', 'w') as out_f:
|
|
for line in in_f:
|
|
hyp = calc_post_prob(line, paranormal_class_logprob, sceptic_class_logprob, word_logprobs)
|
|
if hyp == 'sceptic':
|
|
out_f.write(" S\n")
|
|
elif hyp == 'paranormal':
|
|
out_f.write(' P\n')
|
|
main()
|