Dodanie 'id3.py'
This commit is contained in:
parent
b9370cb6f9
commit
4c3da838af
169
id3.py
Normal file
169
id3.py
Normal file
@ -0,0 +1,169 @@
|
||||
from collections import Counter
|
||||
import operator
|
||||
from types import prepare_class
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
class Case:
|
||||
def __init__(self, values, attributes, Class):
|
||||
self.values = values
|
||||
self.attributes = attributes
|
||||
self.Class = Class
|
||||
|
||||
class Node:
|
||||
def __init__(self, Class, tag=None):
|
||||
self.Class = Class
|
||||
self.childs = []
|
||||
|
||||
def classes_of_cases (cases):
|
||||
classes = []
|
||||
for case in cases:
|
||||
if case.Class not in classes:
|
||||
classes.append(case.Class)
|
||||
return classes
|
||||
|
||||
def count_classes (cases):
|
||||
classes = []
|
||||
for case in cases:
|
||||
classes.append(case.Class)
|
||||
c = Counter(classes)
|
||||
return max(c.items(), key=operator.itemgetter(1))[0]
|
||||
|
||||
def chose_attribute (cases, attributes):
|
||||
a = ""
|
||||
max = float("-inf")
|
||||
for attribute in attributes:
|
||||
if I(cases) - E(cases, attribute) >= max:
|
||||
max = I(cases) - E(cases, attribute)
|
||||
a = attribute
|
||||
return a
|
||||
|
||||
def I (cases):
|
||||
i = 0
|
||||
all = len(cases)
|
||||
classes = classes_of_cases(cases)
|
||||
for Class in classes:
|
||||
noc = 0
|
||||
for case in cases:
|
||||
if case.Class == Class:
|
||||
noc += 1
|
||||
i -= (noc/all)*np.log2(noc/all)
|
||||
return i
|
||||
|
||||
def E(cases, attribute):
|
||||
e = 0
|
||||
values = []
|
||||
index = cases[0].attributes.index(attribute)
|
||||
print(attribute)
|
||||
print(cases[0].values)
|
||||
print(index)
|
||||
for case in cases:
|
||||
if case.values[index] not in values:
|
||||
values.append(case.values[index])
|
||||
for value in values:
|
||||
ei = []
|
||||
for case in cases:
|
||||
if case.values[index] == value:
|
||||
ei.append(case)
|
||||
e += (len(ei)/len(cases))*I(ei)
|
||||
return e
|
||||
|
||||
|
||||
def treelearn(cases, attributes, default_class):
|
||||
if cases == []:
|
||||
t = Node(default_class)
|
||||
return t
|
||||
if len(classes_of_cases(cases)) == 1:
|
||||
t = Node(cases[0].Class)
|
||||
return t
|
||||
if attributes == []:
|
||||
t = Node(count_classes(cases))
|
||||
return t
|
||||
A = chose_attribute(cases, attributes)
|
||||
t = Node(A)
|
||||
new_default_class = count_classes(cases)
|
||||
|
||||
print(attributes, end=" ")
|
||||
|
||||
values = []
|
||||
index = attributes.index(A)
|
||||
print(index)
|
||||
for case in cases:
|
||||
if case.values[index] not in values:
|
||||
values.append(case.values[index])
|
||||
|
||||
print(values, end="")
|
||||
print(A)
|
||||
for value in values:
|
||||
new_cases = []
|
||||
for case in cases:
|
||||
if case.values[index] == value:
|
||||
new_case = copy.deepcopy(case)
|
||||
new_case.values = case.values[:index] + case.values[index+1:]
|
||||
new_case.attributes = case.attributes[:index] + case.attributes[index+1:]
|
||||
new_cases.append(new_case)
|
||||
new_attributes = attributes[:index] + attributes[index+1 :]
|
||||
child = treelearn(new_cases, new_attributes, new_default_class)
|
||||
t.childs.append([child, value])
|
||||
|
||||
return t
|
||||
|
||||
def pretty_print(root, n):
|
||||
for _ in range(n):
|
||||
print("\t", end="")
|
||||
print(root.Class)
|
||||
for child in root.childs:
|
||||
for _ in range(n):
|
||||
print("\t", end="")
|
||||
print("== " + str(child[1]))
|
||||
pretty_print(child[0], n+1)
|
||||
|
||||
attr = ["hydration", "fertility", "plant_type", "ticks", "is_healthy", "tractor_there"]
|
||||
ccases = []
|
||||
k = Case([2, 0, "wheat", 31, 0, 0], attr, 1)
|
||||
ccases.append(k)
|
||||
k = Case([3, 0, "wheat", 31, 0, 0], attr, 1)
|
||||
ccases.append(k)
|
||||
k = Case([4, 0, "wheat", 31, 0, 0], attr, 1)
|
||||
ccases.append(k)
|
||||
k = Case([1, 1, "wheat", 31, 0, 0], attr, 1)
|
||||
ccases.append(k)
|
||||
k = Case([3, 0, "wheat", 20, 0, 0], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([2, 0, "wheat", 20, 0, 0], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([4, 0, "potato", 31, 0, 0], attr, 1)
|
||||
ccases.append(k)
|
||||
k = Case([3, 0, "potato", 31, 0, 0], attr, 1)
|
||||
ccases.append(k)
|
||||
k = Case([2, 0, "potato", 31, 0, 0], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([2, 0, "potato", 31, 0, 0], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([2, 1, "potato", 31, 0, 0], attr, 1)
|
||||
ccases.append(k)
|
||||
k = Case([1, 1, "potato", 31, 0, 0], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([4, 1, "potato", 31, 0, 0], attr, 1)
|
||||
ccases.append(k)
|
||||
k = Case([4, 1, "potato", 19, 0, 0], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([4, 1, "potato", 31, 1, 0], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([4, 1, "wheat", 19, 0, 0], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([4, 1, "potato", 31, 0, 1], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([4, 1, "wheat", 31, 1, 0], attr, 0)
|
||||
ccases.append(k)
|
||||
k = Case([2, 0, "wheat", 31, 0, 1], attr, 0)
|
||||
ccases.append(k)
|
||||
|
||||
tree = treelearn(ccases, attr, 0)
|
||||
pretty_print(tree, 0)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user