diff --git a/main.py b/main.py index 91a926b..6a8e94a 100644 --- a/main.py +++ b/main.py @@ -2,15 +2,16 @@ import sys import secrets from src.graphics import * +from src.rabbit import Rabbit from src.waiter import * if __name__ == "__main__": # SETUP pygame.init() clock = pygame.time.Clock() - fps = 2 graphics = Graphics() waiter = Waiter(graphics) + rabbit = Rabbit() # init functions graphics.drawBackground(waiter.matrix) @@ -19,7 +20,9 @@ if __name__ == "__main__": goal = None path = '' while True: + for event in pygame.event.get(): + # rabbit.check(waiter.matrix, waiter.X, waiter.Y) if event.type == pygame.QUIT: pygame.quit() sys.exit() @@ -50,4 +53,4 @@ if __name__ == "__main__": waiter.travel(nextStep, graphics) pygame.display.flip() - clock.tick(fps) + clock.tick(graphics.fps) diff --git a/resources/simulations/learning_data.txt b/resources/simulations/learning_data.txt new file mode 100644 index 0000000..2a10388 --- /dev/null +++ b/resources/simulations/learning_data.txt @@ -0,0 +1,6 @@ +0 | price: 0 x: 0 y: 0 +1 | price: 100 x: 0 y: 0 +0 | price: 0 x: 1 y: 1 +1 | price: 100 x: 1 y: 1 +0 | price: 0 x: 2 y: 1 +1 | price: 100 x: 2 y: 1 diff --git a/src/graphics.py b/src/graphics.py index 736bc70..28158a9 100644 --- a/src/graphics.py +++ b/src/graphics.py @@ -5,25 +5,22 @@ class Graphics: def __init__(self): self.image = { 'floor': pygame.image.load('./../resources/images/floor.jpg'), - # 'wall': pygame.image.load('./../resources/images/wall.png'), - # 'bar': pygame.image.load('./../resources/images/table3.png'), 'bar_floor': pygame.image.load('./../resources/images/waiter-up.png'), 'table': pygame.image.load('./../resources/images/table3.png'), + # 'waiter_N': pygame.image.load('./../resources/images/waiter-up.png'), 'waiter_S': pygame.image.load('./../resources/images/waiter-down.png'), 'waiter_E': pygame.image.load('./../resources/images/waiter-right.png'), 'waiter_W': pygame.image.load('./../resources/images/waiter-left.png'), # 'chair_front': pygame.image.load('./../resources/images/chair-front.png'), - # 'chair_back': pygame.image.load('./../resources/images/chair-back.png'), - # 'chair_left': pygame.image.load('./../resources/images/chair-left.png'), - # 'chair_right': pygame.image.load('./../resources/images/chair-right.png') } + self.fps = 2 self.block_size = 50 self.height = 15 self.width: int = 14 diff --git a/src/matrix.py b/src/matrix.py index 0d324a6..ff9bf88 100644 --- a/src/matrix.py +++ b/src/matrix.py @@ -44,3 +44,6 @@ class Matrix: def watch_through(self, x, y): return self.matrix[x][y].watch_through + + def tile_worth(self, x, y): + return self.matrix[x][y].worth \ No newline at end of file diff --git a/src/rabbit.py b/src/rabbit.py new file mode 100644 index 0000000..51f330f --- /dev/null +++ b/src/rabbit.py @@ -0,0 +1,20 @@ +from vowpalwabbit import pyvw + + +class Rabbit: + def __init__(self): + self.data = open('./../resources/simulations/learning_data.txt') + self.model = pyvw.vw(quiet=True) + self.train_set = self.data.read().splitlines() + self.learn() + + def learn(self): + for example in self.train_set: + self.model.learn(example) + + def check(self, matrix, x, y): + set = ["| price:", str(matrix.tile_worth(x, y)), "x:", str(x), "y:", str(y)] + test_sample = ' '.join(set) + prediction = self.model.predict(test_sample) + print(test_sample) + print(prediction) diff --git a/src/tile.py b/src/tile.py index a9a6eae..37ab459 100644 --- a/src/tile.py +++ b/src/tile.py @@ -23,6 +23,9 @@ class Tile: # Dystant do wierzcholka koncowego oszacowany za pomoca funkcji heurystyki H self.heuristic = 0 + #Atrybuty AI do rozpoznawania wartości pól + self.worth = 0 + # Operator porownywania pol def __eq__(self, other): return True if (self.position == other.position) else False diff --git a/src/waiter.py b/src/waiter.py index 5fb43d8..7d00a19 100644 --- a/src/waiter.py +++ b/src/waiter.py @@ -50,7 +50,7 @@ class Waiter(pygame.sprite.Sprite): self.move(1, 0, graphics) if self.direction == 'W': self.move(-1, 0, graphics) - #print(self.X, self.Y) + # print(self.X, self.Y) # AStar def findPath(self, goal): @@ -243,5 +243,6 @@ class Waiter(pygame.sprite.Sprite): self.update('L', graphics) graphics.update(self) + def getTotalCost(tile): return tile.totalCost