1226 lines
40 KiB
Python
1226 lines
40 KiB
Python
from collections import defaultdict
|
|
|
|
from sympy.core import sympify, S, Mul, Derivative, Pow
|
|
from sympy.core.add import _unevaluated_Add, Add
|
|
from sympy.core.assumptions import assumptions
|
|
from sympy.core.exprtools import Factors, gcd_terms
|
|
from sympy.core.function import _mexpand, expand_mul, expand_power_base
|
|
from sympy.core.mul import _keep_coeff, _unevaluated_Mul, _mulsort
|
|
from sympy.core.numbers import Rational, zoo, nan
|
|
from sympy.core.parameters import global_parameters
|
|
from sympy.core.sorting import ordered, default_sort_key
|
|
from sympy.core.symbol import Dummy, Wild, symbols
|
|
from sympy.functions import exp, sqrt, log
|
|
from sympy.functions.elementary.complexes import Abs
|
|
from sympy.polys import gcd
|
|
from sympy.simplify.sqrtdenest import sqrtdenest
|
|
from sympy.utilities.iterables import iterable, sift
|
|
|
|
|
|
|
|
|
|
def collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True):
|
|
"""
|
|
Collect additive terms of an expression.
|
|
|
|
Explanation
|
|
===========
|
|
|
|
This function collects additive terms of an expression with respect
|
|
to a list of expression up to powers with rational exponents. By the
|
|
term symbol here are meant arbitrary expressions, which can contain
|
|
powers, products, sums etc. In other words symbol is a pattern which
|
|
will be searched for in the expression's terms.
|
|
|
|
The input expression is not expanded by :func:`collect`, so user is
|
|
expected to provide an expression in an appropriate form. This makes
|
|
:func:`collect` more predictable as there is no magic happening behind the
|
|
scenes. However, it is important to note, that powers of products are
|
|
converted to products of powers using the :func:`~.expand_power_base`
|
|
function.
|
|
|
|
There are two possible types of output. First, if ``evaluate`` flag is
|
|
set, this function will return an expression with collected terms or
|
|
else it will return a dictionary with expressions up to rational powers
|
|
as keys and collected coefficients as values.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import S, collect, expand, factor, Wild
|
|
>>> from sympy.abc import a, b, c, x, y
|
|
|
|
This function can collect symbolic coefficients in polynomials or
|
|
rational expressions. It will manage to find all integer or rational
|
|
powers of collection variable::
|
|
|
|
>>> collect(a*x**2 + b*x**2 + a*x - b*x + c, x)
|
|
c + x**2*(a + b) + x*(a - b)
|
|
|
|
The same result can be achieved in dictionary form::
|
|
|
|
>>> d = collect(a*x**2 + b*x**2 + a*x - b*x + c, x, evaluate=False)
|
|
>>> d[x**2]
|
|
a + b
|
|
>>> d[x]
|
|
a - b
|
|
>>> d[S.One]
|
|
c
|
|
|
|
You can also work with multivariate polynomials. However, remember that
|
|
this function is greedy so it will care only about a single symbol at time,
|
|
in specification order::
|
|
|
|
>>> collect(x**2 + y*x**2 + x*y + y + a*y, [x, y])
|
|
x**2*(y + 1) + x*y + y*(a + 1)
|
|
|
|
Also more complicated expressions can be used as patterns::
|
|
|
|
>>> from sympy import sin, log
|
|
>>> collect(a*sin(2*x) + b*sin(2*x), sin(2*x))
|
|
(a + b)*sin(2*x)
|
|
|
|
>>> collect(a*x*log(x) + b*(x*log(x)), x*log(x))
|
|
x*(a + b)*log(x)
|
|
|
|
You can use wildcards in the pattern::
|
|
|
|
>>> w = Wild('w1')
|
|
>>> collect(a*x**y - b*x**y, w**y)
|
|
x**y*(a - b)
|
|
|
|
It is also possible to work with symbolic powers, although it has more
|
|
complicated behavior, because in this case power's base and symbolic part
|
|
of the exponent are treated as a single symbol::
|
|
|
|
>>> collect(a*x**c + b*x**c, x)
|
|
a*x**c + b*x**c
|
|
>>> collect(a*x**c + b*x**c, x**c)
|
|
x**c*(a + b)
|
|
|
|
However if you incorporate rationals to the exponents, then you will get
|
|
well known behavior::
|
|
|
|
>>> collect(a*x**(2*c) + b*x**(2*c), x**c)
|
|
x**(2*c)*(a + b)
|
|
|
|
Note also that all previously stated facts about :func:`collect` function
|
|
apply to the exponential function, so you can get::
|
|
|
|
>>> from sympy import exp
|
|
>>> collect(a*exp(2*x) + b*exp(2*x), exp(x))
|
|
(a + b)*exp(2*x)
|
|
|
|
If you are interested only in collecting specific powers of some symbols
|
|
then set ``exact`` flag to True::
|
|
|
|
>>> collect(a*x**7 + b*x**7, x, exact=True)
|
|
a*x**7 + b*x**7
|
|
>>> collect(a*x**7 + b*x**7, x**7, exact=True)
|
|
x**7*(a + b)
|
|
|
|
If you want to collect on any object containing symbols, set
|
|
``exact`` to None:
|
|
|
|
>>> collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None)
|
|
x*exp(x) + 3*x + (y + 2)*sin(x)
|
|
>>> collect(a*x*y + x*y + b*x + x, [x, y], exact=None)
|
|
x*y*(a + 1) + x*(b + 1)
|
|
|
|
You can also apply this function to differential equations, where
|
|
derivatives of arbitrary order can be collected. Note that if you
|
|
collect with respect to a function or a derivative of a function, all
|
|
derivatives of that function will also be collected. Use
|
|
``exact=True`` to prevent this from happening::
|
|
|
|
>>> from sympy import Derivative as D, collect, Function
|
|
>>> f = Function('f') (x)
|
|
|
|
>>> collect(a*D(f,x) + b*D(f,x), D(f,x))
|
|
(a + b)*Derivative(f(x), x)
|
|
|
|
>>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), f)
|
|
(a + b)*Derivative(f(x), (x, 2))
|
|
|
|
>>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), D(f,x), exact=True)
|
|
a*Derivative(f(x), (x, 2)) + b*Derivative(f(x), (x, 2))
|
|
|
|
>>> collect(a*D(f,x) + b*D(f,x) + a*f + b*f, f)
|
|
(a + b)*f(x) + (a + b)*Derivative(f(x), x)
|
|
|
|
Or you can even match both derivative order and exponent at the same time::
|
|
|
|
>>> collect(a*D(D(f,x),x)**2 + b*D(D(f,x),x)**2, D(f,x))
|
|
(a + b)*Derivative(f(x), (x, 2))**2
|
|
|
|
Finally, you can apply a function to each of the collected coefficients.
|
|
For example you can factorize symbolic coefficients of polynomial::
|
|
|
|
>>> f = expand((x + a + 1)**3)
|
|
|
|
>>> collect(f, x, factor)
|
|
x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + (a + 1)**3
|
|
|
|
.. note:: Arguments are expected to be in expanded form, so you might have
|
|
to call :func:`~.expand` prior to calling this function.
|
|
|
|
See Also
|
|
========
|
|
|
|
collect_const, collect_sqrt, rcollect
|
|
"""
|
|
expr = sympify(expr)
|
|
syms = [sympify(i) for i in (syms if iterable(syms) else [syms])]
|
|
|
|
# replace syms[i] if it is not x, -x or has Wild symbols
|
|
cond = lambda x: x.is_Symbol or (-x).is_Symbol or bool(
|
|
x.atoms(Wild))
|
|
_, nonsyms = sift(syms, cond, binary=True)
|
|
if nonsyms:
|
|
reps = dict(zip(nonsyms, [Dummy(**assumptions(i)) for i in nonsyms]))
|
|
syms = [reps.get(s, s) for s in syms]
|
|
rv = collect(expr.subs(reps), syms,
|
|
func=func, evaluate=evaluate, exact=exact,
|
|
distribute_order_term=distribute_order_term)
|
|
urep = {v: k for k, v in reps.items()}
|
|
if not isinstance(rv, dict):
|
|
return rv.xreplace(urep)
|
|
else:
|
|
return {urep.get(k, k).xreplace(urep): v.xreplace(urep)
|
|
for k, v in rv.items()}
|
|
|
|
# see if other expressions should be considered
|
|
if exact is None:
|
|
_syms = set()
|
|
for i in Add.make_args(expr):
|
|
if not i.has_free(*syms) or i in syms:
|
|
continue
|
|
if not i.is_Mul and i not in syms:
|
|
_syms.add(i)
|
|
else:
|
|
# identify compound generators
|
|
g = i._new_rawargs(*i.as_coeff_mul(*syms)[1])
|
|
if g not in syms:
|
|
_syms.add(g)
|
|
simple = all(i.is_Pow and i.base in syms for i in _syms)
|
|
syms = syms + list(ordered(_syms))
|
|
if not simple:
|
|
return collect(expr, syms,
|
|
func=func, evaluate=evaluate, exact=False,
|
|
distribute_order_term=distribute_order_term)
|
|
|
|
if evaluate is None:
|
|
evaluate = global_parameters.evaluate
|
|
|
|
def make_expression(terms):
|
|
product = []
|
|
|
|
for term, rat, sym, deriv in terms:
|
|
if deriv is not None:
|
|
var, order = deriv
|
|
|
|
while order > 0:
|
|
term, order = Derivative(term, var), order - 1
|
|
|
|
if sym is None:
|
|
if rat is S.One:
|
|
product.append(term)
|
|
else:
|
|
product.append(Pow(term, rat))
|
|
else:
|
|
product.append(Pow(term, rat*sym))
|
|
|
|
return Mul(*product)
|
|
|
|
def parse_derivative(deriv):
|
|
# scan derivatives tower in the input expression and return
|
|
# underlying function and maximal differentiation order
|
|
expr, sym, order = deriv.expr, deriv.variables[0], 1
|
|
|
|
for s in deriv.variables[1:]:
|
|
if s == sym:
|
|
order += 1
|
|
else:
|
|
raise NotImplementedError(
|
|
'Improve MV Derivative support in collect')
|
|
|
|
while isinstance(expr, Derivative):
|
|
s0 = expr.variables[0]
|
|
|
|
for s in expr.variables:
|
|
if s != s0:
|
|
raise NotImplementedError(
|
|
'Improve MV Derivative support in collect')
|
|
|
|
if s0 == sym:
|
|
expr, order = expr.expr, order + len(expr.variables)
|
|
else:
|
|
break
|
|
|
|
return expr, (sym, Rational(order))
|
|
|
|
def parse_term(expr):
|
|
"""Parses expression expr and outputs tuple (sexpr, rat_expo,
|
|
sym_expo, deriv)
|
|
where:
|
|
- sexpr is the base expression
|
|
- rat_expo is the rational exponent that sexpr is raised to
|
|
- sym_expo is the symbolic exponent that sexpr is raised to
|
|
- deriv contains the derivatives of the expression
|
|
|
|
For example, the output of x would be (x, 1, None, None)
|
|
the output of 2**x would be (2, 1, x, None).
|
|
"""
|
|
rat_expo, sym_expo = S.One, None
|
|
sexpr, deriv = expr, None
|
|
|
|
if expr.is_Pow:
|
|
if isinstance(expr.base, Derivative):
|
|
sexpr, deriv = parse_derivative(expr.base)
|
|
else:
|
|
sexpr = expr.base
|
|
|
|
if expr.base == S.Exp1:
|
|
arg = expr.exp
|
|
if arg.is_Rational:
|
|
sexpr, rat_expo = S.Exp1, arg
|
|
elif arg.is_Mul:
|
|
coeff, tail = arg.as_coeff_Mul(rational=True)
|
|
sexpr, rat_expo = exp(tail), coeff
|
|
|
|
elif expr.exp.is_Number:
|
|
rat_expo = expr.exp
|
|
else:
|
|
coeff, tail = expr.exp.as_coeff_Mul()
|
|
|
|
if coeff.is_Number:
|
|
rat_expo, sym_expo = coeff, tail
|
|
else:
|
|
sym_expo = expr.exp
|
|
elif isinstance(expr, exp):
|
|
arg = expr.exp
|
|
if arg.is_Rational:
|
|
sexpr, rat_expo = S.Exp1, arg
|
|
elif arg.is_Mul:
|
|
coeff, tail = arg.as_coeff_Mul(rational=True)
|
|
sexpr, rat_expo = exp(tail), coeff
|
|
elif isinstance(expr, Derivative):
|
|
sexpr, deriv = parse_derivative(expr)
|
|
|
|
return sexpr, rat_expo, sym_expo, deriv
|
|
|
|
def parse_expression(terms, pattern):
|
|
"""Parse terms searching for a pattern.
|
|
Terms is a list of tuples as returned by parse_terms;
|
|
Pattern is an expression treated as a product of factors.
|
|
"""
|
|
pattern = Mul.make_args(pattern)
|
|
|
|
if len(terms) < len(pattern):
|
|
# pattern is longer than matched product
|
|
# so no chance for positive parsing result
|
|
return None
|
|
else:
|
|
pattern = [parse_term(elem) for elem in pattern]
|
|
|
|
terms = terms[:] # need a copy
|
|
elems, common_expo, has_deriv = [], None, False
|
|
|
|
for elem, e_rat, e_sym, e_ord in pattern:
|
|
|
|
if elem.is_Number and e_rat == 1 and e_sym is None:
|
|
# a constant is a match for everything
|
|
continue
|
|
|
|
for j in range(len(terms)):
|
|
if terms[j] is None:
|
|
continue
|
|
|
|
term, t_rat, t_sym, t_ord = terms[j]
|
|
|
|
# keeping track of whether one of the terms had
|
|
# a derivative or not as this will require rebuilding
|
|
# the expression later
|
|
if t_ord is not None:
|
|
has_deriv = True
|
|
|
|
if (term.match(elem) is not None and
|
|
(t_sym == e_sym or t_sym is not None and
|
|
e_sym is not None and
|
|
t_sym.match(e_sym) is not None)):
|
|
if exact is False:
|
|
# we don't have to be exact so find common exponent
|
|
# for both expression's term and pattern's element
|
|
expo = t_rat / e_rat
|
|
|
|
if common_expo is None:
|
|
# first time
|
|
common_expo = expo
|
|
else:
|
|
# common exponent was negotiated before so
|
|
# there is no chance for a pattern match unless
|
|
# common and current exponents are equal
|
|
if common_expo != expo:
|
|
common_expo = 1
|
|
else:
|
|
# we ought to be exact so all fields of
|
|
# interest must match in every details
|
|
if e_rat != t_rat or e_ord != t_ord:
|
|
continue
|
|
|
|
# found common term so remove it from the expression
|
|
# and try to match next element in the pattern
|
|
elems.append(terms[j])
|
|
terms[j] = None
|
|
|
|
break
|
|
|
|
else:
|
|
# pattern element not found
|
|
return None
|
|
|
|
return [_f for _f in terms if _f], elems, common_expo, has_deriv
|
|
|
|
if evaluate:
|
|
if expr.is_Add:
|
|
o = expr.getO() or 0
|
|
expr = expr.func(*[
|
|
collect(a, syms, func, True, exact, distribute_order_term)
|
|
for a in expr.args if a != o]) + o
|
|
elif expr.is_Mul:
|
|
return expr.func(*[
|
|
collect(term, syms, func, True, exact, distribute_order_term)
|
|
for term in expr.args])
|
|
elif expr.is_Pow:
|
|
b = collect(
|
|
expr.base, syms, func, True, exact, distribute_order_term)
|
|
return Pow(b, expr.exp)
|
|
|
|
syms = [expand_power_base(i, deep=False) for i in syms]
|
|
|
|
order_term = None
|
|
|
|
if distribute_order_term:
|
|
order_term = expr.getO()
|
|
|
|
if order_term is not None:
|
|
if order_term.has(*syms):
|
|
order_term = None
|
|
else:
|
|
expr = expr.removeO()
|
|
|
|
summa = [expand_power_base(i, deep=False) for i in Add.make_args(expr)]
|
|
|
|
collected, disliked = defaultdict(list), S.Zero
|
|
for product in summa:
|
|
c, nc = product.args_cnc(split_1=False)
|
|
args = list(ordered(c)) + nc
|
|
terms = [parse_term(i) for i in args]
|
|
small_first = True
|
|
|
|
for symbol in syms:
|
|
if isinstance(symbol, Derivative) and small_first:
|
|
terms = list(reversed(terms))
|
|
small_first = not small_first
|
|
result = parse_expression(terms, symbol)
|
|
|
|
if result is not None:
|
|
if not symbol.is_commutative:
|
|
raise AttributeError("Can not collect noncommutative symbol")
|
|
|
|
terms, elems, common_expo, has_deriv = result
|
|
|
|
# when there was derivative in current pattern we
|
|
# will need to rebuild its expression from scratch
|
|
if not has_deriv:
|
|
margs = []
|
|
for elem in elems:
|
|
if elem[2] is None:
|
|
e = elem[1]
|
|
else:
|
|
e = elem[1]*elem[2]
|
|
margs.append(Pow(elem[0], e))
|
|
index = Mul(*margs)
|
|
else:
|
|
index = make_expression(elems)
|
|
terms = expand_power_base(make_expression(terms), deep=False)
|
|
index = expand_power_base(index, deep=False)
|
|
collected[index].append(terms)
|
|
break
|
|
else:
|
|
# none of the patterns matched
|
|
disliked += product
|
|
# add terms now for each key
|
|
collected = {k: Add(*v) for k, v in collected.items()}
|
|
|
|
if disliked is not S.Zero:
|
|
collected[S.One] = disliked
|
|
|
|
if order_term is not None:
|
|
for key, val in collected.items():
|
|
collected[key] = val + order_term
|
|
|
|
if func is not None:
|
|
collected = {
|
|
key: func(val) for key, val in collected.items()}
|
|
|
|
if evaluate:
|
|
return Add(*[key*val for key, val in collected.items()])
|
|
else:
|
|
return collected
|
|
|
|
|
|
def rcollect(expr, *vars):
|
|
"""
|
|
Recursively collect sums in an expression.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.simplify import rcollect
|
|
>>> from sympy.abc import x, y
|
|
|
|
>>> expr = (x**2*y + x*y + x + y)/(x + y)
|
|
|
|
>>> rcollect(expr, y)
|
|
(x + y*(x**2 + x + 1))/(x + y)
|
|
|
|
See Also
|
|
========
|
|
|
|
collect, collect_const, collect_sqrt
|
|
"""
|
|
if expr.is_Atom or not expr.has(*vars):
|
|
return expr
|
|
else:
|
|
expr = expr.__class__(*[rcollect(arg, *vars) for arg in expr.args])
|
|
|
|
if expr.is_Add:
|
|
return collect(expr, vars)
|
|
else:
|
|
return expr
|
|
|
|
|
|
def collect_sqrt(expr, evaluate=None):
|
|
"""Return expr with terms having common square roots collected together.
|
|
If ``evaluate`` is False a count indicating the number of sqrt-containing
|
|
terms will be returned and, if non-zero, the terms of the Add will be
|
|
returned, else the expression itself will be returned as a single term.
|
|
If ``evaluate`` is True, the expression with any collected terms will be
|
|
returned.
|
|
|
|
Note: since I = sqrt(-1), it is collected, too.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import sqrt
|
|
>>> from sympy.simplify.radsimp import collect_sqrt
|
|
>>> from sympy.abc import a, b
|
|
|
|
>>> r2, r3, r5 = [sqrt(i) for i in [2, 3, 5]]
|
|
>>> collect_sqrt(a*r2 + b*r2)
|
|
sqrt(2)*(a + b)
|
|
>>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r3)
|
|
sqrt(2)*(a + b) + sqrt(3)*(a + b)
|
|
>>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5)
|
|
sqrt(3)*a + sqrt(5)*b + sqrt(2)*(a + b)
|
|
|
|
If evaluate is False then the arguments will be sorted and
|
|
returned as a list and a count of the number of sqrt-containing
|
|
terms will be returned:
|
|
|
|
>>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5, evaluate=False)
|
|
((sqrt(3)*a, sqrt(5)*b, sqrt(2)*(a + b)), 3)
|
|
>>> collect_sqrt(a*sqrt(2) + b, evaluate=False)
|
|
((b, sqrt(2)*a), 1)
|
|
>>> collect_sqrt(a + b, evaluate=False)
|
|
((a + b,), 0)
|
|
|
|
See Also
|
|
========
|
|
|
|
collect, collect_const, rcollect
|
|
"""
|
|
if evaluate is None:
|
|
evaluate = global_parameters.evaluate
|
|
# this step will help to standardize any complex arguments
|
|
# of sqrts
|
|
coeff, expr = expr.as_content_primitive()
|
|
vars = set()
|
|
for a in Add.make_args(expr):
|
|
for m in a.args_cnc()[0]:
|
|
if m.is_number and (
|
|
m.is_Pow and m.exp.is_Rational and m.exp.q == 2 or
|
|
m is S.ImaginaryUnit):
|
|
vars.add(m)
|
|
|
|
# we only want radicals, so exclude Number handling; in this case
|
|
# d will be evaluated
|
|
d = collect_const(expr, *vars, Numbers=False)
|
|
hit = expr != d
|
|
|
|
if not evaluate:
|
|
nrad = 0
|
|
# make the evaluated args canonical
|
|
args = list(ordered(Add.make_args(d)))
|
|
for i, m in enumerate(args):
|
|
c, nc = m.args_cnc()
|
|
for ci in c:
|
|
# XXX should this be restricted to ci.is_number as above?
|
|
if ci.is_Pow and ci.exp.is_Rational and ci.exp.q == 2 or \
|
|
ci is S.ImaginaryUnit:
|
|
nrad += 1
|
|
break
|
|
args[i] *= coeff
|
|
if not (hit or nrad):
|
|
args = [Add(*args)]
|
|
return tuple(args), nrad
|
|
|
|
return coeff*d
|
|
|
|
|
|
def collect_abs(expr):
|
|
"""Return ``expr`` with arguments of multiple Abs in a term collected
|
|
under a single instance.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.simplify.radsimp import collect_abs
|
|
>>> from sympy.abc import x
|
|
>>> collect_abs(abs(x + 1)/abs(x**2 - 1))
|
|
Abs((x + 1)/(x**2 - 1))
|
|
>>> collect_abs(abs(1/x))
|
|
Abs(1/x)
|
|
"""
|
|
def _abs(mul):
|
|
c, nc = mul.args_cnc()
|
|
a = []
|
|
o = []
|
|
for i in c:
|
|
if isinstance(i, Abs):
|
|
a.append(i.args[0])
|
|
elif isinstance(i, Pow) and isinstance(i.base, Abs) and i.exp.is_real:
|
|
a.append(i.base.args[0]**i.exp)
|
|
else:
|
|
o.append(i)
|
|
if len(a) < 2 and not any(i.exp.is_negative for i in a if isinstance(i, Pow)):
|
|
return mul
|
|
absarg = Mul(*a)
|
|
A = Abs(absarg)
|
|
args = [A]
|
|
args.extend(o)
|
|
if not A.has(Abs):
|
|
args.extend(nc)
|
|
return Mul(*args)
|
|
if not isinstance(A, Abs):
|
|
# reevaluate and make it unevaluated
|
|
A = Abs(absarg, evaluate=False)
|
|
args[0] = A
|
|
_mulsort(args)
|
|
args.extend(nc) # nc always go last
|
|
return Mul._from_args(args, is_commutative=not nc)
|
|
|
|
return expr.replace(
|
|
lambda x: isinstance(x, Mul),
|
|
lambda x: _abs(x)).replace(
|
|
lambda x: isinstance(x, Pow),
|
|
lambda x: _abs(x))
|
|
|
|
|
|
def collect_const(expr, *vars, Numbers=True):
|
|
"""A non-greedy collection of terms with similar number coefficients in
|
|
an Add expr. If ``vars`` is given then only those constants will be
|
|
targeted. Although any Number can also be targeted, if this is not
|
|
desired set ``Numbers=False`` and no Float or Rational will be collected.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
expr : SymPy expression
|
|
This parameter defines the expression the expression from which
|
|
terms with similar coefficients are to be collected. A non-Add
|
|
expression is returned as it is.
|
|
|
|
vars : variable length collection of Numbers, optional
|
|
Specifies the constants to target for collection. Can be multiple in
|
|
number.
|
|
|
|
Numbers : bool
|
|
Specifies to target all instance of
|
|
:class:`sympy.core.numbers.Number` class. If ``Numbers=False``, then
|
|
no Float or Rational will be collected.
|
|
|
|
Returns
|
|
=======
|
|
|
|
expr : Expr
|
|
Returns an expression with similar coefficient terms collected.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import sqrt
|
|
>>> from sympy.abc import s, x, y, z
|
|
>>> from sympy.simplify.radsimp import collect_const
|
|
>>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2)))
|
|
sqrt(3)*(sqrt(2) + 2)
|
|
>>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7))
|
|
(sqrt(3) + sqrt(7))*(s + 1)
|
|
>>> s = sqrt(2) + 2
|
|
>>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7))
|
|
(sqrt(2) + 3)*(sqrt(3) + sqrt(7))
|
|
>>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3))
|
|
sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2)
|
|
|
|
The collection is sign-sensitive, giving higher precedence to the
|
|
unsigned values:
|
|
|
|
>>> collect_const(x - y - z)
|
|
x - (y + z)
|
|
>>> collect_const(-y - z)
|
|
-(y + z)
|
|
>>> collect_const(2*x - 2*y - 2*z, 2)
|
|
2*(x - y - z)
|
|
>>> collect_const(2*x - 2*y - 2*z, -2)
|
|
2*x - 2*(y + z)
|
|
|
|
See Also
|
|
========
|
|
|
|
collect, collect_sqrt, rcollect
|
|
"""
|
|
if not expr.is_Add:
|
|
return expr
|
|
|
|
recurse = False
|
|
|
|
if not vars:
|
|
recurse = True
|
|
vars = set()
|
|
for a in expr.args:
|
|
for m in Mul.make_args(a):
|
|
if m.is_number:
|
|
vars.add(m)
|
|
else:
|
|
vars = sympify(vars)
|
|
if not Numbers:
|
|
vars = [v for v in vars if not v.is_Number]
|
|
|
|
vars = list(ordered(vars))
|
|
for v in vars:
|
|
terms = defaultdict(list)
|
|
Fv = Factors(v)
|
|
for m in Add.make_args(expr):
|
|
f = Factors(m)
|
|
q, r = f.div(Fv)
|
|
if r.is_one:
|
|
# only accept this as a true factor if
|
|
# it didn't change an exponent from an Integer
|
|
# to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2)
|
|
# -- we aren't looking for this sort of change
|
|
fwas = f.factors.copy()
|
|
fnow = q.factors
|
|
if not any(k in fwas and fwas[k].is_Integer and not
|
|
fnow[k].is_Integer for k in fnow):
|
|
terms[v].append(q.as_expr())
|
|
continue
|
|
terms[S.One].append(m)
|
|
|
|
args = []
|
|
hit = False
|
|
uneval = False
|
|
for k in ordered(terms):
|
|
v = terms[k]
|
|
if k is S.One:
|
|
args.extend(v)
|
|
continue
|
|
|
|
if len(v) > 1:
|
|
v = Add(*v)
|
|
hit = True
|
|
if recurse and v != expr:
|
|
vars.append(v)
|
|
else:
|
|
v = v[0]
|
|
|
|
# be careful not to let uneval become True unless
|
|
# it must be because it's going to be more expensive
|
|
# to rebuild the expression as an unevaluated one
|
|
if Numbers and k.is_Number and v.is_Add:
|
|
args.append(_keep_coeff(k, v, sign=True))
|
|
uneval = True
|
|
else:
|
|
args.append(k*v)
|
|
|
|
if hit:
|
|
if uneval:
|
|
expr = _unevaluated_Add(*args)
|
|
else:
|
|
expr = Add(*args)
|
|
if not expr.is_Add:
|
|
break
|
|
|
|
return expr
|
|
|
|
|
|
def radsimp(expr, symbolic=True, max_terms=4):
|
|
r"""
|
|
Rationalize the denominator by removing square roots.
|
|
|
|
Explanation
|
|
===========
|
|
|
|
The expression returned from radsimp must be used with caution
|
|
since if the denominator contains symbols, it will be possible to make
|
|
substitutions that violate the assumptions of the simplification process:
|
|
that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If
|
|
there are no symbols, this assumptions is made valid by collecting terms
|
|
of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If
|
|
you do not want the simplification to occur for symbolic denominators, set
|
|
``symbolic`` to False.
|
|
|
|
If there are more than ``max_terms`` radical terms then the expression is
|
|
returned unchanged.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import radsimp, sqrt, Symbol, pprint
|
|
>>> from sympy import factor_terms, fraction, signsimp
|
|
>>> from sympy.simplify.radsimp import collect_sqrt
|
|
>>> from sympy.abc import a, b, c
|
|
|
|
>>> radsimp(1/(2 + sqrt(2)))
|
|
(2 - sqrt(2))/2
|
|
>>> x,y = map(Symbol, 'xy')
|
|
>>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2))
|
|
>>> radsimp(e)
|
|
sqrt(2)*(x + y)
|
|
|
|
No simplification beyond removal of the gcd is done. One might
|
|
want to polish the result a little, however, by collecting
|
|
square root terms:
|
|
|
|
>>> r2 = sqrt(2)
|
|
>>> r5 = sqrt(5)
|
|
>>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)); pprint(ans)
|
|
___ ___ ___ ___
|
|
\/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y
|
|
------------------------------------------
|
|
2 2 2 2
|
|
5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y
|
|
|
|
>>> n, d = fraction(ans)
|
|
>>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True))
|
|
___ ___
|
|
\/ 5 *(a + b) - \/ 2 *(x + y)
|
|
------------------------------------------
|
|
2 2 2 2
|
|
5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y
|
|
|
|
If radicals in the denominator cannot be removed or there is no denominator,
|
|
the original expression will be returned.
|
|
|
|
>>> radsimp(sqrt(2)*x + sqrt(2))
|
|
sqrt(2)*x + sqrt(2)
|
|
|
|
Results with symbols will not always be valid for all substitutions:
|
|
|
|
>>> eq = 1/(a + b*sqrt(c))
|
|
>>> eq.subs(a, b*sqrt(c))
|
|
1/(2*b*sqrt(c))
|
|
>>> radsimp(eq).subs(a, b*sqrt(c))
|
|
nan
|
|
|
|
If ``symbolic=False``, symbolic denominators will not be transformed (but
|
|
numeric denominators will still be processed):
|
|
|
|
>>> radsimp(eq, symbolic=False)
|
|
1/(a + b*sqrt(c))
|
|
|
|
"""
|
|
from sympy.simplify.simplify import signsimp
|
|
|
|
syms = symbols("a:d A:D")
|
|
def _num(rterms):
|
|
# return the multiplier that will simplify the expression described
|
|
# by rterms [(sqrt arg, coeff), ... ]
|
|
a, b, c, d, A, B, C, D = syms
|
|
if len(rterms) == 2:
|
|
reps = dict(list(zip([A, a, B, b], [j for i in rterms for j in i])))
|
|
return (
|
|
sqrt(A)*a - sqrt(B)*b).xreplace(reps)
|
|
if len(rterms) == 3:
|
|
reps = dict(list(zip([A, a, B, b, C, c], [j for i in rterms for j in i])))
|
|
return (
|
|
(sqrt(A)*a + sqrt(B)*b - sqrt(C)*c)*(2*sqrt(A)*sqrt(B)*a*b - A*a**2 -
|
|
B*b**2 + C*c**2)).xreplace(reps)
|
|
elif len(rterms) == 4:
|
|
reps = dict(list(zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i])))
|
|
return ((sqrt(A)*a + sqrt(B)*b - sqrt(C)*c - sqrt(D)*d)*(2*sqrt(A)*sqrt(B)*a*b
|
|
- A*a**2 - B*b**2 - 2*sqrt(C)*sqrt(D)*c*d + C*c**2 +
|
|
D*d**2)*(-8*sqrt(A)*sqrt(B)*sqrt(C)*sqrt(D)*a*b*c*d + A**2*a**4 -
|
|
2*A*B*a**2*b**2 - 2*A*C*a**2*c**2 - 2*A*D*a**2*d**2 + B**2*b**4 -
|
|
2*B*C*b**2*c**2 - 2*B*D*b**2*d**2 + C**2*c**4 - 2*C*D*c**2*d**2 +
|
|
D**2*d**4)).xreplace(reps)
|
|
elif len(rterms) == 1:
|
|
return sqrt(rterms[0][0])
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def ispow2(d, log2=False):
|
|
if not d.is_Pow:
|
|
return False
|
|
e = d.exp
|
|
if e.is_Rational and e.q == 2 or symbolic and denom(e) == 2:
|
|
return True
|
|
if log2:
|
|
q = 1
|
|
if e.is_Rational:
|
|
q = e.q
|
|
elif symbolic:
|
|
d = denom(e)
|
|
if d.is_Integer:
|
|
q = d
|
|
if q != 1 and log(q, 2).is_Integer:
|
|
return True
|
|
return False
|
|
|
|
def handle(expr):
|
|
# Handle first reduces to the case
|
|
# expr = 1/d, where d is an add, or d is base**p/2.
|
|
# We do this by recursively calling handle on each piece.
|
|
from sympy.simplify.simplify import nsimplify
|
|
|
|
n, d = fraction(expr)
|
|
|
|
if expr.is_Atom or (d.is_Atom and n.is_Atom):
|
|
return expr
|
|
elif not n.is_Atom:
|
|
n = n.func(*[handle(a) for a in n.args])
|
|
return _unevaluated_Mul(n, handle(1/d))
|
|
elif n is not S.One:
|
|
return _unevaluated_Mul(n, handle(1/d))
|
|
elif d.is_Mul:
|
|
return _unevaluated_Mul(*[handle(1/d) for d in d.args])
|
|
|
|
# By this step, expr is 1/d, and d is not a mul.
|
|
if not symbolic and d.free_symbols:
|
|
return expr
|
|
|
|
if ispow2(d):
|
|
d2 = sqrtdenest(sqrt(d.base))**numer(d.exp)
|
|
if d2 != d:
|
|
return handle(1/d2)
|
|
elif d.is_Pow and (d.exp.is_integer or d.base.is_positive):
|
|
# (1/d**i) = (1/d)**i
|
|
return handle(1/d.base)**d.exp
|
|
|
|
if not (d.is_Add or ispow2(d)):
|
|
return 1/d.func(*[handle(a) for a in d.args])
|
|
|
|
# handle 1/d treating d as an Add (though it may not be)
|
|
|
|
keep = True # keep changes that are made
|
|
|
|
# flatten it and collect radicals after checking for special
|
|
# conditions
|
|
d = _mexpand(d)
|
|
|
|
# did it change?
|
|
if d.is_Atom:
|
|
return 1/d
|
|
|
|
# is it a number that might be handled easily?
|
|
if d.is_number:
|
|
_d = nsimplify(d)
|
|
if _d.is_Number and _d.equals(d):
|
|
return 1/_d
|
|
|
|
while True:
|
|
# collect similar terms
|
|
collected = defaultdict(list)
|
|
for m in Add.make_args(d): # d might have become non-Add
|
|
p2 = []
|
|
other = []
|
|
for i in Mul.make_args(m):
|
|
if ispow2(i, log2=True):
|
|
p2.append(i.base if i.exp is S.Half else i.base**(2*i.exp))
|
|
elif i is S.ImaginaryUnit:
|
|
p2.append(S.NegativeOne)
|
|
else:
|
|
other.append(i)
|
|
collected[tuple(ordered(p2))].append(Mul(*other))
|
|
rterms = list(ordered(list(collected.items())))
|
|
rterms = [(Mul(*i), Add(*j)) for i, j in rterms]
|
|
nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0)
|
|
if nrad < 1:
|
|
break
|
|
elif nrad > max_terms:
|
|
# there may have been invalid operations leading to this point
|
|
# so don't keep changes, e.g. this expression is troublesome
|
|
# in collecting terms so as not to raise the issue of 2834:
|
|
# r = sqrt(sqrt(5) + 5)
|
|
# eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r)
|
|
keep = False
|
|
break
|
|
if len(rterms) > 4:
|
|
# in general, only 4 terms can be removed with repeated squaring
|
|
# but other considerations can guide selection of radical terms
|
|
# so that radicals are removed
|
|
if all(x.is_Integer and (y**2).is_Rational for x, y in rterms):
|
|
nd, d = rad_rationalize(S.One, Add._from_args(
|
|
[sqrt(x)*y for x, y in rterms]))
|
|
n *= nd
|
|
else:
|
|
# is there anything else that might be attempted?
|
|
keep = False
|
|
break
|
|
from sympy.simplify.powsimp import powsimp, powdenest
|
|
|
|
num = powsimp(_num(rterms))
|
|
n *= num
|
|
d *= num
|
|
d = powdenest(_mexpand(d), force=symbolic)
|
|
if d.has(S.Zero, nan, zoo):
|
|
return expr
|
|
if d.is_Atom:
|
|
break
|
|
|
|
if not keep:
|
|
return expr
|
|
return _unevaluated_Mul(n, 1/d)
|
|
|
|
coeff, expr = expr.as_coeff_Add()
|
|
expr = expr.normal()
|
|
old = fraction(expr)
|
|
n, d = fraction(handle(expr))
|
|
if old != (n, d):
|
|
if not d.is_Atom:
|
|
was = (n, d)
|
|
n = signsimp(n, evaluate=False)
|
|
d = signsimp(d, evaluate=False)
|
|
u = Factors(_unevaluated_Mul(n, 1/d))
|
|
u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()])
|
|
n, d = fraction(u)
|
|
if old == (n, d):
|
|
n, d = was
|
|
n = expand_mul(n)
|
|
if d.is_Number or d.is_Add:
|
|
n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1/d)))
|
|
if d2.is_Number or (d2.count_ops() <= d.count_ops()):
|
|
n, d = [signsimp(i) for i in (n2, d2)]
|
|
if n.is_Mul and n.args[0].is_Number:
|
|
n = n.func(*n.args)
|
|
|
|
return coeff + _unevaluated_Mul(n, 1/d)
|
|
|
|
|
|
def rad_rationalize(num, den):
|
|
"""
|
|
Rationalize ``num/den`` by removing square roots in the denominator;
|
|
num and den are sum of terms whose squares are positive rationals.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import sqrt
|
|
>>> from sympy.simplify.radsimp import rad_rationalize
|
|
>>> rad_rationalize(sqrt(3), 1 + sqrt(2)/3)
|
|
(-sqrt(3) + sqrt(6)/3, -7/9)
|
|
"""
|
|
if not den.is_Add:
|
|
return num, den
|
|
g, a, b = split_surds(den)
|
|
a = a*sqrt(g)
|
|
num = _mexpand((a - b)*num)
|
|
den = _mexpand(a**2 - b**2)
|
|
return rad_rationalize(num, den)
|
|
|
|
|
|
def fraction(expr, exact=False):
|
|
"""Returns a pair with expression's numerator and denominator.
|
|
If the given expression is not a fraction then this function
|
|
will return the tuple (expr, 1).
|
|
|
|
This function will not make any attempt to simplify nested
|
|
fractions or to do any term rewriting at all.
|
|
|
|
If only one of the numerator/denominator pair is needed then
|
|
use numer(expr) or denom(expr) functions respectively.
|
|
|
|
>>> from sympy import fraction, Rational, Symbol
|
|
>>> from sympy.abc import x, y
|
|
|
|
>>> fraction(x/y)
|
|
(x, y)
|
|
>>> fraction(x)
|
|
(x, 1)
|
|
|
|
>>> fraction(1/y**2)
|
|
(1, y**2)
|
|
|
|
>>> fraction(x*y/2)
|
|
(x*y, 2)
|
|
>>> fraction(Rational(1, 2))
|
|
(1, 2)
|
|
|
|
This function will also work fine with assumptions:
|
|
|
|
>>> k = Symbol('k', negative=True)
|
|
>>> fraction(x * y**k)
|
|
(x, y**(-k))
|
|
|
|
If we know nothing about sign of some exponent and ``exact``
|
|
flag is unset, then structure this exponent's structure will
|
|
be analyzed and pretty fraction will be returned:
|
|
|
|
>>> from sympy import exp, Mul
|
|
>>> fraction(2*x**(-y))
|
|
(2, x**y)
|
|
|
|
>>> fraction(exp(-x))
|
|
(1, exp(x))
|
|
|
|
>>> fraction(exp(-x), exact=True)
|
|
(exp(-x), 1)
|
|
|
|
The ``exact`` flag will also keep any unevaluated Muls from
|
|
being evaluated:
|
|
|
|
>>> u = Mul(2, x + 1, evaluate=False)
|
|
>>> fraction(u)
|
|
(2*x + 2, 1)
|
|
>>> fraction(u, exact=True)
|
|
(2*(x + 1), 1)
|
|
"""
|
|
expr = sympify(expr)
|
|
|
|
numer, denom = [], []
|
|
|
|
for term in Mul.make_args(expr):
|
|
if term.is_commutative and (term.is_Pow or isinstance(term, exp)):
|
|
b, ex = term.as_base_exp()
|
|
if ex.is_negative:
|
|
if ex is S.NegativeOne:
|
|
denom.append(b)
|
|
elif exact:
|
|
if ex.is_constant():
|
|
denom.append(Pow(b, -ex))
|
|
else:
|
|
numer.append(term)
|
|
else:
|
|
denom.append(Pow(b, -ex))
|
|
elif ex.is_positive:
|
|
numer.append(term)
|
|
elif not exact and ex.is_Mul:
|
|
n, d = term.as_numer_denom()
|
|
if n != 1:
|
|
numer.append(n)
|
|
denom.append(d)
|
|
else:
|
|
numer.append(term)
|
|
elif term.is_Rational and not term.is_Integer:
|
|
if term.p != 1:
|
|
numer.append(term.p)
|
|
denom.append(term.q)
|
|
else:
|
|
numer.append(term)
|
|
return Mul(*numer, evaluate=not exact), Mul(*denom, evaluate=not exact)
|
|
|
|
|
|
def numer(expr):
|
|
return fraction(expr)[0]
|
|
|
|
|
|
def denom(expr):
|
|
return fraction(expr)[1]
|
|
|
|
|
|
def fraction_expand(expr, **hints):
|
|
return expr.expand(frac=True, **hints)
|
|
|
|
|
|
def numer_expand(expr, **hints):
|
|
a, b = fraction(expr)
|
|
return a.expand(numer=True, **hints) / b
|
|
|
|
|
|
def denom_expand(expr, **hints):
|
|
a, b = fraction(expr)
|
|
return a / b.expand(denom=True, **hints)
|
|
|
|
|
|
expand_numer = numer_expand
|
|
expand_denom = denom_expand
|
|
expand_fraction = fraction_expand
|
|
|
|
|
|
def split_surds(expr):
|
|
"""
|
|
Split an expression with terms whose squares are positive rationals
|
|
into a sum of terms whose surds squared have gcd equal to g
|
|
and a sum of terms with surds squared prime with g.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import sqrt
|
|
>>> from sympy.simplify.radsimp import split_surds
|
|
>>> split_surds(3*sqrt(3) + sqrt(5)/7 + sqrt(6) + sqrt(10) + sqrt(15))
|
|
(3, sqrt(2) + sqrt(5) + 3, sqrt(5)/7 + sqrt(10))
|
|
"""
|
|
args = sorted(expr.args, key=default_sort_key)
|
|
coeff_muls = [x.as_coeff_Mul() for x in args]
|
|
surds = [x[1]**2 for x in coeff_muls if x[1].is_Pow]
|
|
surds.sort(key=default_sort_key)
|
|
g, b1, b2 = _split_gcd(*surds)
|
|
g2 = g
|
|
if not b2 and len(b1) >= 2:
|
|
b1n = [x/g for x in b1]
|
|
b1n = [x for x in b1n if x != 1]
|
|
# only a common factor has been factored; split again
|
|
g1, b1n, b2 = _split_gcd(*b1n)
|
|
g2 = g*g1
|
|
a1v, a2v = [], []
|
|
for c, s in coeff_muls:
|
|
if s.is_Pow and s.exp == S.Half:
|
|
s1 = s.base
|
|
if s1 in b1:
|
|
a1v.append(c*sqrt(s1/g2))
|
|
else:
|
|
a2v.append(c*s)
|
|
else:
|
|
a2v.append(c*s)
|
|
a = Add(*a1v)
|
|
b = Add(*a2v)
|
|
return g2, a, b
|
|
|
|
|
|
def _split_gcd(*a):
|
|
"""
|
|
Split the list of integers ``a`` into a list of integers, ``a1`` having
|
|
``g = gcd(a1)``, and a list ``a2`` whose elements are not divisible by
|
|
``g``. Returns ``g, a1, a2``.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.simplify.radsimp import _split_gcd
|
|
>>> _split_gcd(55, 35, 22, 14, 77, 10)
|
|
(5, [55, 35, 10], [22, 14, 77])
|
|
"""
|
|
g = a[0]
|
|
b1 = [g]
|
|
b2 = []
|
|
for x in a[1:]:
|
|
g1 = gcd(g, x)
|
|
if g1 == 1:
|
|
b2.append(x)
|
|
else:
|
|
g = g1
|
|
b1.append(x)
|
|
return g, b1, b2
|