130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
|
# Copyright 2022 The JAX Authors.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
|
||
|
import scipy.stats as osp_stats
|
||
|
|
||
|
from jax import lax
|
||
|
import jax.numpy as jnp
|
||
|
from jax._src.numpy.util import _wraps, promote_args_inexact
|
||
|
from jax._src.scipy.stats import norm
|
||
|
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
|
||
|
|
||
|
|
||
|
def _log_diff(x, y):
|
||
|
return logsumexp(
|
||
|
jnp.array([x, y]),
|
||
|
b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]),
|
||
|
axis=0
|
||
|
)
|
||
|
|
||
|
|
||
|
def _log_gauss_mass(a, b):
|
||
|
"""Log of Gaussian probability mass within an interval"""
|
||
|
a, b = jnp.array(a), jnp.array(b)
|
||
|
a, b = jnp.broadcast_arrays(a, b)
|
||
|
|
||
|
# Note: Docstring carried over from scipy
|
||
|
# Calculations in right tail are inaccurate, so we'll exploit the
|
||
|
# symmetry and work only in the left tail
|
||
|
case_left = b <= 0
|
||
|
case_right = a > 0
|
||
|
case_central = ~(case_left | case_right)
|
||
|
|
||
|
def mass_case_left(a, b):
|
||
|
return _log_diff(log_ndtr(b), log_ndtr(a))
|
||
|
|
||
|
def mass_case_right(a, b):
|
||
|
return mass_case_left(-b, -a)
|
||
|
|
||
|
def mass_case_central(a, b):
|
||
|
# Note: Docstring carried over from scipy
|
||
|
# Previously, this was implemented as:
|
||
|
# left_mass = mass_case_left(a, 0)
|
||
|
# right_mass = mass_case_right(0, b)
|
||
|
# return _log_sum(left_mass, right_mass)
|
||
|
# Catastrophic cancellation occurs as np.exp(log_mass) approaches 1.
|
||
|
# Correct for this with an alternative formulation.
|
||
|
# We're not concerned with underflow here: if only one term
|
||
|
# underflows, it was insignificant; if both terms underflow,
|
||
|
# the result can't accurately be represented in logspace anyway
|
||
|
# because sc.log1p(x) ~ x for small x.
|
||
|
return jnp.log1p(-ndtr(a) - ndtr(-b))
|
||
|
|
||
|
out = jnp.select(
|
||
|
[case_left, case_right, case_central],
|
||
|
[mass_case_left(a, b), mass_case_right(a, b), mass_case_central(a, b)]
|
||
|
)
|
||
|
return out
|
||
|
|
||
|
|
||
|
@_wraps(osp_stats.truncnorm.logpdf, update_doc=False)
|
||
|
def logpdf(x, a, b, loc=0, scale=1):
|
||
|
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
|
||
|
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
|
||
|
|
||
|
x_scaled = lax.div(lax.sub(x, loc), scale)
|
||
|
val = jnp.where((x_scaled < a) | (x_scaled > b), -jnp.inf, val)
|
||
|
val = jnp.where(a >= b, jnp.nan, val)
|
||
|
return val
|
||
|
|
||
|
|
||
|
@_wraps(osp_stats.truncnorm.pdf, update_doc=False)
|
||
|
def pdf(x, a, b, loc=0, scale=1):
|
||
|
return lax.exp(logpdf(x, a, b, loc, scale))
|
||
|
|
||
|
|
||
|
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
|
||
|
def logsf(x, a, b, loc=0, scale=1):
|
||
|
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
|
||
|
x, a, b = jnp.broadcast_arrays(x, a, b)
|
||
|
x = lax.div(lax.sub(x, loc), scale)
|
||
|
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
|
||
|
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
|
||
|
|
||
|
logsf = jnp.select(
|
||
|
# third condition: avoid catastrophic cancellation (from scipy)
|
||
|
[x >= b, x <= a, logsf > -0.1, x > a],
|
||
|
[-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf]
|
||
|
)
|
||
|
logsf = jnp.where(a >= b, jnp.nan, logsf)
|
||
|
return logsf
|
||
|
|
||
|
|
||
|
@_wraps(osp_stats.truncnorm.sf, update_doc=False)
|
||
|
def sf(x, a, b, loc=0, scale=1):
|
||
|
return lax.exp(logsf(x, a, b, loc, scale))
|
||
|
|
||
|
|
||
|
@_wraps(osp_stats.truncnorm.logcdf, update_doc=False)
|
||
|
def logcdf(x, a, b, loc=0, scale=1):
|
||
|
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
|
||
|
x, a, b = jnp.broadcast_arrays(x, a, b)
|
||
|
x = lax.div(lax.sub(x, loc), scale)
|
||
|
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
|
||
|
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
|
||
|
|
||
|
logcdf = jnp.select(
|
||
|
# third condition: avoid catastrophic cancellation (from scipy)
|
||
|
[x >= b, x <= a, logcdf > -0.1, x > a],
|
||
|
[0, -jnp.inf, jnp.log1p(-jnp.exp(logsf)), logcdf]
|
||
|
)
|
||
|
logcdf = jnp.where(a >= b, jnp.nan, logcdf)
|
||
|
return logcdf
|
||
|
|
||
|
|
||
|
@_wraps(osp_stats.truncnorm.cdf, update_doc=False)
|
||
|
def cdf(x, a, b, loc=0, scale=1):
|
||
|
return lax.exp(logcdf(x, a, b, loc, scale))
|