From bb1de89d71792330e29c424102076a70e3262299 Mon Sep 17 00:00:00 2001 From: Serhii Hromov Date: Mon, 25 May 2020 13:55:43 +0000 Subject: [PATCH] moved tree --- main.py | 203 +++++--------------------------------------------------- 1 file changed, 15 insertions(+), 188 deletions(-) diff --git a/main.py b/main.py index c0dcef1..4662d4c 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,8 @@ import pygad from concepts import * from graphviz import * import numpy as np +from data import * +from choice_tree import * pygame.init() @@ -28,198 +30,23 @@ display = pygame.display.set_mode((WIDTH * 32 + 200, HEIGHT * 32)) # eating time EAT_TIME = 15 -#### Menu -menu = Context.fromstring(''' |meat|salad|meal|drink|cold|hot | - Pork | X | | | | | X | - Espresso | | | | X | | X | - Green Tea | | | | X | X | | - Greek Salad| | X | | | X | | - Pizza | | | X | | | X |''') - - -training_data = [ -['meat','hot','Pork'], -['salad','cold','Greek Salad'], -['drink','hot','Espresso'], -['drink','cold','Green Tea'], -['meal','hot','Pizza'], -] - -tree_format = ["dish", "temperature", "label"] - -#menu.lattice.graphviz() -#Digraph.render('Lattice.gv', view=True) - - -#print(menu.extension(['meal',])) - - -#print(func_output) - -def uniq_count(rows): - #count uniq labels(names) - count = {} - for row in rows: - lbl = row[-1] - if lbl not in count: - count[lbl] = 0 - count[lbl] += 1 - return count - -#didn't used -def isnumer(val): - return isinstance(val, int) or isinstance(val, float) - - -class Question(): - - def __init__(self, col, value): - self.col = col #column - self.value = value #value of column - - def compare(self, example): - #compare val in example with val in the question - val = example[self.col] - if isnumer(val): #in case menu have prices - return val >= self.value - else: - return val == self.value - - def __repr__(self): - #just to print - condition = "==" - if isnumer(self.value): - condition = ">=" - return "Is %s %s %s?" % (tree_format[self.col], condition, str(self.value)) - - -def split(rows, quest): - #split data into True and False - t_rows, f_rows = [], [] - for row in rows: - if quest.compare(row): - t_rows.append(row) - else: - f_rows.append(row) - return t_rows, f_rows - - -def gini(rows): - counts = uniq_count(rows) - impurity = 1 - for lbl in counts: - prob_of_lbl = counts[lbl] / float(len(rows)) - impurity -= prob_of_lbl**2 - return impurity - - -def info_gain(l, r, current_uncertainty): - p = float(len(l)) / (len(l) + len(r)) #something like an enthropy? - return current_uncertainty - p*gini(l) - (1-p)*gini(r) - - -def find_best_q(rows): - #best question to split the data - best_gain = 0 - best_quest = None - current_uncertainty = gini(rows) - n_feat = len(rows[0]) - 1 - - for col in range(n_feat): - vals = set([row[col] for row in rows]) - - for val in vals: - quest = Question(col, val) - - t_rows, f_rows = split(rows, quest) - - if len(t_rows) == 0 or len(f_rows) == 0: - continue - - gain = info_gain(t_rows, f_rows, current_uncertainty) - - if gain >= best_gain: - best_gain, best_quest = gain, quest - - return best_gain, best_quest - - -class Leaf: - #contain a number of how many times the label has appeared in dataset - def __init__(self, rows): - self.predicts = uniq_count(rows) - - -class Decision_Node(): - #contain the question and child nodes - def __init__(self, quest, t_branch, f_branch): - self.quest = quest - self.t_branch = t_branch - self.f_branch = f_branch - - -def build_tree(rows): - #use info gain and question - gain, quest = find_best_q(rows) - - #no gain = no more question, so return a Leaf - if gain == 0: - return Leaf(rows) - - #split into true and false branch - t_rows, f_rows = split(rows, quest) - - #print out branches - t_branch = build_tree(t_rows) - f_branch = build_tree(f_rows) - - #return the child/leaf - return Decision_Node(quest, t_branch, f_branch) - - -def print_tree(node, spc=""): - - #if node is a leaf - if isinstance(node, Leaf): - print(" " + "Predict", node.predicts) - return #end of function - - #Or question - print("" + str(node.quest)) - #True branch - print("" + '--> True:') - print_tree(node.t_branch, spc + " ") - #False branch - print("" + '--> False:') - print_tree(node.f_branch, spc + " ") - -def classify(row, node): - #return our prediction in case the node is a leaf - if isinstance(node, Leaf): - return node.predicts - #otherwise go to the child - if node.quest.compare(row): - return classify(row, node.t_branch) - else: - return classify(row, node.f_branch) - - -def print_leaf(counts): - #count prediction - total = sum(counts.values())*1.0 - probs = {} #probability - for lbl in counts.keys(): - probs[lbl] = str(int(counts[lbl] / total*100)) + "%" - return probs - - - -#print(menu.extension(['meal',])) - tree = build_tree(training_data) +#order_len = len(tree_format) print_tree(tree) +def client_ordering(): + order = [] + + for i in range(0, len(tree_format)-1): + tmpr = random.sample(rand_data[i], 1) + order.append(tmpr[0]) + + order.append('order') + return order +### + + ### class Node: def __init__(self, state, parent, action):