import cv2 import numpy as np import tensorflow as tf class ClassificationModel: def __init__(self, model_path: str = './frozen_models/frozen_graph_best_vgg.pb', model_type: str = "VGG16") -> None: print("loading classification model") self.model_path = model_path self.model_func = self.init_frozen_func() self.model_type = model_type self.class_names=sorted(['Fish', "Jellyfish", 'Lionfish', 'Shark', 'Stingray', 'Turtle']) def wrap_frozen_graph(self, graph_def, inputs, outputs): def _imports_graph_def(): tf.compat.v1.import_graph_def(graph_def, name="") wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, []) import_graph = wrapped_import.graph return wrapped_import.prune( tf.nest.map_structure(import_graph.as_graph_element, inputs), tf.nest.map_structure(import_graph.as_graph_element, outputs)) def init_frozen_func(self): with tf.io.gfile.GFile(self.model_path, "rb") as f: graph_def = tf.compat.v1.GraphDef() loaded = graph_def.ParseFromString(f.read()) return self.wrap_frozen_graph(graph_def=graph_def, inputs=["x:0"], outputs=["Identity:0"]) def predict(self, image, shape=(224, 224)): image = cv2.resize(image, shape) pred = self.model_func(x=tf.convert_to_tensor(image[None, :], dtype='float32')) return self.class_names[np.argmax(pred)]