implement id3 algorithm
This commit is contained in:
parent
c6e2f5b796
commit
5560d6687d
84
src/machine_learning/decision_tree.py
Normal file
84
src/machine_learning/decision_tree.py
Normal file
@ -0,0 +1,84 @@
|
||||
from math import log
|
||||
from typing import List, Dict
|
||||
from collections import Counter
|
||||
|
||||
from anytree import Node, RenderTree
|
||||
|
||||
from machine_learning.data_set import training_set, test_set, attributes as attribs
|
||||
|
||||
|
||||
def calculate_entropy(p: int, n: int) -> float:
|
||||
entropy = 0
|
||||
if p / (p + n) != 0:
|
||||
entropy += - p / (p + n) * log(p / (p + n), 2)
|
||||
if n / (p + n) != 0:
|
||||
entropy += - n / (p + n) * log(n / (p + n), 2)
|
||||
return entropy
|
||||
|
||||
|
||||
def calculate_information_gain(examples: List[Dict], attribute: str) -> float:
|
||||
counter = Counter([entry['action'] for entry in examples])
|
||||
p, n = counter['defuse'], counter['detonation']
|
||||
|
||||
entropy = calculate_entropy(p, n)
|
||||
|
||||
values = list(Counter([entry[attribute] for entry in examples]).keys())
|
||||
|
||||
info_needed = 0
|
||||
for value in values:
|
||||
counter = Counter([entry['action'] for entry in examples if entry[attribute] == value])
|
||||
p_i, n_i = counter['defuse'], counter['detonation']
|
||||
|
||||
info_needed += ((p_i + n_i) / (p + n)) * calculate_entropy(p_i, n_i)
|
||||
|
||||
return entropy - info_needed
|
||||
|
||||
|
||||
def tree_learn(examples: List[Dict], attributes: List[str], default_class: str) -> Node:
|
||||
if not examples:
|
||||
return Node(default_class)
|
||||
|
||||
most_common, occurrences = Counter([entry['action'] for entry in examples]).most_common(1)[0]
|
||||
if occurrences == len(examples) or not attributes:
|
||||
return Node(most_common)
|
||||
|
||||
attributes_dict = {}
|
||||
for attribute in attributes:
|
||||
attributes_dict[attribute] = calculate_information_gain(examples, attribute)
|
||||
|
||||
root = Node(max(attributes_dict, key=attributes_dict.get))
|
||||
values = list(Counter([entry[root.name] for entry in examples]).keys())
|
||||
new_attributes = attributes.copy()
|
||||
new_attributes.remove(root.name)
|
||||
|
||||
for value in values:
|
||||
new_examples = [entry for entry in examples if entry[root.name] == value]
|
||||
new_node = tree_learn(new_examples, new_attributes, most_common)
|
||||
new_node.parent = root
|
||||
new_node.edge = value
|
||||
|
||||
return root
|
||||
|
||||
|
||||
def get_decision(data, root):
|
||||
while root.children:
|
||||
for children in root.children:
|
||||
if data[root.name] == children.edge:
|
||||
root = children
|
||||
break
|
||||
return root.name
|
||||
|
||||
|
||||
tree_root = tree_learn(training_set, attribs, 'detonation')
|
||||
print(RenderTree(tree_root))
|
||||
print('-' * 150)
|
||||
|
||||
score = 0
|
||||
for test in test_set:
|
||||
print(f'Test data: {test}')
|
||||
decision = get_decision(test, tree_root)
|
||||
print(f'Decision: {decision}')
|
||||
if decision == test['action']:
|
||||
score += 1
|
||||
|
||||
print(f'Accuracy: {score/len(test_set)}')
|
Loading…
Reference in New Issue
Block a user