automatyczny_kelner/Network/Network.py

81 lines
1.9 KiB
Python
Raw Normal View History

2023-06-01 20:15:18 +02:00
import tensorflow as tf
from keras import layers
from keras.models import Sequential
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
import os
import PIL
import PIL.Image
import numpy
2023-06-01 21:50:13 +02:00
def normalize(image, label):
return image / 255, label
2023-06-01 20:15:18 +02:00
# Set the paths to the folders containing the training data
train_data_dir = "Training/"
# Set the number of classes and batch size
num_classes = 3
batch_size = 32
# Set the image size and input shape
img_width, img_height = 100, 100
2023-06-01 21:50:13 +02:00
input_shape = (img_width, img_height, 1)
2023-06-01 20:15:18 +02:00
train_ds = tf.keras.utils.image_dataset_from_directory(
train_data_dir,
validation_split=0.2,
subset="training",
shuffle=True,
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.utils.image_dataset_from_directory(
train_data_dir,
validation_split=0.2,
subset="validation",
shuffle=True,
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)
2023-06-01 21:50:13 +02:00
train_ds = train_ds.map(normalize)
val_ds = val_ds.map(normalize)
# Define the model architecture
model = tf.keras.Sequential([
layers.Conv2D(16, 3, padding='same', activation='relu', input_shape=(img_height, img_width, 1)),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes, activation='softmax')
])
# Compile the model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
epochs=10
# Train the model
model.fit(train_ds,
validation_data=val_ds,
epochs=epochs)
# Save the trained model
model.save('trained_model.h5')