paranormal-or-skeptic/train_baseline.py
2020-03-29 13:39:47 +02:00

92 lines
3.5 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python3
from collections import defaultdict
import math
import pickle
import re
# in expected.tsv
def calc_class_logprob(expected_path):
paranolal_classcount=0
sceptic_classcount=0
with open(expected_path) as f:
for line in f:
line = line.rstrip('\n').replace(' ','')
if 'P' in line:
paranolal_classcount +=1
elif 'S' in line:
sceptic_classcount +=1
paranol_prob = paranolal_classcount / (paranolal_classcount + sceptic_classcount)
sceptic_prob = sceptic_classcount / (paranolal_classcount + sceptic_classcount)
return math.log(paranol_prob), math.log(sceptic_prob)
def clear_tokens(tokens, is_text=True):
tokens = tokens.replace('\\n', ' ')
return tokens
# delete links, special characters, kropki, and \n
tokens = re.sub(r'\(((http)|(https)).*((\.com)|(\.net)|(\.jpg)|(\.html))\)'," ", tokens)
tokens = re.sub(r'(|\-|\_)([a-z]+(\-|\_))+[a-z]+(|\-|\_)', ' ', tokens)
tokens = re.sub(r'[\n\&\"\?\\\'\*\[\]\,\;\.\=\+\(\)\!\/\:\`\~\%\^\$\#\@\\\\±]+', ' ', tokens)
tokens = re.sub(r'[\.\-][\.\-]+', ' ', tokens)
tokens = re.sub(r'[0-9]+', ' ', tokens)
tokens = re.sub(r'œ|·', '', tokens)
if is_text:
tokens = re.sub(r' +', ' ', tokens)
else:
tokens = re.sub(r' +', '', tokens)
return tokens
# ile razy slowo wystepuje w dokumentach w danej klasie
def calc_word_count(in_path, expected_path):
word_counts = {'paranormal':defaultdict(int), 'sceptic': defaultdict(int)} # dzienik zawierajacy slownik w ktorym s slowa i ile razy wystepuja
with open(in_path) as infile, open(expected_path) as expectedfile:
for line, exp in zip(infile, expectedfile):
class_ = exp.rstrip('\n').replace(' ','')
text, timestap =line.rstrip('\n').split('\t')
#print(f"text {type(text)}")
text = clear_tokens(text, True)
tokens = text.lower().split(' ')
#print(f"tokens {type(tokens)}")
for token in tokens:
clear_tokens(token,False)
if class_ == 'P':
word_counts['paranormal'][token] += 1
elif class_ == 'S':
word_counts['sceptic'][token]+=1
return word_counts
def calc_word_logprobs(word_counts):
total_skeptic = sum(word_counts['sceptic'].values()) + len(word_counts['sceptic'].keys())
total_paranormal = sum(word_counts['paranormal'].values())+ len(word_counts['paranormal'].keys())
word_logprobs= {'paranormal': {}, 'sceptic': {}}
for class_ in word_counts.keys(): # sceptic paranormal
for token, value in word_counts[class_].items():
if class_ == 'sceptic':
word_prob = (value +1)/ total_skeptic
elif class_ == 'paranormal':
word_prob = (value+1)/ total_paranormal
#print (token)
word_logprobs[class_][token] = math.log(word_prob)
return word_logprobs
def main():
expected = './train/expected.tsv'
#expected = './dev-0/expected.tsv'
in_f = './train/in.tsv'
#in_f = './dev-0/in.tsv'
print (f"expected {expected}")
print (f"in {in_f}")
paranormal_class_lgprob, skeptic_class_logprob = calc_class_logprob(expected)
wordcounts =calc_word_count(in_f,expected)
word_logprobs = calc_word_logprobs(wordcounts)
with open('naive_base_model.pkl', 'wb') as f:
pickle.dump([paranormal_class_lgprob, skeptic_class_logprob, word_logprobs], f)
# w predict.py bierzemy ten wzor argmax P(w) iloczynP(w|c)
main()