41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
|
|
from src.Unet import Unet
|
|
from src.loss import jaccard_loss
|
|
from src.metrics import IOU
|
|
from src.consts import EPOCHS, STEPS, SEED
|
|
from src.generators import create_generators
|
|
from tensorflow.keras.callbacks import ModelCheckpoint
|
|
import tensorflow as tf
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = Unet(num_classes=1).build_model()
|
|
|
|
compile_params ={
|
|
'loss':jaccard_loss(smooth=90),
|
|
'optimizer':'rmsprop',
|
|
'metrics':[IOU]
|
|
}
|
|
|
|
|
|
model.compile(**compile_params)
|
|
# tf.keras.utils.plot_model(model, show_shapes=True)
|
|
|
|
model_name = "models/unet.h5"
|
|
modelcheckpoint = ModelCheckpoint(model_name,
|
|
monitor='val_loss',
|
|
mode='auto',
|
|
verbose=1,
|
|
save_best_only=True)
|
|
|
|
|
|
train_gen = create_generators('training', SEED)
|
|
val_gen = create_generators('validation', SEED)
|
|
|
|
history = model.fit_generator(train_gen,
|
|
validation_data=val_gen,
|
|
epochs=EPOCHS,
|
|
steps_per_epoch=STEPS,
|
|
validation_steps = STEPS,
|
|
shuffle=True,
|
|
) |