Skip to content

Commit

Permalink
Merge pull request sympy#23932 from bjodah/printing-scipy-polygamma
Browse files Browse the repository at this point in the history
SciPyPrinter: polygamma (fix sympygh-23924)
  • Loading branch information
oscarbenjamin authored Nov 24, 2022
2 parents e40fe6e + 9c7ca29 commit faa0865
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
1 change: 1 addition & 0 deletions sympy/printing/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _print_NDimArray(self, expr):
'gamma': 'gamma',
'loggamma': 'gammaln',
'digamma': 'psi',
'polygamma': 'polygamma',
'RisingFactorial': 'poch',
'jacobi': 'eval_jacobi',
'gegenbauer': 'eval_gegenbauer',
Expand Down
5 changes: 5 additions & 0 deletions sympy/printing/tests/test_numpy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from sympy.concrete.summations import Sum
from sympy.core.mod import Mod
from sympy.core.relational import (Equality, Unequality)
from sympy.core.symbol import Symbol
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.special.gamma_functions import polygamma
from sympy.matrices.expressions.blockmatrix import BlockMatrix
from sympy.matrices.expressions.matexpr import MatrixSymbol
from sympy.matrices.expressions.special import Identity
Expand Down Expand Up @@ -341,3 +343,6 @@ def test_scipy_print_methods():
assert hasattr(prntr, '_print_erf')
assert hasattr(prntr, '_print_factorial')
assert hasattr(prntr, '_print_chebyshevt')
k = Symbol('k', integer=True, nonnegative=True)
x = Symbol('x', real=True)
assert prntr.doprint(polygamma(k, x)) == "scipy.special.polygamma(k, x)"
14 changes: 8 additions & 6 deletions sympy/utilities/tests/test_lambdify.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sympy.functions.special.beta_functions import (beta, betainc, betainc_regularized)
from sympy.functions.special.delta_functions import (Heaviside)
from sympy.functions.special.error_functions import (Ei, erf, erfc, fresnelc, fresnels)
from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma)
from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma, polygamma)
from sympy.integrals.integrals import Integral
from sympy.logic.boolalg import (And, false, ITE, Not, Or, true)
from sympy.matrices.expressions.dotproduct import DotProduct
Expand Down Expand Up @@ -1081,7 +1081,7 @@ def test_scipy_fns():
single_arg_sympy_fns = [Ei, erf, erfc, factorial, gamma, loggamma, digamma]
single_arg_scipy_fns = [scipy.special.expi, scipy.special.erf, scipy.special.erfc,
scipy.special.factorial, scipy.special.gamma, scipy.special.gammaln,
scipy.special.psi]
scipy.special.psi]
numpy.random.seed(0)
for (sympy_fn, scipy_fn) in zip(single_arg_sympy_fns, single_arg_scipy_fns):
f = lambdify(x, sympy_fn(x), modules="scipy")
Expand All @@ -1105,18 +1105,20 @@ def test_scipy_fns():
assert abs(f(tv) - scipy_fn(tv)) < 1e-13*(1 + abs(sympy_result))

double_arg_sympy_fns = [RisingFactorial, besselj, bessely, besseli,
besselk]
besselk, polygamma]
double_arg_scipy_fns = [scipy.special.poch, scipy.special.jv,
scipy.special.yv, scipy.special.iv, scipy.special.kv]
scipy.special.yv, scipy.special.iv, scipy.special.kv, scipy.special.polygamma]
for (sympy_fn, scipy_fn) in zip(double_arg_sympy_fns, double_arg_scipy_fns):
f = lambdify((x, y), sympy_fn(x, y), modules="scipy")
for i in range(20):
# SciPy supports only real orders of Bessel functions
tv1 = numpy.random.uniform(-10, 10)
tv2 = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
# SciPy supports poch for real arguments only
if sympy_fn == RisingFactorial:
# SciPy requires a real valued 2nd argument for: poch, polygamma
if sympy_fn in (RisingFactorial, polygamma):
tv2 = numpy.real(tv2)
if sympy_fn == polygamma:
tv1 = abs(int(tv1)) # first argument to polygamma must be a non-negative integral.
sympy_result = sympy_fn(tv1, tv2).evalf()
assert abs(f(tv1, tv2) - sympy_result) < 1e-13*(1 + abs(sympy_result))
assert abs(f(tv1, tv2) - scipy_fn(tv1, tv2)) < 1e-13*(1 + abs(sympy_result))
Expand Down

0 comments on commit faa0865

Please sign in to comment.