188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
|
import math
|
||
|
from sympy.sets.sets import Interval
|
||
|
from sympy.calculus.singularities import is_increasing, is_decreasing
|
||
|
from sympy.codegen.rewriting import Optimization
|
||
|
from sympy.core.function import UndefinedFunction
|
||
|
|
||
|
"""
|
||
|
This module collects classes useful for approimate rewriting of expressions.
|
||
|
This can be beneficial when generating numeric code for which performance is
|
||
|
of greater importance than precision (e.g. for preconditioners used in iterative
|
||
|
methods).
|
||
|
"""
|
||
|
|
||
|
class SumApprox(Optimization):
|
||
|
"""
|
||
|
Approximates sum by neglecting small terms.
|
||
|
|
||
|
Explanation
|
||
|
===========
|
||
|
|
||
|
If terms are expressions which can be determined to be monotonic, then
|
||
|
bounds for those expressions are added.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
bounds : dict
|
||
|
Mapping expressions to length 2 tuple of bounds (low, high).
|
||
|
reltol : number
|
||
|
Threshold for when to ignore a term. Taken relative to the largest
|
||
|
lower bound among bounds.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import exp
|
||
|
>>> from sympy.abc import x, y, z
|
||
|
>>> from sympy.codegen.rewriting import optimize
|
||
|
>>> from sympy.codegen.approximations import SumApprox
|
||
|
>>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)}
|
||
|
>>> sum_approx3 = SumApprox(bounds, reltol=1e-3)
|
||
|
>>> sum_approx2 = SumApprox(bounds, reltol=1e-2)
|
||
|
>>> sum_approx1 = SumApprox(bounds, reltol=1e-1)
|
||
|
>>> expr = 3*(x + y + exp(z))
|
||
|
>>> optimize(expr, [sum_approx3])
|
||
|
3*(x + y + exp(z))
|
||
|
>>> optimize(expr, [sum_approx2])
|
||
|
3*y + 3*exp(z)
|
||
|
>>> optimize(expr, [sum_approx1])
|
||
|
3*y
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, bounds, reltol, **kwargs):
|
||
|
super().__init__(**kwargs)
|
||
|
self.bounds = bounds
|
||
|
self.reltol = reltol
|
||
|
|
||
|
def __call__(self, expr):
|
||
|
return expr.factor().replace(self.query, lambda arg: self.value(arg))
|
||
|
|
||
|
def query(self, expr):
|
||
|
return expr.is_Add
|
||
|
|
||
|
def value(self, add):
|
||
|
for term in add.args:
|
||
|
if term.is_number or term in self.bounds or len(term.free_symbols) != 1:
|
||
|
continue
|
||
|
fs, = term.free_symbols
|
||
|
if fs not in self.bounds:
|
||
|
continue
|
||
|
intrvl = Interval(*self.bounds[fs])
|
||
|
if is_increasing(term, intrvl, fs):
|
||
|
self.bounds[term] = (
|
||
|
term.subs({fs: self.bounds[fs][0]}),
|
||
|
term.subs({fs: self.bounds[fs][1]})
|
||
|
)
|
||
|
elif is_decreasing(term, intrvl, fs):
|
||
|
self.bounds[term] = (
|
||
|
term.subs({fs: self.bounds[fs][1]}),
|
||
|
term.subs({fs: self.bounds[fs][0]})
|
||
|
)
|
||
|
else:
|
||
|
return add
|
||
|
|
||
|
if all(term.is_number or term in self.bounds for term in add.args):
|
||
|
bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args]
|
||
|
largest_abs_guarantee = 0
|
||
|
for lo, hi in bounds:
|
||
|
if lo <= 0 <= hi:
|
||
|
continue
|
||
|
largest_abs_guarantee = max(largest_abs_guarantee,
|
||
|
min(abs(lo), abs(hi)))
|
||
|
new_terms = []
|
||
|
for term, (lo, hi) in zip(add.args, bounds):
|
||
|
if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol:
|
||
|
new_terms.append(term)
|
||
|
return add.func(*new_terms)
|
||
|
else:
|
||
|
return add
|
||
|
|
||
|
|
||
|
class SeriesApprox(Optimization):
|
||
|
""" Approximates functions by expanding them as a series.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
bounds : dict
|
||
|
Mapping expressions to length 2 tuple of bounds (low, high).
|
||
|
reltol : number
|
||
|
Threshold for when to ignore a term. Taken relative to the largest
|
||
|
lower bound among bounds.
|
||
|
max_order : int
|
||
|
Largest order to include in series expansion
|
||
|
n_point_checks : int (even)
|
||
|
The validity of an expansion (with respect to reltol) is checked at
|
||
|
discrete points (linearly spaced over the bounds of the variable). The
|
||
|
number of points used in this numerical check is given by this number.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import sin, pi
|
||
|
>>> from sympy.abc import x, y
|
||
|
>>> from sympy.codegen.rewriting import optimize
|
||
|
>>> from sympy.codegen.approximations import SeriesApprox
|
||
|
>>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)}
|
||
|
>>> series_approx2 = SeriesApprox(bounds, reltol=1e-2)
|
||
|
>>> series_approx3 = SeriesApprox(bounds, reltol=1e-3)
|
||
|
>>> series_approx8 = SeriesApprox(bounds, reltol=1e-8)
|
||
|
>>> expr = sin(x)*sin(y)
|
||
|
>>> optimize(expr, [series_approx2])
|
||
|
x*(-y + (y - pi)**3/6 + pi)
|
||
|
>>> optimize(expr, [series_approx3])
|
||
|
(-x**3/6 + x)*sin(y)
|
||
|
>>> optimize(expr, [series_approx8])
|
||
|
sin(x)*sin(y)
|
||
|
|
||
|
"""
|
||
|
def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs):
|
||
|
super().__init__(**kwargs)
|
||
|
self.bounds = bounds
|
||
|
self.reltol = reltol
|
||
|
self.max_order = max_order
|
||
|
if n_point_checks % 2 == 1:
|
||
|
raise ValueError("Checking the solution at expansion point is not helpful")
|
||
|
self.n_point_checks = n_point_checks
|
||
|
self._prec = math.ceil(-math.log10(self.reltol))
|
||
|
|
||
|
def __call__(self, expr):
|
||
|
return expr.factor().replace(self.query, lambda arg: self.value(arg))
|
||
|
|
||
|
def query(self, expr):
|
||
|
return (expr.is_Function and not isinstance(expr, UndefinedFunction)
|
||
|
and len(expr.args) == 1)
|
||
|
|
||
|
def value(self, fexpr):
|
||
|
free_symbols = fexpr.free_symbols
|
||
|
if len(free_symbols) != 1:
|
||
|
return fexpr
|
||
|
symb, = free_symbols
|
||
|
if symb not in self.bounds:
|
||
|
return fexpr
|
||
|
lo, hi = self.bounds[symb]
|
||
|
x0 = (lo + hi)/2
|
||
|
cheapest = None
|
||
|
for n in range(self.max_order+1, 0, -1):
|
||
|
fseri = fexpr.series(symb, x0=x0, n=n).removeO()
|
||
|
n_ok = True
|
||
|
for idx in range(self.n_point_checks):
|
||
|
x = lo + idx*(hi - lo)/(self.n_point_checks - 1)
|
||
|
val = fseri.xreplace({symb: x})
|
||
|
ref = fexpr.xreplace({symb: x})
|
||
|
if abs((1 - val/ref).evalf(self._prec)) > self.reltol:
|
||
|
n_ok = False
|
||
|
break
|
||
|
|
||
|
if n_ok:
|
||
|
cheapest = fseri
|
||
|
else:
|
||
|
break
|
||
|
|
||
|
if cheapest is None:
|
||
|
return fexpr
|
||
|
else:
|
||
|
return cheapest
|