diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 26d3352..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml diff --git a/__pycache__/astar.cpython-310.pyc b/__pycache__/astar.cpython-310.pyc deleted file mode 100644 index 42aa76a..0000000 Binary files a/__pycache__/astar.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/bfs.cpython-310.pyc b/__pycache__/bfs.cpython-310.pyc deleted file mode 100644 index 6cd41fc..0000000 Binary files a/__pycache__/bfs.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/garbage_truck.cpython-310.pyc b/__pycache__/garbage_truck.cpython-310.pyc deleted file mode 100644 index d4dcb18..0000000 Binary files a/__pycache__/garbage_truck.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/heuristicfn.cpython-310.pyc b/__pycache__/heuristicfn.cpython-310.pyc deleted file mode 100644 index a8d8775..0000000 Binary files a/__pycache__/heuristicfn.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/main.cpython-310.pyc b/__pycache__/main.cpython-310.pyc deleted file mode 100644 index 6af3024..0000000 Binary files a/__pycache__/main.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/state.cpython-310.pyc b/__pycache__/state.cpython-310.pyc deleted file mode 100644 index 927e172..0000000 Binary files a/__pycache__/state.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/succ.cpython-310.pyc b/__pycache__/succ.cpython-310.pyc deleted file mode 100644 index d3465fc..0000000 Binary files a/__pycache__/succ.cpython-310.pyc and /dev/null differ diff --git a/astar.py b/astar.py index 2b30eaf..8621c06 100644 --- a/astar.py +++ b/astar.py @@ -36,26 +36,3 @@ def astar(istate, goalx, goaly, passedFields): element.priority = value.priority return False - -# def bfs(istate, goalx, goaly, passedFields): -# fringe = [istate] -# explored = [] -# steps = [] -# while fringe: -# state = fringe.pop(0) -# if state.xpos == goalx and state.ypos == goaly: -# steps.insert(0, state) -# while (state.parent != None): -# state = state.parent -# steps.insert(0, state) -# return steps - -# element = successors(state, passedFields) -# explored.append((state.xpos, state.ypos, state.orientation)) -# for value in element: -# val = (value.xpos, value.ypos, value.orientation) -# if val not in explored and value not in fringe: -# fringe.append(value) -# return False - - diff --git a/collect b/collect index 668d904..bcbe918 100644 --- a/collect +++ b/collect @@ -24,7 +24,7 @@ edge [fontname="helvetica"] ; 6 -> 10 ; 11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ; 10 -> 11 ; -12 [label="odour_intensity <= 5.682\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ; +12 [label="distance <= 10.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ; 11 -> 12 ; 13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; 12 -> 13 ; @@ -36,7 +36,7 @@ edge [fontname="helvetica"] ; 15 -> 16 ; 17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ; 15 -> 17 ; -18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ; +18 [label="odour_intensity <= 5.724\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ; 17 -> 18 ; 19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; 18 -> 19 ; @@ -54,11 +54,11 @@ edge [fontname="helvetica"] ; 23 -> 25 ; 26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ; 25 -> 26 ; -27 [label="distance <= 7.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ; +27 [label="space_occupied <= 0.936\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ; 25 -> 27 ; -28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; +28 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ; 27 -> 28 ; -29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ; +29 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; 27 -> 29 ; 30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ; 0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ; @@ -88,14 +88,18 @@ edge [fontname="helvetica"] ; 40 -> 42 ; 43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ; 42 -> 43 ; -44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ; +44 [label="days_since_last_collection <= 20.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ; 42 -> 44 ; 45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; 44 -> 45 ; -46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ; +46 [label="paid_on_time <= 0.5\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ; 44 -> 46 ; -47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; +47 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; 46 -> 47 ; -48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ; +48 [label="space_occupied <= 0.243\ngini = 0.245\nsamples = 7\nvalue = [1, 6]\nclass = no-collect"] ; 46 -> 48 ; +49 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; +48 -> 49 ; +50 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ; +48 -> 50 ; } diff --git a/collect.pdf b/collect.pdf index 0bc3799..2f2f23c 100644 Binary files a/collect.pdf and b/collect.pdf differ diff --git a/garbage_truck.py b/garbage_truck.py index 9998400..af0e228 100644 --- a/garbage_truck.py +++ b/garbage_truck.py @@ -1,31 +1,38 @@ +from heuristicfn import heuristicfn + FIELDWIDTH = 50 +TURN_FUEL_COST = 10 +MOVE_FUEL_COST = 200 +MAX_FUEL = 20000 +MAX_SPACE = 5 +MAX_WEIGHT = 200 -class GarbageTank: - def __init__(self, volume_capacity, mass_capacity): - self.vcapacity = volume_capacity #m^3 - self.mcapacity = mass_capacity #kg - -class Engine: - def __init__(self, power): - self.power = power #HP class GarbageTruck: - def __init__(self, dump_location, fuel_capacity, rect, orientation): - self.dump_location = dump_location - self.tank = GarbageTank(15, 18000) - self.engine = Engine(400) - self.fuel = fuel_capacity + + garbage_types = {'bio': 0, 'electronics': 1, 'mixed': 2, 'recyclable': 3} + + def __init__(self, dump_x, dump_y, rect, orientation, request_list: list, clf): + self.dump_x = dump_x + self.dump_y = dump_y + self.fuel = MAX_FUEL + self.free_space = MAX_SPACE + self.weight_capacity = MAX_WEIGHT self.rect = rect self.orientation = orientation - self.houses = [] #lista domów do odwiedzenia + self.request_list = request_list #lista domów do odwiedzenia + self.clf = clf def turn_left(self): self.orientation = (self.orientation - 1) % 4 + self.fuel -= TURN_FUEL_COST def turn_right(self): self.orientation = (self.orientation + 1) % 4 + self.fuel -= TURN_FUEL_COST def forward(self): + self.fuel -= MOVE_FUEL_COST if self.orientation == 0: self.rect.x += FIELDWIDTH elif self.orientation == 1: @@ -33,4 +40,50 @@ class GarbageTruck: elif self.orientation == 2: self.rect.x -= FIELDWIDTH else: - self.rect.y -= FIELDWIDTH \ No newline at end of file + self.rect.y -= FIELDWIDTH + + def next_destination(self): + if self.fuel <= 0 or not self.request_list: + return self.dump_x, self.dump_y + + for i in range(len(self.request_list)): + request = self.request_list[i] + + #nie ma miejsca w zbiorniku lub za ciężkie śmieci + if request.volume > self.free_space or request.weight > self.weight_capacity: + continue + + #nie straczy paliwa na dojechanie i powrót na wysypisko + if heuristicfn(request.x_pos, request.y_pos, self.dump_x, self.dump_y) / 50 * 200 > self.fuel: + continue + + + + distance = heuristicfn(self.rect.x, self.rect.y, request.x_pos, request.y_pos) / 50 + + r = [ + self.fuel, + distance, + request.volume, + request.last_collection, + request.is_paid, + request.odour_intensity, + request.weight, + request.type + ] + if self.clf.predict([r]) == True: + self.request_list.pop(i) + self.free_space -= request.volume + self.weight_capacity -= request.weight + return request.x_pos, request.y_pos + return self.dump_x, self.dump_y + + + + def collect(self): + if self.rect.x == self.dump_x and self.rect.y == self.dump_y: + self.fuel = MAX_WEIGHT + self.free_space = MAX_SPACE + self.weight_capacity = MAX_WEIGHT + print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}') + pass \ No newline at end of file diff --git a/home.py b/home.py deleted file mode 100644 index e795dfc..0000000 --- a/home.py +++ /dev/null @@ -1,4 +0,0 @@ -class Home: - def __init__(self, coord): - self.coord = coord - self.collect_request = False \ No newline at end of file diff --git a/litter.py b/litter.py deleted file mode 100644 index 6dc9dd9..0000000 --- a/litter.py +++ /dev/null @@ -1,8 +0,0 @@ -class Litter: - - types = ['PAPER', 'GLASS', 'PLASTIC', 'METAL', 'BIO', 'MUNICIPAL', 'ELECTRONICS'] - - def __init__(self, type, volume, mass): - self.type = type - self.volume = volume - self.mass = mass diff --git a/main.py b/main.py index 3123f69..e37f883 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,5 @@ import pygame -import random -import pandas as pd -from sklearn import tree -from sklearn.preprocessing import LabelEncoder -import graphviz +from treelearn import treelearn from astar import astar @@ -11,6 +7,7 @@ from state import State import time from garbage_truck import GarbageTruck from heuristicfn import heuristicfn +from map import randomize_map pygame.init() WIDTH, HEIGHT = 800, 800 @@ -18,52 +15,12 @@ window = pygame.display.set_mode((WIDTH, HEIGHT)) pygame.display.set_caption("Intelligent Garbage Collector") AGENT_IMG = pygame.image.load("garbage-truck-nbg.png") AGENT = pygame.transform.scale(AGENT_IMG, (50, 50)) -DIRT_IMG = pygame.image.load("dirt.jpg") -DIRT = pygame.transform.scale(DIRT_IMG, (50, 50)) -GRASS_IMG = pygame.image.load("grass.png") -GRASS = pygame.transform.scale(GRASS_IMG, (50, 50)) -SAND_IMG = pygame.image.load("sand.jpeg") -SAND = pygame.transform.scale(SAND_IMG, (50, 50)) -COBBLE_IMG = pygame.image.load("cobble.jpeg") -COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50)) FPS = 10 FIELDCOUNT = 16 FIELDWIDTH = 50 - -class Agent: - def __init__(self, rect, direction): - self.rect = rect - self.direction = direction - - -def randomize_map(): # tworzenie mapy z losowymi polami - field_array_1 = [] - field_array_2 = [] - field_priority = [] - for i in range(16): - temp_priority = [] - for j in range(16): - if i in (0, 1) and j in (0, 1): - field_array_2.append(GRASS) - temp_priority.append(1) - else: - prob = random.uniform(0, 100) - if 0 <= prob <= 12: - field_array_2.append(COBBLE) - temp_priority.append(3) - elif 12 < prob <= 24: - field_array_2.append(SAND) - temp_priority.append(2) - else: - field_array_2.append(GRASS) - temp_priority.append(1) - field_array_1.append(field_array_2) - field_array_2 = [] - field_priority.append(temp_priority) - return field_array_1, field_priority - - +GRASS_IMG = pygame.image.load("grass.png") +GRASS = pygame.transform.scale(GRASS_IMG, (50, 50)) def draw_window(agent, fields, flip): if flip: direction = pygame.transform.flip(AGENT, True, False) @@ -77,35 +34,22 @@ def draw_window(agent, fields, flip): def main(): - train_data = pd.read_csv('./data_set.csv') - attributes = train_data.drop('collect', axis='columns') - e_type = LabelEncoder() - attributes['type_num'] = e_type.fit_transform(attributes['garbage_type']) - attr_encoded = attributes.drop(['garbage_type'], axis='columns') - attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type'] - label_names = ['collect', 'no-collect'] - label = train_data['collect'] - print(attr_encoded) - print(label) - classifier = tree.DecisionTreeClassifier() - classifier.fit(attr_encoded, label) - dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names) - graph = graphviz.Source(dot_data) - graph.render('collect') + clf = treelearn() clock = pygame.time.Clock() run = True - x, y = [0, 0] - agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta - fields, priority_array = randomize_map() - final_x, final_y = [100, 300] + fields, priority_array, request_list = randomize_map() + agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta while run: clock.tick(FPS) for event in pygame.event.get(): if event.type == pygame.QUIT: run = False - # keys_pressed = pygame.key.get_pressed() draw_window(agent, fields, False) # false = kierunek east (domyslny), true = west - steps = astar(State(None, None, x, y, 'E', priority_array[0][0], heuristicfn(x, y, final_x, final_y)), final_x, final_y, priority_array) + x, y = agent.next_destination() + if x == agent.rect.x and y == agent.rect.y: + print('out of jobs') + break + steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation, priority_array[0][0], heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array) for interm in steps: if interm.action == 'LEFT': agent.turn_left() @@ -121,10 +65,11 @@ def main(): draw_window(agent, fields, True) else: draw_window(agent, fields, False) - time.sleep(0.5) + time.sleep(0.3) + agent.collect() + fields[agent.rect.x//50][agent.rect.y//50] = GRASS + time.sleep(0.5) - while True: - pass pygame.quit() diff --git a/map.py b/map.py new file mode 100644 index 0000000..f155c2b --- /dev/null +++ b/map.py @@ -0,0 +1,44 @@ +import pygame, random +from request import Request + +DIRT_IMG = pygame.image.load("dirt.jpg") +DIRT = pygame.transform.scale(DIRT_IMG, (50, 50)) +GRASS_IMG = pygame.image.load("grass.png") +GRASS = pygame.transform.scale(GRASS_IMG, (50, 50)) +SAND_IMG = pygame.image.load("sand.jpeg") +SAND = pygame.transform.scale(SAND_IMG, (50, 50)) +COBBLE_IMG = pygame.image.load("cobble.jpeg") +COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50)) + +def randomize_map(): # tworzenie mapy z losowymi polami + request_list = [] + field_array_1 = [] + field_array_2 = [] + field_priority = [] + for i in range(16): + temp_priority = [] + for j in range(16): + if i in (0, 1) and j in (0, 1): + field_array_2.append(GRASS) + temp_priority.append(1) + else: + prob = random.uniform(0, 100) + if 0 <= prob <= 12: + field_array_2.append(COBBLE) + temp_priority.append(100) + request_list.append(Request( + i*50,j*50, #lokacja + random.randint(0,3), #typ śmieci + random.random(), #objętość śmieci + random.randint(0,30), #ostatni odbiór + random.randint(0,1), #czy opłacone w terminie + random.random() * 10, #intensywność odoru + random.random() * 50 #waga śmieci + )) + else: + field_array_2.append(GRASS) + temp_priority.append(1) + field_array_1.append(field_array_2) + field_array_2 = [] + field_priority.append(temp_priority) + return field_array_1, field_priority, request_list \ No newline at end of file diff --git a/request.py b/request.py new file mode 100644 index 0000000..9064eb8 --- /dev/null +++ b/request.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +@dataclass +class Request: + def __init__(self, x_pos, y_pos, type, volume, last_collection, is_paid, odour_intensity, weight): + self.x_pos = x_pos + self.y_pos = y_pos + self.type = type + self.volume = volume + self.last_collection = last_collection + self.is_paid = is_paid + self.odour_intensity = odour_intensity + self.weight = weight \ No newline at end of file diff --git a/succ.py b/succ.py index 8773dcd..07f73fa 100644 --- a/succ.py +++ b/succ.py @@ -5,27 +5,27 @@ FIELDWIDTH, FIELDCOUNT = 50, 16 def succ(st: State, passedPriorities, goalx, goaly): successors = [] - if st.orientation == 'N': - successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'W', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) - successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 'E', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + if st.orientation == 3: + successors.append(State(st, 'LEFT', st.xpos, st.ypos, 2, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 0, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) if st.ypos > 0: - successors.append(State(st, 'FORWARD', st.xpos, st.ypos - FIELDWIDTH , 'N', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + successors.append(State(st, 'FORWARD', st.xpos, st.ypos - FIELDWIDTH , 3, passedPriorities[st.xpos//50][st.ypos//50 - 1], heuristicfn(st.xpos, st.ypos - 50, goalx, goaly))) - if st.orientation == 'S': - successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'E', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) - successors.append(State(st,'RIGHT', st.xpos, st.ypos, 'W', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + if st.orientation == 1: + successors.append(State(st, 'LEFT', st.xpos, st.ypos, 0, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + successors.append(State(st,'RIGHT', st.xpos, st.ypos, 2, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) if st.ypos < FIELDWIDTH * (FIELDCOUNT - 1): - successors.append(State(st, 'FORWARD', st.xpos, st.ypos + FIELDWIDTH , 'S', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + successors.append(State(st, 'FORWARD', st.xpos, st.ypos + FIELDWIDTH , 1, passedPriorities[st.xpos//50][st.ypos//50 + 1], heuristicfn(st.xpos, st.ypos + 50, goalx, goaly))) - if st.orientation == 'W': - successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'S', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) - successors.append(State(st,'RIGHT', st.xpos, st.ypos, 'N', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + if st.orientation == 2: + successors.append(State(st, 'LEFT', st.xpos, st.ypos, 1, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + successors.append(State(st,'RIGHT', st.xpos, st.ypos, 3, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) if st.xpos > 0: - successors.append(State(st, 'FORWARD', st.xpos - FIELDWIDTH , st.ypos, 'W', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + successors.append(State(st, 'FORWARD', st.xpos - FIELDWIDTH , st.ypos, 2, passedPriorities[st.xpos//50 - 1][st.ypos//50], heuristicfn(st.xpos - 50, st.ypos, goalx, goaly))) - if st.orientation == 'E': - successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'N', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) - successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 'S', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + if st.orientation == 0: + successors.append(State(st, 'LEFT', st.xpos, st.ypos, 3, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 1, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) if st.xpos < FIELDWIDTH * (FIELDCOUNT - 1): - successors.append(State(st, 'FORWARD', st.xpos + FIELDWIDTH , st.ypos, 'E', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) + successors.append(State(st, 'FORWARD', st.xpos + FIELDWIDTH , st.ypos, 0, passedPriorities[st.xpos//50 + 1][st.ypos//50], heuristicfn(st.xpos + 50, st.ypos, goalx, goaly))) return successors diff --git a/treelearn.py b/treelearn.py new file mode 100644 index 0000000..7d20976 --- /dev/null +++ b/treelearn.py @@ -0,0 +1,20 @@ +import pandas as pd +from sklearn import tree +from sklearn.preprocessing import LabelEncoder +import graphviz + +def treelearn(): + train_data = pd.read_csv('./data_set.csv') + attributes = train_data.drop('collect', axis='columns') + e_type = LabelEncoder() + attributes['type_num'] = e_type.fit_transform(attributes['garbage_type']) + attr_encoded = attributes.drop(['garbage_type'], axis='columns') + attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type'] + label_names = ['collect', 'no-collect'] + label = train_data['collect'] + classifier = tree.DecisionTreeClassifier() + classifier.fit(attr_encoded.values, label) + dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names) + graph = graphviz.Source(dot_data) + graph.render('collect') + return classifier \ No newline at end of file