widzenie-komputerowe-projekt/models.py

34 lines
1.5 KiB
Python
Raw Permalink Normal View History

2023-02-01 18:42:47 +01:00
import cv2
import numpy as np
import tensorflow as tf
class ClassificationModel:
def __init__(self, model_path: str = './frozen_models/frozen_graph_best_vgg.pb', model_type: str = "VGG16") -> None:
print("loading classification model")
self.model_path = model_path
self.model_func = self.init_frozen_func()
self.model_type = model_type
self.class_names=sorted(['Fish', "Jellyfish", 'Lionfish', 'Shark', 'Stingray', 'Turtle'])
def wrap_frozen_graph(self, graph_def, inputs, outputs):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
return wrapped_import.prune(
tf.nest.map_structure(import_graph.as_graph_element, inputs),
tf.nest.map_structure(import_graph.as_graph_element, outputs))
def init_frozen_func(self):
with tf.io.gfile.GFile(self.model_path, "rb") as f:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(f.read())
return self.wrap_frozen_graph(graph_def=graph_def,
inputs=["x:0"],
outputs=["Identity:0"])
def predict(self, image, shape=(224, 224)):
image = cv2.resize(image, shape)
pred = self.model_func(x=tf.convert_to_tensor(image[None, :], dtype='float32'))
return self.class_names[np.argmax(pred)]