From 7e92796a191b4a4681f6be69090627a76e14643e Mon Sep 17 00:00:00 2001 From: Serhii Hromov Date: Mon, 18 May 2020 10:20:10 +0000 Subject: [PATCH] Fixed tree --- main.py | 70 +++++++++++++++++++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 29 deletions(-) diff --git a/main.py b/main.py index 6cfa9fa..9a8e129 100644 --- a/main.py +++ b/main.py @@ -30,13 +30,23 @@ EAT_TIME = 15 #### Menu menu = Context.fromstring(''' |meat|salad|meal|drink|cold|hot | - Pork | X | | X | | | X | + Pork | X | | | | | X | Espresso | | | | X | | X | Green Tea | | | | X | X | | - Greek Salad| | X | X | | X | | + Greek Salad| | X | | | X | | Pizza | | | X | | | X |''') +training_data = [ +['meat','hot','Pork'], +['salad','cold','Greek Salad'], +['drink','hot','Espresso'], +['drink','cold','Green Tea'], +['meal','hot','Pizza'], +] + +tree_format = ["dish", "temperature", "label"] + #menu.lattice.graphviz() #Digraph.render('Lattice.gv', view=True) @@ -46,7 +56,6 @@ menu = Context.fromstring(''' |meat|salad|meal|drink|cold|hot | #print(func_output) -''' def uniq_val_from_data(rows, col): return set([row[col] for row in rows]) @@ -65,16 +74,14 @@ def isnumer(value): return isinstance(value, int) or isinstance(value, float) -header = ... - class Question(): - def __init__(self, column, value): - self.column = column + def __init__(self, col, value): + self.col = col self.value = value def compare(self, example): - val = example[self.column] + val = example[self.col] if isnumer(val): return val >= self.value else: @@ -83,14 +90,14 @@ class Question(): def __repr__(self): condition = "==" if isnumer(self.value): - condition = ">=" - return "Is %s %s %s?" % (header[self.column], condition, str(self.value)) + condition = ">=" + return "Is %s %s %s?" % (tree_format[self.col], condition, str(self.value)) def partition(rows, quest): t_rows, f_rows = [], [] - for rows in rows: - if quest.compare(row) + for row in rows: + if quest.compare(row): t_rows.append(row) else: f_rows.append(row) @@ -101,12 +108,12 @@ def gini(rows): counts = class_counts(rows) impurity = 1 for lbl in counts: - prob_of_lbl = counts[lbl] / float(lem(rows)) + prob_of_lbl = counts[lbl] / float(len(rows)) impurity -= prob_of_lbl**2 return impurity -def info_gain(l,r, current_uncertainty): +def info_gain(l, r, current_uncertainty): p = float(len(l)) / (len(l) + len(r)) return current_uncertainty - p*gini(l) - (1-p)*gini(r) @@ -115,29 +122,29 @@ def find_best_q(rows): best_gain = 0 best_quest = None current_uncertainty = gini(rows) - n_features = len(rows[0]) - 1 + n_feat = len(rows[0]) - 1 for col in range(n_feat): - values = set([row[col] for row in rows]) + vals = set([row[col] for row in rows]) - for cal in values: + for val in vals: quest = Question(col, val) t_rows, f_rows = partition(rows, quest) - if len(t_rows) == 0 or len(f_rows) == 0Ж + if len(t_rows) == 0 or len(f_rows) == 0: continue - fain = info_gain(t_rows, f_rows, current_uncertainty) + gain = info_gain(t_rows, f_rows, current_uncertainty) - if gain >= best gain: + if gain >= best_gain: best_gain, best_quest = gain, quest return best_gain, best_quest class Leaf: - def __init__(self,rows): + def __init__(self, rows): self.predicts = class_counts(rows) @@ -148,7 +155,7 @@ class Decision_Node(): self.f_branch = f_branch -def build_tree(): +def build_tree(rows): gain, quest = find_best_q(rows) if gain == 0: @@ -162,22 +169,22 @@ def build_tree(): return Decision_Node(quest, t_branch, f_branch) -def print_tree(node): +def print_tree(node, spc=""): - if isinstance(node, leaf): - print("" + "Predict", node.predictions) + if isinstance(node, Leaf): + print(" " + "Predict", node.predicts) return print("" + str(node.quest)) print("" + '--> True:') - print_tree(node.t_branch, ""+ " ") + print_tree(node.t_branch, spc + " ") print("" + '--> False:') - print_tree(node.f_branch,"" + " ") + print_tree(node.f_branch, spc + " ") def classify(row, node): - if isinstance(node, leaf): + if isinstance(node, Leaf): return node.predictions if node.quest.compare(row): @@ -194,7 +201,12 @@ def print_leaf(counts): return probs -''' + +#print(menu.extension(['meal',])) + +tree = build_tree(training_data) +print_tree(tree) + ### class Node: