DecisionTree update
This commit is contained in:
parent
80e6c700cb
commit
89d1aa7802
@ -1 +1 @@
|
||||
main.py
|
||||
astar.py
|
Binary file not shown.
Binary file not shown.
BIN
__pycache__/tree.cpython-39.pyc
Normal file
BIN
__pycache__/tree.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
3
bfs.py
3
bfs.py
@ -47,7 +47,7 @@ def bfs(pos, direction, end_pos, houses):
|
||||
if not is_house(curr_node.pos, houses) and curr_node.pos == end_pos:
|
||||
while curr_node.parent:
|
||||
|
||||
#print(curr_node.pos, end_pos)
|
||||
# print(curr_node.pos, end_pos)
|
||||
actions.append(curr_node.action)
|
||||
curr_node = curr_node.parent
|
||||
return actions
|
||||
@ -64,3 +64,4 @@ def distance(pos, endpos):
|
||||
houses = create_houses(40)
|
||||
actions = bfs(pos, 0, endpos, houses)
|
||||
return len(actions)
|
||||
|
||||
|
51
csv_gen.py
Normal file
51
csv_gen.py
Normal file
@ -0,0 +1,51 @@
|
||||
import csv
|
||||
|
||||
decision = [0, 1] # 0 - go to bin, 1 - pick up
|
||||
levels = [1, 2, 3, 4, 5]
|
||||
# 1 - 0;20 2 - 20;40 3 - 40;60 4 - 60;80 5 - 80;100
|
||||
# 1 - 0;40 2 - 40;80 3 - 80;120 4 - 120;160 5 - 160+
|
||||
|
||||
|
||||
def enough_free_space(available_space, trash_size, available_mass, mass_trash):
|
||||
if available_space + trash_size <= 5 and available_mass + mass_trash <= 5:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def where_is_closer(bin_distance, trash_distance):
|
||||
if bin_distance <= trash_distance:
|
||||
return 0
|
||||
return 1
|
||||
|
||||
|
||||
with open('tree_dataset.csv', 'w', newline='') as csv_file:
|
||||
file_writer = csv.writer(csv_file)
|
||||
file_writer.writerow(["dis_dump", "dis_trash", "mass", "space", "trash_mass", "trash_space", "decision"])
|
||||
|
||||
counter = 0
|
||||
|
||||
for dis_dump in levels:
|
||||
for dis_trash in levels:
|
||||
for mass in levels:
|
||||
for space in levels:
|
||||
for trash_mass in levels:
|
||||
for trash_space in levels:
|
||||
if counter % 10 == 0:
|
||||
if dis_dump == 1 and space >= 1 and mass >= 1:
|
||||
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
|
||||
elif dis_trash == 1 and enough_free_space(space, trash_space, mass, trash_mass):
|
||||
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 1])
|
||||
elif mass == 4 or space == 4 and not enough_free_space(space, trash_space, mass, trash_mass):
|
||||
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
|
||||
elif mass == 5 or space == 5:
|
||||
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
|
||||
elif mass <= 3 and space <= 3 and enough_free_space(space, trash_space, mass, trash_mass):
|
||||
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 1])
|
||||
elif mass == 4 or space == 4 and enough_free_space(space, trash_space, mass, trash_mass):
|
||||
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, where_is_closer(dis_dump, dis_trash)])
|
||||
elif not enough_free_space(space, trash_space, mass, trash_mass):
|
||||
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, 0])
|
||||
else:
|
||||
file_writer.writerow([dis_dump, dis_trash, mass, space, trash_mass, trash_space, None])
|
||||
|
||||
counter += 1
|
BIN
decision_tree.png
Normal file
BIN
decision_tree.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 319 KiB |
BIN
img/wet.png
Normal file
BIN
img/wet.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
49
main.py
49
main.py
@ -8,16 +8,16 @@ from random import shuffle, choice
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
import tree
|
||||
import pygame
|
||||
from time import sleep
|
||||
|
||||
from os import path
|
||||
from colors import gray
|
||||
from house import create_houses
|
||||
from truck import Truck
|
||||
from trash import Trash
|
||||
from TSP import tsp, tspmove
|
||||
from bfs import bfs
|
||||
from bfs import bfs, distance
|
||||
|
||||
|
||||
model = load_model("model.h5")
|
||||
@ -38,6 +38,9 @@ def game_keys(truck, multi_trash, houses, auto=False):
|
||||
print('↑')
|
||||
for tindex, trash in enumerate(multi_trash):
|
||||
if truck.pos == trash.pos:
|
||||
truck.mass += trash.mass
|
||||
truck.space += trash.space
|
||||
print(truck.mass, truck.space)
|
||||
prediction = model.predict(multi_trash[tindex].content)
|
||||
for i in range (3):
|
||||
if multi_trash[tindex].names[i][:3] == 'cat':
|
||||
@ -106,6 +109,46 @@ def game_loop():
|
||||
if event.key == pygame.K_ESCAPE:
|
||||
pygame.quit()
|
||||
quit()
|
||||
if (event.key == pygame.K_l):
|
||||
if path.isfile('./tree_model') and not os.stat(
|
||||
'./tree_model').st_size == 0:
|
||||
decision_tree = tree.load_tree_from_structure('./tree_model')
|
||||
print("Tree model loaded!")
|
||||
|
||||
if (event.key == pygame.K_k):
|
||||
print(":>")
|
||||
trash = multi_trash[0]
|
||||
dis_dump = distance(truck.pos,[80,80])
|
||||
dis_trash = distance(truck.pos, trash.pos)
|
||||
print(dis_dump, dis_trash, truck.mass, truck.space, trash.mass, trash.space)
|
||||
decision = tree.making_decision(decision_tree,
|
||||
dis_dump // 12 + 1,
|
||||
dis_trash // 12 + 1,
|
||||
truck.mass // 20 + 1, truck.space // 20 + 1,
|
||||
trash.mass // 20 + 1,
|
||||
trash.space // 20 + 1)
|
||||
|
||||
print(decision)
|
||||
if(decision[0]==0):
|
||||
actions = bfs(truck.pos, truck.dir_control,
|
||||
trash.pos, houses)
|
||||
print(actions)
|
||||
if not actions:
|
||||
print('Path couldn\'t be found')
|
||||
break
|
||||
|
||||
print('##################################################')
|
||||
while actions:
|
||||
action = actions.pop()
|
||||
pygame.event.post(pygame.event.Event(
|
||||
pygame.KEYDOWN, {'key': action}))
|
||||
|
||||
game_keys(truck, multi_trash, houses, True)
|
||||
update_images(gameDisplay, truck, multi_trash, houses)
|
||||
else:
|
||||
truck.space=0
|
||||
truck.mass=0
|
||||
|
||||
if (event.key == pygame.K_b):
|
||||
trash = multi_trash[0]
|
||||
actions = bfs(truck.pos, truck.dir_control,
|
||||
|
2
trash.py
2
trash.py
@ -37,6 +37,8 @@ class Trash:
|
||||
self.size = grid_size
|
||||
self.content = draw_trash(filenames)[0]
|
||||
self.names = draw_trash(filenames)[1]
|
||||
self.mass = random.randint(0, 25)
|
||||
self.space = random.randint(0, 25)
|
||||
|
||||
def new_pos(self, truck_pos, houses, multi):
|
||||
self.trash_content, self.trash_names = draw_trash(filenames)
|
||||
|
57
tree.py
Normal file
57
tree.py
Normal file
@ -0,0 +1,57 @@
|
||||
import joblib
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas
|
||||
from sklearn import tree
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
|
||||
attributes = ["dis_dump", "dis_trash", "mass", "space", "trash_mass", "trash_space"]
|
||||
decisions = ["decision"]
|
||||
|
||||
|
||||
def learning_tree():
|
||||
dataset = pandas.read_csv('./tree_dataset.csv')
|
||||
|
||||
x = dataset[attributes]
|
||||
y = dataset[decisions]
|
||||
|
||||
decision_tree = DecisionTreeClassifier()
|
||||
decision_tree = decision_tree.fit(x, y)
|
||||
|
||||
return decision_tree
|
||||
|
||||
|
||||
def making_decision(decision_tree, distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass,
|
||||
trash_space):
|
||||
decision = decision_tree.predict(
|
||||
[[distance_to_bin, distance_to_trash, filling_mass, filling_space, trash_mass, trash_space]])
|
||||
|
||||
return decision
|
||||
|
||||
|
||||
def save_all(decision_tree):
|
||||
save_tree_to_png(decision_tree)
|
||||
save_tree_to_txt(decision_tree)
|
||||
save_tree_to_structure(decision_tree)
|
||||
|
||||
|
||||
def save_tree_to_txt(decision_tree):
|
||||
with open('./tree_in_txt.txt', "w") as file:
|
||||
file.write(tree.export_text(decision_tree))
|
||||
|
||||
|
||||
def save_tree_to_png(decision_tree):
|
||||
fig = plt.figure(figsize=(25, 20))
|
||||
_ = tree.plot_tree(decision_tree, feature_names=attributes, filled=True)
|
||||
fig.savefig('./decision_tree.png')
|
||||
|
||||
|
||||
def save_tree_to_structure(decision_tree):
|
||||
joblib.dump(decision_tree, './tree_model')
|
||||
|
||||
|
||||
def load_tree_from_structure(file):
|
||||
return joblib.load(file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tre = learning_tree()
|
||||
save_all(tre)
|
1564
tree_dataset.csv
Normal file
1564
tree_dataset.csv
Normal file
File diff suppressed because it is too large
Load Diff
112
tree_in_txt.txt
Normal file
112
tree_in_txt.txt
Normal file
@ -0,0 +1,112 @@
|
||||
|--- feature_2 <= 3.50
|
||||
| |--- feature_4 <= 3.50
|
||||
| | |--- feature_3 <= 3.50
|
||||
| | | |--- feature_0 <= 1.50
|
||||
| | | | |--- class: 0
|
||||
| | | |--- feature_0 > 1.50
|
||||
| | | | |--- feature_4 <= 2.50
|
||||
| | | | | |--- class: 1
|
||||
| | | | |--- feature_4 > 2.50
|
||||
| | | | | |--- feature_2 <= 2.50
|
||||
| | | | | | |--- class: 1
|
||||
| | | | | |--- feature_2 > 2.50
|
||||
| | | | | | |--- class: 0
|
||||
| | |--- feature_3 > 3.50
|
||||
| | | |--- feature_3 <= 4.50
|
||||
| | | | |--- feature_0 <= 2.50
|
||||
| | | | | |--- feature_1 <= 1.50
|
||||
| | | | | | |--- feature_0 <= 1.50
|
||||
| | | | | | | |--- class: 0
|
||||
| | | | | | |--- feature_0 > 1.50
|
||||
| | | | | | | |--- feature_2 <= 2.50
|
||||
| | | | | | | | |--- class: 1
|
||||
| | | | | | | |--- feature_2 > 2.50
|
||||
| | | | | | | | |--- feature_4 <= 2.00
|
||||
| | | | | | | | | |--- class: 1
|
||||
| | | | | | | | |--- feature_4 > 2.00
|
||||
| | | | | | | | | |--- class: 0
|
||||
| | | | | |--- feature_1 > 1.50
|
||||
| | | | | | |--- class: 0
|
||||
| | | | |--- feature_0 > 2.50
|
||||
| | | | | |--- feature_1 <= 3.50
|
||||
| | | | | | |--- feature_1 <= 2.50
|
||||
| | | | | | | |--- feature_4 <= 2.50
|
||||
| | | | | | | | |--- class: 1
|
||||
| | | | | | | |--- feature_4 > 2.50
|
||||
| | | | | | | | |--- feature_2 <= 2.50
|
||||
| | | | | | | | | |--- class: 1
|
||||
| | | | | | | | |--- feature_2 > 2.50
|
||||
| | | | | | | | | |--- class: 0
|
||||
| | | | | | |--- feature_1 > 2.50
|
||||
| | | | | | | |--- feature_0 <= 3.50
|
||||
| | | | | | | | |--- class: 0
|
||||
| | | | | | | |--- feature_0 > 3.50
|
||||
| | | | | | | | |--- feature_4 <= 2.50
|
||||
| | | | | | | | | |--- class: 1
|
||||
| | | | | | | | |--- feature_4 > 2.50
|
||||
| | | | | | | | | |--- feature_2 <= 2.50
|
||||
| | | | | | | | | | |--- class: 1
|
||||
| | | | | | | | | |--- feature_2 > 2.50
|
||||
| | | | | | | | | | |--- class: 0
|
||||
| | | | | |--- feature_1 > 3.50
|
||||
| | | | | | |--- feature_0 <= 4.50
|
||||
| | | | | | | |--- class: 0
|
||||
| | | | | | |--- feature_0 > 4.50
|
||||
| | | | | | | |--- feature_1 <= 4.50
|
||||
| | | | | | | | |--- feature_4 <= 2.50
|
||||
| | | | | | | | | |--- class: 1
|
||||
| | | | | | | | |--- feature_4 > 2.50
|
||||
| | | | | | | | | |--- feature_2 <= 2.00
|
||||
| | | | | | | | | | |--- class: 1
|
||||
| | | | | | | | | |--- feature_2 > 2.00
|
||||
| | | | | | | | | | |--- class: 0
|
||||
| | | | | | | |--- feature_1 > 4.50
|
||||
| | | | | | | | |--- class: 0
|
||||
| | | |--- feature_3 > 4.50
|
||||
| | | | |--- class: 0
|
||||
| |--- feature_4 > 3.50
|
||||
| | |--- feature_2 <= 1.50
|
||||
| | | |--- feature_4 <= 4.50
|
||||
| | | | |--- feature_3 <= 3.50
|
||||
| | | | | |--- feature_0 <= 1.50
|
||||
| | | | | | |--- class: 0
|
||||
| | | | | |--- feature_0 > 1.50
|
||||
| | | | | | |--- class: 1
|
||||
| | | | |--- feature_3 > 3.50
|
||||
| | | | | |--- feature_3 <= 4.50
|
||||
| | | | | | |--- feature_1 <= 3.50
|
||||
| | | | | | | |--- feature_0 <= 2.50
|
||||
| | | | | | | | |--- class: 0
|
||||
| | | | | | | |--- feature_0 > 2.50
|
||||
| | | | | | | | |--- feature_1 <= 2.50
|
||||
| | | | | | | | | |--- class: 1
|
||||
| | | | | | | | |--- feature_1 > 2.50
|
||||
| | | | | | | | | |--- feature_0 <= 4.00
|
||||
| | | | | | | | | | |--- class: 0
|
||||
| | | | | | | | | |--- feature_0 > 4.00
|
||||
| | | | | | | | | | |--- class: 1
|
||||
| | | | | | |--- feature_1 > 3.50
|
||||
| | | | | | | |--- class: 0
|
||||
| | | | | |--- feature_3 > 4.50
|
||||
| | | | | | |--- class: 0
|
||||
| | | |--- feature_4 > 4.50
|
||||
| | | | |--- class: 0
|
||||
| | |--- feature_2 > 1.50
|
||||
| | | |--- class: 0
|
||||
|--- feature_2 > 3.50
|
||||
| |--- feature_1 <= 1.50
|
||||
| | |--- feature_4 <= 1.50
|
||||
| | | |--- feature_2 <= 4.50
|
||||
| | | | |--- feature_3 <= 4.50
|
||||
| | | | | |--- feature_0 <= 1.50
|
||||
| | | | | | |--- class: 0
|
||||
| | | | | |--- feature_0 > 1.50
|
||||
| | | | | | |--- class: 1
|
||||
| | | | |--- feature_3 > 4.50
|
||||
| | | | | |--- class: 0
|
||||
| | | |--- feature_2 > 4.50
|
||||
| | | | |--- class: 0
|
||||
| | |--- feature_4 > 1.50
|
||||
| | | |--- class: 0
|
||||
| |--- feature_1 > 1.50
|
||||
| | |--- class: 0
|
BIN
tree_model
Normal file
BIN
tree_model
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user