trained decision tree

This commit is contained in:
Mateusz 2023-05-26 10:52:02 +02:00
parent 0738e210b1
commit 63596f20f2
4 changed files with 123 additions and 1 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/__pycache__

101
collect Normal file
View File

@ -0,0 +1,101 @@
digraph Tree {
node [shape=box, fontname="helvetica"] ;
edge [fontname="helvetica"] ;
0 [label="distance <= 14.5\ngini = 0.484\nsamples = 200\nvalue = [118, 82]\nclass = collect"] ;
1 [label="paid_on_time <= 0.5\ngini = 0.437\nsamples = 93\nvalue = [30, 63]\nclass = no-collect"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
1 -> 2 ;
3 [label="fuel <= 1508.0\ngini = 0.375\nsamples = 84\nvalue = [21, 63]\nclass = no-collect"] ;
1 -> 3 ;
4 [label="gini = 0.0\nsamples = 7\nvalue = [7, 0]\nclass = collect"] ;
3 -> 4 ;
5 [label="space_occupied <= 0.856\ngini = 0.298\nsamples = 77\nvalue = [14, 63]\nclass = no-collect"] ;
3 -> 5 ;
6 [label="days_since_last_collection <= 4.5\ngini = 0.187\nsamples = 67\nvalue = [7, 60]\nclass = no-collect"] ;
5 -> 6 ;
7 [label="fuel <= 11519.5\ngini = 0.48\nsamples = 5\nvalue = [3, 2]\nclass = collect"] ;
6 -> 7 ;
8 [label="gini = 0.0\nsamples = 3\nvalue = [3, 0]\nclass = collect"] ;
7 -> 8 ;
9 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
7 -> 9 ;
10 [label="fuel <= 16955.0\ngini = 0.121\nsamples = 62\nvalue = [4, 58]\nclass = no-collect"] ;
6 -> 10 ;
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
10 -> 11 ;
12 [label="odour_intensity <= 5.682\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
11 -> 12 ;
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
12 -> 13 ;
14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
12 -> 14 ;
15 [label="garbage_type <= 2.5\ngini = 0.065\nsamples = 59\nvalue = [2, 57]\nclass = no-collect"] ;
11 -> 15 ;
16 [label="gini = 0.0\nsamples = 46\nvalue = [0, 46]\nclass = no-collect"] ;
15 -> 16 ;
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
15 -> 17 ;
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
17 -> 18 ;
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
18 -> 19 ;
20 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
18 -> 20 ;
21 [label="gini = 0.0\nsamples = 10\nvalue = [0, 10]\nclass = no-collect"] ;
17 -> 21 ;
22 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
10 -> 22 ;
23 [label="garbage_type <= 1.0\ngini = 0.42\nsamples = 10\nvalue = [7, 3]\nclass = collect"] ;
5 -> 23 ;
24 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
23 -> 24 ;
25 [label="odour_intensity <= 8.841\ngini = 0.219\nsamples = 8\nvalue = [7, 1]\nclass = collect"] ;
23 -> 25 ;
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
25 -> 26 ;
27 [label="distance <= 7.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
25 -> 27 ;
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
27 -> 28 ;
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
27 -> 29 ;
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
31 [label="garbage_weight <= 49.659\ngini = 0.116\nsamples = 81\nvalue = [76, 5]\nclass = collect"] ;
30 -> 31 ;
32 [label="days_since_last_collection <= 24.5\ngini = 0.095\nsamples = 80\nvalue = [76, 4]\nclass = collect"] ;
31 -> 32 ;
33 [label="gini = 0.0\nsamples = 65\nvalue = [65, 0]\nclass = collect"] ;
32 -> 33 ;
34 [label="distance <= 19.0\ngini = 0.391\nsamples = 15\nvalue = [11, 4]\nclass = collect"] ;
32 -> 34 ;
35 [label="fuel <= 6122.0\ngini = 0.444\nsamples = 6\nvalue = [2, 4]\nclass = no-collect"] ;
34 -> 35 ;
36 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
35 -> 36 ;
37 [label="gini = 0.0\nsamples = 4\nvalue = [0, 4]\nclass = no-collect"] ;
35 -> 37 ;
38 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
34 -> 38 ;
39 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
31 -> 39 ;
40 [label="days_since_last_collection <= 13.5\ngini = 0.497\nsamples = 26\nvalue = [12, 14]\nclass = no-collect"] ;
30 -> 40 ;
41 [label="gini = 0.0\nsamples = 8\nvalue = [8, 0]\nclass = collect"] ;
40 -> 41 ;
42 [label="distance <= 20.0\ngini = 0.346\nsamples = 18\nvalue = [4, 14]\nclass = no-collect"] ;
40 -> 42 ;
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
42 -> 43 ;
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
42 -> 44 ;
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
44 -> 45 ;
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
44 -> 46 ;
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
46 -> 47 ;
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
46 -> 48 ;
}

BIN
collect.pdf Normal file

Binary file not shown.

22
main.py
View File

@ -1,5 +1,11 @@
import pygame import pygame
import random import random
import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
import graphviz
from astar import astar from astar import astar
from state import State from state import State
import time import time
@ -71,12 +77,26 @@ def draw_window(agent, fields, flip):
def main(): def main():
train_data = pd.read_csv('./data_set.csv')
attributes = train_data.drop('collect', axis='columns')
e_type = LabelEncoder()
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
label_names = ['collect', 'no-collect']
label = train_data['collect']
print(attr_encoded)
print(label)
classifier = tree.DecisionTreeClassifier()
classifier.fit(attr_encoded, label)
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
graph = graphviz.Source(dot_data)
graph.render('collect')
clock = pygame.time.Clock() clock = pygame.time.Clock()
run = True run = True
x, y = [0, 0] x, y = [0, 0]
agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta
fields, priority_array = randomize_map() fields, priority_array = randomize_map()
print(priority_array)
final_x, final_y = [100, 300] final_x, final_y = [100, 300]
while run: while run:
clock.tick(FPS) clock.tick(FPS)