Added recognition of content in boxes using neural network
This commit is contained in:
parent
63c9ebaf0d
commit
fe3f9bf4fe
BIN
nn_model.h5
Normal file
BIN
nn_model.h5
Normal file
Binary file not shown.
@ -2,7 +2,7 @@ from src.agent.map.gameMap import GameMap
|
|||||||
from mesa.visualization.modules import CanvasGrid
|
from mesa.visualization.modules import CanvasGrid
|
||||||
from mesa.visualization.modules.TextVisualization import TextElement
|
from mesa.visualization.modules.TextVisualization import TextElement
|
||||||
from mesa.visualization.ModularVisualization import ModularServer
|
from mesa.visualization.ModularVisualization import ModularServer
|
||||||
from src.decisiontree import create_model
|
#from src.decisiontree import create_model
|
||||||
from src.direction import Direction
|
from src.direction import Direction
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
@ -1,17 +1,29 @@
|
|||||||
import random
|
import random
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import rpy2.robjects as robjects
|
#import rpy2.robjects as robjects
|
||||||
from rpy2.robjects.packages import importr
|
#from rpy2.robjects.packages import importr
|
||||||
from rpy2.robjects import pandas2ri
|
#from rpy2.robjects import pandas2ri
|
||||||
from rpy2.robjects.conversion import localconverter
|
#from rpy2.robjects.conversion import localconverter
|
||||||
from rpy2.robjects.packages import importr
|
#from rpy2.robjects.packages import importr
|
||||||
from src.agent.model import *
|
from src.agent.model import *
|
||||||
from src.agent.state import AgentState
|
from src.agent.state import AgentState
|
||||||
from src.direction import Direction
|
from src.direction import Direction
|
||||||
|
from src.items.armory import WM9
|
||||||
from src.treesearch.actionsInterpreter import ActionInterpreter
|
from src.treesearch.actionsInterpreter import ActionInterpreter
|
||||||
from src.agent.model.dice.dice import roll_the_dice
|
from src.agent.model.dice.dice import roll_the_dice
|
||||||
from src.treesearch.bfs import BFS
|
from src.treesearch.bfs import BFS
|
||||||
from src.nominalize import nominalize
|
from src.nominalize import nominalize
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import PIL
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow import keras
|
||||||
|
from tensorflow.keras import layers
|
||||||
|
from tensorflow.keras.models import Sequential
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
|
||||||
class Player(Creature):
|
class Player(Creature):
|
||||||
def __init__(self, unique_id, model, n, s, a, w, max_hp, hp, weapon, arm, g, w2, w3, list_of_chests):
|
def __init__(self, unique_id, model, n, s, a, w, max_hp, hp, weapon, arm, g, w2, w3, list_of_chests):
|
||||||
@ -40,6 +52,75 @@ class Player(Creature):
|
|||||||
self.opponent=None
|
self.opponent=None
|
||||||
self.strategy="PASS"
|
self.strategy="PASS"
|
||||||
|
|
||||||
|
def identify_content(self, chest):
|
||||||
|
dataset_url = "https://drive.google.com/uc?export=download&id=1b6w1FbupRmgVC-q9Lpdlg5OBK_gRKEUy"
|
||||||
|
data_dir = tf.keras.utils.get_file(fname='loot', origin=dataset_url, untar=True)
|
||||||
|
data_dir = pathlib.Path(data_dir)
|
||||||
|
|
||||||
|
#image_count = len(list(data_dir.glob('*/*.jpg')))
|
||||||
|
#print(image_count)
|
||||||
|
|
||||||
|
batch_size = 32
|
||||||
|
img_height = 180
|
||||||
|
img_width = 180
|
||||||
|
|
||||||
|
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
|
||||||
|
data_dir,
|
||||||
|
validation_split=0.2,
|
||||||
|
subset="training",
|
||||||
|
seed=123,
|
||||||
|
image_size=(img_height, img_width),
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
|
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
|
||||||
|
data_dir,
|
||||||
|
validation_split=0.2,
|
||||||
|
subset="validation",
|
||||||
|
seed=123,
|
||||||
|
image_size=(img_height, img_width),
|
||||||
|
batch_size=batch_size)
|
||||||
|
|
||||||
|
class_names = train_ds.class_names
|
||||||
|
#print(class_names)
|
||||||
|
|
||||||
|
normalization_layer = layers.experimental.preprocessing.Rescaling(1. / 255)
|
||||||
|
|
||||||
|
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
|
||||||
|
image_batch, labels_batch = next(iter(normalized_ds))
|
||||||
|
first_image = image_batch[0]
|
||||||
|
# Notice the pixels values are now in `[0,1]`.
|
||||||
|
#print(np.min(first_image), np.max(first_image))
|
||||||
|
|
||||||
|
#num_classes = 3
|
||||||
|
|
||||||
|
# Recreate the exact same model, including its weights and the optimizer
|
||||||
|
new_model = tf.keras.models.load_model('nn_model.h5')
|
||||||
|
|
||||||
|
# Show the model architecture
|
||||||
|
#new_model.summary()
|
||||||
|
|
||||||
|
# loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
|
||||||
|
# print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
|
||||||
|
|
||||||
|
object_url = chest.type
|
||||||
|
object_path = tf.keras.utils.get_file(chest.file_name, origin=object_url)
|
||||||
|
|
||||||
|
img = keras.preprocessing.image.load_img(
|
||||||
|
object_path, target_size=(img_height, img_width)
|
||||||
|
)
|
||||||
|
img_array = keras.preprocessing.image.img_to_array(img)
|
||||||
|
img_array = tf.expand_dims(img_array, 0) # Create a batch
|
||||||
|
|
||||||
|
predictions = new_model.predict(img_array)
|
||||||
|
score = tf.nn.softmax(predictions[0])
|
||||||
|
|
||||||
|
print(
|
||||||
|
"This image most likely belongs to {} with a {:.2f} percent confidence."
|
||||||
|
.format(class_names[np.argmax(score)], 100 * np.max(score))
|
||||||
|
)
|
||||||
|
|
||||||
|
return class_names[np.argmax(score)]
|
||||||
|
|
||||||
def predict_strategy(self, opponent):
|
def predict_strategy(self, opponent):
|
||||||
testcase = pd.DataFrame({'p_strength': nominalize(self.get_strength(),10),
|
testcase = pd.DataFrame({'p_strength': nominalize(self.get_strength(),10),
|
||||||
'p_agility':nominalize(self.get_agility(), 10),
|
'p_agility':nominalize(self.get_agility(), 10),
|
||||||
@ -169,12 +250,14 @@ class Player(Creature):
|
|||||||
print("Player died :/")
|
print("Player died :/")
|
||||||
|
|
||||||
def open_chest(self, chest):
|
def open_chest(self, chest):
|
||||||
self.gold = self.gold + chest.gold
|
if chest.type == 1:
|
||||||
print("------Chest opened. Gold inside:", chest.gold, "-----")
|
ch_gold = 3 * roll_the_dice(6)
|
||||||
chest.gold = 0
|
self.gold = self.gold + ch_gold
|
||||||
self.opened_chests += 1
|
print("------Chest opened. Gold inside:", ch_gold, "-----")
|
||||||
self.has_goal_chest = False
|
elif chest.type == 2:
|
||||||
chest.model.grid.remove_agent(chest)
|
self.weapon1 = WM9
|
||||||
|
else:
|
||||||
|
self.health = 0
|
||||||
# self.direction = 0 # po osiągnięciu jednego celu 'restartuje sie' na szukanie ścieżki do kolejnego -- NIE ZEROWAĆ OBROTU - to psuje goldState w bfs!!!
|
# self.direction = 0 # po osiągnięciu jednego celu 'restartuje sie' na szukanie ścieżki do kolejnego -- NIE ZEROWAĆ OBROTU - to psuje goldState w bfs!!!
|
||||||
# if isinstance(chest.loot,Armor):
|
# if isinstance(chest.loot,Armor):
|
||||||
# buffer = self.armor
|
# buffer = self.armor
|
||||||
@ -247,7 +330,18 @@ class Player(Creature):
|
|||||||
cellmates = self.model.grid.get_cell_list_contents([self.pos])
|
cellmates = self.model.grid.get_cell_list_contents([self.pos])
|
||||||
if len(cellmates) > 1:
|
if len(cellmates) > 1:
|
||||||
if isinstance(cellmates[0], Box):
|
if isinstance(cellmates[0], Box):
|
||||||
|
decision = self.identify_content(cellmates[0])
|
||||||
|
print("Content of chest: "+cellmates[0].address)
|
||||||
|
if (decision == 'coins') or (decision == 'weapons'):
|
||||||
|
print("I will open this chest!")
|
||||||
self.open_chest(cellmates[0])
|
self.open_chest(cellmates[0])
|
||||||
|
print("Type of opened chest: ", cellmates[0].type)
|
||||||
|
else:
|
||||||
|
print("Probably a trap - chest skipped!")
|
||||||
|
print("Type of skipped chest: ", cellmates[0].type)
|
||||||
|
self.opened_chests += 1
|
||||||
|
self.has_goal_chest = False
|
||||||
|
cellmates[0].model.grid.remove_agent(cellmates[0])
|
||||||
else:
|
else:
|
||||||
self.opponent = cellmates[0]
|
self.opponent = cellmates[0]
|
||||||
self.strategy=self.predict_strategy(self.opponent)
|
self.strategy=self.predict_strategy(self.opponent)
|
||||||
|
@ -12,12 +12,14 @@ from src.items.armory import *#WM1, A1, WR1, S1, WM2, A2
|
|||||||
from mesa.time import RandomActivation
|
from mesa.time import RandomActivation
|
||||||
from mesa.space import MultiGrid
|
from mesa.space import MultiGrid
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
# from mesa.datacollection import DataCollector
|
# from mesa.datacollection import DataCollector
|
||||||
|
|
||||||
x = 10
|
x = 10
|
||||||
y = 10
|
y = 10
|
||||||
step_counter = 0
|
step_counter = 0
|
||||||
boxes_number = 5
|
boxes_number = 6
|
||||||
creatures_number = 35
|
creatures_number = 35
|
||||||
|
|
||||||
class GameMap(Model):
|
class GameMap(Model):
|
||||||
@ -34,7 +36,8 @@ class GameMap(Model):
|
|||||||
y = self.random.randrange(self.grid.height)
|
y = self.random.randrange(self.grid.height)
|
||||||
self.grid.place_agent(self.player, (x, y))
|
self.grid.place_agent(self.player, (x, y))
|
||||||
for i in range(self.boxes_number):
|
for i in range(self.boxes_number):
|
||||||
box = Box(i, self)
|
r_type = random.randrange(1, 4)
|
||||||
|
box = Box(i, self, r_type)
|
||||||
x = self.random.randrange(self.grid.width)
|
x = self.random.randrange(self.grid.width)
|
||||||
y = self.random.randrange(self.grid.height)
|
y = self.random.randrange(self.grid.height)
|
||||||
if self.grid.is_cell_empty((x, y)):
|
if self.grid.is_cell_empty((x, y)):
|
||||||
@ -70,7 +73,7 @@ class GameMap(Model):
|
|||||||
self.schedule.add(creature)
|
self.schedule.add(creature)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
self.datacollector=DataCollector(model_reporters={"HP":self.player.health, "Gold":self.player.gold, "Position (x, y)":self.player.pos}) #informacje o stanie planszy, pozycja agenta
|
#self.datacollector=DataCollector(model_reporters={"HP":self.player.health, "Gold":self.player.gold, "Position (x, y)":self.player.pos}) #informacje o stanie planszy, pozycja agenta
|
||||||
#other data: position, strength & other parameters
|
#other data: position, strength & other parameters
|
||||||
def get_list_of_chests(self):
|
def get_list_of_chests(self):
|
||||||
return self.listOfChests
|
return self.listOfChests
|
||||||
@ -84,5 +87,5 @@ class GameMap(Model):
|
|||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.schedule.step()
|
self.schedule.step()
|
||||||
self.datacollector.collect(self)
|
#self.datacollector.collect(self)
|
||||||
#print(str(self.datacollector.model_reporters))
|
#print(str(self.datacollector.model_reporters))
|
||||||
|
@ -1,14 +1,41 @@
|
|||||||
from mesa import Agent
|
from mesa import Agent
|
||||||
from .dice.dice import roll_the_dice
|
import random
|
||||||
|
#from .dice.dice import roll_the_dice
|
||||||
|
|
||||||
|
golds = [("https://drive.google.com/uc?export=download&id=1fWeew0jXZ1lZBmv6CG5viLGloJAex6ao",'moneta'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1UrXbbfJhfCuDZSnxund7sVk40QMS3R2Q", 'moneta2'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1qH0OP4X1NQqpHtUkwD8SryJtKCKcZzVe", 'moneta3'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1b9tZf639mEWgiWq_EYyqjeKYPEMi7dX9", 'moneta4'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1z9jt-j3aS1fRUgVA_t1zR7TS5QzuXW5b", 'moneta5')]
|
||||||
|
weapons = [("https://drive.google.com/uc?export=download&id=1TA-ObC33FaiHmgQ6i71Kcb32VrHsKUd1", 'miecz'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1oyv15FSPJ84xx1tQLJW8Wlbb1JUx2xCy", 'miecz2'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1Ha0eeLRLcidrMAN1P59V9zB8uSQq3GM5", 'miecz3'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1GetUWnglUtqWqcK4sd5HsdVuaUpU5Ec_", 'miecz4'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1RImo84OykYICvwfLycDEb5tr4tPbGVy1", 'miecz5')]
|
||||||
|
traps = [("https://drive.google.com/uc?export=download&id=1G-AxY712V-eT2ylW0VXn4o2V_4pvy7lz", 'kwiat'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1i7MwzJRPBZ-KrCDqhT5RCXLstLlWCsx9", 'kwiat2'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1zF7wQuG1gtQ796m6TM08FOUOVK8s_qhT", 'kwiat3'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1qwrIThsoKg44b57JXvbXe--TIplacd-i", 'kwiat4'),
|
||||||
|
("https://drive.google.com/uc?export=download&id=1YsqdQaLyD8Es4w09p4Zdw2CufNkl_avN", 'kwiat5')]
|
||||||
|
|
||||||
chest_names=['c'+str(i)+'.jpg' for i in range(1,10)]
|
|
||||||
class Box(Agent):
|
class Box(Agent):
|
||||||
def __init__(self, unique_id, model):
|
def __init__(self, unique_id, model, type):
|
||||||
super().__init__(unique_id, model)
|
super().__init__(unique_id, model)
|
||||||
self.gold = 3 * roll_the_dice(6)
|
#self.gold = 3 * roll_the_dice(6)
|
||||||
self.isBox = True
|
self.isBox = True
|
||||||
self.isCreature = False
|
self.isCreature = False
|
||||||
self.isPlayer = False
|
self.isPlayer = False
|
||||||
|
self.type = type
|
||||||
|
r_img = random.randrange(0,4)
|
||||||
|
if type == 1:
|
||||||
|
self.address = golds[r_img][0]
|
||||||
|
self.file_name = golds[r_img][1]
|
||||||
|
elif type == 2:
|
||||||
|
self.address = weapons[r_img][0]
|
||||||
|
self.file_name = weapons[r_img][1]
|
||||||
|
else:
|
||||||
|
self.address = traps[r_img][0]
|
||||||
|
self.file_name = traps[r_img][1]
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user