projekt_widzenie/training_model/train_vgg16.py
2023-02-02 18:12:17 +01:00

79 lines
1.7 KiB
Python

# %% [markdown]
# # Import the required libraries
# %%
import pandas as pd
import numpy as np
import tensorflow as tf
batch_size=64
img_height=224
img_width=224
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
train_ds = tf.keras.utils.image_dataset_from_directory(
"combined_data_skeleton/",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
test_ds = tf.keras.utils.image_dataset_from_directory(
"test_data_own_skeleton",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
class_names = train_ds.class_names
print("Class names:",class_names)
print("Total classes:",len(class_names))
from tensorflow.keras import Sequential
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from keras import models, layers
from tensorflow.keras.applications import VGG16
img_size = [img_height, img_width]
vgg_model = VGG16(input_shape=img_size + [3], weights='imagenet', include_top=False)
x = layers.Flatten()(vgg_model.output)
x = layers.Dense(len(class_names), activation='softmax')(x)
model = models.Model(vgg_model.input, x)
# for layer in model.layers[:-1]:
# layer.trainable = False
model.compile(loss='sparse_categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
from tensorflow.keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(
monitor="val_accuracy", patience=12,
restore_best_weights=True,
)
history = model.fit(
train_ds,
batch_size=128,
validation_data=test_ds,
validation_batch_size=128,
epochs=50,
callbacks=[early_stopping]
)
model.save('VGG16_sign_char_detection_model')