#!/usr/bin/python3
from collections import defaultdict
import math
import pickle
import re

# in expected.tsv
def calc_class_logprob(expected_path):
    paranolal_classcount=0
    sceptic_classcount=0
    with open(expected_path) as f:
        for line in f:
            line = line.rstrip('\n').replace(' ','')
            if 'P' in line:
                paranolal_classcount +=1
            elif 'S' in line:
                sceptic_classcount +=1

    paranol_prob = paranolal_classcount / (paranolal_classcount + sceptic_classcount)
    sceptic_prob = sceptic_classcount / (paranolal_classcount + sceptic_classcount)

    return math.log(paranol_prob), math.log(sceptic_prob)

def clear_tokens(tokens, is_text=True):
    tokens = tokens.replace('\\n', ' ')
    return tokens
    # delete links, special characters, kropki, and \n
    tokens = re.sub(r'\(((http)|(https)).*((\.com)|(\.net)|(\.jpg)|(\.html))\)'," ", tokens)
    tokens = re.sub(r'(|\-|\_)([a-z]+(\-|\_))+[a-z]+(|\-|\_)', ' ', tokens)
    tokens = re.sub(r'[\n\&\"\?\\\'\*\[\]\,\;\.\=\+\(\)\!\/\:\`\~\%\^\$\#\@\’\>\″\±]+', ' ', tokens)
    tokens = re.sub(r'[\.\-][\.\-]+', ' ', tokens)
    tokens = re.sub(r'[0-9]+', ' ', tokens)
    tokens = re.sub(r'œ|·', '', tokens)
    if is_text:
        tokens = re.sub(r' +', ' ', tokens)
    else:
        tokens = re.sub(r' +', '', tokens)
    return tokens

# ile razy slowo wystepuje w dokumentach w danej klasie
def calc_word_count(in_path, expected_path):
    word_counts = {'paranormal':defaultdict(int), 'sceptic': defaultdict(int)} # dzienik zawierajacy slownik w ktorym s slowa i ile razy wystepuja
    with open(in_path) as infile, open(expected_path)  as expectedfile:
        for line, exp in zip(infile, expectedfile):
            class_ = exp.rstrip('\n').replace(' ','')
            text, timestap =line.rstrip('\n').split('\t')
            #print(f"text  {type(text)}")
            text = clear_tokens(text, True)
            tokens = text.lower().split(' ')
            #print(f"tokens {type(tokens)}")
            for token in tokens:
                clear_tokens(token,False)
                if class_ == 'P':
                    word_counts['paranormal'][token] += 1
                elif class_ == 'S':
                    word_counts['sceptic'][token]+=1

    return word_counts

def calc_word_logprobs(word_counts):
    total_skeptic = sum(word_counts['sceptic'].values()) + len(word_counts['sceptic'].keys())
    total_paranormal = sum(word_counts['paranormal'].values())+ len(word_counts['paranormal'].keys())
    word_logprobs= {'paranormal': {}, 'sceptic': {}}
    for class_ in word_counts.keys(): # sceptic paranormal
        for token, value in word_counts[class_].items():
            if class_ == 'sceptic':
                word_prob = (value +1)/ total_skeptic
            elif class_ == 'paranormal':
                word_prob = (value+1)/ total_paranormal

            #print (token)
            word_logprobs[class_][token] = math.log(word_prob)

    return word_logprobs

def main():
    #expected = './train/expected.tsv'
    expected = './dev-0/expected.tsv'
    #in_f = './train/in.tsv'
    in_f = './dev-0/in.tsv'
    print (f"expected {expected}")
    print (f"in {in_f}")
    paranormal_class_lgprob, skeptic_class_logprob = calc_class_logprob(expected)
    wordcounts =calc_word_count(in_f,expected)

    word_logprobs = calc_word_logprobs(wordcounts)
    with open('naive_base_model.pkl', 'wb') as f:
        pickle.dump([paranormal_class_lgprob, skeptic_class_logprob, word_logprobs], f)
     # w predict.py bierzemy ten wzor argmax P(w) iloczynP(w|c)

main()