Merge remote-tracking branch 'origin/master'

This commit is contained in:
s473603 2023-03-28 20:05:41 +02:00
commit 9acdea6f1d
2 changed files with 84 additions and 0 deletions

51
recognition_v1/main.py Normal file
View File

@ -0,0 +1,51 @@
import numpy as np
import tensorflow as tf
from tensorflow import keras
def normalize(image, label):
return image / 255, label
directoryTRAIN = "C:/Users/KimD/PycharmProjects/Traktor_V1/Vegetable Images/train"
directoryVALIDATION = "C:/Users/KimD/PycharmProjects/Traktor_V1/Vegetable Images/validation"
train_ds = tf.keras.utils.image_dataset_from_directory(directoryTRAIN,
seed=123, batch_size=32,
image_size=(224, 224), color_mode='rgb')
val_ds = tf.keras.utils.image_dataset_from_directory(directoryVALIDATION,
seed=123, batch_size=32,
image_size=(224, 224), color_mode='rgb')
train_ds = train_ds.map(normalize)
val_ds = val_ds.map(normalize)
model = keras.Sequential([
keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(224, 224, 3)),
keras.layers.MaxPool2D((2, 2)),
keras.layers.Conv2D(128, (3, 3), activation='relu'),
keras.layers.MaxPool2D((2, 2)),
keras.layers.Conv2D(256, (3, 3), activation='relu'),
keras.layers.MaxPool2D((2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(1024, activation='relu'),
keras.layers.Dense(9, activation='softmax')
])
print(model.summary())
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
trainHistory = model.fit(train_ds, epochs=4, validation_data=val_ds)
model = keras.models.load_model("C:/Users/KimD/PycharmProjects/Traktor_V1/mode2.h5")
(loss, accuracy) = model.evaluate(val_ds)
print(loss)
print(accuracy)
# model.save("mode2.h5")

View File

@ -0,0 +1,33 @@
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
import cv2
directory = "C:/Users/KimD/PycharmProjects/Traktor_V1/Vegetable Images/test"
class VegebatlesRecognizer:
def recognize(self, image_path) -> str:
model = keras.models.load_model("C:/Users/KimD/PycharmProjects/Traktor_V1/mode2.h5")
class_names = ['Bean', 'Broccoli', 'Cabbage', 'Capsicum', 'Carrot', 'Cucumber', 'Potato', 'Pumpkin', 'Tomato']
img = cv2.imread(image_path)
# cv2.imshow("lala", img)
# cv2.waitKey(0)
img = (np.expand_dims(img, 0))
predictions = model.predict(img)[0].tolist()
print(class_names)
print(predictions)
print(max(predictions))
print(predictions.index(max(predictions)))
return class_names[predictions.index(max(predictions))]
# image_path = 'C:/Users/KimD/PycharmProjects/Traktor_V1/Vegetable Images/test/Carrot/1001.jpg'
# uio = VegebatlesRecognizer()
# print(uio.recognize(image_path))