added some comments

This commit is contained in:
Serhii Hromov 2020-05-18 12:37:08 +00:00
parent 7e92796a19
commit 26a2824818

70
main.py
View File

@ -56,45 +56,45 @@ tree_format = ["dish", "temperature", "label"]
#print(func_output) #print(func_output)
def uniq_val_from_data(rows, col): def uniq_count(rows):
return set([row[col] for row in rows]) #count uniq labels(names)
count = {}
def class_counts(rows):
counts = {}
for row in rows: for row in rows:
label = row[-1] lbl = row[-1]
if label not in counts: if lbl not in count:
counts[label] = 0 count[lbl] = 0
counts[label] += 1 count[lbl] += 1
return counts return count
#didn't used
def isnumer(value): def isnumer(val):
return isinstance(value, int) or isinstance(value, float) return isinstance(val, int) or isinstance(val, float)
class Question(): class Question():
def __init__(self, col, value): def __init__(self, col, value):
self.col = col self.col = col #column
self.value = value self.value = value #value of column
def compare(self, example): def compare(self, example):
#compare val in example with val in the question
val = example[self.col] val = example[self.col]
if isnumer(val): if isnumer(val): #in case menu have prices
return val >= self.value return val >= self.value
else: else:
return val == self.value return val == self.value
def __repr__(self): def __repr__(self):
#just to print
condition = "==" condition = "=="
if isnumer(self.value): if isnumer(self.value):
condition = ">=" condition = ">="
return "Is %s %s %s?" % (tree_format[self.col], condition, str(self.value)) return "Is %s %s %s?" % (tree_format[self.col], condition, str(self.value))
def partition(rows, quest): def split(rows, quest):
#split data into True and False
t_rows, f_rows = [], [] t_rows, f_rows = [], []
for row in rows: for row in rows:
if quest.compare(row): if quest.compare(row):
@ -105,7 +105,7 @@ def partition(rows, quest):
def gini(rows): def gini(rows):
counts = class_counts(rows) counts = uniq_count(rows)
impurity = 1 impurity = 1
for lbl in counts: for lbl in counts:
prob_of_lbl = counts[lbl] / float(len(rows)) prob_of_lbl = counts[lbl] / float(len(rows))
@ -114,11 +114,12 @@ def gini(rows):
def info_gain(l, r, current_uncertainty): def info_gain(l, r, current_uncertainty):
p = float(len(l)) / (len(l) + len(r)) p = float(len(l)) / (len(l) + len(r)) #something like an enthropy?
return current_uncertainty - p*gini(l) - (1-p)*gini(r) return current_uncertainty - p*gini(l) - (1-p)*gini(r)
def find_best_q(rows): def find_best_q(rows):
#best question to split the data
best_gain = 0 best_gain = 0
best_quest = None best_quest = None
current_uncertainty = gini(rows) current_uncertainty = gini(rows)
@ -130,7 +131,7 @@ def find_best_q(rows):
for val in vals: for val in vals:
quest = Question(col, val) quest = Question(col, val)
t_rows, f_rows = partition(rows, quest) t_rows, f_rows = split(rows, quest)
if len(t_rows) == 0 or len(f_rows) == 0: if len(t_rows) == 0 or len(f_rows) == 0:
continue continue
@ -144,11 +145,13 @@ def find_best_q(rows):
class Leaf: class Leaf:
#contain a number of how many times the label has appeared in dataset
def __init__(self, rows): def __init__(self, rows):
self.predicts = class_counts(rows) self.predicts = uniq_count(rows)
class Decision_Node(): class Decision_Node():
#contain the question and child nodes
def __init__(self, quest, t_branch, f_branch): def __init__(self, quest, t_branch, f_branch):
self.quest = quest self.quest = quest
self.t_branch = t_branch self.t_branch = t_branch
@ -156,37 +159,45 @@ class Decision_Node():
def build_tree(rows): def build_tree(rows):
#use info gain and question
gain, quest = find_best_q(rows) gain, quest = find_best_q(rows)
#no gain = no more question, so return a Leaf
if gain == 0: if gain == 0:
return Leaf(rows) return Leaf(rows)
t_rows, f_rows = partition(rows, quest) #split into true and false branch
t_rows, f_rows = split(rows, quest)
#print out branches
t_branch = build_tree(t_rows) t_branch = build_tree(t_rows)
f_branch = build_tree(f_rows) f_branch = build_tree(f_rows)
#return the child/leaf
return Decision_Node(quest, t_branch, f_branch) return Decision_Node(quest, t_branch, f_branch)
def print_tree(node, spc=""): def print_tree(node, spc=""):
#if node is a leaf
if isinstance(node, Leaf): if isinstance(node, Leaf):
print(" " + "Predict", node.predicts) print(" " + "Predict", node.predicts)
return return #end of function
#Or question
print("" + str(node.quest)) print("" + str(node.quest))
#True branch
print("" + '--> True:') print("" + '--> True:')
print_tree(node.t_branch, spc + " ") print_tree(node.t_branch, spc + " ")
#False branch
print("" + '--> False:') print("" + '--> False:')
print_tree(node.f_branch, spc + " ") print_tree(node.f_branch, spc + " ")
def classify(row, node): def classify(row, node):
#return our prediction in case the node is a leaf
if isinstance(node, Leaf): if isinstance(node, Leaf):
return node.predictions return node.predicts
#otherwise go to the child
if node.quest.compare(row): if node.quest.compare(row):
return classify(row, node.t_branch) return classify(row, node.t_branch)
else: else:
@ -194,8 +205,9 @@ def classify(row, node):
def print_leaf(counts): def print_leaf(counts):
#count prediction
total = sum(counts.values())*1.0 total = sum(counts.values())*1.0
probs = {} probs = {} #probability
for lbl in counts.keys(): for lbl in counts.keys():
probs[lbl] = str(int(counts[lbl] / total*100)) + "%" probs[lbl] = str(int(counts[lbl] / total*100)) + "%"
return probs return probs