42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
|
|
from src.Unet import Unet
|
|
from src.loss import jaccard_loss
|
|
from src.metrics import IOU
|
|
from src.consts import EPOCHS, STEPS, SEED
|
|
from src.generators import create_generators
|
|
from src.helpers import create_folder
|
|
from tensorflow.keras.callbacks import ModelCheckpoint
|
|
import tensorflow as tf
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = Unet(num_classes=1).build_model()
|
|
|
|
compile_params ={
|
|
'loss':jaccard_loss(smooth=90),
|
|
'optimizer':'rmsprop',
|
|
'metrics':[IOU]
|
|
}
|
|
|
|
model.compile(**compile_params)
|
|
|
|
model_name = "models/unet.h5"
|
|
modelcheckpoint = ModelCheckpoint(model_name,
|
|
monitor='val_loss',
|
|
mode='auto',
|
|
verbose=1,
|
|
save_best_only=True)
|
|
|
|
|
|
train_gen = create_generators('training', SEED)
|
|
val_gen = create_generators('validation', SEED)
|
|
|
|
history = model.fit_generator(train_gen,
|
|
validation_data=val_gen,
|
|
epochs=EPOCHS,
|
|
steps_per_epoch=STEPS,
|
|
validation_steps = STEPS,
|
|
shuffle=True,
|
|
)
|
|
create_folder('models', '.')
|
|
model.save(filepath=model_name) |