188 lines
5.1 KiB
Python
188 lines
5.1 KiB
Python
from ..libmp.backend import xrange
|
|
from .functions import defun, defun_wrapped
|
|
|
|
@defun
|
|
def gammaprod(ctx, a, b, _infsign=False):
|
|
a = [ctx.convert(x) for x in a]
|
|
b = [ctx.convert(x) for x in b]
|
|
poles_num = []
|
|
poles_den = []
|
|
regular_num = []
|
|
regular_den = []
|
|
for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x)
|
|
for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x)
|
|
# One more pole in numerator or denominator gives 0 or inf
|
|
if len(poles_num) < len(poles_den): return ctx.zero
|
|
if len(poles_num) > len(poles_den):
|
|
# Get correct sign of infinity for x+h, h -> 0 from above
|
|
# XXX: hack, this should be done properly
|
|
if _infsign:
|
|
a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num]
|
|
b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den]
|
|
return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf
|
|
else:
|
|
return ctx.inf
|
|
# All poles cancel
|
|
# lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
|
|
p = ctx.one
|
|
orig = ctx.prec
|
|
try:
|
|
ctx.prec = orig + 15
|
|
while poles_num:
|
|
i = poles_num.pop()
|
|
j = poles_den.pop()
|
|
p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i)
|
|
for x in regular_num: p *= ctx.gamma(x)
|
|
for x in regular_den: p /= ctx.gamma(x)
|
|
finally:
|
|
ctx.prec = orig
|
|
return +p
|
|
|
|
@defun
|
|
def beta(ctx, x, y):
|
|
x = ctx.convert(x)
|
|
y = ctx.convert(y)
|
|
if ctx.isinf(y):
|
|
x, y = y, x
|
|
if ctx.isinf(x):
|
|
if x == ctx.inf and not ctx._im(y):
|
|
if y == ctx.ninf:
|
|
return ctx.nan
|
|
if y > 0:
|
|
return ctx.zero
|
|
if ctx.isint(y):
|
|
return ctx.nan
|
|
if y < 0:
|
|
return ctx.sign(ctx.gamma(y)) * ctx.inf
|
|
return ctx.nan
|
|
xy = ctx.fadd(x, y, prec=2*ctx.prec)
|
|
return ctx.gammaprod([x, y], [xy])
|
|
|
|
@defun
|
|
def binomial(ctx, n, k):
|
|
n1 = ctx.fadd(n, 1, prec=2*ctx.prec)
|
|
k1 = ctx.fadd(k, 1, prec=2*ctx.prec)
|
|
nk1 = ctx.fsub(n1, k, prec=2*ctx.prec)
|
|
return ctx.gammaprod([n1], [k1, nk1])
|
|
|
|
@defun
|
|
def rf(ctx, x, n):
|
|
xn = ctx.fadd(x, n, prec=2*ctx.prec)
|
|
return ctx.gammaprod([xn], [x])
|
|
|
|
@defun
|
|
def ff(ctx, x, n):
|
|
x1 = ctx.fadd(x, 1, prec=2*ctx.prec)
|
|
xn1 = ctx.fadd(ctx.fsub(x, n, prec=2*ctx.prec), 1, prec=2*ctx.prec)
|
|
return ctx.gammaprod([x1], [xn1])
|
|
|
|
@defun_wrapped
|
|
def fac2(ctx, x):
|
|
if ctx.isinf(x):
|
|
if x == ctx.inf:
|
|
return x
|
|
return ctx.nan
|
|
return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1)
|
|
|
|
@defun_wrapped
|
|
def barnesg(ctx, z):
|
|
if ctx.isinf(z):
|
|
if z == ctx.inf:
|
|
return z
|
|
return ctx.nan
|
|
if ctx.isnan(z):
|
|
return z
|
|
if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)):
|
|
return z*0
|
|
# Account for size (would not be needed if computing log(G))
|
|
if abs(z) > 5:
|
|
ctx.dps += 2*ctx.log(abs(z),2)
|
|
# Reflection formula
|
|
if ctx.re(z) < -ctx.dps:
|
|
w = 1-z
|
|
pi2 = 2*ctx.pi
|
|
u = ctx.expjpi(2*w)
|
|
v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \
|
|
ctx.j*ctx.polylog(2, u)/pi2
|
|
v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w
|
|
if ctx._is_real_type(z):
|
|
v = ctx._re(v)
|
|
return v
|
|
# Estimate terms for asymptotic expansion
|
|
# TODO: fixme, obviously
|
|
N = ctx.dps // 2 + 5
|
|
G = 1
|
|
while abs(z) < N or ctx.re(z) < 1:
|
|
G /= ctx.gamma(z)
|
|
z += 1
|
|
z -= 1
|
|
s = ctx.mpf(1)/12
|
|
s -= ctx.log(ctx.glaisher)
|
|
s += z*ctx.log(2*ctx.pi)/2
|
|
s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z)
|
|
s -= 3*z**2/4
|
|
z2k = z2 = z**2
|
|
for k in xrange(1, N+1):
|
|
t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k)
|
|
if abs(t) < ctx.eps:
|
|
#print k, N # check how many terms were needed
|
|
break
|
|
z2k *= z2
|
|
s += t
|
|
#if k == N:
|
|
# print "warning: series for barnesg failed to converge", ctx.dps
|
|
return G*ctx.exp(s)
|
|
|
|
@defun
|
|
def superfac(ctx, z):
|
|
return ctx.barnesg(z+2)
|
|
|
|
@defun_wrapped
|
|
def hyperfac(ctx, z):
|
|
# XXX: estimate needed extra bits accurately
|
|
if z == ctx.inf:
|
|
return z
|
|
if abs(z) > 5:
|
|
extra = 4*int(ctx.log(abs(z),2))
|
|
else:
|
|
extra = 0
|
|
ctx.prec += extra
|
|
if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)):
|
|
n = int(ctx.re(z))
|
|
h = ctx.hyperfac(-n-1)
|
|
if ((n+1)//2) & 1:
|
|
h = -h
|
|
if ctx._is_complex_type(z):
|
|
return h + 0j
|
|
return h
|
|
zp1 = z+1
|
|
# Wrong branch cut
|
|
#v = ctx.gamma(zp1)**z
|
|
#ctx.prec -= extra
|
|
#return v / ctx.barnesg(zp1)
|
|
v = ctx.exp(z*ctx.loggamma(zp1))
|
|
ctx.prec -= extra
|
|
return v / ctx.barnesg(zp1)
|
|
|
|
'''
|
|
@defun
|
|
def psi0(ctx, z):
|
|
"""Shortcut for psi(0,z) (the digamma function)"""
|
|
return ctx.psi(0, z)
|
|
|
|
@defun
|
|
def psi1(ctx, z):
|
|
"""Shortcut for psi(1,z) (the trigamma function)"""
|
|
return ctx.psi(1, z)
|
|
|
|
@defun
|
|
def psi2(ctx, z):
|
|
"""Shortcut for psi(2,z) (the tetragamma function)"""
|
|
return ctx.psi(2, z)
|
|
|
|
@defun
|
|
def psi3(ctx, z):
|
|
"""Shortcut for psi(3,z) (the pentagamma function)"""
|
|
return ctx.psi(3, z)
|
|
'''
|