diff --git a/ForkliftAgent.py b/ForkliftAgent.py index d40448b..502bdc6 100644 --- a/ForkliftAgent.py +++ b/ForkliftAgent.py @@ -103,7 +103,7 @@ class ForkliftAgent(AgentBase): stations = dict(self.graph.packingStations) if i.real_type == ItemType.SHELF: packing_station = stations[PatchType.packingA] - elif i.real_type == ItemType.FRIDGE: + elif i.real_type == ItemType.REFRIGERATOR: packing_station = stations[PatchType.packingB] elif i.real_type == ItemType.DOOR: packing_station = stations[PatchType.packingC] diff --git a/GameModel.py b/GameModel.py index a373124..b36070f 100644 --- a/GameModel.py +++ b/GameModel.py @@ -18,6 +18,7 @@ from data.Order import Order from data.enum.ItemType import ItemType from decision.Action import Action from decision.ActionType import ActionType +from imageClasification.Classificator import image_classification from pathfinding.PathfinderOnStates import PathFinderOnStates, PathFinderState from util.PathByEnum import PathByEnum from util.PathDefinitions import GridLocation, GridWithWeights @@ -188,13 +189,13 @@ class GameModel(Model): def recognise_item(self, item: Item): # TODO IMAGE PROCESSING - val = self.classificator.image_clasification(self.picture_visualization.img) + val = image_classification(self.picture_visualization.img, self.classificator) print("VAL: {}".format(val)) if val == ItemType.DOOR: item.guessed_type = ItemType.DOOR - elif val == ItemType.FRIDGE: - item.guessed_type = ItemType.FRIDGE + elif val == ItemType.REFRIGERATOR: + item.guessed_type = ItemType.REFRIGERATOR elif val == ItemType.SHELF: item.guessed_type = ItemType.SHELF diff --git a/data/enum/ItemType.py b/data/enum/ItemType.py index 54fc4c1..af038ed 100644 --- a/data/enum/ItemType.py +++ b/data/enum/ItemType.py @@ -4,4 +4,4 @@ from enum import Enum class ItemType(Enum): DOOR = "door" SHELF = "shelf" - FRIDGE = "fridge" \ No newline at end of file + REFRIGERATOR = "refrigerator" \ No newline at end of file diff --git a/imageClasification/Classificator.py b/imageClasification/Classificator.py index e69de29..f50f2fa 100644 --- a/imageClasification/Classificator.py +++ b/imageClasification/Classificator.py @@ -0,0 +1,22 @@ +import numpy as np +import tensorflow as tf +from tensorflow import keras + + +# loaded_model = keras.models.load_model("my_model") + +def image_classification(path, model): + class_names = ['door', 'refrigerator', 'shelf'] + + img = tf.keras.utils.load_img( + path, target_size=(180, 180) + ) + img_array = tf.keras.utils.img_to_array(img) + img_array = tf.expand_dims(img_array, 0) # Create a batch + + predictions = model.predict(img_array) + score = tf.nn.softmax(predictions[0]) + print(class_names[np.argmax(score)]) + return class_names[np.argmax(score)] + + diff --git a/imageClasification/usedAssets/TrainClassificator.py b/imageClasification/usedAssets/TrainClassificator.py index 4f9d2fd..07a22e0 100644 --- a/imageClasification/usedAssets/TrainClassificator.py +++ b/imageClasification/usedAssets/TrainClassificator.py @@ -9,7 +9,7 @@ from tensorflow.keras import layers from tensorflow.keras.models import Sequential -class Classificator(): +class TrainClassificator(): def __init__(self, data_dir: str) -> None: super().__init__() diff --git a/main.py b/main.py index 1991f58..c72ea9e 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ from PatchAgent import PatchAgent from PatchType import PatchType from PictureVisualizationAgent import PictureVisualizationAgent from data.enum.Direction import Direction -from imageClasification.Classificator import Classificator +from tensorflow import keras from util.PathDefinitions import GridWithWeights from visualization.DisplayAttributeElement import DisplayAttributeElement from visualization.DisplayItemListAttribute import DisplayItemListAttributeElement @@ -88,13 +88,13 @@ if __name__ == '__main__': ordersText = DisplayOrderList("orderList") fulfilled_orders = DisplayOrderList("fulfilled_orders") - classificator = Classificator("imageClasification/images") + model = keras.models.load_model("imageClasification/my_model") server = ModularServer(GameModel, [grid, readyText, provided_itesm, recognised_items, ordersText, fulfilled_orders], "Automatyczny Wózek Widłowy", - dict(width=gridHeight, height=gridWidth, graph=diagram, items=50, orders=3, classificator=classificator)) + dict(width=gridHeight, height=gridWidth, graph=diagram, items=50, orders=3, classificator=model)) server.port = 8888 server.launch() diff --git a/util/PathByEnum.py b/util/PathByEnum.py index b517494..63d3b7c 100644 --- a/util/PathByEnum.py +++ b/util/PathByEnum.py @@ -9,7 +9,7 @@ class PathByEnum: if item == ItemType.DOOR: a = str(random.randint(1, 10)) return "item_images/door/drzwi" + a + ".jpg" - if item == ItemType.FRIDGE: + if item == ItemType.REFRIGERATOR: a = str(random.randint(1, 10)) return "item_images/refrigerator/lodowka" + a + ".jpg" if item == ItemType.SHELF: