trained decision tree
This commit is contained in:
parent
0738e210b1
commit
63596f20f2
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/__pycache__
|
101
collect
Normal file
101
collect
Normal file
@ -0,0 +1,101 @@
|
||||
digraph Tree {
|
||||
node [shape=box, fontname="helvetica"] ;
|
||||
edge [fontname="helvetica"] ;
|
||||
0 [label="distance <= 14.5\ngini = 0.484\nsamples = 200\nvalue = [118, 82]\nclass = collect"] ;
|
||||
1 [label="paid_on_time <= 0.5\ngini = 0.437\nsamples = 93\nvalue = [30, 63]\nclass = no-collect"] ;
|
||||
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
|
||||
2 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
|
||||
1 -> 2 ;
|
||||
3 [label="fuel <= 1508.0\ngini = 0.375\nsamples = 84\nvalue = [21, 63]\nclass = no-collect"] ;
|
||||
1 -> 3 ;
|
||||
4 [label="gini = 0.0\nsamples = 7\nvalue = [7, 0]\nclass = collect"] ;
|
||||
3 -> 4 ;
|
||||
5 [label="space_occupied <= 0.856\ngini = 0.298\nsamples = 77\nvalue = [14, 63]\nclass = no-collect"] ;
|
||||
3 -> 5 ;
|
||||
6 [label="days_since_last_collection <= 4.5\ngini = 0.187\nsamples = 67\nvalue = [7, 60]\nclass = no-collect"] ;
|
||||
5 -> 6 ;
|
||||
7 [label="fuel <= 11519.5\ngini = 0.48\nsamples = 5\nvalue = [3, 2]\nclass = collect"] ;
|
||||
6 -> 7 ;
|
||||
8 [label="gini = 0.0\nsamples = 3\nvalue = [3, 0]\nclass = collect"] ;
|
||||
7 -> 8 ;
|
||||
9 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
|
||||
7 -> 9 ;
|
||||
10 [label="fuel <= 16955.0\ngini = 0.121\nsamples = 62\nvalue = [4, 58]\nclass = no-collect"] ;
|
||||
6 -> 10 ;
|
||||
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
|
||||
10 -> 11 ;
|
||||
12 [label="odour_intensity <= 5.682\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
11 -> 12 ;
|
||||
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
12 -> 13 ;
|
||||
14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||
12 -> 14 ;
|
||||
15 [label="garbage_type <= 2.5\ngini = 0.065\nsamples = 59\nvalue = [2, 57]\nclass = no-collect"] ;
|
||||
11 -> 15 ;
|
||||
16 [label="gini = 0.0\nsamples = 46\nvalue = [0, 46]\nclass = no-collect"] ;
|
||||
15 -> 16 ;
|
||||
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
|
||||
15 -> 17 ;
|
||||
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
|
||||
17 -> 18 ;
|
||||
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
18 -> 19 ;
|
||||
20 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||
18 -> 20 ;
|
||||
21 [label="gini = 0.0\nsamples = 10\nvalue = [0, 10]\nclass = no-collect"] ;
|
||||
17 -> 21 ;
|
||||
22 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
10 -> 22 ;
|
||||
23 [label="garbage_type <= 1.0\ngini = 0.42\nsamples = 10\nvalue = [7, 3]\nclass = collect"] ;
|
||||
5 -> 23 ;
|
||||
24 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
|
||||
23 -> 24 ;
|
||||
25 [label="odour_intensity <= 8.841\ngini = 0.219\nsamples = 8\nvalue = [7, 1]\nclass = collect"] ;
|
||||
23 -> 25 ;
|
||||
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
|
||||
25 -> 26 ;
|
||||
27 [label="distance <= 7.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
|
||||
25 -> 27 ;
|
||||
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
|
||||
27 -> 28 ;
|
||||
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||
27 -> 29 ;
|
||||
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
|
||||
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
|
||||
31 [label="garbage_weight <= 49.659\ngini = 0.116\nsamples = 81\nvalue = [76, 5]\nclass = collect"] ;
|
||||
30 -> 31 ;
|
||||
32 [label="days_since_last_collection <= 24.5\ngini = 0.095\nsamples = 80\nvalue = [76, 4]\nclass = collect"] ;
|
||||
31 -> 32 ;
|
||||
33 [label="gini = 0.0\nsamples = 65\nvalue = [65, 0]\nclass = collect"] ;
|
||||
32 -> 33 ;
|
||||
34 [label="distance <= 19.0\ngini = 0.391\nsamples = 15\nvalue = [11, 4]\nclass = collect"] ;
|
||||
32 -> 34 ;
|
||||
35 [label="fuel <= 6122.0\ngini = 0.444\nsamples = 6\nvalue = [2, 4]\nclass = no-collect"] ;
|
||||
34 -> 35 ;
|
||||
36 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
35 -> 36 ;
|
||||
37 [label="gini = 0.0\nsamples = 4\nvalue = [0, 4]\nclass = no-collect"] ;
|
||||
35 -> 37 ;
|
||||
38 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
|
||||
34 -> 38 ;
|
||||
39 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
|
||||
31 -> 39 ;
|
||||
40 [label="days_since_last_collection <= 13.5\ngini = 0.497\nsamples = 26\nvalue = [12, 14]\nclass = no-collect"] ;
|
||||
30 -> 40 ;
|
||||
41 [label="gini = 0.0\nsamples = 8\nvalue = [8, 0]\nclass = collect"] ;
|
||||
40 -> 41 ;
|
||||
42 [label="distance <= 20.0\ngini = 0.346\nsamples = 18\nvalue = [4, 14]\nclass = no-collect"] ;
|
||||
40 -> 42 ;
|
||||
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
|
||||
42 -> 43 ;
|
||||
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
|
||||
42 -> 44 ;
|
||||
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
44 -> 45 ;
|
||||
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
|
||||
44 -> 46 ;
|
||||
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
|
||||
46 -> 47 ;
|
||||
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
|
||||
46 -> 48 ;
|
||||
}
|
BIN
collect.pdf
Normal file
BIN
collect.pdf
Normal file
Binary file not shown.
22
main.py
22
main.py
@ -1,5 +1,11 @@
|
||||
import pygame
|
||||
import random
|
||||
import pandas as pd
|
||||
from sklearn import tree
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import graphviz
|
||||
|
||||
|
||||
from astar import astar
|
||||
from state import State
|
||||
import time
|
||||
@ -71,12 +77,26 @@ def draw_window(agent, fields, flip):
|
||||
|
||||
|
||||
def main():
|
||||
train_data = pd.read_csv('./data_set.csv')
|
||||
attributes = train_data.drop('collect', axis='columns')
|
||||
e_type = LabelEncoder()
|
||||
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
|
||||
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
|
||||
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
|
||||
label_names = ['collect', 'no-collect']
|
||||
label = train_data['collect']
|
||||
print(attr_encoded)
|
||||
print(label)
|
||||
classifier = tree.DecisionTreeClassifier()
|
||||
classifier.fit(attr_encoded, label)
|
||||
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
|
||||
graph = graphviz.Source(dot_data)
|
||||
graph.render('collect')
|
||||
clock = pygame.time.Clock()
|
||||
run = True
|
||||
x, y = [0, 0]
|
||||
agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta
|
||||
fields, priority_array = randomize_map()
|
||||
print(priority_array)
|
||||
final_x, final_y = [100, 300]
|
||||
while run:
|
||||
clock.tick(FPS)
|
||||
|
Loading…
Reference in New Issue
Block a user