2023-06-01 16:09:01 +02:00
|
|
|
import os
|
|
|
|
import glob
|
|
|
|
import PIL
|
|
|
|
from PIL import Image
|
|
|
|
import tensorflow as tf
|
|
|
|
import pickle
|
|
|
|
from tensorflow import keras
|
|
|
|
from keras import layers
|
|
|
|
from keras.models import Sequential
|
|
|
|
import pathlib
|
|
|
|
|
|
|
|
|
|
|
|
class NeuralN:
|
|
|
|
# @staticmethod
|
2023-06-01 17:23:53 +02:00
|
|
|
def predict(self,image):
|
2023-06-01 16:09:01 +02:00
|
|
|
data_dir = pathlib.Path('zdjecia')
|
|
|
|
saved_model_path = pathlib.Path('trained_model.h5')
|
|
|
|
class_names_path = pathlib.Path("class_names.pkl")
|
|
|
|
image_count = sum(len(files) for _, _, files in os.walk(data_dir))
|
|
|
|
|
|
|
|
print(image_count)
|
|
|
|
|
|
|
|
# ORK_ARCHER = list(glob.glob('C:\\mobs_photos\\ORK_ARCHER'))
|
|
|
|
# im = PIL.Image.open(ORK_ARCHER[0])
|
|
|
|
# im.show()
|
|
|
|
if os.path.exists(saved_model_path):
|
|
|
|
model = tf.keras.models.load_model(saved_model_path)
|
|
|
|
print("Saved model loaded")
|
|
|
|
with open(class_names_path, 'rb') as f:
|
|
|
|
class_names = pickle.load(f)
|
|
|
|
print("Class names loaded.")
|
|
|
|
|
|
|
|
else:
|
|
|
|
train_ds = tf.keras.utils.image_dataset_from_directory(
|
|
|
|
data_dir,
|
|
|
|
validation_split=0.2,
|
|
|
|
subset="training",
|
|
|
|
seed=123,
|
|
|
|
image_size=(180, 180),
|
|
|
|
batch_size=32)
|
|
|
|
|
|
|
|
val_ds = tf.keras.utils.image_dataset_from_directory(
|
|
|
|
data_dir,
|
|
|
|
validation_split=0.2,
|
|
|
|
subset="validation",
|
|
|
|
seed=123,
|
|
|
|
image_size=(180, 180),
|
|
|
|
batch_size=32)
|
|
|
|
|
|
|
|
# test_ds = tf.keras.utils.image_dataset_from_directory(
|
|
|
|
# data_dir,
|
|
|
|
# seed=123,
|
|
|
|
# image_size=(180, 180),
|
|
|
|
# batch_size=32)
|
|
|
|
|
|
|
|
class_names = train_ds.class_names
|
|
|
|
print(class_names)
|
|
|
|
|
|
|
|
num_classes = len(class_names)
|
|
|
|
|
|
|
|
model = Sequential([
|
|
|
|
layers.Rescaling(1. / 255, input_shape=(180, 180, 3)),
|
|
|
|
layers.Conv2D(16, 3, padding='same', activation='relu'),
|
|
|
|
layers.MaxPooling2D(),
|
|
|
|
layers.Conv2D(32, 3, padding='same', activation='relu'),
|
|
|
|
layers.MaxPooling2D(),
|
|
|
|
layers.Conv2D(64, 3, padding='same', activation='relu'),
|
|
|
|
layers.MaxPooling2D(),
|
|
|
|
layers.Flatten(),
|
|
|
|
layers.Dense(128, activation='relu'),
|
|
|
|
layers.Dense(num_classes)
|
|
|
|
])
|
|
|
|
|
|
|
|
model.compile(optimizer='adam',
|
|
|
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(
|
|
|
|
from_logits=True),
|
|
|
|
metrics=['accuracy'])
|
|
|
|
model.summary()
|
|
|
|
|
|
|
|
epochs = 1
|
|
|
|
history = model.fit(
|
|
|
|
train_ds,
|
|
|
|
validation_data=val_ds,
|
|
|
|
epochs=epochs
|
|
|
|
)
|
|
|
|
model.save("trained_model.h5")
|
|
|
|
print("Model trained and saved.")
|
|
|
|
with open(class_names_path, 'wb') as f:
|
|
|
|
pickle.dump(train_ds.class_names, f)
|
|
|
|
print("Class names saved.")
|
|
|
|
# loaded_model = tf.keras.models.load_model("trained_model.h5")
|
|
|
|
probability_model = tf.keras.Sequential([model,
|
|
|
|
tf.keras.layers.Softmax()])
|
|
|
|
|
|
|
|
#image_path = image
|
2023-06-01 17:23:53 +02:00
|
|
|
image_path = pathlib.Path('zdjecia\ORK_ARCHER\ork_archer (942).jpg')
|
2023-06-01 16:09:01 +02:00
|
|
|
image = Image.open(image_path)
|
|
|
|
|
|
|
|
# Preprocess the image
|
|
|
|
image = image.resize((180, 180)) # Resize to match the input size of the model
|
|
|
|
image_array = tf.keras.preprocessing.image.img_to_array(image)
|
|
|
|
image_array = image_array / 255.0 # Normalize pixel values
|
|
|
|
|
|
|
|
# Add an extra dimension to the image array
|
|
|
|
image_array = tf.expand_dims(image_array, 0)
|
|
|
|
# Make the prediction
|
|
|
|
predictions = probability_model.predict(image_array)
|
|
|
|
|
|
|
|
# Convert the predictions to class labels
|
|
|
|
predicted_label = class_names[predictions[0].argmax()]
|
|
|
|
#actions = {
|
|
|
|
# 'ORK_MELEE': 'fight',
|
|
|
|
# 'ORK_ARCHER': 'change_dir',
|
|
|
|
# 'SAURON': 'change_dir'
|
|
|
|
#}
|
|
|
|
|
|
|
|
# Get the action for the predicted character
|
|
|
|
#action = actions.get(predicted_label, 'unknown')
|
|
|
|
|
|
|
|
# Print the predicted label
|
|
|
|
print(predicted_label)
|
2023-06-01 17:23:53 +02:00
|
|
|
return predicted_label#, action
|