41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
|
|
||
|
from src.Unet import Unet
|
||
|
from src.loss import jaccard_loss
|
||
|
from src.metrics import IOU
|
||
|
from src.consts import EPOCHS, STEPS, SEED
|
||
|
from src.generators import create_generators
|
||
|
from tensorflow.keras.callbacks import ModelCheckpoint
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
model = Unet(num_classes=1).build_model()
|
||
|
|
||
|
compile_params ={
|
||
|
'loss':jaccard_loss(smooth=90),
|
||
|
'optimizer':'rmsprop',
|
||
|
'metrics':[IOU]
|
||
|
}
|
||
|
|
||
|
|
||
|
model.compile(**compile_params)
|
||
|
# tf.keras.utils.plot_model(model, show_shapes=True)
|
||
|
|
||
|
model_name = "models/unet.h5"
|
||
|
modelcheckpoint = ModelCheckpoint(model_name,
|
||
|
monitor='val_loss',
|
||
|
mode='auto',
|
||
|
verbose=1,
|
||
|
save_best_only=True)
|
||
|
|
||
|
|
||
|
train_gen = create_generators('training', SEED)
|
||
|
val_gen = create_generators('validation', SEED)
|
||
|
|
||
|
history = model.fit_generator(train_gen,
|
||
|
validation_data=val_gen,
|
||
|
epochs=EPOCHS,
|
||
|
steps_per_epoch=STEPS,
|
||
|
validation_steps = STEPS,
|
||
|
shuffle=True,
|
||
|
)
|