from src.Unet import Unet from src.loss import jaccard_loss from src.metrics import IOU from src.consts import EPOCHS, STEPS, SEED from src.generators import create_generators from tensorflow.keras.callbacks import ModelCheckpoint import tensorflow as tf if __name__ == "__main__": model = Unet(num_classes=1).build_model() compile_params ={ 'loss':jaccard_loss(smooth=90), 'optimizer':'rmsprop', 'metrics':[IOU] } model.compile(**compile_params) # tf.keras.utils.plot_model(model, show_shapes=True) model_name = "models/unet.h5" modelcheckpoint = ModelCheckpoint(model_name, monitor='val_loss', mode='auto', verbose=1, save_best_only=True) train_gen = create_generators('training', SEED) val_gen = create_generators('validation', SEED) history = model.fit_generator(train_gen, validation_data=val_gen, epochs=EPOCHS, steps_per_epoch=STEPS, validation_steps = STEPS, shuffle=True, )