updated agent to collect garbage

This commit is contained in:
Mateusz 2023-05-26 19:08:22 +02:00
parent 63596f20f2
commit 55baf24513
19 changed files with 190 additions and 149 deletions

3
.idea/.gitignore vendored
View File

@ -1,3 +0,0 @@
# Default ignored files
/shelf/
/workspace.xml

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -36,26 +36,3 @@ def astar(istate, goalx, goaly, passedFields):
element.priority = value.priority element.priority = value.priority
return False return False
# def bfs(istate, goalx, goaly, passedFields):
# fringe = [istate]
# explored = []
# steps = []
# while fringe:
# state = fringe.pop(0)
# if state.xpos == goalx and state.ypos == goaly:
# steps.insert(0, state)
# while (state.parent != None):
# state = state.parent
# steps.insert(0, state)
# return steps
# element = successors(state, passedFields)
# explored.append((state.xpos, state.ypos, state.orientation))
# for value in element:
# val = (value.xpos, value.ypos, value.orientation)
# if val not in explored and value not in fringe:
# fringe.append(value)
# return False

22
collect
View File

@ -24,7 +24,7 @@ edge [fontname="helvetica"] ;
6 -> 10 ; 6 -> 10 ;
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ; 11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
10 -> 11 ; 10 -> 11 ;
12 [label="odour_intensity <= 5.682\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ; 12 [label="distance <= 10.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
11 -> 12 ; 11 -> 12 ;
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; 13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
12 -> 13 ; 12 -> 13 ;
@ -36,7 +36,7 @@ edge [fontname="helvetica"] ;
15 -> 16 ; 15 -> 16 ;
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ; 17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
15 -> 17 ; 15 -> 17 ;
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ; 18 [label="odour_intensity <= 5.724\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
17 -> 18 ; 17 -> 18 ;
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; 19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
18 -> 19 ; 18 -> 19 ;
@ -54,11 +54,11 @@ edge [fontname="helvetica"] ;
23 -> 25 ; 23 -> 25 ;
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ; 26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
25 -> 26 ; 25 -> 26 ;
27 [label="distance <= 7.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ; 27 [label="space_occupied <= 0.936\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
25 -> 27 ; 25 -> 27 ;
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; 28 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
27 -> 28 ; 27 -> 28 ;
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ; 29 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
27 -> 29 ; 27 -> 29 ;
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ; 30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ; 0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
@ -88,14 +88,18 @@ edge [fontname="helvetica"] ;
40 -> 42 ; 40 -> 42 ;
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ; 43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
42 -> 43 ; 42 -> 43 ;
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ; 44 [label="days_since_last_collection <= 20.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
42 -> 44 ; 42 -> 44 ;
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; 45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
44 -> 45 ; 44 -> 45 ;
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ; 46 [label="paid_on_time <= 0.5\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
44 -> 46 ; 44 -> 46 ;
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; 47 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
46 -> 47 ; 46 -> 47 ;
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ; 48 [label="space_occupied <= 0.243\ngini = 0.245\nsamples = 7\nvalue = [1, 6]\nclass = no-collect"] ;
46 -> 48 ; 46 -> 48 ;
49 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
48 -> 49 ;
50 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
48 -> 50 ;
} }

Binary file not shown.

View File

@ -1,31 +1,38 @@
from heuristicfn import heuristicfn
FIELDWIDTH = 50 FIELDWIDTH = 50
TURN_FUEL_COST = 10
MOVE_FUEL_COST = 200
MAX_FUEL = 20000
MAX_SPACE = 5
MAX_WEIGHT = 200
class GarbageTank:
def __init__(self, volume_capacity, mass_capacity):
self.vcapacity = volume_capacity #m^3
self.mcapacity = mass_capacity #kg
class Engine:
def __init__(self, power):
self.power = power #HP
class GarbageTruck: class GarbageTruck:
def __init__(self, dump_location, fuel_capacity, rect, orientation):
self.dump_location = dump_location garbage_types = {'bio': 0, 'electronics': 1, 'mixed': 2, 'recyclable': 3}
self.tank = GarbageTank(15, 18000)
self.engine = Engine(400) def __init__(self, dump_x, dump_y, rect, orientation, request_list: list, clf):
self.fuel = fuel_capacity self.dump_x = dump_x
self.dump_y = dump_y
self.fuel = MAX_FUEL
self.free_space = MAX_SPACE
self.weight_capacity = MAX_WEIGHT
self.rect = rect self.rect = rect
self.orientation = orientation self.orientation = orientation
self.houses = [] #lista domów do odwiedzenia self.request_list = request_list #lista domów do odwiedzenia
self.clf = clf
def turn_left(self): def turn_left(self):
self.orientation = (self.orientation - 1) % 4 self.orientation = (self.orientation - 1) % 4
self.fuel -= TURN_FUEL_COST
def turn_right(self): def turn_right(self):
self.orientation = (self.orientation + 1) % 4 self.orientation = (self.orientation + 1) % 4
self.fuel -= TURN_FUEL_COST
def forward(self): def forward(self):
self.fuel -= MOVE_FUEL_COST
if self.orientation == 0: if self.orientation == 0:
self.rect.x += FIELDWIDTH self.rect.x += FIELDWIDTH
elif self.orientation == 1: elif self.orientation == 1:
@ -34,3 +41,49 @@ class GarbageTruck:
self.rect.x -= FIELDWIDTH self.rect.x -= FIELDWIDTH
else: else:
self.rect.y -= FIELDWIDTH self.rect.y -= FIELDWIDTH
def next_destination(self):
if self.fuel <= 0 or not self.request_list:
return self.dump_x, self.dump_y
for i in range(len(self.request_list)):
request = self.request_list[i]
#nie ma miejsca w zbiorniku lub za ciężkie śmieci
if request.volume > self.free_space or request.weight > self.weight_capacity:
continue
#nie straczy paliwa na dojechanie i powrót na wysypisko
if heuristicfn(request.x_pos, request.y_pos, self.dump_x, self.dump_y) / 50 * 200 > self.fuel:
continue
distance = heuristicfn(self.rect.x, self.rect.y, request.x_pos, request.y_pos) / 50
r = [
self.fuel,
distance,
request.volume,
request.last_collection,
request.is_paid,
request.odour_intensity,
request.weight,
request.type
]
if self.clf.predict([r]) == True:
self.request_list.pop(i)
self.free_space -= request.volume
self.weight_capacity -= request.weight
return request.x_pos, request.y_pos
return self.dump_x, self.dump_y
def collect(self):
if self.rect.x == self.dump_x and self.rect.y == self.dump_y:
self.fuel = MAX_WEIGHT
self.free_space = MAX_SPACE
self.weight_capacity = MAX_WEIGHT
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}')
pass

View File

@ -1,4 +0,0 @@
class Home:
def __init__(self, coord):
self.coord = coord
self.collect_request = False

View File

@ -1,8 +0,0 @@
class Litter:
types = ['PAPER', 'GLASS', 'PLASTIC', 'METAL', 'BIO', 'MUNICIPAL', 'ELECTRONICS']
def __init__(self, type, volume, mass):
self.type = type
self.volume = volume
self.mass = mass

87
main.py
View File

@ -1,9 +1,5 @@
import pygame import pygame
import random from treelearn import treelearn
import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
import graphviz
from astar import astar from astar import astar
@ -11,6 +7,7 @@ from state import State
import time import time
from garbage_truck import GarbageTruck from garbage_truck import GarbageTruck
from heuristicfn import heuristicfn from heuristicfn import heuristicfn
from map import randomize_map
pygame.init() pygame.init()
WIDTH, HEIGHT = 800, 800 WIDTH, HEIGHT = 800, 800
@ -18,52 +15,12 @@ window = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("Intelligent Garbage Collector") pygame.display.set_caption("Intelligent Garbage Collector")
AGENT_IMG = pygame.image.load("garbage-truck-nbg.png") AGENT_IMG = pygame.image.load("garbage-truck-nbg.png")
AGENT = pygame.transform.scale(AGENT_IMG, (50, 50)) AGENT = pygame.transform.scale(AGENT_IMG, (50, 50))
DIRT_IMG = pygame.image.load("dirt.jpg")
DIRT = pygame.transform.scale(DIRT_IMG, (50, 50))
GRASS_IMG = pygame.image.load("grass.png")
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
SAND_IMG = pygame.image.load("sand.jpeg")
SAND = pygame.transform.scale(SAND_IMG, (50, 50))
COBBLE_IMG = pygame.image.load("cobble.jpeg")
COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50))
FPS = 10 FPS = 10
FIELDCOUNT = 16 FIELDCOUNT = 16
FIELDWIDTH = 50 FIELDWIDTH = 50
GRASS_IMG = pygame.image.load("grass.png")
class Agent: GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
def __init__(self, rect, direction):
self.rect = rect
self.direction = direction
def randomize_map(): # tworzenie mapy z losowymi polami
field_array_1 = []
field_array_2 = []
field_priority = []
for i in range(16):
temp_priority = []
for j in range(16):
if i in (0, 1) and j in (0, 1):
field_array_2.append(GRASS)
temp_priority.append(1)
else:
prob = random.uniform(0, 100)
if 0 <= prob <= 12:
field_array_2.append(COBBLE)
temp_priority.append(3)
elif 12 < prob <= 24:
field_array_2.append(SAND)
temp_priority.append(2)
else:
field_array_2.append(GRASS)
temp_priority.append(1)
field_array_1.append(field_array_2)
field_array_2 = []
field_priority.append(temp_priority)
return field_array_1, field_priority
def draw_window(agent, fields, flip): def draw_window(agent, fields, flip):
if flip: if flip:
direction = pygame.transform.flip(AGENT, True, False) direction = pygame.transform.flip(AGENT, True, False)
@ -77,35 +34,22 @@ def draw_window(agent, fields, flip):
def main(): def main():
train_data = pd.read_csv('./data_set.csv') clf = treelearn()
attributes = train_data.drop('collect', axis='columns')
e_type = LabelEncoder()
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
label_names = ['collect', 'no-collect']
label = train_data['collect']
print(attr_encoded)
print(label)
classifier = tree.DecisionTreeClassifier()
classifier.fit(attr_encoded, label)
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
graph = graphviz.Source(dot_data)
graph.render('collect')
clock = pygame.time.Clock() clock = pygame.time.Clock()
run = True run = True
x, y = [0, 0] fields, priority_array, request_list = randomize_map()
agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta
fields, priority_array = randomize_map()
final_x, final_y = [100, 300]
while run: while run:
clock.tick(FPS) clock.tick(FPS)
for event in pygame.event.get(): for event in pygame.event.get():
if event.type == pygame.QUIT: if event.type == pygame.QUIT:
run = False run = False
# keys_pressed = pygame.key.get_pressed()
draw_window(agent, fields, False) # false = kierunek east (domyslny), true = west draw_window(agent, fields, False) # false = kierunek east (domyslny), true = west
steps = astar(State(None, None, x, y, 'E', priority_array[0][0], heuristicfn(x, y, final_x, final_y)), final_x, final_y, priority_array) x, y = agent.next_destination()
if x == agent.rect.x and y == agent.rect.y:
print('out of jobs')
break
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation, priority_array[0][0], heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
for interm in steps: for interm in steps:
if interm.action == 'LEFT': if interm.action == 'LEFT':
agent.turn_left() agent.turn_left()
@ -121,10 +65,11 @@ def main():
draw_window(agent, fields, True) draw_window(agent, fields, True)
else: else:
draw_window(agent, fields, False) draw_window(agent, fields, False)
time.sleep(0.5) time.sleep(0.3)
agent.collect()
fields[agent.rect.x//50][agent.rect.y//50] = GRASS
time.sleep(0.5)
while True:
pass
pygame.quit() pygame.quit()

44
map.py Normal file
View File

@ -0,0 +1,44 @@
import pygame, random
from request import Request
DIRT_IMG = pygame.image.load("dirt.jpg")
DIRT = pygame.transform.scale(DIRT_IMG, (50, 50))
GRASS_IMG = pygame.image.load("grass.png")
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
SAND_IMG = pygame.image.load("sand.jpeg")
SAND = pygame.transform.scale(SAND_IMG, (50, 50))
COBBLE_IMG = pygame.image.load("cobble.jpeg")
COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50))
def randomize_map(): # tworzenie mapy z losowymi polami
request_list = []
field_array_1 = []
field_array_2 = []
field_priority = []
for i in range(16):
temp_priority = []
for j in range(16):
if i in (0, 1) and j in (0, 1):
field_array_2.append(GRASS)
temp_priority.append(1)
else:
prob = random.uniform(0, 100)
if 0 <= prob <= 12:
field_array_2.append(COBBLE)
temp_priority.append(100)
request_list.append(Request(
i*50,j*50, #lokacja
random.randint(0,3), #typ śmieci
random.random(), #objętość śmieci
random.randint(0,30), #ostatni odbiór
random.randint(0,1), #czy opłacone w terminie
random.random() * 10, #intensywność odoru
random.random() * 50 #waga śmieci
))
else:
field_array_2.append(GRASS)
temp_priority.append(1)
field_array_1.append(field_array_2)
field_array_2 = []
field_priority.append(temp_priority)
return field_array_1, field_priority, request_list

13
request.py Normal file
View File

@ -0,0 +1,13 @@
from dataclasses import dataclass
@dataclass
class Request:
def __init__(self, x_pos, y_pos, type, volume, last_collection, is_paid, odour_intensity, weight):
self.x_pos = x_pos
self.y_pos = y_pos
self.type = type
self.volume = volume
self.last_collection = last_collection
self.is_paid = is_paid
self.odour_intensity = odour_intensity
self.weight = weight

32
succ.py
View File

@ -5,27 +5,27 @@ FIELDWIDTH, FIELDCOUNT = 50, 16
def succ(st: State, passedPriorities, goalx, goaly): def succ(st: State, passedPriorities, goalx, goaly):
successors = [] successors = []
if st.orientation == 'N': if st.orientation == 3:
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'W', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'LEFT', st.xpos, st.ypos, 2, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 'E', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 0, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
if st.ypos > 0: if st.ypos > 0:
successors.append(State(st, 'FORWARD', st.xpos, st.ypos - FIELDWIDTH , 'N', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'FORWARD', st.xpos, st.ypos - FIELDWIDTH , 3, passedPriorities[st.xpos//50][st.ypos//50 - 1], heuristicfn(st.xpos, st.ypos - 50, goalx, goaly)))
if st.orientation == 'S': if st.orientation == 1:
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'E', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'LEFT', st.xpos, st.ypos, 0, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
successors.append(State(st,'RIGHT', st.xpos, st.ypos, 'W', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st,'RIGHT', st.xpos, st.ypos, 2, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
if st.ypos < FIELDWIDTH * (FIELDCOUNT - 1): if st.ypos < FIELDWIDTH * (FIELDCOUNT - 1):
successors.append(State(st, 'FORWARD', st.xpos, st.ypos + FIELDWIDTH , 'S', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'FORWARD', st.xpos, st.ypos + FIELDWIDTH , 1, passedPriorities[st.xpos//50][st.ypos//50 + 1], heuristicfn(st.xpos, st.ypos + 50, goalx, goaly)))
if st.orientation == 'W': if st.orientation == 2:
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'S', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'LEFT', st.xpos, st.ypos, 1, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
successors.append(State(st,'RIGHT', st.xpos, st.ypos, 'N', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st,'RIGHT', st.xpos, st.ypos, 3, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
if st.xpos > 0: if st.xpos > 0:
successors.append(State(st, 'FORWARD', st.xpos - FIELDWIDTH , st.ypos, 'W', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'FORWARD', st.xpos - FIELDWIDTH , st.ypos, 2, passedPriorities[st.xpos//50 - 1][st.ypos//50], heuristicfn(st.xpos - 50, st.ypos, goalx, goaly)))
if st.orientation == 'E': if st.orientation == 0:
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'N', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'LEFT', st.xpos, st.ypos, 3, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 'S', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 1, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
if st.xpos < FIELDWIDTH * (FIELDCOUNT - 1): if st.xpos < FIELDWIDTH * (FIELDCOUNT - 1):
successors.append(State(st, 'FORWARD', st.xpos + FIELDWIDTH , st.ypos, 'E', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly))) successors.append(State(st, 'FORWARD', st.xpos + FIELDWIDTH , st.ypos, 0, passedPriorities[st.xpos//50 + 1][st.ypos//50], heuristicfn(st.xpos + 50, st.ypos, goalx, goaly)))
return successors return successors

20
treelearn.py Normal file
View File

@ -0,0 +1,20 @@
import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
import graphviz
def treelearn():
train_data = pd.read_csv('./data_set.csv')
attributes = train_data.drop('collect', axis='columns')
e_type = LabelEncoder()
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
label_names = ['collect', 'no-collect']
label = train_data['collect']
classifier = tree.DecisionTreeClassifier()
classifier.fit(attr_encoded.values, label)
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
graph = graphviz.Source(dot_data)
graph.render('collect')
return classifier