paranormal-or-skeptic/train.py
2020-03-22 10:15:36 +01:00

73 lines
2.7 KiB
Python
Executable File

#!/usr/bin/python3
from collections import defaultdict
import math
import pickle
# in expected.tsv
def calc_class_logprob(expected_path):
paranolal_classcount=0
sceptic_classcount=0
with open(expected_path) as f:
for line in f:
if 'P' in line:
paranolal_classcount +=1
elif 'S' in line:
sceptic_classcount +=1
paranol_prob = paranolal_classcount / (paranolal_classcount + sceptic_classcount)
sceptic_prob = sceptic_classcount / (paranolal_classcount + sceptic_classcount)
return math.log(paranol_prob), math.log(sceptic_prob)
def clear_tokens(tokens):
tokens = tokens.replace('\n', ' ')
# delete links, special characters, kropki, and \n
return tokens
# ile razy slowo wystepuje w dokumentach w danej klasie
def calc_word_count(in_path, expected_path):
word_counts = {'paranormal':defaultdict(int), 'sceptic': defaultdict(int)} # dzienik zawierajacy slownik w ktorym s slowa i ile razy wystepuja
with open(in_path) as infile, open(expected_path) as expectedfile:
for line, exp in zip(infile, expectedfile):
class_ = exp.rstrip('\n').replace(' ','')
text, timestap =line.rstrip('\n').split('\t')
#print(f"text {type(text)}")
text = clear_tokens(text)
tokens = text.lower().split(' ')
#print(f"tokens {type(tokens)}")
for token in tokens:
if class_ == 'P':
word_counts['paranormal'][token] += 1
elif class_ == 'S':
word_counts['sceptic'][token]+=1
return word_counts
def calc_word_logprobs(word_counts):
total_skeptic = sum(word_counts['sceptic'].values()) + len(word_counts['sceptic'].keys())
total_paranormal = sum(word_counts['paranormal'].values())+ len(word_counts['paranormal'].keys())
word_logprobs= {'paranormal': {}, 'sceptic': {}}
for class_ in word_counts.keys(): # sceptic paranormal
for token, value in word_counts[class_].items():
if class_ == 'sceptic':
word_prob = (value +1)/ total_skeptic
elif class_ == 'paranormal':
word_prob = (value+1)/ total_paranormal
print (token)
word_logprobs[class_][token] = math.log(word_prob)
return word_logprobs
def main():
paranormal_class_lgprob, skeptic_class_logprob = calc_class_logprob('./train/expected.tsv')
wordcounts =calc_word_count('./train/in.tsv','./train/expected.tsv')
word_logprobs = calc_word_logprobs(wordcounts)
with open('naive_base_model.pkl', 'wb') as f:
pickle.dump([paranormal_class_lgprob, skeptic_class_logprob, word_logprobs], f)
# w predict.py bierzemy ten wzor argmax P(w) iloczynP(w|c)
main()