DecisionTree update
This commit is contained in:
parent
80e6c700cb
commit
89d1aa7802
@ -1 +1 @@
|
|||||||
main.py
|
astar.py
|
Binary file not shown.
Binary file not shown.
BIN
__pycache__/tree.cpython-39.pyc
Normal file
BIN
__pycache__/tree.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
3
bfs.py
3
bfs.py
@ -47,7 +47,7 @@ def bfs(pos, direction, end_pos, houses):
|
|||||||
if not is_house(curr_node.pos, houses) and curr_node.pos == end_pos:
|
if not is_house(curr_node.pos, houses) and curr_node.pos == end_pos:
|
||||||
while curr_node.parent:
|
while curr_node.parent:
|
||||||
|
|
||||||
#print(curr_node.pos, end_pos)
|
# print(curr_node.pos, end_pos)
|
||||||
actions.append(curr_node.action)
|
actions.append(curr_node.action)
|
||||||
curr_node = curr_node.parent
|
curr_node = curr_node.parent
|
||||||
return actions
|
return actions
|
||||||
@ -64,3 +64,4 @@ def distance(pos, endpos):
|
|||||||
houses = create_houses(40)
|
houses = create_houses(40)
|
||||||
actions = bfs(pos, 0, endpos, houses)
|
actions = bfs(pos, 0, endpos, houses)
|
||||||
return len(actions)
|
return len(actions)
|
||||||
|
|
||||||
|
51
csv_gen.py
Normal file
51
csv_gen.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import csv
|
||||||
|
|
||||||
|
decision = [0, 1] # 0 - go to bin, 1 - pick up
|
||||||
|
levels = [1, 2, 3, 4, 5]
|
||||||
|
# 1 - 0;20 2 - 20;40 3 - 40;60 4 - 60;80 5 - 80;100
|
||||||
|
# 1 - 0;40 2 - 40;80 3 - 80;120 4 - 120;160 5 - 160+
|
||||||
|
|
||||||
|
|
||||||
|
def enough_free_space(available_space, trash_size, available_mass, mass_trash):
|
||||||
|
if available_space + trash_size <= 5 and available_mass + mass_trash <= 5:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def where_is_closer(bin_distance, trash_distance):
|
||||||
|
if bin_distance <= trash_distance:
|
||||||
|
return 0
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
with open('tree_dataset.csv', 'w', newline='') as csv_file:
|
||||||
|
file_writer = csv.writer(csv_file)
|
||||||
|
file_writer.writerow(["dis_dump", "dis_trash", "mass", "space", "trash_mass", "trash_space", "decision"])
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
for dis_dump in levels:
|
||||||
|
for dis_trash in levels:
|
||||||
|
for mass in levels:
|
||||||
|
for space in levels:
|
||||||
|
for trash_mass in levels:
|
||||||
|
for trash_space in levels:
|
||||||
|
if counter % 10 == 0:
|
||||||
|
if dis_dump == 1 and space >= 1 and mass >= 1:
|
||||||
|
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
|
||||||
|
elif dis_trash == 1 and enough_free_space(space, trash_space, mass, trash_mass):
|
||||||
|
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 1])
|
||||||
|
elif mass == 4 or space == 4 and not enough_free_space(space, trash_space, mass, trash_mass):
|
||||||
|
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
|
||||||
|
elif mass == 5 or space == 5:
|
||||||
|
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
|
||||||
|
elif mass <= 3 and space <= 3 and enough_free_space(space, trash_space, mass, trash_mass):
|
||||||
|
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 1])
|
||||||
|
elif mass == 4 or space == 4 and enough_free_space(space, trash_space, mass, trash_mass):
|
||||||
|
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, where_is_closer(dis_dump, dis_trash)])
|
||||||
|
elif not enough_free_space(space, trash_space, mass, trash_mass):
|
||||||
|
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
|
||||||
|
else:
|
||||||
|
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, None])
|
||||||
|
|
||||||
|
counter += 1
|
BIN
decision_tree.png
Normal file
BIN
decision_tree.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 319 KiB |
BIN
img/wet.png
Normal file
BIN
img/wet.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
49
main.py
49
main.py
@ -8,16 +8,16 @@ from random import shuffle, choice
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import tree
|
||||||
import pygame
|
import pygame
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
from os import path
|
||||||
from colors import gray
|
from colors import gray
|
||||||
from house import create_houses
|
from house import create_houses
|
||||||
from truck import Truck
|
from truck import Truck
|
||||||
from trash import Trash
|
from trash import Trash
|
||||||
from TSP import tsp, tspmove
|
from TSP import tsp, tspmove
|
||||||
from bfs import bfs
|
from bfs import bfs, distance
|
||||||
|
|
||||||
|
|
||||||
model = load_model("model.h5")
|
model = load_model("model.h5")
|
||||||
@ -38,6 +38,9 @@ def game_keys(truck, multi_trash, houses, auto=False):
|
|||||||
print('↑')
|
print('↑')
|
||||||
for tindex, trash in enumerate(multi_trash):
|
for tindex, trash in enumerate(multi_trash):
|
||||||
if truck.pos == trash.pos:
|
if truck.pos == trash.pos:
|
||||||
|
truck.mass += trash.mass
|
||||||
|
truck.space += trash.space
|
||||||
|
print(truck.mass, truck.space)
|
||||||
prediction = model.predict(multi_trash[tindex].content)
|
prediction = model.predict(multi_trash[tindex].content)
|
||||||
for i in range (3):
|
for i in range (3):
|
||||||
if multi_trash[tindex].names[i][:3] == 'cat':
|
if multi_trash[tindex].names[i][:3] == 'cat':
|
||||||
@ -106,6 +109,46 @@ def game_loop():
|
|||||||
if event.key == pygame.K_ESCAPE:
|
if event.key == pygame.K_ESCAPE:
|
||||||
pygame.quit()
|
pygame.quit()
|
||||||
quit()
|
quit()
|
||||||
|
if (event.key == pygame.K_l):
|
||||||
|
if path.isfile('./tree_model') and not os.stat(
|
||||||
|
'./tree_model').st_size == 0:
|
||||||
|
decision_tree = tree.load_tree_from_structure('./tree_model')
|
||||||
|
print("Tree model loaded!")
|
||||||
|
|
||||||
|
if (event.key == pygame.K_k):
|
||||||
|
print(":>")
|
||||||
|
trash = multi_trash[0]
|
||||||
|
dis_dump = distance(truck.pos,[80,80])
|
||||||
|
dis_trash = distance(truck.pos, trash.pos)
|
||||||
|
print(dis_dump, dis_trash, truck.mass, truck.space, trash.mass, trash.space)
|
||||||
|
decision = tree.making_decision(decision_tree,
|
||||||
|
dis_dump // 12 + 1,
|
||||||
|
dis_trash // 12 + 1,
|
||||||
|
truck.mass // 20 + 1, truck.space // 20 + 1,
|
||||||
|
trash.mass // 20 + 1,
|
||||||
|
trash.space // 20 + 1)
|
||||||
|
|
||||||
|
print(decision)
|
||||||
|
if(decision[0]==0):
|
||||||
|
actions = bfs(truck.pos, truck.dir_control,
|
||||||
|
trash.pos, houses)
|
||||||
|
print(actions)
|
||||||
|
if not actions:
|
||||||
|
print('Path couldn\'t be found')
|
||||||
|
break
|
||||||
|
|
||||||
|
print('##################################################')
|
||||||
|
while actions:
|
||||||
|
action = actions.pop()
|
||||||
|
pygame.event.post(pygame.event.Event(
|
||||||
|
pygame.KEYDOWN, {'key': action}))
|
||||||
|
|
||||||
|
game_keys(truck, multi_trash, houses, True)
|
||||||
|
update_images(gameDisplay, truck, multi_trash, houses)
|
||||||
|
else:
|
||||||
|
truck.space=0
|
||||||
|
truck.mass=0
|
||||||
|
|
||||||
if (event.key == pygame.K_b):
|
if (event.key == pygame.K_b):
|
||||||
trash = multi_trash[0]
|
trash = multi_trash[0]
|
||||||
actions = bfs(truck.pos, truck.dir_control,
|
actions = bfs(truck.pos, truck.dir_control,
|
||||||
|
2
trash.py
2
trash.py
@ -37,6 +37,8 @@ class Trash:
|
|||||||
self.size = grid_size
|
self.size = grid_size
|
||||||
self.content = draw_trash(filenames)[0]
|
self.content = draw_trash(filenames)[0]
|
||||||
self.names = draw_trash(filenames)[1]
|
self.names = draw_trash(filenames)[1]
|
||||||
|
self.mass = random.randint(0, 25)
|
||||||
|
self.space = random.randint(0, 25)
|
||||||
|
|
||||||
def new_pos(self, truck_pos, houses, multi):
|
def new_pos(self, truck_pos, houses, multi):
|
||||||
self.trash_content, self.trash_names = draw_trash(filenames)
|
self.trash_content, self.trash_names = draw_trash(filenames)
|
||||||
|
57
tree.py
Normal file
57
tree.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import joblib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas
|
||||||
|
from sklearn import tree
|
||||||
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
|
|
||||||
|
attributes = ["dis_dump", "dis_trash", "mass", "space", "trash_mass", "trash_space"]
|
||||||
|
decisions = ["decision"]
|
||||||
|
|
||||||
|
|
||||||
|
def learning_tree():
|
||||||
|
dataset = pandas.read_csv('./tree_dataset.csv')
|
||||||
|
|
||||||
|
x = dataset[attributes]
|
||||||
|
y = dataset[decisions]
|
||||||
|
|
||||||
|
decision_tree = DecisionTreeClassifier()
|
||||||
|
decision_tree = decision_tree.fit(x, y)
|
||||||
|
|
||||||
|
return decision_tree
|
||||||
|
|
||||||
|
|
||||||
|
def making_decision(decision_tree, distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass,
|
||||||
|
trash_space):
|
||||||
|
decision = decision_tree.predict(
|
||||||
|
[[distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass, trash_space]])
|
||||||
|
|
||||||
|
return decision
|
||||||
|
|
||||||
|
|
||||||
|
def save_all(decision_tree):
|
||||||
|
save_tree_to_png(decision_tree)
|
||||||
|
save_tree_to_txt(decision_tree)
|
||||||
|
save_tree_to_structure(decision_tree)
|
||||||
|
|
||||||
|
|
||||||
|
def save_tree_to_txt(decision_tree):
|
||||||
|
with open('./tree_in_txt.txt', "w") as file:
|
||||||
|
file.write(tree.export_text(decision_tree))
|
||||||
|
|
||||||
|
|
||||||
|
def save_tree_to_png(decision_tree):
|
||||||
|
fig = plt.figure(figsize=(25, 20))
|
||||||
|
_ = tree.plot_tree(decision_tree, feature_names=attributes, filled=True)
|
||||||
|
fig.savefig('./decision_tree.png')
|
||||||
|
|
||||||
|
|
||||||
|
def save_tree_to_structure(decision_tree):
|
||||||
|
joblib.dump(decision_tree, './tree_model')
|
||||||
|
|
||||||
|
|
||||||
|
def load_tree_from_structure(file):
|
||||||
|
return joblib.load(file)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tre = learning_tree()
|
||||||
|
save_all(tre)
|
1564
tree_dataset.csv
Normal file
1564
tree_dataset.csv
Normal file
File diff suppressed because it is too large
Load Diff
112
tree_in_txt.txt
Normal file
112
tree_in_txt.txt
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
|--- feature_2 <= 3.50
|
||||||
|
| |--- feature_4 <= 3.50
|
||||||
|
| | |--- feature_3 <= 3.50
|
||||||
|
| | | |--- feature_0 <= 1.50
|
||||||
|
| | | | |--- class: 0
|
||||||
|
| | | |--- feature_0 > 1.50
|
||||||
|
| | | | |--- feature_4 <= 2.50
|
||||||
|
| | | | | |--- class: 1
|
||||||
|
| | | | |--- feature_4 > 2.50
|
||||||
|
| | | | | |--- feature_2 <= 2.50
|
||||||
|
| | | | | | |--- class: 1
|
||||||
|
| | | | | |--- feature_2 > 2.50
|
||||||
|
| | | | | | |--- class: 0
|
||||||
|
| | |--- feature_3 > 3.50
|
||||||
|
| | | |--- feature_3 <= 4.50
|
||||||
|
| | | | |--- feature_0 <= 2.50
|
||||||
|
| | | | | |--- feature_1 <= 1.50
|
||||||
|
| | | | | | |--- feature_0 <= 1.50
|
||||||
|
| | | | | | | |--- class: 0
|
||||||
|
| | | | | | |--- feature_0 > 1.50
|
||||||
|
| | | | | | | |--- feature_2 <= 2.50
|
||||||
|
| | | | | | | | |--- class: 1
|
||||||
|
| | | | | | | |--- feature_2 > 2.50
|
||||||
|
| | | | | | | | |--- feature_4 <= 2.00
|
||||||
|
| | | | | | | | | |--- class: 1
|
||||||
|
| | | | | | | | |--- feature_4 > 2.00
|
||||||
|
| | | | | | | | | |--- class: 0
|
||||||
|
| | | | | |--- feature_1 > 1.50
|
||||||
|
| | | | | | |--- class: 0
|
||||||
|
| | | | |--- feature_0 > 2.50
|
||||||
|
| | | | | |--- feature_1 <= 3.50
|
||||||
|
| | | | | | |--- feature_1 <= 2.50
|
||||||
|
| | | | | | | |--- feature_4 <= 2.50
|
||||||
|
| | | | | | | | |--- class: 1
|
||||||
|
| | | | | | | |--- feature_4 > 2.50
|
||||||
|
| | | | | | | | |--- feature_2 <= 2.50
|
||||||
|
| | | | | | | | | |--- class: 1
|
||||||
|
| | | | | | | | |--- feature_2 > 2.50
|
||||||
|
| | | | | | | | | |--- class: 0
|
||||||
|
| | | | | | |--- feature_1 > 2.50
|
||||||
|
| | | | | | | |--- feature_0 <= 3.50
|
||||||
|
| | | | | | | | |--- class: 0
|
||||||
|
| | | | | | | |--- feature_0 > 3.50
|
||||||
|
| | | | | | | | |--- feature_4 <= 2.50
|
||||||
|
| | | | | | | | | |--- class: 1
|
||||||
|
| | | | | | | | |--- feature_4 > 2.50
|
||||||
|
| | | | | | | | | |--- feature_2 <= 2.50
|
||||||
|
| | | | | | | | | | |--- class: 1
|
||||||
|
| | | | | | | | | |--- feature_2 > 2.50
|
||||||
|
| | | | | | | | | | |--- class: 0
|
||||||
|
| | | | | |--- feature_1 > 3.50
|
||||||
|
| | | | | | |--- feature_0 <= 4.50
|
||||||
|
| | | | | | | |--- class: 0
|
||||||
|
| | | | | | |--- feature_0 > 4.50
|
||||||
|
| | | | | | | |--- feature_1 <= 4.50
|
||||||
|
| | | | | | | | |--- feature_4 <= 2.50
|
||||||
|
| | | | | | | | | |--- class: 1
|
||||||
|
| | | | | | | | |--- feature_4 > 2.50
|
||||||
|
| | | | | | | | | |--- feature_2 <= 2.00
|
||||||
|
| | | | | | | | | | |--- class: 1
|
||||||
|
| | | | | | | | | |--- feature_2 > 2.00
|
||||||
|
| | | | | | | | | | |--- class: 0
|
||||||
|
| | | | | | | |--- feature_1 > 4.50
|
||||||
|
| | | | | | | | |--- class: 0
|
||||||
|
| | | |--- feature_3 > 4.50
|
||||||
|
| | | | |--- class: 0
|
||||||
|
| |--- feature_4 > 3.50
|
||||||
|
| | |--- feature_2 <= 1.50
|
||||||
|
| | | |--- feature_4 <= 4.50
|
||||||
|
| | | | |--- feature_3 <= 3.50
|
||||||
|
| | | | | |--- feature_0 <= 1.50
|
||||||
|
| | | | | | |--- class: 0
|
||||||
|
| | | | | |--- feature_0 > 1.50
|
||||||
|
| | | | | | |--- class: 1
|
||||||
|
| | | | |--- feature_3 > 3.50
|
||||||
|
| | | | | |--- feature_3 <= 4.50
|
||||||
|
| | | | | | |--- feature_1 <= 3.50
|
||||||
|
| | | | | | | |--- feature_0 <= 2.50
|
||||||
|
| | | | | | | | |--- class: 0
|
||||||
|
| | | | | | | |--- feature_0 > 2.50
|
||||||
|
| | | | | | | | |--- feature_1 <= 2.50
|
||||||
|
| | | | | | | | | |--- class: 1
|
||||||
|
| | | | | | | | |--- feature_1 > 2.50
|
||||||
|
| | | | | | | | | |--- feature_0 <= 4.00
|
||||||
|
| | | | | | | | | | |--- class: 0
|
||||||
|
| | | | | | | | | |--- feature_0 > 4.00
|
||||||
|
| | | | | | | | | | |--- class: 1
|
||||||
|
| | | | | | |--- feature_1 > 3.50
|
||||||
|
| | | | | | | |--- class: 0
|
||||||
|
| | | | | |--- feature_3 > 4.50
|
||||||
|
| | | | | | |--- class: 0
|
||||||
|
| | | |--- feature_4 > 4.50
|
||||||
|
| | | | |--- class: 0
|
||||||
|
| | |--- feature_2 > 1.50
|
||||||
|
| | | |--- class: 0
|
||||||
|
|--- feature_2 > 3.50
|
||||||
|
| |--- feature_1 <= 1.50
|
||||||
|
| | |--- feature_4 <= 1.50
|
||||||
|
| | | |--- feature_2 <= 4.50
|
||||||
|
| | | | |--- feature_3 <= 4.50
|
||||||
|
| | | | | |--- feature_0 <= 1.50
|
||||||
|
| | | | | | |--- class: 0
|
||||||
|
| | | | | |--- feature_0 > 1.50
|
||||||
|
| | | | | | |--- class: 1
|
||||||
|
| | | | |--- feature_3 > 4.50
|
||||||
|
| | | | | |--- class: 0
|
||||||
|
| | | |--- feature_2 > 4.50
|
||||||
|
| | | | |--- class: 0
|
||||||
|
| | |--- feature_4 > 1.50
|
||||||
|
| | | |--- class: 0
|
||||||
|
| |--- feature_1 > 1.50
|
||||||
|
| | |--- class: 0
|
BIN
tree_model
Normal file
BIN
tree_model
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user