paranormal-or-skeptic/train.py

84 lines
3.3 KiB
Python
Raw Normal View History

2020-03-22 10:15:36 +01:00
#!/usr/bin/python3
from collections import defaultdict
import math
import pickle
2020-03-22 11:59:07 +01:00
import re
2020-03-29 13:39:47 +02:00
import sys
2020-03-22 10:15:36 +01:00
def calc_class_logprob(expected_path):
2020-03-29 13:39:47 +02:00
paranormal_classcount = 0
sceptic_classcount = 0
2020-03-22 10:15:36 +01:00
with open(expected_path) as f:
for line in f:
2020-03-22 12:14:52 +01:00
line = line.rstrip('\n').replace(' ','')
2020-03-22 10:15:36 +01:00
if 'P' in line:
2020-03-29 13:39:47 +02:00
paranormal_classcount +=1
2020-03-22 10:15:36 +01:00
elif 'S' in line:
sceptic_classcount +=1
2020-03-29 13:39:47 +02:00
paranol_prob = paranormal_classcount / (paranormal_classcount + sceptic_classcount)
sceptic_prob = sceptic_classcount / (paranormal_classcount + sceptic_classcount)
2020-03-22 10:15:36 +01:00
return math.log(paranol_prob), math.log(sceptic_prob)
2020-03-29 13:39:47 +02:00
def clear_post(post):
post = post.replace('\\n', ' ')
# delete links
post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\?\+]+(\)|)', '', post)
post = re.sub(r'[\.\,\/]+', ' ', post)
post = re.sub(r'(&lt|&gt)','',post)
post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%]+', '', post)
post = re.sub(r' \- ', ' ', post)
post = re.sub(r' +', ' ', post)
post = post.rstrip(' ')
return post
2020-03-22 10:15:36 +01:00
2020-03-29 13:39:47 +02:00
def calc_bigram_count(in_path, expected_path):
bigram_counts = {'paranormal' : defaultdict(int), 'sceptic' : defaultdict(int)}
with open(in_path) as infile, open(expected_path) as expected_file:
for line, exp in zip(infile, expected_file):
class_ = exp.rstrip('\n').replace(' ', '')
text, timestap = line.rstrip('\n').split('\t')
text = clear_post(text)
2020-03-22 10:15:36 +01:00
tokens = text.lower().split(' ')
2020-03-29 13:39:47 +02:00
for index in range(len(tokens)-1):
# if there is next token we append current and next
bigram = tokens[index] + " " + tokens[index + 1]
#print(bigram)
#print (f"bigram constructed from ;;;;{tokens[index]}:{tokens[index+1]};;;;;;;")
2020-03-22 10:15:36 +01:00
if class_ == 'P':
2020-03-29 13:39:47 +02:00
bigram_counts['paranormal'][bigram] +=1
2020-03-22 10:15:36 +01:00
elif class_ == 'S':
2020-03-29 13:39:47 +02:00
bigram_counts['sceptic'][bigram] +=1
return bigram_counts
2020-03-22 10:15:36 +01:00
2020-03-29 13:39:47 +02:00
def calc_bigram_logprobs(bigram_counts):
total_sceptic = sum(bigram_counts['sceptic'].values()) + len(bigram_counts['sceptic'].keys())
total_paranormal = sum(bigram_counts['paranormal'].values()) + len(bigram_counts['paranormal'].keys())
bigram_logprobs = {'paranormal' : {}, 'sceptic' : {}}
for class_ in bigram_counts.keys():
for bigram, value in bigram_counts[class_].items():
if class_ == "sceptic":
bigram_prob = (value + 1) / total_sceptic
elif class_ == "paranormal":
bigram_prob = (value + 1) / total_paranormal
2020-03-22 10:15:36 +01:00
2020-03-29 13:39:47 +02:00
bigram_logprobs[class_][bigram] = math.log(bigram_prob)
2020-03-22 10:15:36 +01:00
2020-03-29 13:39:47 +02:00
return bigram_logprobs
2020-03-22 10:15:36 +01:00
def main():
2020-03-29 13:39:47 +02:00
if len(sys.argv) != 4:
print("syntax is ./train.py expected.tsv in.tsv model.pkl")
return
expected_file = str(sys.argv[1])
in_file = str(sys.argv[2])
model = str(sys.argv[3])
paranormal_class_logprob, sceptic_class_logprob = calc_class_logprob(expected_file)
bigrams_count = calc_bigram_count(in_file, expected_file)
bigram_logprobs = calc_bigram_logprobs(bigrams_count)
with open(model, 'wb') as f:
pickle.dump([paranormal_class_logprob, sceptic_class_logprob, bigram_logprobs],f)
2020-03-22 10:15:36 +01:00
main()