wko-on-cloud-n/main.py

52 lines
1.7 KiB
Python
Raw Normal View History

2022-02-16 22:29:01 +01:00
from src.Unet import Unet
from src.loss import jaccard_loss
from src.metrics import IOU
from src.consts import EPOCHS, STEPS, SEED
from src.generators import create_generators
2022-02-16 22:47:40 +01:00
from src.helpers import create_folder
2022-02-17 03:04:00 +01:00
from src.utils import plot_predictions_grid
2022-02-16 22:29:01 +01:00
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow as tf
2022-02-17 03:04:00 +01:00
import sys
import random
import os
2022-02-16 22:29:01 +01:00
if __name__ == "__main__":
model_name = "models/unet.h5"
2022-02-17 03:04:00 +01:00
if(len(sys.argv) <= 1):
model = Unet(num_classes=1).build_model()
compile_params ={
'loss':jaccard_loss(smooth=90),
'optimizer':'rmsprop',
'metrics':[IOU]
}
model.compile(**compile_params)
modelcheckpoint = ModelCheckpoint(model_name,
monitor='val_loss',
mode='auto',
verbose=1,
save_best_only=True)
train_gen = create_generators('training', SEED)
val_gen = create_generators('validation', SEED)
history = model.fit_generator(train_gen,
validation_data=val_gen,
epochs=EPOCHS,
steps_per_epoch=STEPS,
validation_steps = STEPS,
shuffle=True,
)
create_folder('models', '.')
model.save(filepath=model_name)
elif(sys.argv[1] == '--predictions'):
img_names = [random.choice(os.listdir('./images/rgb/img')) for _ in range(3)]
loaded_model = tf.keras.models.load_model("./" + model_name, compile=False)
plot_predictions_grid(img_names, loaded_model)