# Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple from functools import partial import math from typing import Optional, Tuple import jax import jax.numpy as jnp from jax import jit from jax._src import dtypes from jax._src.api import vmap from jax._src.numpy.util import check_arraylike, _wraps from jax._src.typing import ArrayLike, Array from jax._src.util import canonicalize_axis import scipy ModeResult = namedtuple('ModeResult', ('mode', 'count')) @_wraps(scipy.stats.mode, lax_description="""\ Currently the only supported nan_policy is 'propagate' """) @partial(jit, static_argnames=['axis', 'nan_policy', 'keepdims']) def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult: check_arraylike("mode", a) x = jnp.atleast_1d(a) if nan_policy not in ["propagate", "omit", "raise"]: raise ValueError( f"Illegal nan_policy value {nan_policy!r}; expected one of " "{'propagate', 'omit', 'raise'}" ) if nan_policy == "omit": # TODO: return answer without nans included. raise NotImplementedError( f"Logic for `nan_policy` of {nan_policy} is not implemented" ) if nan_policy == "raise": raise NotImplementedError( "In order to best JIT compile `mode`, we cannot know whether `x` contains nans. " "Please check if nans exist in `x` outside of the `mode` function." ) input_shape = x.shape if keepdims: if axis is None: output_shape = tuple(1 for i in input_shape) else: output_shape = tuple(1 if i == axis else s for i, s in enumerate(input_shape)) else: if axis is None: output_shape = () else: output_shape = tuple(s for i, s in enumerate(input_shape) if i != axis) if axis is None: axis = 0 x = x.ravel() def _mode_helper(x: jax.Array) -> Tuple[jax.Array, jax.Array]: """Helper function to return mode and count of a given array.""" if x.size == 0: return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)) else: vals, counts = jnp.unique(x, return_counts=True, size=x.size) return vals[jnp.argmax(counts)], counts.max() axis = canonicalize_axis(axis, x.ndim) x = jnp.moveaxis(x, axis, 0) x = x.reshape(x.shape[0], math.prod(x.shape[1:])) vals, counts = vmap(_mode_helper, in_axes=1)(x) return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape)) def invert_permutation(i: Array) -> Array: """Helper function that inverts a permutation array.""" return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype)) @_wraps(scipy.stats.rankdata, lax_description="""\ Currently the only supported nan_policy is 'propagate' """) @partial(jit, static_argnames=["method", "axis", "nan_policy"]) def rankdata( a: ArrayLike, method: str = "average", *, axis: Optional[int] = None, nan_policy: str = "propagate", ) -> Array: check_arraylike("rankdata", a) if nan_policy not in ["propagate", "omit", "raise"]: raise ValueError( f"Illegal nan_policy value {nan_policy!r}; expected one of " "{'propoagate', 'omit', 'raise'}" ) if nan_policy == "omit": raise NotImplementedError( f"Logic for `nan_policy` of {nan_policy} is not implemented" ) if nan_policy == "raise": raise NotImplementedError( "In order to best JIT compile `mode`, we cannot know whether `x` " "contains nans. Please check if nans exist in `x` outside of the " "`rankdata` function." ) if method not in ("average", "min", "max", "dense", "ordinal"): raise ValueError(f"unknown method '{method}'") a = jnp.asarray(a) if axis is not None: return jnp.apply_along_axis(rankdata, axis, a, method) arr = jnp.ravel(a) sorter = jnp.argsort(arr) inv = invert_permutation(sorter) if method == "ordinal": return inv + 1 arr = arr[sorter] obs = jnp.insert(arr[1:] != arr[:-1], 0, True) dense = obs.cumsum()[inv] if method == "dense": return dense count = jnp.nonzero(obs, size=arr.size + 1, fill_value=len(obs))[0] if method == "max": return count[dense] if method == "min": return count[dense - 1] + 1 if method == "average": return .5 * (count[dense] + count[dense - 1] + 1).astype(dtypes.canonicalize_dtype(jnp.float_)) raise ValueError(f"unknown method '{method}'")