.
This commit is contained in:
parent
c5e0b65445
commit
e5a7a975e8
@ -103,7 +103,7 @@ class ForkliftAgent(AgentBase):
|
||||
stations = dict(self.graph.packingStations)
|
||||
if i.real_type == ItemType.SHELF:
|
||||
packing_station = stations[PatchType.packingA]
|
||||
elif i.real_type == ItemType.FRIDGE:
|
||||
elif i.real_type == ItemType.REFRIGERATOR:
|
||||
packing_station = stations[PatchType.packingB]
|
||||
elif i.real_type == ItemType.DOOR:
|
||||
packing_station = stations[PatchType.packingC]
|
||||
|
@ -18,6 +18,7 @@ from data.Order import Order
|
||||
from data.enum.ItemType import ItemType
|
||||
from decision.Action import Action
|
||||
from decision.ActionType import ActionType
|
||||
from imageClasification.Classificator import image_classification
|
||||
from pathfinding.PathfinderOnStates import PathFinderOnStates, PathFinderState
|
||||
from util.PathByEnum import PathByEnum
|
||||
from util.PathDefinitions import GridLocation, GridWithWeights
|
||||
@ -188,13 +189,13 @@ class GameModel(Model):
|
||||
|
||||
def recognise_item(self, item: Item):
|
||||
# TODO IMAGE PROCESSING
|
||||
val = self.classificator.image_clasification(self.picture_visualization.img)
|
||||
val = image_classification(self.picture_visualization.img, self.classificator)
|
||||
print("VAL: {}".format(val))
|
||||
|
||||
if val == ItemType.DOOR:
|
||||
item.guessed_type = ItemType.DOOR
|
||||
elif val == ItemType.FRIDGE:
|
||||
item.guessed_type = ItemType.FRIDGE
|
||||
elif val == ItemType.REFRIGERATOR:
|
||||
item.guessed_type = ItemType.REFRIGERATOR
|
||||
elif val == ItemType.SHELF:
|
||||
item.guessed_type = ItemType.SHELF
|
||||
|
||||
|
@ -4,4 +4,4 @@ from enum import Enum
|
||||
class ItemType(Enum):
|
||||
DOOR = "door"
|
||||
SHELF = "shelf"
|
||||
FRIDGE = "fridge"
|
||||
REFRIGERATOR = "refrigerator"
|
@ -0,0 +1,22 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
|
||||
# loaded_model = keras.models.load_model("my_model")
|
||||
|
||||
def image_classification(path, model):
|
||||
class_names = ['door', 'refrigerator', 'shelf']
|
||||
|
||||
img = tf.keras.utils.load_img(
|
||||
path, target_size=(180, 180)
|
||||
)
|
||||
img_array = tf.keras.utils.img_to_array(img)
|
||||
img_array = tf.expand_dims(img_array, 0) # Create a batch
|
||||
|
||||
predictions = model.predict(img_array)
|
||||
score = tf.nn.softmax(predictions[0])
|
||||
print(class_names[np.argmax(score)])
|
||||
return class_names[np.argmax(score)]
|
||||
|
||||
|
@ -9,7 +9,7 @@ from tensorflow.keras import layers
|
||||
from tensorflow.keras.models import Sequential
|
||||
|
||||
|
||||
class Classificator():
|
||||
class TrainClassificator():
|
||||
|
||||
def __init__(self, data_dir: str) -> None:
|
||||
super().__init__()
|
||||
|
6
main.py
6
main.py
@ -9,7 +9,7 @@ from PatchAgent import PatchAgent
|
||||
from PatchType import PatchType
|
||||
from PictureVisualizationAgent import PictureVisualizationAgent
|
||||
from data.enum.Direction import Direction
|
||||
from imageClasification.Classificator import Classificator
|
||||
from tensorflow import keras
|
||||
from util.PathDefinitions import GridWithWeights
|
||||
from visualization.DisplayAttributeElement import DisplayAttributeElement
|
||||
from visualization.DisplayItemListAttribute import DisplayItemListAttributeElement
|
||||
@ -88,13 +88,13 @@ if __name__ == '__main__':
|
||||
ordersText = DisplayOrderList("orderList")
|
||||
fulfilled_orders = DisplayOrderList("fulfilled_orders")
|
||||
|
||||
classificator = Classificator("imageClasification/images")
|
||||
model = keras.models.load_model("imageClasification/my_model")
|
||||
|
||||
server = ModularServer(GameModel,
|
||||
[grid, readyText, provided_itesm, recognised_items, ordersText,
|
||||
fulfilled_orders],
|
||||
"Automatyczny Wózek Widłowy",
|
||||
dict(width=gridHeight, height=gridWidth, graph=diagram, items=50, orders=3, classificator=classificator))
|
||||
dict(width=gridHeight, height=gridWidth, graph=diagram, items=50, orders=3, classificator=model))
|
||||
|
||||
server.port = 8888
|
||||
server.launch()
|
||||
|
@ -9,7 +9,7 @@ class PathByEnum:
|
||||
if item == ItemType.DOOR:
|
||||
a = str(random.randint(1, 10))
|
||||
return "item_images/door/drzwi" + a + ".jpg"
|
||||
if item == ItemType.FRIDGE:
|
||||
if item == ItemType.REFRIGERATOR:
|
||||
a = str(random.randint(1, 10))
|
||||
return "item_images/refrigerator/lodowka" + a + ".jpg"
|
||||
if item == ItemType.SHELF:
|
||||
|
Loading…
Reference in New Issue
Block a user