implement id3 algorithm

This commit is contained in:
matixezor 2021-05-16 19:41:45 +02:00
parent c6e2f5b796
commit 5560d6687d

View File

@ -0,0 +1,84 @@
from math import log
from typing import List, Dict
from collections import Counter
from anytree import Node, RenderTree
from machine_learning.data_set import training_set, test_set, attributes as attribs
def calculate_entropy(p: int, n: int) -> float:
entropy = 0
if p / (p + n) != 0:
entropy += - p / (p + n) * log(p / (p + n), 2)
if n / (p + n) != 0:
entropy += - n / (p + n) * log(n / (p + n), 2)
return entropy
def calculate_information_gain(examples: List[Dict], attribute: str) -> float:
counter = Counter([entry['action'] for entry in examples])
p, n = counter['defuse'], counter['detonation']
entropy = calculate_entropy(p, n)
values = list(Counter([entry[attribute] for entry in examples]).keys())
info_needed = 0
for value in values:
counter = Counter([entry['action'] for entry in examples if entry[attribute] == value])
p_i, n_i = counter['defuse'], counter['detonation']
info_needed += ((p_i + n_i) / (p + n)) * calculate_entropy(p_i, n_i)
return entropy - info_needed
def tree_learn(examples: List[Dict], attributes: List[str], default_class: str) -> Node:
if not examples:
return Node(default_class)
most_common, occurrences = Counter([entry['action'] for entry in examples]).most_common(1)[0]
if occurrences == len(examples) or not attributes:
return Node(most_common)
attributes_dict = {}
for attribute in attributes:
attributes_dict[attribute] = calculate_information_gain(examples, attribute)
root = Node(max(attributes_dict, key=attributes_dict.get))
values = list(Counter([entry[root.name] for entry in examples]).keys())
new_attributes = attributes.copy()
new_attributes.remove(root.name)
for value in values:
new_examples = [entry for entry in examples if entry[root.name] == value]
new_node = tree_learn(new_examples, new_attributes, most_common)
new_node.parent = root
new_node.edge = value
return root
def get_decision(data, root):
while root.children:
for children in root.children:
if data[root.name] == children.edge:
root = children
break
return root.name
tree_root = tree_learn(training_set, attribs, 'detonation')
print(RenderTree(tree_root))
print('-' * 150)
score = 0
for test in test_set:
print(f'Test data: {test}')
decision = get_decision(test, tree_root)
print(f'Decision: {decision}')
if decision == test['action']:
score += 1
print(f'Accuracy: {score/len(test_set)}')