95 lines
3.2 KiB
Python
95 lines
3.2 KiB
Python
|
from numbers import Number
|
||
|
|
||
|
import torch
|
||
|
from torch.distributions import constraints
|
||
|
from torch.distributions.distribution import Distribution
|
||
|
from torch.distributions.utils import broadcast_all
|
||
|
|
||
|
__all__ = ["Laplace"]
|
||
|
|
||
|
|
||
|
class Laplace(Distribution):
|
||
|
r"""
|
||
|
Creates a Laplace distribution parameterized by :attr:`loc` and :attr:`scale`.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||
|
>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0]))
|
||
|
>>> m.sample() # Laplace distributed with loc=0, scale=1
|
||
|
tensor([ 0.1046])
|
||
|
|
||
|
Args:
|
||
|
loc (float or Tensor): mean of the distribution
|
||
|
scale (float or Tensor): scale of the distribution
|
||
|
"""
|
||
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||
|
support = constraints.real
|
||
|
has_rsample = True
|
||
|
|
||
|
@property
|
||
|
def mean(self):
|
||
|
return self.loc
|
||
|
|
||
|
@property
|
||
|
def mode(self):
|
||
|
return self.loc
|
||
|
|
||
|
@property
|
||
|
def variance(self):
|
||
|
return 2 * self.scale.pow(2)
|
||
|
|
||
|
@property
|
||
|
def stddev(self):
|
||
|
return (2**0.5) * self.scale
|
||
|
|
||
|
def __init__(self, loc, scale, validate_args=None):
|
||
|
self.loc, self.scale = broadcast_all(loc, scale)
|
||
|
if isinstance(loc, Number) and isinstance(scale, Number):
|
||
|
batch_shape = torch.Size()
|
||
|
else:
|
||
|
batch_shape = self.loc.size()
|
||
|
super().__init__(batch_shape, validate_args=validate_args)
|
||
|
|
||
|
def expand(self, batch_shape, _instance=None):
|
||
|
new = self._get_checked_instance(Laplace, _instance)
|
||
|
batch_shape = torch.Size(batch_shape)
|
||
|
new.loc = self.loc.expand(batch_shape)
|
||
|
new.scale = self.scale.expand(batch_shape)
|
||
|
super(Laplace, new).__init__(batch_shape, validate_args=False)
|
||
|
new._validate_args = self._validate_args
|
||
|
return new
|
||
|
|
||
|
def rsample(self, sample_shape=torch.Size()):
|
||
|
shape = self._extended_shape(sample_shape)
|
||
|
finfo = torch.finfo(self.loc.dtype)
|
||
|
if torch._C._get_tracing_state():
|
||
|
# [JIT WORKAROUND] lack of support for .uniform_()
|
||
|
u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1
|
||
|
return self.loc - self.scale * u.sign() * torch.log1p(
|
||
|
-u.abs().clamp(min=finfo.tiny)
|
||
|
)
|
||
|
u = self.loc.new(shape).uniform_(finfo.eps - 1, 1)
|
||
|
# TODO: If we ever implement tensor.nextafter, below is what we want ideally.
|
||
|
# u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5)
|
||
|
return self.loc - self.scale * u.sign() * torch.log1p(-u.abs())
|
||
|
|
||
|
def log_prob(self, value):
|
||
|
if self._validate_args:
|
||
|
self._validate_sample(value)
|
||
|
return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale
|
||
|
|
||
|
def cdf(self, value):
|
||
|
if self._validate_args:
|
||
|
self._validate_sample(value)
|
||
|
return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1(
|
||
|
-(value - self.loc).abs() / self.scale
|
||
|
)
|
||
|
|
||
|
def icdf(self, value):
|
||
|
term = value - 0.5
|
||
|
return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs())
|
||
|
|
||
|
def entropy(self):
|
||
|
return 1 + torch.log(2 * self.scale)
|