diff --git a/src/AI/DecisionTrees/DecisionTree.py b/src/AI/DecisionTrees/DecisionTree.py index f224bda..46a8fb9 100644 --- a/src/AI/DecisionTrees/DecisionTree.py +++ b/src/AI/DecisionTrees/DecisionTree.py @@ -1,7 +1,8 @@ -from typing import List, Any +from typing import List from src.AI.DecisionTrees.AttributeDefinition import AttributeDefinition from src.AI.DecisionTrees.DecisionTreeBranch import DecisionTreeBranch +from src.AI.DecisionTrees.DecisionTreeExample import DecisionTreeExample class DecisionTree(object): @@ -16,3 +17,13 @@ class DecisionTree(object): def addBranch(self, newBranch): self.branches.append(newBranch) self.branchesNum += 1 + + def giveAnswer(self, example: DecisionTreeExample): + if self.branchesNum == 0: + return self.root + + for attr in example.attributes: + if attr.attributeDefinition.id == self.root.id: + for branch in self.branches: + if branch.label == attr.value: + return branch.subtree.giveAnswer(example)