diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1d7901e --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/__pycache__ \ No newline at end of file diff --git a/collect b/collect new file mode 100644 index 0000000..668d904 --- /dev/null +++ b/collect @@ -0,0 +1,101 @@ +digraph Tree { +node [shape=box, fontname="helvetica"] ; +edge [fontname="helvetica"] ; +0 [label="distance <= 14.5\ngini = 0.484\nsamples = 200\nvalue = [118, 82]\nclass = collect"] ; +1 [label="paid_on_time <= 0.5\ngini = 0.437\nsamples = 93\nvalue = [30, 63]\nclass = no-collect"] ; +0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ; +2 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ; +1 -> 2 ; +3 [label="fuel <= 1508.0\ngini = 0.375\nsamples = 84\nvalue = [21, 63]\nclass = no-collect"] ; +1 -> 3 ; +4 [label="gini = 0.0\nsamples = 7\nvalue = [7, 0]\nclass = collect"] ; +3 -> 4 ; +5 [label="space_occupied <= 0.856\ngini = 0.298\nsamples = 77\nvalue = [14, 63]\nclass = no-collect"] ; +3 -> 5 ; +6 [label="days_since_last_collection <= 4.5\ngini = 0.187\nsamples = 67\nvalue = [7, 60]\nclass = no-collect"] ; +5 -> 6 ; +7 [label="fuel <= 11519.5\ngini = 0.48\nsamples = 5\nvalue = [3, 2]\nclass = collect"] ; +6 -> 7 ; +8 [label="gini = 0.0\nsamples = 3\nvalue = [3, 0]\nclass = collect"] ; +7 -> 8 ; +9 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ; +7 -> 9 ; +10 [label="fuel <= 16955.0\ngini = 0.121\nsamples = 62\nvalue = [4, 58]\nclass = no-collect"] ; +6 -> 10 ; +11 [label="garbage_weight <= 0.612\ngini = 0.094\nsamples = 61\nvalue = [3, 58]\nclass = no-collect"] ; +10 -> 11 ; +12 [label="odour_intensity <= 5.682\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ; +11 -> 12 ; +13 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; +12 -> 13 ; +14 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ; +12 -> 14 ; +15 [label="garbage_type <= 2.5\ngini = 0.065\nsamples = 59\nvalue = [2, 57]\nclass = no-collect"] ; +11 -> 15 ; +16 [label="gini = 0.0\nsamples = 46\nvalue = [0, 46]\nclass = no-collect"] ; +15 -> 16 ; +17 [label="garbage_weight <= 15.925\ngini = 0.26\nsamples = 13\nvalue = [2, 11]\nclass = no-collect"] ; +15 -> 17 ; +18 [label="fuel <= 13561.0\ngini = 0.444\nsamples = 3\nvalue = [2, 1]\nclass = collect"] ; +17 -> 18 ; +19 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; +18 -> 19 ; +20 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ; +18 -> 20 ; +21 [label="gini = 0.0\nsamples = 10\nvalue = [0, 10]\nclass = no-collect"] ; +17 -> 21 ; +22 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; +10 -> 22 ; +23 [label="garbage_type <= 1.0\ngini = 0.42\nsamples = 10\nvalue = [7, 3]\nclass = collect"] ; +5 -> 23 ; +24 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = no-collect"] ; +23 -> 24 ; +25 [label="odour_intensity <= 8.841\ngini = 0.219\nsamples = 8\nvalue = [7, 1]\nclass = collect"] ; +23 -> 25 ; +26 [label="gini = 0.0\nsamples = 6\nvalue = [6, 0]\nclass = collect"] ; +25 -> 26 ; +27 [label="distance <= 7.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = collect"] ; +25 -> 27 ; +28 [label="gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = collect"] ; +27 -> 28 ; +29 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ; +27 -> 29 ; +30 [label="odour_intensity <= 7.156\ngini = 0.292\nsamples = 107\nvalue = [88, 19]\nclass = collect"] ; +0 -> 30 [labeldistance=2.5, labelangle=-45, headlabel="False"] ; +31 [label="garbage_weight <= 49.659\ngini = 0.116\nsamples = 81\nvalue = [76, 5]\nclass = collect"] ; +30 -> 31 ; +32 [label="days_since_last_collection <= 24.5\ngini = 0.095\nsamples = 80\nvalue = [76, 4]\nclass = collect"] ; +31 -> 32 ; +33 [label="gini = 0.0\nsamples = 65\nvalue = [65, 0]\nclass = collect"] ; +32 -> 33 ; +34 [label="distance <= 19.0\ngini = 0.391\nsamples = 15\nvalue = [11, 4]\nclass = collect"] ; +32 -> 34 ; +35 [label="fuel <= 6122.0\ngini = 0.444\nsamples = 6\nvalue = [2, 4]\nclass = no-collect"] ; +34 -> 35 ; +36 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; +35 -> 36 ; +37 [label="gini = 0.0\nsamples = 4\nvalue = [0, 4]\nclass = no-collect"] ; +35 -> 37 ; +38 [label="gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = collect"] ; +34 -> 38 ; +39 [label="gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = no-collect"] ; +31 -> 39 ; +40 [label="days_since_last_collection <= 13.5\ngini = 0.497\nsamples = 26\nvalue = [12, 14]\nclass = no-collect"] ; +30 -> 40 ; +41 [label="gini = 0.0\nsamples = 8\nvalue = [8, 0]\nclass = collect"] ; +40 -> 41 ; +42 [label="distance <= 20.0\ngini = 0.346\nsamples = 18\nvalue = [4, 14]\nclass = no-collect"] ; +40 -> 42 ; +43 [label="gini = 0.0\nsamples = 8\nvalue = [0, 8]\nclass = no-collect"] ; +42 -> 43 ; +44 [label="distance <= 24.0\ngini = 0.48\nsamples = 10\nvalue = [4, 6]\nclass = no-collect"] ; +42 -> 44 ; +45 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; +44 -> 45 ; +46 [label="space_occupied <= 0.243\ngini = 0.375\nsamples = 8\nvalue = [2, 6]\nclass = no-collect"] ; +44 -> 46 ; +47 [label="gini = 0.0\nsamples = 2\nvalue = [2, 0]\nclass = collect"] ; +46 -> 47 ; +48 [label="gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = no-collect"] ; +46 -> 48 ; +} diff --git a/collect.pdf b/collect.pdf new file mode 100644 index 0000000..0bc3799 Binary files /dev/null and b/collect.pdf differ diff --git a/main.py b/main.py index 2c9e7a5..3123f69 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,11 @@ import pygame import random +import pandas as pd +from sklearn import tree +from sklearn.preprocessing import LabelEncoder +import graphviz + + from astar import astar from state import State import time @@ -71,12 +77,26 @@ def draw_window(agent, fields, flip): def main(): + train_data = pd.read_csv('./data_set.csv') + attributes = train_data.drop('collect', axis='columns') + e_type = LabelEncoder() + attributes['type_num'] = e_type.fit_transform(attributes['garbage_type']) + attr_encoded = attributes.drop(['garbage_type'], axis='columns') + attr_names = ['fuel','distance','space_occupied','days_since_last_collection','paid_on_time','odour_intensity','garbage_weight', 'garbage_type'] + label_names = ['collect', 'no-collect'] + label = train_data['collect'] + print(attr_encoded) + print(label) + classifier = tree.DecisionTreeClassifier() + classifier.fit(attr_encoded, label) + dot_data = tree.export_graphviz(classifier, out_file=None, feature_names=attr_names, class_names=label_names) + graph = graphviz.Source(dot_data) + graph.render('collect') clock = pygame.time.Clock() run = True x, y = [0, 0] agent = GarbageTruck(0, 0, pygame.Rect(x, y, 50, 50), 0) # tworzenie pola dla agenta fields, priority_array = randomize_map() - print(priority_array) final_x, final_y = [100, 300] while run: clock.tick(FPS)