save model after fitting
This commit is contained in:
parent
af76c778d7
commit
8d02fa043c
7
main.py
7
main.py
@ -4,6 +4,7 @@ from src.loss import jaccard_loss
|
|||||||
from src.metrics import IOU
|
from src.metrics import IOU
|
||||||
from src.consts import EPOCHS, STEPS, SEED
|
from src.consts import EPOCHS, STEPS, SEED
|
||||||
from src.generators import create_generators
|
from src.generators import create_generators
|
||||||
|
from src.helpers import create_folder
|
||||||
from tensorflow.keras.callbacks import ModelCheckpoint
|
from tensorflow.keras.callbacks import ModelCheckpoint
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -16,10 +17,8 @@ if __name__ == "__main__":
|
|||||||
'optimizer':'rmsprop',
|
'optimizer':'rmsprop',
|
||||||
'metrics':[IOU]
|
'metrics':[IOU]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
model.compile(**compile_params)
|
model.compile(**compile_params)
|
||||||
# tf.keras.utils.plot_model(model, show_shapes=True)
|
|
||||||
|
|
||||||
model_name = "models/unet.h5"
|
model_name = "models/unet.h5"
|
||||||
modelcheckpoint = ModelCheckpoint(model_name,
|
modelcheckpoint = ModelCheckpoint(model_name,
|
||||||
@ -38,4 +37,6 @@ if __name__ == "__main__":
|
|||||||
steps_per_epoch=STEPS,
|
steps_per_epoch=STEPS,
|
||||||
validation_steps = STEPS,
|
validation_steps = STEPS,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
create_folder('models', '.')
|
||||||
|
model.save(filepath=model_name)
|
Loading…
Reference in New Issue
Block a user