#!/usr/bin/env python3
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import re, sys, pickle, random
from nltk.corpus import stopwords
import torch
import ipdb as ip
from string import punctuation
from collections import Counter
import numpy as np

train_on_gpu = False

class ClassifyLSTM(nn.Module):
    def __init__(self, vocab_size, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5):
        super().__init__()
        self.output_size = output_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=drop_prob, batch_first=True)
        
        self.dropout = nn.Dropout(0.3)
        
        self.fc = nn.Linear(hidden_dim, output_size)
        self.sig = nn.Sigmoid()

    def forward(self, x, hidden):
        batch_size = x.size(0)
        embeds = self.embedding(x)
        lstm_out, hidden = self.lstm(embeds, hidden)
        
        lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)
        out = self.dropout(lstm_out)
        out = self.fc(out)
        sig_out = self.sig(out)
        sig_out = sig_out.view(batch_size, -1)
        sig_out = sig_out[:, -1]
        return sig_out, hidden

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data

        if (train_on_gpu):
            hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda(), weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda())
        else:
            hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_(), weight.new(self.n_layers, batch_size, self.hidden_dim).zero_())
        return hidden


def clear_post(post):
    post = post.replace('\\n', ' ')
    post = post.lower()
    post = re.sub(r'(\(|)(http|https|www)[a-zA-Z0-9\.\:\/\_\=\&\;\?\+\-\%]+(\)|)', ' internetlink ', post)
    post = re.sub(r'[\.\,\/\~]+', ' ', post)
    post = re.sub(r'(&lt|&gt|\@[a-zA-Z0-9]+)','',post)
    post = re.sub(r'[\'\(\)\?\*\"\`\;0-9\[\]\:\%\|\–\”\!\=\^]+', '', post)
    post = re.sub(r'( \- |\-\-+)', ' ', post)
    post = re.sub(r' +', ' ', post)
    post = post.rstrip(' ')
    post = post.split(' ')
    stop_words = set(stopwords.words('english'))
    post_no_stop = [w for w in post if not w in stop_words]
    return post_no_stop

def count_all_words(posts):
    joint_posts = ' '.join(posts)
    words = joint_posts.split()
    count_words = Counter(words)
    total_words = len(words)
    sorted_words = count_words.most_common(total_words)
    return sorted_words, total_words, count_words

def pad_features(posts_int, seq_length):
    features = np.zeros((len(posts_int), seq_length), dtype = int)
    for i, post in enumerate(posts_int):
        post_len = len(post)
        if post_len <= seq_length:
            zeroes = list(np.zeros(seq_length-post_len))
            new = zeroes+post

        elif post_len > seq_length:
            new = post[0:seq_length]

        features[i,:] = np.array(new)

    return features

def main():
    if len(sys.argv) != 3:
        return

    in_file = sys.argv[1]
    expected_file = sys.argv[2]

    posts = []
    labels = []
    with open(in_file, 'r') as f:
        for line in f:
            post = line.split('\t')[0].rstrip().lower()
            post = ''.join([c for c in post if c not in punctuation])
            posts.append(post)

    with open(expected_file, 'r') as f:
        for line in f:
            labels.append(int(line))

    sorted_words, total_words, count_words = count_all_words(posts)
    vocab_to_int = {w:i+1 for i, (w,c) in enumerate(sorted_words)}

    posts_int = []
    for post in posts:
        p = [vocab_to_int[w] for w in post.split()]
        posts_int.append(p)

    encoded_labels = np.array(labels)

    posts_len = [len(p) for p in posts_int]
    pd.Series(posts_len).hist()
    print(pd.Series(posts_len).describe())
    # outliners
    posts_int = [posts_int[i] for i, l in enumerate(posts_len) if l>2 ]
    encoded_labels = [ encoded_labels[i] for i, l in enumerate(posts_len) if l> 2 ]

    seq_length = 63
    train_x = pad_features(posts_int, seq_length)
    train_y = np.array(encoded_labels)

    train_data = TensorDataset(torch.from_numpy(train_x), torch.from_numpy(train_y))
    batch_size = 50

    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, drop_last = True)

    dataiter = iter(train_loader)
    sample_x, sample_y = dataiter.next()

    vocab_size = len(vocab_to_int)+1
    output_size = 1
    embedding_dim = 400
    hidden_dim = 256
    n_layers = 2

    model = ClassifyLSTM(vocab_size, output_size, embedding_dim, hidden_dim, n_layers)
    
    lr=0.001
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    epochs = 4
    counter = 0
    print_every = 100
    clip=5
    if(train_on_gpu):
        model.cuda()

    model.train()
    for e in range(epochs):
        h = model.init_hidden(batch_size)
        for inputs, labels in train_loader:
            counter += 1
            if(train_on_gpu):
                inputs, labels = inputs.cuda(), labels.cuda()
            h = tuple([each.data for each in h])
            model.zero_grad()
            inputs = inputs.type(torch.LongTensor)
            output, h = model(inputs, h)
            loss = criterion(output.squeeze(), labels.float())
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()

            if counter % print_every == 0:
                val_h = model.init_hidden(batch_size)
                val_losses = []
                model.eval()
                #for inputs, labels in valid_loader:
                #    val_h = tuple([each.data for each in val_h])
                #    if(train_on_gpu):
                #        inputs, labels = inputs.cuda(), labels.cuda()
                #    inputs = inputs.type(torch.LongTensor)
                #    output, val_h = model(inputs, val_h)
                #    val_loss = criterion(output.squeeze(), labels.float())
                #    val_losses.append(val_loss.item())
                #model.train()
                print("Epoch: {}/{}...".format(e+1, epochs),
                                        "Step: {}...".format(counter),
                                        "Loss: {:.6f}...".format(loss.item()),
                                        "Val Loss: {:.6f}".format(np.mean(val_losses)))

#    test_losses = []
#    num_correct = 0
#    h = model.init_hidden(batch_size)
#    model.eval()
#    for inputs, labels in test_loader:
#        h = tuple([each.data for each in h])
#        if(train_on_gpu):
#            inputs, labels = inputs.cuda(), labels.cuda()
#
#        inputs = inputs.type(torch.LongTensor)
#        output, h = model(inputs, h)
#        test_loss = criterion(output.squeeze(), labels.float())
#        test_losses.append(test_loss.item())


    ip.set_trace()
main()