From 5560d6687dc6ebe5ec564d2288c53a6ebf4bdb11 Mon Sep 17 00:00:00 2001 From: matixezor Date: Sun, 16 May 2021 19:41:45 +0200 Subject: [PATCH] implement id3 algorithm --- src/machine_learning/decision_tree.py | 84 +++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 src/machine_learning/decision_tree.py diff --git a/src/machine_learning/decision_tree.py b/src/machine_learning/decision_tree.py new file mode 100644 index 0000000..b8ba973 --- /dev/null +++ b/src/machine_learning/decision_tree.py @@ -0,0 +1,84 @@ +from math import log +from typing import List, Dict +from collections import Counter + +from anytree import Node, RenderTree + +from machine_learning.data_set import training_set, test_set, attributes as attribs + + +def calculate_entropy(p: int, n: int) -> float: + entropy = 0 + if p / (p + n) != 0: + entropy += - p / (p + n) * log(p / (p + n), 2) + if n / (p + n) != 0: + entropy += - n / (p + n) * log(n / (p + n), 2) + return entropy + + +def calculate_information_gain(examples: List[Dict], attribute: str) -> float: + counter = Counter([entry['action'] for entry in examples]) + p, n = counter['defuse'], counter['detonation'] + + entropy = calculate_entropy(p, n) + + values = list(Counter([entry[attribute] for entry in examples]).keys()) + + info_needed = 0 + for value in values: + counter = Counter([entry['action'] for entry in examples if entry[attribute] == value]) + p_i, n_i = counter['defuse'], counter['detonation'] + + info_needed += ((p_i + n_i) / (p + n)) * calculate_entropy(p_i, n_i) + + return entropy - info_needed + + +def tree_learn(examples: List[Dict], attributes: List[str], default_class: str) -> Node: + if not examples: + return Node(default_class) + + most_common, occurrences = Counter([entry['action'] for entry in examples]).most_common(1)[0] + if occurrences == len(examples) or not attributes: + return Node(most_common) + + attributes_dict = {} + for attribute in attributes: + attributes_dict[attribute] = calculate_information_gain(examples, attribute) + + root = Node(max(attributes_dict, key=attributes_dict.get)) + values = list(Counter([entry[root.name] for entry in examples]).keys()) + new_attributes = attributes.copy() + new_attributes.remove(root.name) + + for value in values: + new_examples = [entry for entry in examples if entry[root.name] == value] + new_node = tree_learn(new_examples, new_attributes, most_common) + new_node.parent = root + new_node.edge = value + + return root + + +def get_decision(data, root): + while root.children: + for children in root.children: + if data[root.name] == children.edge: + root = children + break + return root.name + + +tree_root = tree_learn(training_set, attribs, 'detonation') +print(RenderTree(tree_root)) +print('-' * 150) + +score = 0 +for test in test_set: + print(f'Test data: {test}') + decision = get_decision(test, tree_root) + print(f'Decision: {decision}') + if decision == test['action']: + score += 1 + +print(f'Accuracy: {score/len(test_set)}')