""" Module contains classes for invertible (and differentiable) link functions. """ # Author: Christian Lorentzen from abc import ABC, abstractmethod from dataclasses import dataclass import numpy as np from scipy.special import expit, logit from scipy.stats import gmean from ..utils.extmath import softmax @dataclass class Interval: low: float high: float low_inclusive: bool high_inclusive: bool def __post_init__(self): """Check that low <= high""" if self.low > self.high: raise ValueError( f"One must have low <= high; got low={self.low}, high={self.high}." ) def includes(self, x): """Test whether all values of x are in interval range. Parameters ---------- x : ndarray Array whose elements are tested to be in interval range. Returns ------- result : bool """ if self.low_inclusive: low = np.greater_equal(x, self.low) else: low = np.greater(x, self.low) if not np.all(low): return False if self.high_inclusive: high = np.less_equal(x, self.high) else: high = np.less(x, self.high) # Note: np.all returns numpy.bool_ return bool(np.all(high)) def _inclusive_low_high(interval, dtype=np.float64): """Generate values low and high to be within the interval range. This is used in tests only. Returns ------- low, high : tuple The returned values low and high lie within the interval. """ eps = 10 * np.finfo(dtype).eps if interval.low == -np.inf: low = -1e10 elif interval.low < 0: low = interval.low * (1 - eps) + eps else: low = interval.low * (1 + eps) + eps if interval.high == np.inf: high = 1e10 elif interval.high < 0: high = interval.high * (1 + eps) - eps else: high = interval.high * (1 - eps) - eps return low, high class BaseLink(ABC): """Abstract base class for differentiable, invertible link functions. Convention: - link function g: raw_prediction = g(y_pred) - inverse link h: y_pred = h(raw_prediction) For (generalized) linear models, `raw_prediction = X @ coef` is the so called linear predictor, and `y_pred = h(raw_prediction)` is the predicted conditional (on X) expected value of the target `y_true`. The methods are not implemented as staticmethods in case a link function needs parameters. """ is_multiclass = False # used for testing only # Usually, raw_prediction may be any real number and y_pred is an open # interval. # interval_raw_prediction = Interval(-np.inf, np.inf, False, False) interval_y_pred = Interval(-np.inf, np.inf, False, False) @abstractmethod def link(self, y_pred, out=None): """Compute the link function g(y_pred). The link function maps (predicted) target values to raw predictions, i.e. `g(y_pred) = raw_prediction`. Parameters ---------- y_pred : array Predicted target values. out : array A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. Returns ------- out : array Output array, element-wise link function. """ @abstractmethod def inverse(self, raw_prediction, out=None): """Compute the inverse link function h(raw_prediction). The inverse link function maps raw predictions to predicted target values, i.e. `h(raw_prediction) = y_pred`. Parameters ---------- raw_prediction : array Raw prediction values (in link space). out : array A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. Returns ------- out : array Output array, element-wise inverse link function. """ class IdentityLink(BaseLink): """The identity link function g(x)=x.""" def link(self, y_pred, out=None): if out is not None: np.copyto(out, y_pred) return out else: return y_pred inverse = link class LogLink(BaseLink): """The log link function g(x)=log(x).""" interval_y_pred = Interval(0, np.inf, False, False) def link(self, y_pred, out=None): return np.log(y_pred, out=out) def inverse(self, raw_prediction, out=None): return np.exp(raw_prediction, out=out) class LogitLink(BaseLink): """The logit link function g(x)=logit(x).""" interval_y_pred = Interval(0, 1, False, False) def link(self, y_pred, out=None): return logit(y_pred, out=out) def inverse(self, raw_prediction, out=None): return expit(raw_prediction, out=out) class MultinomialLogit(BaseLink): """The symmetric multinomial logit function. Convention: - y_pred.shape = raw_prediction.shape = (n_samples, n_classes) Notes: - The inverse link h is the softmax function. - The sum is over the second axis, i.e. axis=1 (n_classes). We have to choose additional constraints in order to make y_pred[k] = exp(raw_pred[k]) / sum(exp(raw_pred[k]), k=0..n_classes-1) for n_classes classes identifiable and invertible. We choose the symmetric side constraint where the geometric mean response is set as reference category, see [2]: The symmetric multinomial logit link function for a single data point is then defined as raw_prediction[k] = g(y_pred[k]) = log(y_pred[k]/gmean(y_pred)) = log(y_pred[k]) - mean(log(y_pred)). Note that this is equivalent to the definition in [1] and implies mean centered raw predictions: sum(raw_prediction[k], k=0..n_classes-1) = 0. For linear models with raw_prediction = X @ coef, this corresponds to sum(coef[k], k=0..n_classes-1) = 0, i.e. the sum over classes for every feature is zero. Reference --------- .. [1] Friedman, Jerome; Hastie, Trevor; Tibshirani, Robert. "Additive logistic regression: a statistical view of boosting" Ann. Statist. 28 (2000), no. 2, 337--407. doi:10.1214/aos/1016218223. https://projecteuclid.org/euclid.aos/1016218223 .. [2] Zahid, Faisal Maqbool and Gerhard Tutz. "Ridge estimation for multinomial logit models with symmetric side constraints." Computational Statistics 28 (2013): 1017-1034. http://epub.ub.uni-muenchen.de/11001/1/tr067.pdf """ is_multiclass = True interval_y_pred = Interval(0, 1, False, False) def symmetrize_raw_prediction(self, raw_prediction): return raw_prediction - np.mean(raw_prediction, axis=1)[:, np.newaxis] def link(self, y_pred, out=None): # geometric mean as reference category gm = gmean(y_pred, axis=1) return np.log(y_pred / gm[:, np.newaxis], out=out) def inverse(self, raw_prediction, out=None): if out is None: return softmax(raw_prediction, copy=True) else: np.copyto(out, raw_prediction) softmax(out, copy=False) return out _LINKS = { "identity": IdentityLink, "log": LogLink, "logit": LogitLink, "multinomial_logit": MultinomialLogit, }