from collections import defaultdict import math import pickle import re def prediction(input,output): output_file = open(output,'w') pickle_load = pickle.load(open('naive_base_model.pkl', 'rb')) paranormal_class_logprob, skeptic_class_logprob, word_logprob = pickle_load with open(input,encoding='utf-8') as in_file: for line in in_file: temp_paranormal_logprob = paranomal_class_logprob temp_skeptic_logprob = skeptic_class_logprob text, timestamp = line.rstrip('\n').split('\t') text = text.lower() text = re.sub(r'http\S+', " ", text) text = re.sub(r'\\n+', " ", text) text = re.sub(r'\/[a-z]\/', " ", text) text = re.sub(r'[^a-z]', " ", text) text = re.sub(r'\s{2,}', " ", text) text = re.sub(r'(\s+|\\n)', ' ', text) text = re.sub(r'\W\w{1,3}\W|\A\w{1,3}\W', " ", text) text = re.sub(r'^\s', "", text) tokens = text.split(' ') for token in tokens: if token not in word_logprobs['paranormal']: word_logprobs['paranormal'][token] = -14.78 if token not in word_logprobs['skeptic']: word_logprobs['skeptic'][token] = -15.6 temp_paranormal_logprob += paranomal_class_logprob + word_logprobs['paranormal'][token] temp_skeptic_logprob += skeptic_class_logprob + word_logprobs['skeptic'][token] if temp_paranormal_logprob > temp_skeptic_logprob: output_file.write('0\n') else: output_file.write('1\n') ## def main(): prediction('dev-0/in.tsv','dev-0/out.tsv') prediction('test-A/in.tsv/in.tsv','test-A/out.tsv') main()