#!/usr/bin/python3
from collections import defaultdict
import math
import pickle
import re
import sys
import nltk
from nltk.corpus import stopwords

def calc_class_logprob(expected_path):
    paranormal_classcount = 0
    sceptic_classcount = 0

    with open(expected_path) as f:
        for line in f:
            line = line.rstrip('\n').replace(' ','')
            if 'P' in line:
                paranormal_classcount +=1
            elif 'S' in line:
                sceptic_classcount +=1

    paranol_prob = paranormal_classcount / (paranormal_classcount + sceptic_classcount)
    sceptic_prob = sceptic_classcount / (paranormal_classcount + sceptic_classcount)

    return math.log(paranol_prob), math.log(sceptic_prob)

def clear_post(post):
    post = post.replace('\\n', ' ')
    post = post.lower()
    # delete links
    post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\?\+\-\%]+(\)|)', ' internetlink ', post)
    post = re.sub(r'[\.\,\/\~]+', ' ', post)
    post = re.sub(r'(&lt|&gt|\@[a-zA-Z0-9]+)','',post)
    post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\–\”\!\=\^]+', '', post)
    post = re.sub(r'( \- |\-\-+)', ' ', post)
    post = re.sub(r' +', ' ', post)
    post = post.rstrip(' ')
    post = post.split(' ')
    stop_words = set(stopwords.words('english'))
    post_no_stop = [w for w in post if not w in stop_words]
    return post_no_stop

def calc_bigram_count(in_path, expected_path):
    bigram_counts = {'paranormal' : defaultdict(int), 'sceptic' : defaultdict(int)}
    with open(in_path) as infile, open(expected_path) as expected_file:
        num_of_bigams = 0
        for line, exp in zip(infile, expected_file):
            class_ = exp.rstrip('\n').replace(' ', '')
            text, timestap = line.rstrip('\n').split('\t')
            tokens = clear_post(text)
            #tokens = text.lower().split(' ')
            for index in range(len(tokens)-1):
                # if there is next token we append current and next
                bigram = tokens[index] + " " + tokens[index + 1]
                #print(bigram)
                #print (f"bigram constructed from ;;;;{tokens[index]}:{tokens[index+1]};;;;;;;")
                if class_ == 'P':
                    bigram_counts['paranormal'][bigram] +=1
                elif class_ == 'S':
                    bigram_counts['sceptic'][bigram] +=1
                num_of_bigams +=1
    #print(f"num of every added bigams with repetitions {num_of_bigams})")
    #print(f"num of bigams in paranormal {len(bigram_counts['paranormal'])} and sceptic {len(bigram_counts['sceptic'])}")
    return bigram_counts

def calc_bigram_logprobs(bigram_counts):
    total_sceptic = sum(bigram_counts['sceptic'].values()) + len(bigram_counts['sceptic'].keys())
    total_paranormal = sum(bigram_counts['paranormal'].values()) + len(bigram_counts['paranormal'].keys())
    bigram_logprobs = {'paranormal' : {}, 'sceptic' : {}}
    for class_ in bigram_counts.keys():
        for bigram, value in bigram_counts[class_].items():
            if class_ == "sceptic":
                bigram_prob = (value + 1) / total_sceptic
            elif class_ == "paranormal":
                bigram_prob = (value + 1) / total_paranormal

            bigram_logprobs[class_][bigram] = math.log(bigram_prob)

    return bigram_logprobs

def main():
    if len(sys.argv) != 4:
        print("syntax is ./train.py expected.tsv in.tsv model.pkl")
        return
    expected_file = str(sys.argv[1])
    in_file = str(sys.argv[2])
    model = str(sys.argv[3])
    paranormal_class_logprob, sceptic_class_logprob = calc_class_logprob(expected_file)
    bigrams_count = calc_bigram_count(in_file, expected_file)
    bigram_logprobs = calc_bigram_logprobs(bigrams_count)
    with open(model, 'wb') as f:
        pickle.dump([paranormal_class_logprob, sceptic_class_logprob, bigram_logprobs],f)
main()