updated agent to collect garbage
This commit is contained in:
parent
63596f20f2
commit
55baf24513
3
.idea/.gitignore
vendored
3
.idea/.gitignore
vendored
@ -1,3 +0,0 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
23
astar.py
23
astar.py
@ -36,26 +36,3 @@ def astar(istate, goalx, goaly, passedFields):
|
||||
element.priority = value.priority
|
||||
return False
|
||||
|
||||
|
||||
# def bfs(istate, goalx, goaly, passedFields):
|
||||
# fringe = [istate]
|
||||
# explored = []
|
||||
# steps = []
|
||||
# while fringe:
|
||||
# state = fringe.pop(0)
|
||||
# if state.xpos == goalx and state.ypos == goaly:
|
||||
# steps.insert(0, state)
|
||||
# while (state.parent != None):
|
||||
# state = state.parent
|
||||
# steps.insert(0, state)
|
||||
# return steps
|
||||
|
||||
# element = successors(state, passedFields)
|
||||
# explored.append((state.xpos, state.ypos, state.orientation))
|
||||
# for value in element:
|
||||
# val = (value.xpos, value.ypos, value.orientation)
|
||||
# if val not in explored and value not in fringe:
|
||||
# fringe.append(value)
|
||||
# return False
|
||||
|
||||
|
||||
|
22
collect
22
collect
@ -24,7 +24,7 @@ edge [fontname="helvetica"] ;
|
||||
6 -> 10 ;
|
||||
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
|
||||
10 -> 11 ;
|
||||
12 [label="odour_intensity <= 5.682\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
12 [label="distance <= 10.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
11 -> 12 ;
|
||||
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
12 -> 13 ;
|
||||
@ -36,7 +36,7 @@ edge [fontname="helvetica"] ;
|
||||
15 -> 16 ;
|
||||
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
|
||||
15 -> 17 ;
|
||||
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
||||
18 [label="odour_intensity <= 5.724\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
||||
17 -> 18 ;
|
||||
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
18 -> 19 ;
|
||||
@ -54,11 +54,11 @@ edge [fontname="helvetica"] ;
|
||||
23 -> 25 ;
|
||||
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
|
||||
25 -> 26 ;
|
||||
27 [label="distance <= 7.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
27 [label="space_occupied <= 0.936\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
25 -> 27 ;
|
||||
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
28 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||
27 -> 28 ;
|
||||
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||
29 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
27 -> 29 ;
|
||||
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
|
||||
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
|
||||
@ -88,14 +88,18 @@ edge [fontname="helvetica"] ;
|
||||
40 -> 42 ;
|
||||
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
|
||||
42 -> 43 ;
|
||||
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
||||
44 [label="days_since_last_collection <= 20.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
||||
42 -> 44 ;
|
||||
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
44 -> 45 ;
|
||||
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
||||
46 [label="paid_on_time <= 0.5\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
||||
44 -> 46 ;
|
||||
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
47 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
46 -> 47 ;
|
||||
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
||||
48 [label="space_occupied <= 0.243\ngini = 0.245\nsamples = 7\nvalue = [1, 6]\nclass = no-collect"] ;
|
||||
46 -> 48 ;
|
||||
49 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
48 -> 49 ;
|
||||
50 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
||||
48 -> 50 ;
|
||||
}
|
||||
|
BIN
collect.pdf
BIN
collect.pdf
Binary file not shown.
@ -1,31 +1,38 @@
|
||||
from heuristicfn import heuristicfn
|
||||
|
||||
FIELDWIDTH = 50
|
||||
TURN_FUEL_COST = 10
|
||||
MOVE_FUEL_COST = 200
|
||||
MAX_FUEL = 20000
|
||||
MAX_SPACE = 5
|
||||
MAX_WEIGHT = 200
|
||||
|
||||
class GarbageTank:
|
||||
def __init__(self, volume_capacity, mass_capacity):
|
||||
self.vcapacity = volume_capacity #m^3
|
||||
self.mcapacity = mass_capacity #kg
|
||||
|
||||
class Engine:
|
||||
def __init__(self, power):
|
||||
self.power = power #HP
|
||||
|
||||
class GarbageTruck:
|
||||
def __init__(self, dump_location, fuel_capacity, rect, orientation):
|
||||
self.dump_location = dump_location
|
||||
self.tank = GarbageTank(15, 18000)
|
||||
self.engine = Engine(400)
|
||||
self.fuel = fuel_capacity
|
||||
|
||||
garbage_types = {'bio': 0, 'electronics': 1, 'mixed': 2, 'recyclable': 3}
|
||||
|
||||
def __init__(self, dump_x, dump_y, rect, orientation, request_list: list, clf):
|
||||
self.dump_x = dump_x
|
||||
self.dump_y = dump_y
|
||||
self.fuel = MAX_FUEL
|
||||
self.free_space = MAX_SPACE
|
||||
self.weight_capacity = MAX_WEIGHT
|
||||
self.rect = rect
|
||||
self.orientation = orientation
|
||||
self.houses = [] #lista domów do odwiedzenia
|
||||
self.request_list = request_list #lista domów do odwiedzenia
|
||||
self.clf = clf
|
||||
|
||||
def turn_left(self):
|
||||
self.orientation = (self.orientation - 1) % 4
|
||||
self.fuel -= TURN_FUEL_COST
|
||||
|
||||
def turn_right(self):
|
||||
self.orientation = (self.orientation + 1) % 4
|
||||
self.fuel -= TURN_FUEL_COST
|
||||
|
||||
def forward(self):
|
||||
self.fuel -= MOVE_FUEL_COST
|
||||
if self.orientation == 0:
|
||||
self.rect.x += FIELDWIDTH
|
||||
elif self.orientation == 1:
|
||||
@ -34,3 +41,49 @@ class GarbageTruck:
|
||||
self.rect.x -= FIELDWIDTH
|
||||
else:
|
||||
self.rect.y -= FIELDWIDTH
|
||||
|
||||
def next_destination(self):
|
||||
if self.fuel <= 0 or not self.request_list:
|
||||
return self.dump_x, self.dump_y
|
||||
|
||||
for i in range(len(self.request_list)):
|
||||
request = self.request_list[i]
|
||||
|
||||
#nie ma miejsca w zbiorniku lub za ciężkie śmieci
|
||||
if request.volume > self.free_space or request.weight > self.weight_capacity:
|
||||
continue
|
||||
|
||||
#nie straczy paliwa na dojechanie i powrót na wysypisko
|
||||
if heuristicfn(request.x_pos, request.y_pos, self.dump_x, self.dump_y) / 50 * 200 > self.fuel:
|
||||
continue
|
||||
|
||||
|
||||
|
||||
distance = heuristicfn(self.rect.x, self.rect.y, request.x_pos, request.y_pos) / 50
|
||||
|
||||
r = [
|
||||
self.fuel,
|
||||
distance,
|
||||
request.volume,
|
||||
request.last_collection,
|
||||
request.is_paid,
|
||||
request.odour_intensity,
|
||||
request.weight,
|
||||
request.type
|
||||
]
|
||||
if self.clf.predict([r]) == True:
|
||||
self.request_list.pop(i)
|
||||
self.free_space -= request.volume
|
||||
self.weight_capacity -= request.weight
|
||||
return request.x_pos, request.y_pos
|
||||
return self.dump_x, self.dump_y
|
||||
|
||||
|
||||
|
||||
def collect(self):
|
||||
if self.rect.x == self.dump_x and self.rect.y == self.dump_y:
|
||||
self.fuel = MAX_WEIGHT
|
||||
self.free_space = MAX_SPACE
|
||||
self.weight_capacity = MAX_WEIGHT
|
||||
print(f'agent at ({self.rect.x}, {self.rect.y}); fuel: {self.fuel}; free space: {self.free_space}; weight capacity: {self.weight_capacity}')
|
||||
pass
|
4
home.py
4
home.py
@ -1,4 +0,0 @@
|
||||
class Home:
|
||||
def __init__(self, coord):
|
||||
self.coord = coord
|
||||
self.collect_request = False
|
@ -1,8 +0,0 @@
|
||||
class Litter:
|
||||
|
||||
types = ['PAPER', 'GLASS', 'PLASTIC', 'METAL', 'BIO', 'MUNICIPAL', 'ELECTRONICS']
|
||||
|
||||
def __init__(self, type, volume, mass):
|
||||
self.type = type
|
||||
self.volume = volume
|
||||
self.mass = mass
|
87
main.py
87
main.py
@ -1,9 +1,5 @@
|
||||
import pygame
|
||||
import random
|
||||
import pandas as pd
|
||||
from sklearn import tree
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import graphviz
|
||||
from treelearn import treelearn
|
||||
|
||||
|
||||
from astar import astar
|
||||
@ -11,6 +7,7 @@ from state import State
|
||||
import time
|
||||
from garbage_truck import GarbageTruck
|
||||
from heuristicfn import heuristicfn
|
||||
from map import randomize_map
|
||||
|
||||
pygame.init()
|
||||
WIDTH, HEIGHT = 800, 800
|
||||
@ -18,52 +15,12 @@ window = pygame.display.set_mode((WIDTH, HEIGHT))
|
||||
pygame.display.set_caption("Intelligent Garbage Collector")
|
||||
AGENT_IMG = pygame.image.load("garbage-truck-nbg.png")
|
||||
AGENT = pygame.transform.scale(AGENT_IMG, (50, 50))
|
||||
DIRT_IMG = pygame.image.load("dirt.jpg")
|
||||
DIRT = pygame.transform.scale(DIRT_IMG, (50, 50))
|
||||
GRASS_IMG = pygame.image.load("grass.png")
|
||||
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
|
||||
SAND_IMG = pygame.image.load("sand.jpeg")
|
||||
SAND = pygame.transform.scale(SAND_IMG, (50, 50))
|
||||
COBBLE_IMG = pygame.image.load("cobble.jpeg")
|
||||
COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50))
|
||||
FPS = 10
|
||||
FIELDCOUNT = 16
|
||||
FIELDWIDTH = 50
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, rect, direction):
|
||||
self.rect = rect
|
||||
self.direction = direction
|
||||
|
||||
|
||||
def randomize_map(): # tworzenie mapy z losowymi polami
|
||||
field_array_1 = []
|
||||
field_array_2 = []
|
||||
field_priority = []
|
||||
for i in range(16):
|
||||
temp_priority = []
|
||||
for j in range(16):
|
||||
if i in (0, 1) and j in (0, 1):
|
||||
field_array_2.append(GRASS)
|
||||
temp_priority.append(1)
|
||||
else:
|
||||
prob = random.uniform(0, 100)
|
||||
if 0 <= prob <= 12:
|
||||
field_array_2.append(COBBLE)
|
||||
temp_priority.append(3)
|
||||
elif 12 < prob <= 24:
|
||||
field_array_2.append(SAND)
|
||||
temp_priority.append(2)
|
||||
else:
|
||||
field_array_2.append(GRASS)
|
||||
temp_priority.append(1)
|
||||
field_array_1.append(field_array_2)
|
||||
field_array_2 = []
|
||||
field_priority.append(temp_priority)
|
||||
return field_array_1, field_priority
|
||||
|
||||
|
||||
GRASS_IMG = pygame.image.load("grass.png")
|
||||
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
|
||||
def draw_window(agent, fields, flip):
|
||||
if flip:
|
||||
direction = pygame.transform.flip(AGENT, True, False)
|
||||
@ -77,35 +34,22 @@ def draw_window(agent, fields, flip):
|
||||
|
||||
|
||||
def main():
|
||||
train_data = pd.read_csv('./data_set.csv')
|
||||
attributes = train_data.drop('collect', axis='columns')
|
||||
e_type = LabelEncoder()
|
||||
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
|
||||
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
|
||||
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
|
||||
label_names = ['collect', 'no-collect']
|
||||
label = train_data['collect']
|
||||
print(attr_encoded)
|
||||
print(label)
|
||||
classifier = tree.DecisionTreeClassifier()
|
||||
classifier.fit(attr_encoded, label)
|
||||
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
|
||||
graph = graphviz.Source(dot_data)
|
||||
graph.render('collect')
|
||||
clf = treelearn()
|
||||
clock = pygame.time.Clock()
|
||||
run = True
|
||||
x, y = [0, 0]
|
||||
agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta
|
||||
fields, priority_array = randomize_map()
|
||||
final_x, final_y = [100, 300]
|
||||
fields, priority_array, request_list = randomize_map()
|
||||
agent = GarbageTruck(0, 0, pygame.Rect(0, 0, 50, 50), 0, request_list, clf) # tworzenie pola dla agenta
|
||||
while run:
|
||||
clock.tick(FPS)
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
run = False
|
||||
# keys_pressed = pygame.key.get_pressed()
|
||||
draw_window(agent, fields, False) # false = kierunek east (domyslny), true = west
|
||||
steps = astar(State(None, None, x, y, 'E', priority_array[0][0], heuristicfn(x, y, final_x, final_y)), final_x, final_y, priority_array)
|
||||
x, y = agent.next_destination()
|
||||
if x == agent.rect.x and y == agent.rect.y:
|
||||
print('out of jobs')
|
||||
break
|
||||
steps = astar(State(None, None, agent.rect.x, agent.rect.y, agent.orientation, priority_array[0][0], heuristicfn(agent.rect.x, agent.rect.y, x, y)), x, y, priority_array)
|
||||
for interm in steps:
|
||||
if interm.action == 'LEFT':
|
||||
agent.turn_left()
|
||||
@ -121,10 +65,11 @@ def main():
|
||||
draw_window(agent, fields, True)
|
||||
else:
|
||||
draw_window(agent, fields, False)
|
||||
time.sleep(0.5)
|
||||
time.sleep(0.3)
|
||||
agent.collect()
|
||||
fields[agent.rect.x//50][agent.rect.y//50] = GRASS
|
||||
time.sleep(0.5)
|
||||
|
||||
while True:
|
||||
pass
|
||||
|
||||
pygame.quit()
|
||||
|
||||
|
44
map.py
Normal file
44
map.py
Normal file
@ -0,0 +1,44 @@
|
||||
import pygame, random
|
||||
from request import Request
|
||||
|
||||
DIRT_IMG = pygame.image.load("dirt.jpg")
|
||||
DIRT = pygame.transform.scale(DIRT_IMG, (50, 50))
|
||||
GRASS_IMG = pygame.image.load("grass.png")
|
||||
GRASS = pygame.transform.scale(GRASS_IMG, (50, 50))
|
||||
SAND_IMG = pygame.image.load("sand.jpeg")
|
||||
SAND = pygame.transform.scale(SAND_IMG, (50, 50))
|
||||
COBBLE_IMG = pygame.image.load("cobble.jpeg")
|
||||
COBBLE = pygame.transform.scale(COBBLE_IMG, (50, 50))
|
||||
|
||||
def randomize_map(): # tworzenie mapy z losowymi polami
|
||||
request_list = []
|
||||
field_array_1 = []
|
||||
field_array_2 = []
|
||||
field_priority = []
|
||||
for i in range(16):
|
||||
temp_priority = []
|
||||
for j in range(16):
|
||||
if i in (0, 1) and j in (0, 1):
|
||||
field_array_2.append(GRASS)
|
||||
temp_priority.append(1)
|
||||
else:
|
||||
prob = random.uniform(0, 100)
|
||||
if 0 <= prob <= 12:
|
||||
field_array_2.append(COBBLE)
|
||||
temp_priority.append(100)
|
||||
request_list.append(Request(
|
||||
i*50,j*50, #lokacja
|
||||
random.randint(0,3), #typ śmieci
|
||||
random.random(), #objętość śmieci
|
||||
random.randint(0,30), #ostatni odbiór
|
||||
random.randint(0,1), #czy opłacone w terminie
|
||||
random.random() * 10, #intensywność odoru
|
||||
random.random() * 50 #waga śmieci
|
||||
))
|
||||
else:
|
||||
field_array_2.append(GRASS)
|
||||
temp_priority.append(1)
|
||||
field_array_1.append(field_array_2)
|
||||
field_array_2 = []
|
||||
field_priority.append(temp_priority)
|
||||
return field_array_1, field_priority, request_list
|
13
request.py
Normal file
13
request.py
Normal file
@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
def __init__(self, x_pos, y_pos, type, volume, last_collection, is_paid, odour_intensity, weight):
|
||||
self.x_pos = x_pos
|
||||
self.y_pos = y_pos
|
||||
self.type = type
|
||||
self.volume = volume
|
||||
self.last_collection = last_collection
|
||||
self.is_paid = is_paid
|
||||
self.odour_intensity = odour_intensity
|
||||
self.weight = weight
|
32
succ.py
32
succ.py
@ -5,27 +5,27 @@ FIELDWIDTH, FIELDCOUNT = 50, 16
|
||||
def succ(st: State, passedPriorities, goalx, goaly):
|
||||
successors = []
|
||||
|
||||
if st.orientation == 'N':
|
||||
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'W', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 'E', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
if st.orientation == 3:
|
||||
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 2, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 0, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
if st.ypos > 0:
|
||||
successors.append(State(st, 'FORWARD', st.xpos, st.ypos - FIELDWIDTH , 'N', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st, 'FORWARD', st.xpos, st.ypos - FIELDWIDTH , 3, passedPriorities[st.xpos//50][st.ypos//50 - 1], heuristicfn(st.xpos, st.ypos - 50, goalx, goaly)))
|
||||
|
||||
if st.orientation == 'S':
|
||||
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'E', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st,'RIGHT', st.xpos, st.ypos, 'W', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
if st.orientation == 1:
|
||||
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 0, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st,'RIGHT', st.xpos, st.ypos, 2, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
if st.ypos < FIELDWIDTH * (FIELDCOUNT - 1):
|
||||
successors.append(State(st, 'FORWARD', st.xpos, st.ypos + FIELDWIDTH , 'S', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st, 'FORWARD', st.xpos, st.ypos + FIELDWIDTH , 1, passedPriorities[st.xpos//50][st.ypos//50 + 1], heuristicfn(st.xpos, st.ypos + 50, goalx, goaly)))
|
||||
|
||||
if st.orientation == 'W':
|
||||
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'S', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st,'RIGHT', st.xpos, st.ypos, 'N', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
if st.orientation == 2:
|
||||
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 1, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st,'RIGHT', st.xpos, st.ypos, 3, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
if st.xpos > 0:
|
||||
successors.append(State(st, 'FORWARD', st.xpos - FIELDWIDTH , st.ypos, 'W', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st, 'FORWARD', st.xpos - FIELDWIDTH , st.ypos, 2, passedPriorities[st.xpos//50 - 1][st.ypos//50], heuristicfn(st.xpos - 50, st.ypos, goalx, goaly)))
|
||||
|
||||
if st.orientation == 'E':
|
||||
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 'N', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 'S', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
if st.orientation == 0:
|
||||
successors.append(State(st, 'LEFT', st.xpos, st.ypos, 3, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st, 'RIGHT', st.xpos, st.ypos, 1, passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
if st.xpos < FIELDWIDTH * (FIELDCOUNT - 1):
|
||||
successors.append(State(st, 'FORWARD', st.xpos + FIELDWIDTH , st.ypos, 'E', passedPriorities[st.xpos//50][st.ypos//50], heuristicfn(st.xpos, st.ypos, goalx, goaly)))
|
||||
successors.append(State(st, 'FORWARD', st.xpos + FIELDWIDTH , st.ypos, 0, passedPriorities[st.xpos//50 + 1][st.ypos//50], heuristicfn(st.xpos + 50, st.ypos, goalx, goaly)))
|
||||
return successors
|
||||
|
20
treelearn.py
Normal file
20
treelearn.py
Normal file
@ -0,0 +1,20 @@
|
||||
import pandas as pd
|
||||
from sklearn import tree
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import graphviz
|
||||
|
||||
def treelearn():
|
||||
train_data = pd.read_csv('./data_set.csv')
|
||||
attributes = train_data.drop('collect', axis='columns')
|
||||
e_type = LabelEncoder()
|
||||
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
|
||||
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
|
||||
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
|
||||
label_names = ['collect', 'no-collect']
|
||||
label = train_data['collect']
|
||||
classifier = tree.DecisionTreeClassifier()
|
||||
classifier.fit(attr_encoded.values, label)
|
||||
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
|
||||
graph = graphviz.Source(dot_data)
|
||||
graph.render('collect')
|
||||
return classifier
|
Loading…
Reference in New Issue
Block a user