implement id3 algorithm
This commit is contained in:
parent
c6e2f5b796
commit
5560d6687d
84
src/machine_learning/decision_tree.py
Normal file
84
src/machine_learning/decision_tree.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from math import log
|
||||||
|
from typing import List, Dict
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
from anytree import Node, RenderTree
|
||||||
|
|
||||||
|
from machine_learning.data_set import training_set, test_set, attributes as attribs
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_entropy(p: int, n: int) -> float:
|
||||||
|
entropy = 0
|
||||||
|
if p / (p + n) != 0:
|
||||||
|
entropy += - p / (p + n) * log(p / (p + n), 2)
|
||||||
|
if n / (p + n) != 0:
|
||||||
|
entropy += - n / (p + n) * log(n / (p + n), 2)
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_information_gain(examples: List[Dict], attribute: str) -> float:
|
||||||
|
counter = Counter([entry['action'] for entry in examples])
|
||||||
|
p, n = counter['defuse'], counter['detonation']
|
||||||
|
|
||||||
|
entropy = calculate_entropy(p, n)
|
||||||
|
|
||||||
|
values = list(Counter([entry[attribute] for entry in examples]).keys())
|
||||||
|
|
||||||
|
info_needed = 0
|
||||||
|
for value in values:
|
||||||
|
counter = Counter([entry['action'] for entry in examples if entry[attribute] == value])
|
||||||
|
p_i, n_i = counter['defuse'], counter['detonation']
|
||||||
|
|
||||||
|
info_needed += ((p_i + n_i) / (p + n)) * calculate_entropy(p_i, n_i)
|
||||||
|
|
||||||
|
return entropy - info_needed
|
||||||
|
|
||||||
|
|
||||||
|
def tree_learn(examples: List[Dict], attributes: List[str], default_class: str) -> Node:
|
||||||
|
if not examples:
|
||||||
|
return Node(default_class)
|
||||||
|
|
||||||
|
most_common, occurrences = Counter([entry['action'] for entry in examples]).most_common(1)[0]
|
||||||
|
if occurrences == len(examples) or not attributes:
|
||||||
|
return Node(most_common)
|
||||||
|
|
||||||
|
attributes_dict = {}
|
||||||
|
for attribute in attributes:
|
||||||
|
attributes_dict[attribute] = calculate_information_gain(examples, attribute)
|
||||||
|
|
||||||
|
root = Node(max(attributes_dict, key=attributes_dict.get))
|
||||||
|
values = list(Counter([entry[root.name] for entry in examples]).keys())
|
||||||
|
new_attributes = attributes.copy()
|
||||||
|
new_attributes.remove(root.name)
|
||||||
|
|
||||||
|
for value in values:
|
||||||
|
new_examples = [entry for entry in examples if entry[root.name] == value]
|
||||||
|
new_node = tree_learn(new_examples, new_attributes, most_common)
|
||||||
|
new_node.parent = root
|
||||||
|
new_node.edge = value
|
||||||
|
|
||||||
|
return root
|
||||||
|
|
||||||
|
|
||||||
|
def get_decision(data, root):
|
||||||
|
while root.children:
|
||||||
|
for children in root.children:
|
||||||
|
if data[root.name] == children.edge:
|
||||||
|
root = children
|
||||||
|
break
|
||||||
|
return root.name
|
||||||
|
|
||||||
|
|
||||||
|
tree_root = tree_learn(training_set, attribs, 'detonation')
|
||||||
|
print(RenderTree(tree_root))
|
||||||
|
print('-' * 150)
|
||||||
|
|
||||||
|
score = 0
|
||||||
|
for test in test_set:
|
||||||
|
print(f'Test data: {test}')
|
||||||
|
decision = get_decision(test, tree_root)
|
||||||
|
print(f'Decision: {decision}')
|
||||||
|
if decision == test['action']:
|
||||||
|
score += 1
|
||||||
|
|
||||||
|
print(f'Accuracy: {score/len(test_set)}')
|
Loading…
Reference in New Issue
Block a user