wko-on-cloud-n/main.py
2022-02-16 22:47:40 +01:00

42 lines
1.3 KiB
Python

from src.Unet import Unet
from src.loss import jaccard_loss
from src.metrics import IOU
from src.consts import EPOCHS, STEPS, SEED
from src.generators import create_generators
from src.helpers import create_folder
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow as tf
if __name__ == "__main__":
model = Unet(num_classes=1).build_model()
compile_params ={
'loss':jaccard_loss(smooth=90),
'optimizer':'rmsprop',
'metrics':[IOU]
}
model.compile(**compile_params)
model_name = "models/unet.h5"
modelcheckpoint = ModelCheckpoint(model_name,
monitor='val_loss',
mode='auto',
verbose=1,
save_best_only=True)
train_gen = create_generators('training', SEED)
val_gen = create_generators('validation', SEED)
history = model.fit_generator(train_gen,
validation_data=val_gen,
epochs=EPOCHS,
steps_per_epoch=STEPS,
validation_steps = STEPS,
shuffle=True,
)
create_folder('models', '.')
model.save(filepath=model_name)