Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/scipy/stats/truncnorm.py

130 lines
4.3 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import scipy.stats as osp_stats
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps, promote_args_inexact
from jax._src.scipy.stats import norm
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
def _log_diff(x, y):
return logsumexp(
jnp.array([x, y]),
b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]),
axis=0
)
def _log_gauss_mass(a, b):
"""Log of Gaussian probability mass within an interval"""
a, b = jnp.array(a), jnp.array(b)
a, b = jnp.broadcast_arrays(a, b)
# Note: Docstring carried over from scipy
# Calculations in right tail are inaccurate, so we'll exploit the
# symmetry and work only in the left tail
case_left = b <= 0
case_right = a > 0
case_central = ~(case_left | case_right)
def mass_case_left(a, b):
return _log_diff(log_ndtr(b), log_ndtr(a))
def mass_case_right(a, b):
return mass_case_left(-b, -a)
def mass_case_central(a, b):
# Note: Docstring carried over from scipy
# Previously, this was implemented as:
# left_mass = mass_case_left(a, 0)
# right_mass = mass_case_right(0, b)
# return _log_sum(left_mass, right_mass)
# Catastrophic cancellation occurs as np.exp(log_mass) approaches 1.
# Correct for this with an alternative formulation.
# We're not concerned with underflow here: if only one term
# underflows, it was insignificant; if both terms underflow,
# the result can't accurately be represented in logspace anyway
# because sc.log1p(x) ~ x for small x.
return jnp.log1p(-ndtr(a) - ndtr(-b))
out = jnp.select(
[case_left, case_right, case_central],
[mass_case_left(a, b), mass_case_right(a, b), mass_case_central(a, b)]
)
return out
@_wraps(osp_stats.truncnorm.logpdf, update_doc=False)
def logpdf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
x_scaled = lax.div(lax.sub(x, loc), scale)
val = jnp.where((x_scaled < a) | (x_scaled > b), -jnp.inf, val)
val = jnp.where(a >= b, jnp.nan, val)
return val
@_wraps(osp_stats.truncnorm.pdf, update_doc=False)
def pdf(x, a, b, loc=0, scale=1):
return lax.exp(logpdf(x, a, b, loc, scale))
@_wraps(osp_stats.truncnorm.logsf, update_doc=False)
def logsf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
logsf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logsf > -0.1, x > a],
[-jnp.inf, 0, jnp.log1p(-jnp.exp(logcdf)), logsf]
)
logsf = jnp.where(a >= b, jnp.nan, logsf)
return logsf
@_wraps(osp_stats.truncnorm.sf, update_doc=False)
def sf(x, a, b, loc=0, scale=1):
return lax.exp(logsf(x, a, b, loc, scale))
@_wraps(osp_stats.truncnorm.logcdf, update_doc=False)
def logcdf(x, a, b, loc=0, scale=1):
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
logcdf = _log_gauss_mass(a, x) - _log_gauss_mass(a, b)
logsf = _log_gauss_mass(x, b) - _log_gauss_mass(a, b)
logcdf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logcdf > -0.1, x > a],
[0, -jnp.inf, jnp.log1p(-jnp.exp(logsf)), logcdf]
)
logcdf = jnp.where(a >= b, jnp.nan, logcdf)
return logcdf
@_wraps(osp_stats.truncnorm.cdf, update_doc=False)
def cdf(x, a, b, loc=0, scale=1):
return lax.exp(logcdf(x, a, b, loc, scale))