187 lines
6.6 KiB
Python
187 lines
6.6 KiB
Python
import torch
|
|
from ..modules import Module
|
|
from . import comm
|
|
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast
|
|
from torch._utils import _get_device_index
|
|
|
|
from collections import OrderedDict
|
|
|
|
if TYPE_CHECKING:
|
|
import torch.jit
|
|
import torch.jit._state
|
|
|
|
__all__ = ['replicate']
|
|
|
|
def _is_script_module(module: Module) -> bool:
|
|
import torch.jit
|
|
return isinstance(module, torch.jit.ScriptModule)
|
|
|
|
|
|
def _is_script_method(module: Module) -> bool:
|
|
import torch.jit
|
|
return isinstance(module, torch._C.ScriptMethod)
|
|
|
|
|
|
def _init_script_module() -> "torch.jit.ScriptModule":
|
|
import torch.jit
|
|
return torch.jit.ScriptModule()
|
|
|
|
|
|
def _is_jit_enabled() -> "torch.jit._state.EnabledProxy":
|
|
import torch.jit._state
|
|
return torch.jit._state._enabled
|
|
|
|
|
|
# Check if we can safely replicate the module.
|
|
# there are two types of module:
|
|
# 1. python modules
|
|
# 2. ScriptModule
|
|
#
|
|
# currently a module cannot be replicated properly if the descendants of
|
|
# any ScriptModule contains python module (type 1 above)
|
|
def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool:
|
|
|
|
# module.modules() contains module itself as the first element
|
|
def descendant_modules(module: Module) -> Iterator[Module]:
|
|
gen = module.modules()
|
|
next(gen)
|
|
return gen
|
|
|
|
if not _is_jit_enabled():
|
|
return True
|
|
if memo is None:
|
|
memo = set()
|
|
|
|
# memoize visited modules
|
|
memo.add(module)
|
|
if _is_script_module(module):
|
|
memo.update(descendant_modules(module))
|
|
return all(_is_script_module(descendant) for
|
|
descendant in descendant_modules(module))
|
|
|
|
for child in module.children():
|
|
# since any unreplicatable module will cause the check to return
|
|
# False early, visited modules here can be safely ignored.
|
|
if child in memo:
|
|
continue
|
|
if not _replicatable_module(child, memo):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _broadcast_coalesced_reshape(
|
|
tensors: Sequence[torch.Tensor],
|
|
devices: Sequence[Union[int, torch.device]],
|
|
detach: bool = False,
|
|
) -> List[List[torch.Tensor]]:
|
|
from ._functions import Broadcast
|
|
if detach:
|
|
return comm.broadcast_coalesced(tensors, devices)
|
|
else:
|
|
# Use the autograd function to broadcast if not detach
|
|
if len(tensors) > 0:
|
|
tensor_copies = Broadcast.apply(devices, *tensors)
|
|
return [tensor_copies[i:i + len(tensors)]
|
|
for i in range(0, len(tensor_copies), len(tensors))]
|
|
else:
|
|
return []
|
|
|
|
|
|
T = TypeVar("T", bound=Module)
|
|
|
|
|
|
def replicate(
|
|
network: T,
|
|
devices: Sequence[Union[int, torch.device]],
|
|
detach: bool = False,
|
|
) -> List[T]:
|
|
if not _replicatable_module(network):
|
|
raise RuntimeError("Cannot replicate network where python modules are "
|
|
"childrens of ScriptModule")
|
|
|
|
if not devices:
|
|
return []
|
|
|
|
devices = [_get_device_index(x, True) for x in devices]
|
|
num_replicas = len(devices)
|
|
|
|
params = list(network.parameters())
|
|
param_indices = {param: idx for idx, param in enumerate(params)}
|
|
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
|
|
|
|
buffers = list(network.buffers())
|
|
buffers_rg: List[torch.Tensor] = []
|
|
buffers_not_rg: List[torch.Tensor] = []
|
|
for buf in buffers:
|
|
if buf.requires_grad and not detach:
|
|
buffers_rg.append(buf)
|
|
else:
|
|
buffers_not_rg.append(buf)
|
|
|
|
buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
|
|
buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
|
|
|
|
buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
|
|
buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
|
|
|
|
modules = list(network.modules())
|
|
module_copies: List[List[Module]] = [[] for _ in devices]
|
|
module_indices: Dict[Module, int] = {}
|
|
|
|
for i, module in enumerate(modules):
|
|
module_indices[module] = i
|
|
for j in range(num_replicas):
|
|
replica = module._replicate_for_data_parallel()
|
|
# This is a temporary fix for DDP. DDP needs to access the
|
|
# replicated model parameters. It used to do so through
|
|
# `mode.parameters()`. The fix added in #33907 for DP stops the
|
|
# `parameters()` API from exposing the replicated parameters.
|
|
# Hence, we add a `_former_parameters` dict here to support DDP.
|
|
replica._former_parameters = OrderedDict()
|
|
|
|
module_copies[j].append(replica)
|
|
|
|
for i, module in enumerate(modules):
|
|
for key, child in module._modules.items():
|
|
if child is None:
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._modules[key] = None
|
|
else:
|
|
module_idx = module_indices[child]
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
setattr(replica, key, module_copies[j][module_idx])
|
|
for key, param in module._parameters.items():
|
|
if param is None:
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._parameters[key] = None
|
|
else:
|
|
param_idx = param_indices[param]
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
param_copy = param_copies[j][param_idx]
|
|
# parameters in replicas are no longer leaves,
|
|
# so setattr them as non-parameter attributes
|
|
setattr(replica, key, param_copy)
|
|
# expose the parameter for DDP
|
|
replica._former_parameters[key] = param_copy
|
|
for key, buf in module._buffers.items(): # type: ignore[assignment]
|
|
if buf is None:
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
replica._buffers[key] = None
|
|
else:
|
|
if buf.requires_grad and not detach:
|
|
buffer_copies = buffer_copies_rg
|
|
buffer_idx = buffer_indices_rg[buf]
|
|
else:
|
|
buffer_copies = buffer_copies_not_rg
|
|
buffer_idx = buffer_indices_not_rg[buf]
|
|
for j in range(num_replicas):
|
|
replica = module_copies[j][i]
|
|
setattr(replica, key, buffer_copies[j][buffer_idx])
|
|
|
|
return [cast(T, module_copies[j][0]) for j in range(num_replicas)]
|