34 lines
1.5 KiB
Python
34 lines
1.5 KiB
Python
import cv2
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
class ClassificationModel:
|
|
def __init__(self, model_path: str = './frozen_models/frozen_graph_best_vgg.pb', model_type: str = "VGG16") -> None:
|
|
print("loading classification model")
|
|
self.model_path = model_path
|
|
self.model_func = self.init_frozen_func()
|
|
self.model_type = model_type
|
|
self.class_names=sorted(['Fish', "Jellyfish", 'Lionfish', 'Shark', 'Stingray', 'Turtle'])
|
|
|
|
|
|
def wrap_frozen_graph(self, graph_def, inputs, outputs):
|
|
def _imports_graph_def():
|
|
tf.compat.v1.import_graph_def(graph_def, name="")
|
|
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
|
|
import_graph = wrapped_import.graph
|
|
return wrapped_import.prune(
|
|
tf.nest.map_structure(import_graph.as_graph_element, inputs),
|
|
tf.nest.map_structure(import_graph.as_graph_element, outputs))
|
|
|
|
def init_frozen_func(self):
|
|
with tf.io.gfile.GFile(self.model_path, "rb") as f:
|
|
graph_def = tf.compat.v1.GraphDef()
|
|
loaded = graph_def.ParseFromString(f.read())
|
|
return self.wrap_frozen_graph(graph_def=graph_def,
|
|
inputs=["x:0"],
|
|
outputs=["Identity:0"])
|
|
|
|
def predict(self, image, shape=(224, 224)):
|
|
image = cv2.resize(image, shape)
|
|
pred = self.model_func(x=tf.convert_to_tensor(image[None, :], dtype='float32'))
|
|
return self.class_names[np.argmax(pred)] |