# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import scipy.stats as osp_stats from jax import lax from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps, promote_args_inexact from jax.numpy import arctan from jax._src.typing import Array, ArrayLike @_wraps(osp_stats.cauchy.logpdf, update_doc=False) def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale) pi = _lax_const(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x)))) @_wraps(osp_stats.cauchy.pdf, update_doc=False) def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.exp(logpdf(x, loc, scale)) @_wraps(osp_stats.cauchy.cdf, update_doc=False) def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale) pi = _lax_const(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x))) @_wraps(osp_stats.cauchy.logcdf, update_doc=False) def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: return lax.log(cdf(x, loc, scale)) @_wraps(osp_stats.cauchy.sf, update_doc=False) def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: x, = promote_args_inexact("cauchy.sf", x) cdf_result = cdf(x, loc, scale) return lax.sub(_lax_const(cdf_result, 1), cdf_result) @_wraps(osp_stats.cauchy.isf, update_doc=False) def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale) pi = _lax_const(q, np.pi) half_pi = _lax_const(q, np.pi / 2) unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q))) return lax.add(lax.mul(unscaled, scale), loc) @_wraps(osp_stats.cauchy.ppf, update_doc=False) def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale) pi = _lax_const(q, np.pi) half_pi = _lax_const(q, np.pi / 2) unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi)) return lax.add(lax.mul(unscaled, scale), loc)