52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
|
|
from src.Unet import Unet
|
|
from src.loss import jaccard_loss
|
|
from src.metrics import IOU
|
|
from src.consts import EPOCHS, STEPS, SEED
|
|
from src.generators import create_generators
|
|
from src.helpers import create_folder
|
|
from src.utils import plot_predictions_grid
|
|
from tensorflow.keras.callbacks import ModelCheckpoint
|
|
import tensorflow as tf
|
|
import sys
|
|
import random
|
|
import os
|
|
|
|
if __name__ == "__main__":
|
|
model_name = "models/unet.h5"
|
|
if(len(sys.argv) <= 1):
|
|
model = Unet(num_classes=1).build_model()
|
|
|
|
compile_params ={
|
|
'loss':jaccard_loss(smooth=90),
|
|
'optimizer':'rmsprop',
|
|
'metrics':[IOU]
|
|
}
|
|
|
|
model.compile(**compile_params)
|
|
|
|
modelcheckpoint = ModelCheckpoint(model_name,
|
|
monitor='val_loss',
|
|
mode='auto',
|
|
verbose=1,
|
|
save_best_only=True)
|
|
|
|
|
|
train_gen = create_generators('training', SEED)
|
|
val_gen = create_generators('validation', SEED)
|
|
|
|
history = model.fit_generator(train_gen,
|
|
validation_data=val_gen,
|
|
epochs=EPOCHS,
|
|
steps_per_epoch=STEPS,
|
|
validation_steps = STEPS,
|
|
shuffle=True,
|
|
)
|
|
create_folder('models', '.')
|
|
model.save(filepath=model_name)
|
|
|
|
elif(sys.argv[1] == '--predictions'):
|
|
|
|
img_names = [random.choice(os.listdir('./images/rgb/img')) for _ in range(3)]
|
|
loaded_model = tf.keras.models.load_model("./" + model_name, compile=False)
|
|
plot_predictions_grid(img_names, loaded_model) |