#!/usr/bin/python3

import pickle
import math
import re

def clear_tokens(tokens, is_text=True):
    tokens = tokens.replace('\\n', ' ')
    return tokens
    tokens = re.sub(r'\(((http)|(https)).*((\.com)|(\.net)|(\.jpg)|(\.html))\)'," ", tokens)
    tokens = re.sub(r'[\n\&\"\?\\\'\*\[\]\,\;\.\=\+\(\)\!\/\:\`\~\%\^\$\#\@\’\>\″\±]+', ' ', tokens)
    tokens = re.sub(r'[\.\-][\.\-]+', ' ', tokens)
    tokens = re.sub(r'[0-9]+', ' ', tokens)
    tokens = re.sub(r'œ|·', '', tokens)
    if is_text:
        tokens = re.sub(r' +', ' ', tokens)
    else:
        tokens = re.sub(r' +', '', tokens)
    return tokens

def calc_post_prob(post, paranormal_class_logprob, sceptic_class_logprob, word_logprobs):
    # dla kazdego tokenu z danego posta
    text, timestap = post.rstrip('\n').split('\t')
    text =  clear_tokens(text, True)
    tokens = text.lower().split(' ')
    #probs = {0.0 : 'sceptic', 0.0 : 'paranormal'}
    probs = {}
    for class_ in word_logprobs.keys():
        product = 1
        for token in tokens:
            token = clear_tokens(token, False)
            try:
                product *= word_logprobs[class_][token]
            except KeyError:
                product *= 1
            # tu wzoru uzyj
        if class_ == 'sceptic':
            product *=  sceptic_class_logprob
        elif class_ == 'paranormal':
            product *= paranormal_class_logprob
        probs[abs(product)] = class_
        #print(probs)
# mozna jeszcze zrobic aby bralo kluczowe slowa i wtedy decydowalo ze paranormal
    if search_for_keywords(text):
        return 'paranormal'
    return probs[max(probs.keys())]

def search_for_keywords(text):
    keywords = ['paranormal', 'ufo', 'aliens', 'conspiracy', 'aliens', 'atlantis']
    return any(keyword in text for keyword in keywords)

def main():
    with open('naive_base_model.pkl', 'rb') as f:
        pickle_list = pickle.load(f)
    paranormal_class_logprob = pickle_list[0]
    sceptic_class_logprob = pickle_list[1]
    word_logprobs = pickle_list[2]
    in_file = "test-A/in.tsv"
    #in_file = "dev-0/in.tsv"
    out_file = "test-A/out.tsv"
    #out_file = "dev-0/out.tsv"
    print (f"in {in_file}")
    print (f"out {out_file}")
    with open(in_file) as in_f, open(out_file, 'w') as out_f:
        for line in in_f:
            hyp = calc_post_prob(line, paranormal_class_logprob, sceptic_class_logprob, word_logprobs)
            if hyp == 'sceptic':
                out_f.write(" S\n")
            elif hyp == 'paranormal':
                 out_f.write(' P\n')
main()