556 lines
19 KiB
Python
556 lines
19 KiB
Python
from __future__ import (absolute_import, division,
|
|
print_function, unicode_literals)
|
|
import warnings,operator
|
|
warnings.filterwarnings("ignore")
|
|
def vargen(n):
|
|
"""default function generating names for variables"""
|
|
q=len(shiro.alphabet)
|
|
x=shiro.alphabet[n%q]
|
|
if n>=q:
|
|
x+=str(n//q)
|
|
return x
|
|
class Vars:
|
|
pass
|
|
shiro=Vars()
|
|
shiro.display=shiro.warning=print
|
|
shiro.seed=None
|
|
"""Seed is needed to select the weights in linprog function.
|
|
None means that the seed is random"""
|
|
shiro.translation={}
|
|
shiro.varind=0
|
|
shiro.varset=set()
|
|
"""set of used symbols in the proof"""
|
|
shiro.vargen=vargen
|
|
"""function generating names for variables"""
|
|
shiro.alphabet='abcdefghijklmnopqrstuvwxyz'
|
|
"""list of names for variables"""
|
|
translationList=['numerator:','denominator:','status:',
|
|
'Substitute',"Formula after substitution:",
|
|
"Numerator after substitutions:","From weighted AM-GM inequality:",
|
|
'The sum of all inequalities gives us a proof of the inequality.',
|
|
"Program couldn't find a solution with integer coefficients. Try "+
|
|
"to multiple the formula by some integer and run this function again.",
|
|
"Program couldn't find any proof.",
|
|
"Try to set higher linprogiter parameter.",
|
|
"It looks like the formula is symmetric. You can assume without loss of"+
|
|
" generality that ","Try", 'From Jensen inequality:',
|
|
'Warning: intervals contain backwards dependencies. Consider changing order of variables and intervals.'
|
|
]
|
|
#Initialize english-english dictionary.
|
|
for phrase in translationList:
|
|
shiro.translation[phrase]=phrase
|
|
from scipy.optimize import linprog,fmin
|
|
import random
|
|
from sympy import S,cancel,fraction,Pow,expand,solve,latex,oo,Poly,lambdify,srepr,gcd,Symbol
|
|
from sympy.parsing.sympy_parser import parse_expr, standard_transformations,\
|
|
implicit_multiplication_application, convert_xor
|
|
from collections import Counter
|
|
import re
|
|
def addsymbols(formula):
|
|
formula=S(formula)
|
|
funcsymbols=[x[10:-2] for x in re.findall(r"Function\(\'.*?\'\)",srepr(formula))]
|
|
shiro.varset|=set(funcsymbols)|set(map(str,formula.free_symbols))
|
|
def newvar():
|
|
while 1:
|
|
x=shiro.vargen(shiro.varind)
|
|
shiro.varind+=1
|
|
if x not in shiro.varset:
|
|
return S(x)
|
|
def newproof():
|
|
shiro.varset=set()
|
|
shiro.varind=0
|
|
def _remzero(coef,fun):
|
|
"""coef, fun represents an expression.
|
|
For example, if expression=5f(2,3)+0f(4,6)+8f(1,4)
|
|
then coef=[5,0,8], fun=[[2,3],[4,6],[1,4]]
|
|
_remzero removes addends with coefficient equal to zero.
|
|
In this example ncoef=[5,8], nfun=[[2,3],[1,4]]"""
|
|
ncoef=[]
|
|
nfun=[]
|
|
for c,f in zip(coef,fun):
|
|
if c>0:
|
|
ncoef+=[c]
|
|
nfun+=[f]
|
|
return ncoef,nfun
|
|
# ~ def slatex(formula): #fancy function which makes latex code more readable, but still correct
|
|
# ~ formula=re.sub(r'\^{(.)}',r'^\1',latex(formula,fold_short_frac=True).replace(' ','').replace('\\left(','(').replace('\\right)',')'))
|
|
# ~ return re.sub(r'\{(\(.+?\))\}',r'\1',formula)
|
|
def _writ2(coef,fun,variables):
|
|
return latex(Poly({fun:coef},gens=variables).as_expr())
|
|
def _writ(coef,fun,nullvar):
|
|
return str(coef)+'f('+str(fun)[1:-1-(len(fun)==1)]+')'
|
|
def _check(coef,fun,res,rfun):
|
|
#checks if rounding and all the floating point stuff works
|
|
res2=[int(round(x)) for x in res]
|
|
b1=[coef*x for x in fun]
|
|
b2=[[x*y for y in rfuni] for x,rfuni in zip(res2,rfun)]
|
|
return b1==[sum(x) for x in zip(*b2)] and coef==sum(res2)
|
|
def _powr(formula):
|
|
if formula.func==Pow:
|
|
return formula.args
|
|
else:
|
|
return [formula,S('1')]
|
|
def fractioncancel(formula):
|
|
"""workaround for buggy cancel function"""
|
|
num,den=fraction(cancel(formula/S('tmp')))
|
|
den=den.subs(S('tmp'),S('1'))
|
|
return num,den
|
|
def ssolve(formula,variables):
|
|
"""workaround for inconsistent solve function"""
|
|
result=solve(formula,variables)
|
|
if type(result)==dict:
|
|
result=[[result[var] for var in variables]]
|
|
return result
|
|
#def sstr(formula):
|
|
# return str(formula).replace('**','^').replace('*','').replace(' ','')
|
|
def assumeall(formula,**kwargs):
|
|
"""Adds assumptions to all free symbols in formula.
|
|
>>> assumeall('sqrt(x*y)-sqrt(x)*sqrt(y)',positive=True)
|
|
0
|
|
"""
|
|
formula=S(formula)
|
|
fs=formula.free_symbols
|
|
for x in fs:
|
|
y=Symbol(str(x),**kwargs)
|
|
formula=formula.subs(x,y)
|
|
return formula
|
|
def reducegens(formula):
|
|
"""Reduces size of the generator of the polynomial
|
|
>>>Poly('x+sqrt(x)')
|
|
Poly(x + (sqrt(x)), x, sqrt(x), domain='ZZ')
|
|
>>>reducegens('x+sqrt(x)')
|
|
Poly((sqrt(x))**2 + (sqrt(x)), sqrt(x), domain='ZZ') """
|
|
pol=Poly(formula)
|
|
newgens={}
|
|
ind={}
|
|
for gen in pol.gens:
|
|
base,pw=_powr(gen)
|
|
coef,_=pw.as_coeff_mul()
|
|
ml=pw/coef
|
|
if base**ml in newgens:
|
|
newgens[base**ml]=gcd(newgens[base**ml],coef)
|
|
else:
|
|
newgens[base**ml]=coef
|
|
ind[base**ml]=S('tmp'+str(len(ind)))
|
|
for gen in pol.gens:
|
|
base,pw=_powr(gen)
|
|
coef,_=pw.as_coeff_mul()
|
|
ml=pw/coef
|
|
pol=pol.replace(gen,ind[base**ml]**(coef/newgens[base**ml]))
|
|
newpol=Poly(pol.as_expr())
|
|
for gen in newgens:
|
|
newpol=newpol.replace(ind[gen],gen**newgens[gen])
|
|
return newpol
|
|
def Sm(formula):
|
|
"""Adds multiplication signs and sympifies a formula.
|
|
For example, Sm('(2x+y)(7+5xz)') -> S('(2*x+y)*(7+5*x*z)')"""
|
|
if type(formula)!=str:
|
|
if _isiterable(formula): return type(formula)(map(Sm,formula))
|
|
else: return S(formula)
|
|
formula=formula.translate({ord('{'):None,ord('}'):' '})
|
|
transformations = (standard_transformations +(implicit_multiplication_application,convert_xor))
|
|
return parse_expr(formula, transformations=transformations)
|
|
# ~ def Sm(formula):
|
|
# ~ #Adds multiplication signs and sympifies a formula.
|
|
# ~ #For example, Sm('(2x+y)(7+5xz)') -> S('(2*x+y)*(7+5*x*z)')
|
|
# ~ if type(formula)==str:
|
|
# ~ formula=formula.replace(' ','')
|
|
# ~ for i in range(2):
|
|
# ~ formula=re.sub(r'([0-9a-zA-Z)])([(a-zA-Z])',r'\1*\2',formula)
|
|
# ~ formula=S(formula)
|
|
# ~ return formula
|
|
def _input2fraction(formula,variables,values):
|
|
"""makes some substitutions and converts formula to a fraction
|
|
with expanded numerator and denominator"""
|
|
formula=S(formula)
|
|
subst=[]
|
|
for x,y in zip(variables,values):
|
|
if y!=1:
|
|
z=newvar()
|
|
shiro.display(shiro.translation['Substitute']+' $'+latex(x)+'\\to '+latex(S(y)*S(z))+'$')
|
|
subst+=[(x,z*y)]
|
|
formula=formula.subs(subst)
|
|
numerator,denominator=fractioncancel(formula)
|
|
shiro.display(shiro.translation['numerator:']+' $'+latex(numerator)+'$')
|
|
shiro.display(shiro.translation['denominator:']+' $'+latex(denominator)+'$')
|
|
return (numerator,denominator)
|
|
def _formula2list(formula):
|
|
"""Splits a polynomial to a difference of two polynomials with positive
|
|
coefficients and extracts coefficients and powers of both polynomials.
|
|
'variables' is used to set order of powers
|
|
For example, If formula=5x^2-4xy+8y^3, variables=[x,y], then
|
|
the program tries to prove that
|
|
0<=5x^2-4xy+8y^3
|
|
4xy<=5x^2+8y^3
|
|
returns [4],[(1,1)], [5,8],[(2,0),(0,3)], (x,y)"""
|
|
formula=reducegens(assumeall(formula,positive=True))
|
|
neg=(formula.abs()-formula)*S('1/2')
|
|
pos=(formula.abs()+formula)*S('1/2')
|
|
return neg.coeffs(),neg.monoms(),pos.coeffs(),pos.monoms(),Poly(formula).gens
|
|
def _list2proof(lcoef,lfun,rcoef,rfun,variables,itermax,linprogiter,_writ2=_writ2,theorem="From weighted AM-GM inequality:"):
|
|
"""Now the formula is splitted on two polynomials with positive coefficients.
|
|
we will call them LHS and RHS and our inequality to prove would
|
|
be LHS<=RHS (instead of 0<=RHS-LHS).
|
|
|
|
Suppose we are trying to prove that
|
|
30x^2y^2+60xy^4<=48x^3+56y^6 (assuming x,y>0).
|
|
Program will try to find some a,b,c,d such that
|
|
30x^2y^2<=ax^3+by^6
|
|
60xy^4<=cx^3+dy^6
|
|
where a+c<=48 and b+d<=56 (assumption 1).
|
|
We need some additional equalities to meet assumptions
|
|
of the weighted AM-GM inequality.
|
|
a+b=30 and c+d=60 (assumption 2)
|
|
3a+0b=30*2, 0a+6b=30*2, 3c+0d=60*1, 0c+6d=60*4 (assumption 3)
|
|
|
|
The sketch of the algorithm.
|
|
for i in range(itermax):
|
|
1. Create a vector of random numbers (weights).
|
|
2. Try to find real solution of the problem (with linprog).
|
|
3. If there is no solution (status: 2)
|
|
3a. If the solution was never found, break.
|
|
3b. Else, step back (to the bigger inequality)
|
|
4. If the soltuion was found (status: 0)
|
|
Check out which of variables (in example: a,b,c,d) looks like integer.
|
|
If there are some inequalities with all integer coefficients, subtract
|
|
them from the original one.
|
|
If LHS is empty, then break."""
|
|
localseed=shiro.seed
|
|
bufer=[]
|
|
lcoef,lfun=_remzero(lcoef,lfun)
|
|
rcoef,rfun=_remzero(rcoef,rfun)
|
|
itern=0
|
|
if len(lcoef)==0: #if LHS is empty
|
|
shiro.display(shiro.translation['status:']+' 0')
|
|
status=0
|
|
elif len(rcoef)==0:
|
|
#if RHS is empty, but LHS is not
|
|
shiro.display(shiro.translation['status:']+' 2')
|
|
status=2
|
|
itermax=0
|
|
foundreal=0
|
|
while len(lcoef)>0 and itern<itermax:
|
|
itern+=1
|
|
m=len(lcoef)
|
|
n=len(rcoef)
|
|
#lfunt=transposed matrix lfun (in fact, it's
|
|
#a list of lists)
|
|
lfunt=list(map(list, zip(*lfun)))
|
|
rfunt=list(map(list, zip(*rfun)))
|
|
#A,b, - set of linear equalities
|
|
#A_ub,b_ub - set of linear inequalities
|
|
#from linear program
|
|
A=[]
|
|
b=[]
|
|
A_ub=[]
|
|
b_ub=[]
|
|
for i in range(m):
|
|
for j in range(len(rfunt)): #assumption 3
|
|
A+=[[0]*(m*n)]
|
|
A[-1][i*n:i*n+n]=rfunt[j]
|
|
b+=[lfun[i][j]*lcoef[i]]
|
|
A+=[[0]*(m*n)] #assumption 2
|
|
A[-1][i*n:i*n+n]=[1]*n
|
|
b+=[lcoef[i]]
|
|
for j in range(n): #assumption 1
|
|
A_ub+=[[0]*(m*n)]
|
|
A_ub[-1][j::n]=[1]*m
|
|
b_ub+=[rcoef[j]]
|
|
random.seed(localseed)
|
|
vecc=[random.random() for i in range(m*n)]
|
|
localseed=random.randint(1,1000000000)
|
|
res=linprog(vecc,A_eq=A,b_eq=b,A_ub=A_ub,b_ub=b_ub,options={'maxiter':linprogiter})
|
|
status=res.status
|
|
if itern==1:
|
|
shiro.display(shiro.translation['status:']+' '+str(status))
|
|
if status==0:
|
|
shiro.display(shiro.translation[theorem])
|
|
if status==2: #if real solution of current inequality doesn't exist
|
|
if foundreal==0: #if this is the first inequality, then break
|
|
break
|
|
else:
|
|
#step back
|
|
lcoef,lfun=oldlcoef,oldlfun
|
|
rcoef,rfun=oldrcoef,oldrfun
|
|
bufer=[]
|
|
continue
|
|
if status==0:#if found a solution with real coefficients
|
|
for ineq in bufer:
|
|
shiro.display(ineq)
|
|
foundreal=1
|
|
bufer=[]
|
|
oldlfun,oldrfun=lfun,rfun
|
|
oldlcoef,oldrcoef=lcoef[:],rcoef[:]
|
|
for i in range(m):
|
|
c=0
|
|
for j in res.x[i*n:i*n+n]:#check if all coefficients
|
|
#in an equality looks like integers
|
|
if(abs(round(j)-j)>0.0001):
|
|
break
|
|
else:
|
|
#checks if rounding all coefficients doesn't make
|
|
#inequality false
|
|
isok=_check(lcoef[i],lfun[i],res.x[i*n:i*n+n],rfun)
|
|
if not isok:
|
|
continue
|
|
bufer+=['']
|
|
bufer[-1]+='$$'+_writ2(lcoef[i],lfun[i],variables)+' \\le '
|
|
lcoef[i]=0
|
|
for j in range(n):
|
|
rcoef[j]-=int(round(res.x[i*n+j]))
|
|
for j,k in zip(res.x[i*n:i*n+n],rfun):
|
|
if j<0.0001:
|
|
continue
|
|
if(c):bufer[-1]+='+'
|
|
else:c=1
|
|
bufer[-1]+=_writ2(int(round(j)),k,variables)
|
|
bufer[-1]+='$$'
|
|
lcoef,lfun=_remzero(lcoef,lfun)
|
|
rcoef,rfun=_remzero(rcoef,rfun)
|
|
for ineq in bufer:
|
|
shiro.display(ineq)
|
|
lhs='+'.join([_writ2(c,f,variables) for c,f in zip(lcoef,lfun)])
|
|
if lhs=='':
|
|
lhs='0'
|
|
elif status==0:
|
|
shiro.display(shiro.translation[
|
|
"Program couldn't find a solution with integer coefficients. Try "+
|
|
"to multiple the formula by some integer and run this function again."])
|
|
elif(status==2):
|
|
shiro.display(shiro.translation["Program couldn't find any proof."])
|
|
#return res.status
|
|
elif status==1:
|
|
shiro.display(shiro.translation["Try to set higher linprogiter parameter."])
|
|
rhs='+'.join([_writ2(c,f,variables) for c,f in zip(rcoef,rfun)])
|
|
if rhs=='':
|
|
rhs='0'
|
|
shiro.display('$$ '+latex(lhs)+' \\le '+latex(rhs)+' $$')
|
|
if lhs=='0':
|
|
shiro.display(shiro.translation['The sum of all inequalities gives us a proof of the inequality.'])
|
|
return status
|
|
def _isiterable(obj):
|
|
try:
|
|
_ = (e for e in obj)
|
|
return True
|
|
except TypeError:
|
|
return False
|
|
def _smakeiterable(x):
|
|
if x=='':
|
|
return []
|
|
x=S(x)
|
|
if _isiterable(x):
|
|
return x
|
|
return (x,)
|
|
def _smakeiterable2(x):
|
|
if x=='':
|
|
return []
|
|
x=S(x)
|
|
if len(x)==0:
|
|
return []
|
|
if _isiterable(x[0]):
|
|
return x
|
|
return (x,)
|
|
def prove(formula,values=None,variables=None,niter=200,linprogiter=10000):
|
|
"""tries to prove that formula>=0 assuming all variables are positive"""
|
|
formula=S(formula)
|
|
addsymbols(formula)
|
|
if variables: variables=_smakeiterable(variables)
|
|
else: variables=sorted(formula.free_symbols,key=str)
|
|
if values: values=_smakeiterable(values)
|
|
else: values=[1]*len(variables)
|
|
num,den=_input2fraction(formula,variables,values)
|
|
st=_list2proof(*(_formula2list(num)+(niter,linprogiter)))
|
|
if st==2 and issymetric(num):
|
|
fs=sorted([str(x) for x in num.free_symbols])
|
|
shiro.display(shiro.translation["It looks like the formula is symmetric. "+
|
|
"You can assume without loss of generality that "]+
|
|
' >= '.join([str(x) for x in fs])+'. '+shiro.translation['Try'])
|
|
shiro.display('prove(makesubs(S("'+str(num)+'"),'+
|
|
str([(str(x),'oo') for x in variables[1:]])+')')
|
|
return st
|
|
def powerprove(formula,values=None,variables=None,niter=200,linprogiter=10000):
|
|
"""This is a bruteforce and ineffective function for proving inequalities.
|
|
It can be used as the last resort."""
|
|
formula=S(formula)
|
|
addsymbols(formula)
|
|
if variables: variables=_smakeiterable(variables)
|
|
else: variables=sorted(formula.free_symbols,key=str)
|
|
if values: values=_smakeiterable(values)
|
|
else: values=[1]*len(variables)
|
|
num,den=_input2fraction(formula,variables,values)
|
|
statusses=[]
|
|
|
|
for i in range(1<<len(variables)): #tricky substitutions to improve speed
|
|
shiro.display(r'_______________________')
|
|
subst1=[]
|
|
subst2=[]
|
|
substout=[]
|
|
for j in range(len(variables)):
|
|
z=newvar()
|
|
x=variables[j]
|
|
subst2+=[(x,1+z)]
|
|
if i&(1<<j):
|
|
subst1+=[(x,1/x)]
|
|
substout+=[latex(x)+'\\to 1/(1+'+latex(z)+')']
|
|
else:
|
|
substout+=[latex(x)+'\\to 1+'+latex(z)]
|
|
shiro.display(shiro.translation['Substitute']+ ' $'+','.join(substout)+'$')
|
|
num1=fractioncancel(num.subs(subst1))[0]
|
|
num2=expand(num1.subs(subst2))
|
|
shiro.display(shiro.translation["Numerator after substitutions:"]+' $'+latex(num2)+'$')
|
|
statusses+=[_list2proof(*(_formula2list(num2)+(niter,linprogiter)))]
|
|
return Counter(statusses)
|
|
def makesubs(formula,intervals,values=None,variables=None,numden=False):
|
|
"""Generates a new formula which satisfies this condition:
|
|
for all positive variables new formula is nonnegative iff
|
|
for all variables in corresponding intervals old formula is nonnegative"""
|
|
formula=S(formula)
|
|
addsymbols(formula)
|
|
intervals=_smakeiterable2(intervals)
|
|
if variables: variables=_smakeiterable(variables)
|
|
else: variables=sorted(formula.free_symbols,key=str)
|
|
if values!=None:
|
|
values=_smakeiterable(values)
|
|
equations=[var-value for var,value in zip(variables,values)]
|
|
else:
|
|
equations=[]
|
|
newvars=[]
|
|
warn=0
|
|
usedvars=set()
|
|
for var,interval in zip(variables,intervals):
|
|
end1,end2=interval
|
|
z=newvar()
|
|
newvars+=[z]
|
|
usedvars|={var}
|
|
if (end1.free_symbols|end2.free_symbols)&usedvars:
|
|
warn=1
|
|
if end1 in {S('-oo'),S('oo')}:
|
|
end1,end2=end2,end1
|
|
if {end1,end2}=={S('-oo'),S('oo')}:
|
|
sub1=sub2=(z-1/z)
|
|
elif end2==S('oo'):
|
|
sub1=sub2=(end1+z)
|
|
elif end2==S('-oo'):
|
|
sub1=sub2=end1-z
|
|
else:
|
|
sub1=end2+(end1-end2)/z
|
|
sub2=end2+(end1-end2)/(1+z)
|
|
formula=formula.subs(var,sub1)
|
|
shiro.display(shiro.translation['Substitute']+" $"+latex(var)+'\\to '+latex(sub2)+'$')
|
|
equations=[equation.subs(var,sub1) for equation in equations]
|
|
num,den=fractioncancel(formula)
|
|
for var,interval in zip(newvars,intervals):
|
|
if {interval[0],interval[1]} & {S('oo'),S('-oo')}==set():
|
|
num=num.subs(var,var+1)
|
|
den=den.subs(var,var+1)
|
|
equations=[equation.subs(var,var+1) for equation in equations]
|
|
if values:
|
|
values=ssolve(equations,newvars)
|
|
if len(values):
|
|
values=values[0]
|
|
num,den=expand(num),expand(den)
|
|
#shiro.display(shiro.translation["Formula after substitution:"],"$$",latex(num/den),'$$')
|
|
if warn:
|
|
shiro.warning(shiro.translation[
|
|
'Warning: intervals contain backwards dependencies. Consider changing order of variables and intervals.'])
|
|
if values and numden:
|
|
return num,den,values
|
|
elif values:
|
|
return num/den,values
|
|
elif numden:
|
|
return num,den
|
|
else:
|
|
return num/den
|
|
def _formula2listf(formula):
|
|
"""Splits a polynomial to a difference of two formulas with positive
|
|
coefficients and extracts coefficients and function
|
|
arguments of both formulas."""
|
|
lfun=[]
|
|
lcoef=[]
|
|
rfun=[]
|
|
rcoef=[]
|
|
for addend in formula.as_ordered_terms():
|
|
coef,facts=addend.as_coeff_mul()
|
|
if(coef<0):
|
|
lcoef+=[-coef]
|
|
lfun+=[facts[0].args]
|
|
else:
|
|
rcoef+=[coef]
|
|
rfun+=[facts[0].args]
|
|
return(lcoef,lfun,rcoef,rfun,None)
|
|
def provef(formula,niter=200,linprogiter=10000):
|
|
"""This function is similar to prove, formula is a linear combination of
|
|
values of f:R^k->R instead of a polynomial. provef checks if a formula
|
|
is nonnegative for any nonnegative and convex function f. If so, it
|
|
provides a proof of nonnegativity."""
|
|
formula=S(formula)
|
|
addsymbols(formula)
|
|
num,den=_input2fraction(formula,[],[])
|
|
return _list2proof(*(_formula2listf(num)+(niter,linprogiter,_writ,'From Jensen inequality:')))
|
|
def issymetric(formula):
|
|
"""checks if formula is symmetric
|
|
and has at least two variables"""
|
|
formula=S(formula)
|
|
addsymbols(formula)
|
|
if len(formula.free_symbols)<2:
|
|
return False
|
|
ls=list(formula.free_symbols)
|
|
a=ls[0]
|
|
for b in ls[1:]:
|
|
if expand(formula-formula.subs({a:b, b:a}, simultaneous=True))!=S(0):
|
|
return False
|
|
return True
|
|
def cyclize(formula,oper=operator.add,variables=None,init=None):
|
|
"""cyclize('a^2*b')=S('a^2*b+b^2*a')
|
|
cyclize('a^2*b',variables='a,b,c')=S('a^2*b+b^2*c+c^2*a')"""
|
|
formula=S(formula)
|
|
addsymbols(formula)
|
|
if variables==None:
|
|
variables=sorted(formula.free_symbols,key=str)
|
|
else:
|
|
variables=S(variables)
|
|
if len(variables)==0:
|
|
return init
|
|
variables=list(variables) #if variables is a tuple, change it to a list
|
|
variables+=[variables[0]]
|
|
subst=list(zip(variables[:-1],variables[1:]))
|
|
if init==None:
|
|
init=formula
|
|
else:
|
|
init=oper(init,formula)
|
|
for _ in variables[2:]:
|
|
formula=formula.subs(subst,simultaneous=True)
|
|
init=oper(init,formula)
|
|
return init
|
|
def symmetrize(formula,oper=operator.add,variables=None,init=None):
|
|
"""symmetrize('a^2*b')=S('a^2*b+b^2*a')
|
|
symmetrize('a^2*b',variables='a,b,c')=
|
|
=S('a^2*b+a^2*c+b^2*a+b^2*c+c^2*a+c^2*b')"""
|
|
formula=S(formula)
|
|
addsymbols(formula)
|
|
if variables==None:
|
|
variables=sorted(formula.free_symbols,key=str)
|
|
else:
|
|
variables=S(variables)
|
|
for i in range(1,len(variables)):
|
|
formula=cyclize(formula,oper,variables[:i+1])
|
|
return formula
|
|
def findvalues(formula,values=None,variables=None,**kwargs):
|
|
"""finds a candidate for parameter "values" in "prove" function"""
|
|
formula=S(formula)
|
|
addsymbols(formula)
|
|
num,den=fractioncancel(formula)
|
|
if variables==None:
|
|
variables=sorted(num.free_symbols,key=str)
|
|
num=num.subs(zip(variables,list(map(lambda x:x**2,variables))))
|
|
num=Poly(num)
|
|
newformula=S((num.abs()+num)/(num.abs()-num))
|
|
f=lambdify(variables,newformula)
|
|
f2=lambda x:f(*x)
|
|
if values==None:
|
|
values=[1.0]*len(variables)
|
|
else:
|
|
values=S(values)
|
|
tup=tuple(fmin(f2,values,**kwargs))
|
|
return tuple([x*x for x in tup])
|