ID3 #3

Merged
s473585 merged 4 commits from ID3 into master 2023-05-30 16:22:25 +02:00
4 changed files with 123 additions and 1 deletions
Showing only changes of commit 63596f20f2 - Show all commits

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/__pycache__

101
collect Normal file
View File

@ -0,0 +1,101 @@
digraph Tree {
node [shape=box, fontname="helvetica"] ;
edge [fontname="helvetica"] ;
0 [label="distance <= 14.5\ngini = 0.484\nsamples = 200\nvalue = [118, 82]\nclass = collect"] ;
1 [label="paid_on_time <= 0.5\ngini = 0.437\nsamples = 93\nvalue = [30, 63]\nclass = no-collect"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
1 -> 2 ;
3 [label="fuel <= 1508.0\ngini = 0.375\nsamples = 84\nvalue = [21, 63]\nclass = no-collect"] ;
1 -> 3 ;
4 [label="gini = 0.0\nsamples = 7\nvalue = [7, 0]\nclass = collect"] ;
3 -> 4 ;
5 [label="space_occupied <= 0.856\ngini = 0.298\nsamples = 77\nvalue = [14, 63]\nclass = no-collect"] ;
3 -> 5 ;
6 [label="days_since_last_collection <= 4.5\ngini = 0.187\nsamples = 67\nvalue = [7, 60]\nclass = no-collect"] ;
5 -> 6 ;
7 [label="fuel <= 11519.5\ngini = 0.48\nsamples = 5\nvalue = [3, 2]\nclass = collect"] ;
6 -> 7 ;
8 [label="gini = 0.0\nsamples = 3\nvalue = [3, 0]\nclass = collect"] ;
7 -> 8 ;
9 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
7 -> 9 ;
10 [label="fuel <= 16955.0\ngini = 0.121\nsamples = 62\nvalue = [4, 58]\nclass = no-collect"] ;
6 -> 10 ;
11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ;
10 -> 11 ;
12 [label="odour_intensity <= 5.682\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
11 -> 12 ;
13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
12 -> 13 ;
14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
12 -> 14 ;
15 [label="garbage_type <= 2.5\ngini = 0.065\nsamples = 59\nvalue = [2, 57]\nclass = no-collect"] ;
11 -> 15 ;
16 [label="gini = 0.0\nsamples = 46\nvalue = [0, 46]\nclass = no-collect"] ;
15 -> 16 ;
17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ;
15 -> 17 ;
18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ;
17 -> 18 ;
19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
18 -> 19 ;
20 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
18 -> 20 ;
21 [label="gini = 0.0\nsamples = 10\nvalue = [0, 10]\nclass = no-collect"] ;
17 -> 21 ;
22 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
10 -> 22 ;
23 [label="garbage_type <= 1.0\ngini = 0.42\nsamples = 10\nvalue = [7, 3]\nclass = collect"] ;
5 -> 23 ;
24 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ;
23 -> 24 ;
25 [label="odour_intensity <= 8.841\ngini = 0.219\nsamples = 8\nvalue = [7, 1]\nclass = collect"] ;
23 -> 25 ;
26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ;
25 -> 26 ;
27 [label="distance <= 7.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ;
25 -> 27 ;
28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ;
27 -> 28 ;
29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
27 -> 29 ;
30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ;
0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
31 [label="garbage_weight <= 49.659\ngini = 0.116\nsamples = 81\nvalue = [76, 5]\nclass = collect"] ;
30 -> 31 ;
32 [label="days_since_last_collection <= 24.5\ngini = 0.095\nsamples = 80\nvalue = [76, 4]\nclass = collect"] ;
31 -> 32 ;
33 [label="gini = 0.0\nsamples = 65\nvalue = [65, 0]\nclass = collect"] ;
32 -> 33 ;
34 [label="distance <= 19.0\ngini = 0.391\nsamples = 15\nvalue = [11, 4]\nclass = collect"] ;
32 -> 34 ;
35 [label="fuel <= 6122.0\ngini = 0.444\nsamples = 6\nvalue = [2, 4]\nclass = no-collect"] ;
34 -> 35 ;
36 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
35 -> 36 ;
37 [label="gini = 0.0\nsamples = 4\nvalue = [0, 4]\nclass = no-collect"] ;
35 -> 37 ;
38 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ;
34 -> 38 ;
39 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ;
31 -> 39 ;
40 [label="days_since_last_collection <= 13.5\ngini = 0.497\nsamples = 26\nvalue = [12, 14]\nclass = no-collect"] ;
30 -> 40 ;
41 [label="gini = 0.0\nsamples = 8\nvalue = [8, 0]\nclass = collect"] ;
40 -> 41 ;
42 [label="distance <= 20.0\ngini = 0.346\nsamples = 18\nvalue = [4, 14]\nclass = no-collect"] ;
40 -> 42 ;
43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ;
42 -> 43 ;
44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ;
42 -> 44 ;
45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
44 -> 45 ;
46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ;
44 -> 46 ;
47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ;
46 -> 47 ;
48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ;
46 -> 48 ;
}

BIN
collect.pdf Normal file

Binary file not shown.

22
main.py
View File

@ -1,5 +1,11 @@
import pygame import pygame
import random import random
import pandas as pd
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
import graphviz
from astar import astar from astar import astar
from state import State from state import State
import time import time
@ -71,12 +77,26 @@ def draw_window(agent, fields, flip):
def main(): def main():
train_data = pd.read_csv('./data_set.csv')
attributes = train_data.drop('collect', axis='columns')
e_type = LabelEncoder()
attributes['type_num'] = e_type.fit_transform(attributes['garbage_type'])
attr_encoded = attributes.drop(['garbage_type'], axis='columns')
attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type']
label_names = ['collect', 'no-collect']
label = train_data['collect']
print(attr_encoded)
print(label)
classifier = tree.DecisionTreeClassifier()
classifier.fit(attr_encoded, label)
dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names)
graph = graphviz.Source(dot_data)
graph.render('collect')
clock = pygame.time.Clock() clock = pygame.time.Clock()
run = True run = True
x, y = [0, 0] x, y = [0, 0]
agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta
fields, priority_array = randomize_map() fields, priority_array = randomize_map()
print(priority_array)
final_x, final_y = [100, 300] final_x, final_y = [100, 300]
while run: while run:
clock.tick(FPS) clock.tick(FPS)