paranormal-or-skeptic/train.py
2020-03-30 17:20:34 +02:00

158 lines
7.2 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python3
from collections import defaultdict
import math
import pickle
import re
import sys
import nltk
from nltk.corpus import stopwords
def calc_class_logprob(expected_path):
paranormal_classcount = 0
sceptic_classcount = 0
with open(expected_path) as f:
for line in f:
line = line.rstrip('\n').replace(' ','')
if 'P' in line:
paranormal_classcount +=1
elif 'S' in line:
sceptic_classcount +=1
paranol_prob = paranormal_classcount / (paranormal_classcount + sceptic_classcount)
sceptic_prob = sceptic_classcount / (paranormal_classcount + sceptic_classcount)
return math.log(paranol_prob), math.log(sceptic_prob)
def clear_post(post):
post = post.replace('\\n', ' ')
post = post.lower()
# delete links
post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\?\+\-\%]+(\)|)', ' internetlink ', post)
post = re.sub(r'[\.\,\/\~]+', ' ', post)
post = re.sub(r'(&lt|&gt|\@[a-zA-Z0-9]+)','',post)
post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\\\!\=\^]+', '', post)
post = re.sub(r'( \- |\-\-+)', ' ', post)
post = re.sub(r' +', ' ', post)
post = post.rstrip(' ')
post = post.split(' ')
stop_words = set(stopwords.words('english'))
post_no_stop = [w for w in post if not w in stop_words]
return post_no_stop
#def calc_bigram_count(in_path, expected_path):
# bigram_counts = {'paranormal' : defaultdict(int), 'sceptic' : defaultdict(int)}
# with open(in_path) as infile, open(expected_path) as expected_file:
# num_of_bigams = 0
# for line, exp in zip(infile, expected_file):
# class_ = exp.rstrip('\n').replace(' ', '')
# text, timestap = line.rstrip('\n').split('\t')
# tokens = clear_post(text)
# #tokens = text.lower().split(' ')
# for index in range(len(tokens)-1):
# # if there is next token we append current and next
# bigram = tokens[index] + " " + tokens[index + 1]
# #print(bigram)
# #print (f"bigram constructed from ;;;;{tokens[index]}:{tokens[index+1]};;;;;;;")
# if class_ == 'P':
# bigram_counts['paranormal'][bigram] +=1
# elif class_ == 'S':
# bigram_counts['sceptic'][bigram] +=1
# num_of_bigams +=1
# #print(f"num of every added bigams with repetitions {num_of_bigams})")
# #print(f"num of bigams in paranormal {len(bigram_counts['paranormal'])} and sceptic {len(bigram_counts['sceptic'])}")
# return bigram_counts
def calc_bigram_logprobs(bigram_counts):
total_sceptic = sum(bigram_counts['sceptic'].values()) + len(bigram_counts['sceptic'].keys())
total_paranormal = sum(bigram_counts['paranormal'].values()) + len(bigram_counts['paranormal'].keys())
bigram_logprobs = {'paranormal' : {}, 'sceptic' : {}}
for class_ in bigram_counts.keys():
for bigram, value in bigram_counts[class_].items():
if class_ == "sceptic":
bigram_prob = (value + 1) / total_sceptic
elif class_ == "paranormal":
bigram_prob = (value + 1) / total_paranormal
bigram_logprobs[class_][bigram] = math.log(bigram_prob)
return bigram_logprobs
#def calc_word_count(in_path, expected_path):
# word_counts = {'paranormal':defaultdict(int), 'sceptic': defaultdict(int)} # dzienik zawierajacy slownik w ktorym s slowa i ile razy wystepuja
# with open(in_path) as infile, open(expected_path) as expectedfile:
# for line, exp in zip(infile, expectedfile):
# class_ = exp.rstrip('\n').replace(' ','')
# text, timestap =line.rstrip('\n').split('\t')
# #print(f"text {type(text)}")
# text = clear_tokens(text, True)
# tokens = text.lower().split(' ')
# #print(f"tokens {type(tokens)}")
# for token in tokens:
# clear_tokens(token,False)
# if class_ == 'P':
# word_counts['paranormal'][token] += 1
# elif class_ == 'S':
# word_counts['sceptic'][token]+=1
#
# return word_counts
def calc_word_logprobs(word_counts):
total_skeptic = sum(word_counts['sceptic'].values()) + len(word_counts['sceptic'].keys())
total_paranormal = sum(word_counts['paranormal'].values())+ len(word_counts['paranormal'].keys())
word_logprobs= {'paranormal': {}, 'sceptic': {}}
for class_ in word_counts.keys(): # sceptic paranormal
for token, value in word_counts[class_].items():
if class_ == 'sceptic':
word_prob = (value +1)/ total_skeptic
elif class_ == 'paranormal':
word_prob = (value+1)/ total_paranormal
#print (token)
word_logprobs[class_][token] = math.log(word_prob)
return word_logprobs
def launch_bigrams_and_words(in_path, expected_path):
word_counts = {'paranormal':defaultdict(int), 'sceptic': defaultdict(int)}
bigram_counts = {'paranormal' : defaultdict(int), 'sceptic' : defaultdict(int)}
with open(in_path) as infile, open(expected_path) as expected_file:
for line, exp in zip(infile, expected_file):
class_ = exp.rstrip('\n').replace(' ', '')
text, timestap = line.rstrip('\n').split('\t')
tokens = clear_post(text)
for index in range(len(tokens)-1):
# if there is next token we append current and next
bigram = tokens[index] + " " + tokens[index + 1]
#print(bigram)
#print (f"bigram constructed from ;;;;{tokens[index]}:{tokens[index+1]};;;;;;;")
if class_ == 'P':
bigram_counts['paranormal'][bigram] +=1
word_counts['paranormal'][tokens[index]] +=1
elif class_ == 'S':
bigram_counts['sceptic'][bigram] +=1
word_counts['sceptic'][tokens[index]] +=1
return bigram_counts, word_counts
def main():
if len(sys.argv) != 4:
print("syntax is ./train.py expected.tsv in.tsv model.pkl")
return
expected_file = str(sys.argv[1])
in_file = str(sys.argv[2])
model = str(sys.argv[3])
paranormal_class_logprob, sceptic_class_logprob = calc_class_logprob(expected_file)
#bigrams_count = calc_bigram_count(in_file, expected_file)
bigrams_count, words_count = launch_bigrams_and_words(in_file, expected_file)
bigram_logprobs = calc_bigram_logprobs(bigrams_count)
word_logprobs = calc_word_logprobs(words_count)
total_sceptic_bigram = sum(bigrams_count['sceptic'].values()) + len(bigrams_count['sceptic'].keys())
total_paranormal_bigram = sum(bigrams_count['paranormal'].values()) + len(bigrams_count['paranormal'].keys())
total_sceptic_word = sum(words_count['sceptic'].values()) + len(words_count['sceptic'].keys())
total_paranormal_word = sum(words_count['paranormal'].values())+ len(words_count['paranormal'].keys())
with open(model, 'wb') as f:
pickle.dump([paranormal_class_logprob, sceptic_class_logprob, bigram_logprobs, word_logprobs, total_sceptic_bigram, total_paranormal_bigram, total_sceptic_word, total_paranormal_word],f)
main()