trained decision tree

This commit is contained in:
Mateusz 2023-05-26 10:52:02 +02:00
parent 0738e210b1
commit 63596f20f2
4 changed files with 123 additions and 1 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/__pycache__

101
collect Normal file
View File

@ -0,0 +1,101 @@
digraph Tree {
node [shape=box, fontname="helvetica"] ;
edge [fontname="helvetica"] ;
0 [label="distance <= 14.5\ngini = 0.484\nsamples = 200\nvalue = [118, 82]\nclass = collect"] ;
1 [label="paid_on_time <= 0.5\ngini = 0.437\nsamples = 93\nvalue = [30, 63]\nclass = no-collect"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
1 -> 2 ;
3 [label="fuel <= 1508.0\ngini = 0.375\nsamples = 84\nvalue = [21, 63]\nclass = no-collect"] ;
1 -> 3 ;
4 [label="gini = 0.0\nsamples = 7\nvalue = [7, 0]\nclass = collect"] ;
3 -> 4 ;
5 [label="space_occupied <= 0.856\ngini = 0.298\nsamples = 77\nvalue = [14, 63]\nclass = no-collect"] ;
3 -> 5 ;
6 [label="days_since_last_collection <= 4.5\ngini = 0.187\nsamples = 67\nvalue = [7, 60]\nclass = no-collect"] ;
5 -> 6 ;
7 [label="fuel <= 11519.5\ngini = 0.48\nsamples = 5\nvalue = [3, 2]\nclass = collect"] ;
6 -> 7 ;
8 [label="gini = 0.0\nsamples = 3\nvalue = [3, 0]\nclass = collect"] ;
7 -> 8 ;
9 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
7 -> 9 ;
10 [label="fuel <= 16955.0\ngini = 0.121\nsamples = 62\nvalue = [4, 58]\nclass = no-collect"] ;
6 -> 10 ;
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
10 -> 11 ;
12 [label="odour_intensity <= 5.682\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
11 -> 12 ;
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
12 -> 13 ;
14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
12 -> 14 ;
15 [label="garbage_type <= 2.5\ngini = 0.065\nsamples = 59\nvalue = [2, 57]\nclass = no-collect"] ;
11 -> 15 ;
16 [label="gini = 0.0\nsamples = 46\nvalue = [0, 46]\nclass = no-collect"] ;
15 -> 16 ;
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
15 -> 17 ;
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
17 -> 18 ;
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
18 -> 19 ;
20 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
18 -> 20 ;
21 [label="gini = 0.0\nsamples = 10\nvalue = [0, 10]\nclass = no-collect"] ;
17 -> 21 ;
22 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
10 -> 22 ;
23 [label="garbage_type <= 1.0\ngini = 0.42\nsamples = 10\nvalue = [7, 3]\nclass = collect"] ;
5 -> 23 ;
24 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
23 -> 24 ;
25 [label="odour_intensity <= 8.841\ngini = 0.219\nsamples = 8\nvalue = [7, 1]\nclass = collect"] ;
23 -> 25 ;
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
25 -> 26 ;
27 [label="distance <= 7.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
25 -> 27 ;
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
27 -> 28 ;
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
27 -> 29 ;
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
31 [label="garbage_weight <= 49.659\ngini = 0.116\nsamples = 81\nvalue = [76, 5]\nclass = collect"] ;
30 -> 31 ;
32 [label="days_since_last_collection <= 24.5\ngini = 0.095\nsamples = 80\nvalue = [76, 4]\nclass = collect"] ;
31 -> 32 ;
33 [label="gini = 0.0\nsamples = 65\nvalue = [65, 0]\nclass = collect"] ;
32 -> 33 ;
34 [label="distance <= 19.0\ngini = 0.391\nsamples = 15\nvalue = [11, 4]\nclass = collect"] ;
32 -> 34 ;
35 [label="fuel <= 6122.0\ngini = 0.444\nsamples = 6\nvalue = [2, 4]\nclass = no-collect"] ;
34 -> 35 ;
36 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
35 -> 36 ;
37 [label="gini = 0.0\nsamples = 4\nvalue = [0, 4]\nclass = no-collect"] ;
35 -> 37 ;
38 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
34 -> 38 ;
39 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
31 -> 39 ;
40 [label="days_since_last_collection <= 13.5\ngini = 0.497\nsamples = 26\nvalue = [12, 14]\nclass = no-collect"] ;
30 -> 40 ;
41 [label="gini = 0.0\nsamples = 8\nvalue = [8, 0]\nclass = collect"] ;
40 -> 41 ;
42 [label="distance <= 20.0\ngini = 0.346\nsamples = 18\nvalue = [4, 14]\nclass = no-collect"] ;
40 -> 42 ;
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
42 -> 43 ;
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
42 -> 44 ;
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
44 -> 45 ;
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
44 -> 46 ;
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
46 -> 47 ;
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
46 -> 48 ;
}

BIN
collect.pdf Normal file

Binary file not shown.

22
main.py
View File

@ -1,5 +1,11 @@
import pygame
import random
import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
import graphviz
from astar import astar
from state import State
import time
@ -71,12 +77,26 @@ def draw_window(agent, fields, flip):
def main():
train_data = pd.read_csv('./data_set.csv')
attributes = train_data.drop('collect', axis='columns')
e_type = LabelEncoder()
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
label_names = ['collect', 'no-collect']
label = train_data['collect']
print(attr_encoded)
print(label)
classifier = tree.DecisionTreeClassifier()
classifier.fit(attr_encoded, label)
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
graph = graphviz.Source(dot_data)
graph.render('collect')
clock = pygame.time.Clock()
run = True
x, y = [0, 0]
agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta
fields, priority_array = randomize_map()
print(priority_array)
final_x, final_y = [100, 300]
while run:
clock.tick(FPS)