trained decision tree
This commit is contained in:
parent
0738e210b1
commit
63596f20f2
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/__pycache__
|
101
collect
Normal file
101
collect
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
digraph Tree {
|
||||||
|
node [shape=box, fontname="helvetica"] ;
|
||||||
|
edge [fontname="helvetica"] ;
|
||||||
|
0 [label="distance <= 14.5\ngini = 0.484\nsamples = 200\nvalue = [118, 82]\nclass = collect"] ;
|
||||||
|
1 [label="paid_on_time <= 0.5\ngini = 0.437\nsamples = 93\nvalue = [30, 63]\nclass = no-collect"] ;
|
||||||
|
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
|
||||||
|
2 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
|
||||||
|
1 -> 2 ;
|
||||||
|
3 [label="fuel <= 1508.0\ngini = 0.375\nsamples = 84\nvalue = [21, 63]\nclass = no-collect"] ;
|
||||||
|
1 -> 3 ;
|
||||||
|
4 [label="gini = 0.0\nsamples = 7\nvalue = [7, 0]\nclass = collect"] ;
|
||||||
|
3 -> 4 ;
|
||||||
|
5 [label="space_occupied <= 0.856\ngini = 0.298\nsamples = 77\nvalue = [14, 63]\nclass = no-collect"] ;
|
||||||
|
3 -> 5 ;
|
||||||
|
6 [label="days_since_last_collection <= 4.5\ngini = 0.187\nsamples = 67\nvalue = [7, 60]\nclass = no-collect"] ;
|
||||||
|
5 -> 6 ;
|
||||||
|
7 [label="fuel <= 11519.5\ngini = 0.48\nsamples = 5\nvalue = [3, 2]\nclass = collect"] ;
|
||||||
|
6 -> 7 ;
|
||||||
|
8 [label="gini = 0.0\nsamples = 3\nvalue = [3, 0]\nclass = collect"] ;
|
||||||
|
7 -> 8 ;
|
||||||
|
9 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
|
||||||
|
7 -> 9 ;
|
||||||
|
10 [label="fuel <= 16955.0\ngini = 0.121\nsamples = 62\nvalue = [4, 58]\nclass = no-collect"] ;
|
||||||
|
6 -> 10 ;
|
||||||
|
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
|
||||||
|
10 -> 11 ;
|
||||||
|
12 [label="odour_intensity <= 5.682\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||||
|
11 -> 12 ;
|
||||||
|
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||||
|
12 -> 13 ;
|
||||||
|
14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||||
|
12 -> 14 ;
|
||||||
|
15 [label="garbage_type <= 2.5\ngini = 0.065\nsamples = 59\nvalue = [2, 57]\nclass = no-collect"] ;
|
||||||
|
11 -> 15 ;
|
||||||
|
16 [label="gini = 0.0\nsamples = 46\nvalue = [0, 46]\nclass = no-collect"] ;
|
||||||
|
15 -> 16 ;
|
||||||
|
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
|
||||||
|
15 -> 17 ;
|
||||||
|
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
||||||
|
17 -> 18 ;
|
||||||
|
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||||
|
18 -> 19 ;
|
||||||
|
20 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||||
|
18 -> 20 ;
|
||||||
|
21 [label="gini = 0.0\nsamples = 10\nvalue = [0, 10]\nclass = no-collect"] ;
|
||||||
|
17 -> 21 ;
|
||||||
|
22 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||||
|
10 -> 22 ;
|
||||||
|
23 [label="garbage_type <= 1.0\ngini = 0.42\nsamples = 10\nvalue = [7, 3]\nclass = collect"] ;
|
||||||
|
5 -> 23 ;
|
||||||
|
24 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
|
||||||
|
23 -> 24 ;
|
||||||
|
25 [label="odour_intensity <= 8.841\ngini = 0.219\nsamples = 8\nvalue = [7, 1]\nclass = collect"] ;
|
||||||
|
23 -> 25 ;
|
||||||
|
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
|
||||||
|
25 -> 26 ;
|
||||||
|
27 [label="distance <= 7.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||||
|
25 -> 27 ;
|
||||||
|
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||||
|
27 -> 28 ;
|
||||||
|
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||||
|
27 -> 29 ;
|
||||||
|
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
|
||||||
|
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
|
||||||
|
31 [label="garbage_weight <= 49.659\ngini = 0.116\nsamples = 81\nvalue = [76, 5]\nclass = collect"] ;
|
||||||
|
30 -> 31 ;
|
||||||
|
32 [label="days_since_last_collection <= 24.5\ngini = 0.095\nsamples = 80\nvalue = [76, 4]\nclass = collect"] ;
|
||||||
|
31 -> 32 ;
|
||||||
|
33 [label="gini = 0.0\nsamples = 65\nvalue = [65, 0]\nclass = collect"] ;
|
||||||
|
32 -> 33 ;
|
||||||
|
34 [label="distance <= 19.0\ngini = 0.391\nsamples = 15\nvalue = [11, 4]\nclass = collect"] ;
|
||||||
|
32 -> 34 ;
|
||||||
|
35 [label="fuel <= 6122.0\ngini = 0.444\nsamples = 6\nvalue = [2, 4]\nclass = no-collect"] ;
|
||||||
|
34 -> 35 ;
|
||||||
|
36 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||||
|
35 -> 36 ;
|
||||||
|
37 [label="gini = 0.0\nsamples = 4\nvalue = [0, 4]\nclass = no-collect"] ;
|
||||||
|
35 -> 37 ;
|
||||||
|
38 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
|
||||||
|
34 -> 38 ;
|
||||||
|
39 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||||
|
31 -> 39 ;
|
||||||
|
40 [label="days_since_last_collection <= 13.5\ngini = 0.497\nsamples = 26\nvalue = [12, 14]\nclass = no-collect"] ;
|
||||||
|
30 -> 40 ;
|
||||||
|
41 [label="gini = 0.0\nsamples = 8\nvalue = [8, 0]\nclass = collect"] ;
|
||||||
|
40 -> 41 ;
|
||||||
|
42 [label="distance <= 20.0\ngini = 0.346\nsamples = 18\nvalue = [4, 14]\nclass = no-collect"] ;
|
||||||
|
40 -> 42 ;
|
||||||
|
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
|
||||||
|
42 -> 43 ;
|
||||||
|
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
||||||
|
42 -> 44 ;
|
||||||
|
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||||
|
44 -> 45 ;
|
||||||
|
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
||||||
|
44 -> 46 ;
|
||||||
|
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||||
|
46 -> 47 ;
|
||||||
|
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
||||||
|
46 -> 48 ;
|
||||||
|
}
|
BIN
collect.pdf
Normal file
BIN
collect.pdf
Normal file
Binary file not shown.
22
main.py
22
main.py
@ -1,5 +1,11 @@
|
|||||||
import pygame
|
import pygame
|
||||||
import random
|
import random
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn import tree
|
||||||
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
import graphviz
|
||||||
|
|
||||||
|
|
||||||
from astar import astar
|
from astar import astar
|
||||||
from state import State
|
from state import State
|
||||||
import time
|
import time
|
||||||
@ -71,12 +77,26 @@ def draw_window(agent, fields, flip):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
train_data = pd.read_csv('./data_set.csv')
|
||||||
|
attributes = train_data.drop('collect', axis='columns')
|
||||||
|
e_type = LabelEncoder()
|
||||||
|
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
|
||||||
|
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
|
||||||
|
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
|
||||||
|
label_names = ['collect', 'no-collect']
|
||||||
|
label = train_data['collect']
|
||||||
|
print(attr_encoded)
|
||||||
|
print(label)
|
||||||
|
classifier = tree.DecisionTreeClassifier()
|
||||||
|
classifier.fit(attr_encoded, label)
|
||||||
|
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
|
||||||
|
graph = graphviz.Source(dot_data)
|
||||||
|
graph.render('collect')
|
||||||
clock = pygame.time.Clock()
|
clock = pygame.time.Clock()
|
||||||
run = True
|
run = True
|
||||||
x, y = [0, 0]
|
x, y = [0, 0]
|
||||||
agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta
|
agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta
|
||||||
fields, priority_array = randomize_map()
|
fields, priority_array = randomize_map()
|
||||||
print(priority_array)
|
|
||||||
final_x, final_y = [100, 300]
|
final_x, final_y = [100, 300]
|
||||||
while run:
|
while run:
|
||||||
clock.tick(FPS)
|
clock.tick(FPS)
|
||||||
|
Loading…
Reference in New Issue
Block a user