Add missing file with DT testing function
This commit is contained in:
parent
b17221c7c2
commit
f57bc6463b
39
src/AI/DecisionTrees/TestDecisionTree.py
Normal file
39
src/AI/DecisionTrees/TestDecisionTree.py
Normal file
@ -0,0 +1,39 @@
|
||||
import random
|
||||
|
||||
from typing import List
|
||||
|
||||
from src.AI.DecisionTrees.DecisionTreeExample import DecisionTreeExample
|
||||
from src.AI.DecisionTrees.InductiveDecisionTreeLearning import inductiveDecisionTreeLearning
|
||||
from src.AI.DecisionTrees.projectSpecificClasses.SurvivalAttributesDefinitions import SurvivalAttributesDefinitions
|
||||
from src.AI.DecisionTrees.projectSpecificClasses.SurvivalClassification import SurvivalClassification
|
||||
|
||||
|
||||
def testDecisionTree(examples: List[DecisionTreeExample], iterations=10, partOfExamplesAsTrainingSet=0.9):
|
||||
examplesNum = len(examples)
|
||||
trainingSetSize = int(examplesNum * partOfExamplesAsTrainingSet)
|
||||
testSetSize = examplesNum - trainingSetSize
|
||||
|
||||
treeScores = []
|
||||
|
||||
for i in range(iterations):
|
||||
# Shuffling examples
|
||||
random.shuffle(examples)
|
||||
# Test and training set
|
||||
trainingSet = examples[:trainingSetSize]
|
||||
testSet = examples[trainingSetSize:]
|
||||
|
||||
# Create decision tree out of training set
|
||||
dt = inductiveDecisionTreeLearning(trainingSet, SurvivalAttributesDefinitions.allAttributesDefinitions,
|
||||
SurvivalClassification.FOOD, SurvivalClassification)
|
||||
|
||||
# Check how many answers will be correct for test set
|
||||
correctAnswers = 0
|
||||
for testExample in testSet:
|
||||
dtAnswer = dt.giveAnswer(testExample)
|
||||
if dtAnswer == testExample.classification:
|
||||
correctAnswers += 1
|
||||
|
||||
treeScores.append(correctAnswers / testSetSize)
|
||||
|
||||
return treeScores
|
||||
|
Loading…
Reference in New Issue
Block a user