70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
import tensorflow as tf
|
|
from keras import layers
|
|
|
|
# Normalizes the pixel values of an image to the range [0, 1].
|
|
|
|
|
|
def normalize(image, label):
|
|
return image / 255, label
|
|
|
|
|
|
# Set the paths to the folder containing the training data
|
|
train_data_dir = "Network/Training/"
|
|
# Set the number of classes and batch size
|
|
num_classes = 3
|
|
batch_size = 32
|
|
# Set the image size and input shape
|
|
img_width, img_height = 100, 100
|
|
input_shape = (img_width, img_height, 1)
|
|
# Load the training and validation data
|
|
train_ds = tf.keras.utils.image_dataset_from_directory(
|
|
train_data_dir,
|
|
validation_split=0.2,
|
|
subset="training",
|
|
shuffle=True,
|
|
seed=123,
|
|
image_size=(img_height, img_width),
|
|
batch_size=batch_size)
|
|
|
|
val_ds = tf.keras.utils.image_dataset_from_directory(
|
|
train_data_dir,
|
|
validation_split=0.2,
|
|
subset="validation",
|
|
shuffle=True,
|
|
seed=123,
|
|
image_size=(img_height, img_width),
|
|
batch_size=batch_size)
|
|
# Get the class names
|
|
class_names = train_ds.class_names
|
|
print(class_names)
|
|
# Normalize the training and validation data
|
|
train_ds = train_ds.map(normalize)
|
|
val_ds = val_ds.map(normalize)
|
|
# Define the model architecture
|
|
model = tf.keras.Sequential([
|
|
layers.Conv2D(16, 3, padding='same', activation='relu',
|
|
input_shape=(img_height, img_width, 1)),
|
|
layers.MaxPooling2D(),
|
|
layers.Conv2D(32, 3, padding='same', activation='relu'),
|
|
layers.MaxPooling2D(),
|
|
layers.Conv2D(64, 3, padding='same', activation='relu'),
|
|
layers.MaxPooling2D(),
|
|
layers.Flatten(),
|
|
layers.Dense(128, activation='relu'),
|
|
layers.Dense(num_classes, activation='softmax')
|
|
])
|
|
# Compile the model
|
|
model.compile(optimizer='adam',
|
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(
|
|
from_logits=True),
|
|
metrics=['accuracy'])
|
|
# Print the model summary
|
|
model.summary()
|
|
# Train the model
|
|
epochs = 10
|
|
model.fit(train_ds,
|
|
validation_data=val_ds,
|
|
epochs=epochs)
|
|
# Save the trained model
|
|
model.save('Network/trained_model.h5')
|