336 lines
13 KiB
Python
336 lines
13 KiB
Python
import math
|
|
import warnings
|
|
from numbers import Number
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch import nan
|
|
from torch.distributions import constraints
|
|
from torch.distributions.exp_family import ExponentialFamily
|
|
from torch.distributions.multivariate_normal import _precision_to_scale_tril
|
|
from torch.distributions.utils import lazy_property
|
|
|
|
|
|
__all__ = ["Wishart"]
|
|
|
|
_log_2 = math.log(2)
|
|
|
|
|
|
def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:
|
|
assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
|
|
return torch.digamma(
|
|
x.unsqueeze(-1)
|
|
- torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
|
|
).sum(-1)
|
|
|
|
|
|
def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor:
|
|
# We assume positive input for this function
|
|
return x.clamp(min=torch.finfo(x.dtype).eps)
|
|
|
|
|
|
class Wishart(ExponentialFamily):
|
|
r"""
|
|
Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
|
|
or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional")
|
|
>>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
|
|
>>> m.sample() # Wishart distributed with mean=`df * I` and
|
|
>>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
|
|
|
|
Args:
|
|
df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
|
|
covariance_matrix (Tensor): positive-definite covariance matrix
|
|
precision_matrix (Tensor): positive-definite precision matrix
|
|
scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
|
|
Note:
|
|
Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
|
|
:attr:`scale_tril` can be specified.
|
|
Using :attr:`scale_tril` will be more efficient: all computations internally
|
|
are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
|
|
:attr:`precision_matrix` is passed instead, it is only used to compute
|
|
the corresponding lower triangular matrices using a Cholesky decomposition.
|
|
'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
|
|
|
|
**References**
|
|
|
|
[1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`.
|
|
[2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`.
|
|
[3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`.
|
|
[4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203.
|
|
[5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`.
|
|
"""
|
|
arg_constraints = {
|
|
"covariance_matrix": constraints.positive_definite,
|
|
"precision_matrix": constraints.positive_definite,
|
|
"scale_tril": constraints.lower_cholesky,
|
|
"df": constraints.greater_than(0),
|
|
}
|
|
support = constraints.positive_definite
|
|
has_rsample = True
|
|
_mean_carrier_measure = 0
|
|
|
|
def __init__(
|
|
self,
|
|
df: Union[torch.Tensor, Number],
|
|
covariance_matrix: Optional[torch.Tensor] = None,
|
|
precision_matrix: Optional[torch.Tensor] = None,
|
|
scale_tril: Optional[torch.Tensor] = None,
|
|
validate_args=None,
|
|
):
|
|
assert (covariance_matrix is not None) + (scale_tril is not None) + (
|
|
precision_matrix is not None
|
|
) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
|
|
|
|
param = next(
|
|
p
|
|
for p in (covariance_matrix, precision_matrix, scale_tril)
|
|
if p is not None
|
|
)
|
|
|
|
if param.dim() < 2:
|
|
raise ValueError(
|
|
"scale_tril must be at least two-dimensional, with optional leading batch dimensions"
|
|
)
|
|
|
|
if isinstance(df, Number):
|
|
batch_shape = torch.Size(param.shape[:-2])
|
|
self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
|
|
else:
|
|
batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
|
|
self.df = df.expand(batch_shape)
|
|
event_shape = param.shape[-2:]
|
|
|
|
if self.df.le(event_shape[-1] - 1).any():
|
|
raise ValueError(
|
|
f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}."
|
|
)
|
|
|
|
if scale_tril is not None:
|
|
self.scale_tril = param.expand(batch_shape + (-1, -1))
|
|
elif covariance_matrix is not None:
|
|
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
|
|
elif precision_matrix is not None:
|
|
self.precision_matrix = param.expand(batch_shape + (-1, -1))
|
|
|
|
self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1)
|
|
if self.df.lt(event_shape[-1]).any():
|
|
warnings.warn(
|
|
"Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim."
|
|
)
|
|
|
|
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
|
|
|
|
if scale_tril is not None:
|
|
self._unbroadcasted_scale_tril = scale_tril
|
|
elif covariance_matrix is not None:
|
|
self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
|
|
else: # precision_matrix is not None
|
|
self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
|
|
|
|
# Chi2 distribution is needed for Bartlett decomposition sampling
|
|
self._dist_chi2 = torch.distributions.chi2.Chi2(
|
|
df=(
|
|
self.df.unsqueeze(-1)
|
|
- torch.arange(
|
|
self._event_shape[-1],
|
|
dtype=self._unbroadcasted_scale_tril.dtype,
|
|
device=self._unbroadcasted_scale_tril.device,
|
|
).expand(batch_shape + (-1,))
|
|
)
|
|
)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(Wishart, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
cov_shape = batch_shape + self.event_shape
|
|
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
|
|
new.df = self.df.expand(batch_shape)
|
|
|
|
new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
|
|
|
|
if "covariance_matrix" in self.__dict__:
|
|
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
|
|
if "scale_tril" in self.__dict__:
|
|
new.scale_tril = self.scale_tril.expand(cov_shape)
|
|
if "precision_matrix" in self.__dict__:
|
|
new.precision_matrix = self.precision_matrix.expand(cov_shape)
|
|
|
|
# Chi2 distribution is needed for Bartlett decomposition sampling
|
|
new._dist_chi2 = torch.distributions.chi2.Chi2(
|
|
df=(
|
|
new.df.unsqueeze(-1)
|
|
- torch.arange(
|
|
self.event_shape[-1],
|
|
dtype=new._unbroadcasted_scale_tril.dtype,
|
|
device=new._unbroadcasted_scale_tril.device,
|
|
).expand(batch_shape + (-1,))
|
|
)
|
|
)
|
|
|
|
super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
@lazy_property
|
|
def scale_tril(self):
|
|
return self._unbroadcasted_scale_tril.expand(
|
|
self._batch_shape + self._event_shape
|
|
)
|
|
|
|
@lazy_property
|
|
def covariance_matrix(self):
|
|
return (
|
|
self._unbroadcasted_scale_tril
|
|
@ self._unbroadcasted_scale_tril.transpose(-2, -1)
|
|
).expand(self._batch_shape + self._event_shape)
|
|
|
|
@lazy_property
|
|
def precision_matrix(self):
|
|
identity = torch.eye(
|
|
self._event_shape[-1],
|
|
device=self._unbroadcasted_scale_tril.device,
|
|
dtype=self._unbroadcasted_scale_tril.dtype,
|
|
)
|
|
return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
|
|
self._batch_shape + self._event_shape
|
|
)
|
|
|
|
@property
|
|
def mean(self):
|
|
return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix
|
|
|
|
@property
|
|
def mode(self):
|
|
factor = self.df - self.covariance_matrix.shape[-1] - 1
|
|
factor[factor <= 0] = nan
|
|
return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix
|
|
|
|
@property
|
|
def variance(self):
|
|
V = self.covariance_matrix # has shape (batch_shape x event_shape)
|
|
diag_V = V.diagonal(dim1=-2, dim2=-1)
|
|
return self.df.view(self._batch_shape + (1, 1)) * (
|
|
V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)
|
|
)
|
|
|
|
def _bartlett_sampling(self, sample_shape=torch.Size()):
|
|
p = self._event_shape[-1] # has singleton shape
|
|
|
|
# Implemented Sampling using Bartlett decomposition
|
|
noise = _clamp_above_eps(
|
|
self._dist_chi2.rsample(sample_shape).sqrt()
|
|
).diag_embed(dim1=-2, dim2=-1)
|
|
|
|
i, j = torch.tril_indices(p, p, offset=-1)
|
|
noise[..., i, j] = torch.randn(
|
|
torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
|
|
dtype=noise.dtype,
|
|
device=noise.device,
|
|
)
|
|
chol = self._unbroadcasted_scale_tril @ noise
|
|
return chol @ chol.transpose(-2, -1)
|
|
|
|
def rsample(self, sample_shape=torch.Size(), max_try_correction=None):
|
|
r"""
|
|
.. warning::
|
|
In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
|
|
Several tries to correct singular samples are performed by default, but it may end up returning
|
|
singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`.
|
|
In those cases, the user should validate the samples and either fix the value of `df`
|
|
or adjust `max_try_correction` value for argument in `.rsample` accordingly.
|
|
"""
|
|
|
|
if max_try_correction is None:
|
|
max_try_correction = 3 if torch._C._get_tracing_state() else 10
|
|
|
|
sample_shape = torch.Size(sample_shape)
|
|
sample = self._bartlett_sampling(sample_shape)
|
|
|
|
# Below part is to improve numerical stability temporally and should be removed in the future
|
|
is_singular = self.support.check(sample)
|
|
if self._batch_shape:
|
|
is_singular = is_singular.amax(self._batch_dims)
|
|
|
|
if torch._C._get_tracing_state():
|
|
# Less optimized version for JIT
|
|
for _ in range(max_try_correction):
|
|
sample_new = self._bartlett_sampling(sample_shape)
|
|
sample = torch.where(is_singular, sample_new, sample)
|
|
|
|
is_singular = ~self.support.check(sample)
|
|
if self._batch_shape:
|
|
is_singular = is_singular.amax(self._batch_dims)
|
|
|
|
else:
|
|
# More optimized version with data-dependent control flow.
|
|
if is_singular.any():
|
|
warnings.warn("Singular sample detected.")
|
|
|
|
for _ in range(max_try_correction):
|
|
sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
|
|
sample[is_singular] = sample_new
|
|
|
|
is_singular_new = ~self.support.check(sample_new)
|
|
if self._batch_shape:
|
|
is_singular_new = is_singular_new.amax(self._batch_dims)
|
|
is_singular[is_singular.clone()] = is_singular_new
|
|
|
|
if not is_singular.any():
|
|
break
|
|
|
|
return sample
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
nu = self.df # has shape (batch_shape)
|
|
p = self._event_shape[-1] # has singleton shape
|
|
return (
|
|
-nu
|
|
* (
|
|
p * _log_2 / 2
|
|
+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
|
|
.log()
|
|
.sum(-1)
|
|
)
|
|
- torch.mvlgamma(nu / 2, p=p)
|
|
+ (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
|
|
- torch.cholesky_solve(value, self._unbroadcasted_scale_tril)
|
|
.diagonal(dim1=-2, dim2=-1)
|
|
.sum(dim=-1)
|
|
/ 2
|
|
)
|
|
|
|
def entropy(self):
|
|
nu = self.df # has shape (batch_shape)
|
|
p = self._event_shape[-1] # has singleton shape
|
|
V = self.covariance_matrix # has shape (batch_shape x event_shape)
|
|
return (
|
|
(p + 1)
|
|
* (
|
|
p * _log_2 / 2
|
|
+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
|
|
.log()
|
|
.sum(-1)
|
|
)
|
|
+ torch.mvlgamma(nu / 2, p=p)
|
|
- (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
|
|
+ nu * p / 2
|
|
)
|
|
|
|
@property
|
|
def _natural_params(self):
|
|
nu = self.df # has shape (batch_shape)
|
|
p = self._event_shape[-1] # has singleton shape
|
|
return -self.precision_matrix / 2, (nu - p - 1) / 2
|
|
|
|
def _log_normalizer(self, x, y):
|
|
p = self._event_shape[-1]
|
|
return (y + (p + 1) / 2) * (
|
|
-torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p
|
|
) + torch.mvlgamma(y + (p + 1) / 2, p=p)
|