adsdfasdga
This commit is contained in:
parent
b35f0b879e
commit
59fe53fa40
56
main.ipynb
56
main.ipynb
File diff suppressed because one or more lines are too long
14
main.py
14
main.py
@ -5,11 +5,16 @@ from src.metrics import IOU
|
|||||||
from src.consts import EPOCHS, STEPS, SEED
|
from src.consts import EPOCHS, STEPS, SEED
|
||||||
from src.generators import create_generators
|
from src.generators import create_generators
|
||||||
from src.helpers import create_folder
|
from src.helpers import create_folder
|
||||||
|
from src.utils import plot_predictions_grid
|
||||||
from tensorflow.keras.callbacks import ModelCheckpoint
|
from tensorflow.keras.callbacks import ModelCheckpoint
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
import os
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
model_name = "models/unet.h5"
|
||||||
|
if(len(sys.argv) <= 1):
|
||||||
model = Unet(num_classes=1).build_model()
|
model = Unet(num_classes=1).build_model()
|
||||||
|
|
||||||
compile_params ={
|
compile_params ={
|
||||||
@ -20,7 +25,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
model.compile(**compile_params)
|
model.compile(**compile_params)
|
||||||
|
|
||||||
model_name = "models/unet.h5"
|
|
||||||
modelcheckpoint = ModelCheckpoint(model_name,
|
modelcheckpoint = ModelCheckpoint(model_name,
|
||||||
monitor='val_loss',
|
monitor='val_loss',
|
||||||
mode='auto',
|
mode='auto',
|
||||||
@ -40,3 +44,9 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
create_folder('models', '.')
|
create_folder('models', '.')
|
||||||
model.save(filepath=model_name)
|
model.save(filepath=model_name)
|
||||||
|
|
||||||
|
elif(sys.argv[1] == '--predictions'):
|
||||||
|
|
||||||
|
img_names = [random.choice(os.listdir('./images/rgb/img')) for _ in range(3)]
|
||||||
|
loaded_model = tf.keras.models.load_model("./" + model_name, compile=False)
|
||||||
|
plot_predictions_grid(img_names, loaded_model)
|
@ -5,13 +5,12 @@ import matplotlib.pyplot as plt
|
|||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
import random
|
import random
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from consts import SEED, JPG_IMAGES, FC_DIR, FEATURES, MASK_DIR, LABELS, RGB_DIR
|
from src.consts import SEED, JPG_IMAGES, FC_DIR, FEATURES, MASK_DIR, LABELS, RGB_DIR
|
||||||
from helpers import create_folder, convert_tif_to_jpg, progress_bar
|
from src.helpers import create_folder, convert_tif_to_jpg, progress_bar
|
||||||
from utils import plot_image_grid
|
from src.utils import plot_image_grid
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dirs = os.listdir(FEATURES)
|
dirs = os.listdir(FEATURES)
|
@ -1,6 +1,6 @@
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
FEATURES ='../data/train_features'
|
FEATURES ='./data/train_features'
|
||||||
LABELS = '../data/train_labels'
|
LABELS = './data/train_labels'
|
||||||
JPG_IMAGES = 'images'
|
JPG_IMAGES = 'images'
|
||||||
RGB_DIR = "rgb/img"
|
RGB_DIR = "rgb/img"
|
||||||
FC_DIR = "fc/img"
|
FC_DIR = "fc/img"
|
||||||
@ -11,5 +11,5 @@ BATCH = 8
|
|||||||
IMG_SIZE = (512,512)
|
IMG_SIZE = (512,512)
|
||||||
SEED = 7
|
SEED = 7
|
||||||
|
|
||||||
EPOCHS = 10
|
EPOCHS = 5
|
||||||
STEPS = 10
|
STEPS = 10
|
34
src/utils.py
34
src/utils.py
@ -1,7 +1,9 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import cv2
|
import cv2
|
||||||
import os
|
import os
|
||||||
from consts import JPG_IMAGES, RGB_DIR, MASK_DIR
|
import numpy as np
|
||||||
|
from src.consts import JPG_IMAGES, RGB_DIR, MASK_DIR
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
def plot_image_grid(image_names):
|
def plot_image_grid(image_names):
|
||||||
|
|
||||||
@ -25,3 +27,33 @@ def plot_image_grid(image_names):
|
|||||||
#show
|
#show
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_predictions_grid(image_names, model):
|
||||||
|
#number of img rows
|
||||||
|
n_row= len(image_names)
|
||||||
|
|
||||||
|
plt.subplots(n_row, 3, figsize=(6, 3*n_row))
|
||||||
|
for i, img in enumerate(image_names):
|
||||||
|
r_img = cv2.imread(os.path.join(JPG_IMAGES, RGB_DIR, img))
|
||||||
|
m_img = cv2.imread(os.path.join(JPG_IMAGES, MASK_DIR, img))
|
||||||
|
r_img_ext = np.expand_dims(r_img /255.0, axis=0)
|
||||||
|
pred = model.predict(r_img_ext)
|
||||||
|
|
||||||
|
plt.subplot(n_row, 3, i*3+1)
|
||||||
|
plt.title(f'RGB - {img}')
|
||||||
|
plt.axis('off')
|
||||||
|
plt.imshow(r_img, cmap=None)
|
||||||
|
|
||||||
|
plt.subplot(n_row, 3, i*3+2)
|
||||||
|
plt.title(f'Mask - {img}')
|
||||||
|
plt.axis('off')
|
||||||
|
plt.imshow(m_img, cmap='gray')
|
||||||
|
|
||||||
|
plt.subplot(n_row, 3, i*3+3)
|
||||||
|
plt.title(f'Predicted Mask - {img}')
|
||||||
|
plt.axis('off')
|
||||||
|
plt.imshow(tf.keras.preprocessing.image.array_to_img(pred[0]>0.5), cmap='gray')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
Loading…
Reference in New Issue
Block a user