decision tree update
Co-authored-by: Sebastian Piotrowski <sebpio@st.amu.edu.pl> Co-authored-by: Marcin Matoga <marmat35@st.amu.edu.pl> Co-authored-by: Ladislaus3III <Ladislaus3III@users.noreply.github.com>
This commit is contained in:
parent
74788bbc94
commit
f0e0d4f4c1
6
.gitignore
vendored
6
.gitignore
vendored
@ -1,2 +1,6 @@
|
|||||||
# ignore pycache
|
# ignore pycache
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
||||||
|
# ignore pdf files
|
||||||
|
*.pdf
|
||||||
|
data
|
BIN
images/robot3.bmp
Normal file
BIN
images/robot3.bmp
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 MiB |
46
learning.py
Normal file
46
learning.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from sklearn import tree
|
||||||
|
import graphviz
|
||||||
|
|
||||||
|
class Learning():
|
||||||
|
|
||||||
|
X = []
|
||||||
|
Y = []
|
||||||
|
clf = tree.DecisionTreeClassifier()
|
||||||
|
def load_data(self):
|
||||||
|
file = open('dane.txt', "r")
|
||||||
|
data_str = []
|
||||||
|
data_str_separated = []
|
||||||
|
data = []
|
||||||
|
for lines in file:
|
||||||
|
s = file.readline()
|
||||||
|
data_str.append(s.replace(' \n', ''))
|
||||||
|
for lines in data_str:
|
||||||
|
s = lines.split(' ')
|
||||||
|
data_str_separated.append(s)
|
||||||
|
for lines in data_str_separated:
|
||||||
|
line = list(map(int, lines))
|
||||||
|
data.append(line)
|
||||||
|
for lines in data:
|
||||||
|
self.Y.append(lines[8])
|
||||||
|
lines.pop()
|
||||||
|
self.X.append(lines)
|
||||||
|
def learn(self):
|
||||||
|
#clf = tree.DecisionTreeClassifier()
|
||||||
|
self.clf = self.clf.fit(self.X, self.Y)
|
||||||
|
def draw_tree(self):
|
||||||
|
dot_data = tree.export_graphviz(self.clf, out_file=None, filled=True, class_names= ['1', '2', '3', '4', '5'], rounded=True, special_characters=True)
|
||||||
|
graph = graphviz.Source(dot_data)
|
||||||
|
graph.render("data")
|
||||||
|
|
||||||
|
def predict(self, param_array):
|
||||||
|
print(self.clf.predict([param_array]))
|
||||||
|
if self.clf.predict([param_array]) == 1:
|
||||||
|
print("oflagowac")
|
||||||
|
if self.clf.predict([param_array]) == 2:
|
||||||
|
print("zdetonowac")
|
||||||
|
if self.clf.predict([param_array]) == 3:
|
||||||
|
print("sprzedac na allegro")
|
||||||
|
if self.clf.predict([param_array]) == 4:
|
||||||
|
print("sprzedac na czarnym rynku")
|
||||||
|
if self.clf.predict([param_array]) == 5:
|
||||||
|
print("obejrzec")
|
33
main.py
33
main.py
@ -14,7 +14,8 @@ from settings import *
|
|||||||
from sprites import *
|
from sprites import *
|
||||||
from graphsearch import *
|
from graphsearch import *
|
||||||
from astar2 import *
|
from astar2 import *
|
||||||
from generate_examples import *
|
from pizza import *
|
||||||
|
from learning import *
|
||||||
|
|
||||||
|
|
||||||
class Game:
|
class Game:
|
||||||
@ -26,16 +27,18 @@ class Game:
|
|||||||
pg.key.set_repeat(500, 100)
|
pg.key.set_repeat(500, 100)
|
||||||
self.load_data()
|
self.load_data()
|
||||||
self.wentyl_bezpieczenstwa = 0
|
self.wentyl_bezpieczenstwa = 0
|
||||||
|
#ExampleGenerator.generate(ExampleGenerator)
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
game_folder = path.dirname(__file__)
|
game_folder = path.dirname(__file__)
|
||||||
self.map_data = []
|
self.map_data = []
|
||||||
with open(path.join(game_folder, 'map.txt'), 'rt') as f:
|
with open(path.join(game_folder, 'main_map.txt'), 'rt') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
self.map_data.append(line)
|
self.map_data.append(line)
|
||||||
|
|
||||||
def new(self):
|
def new(self):
|
||||||
# initialize all variables and do all the setup for a new game
|
# initialize all variables and do all the setup for a new game
|
||||||
|
#glebokosc rozmiar masa moc szkodliwosc zabawka stan teledysk
|
||||||
self.all_sprites = pg.sprite.Group()
|
self.all_sprites = pg.sprite.Group()
|
||||||
self.walls = pg.sprite.Group()
|
self.walls = pg.sprite.Group()
|
||||||
self.mines = pg.sprite.Group()
|
self.mines = pg.sprite.Group()
|
||||||
@ -44,6 +47,7 @@ class Game:
|
|||||||
for col, tile in enumerate(tiles):
|
for col, tile in enumerate(tiles):
|
||||||
if tile == '2':
|
if tile == '2':
|
||||||
Mine(self, col, row)
|
Mine(self, col, row)
|
||||||
|
Mine.set_parameters(Mine,-5,32,6,7,1,0,0,0)
|
||||||
if tile == '3':
|
if tile == '3':
|
||||||
Bomb(self, col, row)
|
Bomb(self, col, row)
|
||||||
if tile == '4':
|
if tile == '4':
|
||||||
@ -111,21 +115,27 @@ class Game:
|
|||||||
self.player.parse_maze_moves()
|
self.player.parse_maze_moves()
|
||||||
self.i_like_to_move_it()
|
self.i_like_to_move_it()
|
||||||
self.wentyl_bezpieczenstwa = 1
|
self.wentyl_bezpieczenstwa = 1
|
||||||
if event.key == pg.K_F3 and self.wentyl_bezpieczenstwa == 0:
|
if event.key == pg.K_F3:
|
||||||
|
pg.event.clear()
|
||||||
player_moves = BFS.run()
|
player_moves = BFS.run()
|
||||||
self.graph_move(player_moves)
|
self.graph_move(player_moves)
|
||||||
self.wentyl_bezpieczenstwa = 1
|
#self.wentyl_bezpieczenstwa = 1
|
||||||
if event.key == pg.K_F4 and self.wentyl_bezpieczenstwa == 0:
|
if event.key == pg.K_F4 and self.wentyl_bezpieczenstwa == 0:
|
||||||
|
pg.event.clear()
|
||||||
#print("test1")
|
#print("test1")
|
||||||
agent = SweeperAgent()
|
agent = SweeperAgent()
|
||||||
player_moves = SweeperAgent.run(agent)
|
player_moves = SweeperAgent.run(agent)
|
||||||
self.graph_move(player_moves)
|
self.graph_move(player_moves)
|
||||||
self.wentyl_bezpieczenstwa = 1
|
self.wentyl_bezpieczenstwa = 1
|
||||||
|
|
||||||
# Test.run()
|
# Test.run()
|
||||||
if event.key == pg.K_F5:
|
if event.key == pg.K_F5:
|
||||||
ExampleGenerator.generate()
|
print("lol xD")
|
||||||
|
pg.event.clear()
|
||||||
if event.key == pg.K_F6:
|
if event.key == pg.K_F6:
|
||||||
X = [[6,15,31,5,0,0,1,0],[6,36,5,8,1,1,1,0],
|
pg.event.clear()
|
||||||
|
#self.wentyl_bezpieczenstwa = 1
|
||||||
|
"""X = [[6,15,31,5,0,0,1,0],[6,36,5,8,1,1,1,0],
|
||||||
[13,40,29,3,0,0,0,0], [15,25,23,7,0,0,1,1],
|
[13,40,29,3,0,0,0,0], [15,25,23,7,0,0,1,1],
|
||||||
[17,34,5,5,1,1,0,1], [3,9,28,8,0,0,1,0],
|
[17,34,5,5,1,1,0,1], [3,9,28,8,0,0,1,0],
|
||||||
[18,38,43,6,1,1,1,1], [3,8,53,4,1,1,1,1],
|
[18,38,43,6,1,1,1,1], [3,8,53,4,1,1,1,1],
|
||||||
@ -133,7 +143,16 @@ class Game:
|
|||||||
Y = [1,5,1,1,4,1,2,2,1,1]
|
Y = [1,5,1,1,4,1,2,2,1,1]
|
||||||
clf = tree.DecisionTreeClassifier()
|
clf = tree.DecisionTreeClassifier()
|
||||||
clf = clf.fit(X,Y)
|
clf = clf.fit(X,Y)
|
||||||
tree.plot_tree(clf)
|
tree.plot_tree(clf)"""
|
||||||
|
|
||||||
|
"""my_learning = Learning()
|
||||||
|
my_learning.load_data()
|
||||||
|
my_learning.learn()
|
||||||
|
my_learning.draw_tree()
|
||||||
|
my_learning.predict()"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
7
main_map.txt
Normal file
7
main_map.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
.>.p...
|
||||||
|
###.ppp
|
||||||
|
....p..
|
||||||
|
...pp.p
|
||||||
|
..p###.
|
||||||
|
..p#2#.
|
||||||
|
..ppp..
|
@ -3,12 +3,13 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
class ExampleGenerator():
|
class ExampleGenerator():
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate():
|
def generate():
|
||||||
list = [[]]
|
list = [[]]
|
||||||
|
|
||||||
for i in range(1000):
|
for i in range(200):
|
||||||
while True:
|
while True:
|
||||||
c = [random.randrange(0, 21), random.randrange(5, 41), random.randrange(1, 61), random.randrange(0, 11), random.randrange(0, 2), random.randrange(0, 2), random.randrange(0, 2), random.randrange(0, 2)]
|
c = [random.randrange(0, 21), random.randrange(5, 41), random.randrange(1, 61), random.randrange(0, 11), random.randrange(0, 2), random.randrange(0, 2), random.randrange(0, 2), random.randrange(0, 2)]
|
||||||
if c in list:
|
if c in list:
|
||||||
@ -18,42 +19,43 @@ class ExampleGenerator():
|
|||||||
list.append(c)
|
list.append(c)
|
||||||
|
|
||||||
f = open("dane.txt", "w")
|
f = open("dane.txt", "w")
|
||||||
|
#f1 = open("dane_slownie.txt", "w")
|
||||||
|
#glebokosc rozmiar masa moc szkodliwosc zabawka stan teledysk
|
||||||
for i in range(len(list)):
|
for i in range(len(list)):
|
||||||
if i != 0:
|
if i != 0:
|
||||||
#glebokosc
|
#zabawka
|
||||||
if list[i][0] <= 19:
|
if list[i][5] == 1:
|
||||||
#rozmiar
|
#teledysk
|
||||||
if list[i][1] <= 39:
|
if list[i][7] == 1:
|
||||||
|
list[i].append(5)
|
||||||
|
else:
|
||||||
|
list[i].append(3)
|
||||||
|
else:
|
||||||
|
#glebokosc
|
||||||
|
if list[i][0] >= 19:
|
||||||
|
list[i].append(1)
|
||||||
|
else:
|
||||||
#masa
|
#masa
|
||||||
if list[i][2] <= 58:
|
if list[i][2] >= 58:
|
||||||
#moc
|
list[i].append(1)
|
||||||
if list[i][3] <= 9:
|
else:
|
||||||
#szkodliwosc
|
#rozmiar
|
||||||
if list[i][4] == 1:
|
if list[i][1] >= 39:
|
||||||
#zabawka
|
list[i].append(1)
|
||||||
if list[i][5] == 1:
|
else:
|
||||||
|
#moc
|
||||||
|
if list[i][3] >= 9:
|
||||||
|
list[i].append(1)
|
||||||
|
else:
|
||||||
|
#szkodliwosc
|
||||||
|
if list[i][4] == 1:
|
||||||
|
list[i].append(1)
|
||||||
|
else:
|
||||||
#stan
|
#stan
|
||||||
if list[i][6] == 1:
|
if list[i][6] == 1:
|
||||||
#teledysk
|
|
||||||
if list[i][7] == 1:
|
|
||||||
list[i].append(2)
|
|
||||||
else:
|
|
||||||
list[i].append(5)
|
|
||||||
else:
|
|
||||||
list[i].append(4)
|
list[i].append(4)
|
||||||
else:
|
else:
|
||||||
list[i].append(3)
|
list[i].append(2)
|
||||||
else:
|
|
||||||
list[i].append(1)
|
|
||||||
else:
|
|
||||||
list[i].append(1)
|
|
||||||
else:
|
|
||||||
list[i].append(1)
|
|
||||||
else:
|
|
||||||
list[i].append(1)
|
|
||||||
else:
|
|
||||||
list[i].append(1)
|
|
||||||
|
|
||||||
for i in range(len(list)):
|
for i in range(len(list)):
|
||||||
if i == 0:
|
if i == 0:
|
51
sprites.py
51
sprites.py
@ -4,6 +4,7 @@ import ctypes
|
|||||||
import ast
|
import ast
|
||||||
from settings import *
|
from settings import *
|
||||||
from maze import *
|
from maze import *
|
||||||
|
from learning import *
|
||||||
|
|
||||||
class Player(pg.sprite.Sprite):
|
class Player(pg.sprite.Sprite):
|
||||||
def __init__(self, game, x, y, direction = 'Right'):
|
def __init__(self, game, x, y, direction = 'Right'):
|
||||||
@ -11,8 +12,8 @@ class Player(pg.sprite.Sprite):
|
|||||||
pg.sprite.Sprite.__init__(self, self.groups)
|
pg.sprite.Sprite.__init__(self, self.groups)
|
||||||
self.game = game
|
self.game = game
|
||||||
#self.image = pg.Surface((TILESIZE, TILESIZE))
|
#self.image = pg.Surface((TILESIZE, TILESIZE))
|
||||||
self.image = pg.image.load('images/robot2.bmp')
|
self.image = pg.image.load('images/robot3.bmp')
|
||||||
self.baseImage = pg.image.load('images/robot2.bmp')
|
self.baseImage = pg.image.load('images/robot3.bmp')
|
||||||
#self.image.fill(YELLOW)
|
#self.image.fill(YELLOW)
|
||||||
self.image = pg.transform.scale(self.image, (TILESIZE, TILESIZE))
|
self.image = pg.transform.scale(self.image, (TILESIZE, TILESIZE))
|
||||||
self.baseImage = pg.transform.scale(self.image, (TILESIZE, TILESIZE))
|
self.baseImage = pg.transform.scale(self.image, (TILESIZE, TILESIZE))
|
||||||
@ -22,6 +23,8 @@ class Player(pg.sprite.Sprite):
|
|||||||
self.direction = direction
|
self.direction = direction
|
||||||
self.maze = Maze()
|
self.maze = Maze()
|
||||||
self.moves = ''
|
self.moves = ''
|
||||||
|
self.my_learning = Learning()
|
||||||
|
self.decision_tree_learning()
|
||||||
|
|
||||||
def set_direction(self, direction):
|
def set_direction(self, direction):
|
||||||
self.direction = direction
|
self.direction = direction
|
||||||
@ -84,10 +87,11 @@ class Player(pg.sprite.Sprite):
|
|||||||
print("Mine Ahead!")
|
print("Mine Ahead!")
|
||||||
self.x += -1
|
self.x += -1
|
||||||
else:
|
else:
|
||||||
self.x += -1
|
self.x += -1
|
||||||
print("I move: " + str(self.direction))
|
print("I move: " + str(self.direction))
|
||||||
|
|
||||||
print("My direction is: " + str(self.direction))
|
print("My direction is: " + str(self.direction))
|
||||||
|
self.check_bomb()
|
||||||
|
|
||||||
|
|
||||||
def check_border(self, dx=0, dy=0):
|
def check_border(self, dx=0, dy=0):
|
||||||
@ -180,6 +184,32 @@ class Player(pg.sprite.Sprite):
|
|||||||
|
|
||||||
print(self.moves)
|
print(self.moves)
|
||||||
|
|
||||||
|
def decision_tree_learning(self):
|
||||||
|
self.my_learning.load_data()
|
||||||
|
self.my_learning.learn()
|
||||||
|
self.my_learning.draw_tree()
|
||||||
|
#my_learning.predict()
|
||||||
|
|
||||||
|
""" sprawdzenie danych miny """
|
||||||
|
def check_bomb(self):
|
||||||
|
if self.check_if_on_mine():
|
||||||
|
current_mine = self.get_my_mine_object()
|
||||||
|
mine_params = current_mine.get_parameters()
|
||||||
|
self.my_learning.predict(mine_params)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def check_if_on_mine(self):
|
||||||
|
for mine in self.game.mines:
|
||||||
|
if mine.x == self.x and mine.y == self.y:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_my_mine_object(self):
|
||||||
|
if self.check_if_on_mine():
|
||||||
|
for mine in self.game.mines:
|
||||||
|
if mine.x == self.x and mine.y == self.y:
|
||||||
|
return mine
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -215,6 +245,21 @@ class Mine(pg.sprite.Sprite):
|
|||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
|
|
||||||
|
|
||||||
|
def set_parameters(self, glebokosc, rozmiar, masa, moc, szkodliwosc, zabawka, teledysk, stan):
|
||||||
|
self.glebokosc = glebokosc
|
||||||
|
self.rozmiar = rozmiar
|
||||||
|
self.masa = masa
|
||||||
|
self.moc = moc
|
||||||
|
self.szkodliwosc = szkodliwosc
|
||||||
|
self.zabawka = zabawka
|
||||||
|
self.teledysk = teledysk
|
||||||
|
self.stan = stan
|
||||||
|
|
||||||
|
def get_parameters(self):
|
||||||
|
param_array = [self.glebokosc, self.rozmiar, self.masa, self.moc, self.szkodliwosc, self.zabawka, self.teledysk, self.stan]
|
||||||
|
return param_array
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.rect.x = self.x * TILESIZE
|
self.rect.x = self.x * TILESIZE
|
||||||
self.rect.y = self.y * TILESIZE
|
self.rect.y = self.y * TILESIZE
|
||||||
|
Loading…
Reference in New Issue
Block a user