78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
|
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Saving utilities to support Python's Pickle protocol."""
|
||
|
import os
|
||
|
import tempfile
|
||
|
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras.saving import saving_lib
|
||
|
|
||
|
|
||
|
def deserialize_model_from_bytecode(serialized_model):
|
||
|
"""Reconstruct a Model from the output of `serialize_model_as_bytecode`.
|
||
|
|
||
|
Args:
|
||
|
serialized_model: (bytes) return value from
|
||
|
`serialize_model_as_bytecode`.
|
||
|
|
||
|
Returns:
|
||
|
Keras Model instance.
|
||
|
"""
|
||
|
# Note: we don't use a RAM path for this because zipfile cannot write
|
||
|
# to such paths.
|
||
|
temp_dir = tempfile.mkdtemp()
|
||
|
try:
|
||
|
filepath = os.path.join(temp_dir, "model.keras")
|
||
|
with open(filepath, "wb") as f:
|
||
|
f.write(serialized_model)
|
||
|
# When loading, direct import will work for most custom objects
|
||
|
# though it will require get_config() to be implemented.
|
||
|
# Some custom objects (e.g. an activation in a Dense layer,
|
||
|
# serialized as a string by Dense.get_config()) will require
|
||
|
# a custom_object_scope.
|
||
|
model = saving_lib.load_model(filepath, safe_mode=False)
|
||
|
except Exception as e:
|
||
|
raise e
|
||
|
else:
|
||
|
return model
|
||
|
finally:
|
||
|
tf.io.gfile.rmtree(temp_dir)
|
||
|
|
||
|
|
||
|
def serialize_model_as_bytecode(model):
|
||
|
"""Convert a Keras Model into a bytecode representation for pickling.
|
||
|
|
||
|
Args:
|
||
|
model: Keras Model instance.
|
||
|
|
||
|
Returns:
|
||
|
Tuple that can be read by `deserialize_from_bytecode`.
|
||
|
"""
|
||
|
# Note: we don't use a RAM path for this because zipfile cannot write
|
||
|
# to such paths.
|
||
|
temp_dir = tempfile.mkdtemp()
|
||
|
try:
|
||
|
filepath = os.path.join(temp_dir, "model.keras")
|
||
|
saving_lib.save_model(model, filepath)
|
||
|
with open(filepath, "rb") as f:
|
||
|
data = f.read()
|
||
|
except Exception as e:
|
||
|
raise e
|
||
|
else:
|
||
|
return data
|
||
|
finally:
|
||
|
tf.io.gfile.rmtree(temp_dir)
|