diff --git a/main.py b/main.py index 8cf7b4e..6cfa9fa 100644 --- a/main.py +++ b/main.py @@ -46,6 +46,156 @@ menu = Context.fromstring(''' |meat|salad|meal|drink|cold|hot | #print(func_output) +''' +def uniq_val_from_data(rows, col): + return set([row[col] for row in rows]) + + +def class_counts(rows): + counts = {} + for row in rows: + label = row[-1] + if label not in counts: + counts[label] = 0 + counts[label] += 1 + return counts + + +def isnumer(value): + return isinstance(value, int) or isinstance(value, float) + + +header = ... + +class Question(): + + def __init__(self, column, value): + self.column = column + self.value = value + + def compare(self, example): + val = example[self.column] + if isnumer(val): + return val >= self.value + else: + return val == self.value + + def __repr__(self): + condition = "==" + if isnumer(self.value): + condition = ">=" + return "Is %s %s %s?" % (header[self.column], condition, str(self.value)) + + +def partition(rows, quest): + t_rows, f_rows = [], [] + for rows in rows: + if quest.compare(row) + t_rows.append(row) + else: + f_rows.append(row) + return t_rows, f_rows + + +def gini(rows): + counts = class_counts(rows) + impurity = 1 + for lbl in counts: + prob_of_lbl = counts[lbl] / float(lem(rows)) + impurity -= prob_of_lbl**2 + return impurity + + +def info_gain(l,r, current_uncertainty): + p = float(len(l)) / (len(l) + len(r)) + return current_uncertainty - p*gini(l) - (1-p)*gini(r) + + +def find_best_q(rows): + best_gain = 0 + best_quest = None + current_uncertainty = gini(rows) + n_features = len(rows[0]) - 1 + + for col in range(n_feat): + values = set([row[col] for row in rows]) + + for cal in values: + quest = Question(col, val) + + t_rows, f_rows = partition(rows, quest) + + if len(t_rows) == 0 or len(f_rows) == 0Ж + continue + + fain = info_gain(t_rows, f_rows, current_uncertainty) + + if gain >= best gain: + best_gain, best_quest = gain, quest + + return best_gain, best_quest + + +class Leaf: + def __init__(self,rows): + self.predicts = class_counts(rows) + + +class Decision_Node(): + def __init__(self, quest, t_branch, f_branch): + self.quest = quest + self.t_branch = t_branch + self.f_branch = f_branch + + +def build_tree(): + gain, quest = find_best_q(rows) + + if gain == 0: + return Leaf(rows) + + t_rows, f_rows = partition(rows, quest) + + t_branch = build_tree(t_rows) + f_branch = build_tree(f_rows) + + return Decision_Node(quest, t_branch, f_branch) + + +def print_tree(node): + + if isinstance(node, leaf): + print("" + "Predict", node.predictions) + return + + print("" + str(node.quest)) + + print("" + '--> True:') + print_tree(node.t_branch, ""+ " ") + + print("" + '--> False:') + print_tree(node.f_branch,"" + " ") + +def classify(row, node): + if isinstance(node, leaf): + return node.predictions + + if node.quest.compare(row): + return classify(row, node.t_branch) + else: + return classify(row, node.f_branch) + + +def print_leaf(counts): + total = sum(counts.values())*1.0 + probs = {} + for lbl in counts.keys(): + probs[lbl] = str(int(counts[lbl] / total*100)) + "%" + return probs + + +''' + ### class Node: def __init__(self, state, parent, action):