267 lines
9.0 KiB
Python
267 lines
9.0 KiB
Python
"""Serialization.
|
|
|
|
This module contains functionality for serializing TorchScript modules, notably:
|
|
* torch.jit.save
|
|
* torch.jit.load
|
|
|
|
This is not intended to be imported directly; please use the exposed
|
|
functionalities in `torch.jit`.
|
|
"""
|
|
import os
|
|
|
|
import torch
|
|
from torch.jit._recursive import wrap_cpp_module
|
|
from torch.serialization import validate_cuda_device
|
|
|
|
|
|
def save(m, f, _extra_files=None):
|
|
r"""
|
|
Save an offline version of this module for use in a separate process.
|
|
|
|
The saved module serializes all of the methods, submodules, parameters, and
|
|
attributes of this module. It can be loaded into the C++ API using
|
|
``torch::jit::load(filename)`` or into the Python API with
|
|
:func:`torch.jit.load <torch.jit.load>`.
|
|
|
|
To be able to save a module, it must not make any calls to native Python
|
|
functions. This means that all submodules must be subclasses of
|
|
:class:`ScriptModule` as well.
|
|
|
|
.. DANGER::
|
|
All modules, no matter their device, are always loaded onto the CPU
|
|
during loading. This is different from :func:`torch.load`'s semantics
|
|
and may change in the future.
|
|
|
|
Args:
|
|
m: A :class:`ScriptModule` to save.
|
|
f: A file-like object (has to implement write and flush) or a string
|
|
containing a file name.
|
|
_extra_files: Map from filename to contents which will be stored as part of `f`.
|
|
|
|
.. note::
|
|
torch.jit.save attempts to preserve the behavior of some operators
|
|
across versions. For example, dividing two integer tensors in
|
|
PyTorch 1.5 performed floor division, and if the module
|
|
containing that code is saved in PyTorch 1.5 and loaded in PyTorch 1.6
|
|
its division behavior will be preserved. The same module saved in
|
|
PyTorch 1.6 will fail to load in PyTorch 1.5, however, since the
|
|
behavior of division changed in 1.6, and 1.5 does not know how to
|
|
replicate the 1.6 behavior.
|
|
|
|
Example:
|
|
.. testcode::
|
|
|
|
import torch
|
|
import io
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 10
|
|
|
|
m = torch.jit.script(MyModule())
|
|
|
|
# Save to file
|
|
torch.jit.save(m, 'scriptmodule.pt')
|
|
# This line is equivalent to the previous
|
|
m.save("scriptmodule.pt")
|
|
|
|
# Save to io.BytesIO buffer
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer)
|
|
|
|
# Save with extra files
|
|
extra_files = {'foo.txt': b'bar'}
|
|
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
|
|
"""
|
|
if _extra_files is None:
|
|
_extra_files = {}
|
|
if isinstance(f, (str, os.PathLike)):
|
|
m.save(f, _extra_files=_extra_files)
|
|
else:
|
|
ret = m.save_to_buffer(_extra_files=_extra_files)
|
|
f.write(ret)
|
|
|
|
|
|
def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
|
|
r"""
|
|
Load a :class:`ScriptModule` or :class:`ScriptFunction` previously saved with :func:`torch.jit.save <torch.jit.save>`.
|
|
|
|
All previously saved modules, no matter their device, are first loaded onto CPU,
|
|
and then are moved to the devices they were saved from. If this fails (e.g.
|
|
because the run time system doesn't have certain devices), an exception is
|
|
raised.
|
|
|
|
Args:
|
|
f: a file-like object (has to implement read, readline, tell, and seek),
|
|
or a string containing a file name
|
|
map_location (string or torch.device): A simplified version of
|
|
``map_location`` in `torch.jit.save` used to dynamically remap
|
|
storages to an alternative set of devices.
|
|
_extra_files (dictionary of filename to content): The extra
|
|
filenames given in the map would be loaded and their content
|
|
would be stored in the provided map.
|
|
_restore_shapes (bool): Whether or not to retrace the module on load using stored inputs
|
|
|
|
Returns:
|
|
A :class:`ScriptModule` object.
|
|
|
|
Example:
|
|
.. testcode::
|
|
|
|
import torch
|
|
import io
|
|
|
|
torch.jit.load('scriptmodule.pt')
|
|
|
|
# Load ScriptModule from io.BytesIO object
|
|
with open('scriptmodule.pt', 'rb') as f:
|
|
buffer = io.BytesIO(f.read())
|
|
|
|
# Load all tensors to the original device
|
|
torch.jit.load(buffer)
|
|
|
|
# Load all tensors onto CPU, using a device
|
|
buffer.seek(0)
|
|
torch.jit.load(buffer, map_location=torch.device('cpu'))
|
|
|
|
# Load all tensors onto CPU, using a string
|
|
buffer.seek(0)
|
|
torch.jit.load(buffer, map_location='cpu')
|
|
|
|
# Load with extra files.
|
|
extra_files = {'foo.txt': ''} # values will be replaced with data
|
|
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
|
|
print(extra_files['foo.txt'])
|
|
|
|
.. testoutput::
|
|
:hide:
|
|
|
|
...
|
|
|
|
.. testcleanup::
|
|
|
|
import os
|
|
os.remove("scriptmodule.pt")
|
|
"""
|
|
if isinstance(f, (str, os.PathLike)):
|
|
if not os.path.exists(f): # type: ignore[type-var]
|
|
raise ValueError(f"The provided filename {f} does not exist") # type: ignore[str-bytes-safe]
|
|
if os.path.isdir(f):
|
|
raise ValueError(f"The provided filename {f} is a directory") # type: ignore[str-bytes-safe]
|
|
|
|
map_location = validate_map_location(map_location)
|
|
if _extra_files is None:
|
|
_extra_files = {}
|
|
|
|
cu = torch._C.CompilationUnit()
|
|
if isinstance(f, (str, os.PathLike)):
|
|
cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
|
|
else:
|
|
cpp_module = torch._C.import_ir_module_from_buffer(
|
|
cu, f.read(), map_location, _extra_files, _restore_shapes
|
|
) # type: ignore[call-arg]
|
|
|
|
# TODO: Pretty sure this approach loses ConstSequential status and such
|
|
return wrap_cpp_module(cpp_module)
|
|
|
|
|
|
def validate_map_location(map_location=None):
|
|
if isinstance(map_location, str):
|
|
map_location = torch.device(map_location)
|
|
elif not (map_location is None or isinstance(map_location, torch.device)):
|
|
raise ValueError(
|
|
"map_location should be either None, string or torch.device, "
|
|
"but got type: " + str(type(map_location))
|
|
)
|
|
|
|
if str(map_location).startswith("cuda"):
|
|
validate_cuda_device(map_location)
|
|
|
|
return map_location
|
|
|
|
|
|
def jit_module_from_flatbuffer(f):
|
|
if isinstance(f, (str, os.PathLike)):
|
|
f = os.fspath(f)
|
|
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
|
|
else:
|
|
return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read()))
|
|
|
|
|
|
def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
|
|
r"""
|
|
Save an offline version of this module for use in a separate process.
|
|
|
|
The saved module serializes all of the methods, submodules, parameters, and
|
|
attributes of this module. It can be loaded into the C++ API using
|
|
``torch::jit::load_jit_module_from_file(filename)`` or into the Python API with
|
|
:func:`torch.jit.jit_module_from_flatbuffer<torch.jit.jit_module_from_flatbuffer>`.
|
|
|
|
To be able to save a module, it must not make any calls to native Python
|
|
functions. This means that all submodules must be subclasses of
|
|
:class:`ScriptModule` as well.
|
|
|
|
.. DANGER::
|
|
All modules, no matter their device, are always loaded onto the CPU
|
|
during loading. This is different from :func:`torch.load`'s semantics
|
|
and may change in the future.
|
|
|
|
Args:
|
|
m: A :class:`ScriptModule` to save.
|
|
f: A string for file path
|
|
|
|
|
|
Example:
|
|
.. testcode::
|
|
|
|
import torch
|
|
import io
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 10
|
|
|
|
m = torch.jit.script(MyModule())
|
|
|
|
# Save to file
|
|
torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff')
|
|
"""
|
|
extra_files = _extra_files
|
|
if extra_files is None:
|
|
extra_files = {}
|
|
|
|
if isinstance(f, (str, os.PathLike)):
|
|
f = os.fspath(f)
|
|
torch._C._save_jit_module(m._c, f, extra_files)
|
|
else:
|
|
s = torch._C._save_jit_module_to_bytes(m._c, extra_files)
|
|
f.write(s)
|
|
|
|
|
|
def get_flatbuffer_module_info(path_or_file):
|
|
r"""Get some information regarding a model file in flatbuffer format.
|
|
|
|
Args:
|
|
path_or_file: Either str, Path or file like object (BytesIO OK).
|
|
If it's str or Path, we will read the file referenced by that
|
|
path as Bytes.
|
|
|
|
Returns:
|
|
A dict with metadata on what that file contains, currently looks like
|
|
this:
|
|
{
|
|
'bytecode_version': 4, # int
|
|
'operator_version': 4, # int
|
|
'function_names': {
|
|
'__torch__.___torch_mangle_0.Foo.forward'}, # set
|
|
'type_names': set(), # set
|
|
'opname_to_num_args': {'aten::linear': 3} # Dict[str, int]
|
|
}
|
|
"""
|
|
if isinstance(path_or_file, (str, os.PathLike)):
|
|
with open(path_or_file, "rb") as f:
|
|
all_bytes = f.read()
|
|
else:
|
|
all_bytes = path_or_file.read()
|
|
return torch._C._get_module_info_from_flatbuffer(all_bytes)
|