From 26a2824818834305b1d2923d35ec1a6f361f368f Mon Sep 17 00:00:00 2001 From: Serhii Hromov Date: Mon, 18 May 2020 12:37:08 +0000 Subject: [PATCH] added some comments --- main.py | 86 ++++++++++++++++++++++++++++++++------------------------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/main.py b/main.py index 9a8e129..c0dcef1 100644 --- a/main.py +++ b/main.py @@ -56,45 +56,45 @@ tree_format = ["dish", "temperature", "label"] #print(func_output) -def uniq_val_from_data(rows, col): - return set([row[col] for row in rows]) - - -def class_counts(rows): - counts = {} +def uniq_count(rows): + #count uniq labels(names) + count = {} for row in rows: - label = row[-1] - if label not in counts: - counts[label] = 0 - counts[label] += 1 - return counts - - -def isnumer(value): - return isinstance(value, int) or isinstance(value, float) + lbl = row[-1] + if lbl not in count: + count[lbl] = 0 + count[lbl] += 1 + return count + +#didn't used +def isnumer(val): + return isinstance(val, int) or isinstance(val, float) class Question(): def __init__(self, col, value): - self.col = col - self.value = value + self.col = col #column + self.value = value #value of column def compare(self, example): + #compare val in example with val in the question val = example[self.col] - if isnumer(val): + if isnumer(val): #in case menu have prices return val >= self.value else: return val == self.value def __repr__(self): + #just to print condition = "==" if isnumer(self.value): condition = ">=" return "Is %s %s %s?" % (tree_format[self.col], condition, str(self.value)) -def partition(rows, quest): +def split(rows, quest): + #split data into True and False t_rows, f_rows = [], [] for row in rows: if quest.compare(row): @@ -105,7 +105,7 @@ def partition(rows, quest): def gini(rows): - counts = class_counts(rows) + counts = uniq_count(rows) impurity = 1 for lbl in counts: prob_of_lbl = counts[lbl] / float(len(rows)) @@ -114,11 +114,12 @@ def gini(rows): def info_gain(l, r, current_uncertainty): - p = float(len(l)) / (len(l) + len(r)) + p = float(len(l)) / (len(l) + len(r)) #something like an enthropy? return current_uncertainty - p*gini(l) - (1-p)*gini(r) def find_best_q(rows): + #best question to split the data best_gain = 0 best_quest = None current_uncertainty = gini(rows) @@ -130,7 +131,7 @@ def find_best_q(rows): for val in vals: quest = Question(col, val) - t_rows, f_rows = partition(rows, quest) + t_rows, f_rows = split(rows, quest) if len(t_rows) == 0 or len(f_rows) == 0: continue @@ -144,11 +145,13 @@ def find_best_q(rows): class Leaf: + #contain a number of how many times the label has appeared in dataset def __init__(self, rows): - self.predicts = class_counts(rows) + self.predicts = uniq_count(rows) class Decision_Node(): + #contain the question and child nodes def __init__(self, quest, t_branch, f_branch): self.quest = quest self.t_branch = t_branch @@ -156,37 +159,45 @@ class Decision_Node(): def build_tree(rows): + #use info gain and question gain, quest = find_best_q(rows) - + + #no gain = no more question, so return a Leaf if gain == 0: return Leaf(rows) - - t_rows, f_rows = partition(rows, quest) - + + #split into true and false branch + t_rows, f_rows = split(rows, quest) + + #print out branches t_branch = build_tree(t_rows) f_branch = build_tree(f_rows) - + + #return the child/leaf return Decision_Node(quest, t_branch, f_branch) def print_tree(node, spc=""): - + + #if node is a leaf if isinstance(node, Leaf): print(" " + "Predict", node.predicts) - return - + return #end of function + + #Or question print("" + str(node.quest)) - + #True branch print("" + '--> True:') print_tree(node.t_branch, spc + " ") - + #False branch print("" + '--> False:') print_tree(node.f_branch, spc + " ") - + def classify(row, node): + #return our prediction in case the node is a leaf if isinstance(node, Leaf): - return node.predictions - + return node.predicts + #otherwise go to the child if node.quest.compare(row): return classify(row, node.t_branch) else: @@ -194,8 +205,9 @@ def classify(row, node): def print_leaf(counts): + #count prediction total = sum(counts.values())*1.0 - probs = {} + probs = {} #probability for lbl in counts.keys(): probs[lbl] = str(int(counts[lbl] / total*100)) + "%" return probs