313 lines
12 KiB
Python
313 lines
12 KiB
Python
|
|
||
|
from collections import namedtuple
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from torch import Tensor
|
||
|
from typing import List, Sequence
|
||
|
|
||
|
from . import Sequential, ModuleList, Linear
|
||
|
from .module import Module
|
||
|
from ..functional import log_softmax
|
||
|
|
||
|
__all__ = ['AdaptiveLogSoftmaxWithLoss']
|
||
|
|
||
|
_ASMoutput = namedtuple('_ASMoutput', ['output', 'loss'])
|
||
|
|
||
|
|
||
|
class AdaptiveLogSoftmaxWithLoss(Module):
|
||
|
r"""Efficient softmax approximation.
|
||
|
|
||
|
As described in
|
||
|
`Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
|
||
|
Moustapha Cissé, David Grangier, and Hervé Jégou
|
||
|
<https://arxiv.org/abs/1609.04309>`__.
|
||
|
|
||
|
Adaptive softmax is an approximate strategy for training models with large
|
||
|
output spaces. It is most effective when the label distribution is highly
|
||
|
imbalanced, for example in natural language modelling, where the word
|
||
|
frequency distribution approximately follows the `Zipf's law`_.
|
||
|
|
||
|
Adaptive softmax partitions the labels into several clusters, according to
|
||
|
their frequency. These clusters may contain different number of targets
|
||
|
each.
|
||
|
Additionally, clusters containing less frequent labels assign lower
|
||
|
dimensional embeddings to those labels, which speeds up the computation.
|
||
|
For each minibatch, only clusters for which at least one target is
|
||
|
present are evaluated.
|
||
|
|
||
|
The idea is that the clusters which are accessed frequently
|
||
|
(like the first one, containing most frequent labels), should also be cheap
|
||
|
to compute -- that is, contain a small number of assigned labels.
|
||
|
|
||
|
We highly recommend taking a look at the original paper for more details.
|
||
|
|
||
|
* :attr:`cutoffs` should be an ordered Sequence of integers sorted
|
||
|
in the increasing order.
|
||
|
It controls number of clusters and the partitioning of targets into
|
||
|
clusters. For example setting ``cutoffs = [10, 100, 1000]``
|
||
|
means that first `10` targets will be assigned
|
||
|
to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be
|
||
|
assigned to the first cluster, and targets `101, 102, ..., 1000` will be
|
||
|
assigned to the second cluster, while targets
|
||
|
`1001, 1002, ..., n_classes - 1` will be assigned
|
||
|
to the last, third cluster.
|
||
|
|
||
|
* :attr:`div_value` is used to compute the size of each additional cluster,
|
||
|
which is given as
|
||
|
:math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`,
|
||
|
where :math:`idx` is the cluster index (with clusters
|
||
|
for less frequent words having larger indices,
|
||
|
and indices starting from :math:`1`).
|
||
|
|
||
|
* :attr:`head_bias` if set to True, adds a bias term to the 'head' of the
|
||
|
adaptive softmax. See paper for details. Set to False in the official
|
||
|
implementation.
|
||
|
|
||
|
.. warning::
|
||
|
Labels passed as inputs to this module should be sorted according to
|
||
|
their frequency. This means that the most frequent label should be
|
||
|
represented by the index `0`, and the least frequent
|
||
|
label should be represented by the index `n_classes - 1`.
|
||
|
|
||
|
.. note::
|
||
|
This module returns a ``NamedTuple`` with ``output``
|
||
|
and ``loss`` fields. See further documentation for details.
|
||
|
|
||
|
.. note::
|
||
|
To compute log-probabilities for all classes, the ``log_prob``
|
||
|
method can be used.
|
||
|
|
||
|
Args:
|
||
|
in_features (int): Number of features in the input tensor
|
||
|
n_classes (int): Number of classes in the dataset
|
||
|
cutoffs (Sequence): Cutoffs used to assign targets to their buckets
|
||
|
div_value (float, optional): value used as an exponent to compute sizes
|
||
|
of the clusters. Default: 4.0
|
||
|
head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the
|
||
|
adaptive softmax. Default: ``False``
|
||
|
|
||
|
Returns:
|
||
|
``NamedTuple`` with ``output`` and ``loss`` fields:
|
||
|
* **output** is a Tensor of size ``N`` containing computed target
|
||
|
log probabilities for each example
|
||
|
* **loss** is a Scalar representing the computed negative
|
||
|
log likelihood loss
|
||
|
|
||
|
Shape:
|
||
|
- input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})`
|
||
|
- target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}`
|
||
|
- output1: :math:`(N)` or :math:`()`
|
||
|
- output2: ``Scalar``
|
||
|
|
||
|
.. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
|
||
|
"""
|
||
|
|
||
|
in_features: int
|
||
|
n_classes: int
|
||
|
cutoffs: List[int]
|
||
|
div_value: float
|
||
|
head_bias: bool
|
||
|
head: Linear
|
||
|
tail: ModuleList
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_features: int,
|
||
|
n_classes: int,
|
||
|
cutoffs: Sequence[int],
|
||
|
div_value: float = 4.,
|
||
|
head_bias: bool = False,
|
||
|
device=None,
|
||
|
dtype=None
|
||
|
) -> None:
|
||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||
|
super().__init__()
|
||
|
|
||
|
cutoffs = list(cutoffs)
|
||
|
|
||
|
if (len(cutoffs) == 0):
|
||
|
raise ValueError("cutoffs should be a sequence of length larger than 0")
|
||
|
|
||
|
if (cutoffs != sorted(cutoffs)) \
|
||
|
or (min(cutoffs) <= 0) \
|
||
|
or (max(cutoffs) > (n_classes - 1)) \
|
||
|
or (len(set(cutoffs)) != len(cutoffs)) \
|
||
|
or any(int(c) != c for c in cutoffs):
|
||
|
|
||
|
raise ValueError("cutoffs should be a sequence of unique, positive "
|
||
|
"integers sorted in an increasing order, where "
|
||
|
"each value is between 1 and n_classes-1")
|
||
|
|
||
|
self.in_features = in_features
|
||
|
self.n_classes = n_classes
|
||
|
self.cutoffs = cutoffs + [n_classes]
|
||
|
self.div_value = div_value
|
||
|
self.head_bias = head_bias
|
||
|
|
||
|
self.shortlist_size = self.cutoffs[0]
|
||
|
self.n_clusters = len(self.cutoffs) - 1
|
||
|
self.head_size = self.shortlist_size + self.n_clusters
|
||
|
|
||
|
self.head = Linear(self.in_features, self.head_size, bias=self.head_bias,
|
||
|
**factory_kwargs)
|
||
|
self.tail = ModuleList()
|
||
|
|
||
|
for i in range(self.n_clusters):
|
||
|
|
||
|
hsz = int(self.in_features // (self.div_value ** (i + 1)))
|
||
|
osz = self.cutoffs[i + 1] - self.cutoffs[i]
|
||
|
|
||
|
projection = Sequential(
|
||
|
Linear(self.in_features, hsz, bias=False, **factory_kwargs),
|
||
|
Linear(hsz, osz, bias=False, **factory_kwargs),
|
||
|
)
|
||
|
|
||
|
self.tail.append(projection)
|
||
|
|
||
|
def reset_parameters(self) -> None:
|
||
|
self.head.reset_parameters()
|
||
|
for i2h, h2o in self.tail:
|
||
|
i2h.reset_parameters()
|
||
|
h2o.reset_parameters()
|
||
|
|
||
|
def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
|
||
|
targ_dim = target_.dim()
|
||
|
|
||
|
if targ_dim == 1:
|
||
|
if input_.size(0) != target_.size(0):
|
||
|
raise RuntimeError('Input and target should have the same size '
|
||
|
'in the batch dimension.')
|
||
|
if input_.dim() != 2:
|
||
|
raise RuntimeError('1D target tensor expects 2D input tensors, '
|
||
|
'but found inputs with size', input_.size())
|
||
|
elif targ_dim == 0:
|
||
|
if input_.dim() != 1:
|
||
|
raise RuntimeError('0D target tensor expects 1D input tensors, '
|
||
|
'but found inputs with size', input_.size())
|
||
|
else:
|
||
|
raise RuntimeError('0D or 1D target tensor expected, '
|
||
|
'multi-target not supported')
|
||
|
|
||
|
is_batched = targ_dim > 0
|
||
|
input = input_ if is_batched else input_.unsqueeze(0)
|
||
|
target = target_ if is_batched else target_.unsqueeze(0)
|
||
|
|
||
|
used_rows = 0
|
||
|
batch_size = target.size(0)
|
||
|
|
||
|
output = input.new_zeros(batch_size)
|
||
|
gather_inds = target.new_empty(batch_size)
|
||
|
|
||
|
cutoff_values = [0] + self.cutoffs
|
||
|
for i in range(len(cutoff_values) - 1):
|
||
|
|
||
|
low_idx = cutoff_values[i]
|
||
|
high_idx = cutoff_values[i + 1]
|
||
|
|
||
|
target_mask = (target >= low_idx) & (target < high_idx)
|
||
|
row_indices = target_mask.nonzero().squeeze()
|
||
|
|
||
|
if row_indices.numel() == 0:
|
||
|
continue
|
||
|
|
||
|
if i == 0:
|
||
|
gather_inds.index_copy_(0, row_indices, target[target_mask])
|
||
|
|
||
|
else:
|
||
|
relative_target = target[target_mask] - low_idx
|
||
|
input_subset = input.index_select(0, row_indices)
|
||
|
|
||
|
cluster_output = self.tail[i - 1](input_subset)
|
||
|
cluster_index = self.shortlist_size + i - 1
|
||
|
|
||
|
gather_inds.index_fill_(0, row_indices, cluster_index)
|
||
|
cluster_logprob = log_softmax(cluster_output, dim=1)
|
||
|
local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
|
||
|
output.index_copy_(0, row_indices, local_logprob.squeeze(1))
|
||
|
|
||
|
used_rows += row_indices.numel()
|
||
|
|
||
|
if used_rows != batch_size:
|
||
|
raise RuntimeError(f"Target values should be in [0, {self.n_classes - 1}], "
|
||
|
f"but values in range [{target.min().item()}, {target.max().item()}] "
|
||
|
"were found. ")
|
||
|
|
||
|
head_output = self.head(input)
|
||
|
head_logprob = log_softmax(head_output, dim=1)
|
||
|
output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
|
||
|
loss = (-output).mean()
|
||
|
|
||
|
if not is_batched:
|
||
|
output = output.squeeze(0)
|
||
|
|
||
|
return _ASMoutput(output, loss)
|
||
|
|
||
|
def _get_full_log_prob(self, input, head_output):
|
||
|
"""Given input tensor, and output of ``self.head``, compute the log of the full distribution."""
|
||
|
out = input.new_empty((head_output.size(0), self.n_classes))
|
||
|
head_logprob = log_softmax(head_output, dim=1)
|
||
|
|
||
|
out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size]
|
||
|
|
||
|
for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
|
||
|
cluster_output = self.tail[i](input)
|
||
|
cluster_logprob = log_softmax(cluster_output, dim=1)
|
||
|
output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1)
|
||
|
|
||
|
out[:, start_idx:stop_idx] = output_logprob
|
||
|
|
||
|
return out
|
||
|
|
||
|
def log_prob(self, input: Tensor) -> Tensor:
|
||
|
r"""Compute log probabilities for all :math:`\texttt{n\_classes}`.
|
||
|
|
||
|
Args:
|
||
|
input (Tensor): a minibatch of examples
|
||
|
|
||
|
Returns:
|
||
|
log-probabilities of for each class :math:`c`
|
||
|
in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
|
||
|
parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
|
||
|
|
||
|
Shape:
|
||
|
- Input: :math:`(N, \texttt{in\_features})`
|
||
|
- Output: :math:`(N, \texttt{n\_classes})`
|
||
|
|
||
|
"""
|
||
|
head_output = self.head(input)
|
||
|
return self._get_full_log_prob(input, head_output)
|
||
|
|
||
|
def predict(self, input: Tensor) -> Tensor:
|
||
|
r"""Return the class with the highest probability for each example in the input minibatch.
|
||
|
|
||
|
This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases.
|
||
|
|
||
|
Args:
|
||
|
input (Tensor): a minibatch of examples
|
||
|
|
||
|
Returns:
|
||
|
output (Tensor): a class with the highest probability for each example
|
||
|
|
||
|
Shape:
|
||
|
- Input: :math:`(N, \texttt{in\_features})`
|
||
|
- Output: :math:`(N)`
|
||
|
"""
|
||
|
head_output = self.head(input)
|
||
|
output = torch.argmax(head_output, dim=1)
|
||
|
not_in_shortlist = (output >= self.shortlist_size)
|
||
|
all_in_shortlist = not (not_in_shortlist.any())
|
||
|
|
||
|
if all_in_shortlist:
|
||
|
return output
|
||
|
|
||
|
elif not_in_shortlist.all():
|
||
|
log_prob = self._get_full_log_prob(input, head_output)
|
||
|
return torch.argmax(log_prob, dim=1)
|
||
|
|
||
|
else:
|
||
|
log_prob = self._get_full_log_prob(input[not_in_shortlist],
|
||
|
head_output[not_in_shortlist])
|
||
|
output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
|
||
|
return output
|