69 lines
1.9 KiB
Python
69 lines
1.9 KiB
Python
from tensorflow.keras.models import Sequential, save_model
|
|
from tensorflow.keras.layers import Dense, Flatten, Conv2D
|
|
from tensorflow.keras.losses import sparse_categorical_crossentropy
|
|
from tensorflow.keras.optimizers import Adam
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
from config import RESOURCE_DIR
|
|
import os
|
|
|
|
# labels
|
|
|
|
labels = ["cabbage", "carrot", "corn", "lettuce", "paprika", "potato", "tomato"]
|
|
|
|
# Data configuration
|
|
training_set_folder = os.path.join(RESOURCE_DIR, 'smaller_train')
|
|
test_set_folder = os.path.join(RESOURCE_DIR, 'smaller_test')
|
|
|
|
# Model config
|
|
batch_size = 25
|
|
img_width, img_height, img_num_channels = 25, 25, 3
|
|
loss_function = sparse_categorical_crossentropy
|
|
no_classes = 7
|
|
no_epochs = 40
|
|
optimizer = Adam()
|
|
verbosity = 1
|
|
|
|
# Determine shape of the data
|
|
input_shape = (img_width, img_height, img_num_channels)
|
|
|
|
# Create a generator
|
|
train_datagen = ImageDataGenerator(
|
|
rescale=1./255
|
|
)
|
|
train_datagen = train_datagen.flow_from_directory(
|
|
training_set_folder,
|
|
save_to_dir=os.path.join(RESOURCE_DIR, "adapted-images"),
|
|
save_format='jpeg',
|
|
batch_size=batch_size,
|
|
target_size=(25, 25),
|
|
class_mode='sparse')
|
|
|
|
# Create the model
|
|
model = Sequential()
|
|
model.add(Conv2D(16, kernel_size=(5, 5), activation='relu', input_shape=input_shape))
|
|
model.add(Conv2D(32, kernel_size=(5, 5), activation='relu'))
|
|
model.add(Conv2D(64, kernel_size=(5, 5), activation='relu'))
|
|
model.add(Conv2D(128, kernel_size=(5, 5), activation='relu'))
|
|
model.add(Flatten())
|
|
model.add(Dense(16, activation='relu'))
|
|
model.add(Dense(no_classes, activation='softmax'))
|
|
|
|
# Display a model summary
|
|
model.summary()
|
|
|
|
# Compile the model
|
|
model.compile(loss=loss_function,
|
|
optimizer=optimizer,
|
|
metrics=['accuracy'])
|
|
|
|
# Start training
|
|
model.fit(
|
|
train_datagen,
|
|
epochs=no_epochs,
|
|
shuffle=False)
|
|
|
|
# Saving model
|
|
filepath = os.path.join(RESOURCE_DIR, 'saved_model')
|
|
save_model(model, filepath)
|
|
|