automatyczny_kelner/Network/Network.py
2023-06-02 00:14:50 +02:00

68 lines
2.1 KiB
Python

import tensorflow as tf
from keras import layers
from keras.models import Sequential
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
# Normalizes the pixel values of an image to the range [0, 1].
def normalize(image, label):
return image / 255, label
# Set the paths to the folder containing the training data
train_data_dir = "Network/Training/"
# Set the number of classes and batch size
num_classes = 3
batch_size = 32
# Set the image size and input shape
img_width, img_height = 100, 100
input_shape = (img_width, img_height, 1)
# Load the training and validation data
train_ds = tf.keras.utils.image_dataset_from_directory(
train_data_dir,
validation_split=0.2,
subset="training",
shuffle=True,
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.utils.image_dataset_from_directory(
train_data_dir,
validation_split=0.2,
subset="validation",
shuffle=True,
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
# Get the class names
class_names = train_ds.class_names
print(class_names)
# Normalize the training and validation data
train_ds = train_ds.map(normalize)
val_ds = val_ds.map(normalize)
# Define the model architecture
model = tf.keras.Sequential([
layers.Conv2D(16, 3, padding='same', activation='relu', input_shape=(img_height, img_width, 1)),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes, activation='softmax')
])
# Compile the model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Print the model summary
model.summary()
# Train the model
epochs=10
model.fit(train_ds,
validation_data=val_ds,
epochs=epochs)
# Save the trained model
model.save('Network/trained_model.h5')