save model after fitting
This commit is contained in:
parent
af76c778d7
commit
8d02fa043c
5
main.py
5
main.py
@ -4,6 +4,7 @@ from src.loss import jaccard_loss
|
|||||||
from src.metrics import IOU
|
from src.metrics import IOU
|
||||||
from src.consts import EPOCHS, STEPS, SEED
|
from src.consts import EPOCHS, STEPS, SEED
|
||||||
from src.generators import create_generators
|
from src.generators import create_generators
|
||||||
|
from src.helpers import create_folder
|
||||||
from tensorflow.keras.callbacks import ModelCheckpoint
|
from tensorflow.keras.callbacks import ModelCheckpoint
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -17,9 +18,7 @@ if __name__ == "__main__":
|
|||||||
'metrics':[IOU]
|
'metrics':[IOU]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
model.compile(**compile_params)
|
model.compile(**compile_params)
|
||||||
# tf.keras.utils.plot_model(model, show_shapes=True)
|
|
||||||
|
|
||||||
model_name = "models/unet.h5"
|
model_name = "models/unet.h5"
|
||||||
modelcheckpoint = ModelCheckpoint(model_name,
|
modelcheckpoint = ModelCheckpoint(model_name,
|
||||||
@ -39,3 +38,5 @@ if __name__ == "__main__":
|
|||||||
validation_steps = STEPS,
|
validation_steps = STEPS,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
create_folder('models', '.')
|
||||||
|
model.save(filepath=model_name)
|
Loading…
Reference in New Issue
Block a user