DecisionTree update

This commit is contained in:
MonoYuku 2021-06-23 11:09:17 +02:00
parent 80e6c700cb
commit 89d1aa7802
16 changed files with 1837 additions and 5 deletions

View File

@ -1 +1 @@
main.py astar.py

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

3
bfs.py
View File

@ -47,7 +47,7 @@ def bfs(pos, direction, end_pos, houses):
if not is_house(curr_node.pos, houses) and curr_node.pos == end_pos: if not is_house(curr_node.pos, houses) and curr_node.pos == end_pos:
while curr_node.parent: while curr_node.parent:
#print(curr_node.pos, end_pos) # print(curr_node.pos, end_pos)
actions.append(curr_node.action) actions.append(curr_node.action)
curr_node = curr_node.parent curr_node = curr_node.parent
return actions return actions
@ -64,3 +64,4 @@ def distance(pos, endpos):
houses = create_houses(40) houses = create_houses(40)
actions = bfs(pos, 0, endpos, houses) actions = bfs(pos, 0, endpos, houses)
return len(actions) return len(actions)

51
csv_gen.py Normal file
View File

@ -0,0 +1,51 @@
import csv
decision = [0, 1] # 0 - go to bin, 1 - pick up
levels = [1, 2, 3, 4, 5]
# 1 - 0;20 2 - 20;40 3 - 40;60 4 - 60;80 5 - 80;100
# 1 - 0;40 2 - 40;80 3 - 80;120 4 - 120;160 5 - 160+
def enough_free_space(available_space, trash_size, available_mass, mass_trash):
if available_space + trash_size <= 5 and available_mass + mass_trash <= 5:
return True
return False
def where_is_closer(bin_distance, trash_distance):
if bin_distance <= trash_distance:
return 0
return 1
with open('tree_dataset.csv', 'w', newline='') as csv_file:
file_writer = csv.writer(csv_file)
file_writer.writerow(["dis_dump", "dis_trash", "mass", "space", "trash_mass", "trash_space", "decision"])
counter = 0
for dis_dump in levels:
for dis_trash in levels:
for mass in levels:
for space in levels:
for trash_mass in levels:
for trash_space in levels:
if counter % 10 == 0:
if dis_dump == 1 and space >= 1 and mass >= 1:
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
elif dis_trash == 1 and enough_free_space(space, trash_space, mass, trash_mass):
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 1])
elif mass == 4 or space == 4 and not enough_free_space(space, trash_space, mass, trash_mass):
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
elif mass == 5 or space == 5:
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
elif mass <= 3 and space <= 3 and enough_free_space(space, trash_space, mass, trash_mass):
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 1])
elif mass == 4 or space == 4 and enough_free_space(space, trash_space, mass, trash_mass):
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, where_is_closer(dis_dump, dis_trash)])
elif not enough_free_space(space, trash_space, mass, trash_mass):
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
else:
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, None])
counter += 1

BIN
decision_tree.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 319 KiB

BIN
img/wet.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

49
main.py
View File

@ -8,16 +8,16 @@ from random import shuffle, choice
import numpy as np import numpy as np
import os import os
import tree
import pygame import pygame
from time import sleep from time import sleep
from os import path
from colors import gray from colors import gray
from house import create_houses from house import create_houses
from truck import Truck from truck import Truck
from trash import Trash from trash import Trash
from TSP import tsp, tspmove from TSP import tsp, tspmove
from bfs import bfs from bfs import bfs, distance
model = load_model("model.h5") model = load_model("model.h5")
@ -38,6 +38,9 @@ def game_keys(truck, multi_trash, houses, auto=False):
print('') print('')
for tindex, trash in enumerate(multi_trash): for tindex, trash in enumerate(multi_trash):
if truck.pos == trash.pos: if truck.pos == trash.pos:
truck.mass += trash.mass
truck.space += trash.space
print(truck.mass, truck.space)
prediction = model.predict(multi_trash[tindex].content) prediction = model.predict(multi_trash[tindex].content)
for i in range (3): for i in range (3):
if multi_trash[tindex].names[i][:3] == 'cat': if multi_trash[tindex].names[i][:3] == 'cat':
@ -106,6 +109,46 @@ def game_loop():
if event.key == pygame.K_ESCAPE: if event.key == pygame.K_ESCAPE:
pygame.quit() pygame.quit()
quit() quit()
if (event.key == pygame.K_l):
if path.isfile('./tree_model') and not os.stat(
'./tree_model').st_size == 0:
decision_tree = tree.load_tree_from_structure('./tree_model')
print("Tree model loaded!")
if (event.key == pygame.K_k):
print(":>")
trash = multi_trash[0]
dis_dump = distance(truck.pos,[80,80])
dis_trash = distance(truck.pos, trash.pos)
print(dis_dump, dis_trash, truck.mass, truck.space, trash.mass, trash.space)
decision = tree.making_decision(decision_tree,
dis_dump // 12 + 1,
dis_trash // 12 + 1,
truck.mass // 20 + 1, truck.space // 20 + 1,
trash.mass // 20 + 1,
trash.space // 20 + 1)
print(decision)
if(decision[0]==0):
actions = bfs(truck.pos, truck.dir_control,
trash.pos, houses)
print(actions)
if not actions:
print('Path couldn\'t be found')
break
print('##################################################')
while actions:
action = actions.pop()
pygame.event.post(pygame.event.Event(
pygame.KEYDOWN, {'key': action}))
game_keys(truck, multi_trash, houses, True)
update_images(gameDisplay, truck, multi_trash, houses)
else:
truck.space=0
truck.mass=0
if (event.key == pygame.K_b): if (event.key == pygame.K_b):
trash = multi_trash[0] trash = multi_trash[0]
actions = bfs(truck.pos, truck.dir_control, actions = bfs(truck.pos, truck.dir_control,

View File

@ -37,6 +37,8 @@ class Trash:
self.size = grid_size self.size = grid_size
self.content = draw_trash(filenames)[0] self.content = draw_trash(filenames)[0]
self.names = draw_trash(filenames)[1] self.names = draw_trash(filenames)[1]
self.mass = random.randint(0, 25)
self.space = random.randint(0, 25)
def new_pos(self, truck_pos, houses, multi): def new_pos(self, truck_pos, houses, multi):
self.trash_content, self.trash_names = draw_trash(filenames) self.trash_content, self.trash_names = draw_trash(filenames)

57
tree.py Normal file
View File

@ -0,0 +1,57 @@
import joblib
import matplotlib.pyplot as plt
import pandas
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
attributes = ["dis_dump", "dis_trash", "mass", "space", "trash_mass", "trash_space"]
decisions = ["decision"]
def learning_tree():
dataset = pandas.read_csv('./tree_dataset.csv')
x = dataset[attributes]
y = dataset[decisions]
decision_tree = DecisionTreeClassifier()
decision_tree = decision_tree.fit(x, y)
return decision_tree
def making_decision(decision_tree, distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass,
trash_space):
decision = decision_tree.predict(
[[distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass, trash_space]])
return decision
def save_all(decision_tree):
save_tree_to_png(decision_tree)
save_tree_to_txt(decision_tree)
save_tree_to_structure(decision_tree)
def save_tree_to_txt(decision_tree):
with open('./tree_in_txt.txt', "w") as file:
file.write(tree.export_text(decision_tree))
def save_tree_to_png(decision_tree):
fig = plt.figure(figsize=(25, 20))
_ = tree.plot_tree(decision_tree, feature_names=attributes, filled=True)
fig.savefig('./decision_tree.png')
def save_tree_to_structure(decision_tree):
joblib.dump(decision_tree, './tree_model')
def load_tree_from_structure(file):
return joblib.load(file)
if __name__ == '__main__':
tre = learning_tree()
save_all(tre)

1564
tree_dataset.csv Normal file

File diff suppressed because it is too large Load Diff

112
tree_in_txt.txt Normal file
View File

@ -0,0 +1,112 @@
|--- feature_2 <= 3.50
| |--- feature_4 <= 3.50
| | |--- feature_3 <= 3.50
| | | |--- feature_0 <= 1.50
| | | | |--- class: 0
| | | |--- feature_0 > 1.50
| | | | |--- feature_4 <= 2.50
| | | | | |--- class: 1
| | | | |--- feature_4 > 2.50
| | | | | |--- feature_2 <= 2.50
| | | | | | |--- class: 1
| | | | | |--- feature_2 > 2.50
| | | | | | |--- class: 0
| | |--- feature_3 > 3.50
| | | |--- feature_3 <= 4.50
| | | | |--- feature_0 <= 2.50
| | | | | |--- feature_1 <= 1.50
| | | | | | |--- feature_0 <= 1.50
| | | | | | | |--- class: 0
| | | | | | |--- feature_0 > 1.50
| | | | | | | |--- feature_2 <= 2.50
| | | | | | | | |--- class: 1
| | | | | | | |--- feature_2 > 2.50
| | | | | | | | |--- feature_4 <= 2.00
| | | | | | | | | |--- class: 1
| | | | | | | | |--- feature_4 > 2.00
| | | | | | | | | |--- class: 0
| | | | | |--- feature_1 > 1.50
| | | | | | |--- class: 0
| | | | |--- feature_0 > 2.50
| | | | | |--- feature_1 <= 3.50
| | | | | | |--- feature_1 <= 2.50
| | | | | | | |--- feature_4 <= 2.50
| | | | | | | | |--- class: 1
| | | | | | | |--- feature_4 > 2.50
| | | | | | | | |--- feature_2 <= 2.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- feature_2 > 2.50
| | | | | | | | | |--- class: 0
| | | | | | |--- feature_1 > 2.50
| | | | | | | |--- feature_0 <= 3.50
| | | | | | | | |--- class: 0
| | | | | | | |--- feature_0 > 3.50
| | | | | | | | |--- feature_4 <= 2.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- feature_4 > 2.50
| | | | | | | | | |--- feature_2 <= 2.50
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- feature_2 > 2.50
| | | | | | | | | | |--- class: 0
| | | | | |--- feature_1 > 3.50
| | | | | | |--- feature_0 <= 4.50
| | | | | | | |--- class: 0
| | | | | | |--- feature_0 > 4.50
| | | | | | | |--- feature_1 <= 4.50
| | | | | | | | |--- feature_4 <= 2.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- feature_4 > 2.50
| | | | | | | | | |--- feature_2 <= 2.00
| | | | | | | | | | |--- class: 1
| | | | | | | | | |--- feature_2 > 2.00
| | | | | | | | | | |--- class: 0
| | | | | | | |--- feature_1 > 4.50
| | | | | | | | |--- class: 0
| | | |--- feature_3 > 4.50
| | | | |--- class: 0
| |--- feature_4 > 3.50
| | |--- feature_2 <= 1.50
| | | |--- feature_4 <= 4.50
| | | | |--- feature_3 <= 3.50
| | | | | |--- feature_0 <= 1.50
| | | | | | |--- class: 0
| | | | | |--- feature_0 > 1.50
| | | | | | |--- class: 1
| | | | |--- feature_3 > 3.50
| | | | | |--- feature_3 <= 4.50
| | | | | | |--- feature_1 <= 3.50
| | | | | | | |--- feature_0 <= 2.50
| | | | | | | | |--- class: 0
| | | | | | | |--- feature_0 > 2.50
| | | | | | | | |--- feature_1 <= 2.50
| | | | | | | | | |--- class: 1
| | | | | | | | |--- feature_1 > 2.50
| | | | | | | | | |--- feature_0 <= 4.00
| | | | | | | | | | |--- class: 0
| | | | | | | | | |--- feature_0 > 4.00
| | | | | | | | | | |--- class: 1
| | | | | | |--- feature_1 > 3.50
| | | | | | | |--- class: 0
| | | | | |--- feature_3 > 4.50
| | | | | | |--- class: 0
| | | |--- feature_4 > 4.50
| | | | |--- class: 0
| | |--- feature_2 > 1.50
| | | |--- class: 0
|--- feature_2 > 3.50
| |--- feature_1 <= 1.50
| | |--- feature_4 <= 1.50
| | | |--- feature_2 <= 4.50
| | | | |--- feature_3 <= 4.50
| | | | | |--- feature_0 <= 1.50
| | | | | | |--- class: 0
| | | | | |--- feature_0 > 1.50
| | | | | | |--- class: 1
| | | | |--- feature_3 > 4.50
| | | | | |--- class: 0
| | | |--- feature_2 > 4.50
| | | | |--- class: 0
| | |--- feature_4 > 1.50
| | | |--- class: 0
| |--- feature_1 > 1.50
| | |--- class: 0

BIN
tree_model Normal file

Binary file not shown.

View File

@ -16,6 +16,8 @@ class Truck:
self.allCats = 0 self.allCats = 0
self.allTrash = 0 self.allTrash = 0
self.trash = 0 self.trash = 0
self.mass=0
self.space=0
def move(self): def move(self):
self.pos[0] += self.direction[0] * self.size self.pos[0] += self.direction[0] * self.size