import os import json from matplotlib import pyplot from joblib import dump, load from sklearn import tree from sklearn.feature_extraction import DictVectorizer from objects.mines.disarming.mine_parameters import MineParameters from objects.mines.disarming.parameter_json import generate_data class DecisionTree: def __init__(self, clf_source: str = None, vec_source: str = None): if clf_source is not None and vec_source is not None: self.load(clf_source, vec_source) else: self.clf = None self.vec = None def build(self, training_file: str, depth: int): path = os.path.join("..", "..", "resources", "data", training_file) samples = list() results = list() with open(path, "r") as training_file: for sample in training_file: s, r = self._process_input_line(sample) samples.append(s) results.append(r) # vec transforms X (a list of dictionaries of string-string pairs) to binary arrays for tree to work on self.vec = DictVectorizer() # create and run Tree Clasifier upon provided data self.clf = tree.DecisionTreeClassifier(max_depth=depth) self.clf = self.clf.fit(self.vec.fit_transform(samples).toarray(), results) # print a tree (not necessary) print(tree.export_text(self.clf, feature_names=self.vec.get_feature_names())) # plot a tree (not necessary) # fig = pyplot.figure(figsize=(50, 40)) # _ = tree.plot_tree(self.clf, # feature_names=self.vec.get_feature_names(), # class_names=self.clf.classes_, # filled=True) # fig.savefig("decistion_tree.png") def save(self): dump(self.clf, 'decision_tree.joblib') dump(self.vec, 'dict_vectorizer.joblib') def load(self, clf_file, vec_file): self.clf = load(clf_file) self.vec = load(vec_file) def get_answer(self, mine_params): return self.clf.predict(self.vec.transform(mine_params).toarray()) def test(self): mistakes = 0 for _ in range(1000): mine_params = MineParameters().jsonifyable_dict() correct = mine_params['wire'] del mine_params['wire'] answer = self.get_answer(mine_params) if correct != answer: print(f"Answer: {answer}\nCorrect: {correct}") mistakes += 1 print(f"Accuracy: {100 - (mistakes / 10)}") @staticmethod def _process_input_line(line): data = json.loads(line.strip()) result = data['wire'] del data['wire'] sample = data return sample, result if __name__ == "__main__": # generate_data("training_set.txt", 12000) decision_tree = DecisionTree() decision_tree.build("training_set.txt", 15) decision_tree.test() decision_tree.save()