Sztuczna_Inteligencja-projekt/id3.py

122 lines
3.2 KiB
Python

from cases import *
from collections import Counter
import operator
from types import prepare_class
import numpy as np
import copy
class Node:
def __init__(self, Class, tag=None):
self.Class = Class
self.childs = []
def classes_of_cases (cases):
classes = []
for case in cases:
if case.Class not in classes:
classes.append(case.Class)
return classes
def count_classes (cases):
classes = []
for case in cases:
classes.append(case.Class)
c = Counter(classes)
return max(c.items(), key=operator.itemgetter(1))[0]
def chose_attribute (cases, attributes):
a = ""
max = float("-inf")
for attribute in attributes:
if I(cases) - E(cases, attribute) >= max:
max = I(cases) - E(cases, attribute)
a = attribute
return a
def I (cases):
i = 0
all = len(cases)
classes = classes_of_cases(cases)
for Class in classes:
noc = 0
for case in cases:
if case.Class == Class:
noc += 1
i -= (noc/all)*np.log2(noc/all)
return i
def E(cases, attribute):
e = 0
values = []
index = cases[0].attributes.index(attribute)
for case in cases:
if case.values[index] not in values:
values.append(case.values[index])
for value in values:
ei = []
for case in cases:
if case.values[index] == value:
ei.append(case)
e += (len(ei)/len(cases))*I(ei)
return e
def treelearn(cases, attributes, default_class):
if cases == []:
t = Node(default_class)
return t
if len(classes_of_cases(cases)) == 1:
t = Node(cases[0].Class)
return t
if attributes == []:
t = Node(count_classes(cases))
return t
A = chose_attribute(cases, attributes)
t = Node(A)
new_default_class = count_classes(cases)
values = []
index = attributes.index(A)
for case in cases:
if case.values[index] not in values:
values.append(case.values[index])
for value in values:
new_cases = []
for case in cases:
if case.values[index] == value:
new_case = copy.deepcopy(case)
new_case.values = case.values[:index] + case.values[index+1:]
new_case.attributes = case.attributes[:index] + case.attributes[index+1:]
new_cases.append(new_case)
new_attributes = attributes[:index] + attributes[index+1 :]
child = treelearn(new_cases, new_attributes, new_default_class)
t.childs.append([child, value])
return t
def pretty_print(root, n):
if len(root.childs) == 0:
for _ in range(n):
print(" ", end="")
print("return " + str(root.Class))
for child in root.childs:
for _ in range(n):
print(" ", end="")
if child != root.childs[0]:
print("el", end= "")
if len(str(child[1])) > 1:
print("if self." + str(root.Class) + " == \"" + str(child[1]) + "\":")
else:
print("if self." + str(root.Class) + " == " + str(child[1]) + ":")
pretty_print(child[0], n+1)
tree = treelearn(cases, attributes, 0)
pretty_print(tree, 0)