paranormal-or-skeptic/code_prediction.py

39 lines
1.5 KiB
Python
Raw Normal View History

2020-03-28 20:40:28 +01:00
from collections import defaultdict
import math
import pickle
open_file = open('naive_base_model.pkl', 'rb')
pickle_loaded = pickle.load(open_file)
2020-03-28 20:40:28 +01:00
paranomal_class_logprob, skeptic_class_logprob, word_logprobs = pickle_loaded
#pickle_loaded=pickle.load(open_file)
#paranomal_class_logprob, skeptic_class_logprob, word_logprobs = pickle_loaded
#Niektórych słów nie bezie w zbiorze treningowym dev-0 i dev-A
def prediction(input,output):
output_file = open(output,'w')
with open(input,encoding='utf-8') as in_file:
for line in in_file:
temp_paranormal_logprob = paranomal_class_logprob
temp_skeptic_logprob = skeptic_class_logprob
text, timestamp = line.rstrip('\n').split('\t')
tokens = text.lower().split(' ')
for token in tokens:
if token not in word_logprobs['paranormal']:
word_logprobs['paranormal'][token] = 0
if token not in word_logprobs['skeptic']:
word_logprobs['skeptic'][token] = 0
temp_paranormal_logprob += paranomal_class_logprob + word_logprobs['paranormal'][token]
temp_skeptic_logprob += skeptic_class_logprob + word_logprobs['skeptic'][token]
if temp_paranormal_logprob > temp_skeptic_logprob:
2020-04-20 18:34:14 +02:00
output_file.write('0\n')
2020-03-31 14:41:37 +02:00
else:
2020-04-20 18:34:14 +02:00
output_file.write('1\n')
def main():
prediction('dev-0/in.tsv','dev-0/out.tsv')
prediction('test-A/in.tsv/in.tsv','test-A/out.tsv')
main()