79 lines
1.7 KiB
Python
79 lines
1.7 KiB
Python
|
# %% [markdown]
|
||
|
# # Import the required libraries
|
||
|
|
||
|
# %%
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
batch_size=64
|
||
|
img_height=224
|
||
|
img_width=224
|
||
|
import os
|
||
|
os.environ["CUDA_VISIBLE_DEVICES"]="1"
|
||
|
|
||
|
train_ds = tf.keras.utils.image_dataset_from_directory(
|
||
|
"combined_data_skeleton/",
|
||
|
seed=123,
|
||
|
image_size=(img_height, img_width),
|
||
|
batch_size=batch_size)
|
||
|
|
||
|
test_ds = tf.keras.utils.image_dataset_from_directory(
|
||
|
"test_data_own_skeleton",
|
||
|
seed=123,
|
||
|
image_size=(img_height, img_width),
|
||
|
batch_size=batch_size
|
||
|
)
|
||
|
|
||
|
class_names = train_ds.class_names
|
||
|
print("Class names:",class_names)
|
||
|
print("Total classes:",len(class_names))
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
from tensorflow.keras import Sequential
|
||
|
from tensorflow.keras import layers
|
||
|
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
|
||
|
|
||
|
|
||
|
from keras import models, layers
|
||
|
from tensorflow.keras.applications import VGG16
|
||
|
|
||
|
img_size = [img_height, img_width]
|
||
|
vgg_model = VGG16(input_shape=img_size + [3], weights='imagenet', include_top=False)
|
||
|
|
||
|
x = layers.Flatten()(vgg_model.output)
|
||
|
x = layers.Dense(len(class_names), activation='softmax')(x)
|
||
|
|
||
|
model = models.Model(vgg_model.input, x)
|
||
|
|
||
|
# for layer in model.layers[:-1]:
|
||
|
# layer.trainable = False
|
||
|
|
||
|
model.compile(loss='sparse_categorical_crossentropy',
|
||
|
optimizer='rmsprop',
|
||
|
metrics=['accuracy'])
|
||
|
|
||
|
|
||
|
|
||
|
from tensorflow.keras.callbacks import EarlyStopping
|
||
|
|
||
|
early_stopping = EarlyStopping(
|
||
|
monitor="val_accuracy", patience=12,
|
||
|
restore_best_weights=True,
|
||
|
|
||
|
)
|
||
|
|
||
|
|
||
|
history = model.fit(
|
||
|
train_ds,
|
||
|
batch_size=128,
|
||
|
validation_data=test_ds,
|
||
|
validation_batch_size=128,
|
||
|
epochs=50,
|
||
|
callbacks=[early_stopping]
|
||
|
)
|
||
|
model.save('VGG16_sign_char_detection_model')
|