from src.Unet import Unet from src.loss import jaccard_loss from src.metrics import IOU from src.consts import EPOCHS, STEPS, SEED from src.generators import create_generators from src.helpers import create_folder from src.utils import plot_predictions_grid from tensorflow.keras.callbacks import ModelCheckpoint import tensorflow as tf import sys import random import os if __name__ == "__main__": model_name = "models/unet.h5" if(len(sys.argv) <= 1): model = Unet(num_classes=1).build_model() compile_params ={ 'loss':jaccard_loss(smooth=90), 'optimizer':'rmsprop', 'metrics':[IOU] } model.compile(**compile_params) modelcheckpoint = ModelCheckpoint(model_name, monitor='val_loss', mode='auto', verbose=1, save_best_only=True) train_gen = create_generators('training', SEED) val_gen = create_generators('validation', SEED) history = model.fit_generator(train_gen, validation_data=val_gen, epochs=EPOCHS, steps_per_epoch=STEPS, validation_steps = STEPS, shuffle=True, ) create_folder('models', '.') model.save(filepath=model_name) elif(sys.argv[1] == '--predictions'): img_names = [random.choice(os.listdir('./images/rgb/img')) for _ in range(3)] loaded_model = tf.keras.models.load_model("./" + model_name, compile=False) plot_predictions_grid(img_names, loaded_model)