Traktor/myenv/Lib/site-packages/sympy/codegen/approximations.py
2024-05-23 01:57:24 +02:00

188 lines
6.3 KiB
Python

import math
from sympy.sets.sets import Interval
from sympy.calculus.singularities import is_increasing, is_decreasing
from sympy.codegen.rewriting import Optimization
from sympy.core.function import UndefinedFunction
"""
This module collects classes useful for approimate rewriting of expressions.
This can be beneficial when generating numeric code for which performance is
of greater importance than precision (e.g. for preconditioners used in iterative
methods).
"""
class SumApprox(Optimization):
"""
Approximates sum by neglecting small terms.
Explanation
===========
If terms are expressions which can be determined to be monotonic, then
bounds for those expressions are added.
Parameters
==========
bounds : dict
Mapping expressions to length 2 tuple of bounds (low, high).
reltol : number
Threshold for when to ignore a term. Taken relative to the largest
lower bound among bounds.
Examples
========
>>> from sympy import exp
>>> from sympy.abc import x, y, z
>>> from sympy.codegen.rewriting import optimize
>>> from sympy.codegen.approximations import SumApprox
>>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)}
>>> sum_approx3 = SumApprox(bounds, reltol=1e-3)
>>> sum_approx2 = SumApprox(bounds, reltol=1e-2)
>>> sum_approx1 = SumApprox(bounds, reltol=1e-1)
>>> expr = 3*(x + y + exp(z))
>>> optimize(expr, [sum_approx3])
3*(x + y + exp(z))
>>> optimize(expr, [sum_approx2])
3*y + 3*exp(z)
>>> optimize(expr, [sum_approx1])
3*y
"""
def __init__(self, bounds, reltol, **kwargs):
super().__init__(**kwargs)
self.bounds = bounds
self.reltol = reltol
def __call__(self, expr):
return expr.factor().replace(self.query, lambda arg: self.value(arg))
def query(self, expr):
return expr.is_Add
def value(self, add):
for term in add.args:
if term.is_number or term in self.bounds or len(term.free_symbols) != 1:
continue
fs, = term.free_symbols
if fs not in self.bounds:
continue
intrvl = Interval(*self.bounds[fs])
if is_increasing(term, intrvl, fs):
self.bounds[term] = (
term.subs({fs: self.bounds[fs][0]}),
term.subs({fs: self.bounds[fs][1]})
)
elif is_decreasing(term, intrvl, fs):
self.bounds[term] = (
term.subs({fs: self.bounds[fs][1]}),
term.subs({fs: self.bounds[fs][0]})
)
else:
return add
if all(term.is_number or term in self.bounds for term in add.args):
bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args]
largest_abs_guarantee = 0
for lo, hi in bounds:
if lo <= 0 <= hi:
continue
largest_abs_guarantee = max(largest_abs_guarantee,
min(abs(lo), abs(hi)))
new_terms = []
for term, (lo, hi) in zip(add.args, bounds):
if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol:
new_terms.append(term)
return add.func(*new_terms)
else:
return add
class SeriesApprox(Optimization):
""" Approximates functions by expanding them as a series.
Parameters
==========
bounds : dict
Mapping expressions to length 2 tuple of bounds (low, high).
reltol : number
Threshold for when to ignore a term. Taken relative to the largest
lower bound among bounds.
max_order : int
Largest order to include in series expansion
n_point_checks : int (even)
The validity of an expansion (with respect to reltol) is checked at
discrete points (linearly spaced over the bounds of the variable). The
number of points used in this numerical check is given by this number.
Examples
========
>>> from sympy import sin, pi
>>> from sympy.abc import x, y
>>> from sympy.codegen.rewriting import optimize
>>> from sympy.codegen.approximations import SeriesApprox
>>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)}
>>> series_approx2 = SeriesApprox(bounds, reltol=1e-2)
>>> series_approx3 = SeriesApprox(bounds, reltol=1e-3)
>>> series_approx8 = SeriesApprox(bounds, reltol=1e-8)
>>> expr = sin(x)*sin(y)
>>> optimize(expr, [series_approx2])
x*(-y + (y - pi)**3/6 + pi)
>>> optimize(expr, [series_approx3])
(-x**3/6 + x)*sin(y)
>>> optimize(expr, [series_approx8])
sin(x)*sin(y)
"""
def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs):
super().__init__(**kwargs)
self.bounds = bounds
self.reltol = reltol
self.max_order = max_order
if n_point_checks % 2 == 1:
raise ValueError("Checking the solution at expansion point is not helpful")
self.n_point_checks = n_point_checks
self._prec = math.ceil(-math.log10(self.reltol))
def __call__(self, expr):
return expr.factor().replace(self.query, lambda arg: self.value(arg))
def query(self, expr):
return (expr.is_Function and not isinstance(expr, UndefinedFunction)
and len(expr.args) == 1)
def value(self, fexpr):
free_symbols = fexpr.free_symbols
if len(free_symbols) != 1:
return fexpr
symb, = free_symbols
if symb not in self.bounds:
return fexpr
lo, hi = self.bounds[symb]
x0 = (lo + hi)/2
cheapest = None
for n in range(self.max_order+1, 0, -1):
fseri = fexpr.series(symb, x0=x0, n=n).removeO()
n_ok = True
for idx in range(self.n_point_checks):
x = lo + idx*(hi - lo)/(self.n_point_checks - 1)
val = fseri.xreplace({symb: x})
ref = fexpr.xreplace({symb: x})
if abs((1 - val/ref).evalf(self._prec)) > self.reltol:
n_ok = False
break
if n_ok:
cheapest = fseri
else:
break
if cheapest is None:
return fexpr
else:
return cheapest