from collections import defaultdict import math import pickle open_file = open('naive_base_model.pkl', 'rb') pickle_loaded = pickle.load(open_file) paranomal_class_logprob, skeptic_class_logprob, word_logprobs = pickle_loaded #pickle_loaded=pickle.load(open_file) #paranomal_class_logprob, skeptic_class_logprob, word_logprobs = pickle_loaded #Niektórych słów nie bezie w zbiorze treningowym dev-0 i dev-A def prediction(input,output): output_file = open(output,'w') with open(input,encoding='utf-8') as in_file: for line in in_file: temp_paranormal_logprob = paranomal_class_logprob temp_skeptic_logprob = skeptic_class_logprob text, timestamp = line.rstrip('\n').split('\t') tokens = text.lower().split(' ') for token in tokens: if token not in word_logprobs['paranormal']: word_logprobs['paranormal'][token] = 0 if token not in word_logprobs['skeptic']: word_logprobs['skeptic'][token] = 0 temp_paranormal_logprob += paranomal_class_logprob + word_logprobs['paranormal'][token] temp_skeptic_logprob += skeptic_class_logprob + word_logprobs['skeptic'][token] if temp_paranormal_logprob > temp_skeptic_logprob: output_file.write('P\n') else: output_file.write('S\n') def main(): prediction('dev-0/in.tsv','dev-0/out.tsv') prediction('test-A/in.tsv/in.tsv','test-A/out.tsv') main()