from numpy.lib.shape_base import split
import pandas as pd
import numpy as np
import gensim
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from collections import Counter
from torchtext.vocab import vocab
from TorchCRF import CRF
from tqdm import tqdm

EPOCHS = 1
BATCH = 5
SEQ_LEN = 5

# Functions from jupyter


def build_vocab(dataset):
    counter = Counter()
    for document in dataset:
        counter.update(document)
    v = vocab(counter)
    v.set_default_index(0)
    return v


def data_process(dt, vocab):
    return [torch.tensor([vocab[token] for token in document], dtype=torch.long) for document in dt]


def get_scores(y_true, y_pred):
    y_true = [item for sublist in y_true for item in sublist]
    y_pred = [item for sublist in y_pred for item in sublist]
    acc_score = 0

    for p, t in zip(y_pred, y_true):
        if p == t:
            acc_score += 1

    return acc_score / len(y_pred)


class GRU(torch.nn.Module):
    def __init__(self, vocab_len):
        super(GRU, self).__init__()
        self.emb = torch.nn.Embedding(vocab_len, 100)
        self.rec = torch.nn.GRU(100, 256, 1, batch_first=True, dropout=0.2)
        self.fc1 = torch.nn.Linear(256, 9)

    def forward(self, x):
        emb = torch.relu(self.emb(x))
        gru_output, h_n = self.rec(emb)
        out_weights = self.fc1(gru_output)
        return out_weights

# Helpers


def translate(dt, vocab):
    translated = []
    for d in dt:
        translated.append([vocab.get_itos()[token] for token in d])
    return translated


def FIX_OUTPUT_FOR_GEVAL(out):
    result = []
    for line in out:
        last_label = None
        new_line = []
        for label in line:
            if(label != "O" and label[0:2] == "I-"):
                if last_label == None or last_label == "O":
                    label = label.replace('I-', 'B-')
                else:
                    label = "I-" + last_label[2:]
            last_label = label
            new_line.append(label)
            x = (" ".join(new_line))
        result.append(" ".join(new_line))
    return result


def save_to_file(out, out_path):
    lines = FIX_OUTPUT_FOR_GEVAL(out)
    with open(out_path, 'w+') as f:
        for line in lines:
            f.write(str(line) + '\n')


# Load data


def load_data():
    train = pd.read_csv('train/train.tsv', sep='\t',
                        names=['labels', 'document'])

    Y_train = [y.split(sep=" ") for y in train['labels'].values]
    X_train = [x.split(sep=" ") for x in train['document'].values]

    dev = pd.read_csv('dev-0/in.tsv', sep='\t', names=['document'])
    exp = pd.read_csv('dev-0/expected.tsv', sep='\t', names=['labels'])
    X_dev = [x.split(sep=" ") for x in dev['document'].values]
    Y_dev = [y.split(sep=" ") for y in exp['labels'].values]

    test = pd.read_csv('test-A/in.tsv', sep='\t', names=['document'])
    X_test = [x.split(sep=" ") for x in test['document'].values]

    return X_train, Y_train, X_dev, Y_dev, X_test

# Train and save model


def train(model, crf, train_tokens, labels_tokens):
    for i in range(EPOCHS):
        crf.train()
        model.train()
        for i in tqdm(range(len(labels_tokens))):
            batch_tokens = train_tokens[i].unsqueeze(0)
            tags = labels_tokens[i].unsqueeze(1)

            predicted_tags = model(batch_tokens).squeeze(0).unsqueeze(1)

            optimizer.zero_grad()
            loss = -crf(predicted_tags, tags)

            loss.backward()
            optimizer.step()

    torch.save(model.state_dict(), "model.torch")

# Eval dev set and save output


def dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab):
    Y_true = []
    Y_pred = []
    model.eval()
    crf.eval()
    for i in tqdm(range(len(dev_labels_tokens))):
        batch_tokens = dev_tokens[i].unsqueeze(0)
        tags = labels_tokens[i].unsqueeze(1)
        Y_true += [tags]

        Y_batch_pred = model(batch_tokens).squeeze(0).unsqueeze(1)
        Y_pred += [crf.decode(Y_batch_pred)[0]]

    Y_pred_translate = translate(Y_pred, vocab)
    Y_true_translate = translate(Y_true, vocab)

    precision = get_scores(Y_pred_translate, Y_true_translate)
    print(f'precision: {precision}'.format(precision))
    return Y_pred_translate


def test_eval(model, crf, test_tokens, vocab):
    Y_pred = []
    model.eval()
    crf.eval()
    for i in tqdm(range(len(test_tokens))):
        batch_tokens = test_tokens[i].unsqueeze(0)

        Y_batch_pred = model(batch_tokens).squeeze(0).unsqueeze(1)
        Y_pred += [crf.decode(Y_batch_pred)[0]]

    Y_pred_translate = translate(Y_pred, vocab)
    return Y_pred_translate


if __name__ == "__main__":
    X_train, Y_train, X_dev, Y_dev, X_test = load_data()
    vocab_x = build_vocab(X_train)
    vocab_y = build_vocab(Y_train)
    train_tokens = data_process(X_train, vocab_x)
    labels_tokens = data_process(Y_train, vocab_y)

    # train
    model = GRU(len(vocab_x.get_itos()))
    crf = CRF(9)
    p = list(model.parameters()) + list(crf.parameters())
    optimizer = torch.optim.Adam(p)
    # # mask = torch.ByteTensor([1, 1])  # (batch_size. sequence_size)
    # train(model, crf, train_tokens, labels_tokens)

    # eval dev
    model.load_state_dict(torch.load("model.torch"))
    dev_tokens = data_process(X_dev, vocab_x)
    dev_labels_tokens = data_process(Y_dev, vocab_y)
    dev_pred = dev_eval(model, crf, dev_tokens, dev_labels_tokens, vocab_y)
    save_to_file(dev_pred, 'dev-0/out.tsv')

    # predict test
    # model.load_state_dict(torch.load("model.torch"))
    test_tokens = data_process(X_test, vocab_x)
    test_pred = test_eval(model, crf, test_tokens, vocab_y)
    save_to_file(test_pred, 'test-A/out.tsv')