2021-06-14 23:54:30 +02:00
|
|
|
import copy
|
2021-05-23 23:15:50 +02:00
|
|
|
import operator
|
2021-06-14 23:54:30 +02:00
|
|
|
from collections import Counter
|
|
|
|
|
2021-05-23 23:15:50 +02:00
|
|
|
import numpy as np
|
2021-06-14 23:54:30 +02:00
|
|
|
|
|
|
|
from src.cases import *
|
|
|
|
|
2021-05-23 23:15:50 +02:00
|
|
|
|
|
|
|
class Node:
|
|
|
|
def __init__(self, Class, tag=None):
|
|
|
|
self.Class = Class
|
|
|
|
self.childs = []
|
|
|
|
|
2021-06-14 23:54:30 +02:00
|
|
|
|
|
|
|
def classes_of_cases(cases):
|
2021-05-23 23:15:50 +02:00
|
|
|
classes = []
|
|
|
|
for case in cases:
|
|
|
|
if case.Class not in classes:
|
|
|
|
classes.append(case.Class)
|
|
|
|
return classes
|
|
|
|
|
2021-06-14 23:54:30 +02:00
|
|
|
|
|
|
|
def count_classes(cases):
|
2021-05-23 23:15:50 +02:00
|
|
|
classes = []
|
|
|
|
for case in cases:
|
|
|
|
classes.append(case.Class)
|
|
|
|
c = Counter(classes)
|
|
|
|
return max(c.items(), key=operator.itemgetter(1))[0]
|
|
|
|
|
2021-06-14 23:54:30 +02:00
|
|
|
|
|
|
|
def chose_attribute(cases, attributes):
|
2021-05-23 23:15:50 +02:00
|
|
|
a = ""
|
|
|
|
max = float("-inf")
|
|
|
|
for attribute in attributes:
|
|
|
|
if I(cases) - E(cases, attribute) >= max:
|
|
|
|
max = I(cases) - E(cases, attribute)
|
|
|
|
a = attribute
|
|
|
|
return a
|
|
|
|
|
2021-06-14 23:54:30 +02:00
|
|
|
|
|
|
|
def I(cases):
|
2021-05-23 23:15:50 +02:00
|
|
|
i = 0
|
|
|
|
all = len(cases)
|
|
|
|
classes = classes_of_cases(cases)
|
|
|
|
for Class in classes:
|
|
|
|
noc = 0
|
|
|
|
for case in cases:
|
|
|
|
if case.Class == Class:
|
|
|
|
noc += 1
|
2021-06-14 23:54:30 +02:00
|
|
|
i -= (noc / all) * np.log2(noc / all)
|
2021-05-23 23:15:50 +02:00
|
|
|
return i
|
|
|
|
|
2021-06-14 23:54:30 +02:00
|
|
|
|
2021-05-23 23:15:50 +02:00
|
|
|
def E(cases, attribute):
|
|
|
|
e = 0
|
|
|
|
values = []
|
|
|
|
index = cases[0].attributes.index(attribute)
|
|
|
|
for case in cases:
|
|
|
|
if case.values[index] not in values:
|
|
|
|
values.append(case.values[index])
|
|
|
|
for value in values:
|
|
|
|
ei = []
|
|
|
|
for case in cases:
|
|
|
|
if case.values[index] == value:
|
|
|
|
ei.append(case)
|
2021-06-14 23:54:30 +02:00
|
|
|
e += (len(ei) / len(cases)) * I(ei)
|
2021-05-23 23:15:50 +02:00
|
|
|
return e
|
2021-06-14 23:54:30 +02:00
|
|
|
|
2021-05-23 23:15:50 +02:00
|
|
|
|
|
|
|
def treelearn(cases, attributes, default_class):
|
|
|
|
if cases == []:
|
|
|
|
t = Node(default_class)
|
|
|
|
return t
|
|
|
|
if len(classes_of_cases(cases)) == 1:
|
|
|
|
t = Node(cases[0].Class)
|
|
|
|
return t
|
|
|
|
if attributes == []:
|
|
|
|
t = Node(count_classes(cases))
|
|
|
|
return t
|
|
|
|
A = chose_attribute(cases, attributes)
|
|
|
|
t = Node(A)
|
|
|
|
new_default_class = count_classes(cases)
|
|
|
|
|
|
|
|
values = []
|
|
|
|
index = attributes.index(A)
|
|
|
|
for case in cases:
|
|
|
|
if case.values[index] not in values:
|
|
|
|
values.append(case.values[index])
|
|
|
|
|
|
|
|
for value in values:
|
|
|
|
new_cases = []
|
|
|
|
for case in cases:
|
|
|
|
if case.values[index] == value:
|
|
|
|
new_case = copy.deepcopy(case)
|
2021-06-14 23:54:30 +02:00
|
|
|
new_case.values = case.values[:index] + case.values[index + 1:]
|
|
|
|
new_case.attributes = case.attributes[:index] + case.attributes[index + 1:]
|
2021-05-23 23:15:50 +02:00
|
|
|
new_cases.append(new_case)
|
2021-06-14 23:54:30 +02:00
|
|
|
new_attributes = attributes[:index] + attributes[index + 1:]
|
2021-05-23 23:15:50 +02:00
|
|
|
child = treelearn(new_cases, new_attributes, new_default_class)
|
|
|
|
t.childs.append([child, value])
|
|
|
|
|
|
|
|
return t
|
|
|
|
|
2021-06-14 23:54:30 +02:00
|
|
|
|
2021-05-23 23:15:50 +02:00
|
|
|
def pretty_print(root, n):
|
2021-06-03 16:07:54 +02:00
|
|
|
if len(root.childs) == 0:
|
|
|
|
for _ in range(n):
|
|
|
|
print(" ", end="")
|
|
|
|
print("return " + str(root.Class))
|
2021-05-23 23:15:50 +02:00
|
|
|
for child in root.childs:
|
|
|
|
for _ in range(n):
|
2021-06-03 16:07:54 +02:00
|
|
|
print(" ", end="")
|
|
|
|
if child != root.childs[0]:
|
2021-06-14 23:54:30 +02:00
|
|
|
print("el", end="")
|
2021-06-03 16:07:54 +02:00
|
|
|
if len(str(child[1])) > 1:
|
|
|
|
print("if self." + str(root.Class) + " == \"" + str(child[1]) + "\":")
|
|
|
|
else:
|
|
|
|
print("if self." + str(root.Class) + " == " + str(child[1]) + ":")
|
2021-06-14 23:54:30 +02:00
|
|
|
pretty_print(child[0], n + 1)
|
2021-05-23 23:15:50 +02:00
|
|
|
|
2021-06-03 16:07:54 +02:00
|
|
|
|
2021-06-15 01:36:18 +02:00
|
|
|
# Get view of decision_tree.py
|
|
|
|
if __name__ == "__main__":
|
|
|
|
tree = treelearn(cases, attributes, 0)
|
|
|
|
pretty_print(tree, 0)
|