63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
|
import torch.cuda
|
||
|
|
||
|
try:
|
||
|
from torch._C import _cudnn
|
||
|
except ImportError:
|
||
|
# Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
|
||
|
# so it's safe to not emit any checks here.
|
||
|
_cudnn = None # type: ignore[assignment]
|
||
|
|
||
|
|
||
|
def get_cudnn_mode(mode):
|
||
|
if mode == "RNN_RELU":
|
||
|
return int(_cudnn.RNNMode.rnn_relu)
|
||
|
elif mode == "RNN_TANH":
|
||
|
return int(_cudnn.RNNMode.rnn_tanh)
|
||
|
elif mode == "LSTM":
|
||
|
return int(_cudnn.RNNMode.lstm)
|
||
|
elif mode == "GRU":
|
||
|
return int(_cudnn.RNNMode.gru)
|
||
|
else:
|
||
|
raise Exception(f"Unknown mode: {mode}")
|
||
|
|
||
|
|
||
|
# NB: We don't actually need this class anymore (in fact, we could serialize the
|
||
|
# dropout state for even better reproducibility), but it is kept for backwards
|
||
|
# compatibility for old models.
|
||
|
class Unserializable:
|
||
|
def __init__(self, inner):
|
||
|
self.inner = inner
|
||
|
|
||
|
def get(self):
|
||
|
return self.inner
|
||
|
|
||
|
def __getstate__(self):
|
||
|
# Note: can't return {}, because python2 won't call __setstate__
|
||
|
# if the value evaluates to False
|
||
|
return "<unserializable>"
|
||
|
|
||
|
def __setstate__(self, state):
|
||
|
self.inner = None
|
||
|
|
||
|
|
||
|
def init_dropout_state(dropout, train, dropout_seed, dropout_state):
|
||
|
dropout_desc_name = "desc_" + str(torch.cuda.current_device())
|
||
|
dropout_p = dropout if train else 0
|
||
|
if (dropout_desc_name not in dropout_state) or (
|
||
|
dropout_state[dropout_desc_name].get() is None
|
||
|
):
|
||
|
if dropout_p == 0:
|
||
|
dropout_state[dropout_desc_name] = Unserializable(None)
|
||
|
else:
|
||
|
dropout_state[dropout_desc_name] = Unserializable(
|
||
|
torch._cudnn_init_dropout_state( # type: ignore[call-arg]
|
||
|
dropout_p,
|
||
|
train,
|
||
|
dropout_seed,
|
||
|
self_ty=torch.uint8,
|
||
|
device=torch.device("cuda"),
|
||
|
)
|
||
|
)
|
||
|
dropout_ts = dropout_state[dropout_desc_name].get()
|
||
|
return dropout_ts
|