101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
"""Tools for arithmetic error propagation."""
|
|
|
|
from itertools import repeat, combinations
|
|
|
|
from sympy.core.add import Add
|
|
from sympy.core.mul import Mul
|
|
from sympy.core.power import Pow
|
|
from sympy.core.singleton import S
|
|
from sympy.core.symbol import Symbol
|
|
from sympy.functions.elementary.exponential import exp
|
|
from sympy.simplify.simplify import simplify
|
|
from sympy.stats.symbolic_probability import RandomSymbol, Variance, Covariance
|
|
from sympy.stats.rv import is_random
|
|
|
|
_arg0_or_var = lambda var: var.args[0] if len(var.args) > 0 else var
|
|
|
|
|
|
def variance_prop(expr, consts=(), include_covar=False):
|
|
r"""Symbolically propagates variance (`\sigma^2`) for expressions.
|
|
This is computed as as seen in [1]_.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
expr : Expr
|
|
A SymPy expression to compute the variance for.
|
|
consts : sequence of Symbols, optional
|
|
Represents symbols that are known constants in the expr,
|
|
and thus have zero variance. All symbols not in consts are
|
|
assumed to be variant.
|
|
include_covar : bool, optional
|
|
Flag for whether or not to include covariances, default=False.
|
|
|
|
Returns
|
|
=======
|
|
|
|
var_expr : Expr
|
|
An expression for the total variance of the expr.
|
|
The variance for the original symbols (e.g. x) are represented
|
|
via instance of the Variance symbol (e.g. Variance(x)).
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import symbols, exp
|
|
>>> from sympy.stats.error_prop import variance_prop
|
|
>>> x, y = symbols('x y')
|
|
|
|
>>> variance_prop(x + y)
|
|
Variance(x) + Variance(y)
|
|
|
|
>>> variance_prop(x * y)
|
|
x**2*Variance(y) + y**2*Variance(x)
|
|
|
|
>>> variance_prop(exp(2*x))
|
|
4*exp(4*x)*Variance(x)
|
|
|
|
References
|
|
==========
|
|
|
|
.. [1] https://en.wikipedia.org/wiki/Propagation_of_uncertainty
|
|
|
|
"""
|
|
args = expr.args
|
|
if len(args) == 0:
|
|
if expr in consts:
|
|
return S.Zero
|
|
elif is_random(expr):
|
|
return Variance(expr).doit()
|
|
elif isinstance(expr, Symbol):
|
|
return Variance(RandomSymbol(expr)).doit()
|
|
else:
|
|
return S.Zero
|
|
nargs = len(args)
|
|
var_args = list(map(variance_prop, args, repeat(consts, nargs),
|
|
repeat(include_covar, nargs)))
|
|
if isinstance(expr, Add):
|
|
var_expr = Add(*var_args)
|
|
if include_covar:
|
|
terms = [2 * Covariance(_arg0_or_var(x), _arg0_or_var(y)).expand() \
|
|
for x, y in combinations(var_args, 2)]
|
|
var_expr += Add(*terms)
|
|
elif isinstance(expr, Mul):
|
|
terms = [v/a**2 for a, v in zip(args, var_args)]
|
|
var_expr = simplify(expr**2 * Add(*terms))
|
|
if include_covar:
|
|
terms = [2*Covariance(_arg0_or_var(x), _arg0_or_var(y)).expand()/(a*b) \
|
|
for (a, b), (x, y) in zip(combinations(args, 2),
|
|
combinations(var_args, 2))]
|
|
var_expr += Add(*terms)
|
|
elif isinstance(expr, Pow):
|
|
b = args[1]
|
|
v = var_args[0] * (expr * b / args[0])**2
|
|
var_expr = simplify(v)
|
|
elif isinstance(expr, exp):
|
|
var_expr = simplify(var_args[0] * expr**2)
|
|
else:
|
|
# unknown how to proceed, return variance of whole expr.
|
|
var_expr = Variance(expr)
|
|
return var_expr
|