projekt_widzenie/training_model/train_resnet.py
2023-02-02 18:12:17 +01:00

75 lines
1.5 KiB
Python

# %% [markdown]
# # Import the required libraries
# %%
import pandas as pd
import numpy as np
import tensorflow as tf
batch_size=64
img_height=224
img_width=224
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
train_ds = tf.keras.utils.image_dataset_from_directory(
"combined_data_skeleton/",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
test_ds = tf.keras.utils.image_dataset_from_directory(
"test_data_own_skeleton",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
class_names = train_ds.class_names
print("Class names:",class_names)
print("Total classes:",len(class_names))
from keras import models, layers
from tensorflow.keras.applications import ResNet101
img_size = [img_height, img_width]
res_model = ResNet101(input_shape=img_size + [3], weights='imagenet', include_top=False)
x = layers.Flatten()(res_model.output)
x = layers.Dense(len(class_names), activation='softmax')(x)
model = models.Model(res_model.input, x)
# for layer in model.layers[:-1]:
# layer.trainable = False
model.compile(loss='sparse_categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
from tensorflow.keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(
monitor="val_accuracy", patience=12,
restore_best_weights=True,
)
history = model.fit(
train_ds,
batch_size=64,
validation_data=test_ds,
validation_batch_size=64,
epochs=50,
callbacks=[early_stopping]
)
model.save('resnet101_sign_char_detection_model')