66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
from jax import lax
|
|
import jax.numpy as jnp
|
|
from jax._src.typing import Array, ArrayLike
|
|
from jax._src.numpy.util import promote_args_inexact
|
|
|
|
# Note: for mysterious reasons, annotating this leads to very slow mypy runs.
|
|
# def algdiv(a: ArrayLike, b: ArrayLike) -> Array:
|
|
|
|
def algdiv(a, b):
|
|
"""
|
|
Compute ``log(gamma(a))/log(gamma(a + b))`` when ``b >= 8``.
|
|
|
|
Derived from scipy's implementation of `algdiv`_.
|
|
|
|
This differs from the scipy implementation in that it assumes a <= b
|
|
because recomputing ``a, b = jnp.minimum(a, b), jnp.maximum(a, b)`` might
|
|
be expensive and this is only called by ``betaln``.
|
|
|
|
.. _algdiv:
|
|
https://github.com/scipy/scipy/blob/c89dfc2b90d993f2a8174e57e0cbc8fbe6f3ee19/scipy/special/cdflib/algdiv.f
|
|
"""
|
|
c0 = 0.833333333333333e-01
|
|
c1 = -0.277777777760991e-02
|
|
c2 = 0.793650666825390e-03
|
|
c3 = -0.595202931351870e-03
|
|
c4 = 0.837308034031215e-03
|
|
c5 = -0.165322962780713e-02
|
|
h = a / b
|
|
c = h / (1 + h)
|
|
x = h / (1 + h)
|
|
d = b + (a - 0.5)
|
|
# Set sN = (1 - x**n)/(1 - x)
|
|
x2 = x * x
|
|
s3 = 1.0 + (x + x2)
|
|
s5 = 1.0 + (x + x2 * s3)
|
|
s7 = 1.0 + (x + x2 * s5)
|
|
s9 = 1.0 + (x + x2 * s7)
|
|
s11 = 1.0 + (x + x2 * s9)
|
|
# Set w = del(b) - del(a + b)
|
|
# where del(x) is defined by ln(gamma(x)) = (x - 0.5)*ln(x) - x + 0.5*ln(2*pi) + del(x)
|
|
t = (1.0 / b) ** 2
|
|
w = ((((c5 * s11 * t + c4 * s9) * t + c3 * s7) * t + c2 * s5) * t + c1 * s3) * t + c0
|
|
w = w * (c / b)
|
|
# Combine the results
|
|
u = d * lax.log1p(a / b)
|
|
v = a * (lax.log(b) - 1.0)
|
|
return jnp.where(u <= v, (w - v) - u, (w - u) - v)
|
|
|
|
|
|
def betaln(a: ArrayLike, b: ArrayLike) -> Array:
|
|
"""Compute the log of the beta function.
|
|
|
|
Derived from scipy's implementation of `betaln`_.
|
|
|
|
This implementation does not handle all branches of the scipy implementation, but is still much more accurate
|
|
than just doing lgamma(a) + lgamma(b) - lgamma(a + b) when inputs are large (> 1M or so).
|
|
|
|
.. _betaln:
|
|
https://github.com/scipy/scipy/blob/ef2dee592ba8fb900ff2308b9d1c79e4d6a0ad8b/scipy/special/cdflib/betaln.f
|
|
"""
|
|
a, b = promote_args_inexact("betaln", a, b)
|
|
a, b = jnp.minimum(a, b), jnp.maximum(a, b)
|
|
small_b = lax.lgamma(a) + (lax.lgamma(b) - lax.lgamma(a + b))
|
|
large_b = lax.lgamma(a) + algdiv(a, b)
|
|
return jnp.where(b < 8, small_b, large_b)
|