widzenie-komputerowe-projekt/training/model_freeze.py

15 lines
619 B
Python
Raw Normal View History

2023-02-01 18:42:47 +01:00
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
def freeze_model(model, output_path, name):
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir=output_path,
name=f"{name}.pb",
as_text=False)
return