import os
import json
from joblib import dump, load
from sklearn import tree
from sklearn.feature_extraction import DictVectorizer

from disarming.parameters.mine_parameters import MineParameters
from disarming.parameters.parameter_json import generate_data


class DecisionTree:
    def __init__(self, load_from_file: bool = False):
        if load_from_file:
            clf_source = r"algorithms/learn/decision_tree/decision_tree.joblib"
            vec_source = r"algorithms/learn/decision_tree/dict_vectorizer.joblib"
            self.load(clf_source, vec_source)
        else:
            self.clf = None
            self.vec = None

    def build(self, training_file: str, depth: int):
        path = os.path.join("../..", "..", "resources", "data", "decision_tree", training_file)

        samples = list()
        results = list()

        with open(path, "r") as training_file:
            for sample in training_file:
                s, r = self._process_input_line(sample)
                samples.append(s)
                results.append(r)

        # vec transforms X (a list of dictionaries of string-string pairs) to binary arrays for tree to work on
        self.vec = DictVectorizer()

        # create and run Tree Clasifier upon provided data
        self.clf = tree.DecisionTreeClassifier(max_depth=depth)
        self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results)

        # print a tree (not necessary)
        print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names()))

        # plot a tree (not necessary)
        # fig = pyplot.figure(figsize=(50, 40))
        # _ = tree.plot_tree(self.clf,
        #                    feature_names=self.vec.get_feature_names(),
        #                    class_names=self.clf.classes_,
        #                    filled=True)
        # fig.savefig("decistion_tree.png")

    def save(self):
        dump(self.clf, 'decision_tree.joblib')
        dump(self.vec, 'dict_vectorizer.joblib')

    def load(self, clf_file, vec_file):
        self.clf = load(clf_file)
        self.vec = load(vec_file)

    def get_answer(self, mine_params):
        return self.clf.predict(self.vec.transform(mine_params).toarray())

    def test(self):
        mistakes = 0
        for _ in range(1000):
            mine_params = MineParameters().jsonifyable_dict()
            correct = mine_params['wire']
            del mine_params['wire']

            answer = self.get_answer(mine_params)

            if correct != answer:
                print(f"Answer: {answer}\nCorrect: {correct}")
                mistakes += 1

        print(f"Accuracy: {100 - (mistakes / 10)}")

    @staticmethod
    def _process_input_line(line):
        data = json.loads(line.strip())
        result = data['wire']
        del data['wire']
        sample = data

        return sample, result


if __name__ == "__main__":
    generate_data("training_set.txt", 4500)
    decision_tree = DecisionTree()
    decision_tree.build("training_set.txt", 15)
    decision_tree.test()
    decision_tree.save()