63 lines
2.3 KiB
Python
63 lines
2.3 KiB
Python
|
import torch
|
||
|
from torch.distributions.distribution import Distribution
|
||
|
|
||
|
__all__ = ["ExponentialFamily"]
|
||
|
|
||
|
|
||
|
class ExponentialFamily(Distribution):
|
||
|
r"""
|
||
|
ExponentialFamily is the abstract base class for probability distributions belonging to an
|
||
|
exponential family, whose probability mass/density function has the form is defined below
|
||
|
|
||
|
.. math::
|
||
|
|
||
|
p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
|
||
|
|
||
|
where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
|
||
|
:math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
|
||
|
measure.
|
||
|
|
||
|
Note:
|
||
|
This class is an intermediary between the `Distribution` class and distributions which belong
|
||
|
to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
|
||
|
divergence methods. We use this class to compute the entropy and KL divergence using the AD
|
||
|
framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
|
||
|
Cross-entropies of Exponential Families).
|
||
|
"""
|
||
|
|
||
|
@property
|
||
|
def _natural_params(self):
|
||
|
"""
|
||
|
Abstract method for natural parameters. Returns a tuple of Tensors based
|
||
|
on the distribution
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def _log_normalizer(self, *natural_params):
|
||
|
"""
|
||
|
Abstract method for log normalizer function. Returns a log normalizer based on
|
||
|
the distribution and input
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
@property
|
||
|
def _mean_carrier_measure(self):
|
||
|
"""
|
||
|
Abstract method for expected carrier measure, which is required for computing
|
||
|
entropy.
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def entropy(self):
|
||
|
"""
|
||
|
Method to compute the entropy using Bregman divergence of the log normalizer.
|
||
|
"""
|
||
|
result = -self._mean_carrier_measure
|
||
|
nparams = [p.detach().requires_grad_() for p in self._natural_params]
|
||
|
lg_normal = self._log_normalizer(*nparams)
|
||
|
gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
|
||
|
result += lg_normal
|
||
|
for np, g in zip(nparams, gradients):
|
||
|
result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
|
||
|
return result
|