Add specifying training set size when testing DT

This commit is contained in:
Michał Czekański 2020-05-25 02:45:13 +02:00
parent a256fc1cfe
commit b17221c7c2

View File

@ -68,9 +68,9 @@ class Game:
if argv[2] == "-p": if argv[2] == "-p":
print("Running Decision Tree in pause mode.") print("Running Decision Tree in pause mode.")
self.dtRun(filesPath, True) self.dtRun(filesPath, True)
elif argv[2] == "-test": elif argv[2] == "-test" and len(argv) >= 5:
print("Testing Decision Tree.") print("Testing Decision Tree.")
self.dtTestRun(filesPath, int(argv[3])) self.dtTestRun(filesPath, iterations=int(argv[3]), trainingSetSize=float(argv[4]))
else: else:
print("Running Decision Tree.") print("Running Decision Tree.")
self.dtRun(filesPath) self.dtRun(filesPath)
@ -364,13 +364,23 @@ class Game:
geneticAlgorithmWithDecisionTree(self.map, iter, 10, dtExamples, 0.1) geneticAlgorithmWithDecisionTree(self.map, iter, 10, dtExamples, 0.1)
print("Time elapsed: ", self.pgTimer.tick() // 1000) print("Time elapsed: ", self.pgTimer.tick() // 1000)
def dtTestRun(self, filesPath, iterations): @staticmethod
def dtTestRun(filesPath, iterations: int, trainingSetSize: float = 0.9):
"""
:param filesPath:
:param iterations:
:param trainingSetSize: How many % of examples that will be read from file make as trainingSet. Value from 0 to 1.
"""
# Read examples for decision tree testing # Read examples for decision tree testing
examplesFilePath = str( examplesFilePath = str(
filesPath) + os.sep + "data" + os.sep + "AI_data" + os.sep + "dt_exmpls" + os.sep + "dt_examples" filesPath) + os.sep + "data" + os.sep + "AI_data" + os.sep + "dt_exmpls" + os.sep + "dt_examples"
examplesManager = ExamplesManager(examplesFilePath) examplesManager = ExamplesManager(examplesFilePath)
examples = examplesManager.readExamples() examples = examplesManager.readExamples()
scores = testDecisionTree(examples, iterations)
# Testing tree
scores = testDecisionTree(examples, iterations, trainingSetSize)
avg = sum(scores) / iterations avg = sum(scores) / iterations
print("Average: {}".format(str(avg))) print("Average: {}".format(str(avg)))