diff --git a/src/AI/DecisionTrees/DecisionTree.py b/src/AI/DecisionTrees/DecisionTree.py index 46a8fb9..0cf24b0 100644 --- a/src/AI/DecisionTrees/DecisionTree.py +++ b/src/AI/DecisionTrees/DecisionTree.py @@ -27,3 +27,14 @@ class DecisionTree(object): for branch in self.branches: if branch.label == attr.value: return branch.subtree.giveAnswer(example) + + @staticmethod + def printTree(tree, depth: int, indent: int = 20): + if isinstance(tree.root, AttributeDefinition): + print("NODE: {}".format(tree.root.name).rjust(indent * depth)) + else: + print("NODE: {}".format(str(tree.root)).rjust(indent * depth)) + + for branch in tree.branches: + print("| {}".format(str(branch.label)).rjust(indent * depth)) + DecisionTree.printTree(branch.subtree, depth + 1, indent)