diff --git a/src/machine_learning/decision_tree.py b/src/machine_learning/decision_tree.py index f50dd4c..5d93991 100644 --- a/src/machine_learning/decision_tree.py +++ b/src/machine_learning/decision_tree.py @@ -69,16 +69,23 @@ def get_decision(data: dict, root: Node) -> str: return root.name +def main(): + print(RenderTree(tree_root)) + print('-' * 150) + + score = 0 + for test in test_set: + print(f'Test data: {test}') + decision = get_decision(test, tree_root) + print(f'Decision: {decision}') + if decision == test['action']: + score += 1 + + print(f'Accuracy: {score/len(test_set)}') + + tree_root = tree_learn(training_set, attribs, 'detonation') -print(RenderTree(tree_root)) -print('-' * 150) -score = 0 -for test in test_set: - print(f'Test data: {test}') - decision = get_decision(test, tree_root) - print(f'Decision: {decision}') - if decision == test['action']: - score += 1 -print(f'Accuracy: {score/len(test_set)}') +if __name__ == "__main__": + main()