Skip to content

Commit

Permalink
Merge pull request sympy#22993 from oscargus/cot_py_generation
Browse files Browse the repository at this point in the history
Enable automatic rewrite for cot, sec, csc, and related inverse and hyperbolic functions
  • Loading branch information
oscarbenjamin authored Jul 11, 2022
2 parents cace087 + 74c43ce commit 4e72b8b
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 43 deletions.
11 changes: 6 additions & 5 deletions doc/src/tutorial/simplification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,12 @@ another. This works for any function in SymPy, not just special functions.
To rewrite an expression in terms of a function, use
``expr.rewrite(function)``. For example,

>>> tan(x).rewrite(sin)
2
2⋅sin (x)
─────────
sin(2⋅x)
>>> tan(x).rewrite(cos)
⎛ π⎞
cos⎜x - ─⎟
⎝ 2⎠
──────────
cos(x)
>>> factorial(x).rewrite(gamma)
Γ(x + 1)

Expand Down
127 changes: 120 additions & 7 deletions sympy/functions/elementary/hyperbolic.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from sympy.core.logic import FuzzyBool

from sympy.core import S, sympify, cacheit, pi, I, Rational
from sympy.core.add import Add
from sympy.core.function import Function, ArgumentIndexError
from sympy.core.logic import fuzzy_or, fuzzy_and
from sympy.core.logic import fuzzy_or, fuzzy_and, FuzzyBool
from sympy.functions.combinatorial.factorials import (binomial, factorial,
RisingFactorial)
from sympy.functions.combinatorial.numbers import bernoulli, euler, nC
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import exp, log, match_real_imag
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.integers import floor
from sympy.functions.elementary.trigonometric import (acot, asin, atan, cos,
cot, sin, tan,
_imaginary_unit_as_coefficient)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (
acos, acot, asin, atan, cos, cot, csc, sec, sin, tan,
_imaginary_unit_as_coefficient)
from sympy.polys.specialpolys import symmetric_poly


Expand Down Expand Up @@ -221,6 +219,12 @@ def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs):
def _eval_rewrite_as_exp(self, arg, **kwargs):
return (exp(arg) - exp(-arg)) / 2

def _eval_rewrite_as_sin(self, arg, **kwargs):
return -I * sin(I * arg)

def _eval_rewrite_as_csc(self, arg, **kwargs):
return -I / csc(I * arg)

def _eval_rewrite_as_cosh(self, arg, **kwargs):
return -S.ImaginaryUnit*cosh(arg + S.Pi*S.ImaginaryUnit/2)

Expand All @@ -232,6 +236,9 @@ def _eval_rewrite_as_coth(self, arg, **kwargs):
coth_half = coth(S.Half*arg)
return 2*coth_half/(coth_half**2 - 1)

def _eval_rewrite_as_csch(self, arg, **kwargs):
return 1 / csch(arg)

def _eval_as_leading_term(self, x, logx=None, cdir=0):
arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir)
arg0 = arg.subs(x, 0)
Expand Down Expand Up @@ -409,6 +416,12 @@ def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs):
def _eval_rewrite_as_exp(self, arg, **kwargs):
return (exp(arg) + exp(-arg)) / 2

def _eval_rewrite_as_cos(self, arg, **kwargs):
return cos(I * arg)

def _eval_rewrite_as_sec(self, arg, **kwargs):
return 1 / sec(I * arg)

def _eval_rewrite_as_sinh(self, arg, **kwargs):
return -S.ImaginaryUnit*sinh(arg + S.Pi*S.ImaginaryUnit/2)

Expand All @@ -420,6 +433,9 @@ def _eval_rewrite_as_coth(self, arg, **kwargs):
coth_half = coth(S.Half*arg)**2
return (coth_half + 1)/(coth_half - 1)

def _eval_rewrite_as_sech(self, arg, **kwargs):
return 1 / sech(arg)

def _eval_as_leading_term(self, x, logx=None, cdir=0):
arg = self.args[0].as_leading_term(x, logx=logx, cdir=cdir)
arg0 = arg.subs(x, 0)
Expand Down Expand Up @@ -658,6 +674,12 @@ def _eval_rewrite_as_exp(self, arg, **kwargs):
neg_exp, pos_exp = exp(-arg), exp(arg)
return (pos_exp - neg_exp)/(pos_exp + neg_exp)

def _eval_rewrite_as_tan(self, arg, **kwargs):
return -I * tan(I * arg)

def _eval_rewrite_as_cot(self, arg, **kwargs):
return -I / cot(I * arg)

def _eval_rewrite_as_sinh(self, arg, **kwargs):
return S.ImaginaryUnit*sinh(arg)/sinh(S.Pi*S.ImaginaryUnit/2 - arg)

Expand Down Expand Up @@ -1017,9 +1039,18 @@ def taylor_term(n, x, *previous_terms):

return 2 * (1 - 2**n) * B/F * x**n

def _eval_rewrite_as_sin(self, arg, **kwargs):
return I / sin(I * arg)

def _eval_rewrite_as_csc(self, arg, **kwargs):
return I * csc(I * arg)

def _eval_rewrite_as_cosh(self, arg, **kwargs):
return S.ImaginaryUnit / cosh(arg + S.ImaginaryUnit * S.Pi / 2)

def _eval_rewrite_as_sinh(self, arg, **kwargs):
return 1 / sinh(arg)

def _eval_is_positive(self):
if self.args[0].is_extended_real:
return self.args[0].is_positive
Expand Down Expand Up @@ -1067,9 +1098,18 @@ def taylor_term(n, x, *previous_terms):
x = sympify(x)
return euler(n) / factorial(n) * x**(n)

def _eval_rewrite_as_cos(self, arg, **kwargs):
return 1 / cos(I * arg)

def _eval_rewrite_as_sec(self, arg, **kwargs):
return sec(I * arg)

def _eval_rewrite_as_sinh(self, arg, **kwargs):
return S.ImaginaryUnit / sinh(arg + S.ImaginaryUnit * S.Pi /2)

def _eval_rewrite_as_cosh(self, arg, **kwargs):
return 1 / cosh(arg)

def _eval_is_positive(self):
if self.args[0].is_extended_real:
return True
Expand Down Expand Up @@ -1187,6 +1227,19 @@ def _eval_as_leading_term(self, x, logx=None, cdir=0):
def _eval_rewrite_as_log(self, x, **kwargs):
return log(x + sqrt(x**2 + 1))

def _eval_rewrite_as_atanh(self, x, **kwargs):
return atanh(x/sqrt(1 + x**2))

def _eval_rewrite_as_acosh(self, x, **kwargs):
ix = I*x
return I*(sqrt(1 - ix)/sqrt(ix - 1) * acosh(ix) - pi/2)

def _eval_rewrite_as_asin(self, x, **kwargs):
return -I * asin(I * x)

def _eval_rewrite_as_acos(self, x, **kwargs):
return I * acos(I * x) - I*pi/2

def inverse(self, argindex=1):
"""
Returns the inverse of this function.
Expand Down Expand Up @@ -1337,6 +1390,22 @@ def _eval_as_leading_term(self, x, logx=None, cdir=0):
def _eval_rewrite_as_log(self, x, **kwargs):
return log(x + sqrt(x + 1) * sqrt(x - 1))

def _eval_rewrite_as_acos(self, x, **kwargs):
return sqrt(x - 1)/sqrt(1 - x) * acos(x)

def _eval_rewrite_as_asin(self, x, **kwargs):
return sqrt(x - 1)/sqrt(1 - x) * (pi/2 - asin(x))

def _eval_rewrite_as_asinh(self, x, **kwargs):
return sqrt(x - 1)/sqrt(1 - x) * (pi/2 + I*asinh(I*x))

def _eval_rewrite_as_atanh(self, x, **kwargs):
sxm1 = sqrt(x - 1)
s1mx = sqrt(1 - x)
sx2m1 = sqrt(x**2 - 1)
return (pi/2*sxm1/s1mx*(1 - x * sqrt(1/x**2)) +
sxm1*sqrt(x + 1)/sx2m1 * atanh(sx2m1/x))

def inverse(self, argindex=1):
"""
Returns the inverse of this function.
Expand Down Expand Up @@ -1442,6 +1511,11 @@ def _eval_as_leading_term(self, x, logx=None, cdir=0):
def _eval_rewrite_as_log(self, x, **kwargs):
return (log(1 + x) - log(1 - x)) / 2

def _eval_rewrite_as_asinh(self, x, **kwargs):
f = sqrt(1/(x**2 - 1))
return (pi*x/(2*sqrt(-x**2)) -
sqrt(-x)*sqrt(1 - x**2)/sqrt(x)*f*asinh(f))

def _eval_is_zero(self):
if self.args[0].is_zero:
return True
Expand Down Expand Up @@ -1537,6 +1611,13 @@ def _eval_as_leading_term(self, x, logx=None, cdir=0):
def _eval_rewrite_as_log(self, x, **kwargs):
return (log(1 + 1/x) - log(1 - 1/x)) / 2

def _eval_rewrite_as_atanh(self, x, **kwargs):
return atanh(1/x)

def _eval_rewrite_as_asinh(self, x, **kwargs):
return (pi*I/2*(sqrt((x - 1)/x)*sqrt(x/(x - 1)) - sqrt(1 + 1/x)*sqrt(x/(x + 1))) +
x*sqrt(1/x**2)*asinh(sqrt(1/(x**2 - 1))))

def inverse(self, argindex=1):
"""
Returns the inverse of this function.
Expand Down Expand Up @@ -1617,6 +1698,8 @@ def _asech_table():
(-1 - sqrt(5)): 3*S.Pi / 5,
(sqrt(6) + sqrt(2)): 5*S.Pi / 12,
(-sqrt(6) - sqrt(2)): 7*S.Pi / 12,
S.ImaginaryUnit*S.Infinity: -S.Pi*S.ImaginaryUnit / 2,
S.ImaginaryUnit*S.NegativeInfinity: S.Pi*S.ImaginaryUnit / 2,
}

@classmethod
Expand Down Expand Up @@ -1677,6 +1760,20 @@ def inverse(self, argindex=1):
def _eval_rewrite_as_log(self, arg, **kwargs):
return log(1/arg + sqrt(1/arg - 1) * sqrt(1/arg + 1))

def _eval_rewrite_as_acosh(self, arg, **kwargs):
return acosh(1/arg)

def _eval_rewrite_as_asinh(self, arg, **kwargs):
return sqrt(1/arg - 1)/sqrt(1 - 1/arg)*(S.ImaginaryUnit*asinh(S.ImaginaryUnit/arg)
+ S.Pi*S.Half)

def _eval_rewrite_as_atanh(self, x, **kwargs):
return (I*pi*(1 - sqrt(x)*sqrt(1/x) - I/2*sqrt(-x)/sqrt(x) - I/2*sqrt(x**2)/sqrt(-x**2))
+ sqrt(1/(x + 1))*sqrt(x + 1)*atanh(sqrt(1 - x**2)))

def _eval_rewrite_as_acsch(self, x, **kwargs):
return sqrt(1/x - 1)/sqrt(1 - 1/x)*(pi/2 - I*acsch(I*x))


class acsch(InverseHyperbolicFunction):
"""
Expand Down Expand Up @@ -1767,6 +1864,9 @@ def eval(cls, arg):
if arg is S.ComplexInfinity:
return S.Zero

if arg.is_infinite:
return S.Zero

if arg.is_zero:
return S.ComplexInfinity

Expand All @@ -1782,5 +1882,18 @@ def inverse(self, argindex=1):
def _eval_rewrite_as_log(self, arg, **kwargs):
return log(1/arg + sqrt(1/arg**2 + 1))

def _eval_rewrite_as_asinh(self, arg, **kwargs):
return asinh(1/arg)

def _eval_rewrite_as_acosh(self, arg, **kwargs):
return S.ImaginaryUnit*(sqrt(1 - S.ImaginaryUnit/arg)/sqrt(S.ImaginaryUnit/arg - 1)*
acosh(S.ImaginaryUnit/arg) - S.Pi*S.Half)

def _eval_rewrite_as_atanh(self, arg, **kwargs):
arg2 = arg**2
arg2p1 = arg2 + 1
return sqrt(-arg2)/arg*(S.Pi*S.Half -
sqrt(-arg2p1**2)/arg2p1*atanh(sqrt(arg2p1)))

def _eval_is_zero(self):
return self.args[0].is_infinite
29 changes: 27 additions & 2 deletions sympy/functions/elementary/tests/test_hyperbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,10 @@ def test_asinh():
def test_asinh_rewrite():
x = Symbol('x')
assert asinh(x).rewrite(log) == log(x + sqrt(x**2 + 1))
assert asinh(x).rewrite(atanh) == atanh(x/sqrt(1 + x**2))
assert asinh(x).rewrite(asin) == asinh(x)
assert asinh(x*(1 + I)).rewrite(asin) == -I*asin(I*x*(1+I))
assert asinh(x).rewrite(acos) == I*(-I*asinh(x) + pi/2) - I*pi/2


def test_asinh_series():
Expand Down Expand Up @@ -637,6 +641,14 @@ def test_acosh():
def test_acosh_rewrite():
x = Symbol('x')
assert acosh(x).rewrite(log) == log(x + sqrt(x - 1)*sqrt(x + 1))
assert acosh(x).rewrite(asin) == sqrt(x - 1)*(-asin(x) + pi/2)/sqrt(1 - x)
assert acosh(x).rewrite(asinh) == sqrt(x - 1)*(-asin(x) + pi/2)/sqrt(1 - x)
assert acosh(x).rewrite(atanh) == \
(sqrt(x - 1)*sqrt(x + 1)*atanh(sqrt(x**2 - 1)/x)/sqrt(x**2 - 1) +
pi*sqrt(x - 1)*(-x*sqrt(x**(-2)) + 1)/(2*sqrt(1 - x)))
x = Symbol('x', positive=True)
assert acosh(x).rewrite(atanh) == \
sqrt(x - 1)*sqrt(x + 1)*atanh(sqrt(x**2 - 1)/x)/sqrt(x**2 - 1)


def test_acosh_series():
Expand Down Expand Up @@ -721,6 +733,10 @@ def test_asech_series():
def test_asech_rewrite():
x = Symbol('x')
assert asech(x).rewrite(log) == log(1/x + sqrt(1/x - 1) * sqrt(1/x + 1))
assert asech(x).rewrite(acosh) == acosh(1/x)
assert asech(x).rewrite(asinh) == sqrt(-1 + 1/x)*(-asin(1/x) + pi/2)/sqrt(1 - 1/x)
assert asech(x).rewrite(atanh) == \
sqrt(x + 1)*sqrt(1/(x + 1))*atanh(sqrt(1 - x**2)) + I*pi*(-sqrt(x)*sqrt(1/x) + 1 - I*sqrt(x**2)/(2*sqrt(-x**2)) - I*sqrt(-x)/(2*sqrt(x)))


def test_asech_fdiff():
Expand Down Expand Up @@ -796,6 +812,10 @@ def test_acsch_infinities():
def test_acsch_rewrite():
x = Symbol('x')
assert acsch(x).rewrite(log) == log(1/x + sqrt(1/x**2 + 1))
assert acsch(x).rewrite(asinh) == asinh(1/x)
assert acsch(x).rewrite(atanh) == (sqrt(-x**2)*(-sqrt(-(x**2 + 1)**2)
*atanh(sqrt(x**2 + 1))/(x**2 + 1)
+ pi/2)/x)


def test_acsch_fdiff():
Expand Down Expand Up @@ -864,6 +884,8 @@ def test_atanh():
def test_atanh_rewrite():
x = Symbol('x')
assert atanh(x).rewrite(log) == (log(1 + x) - log(1 - x)) / 2
assert atanh(x).rewrite(asinh) == \
pi*x/(2*sqrt(-x**2)) - sqrt(-x)*sqrt(1 - x**2)*sqrt(1/(x**2 - 1))*asinh(sqrt(1/(x**2 - 1)))/sqrt(x)


def test_atanh_series():
Expand Down Expand Up @@ -920,6 +942,9 @@ def test_acoth():
def test_acoth_rewrite():
x = Symbol('x')
assert acoth(x).rewrite(log) == (log(1 + 1/x) - log(1 - 1/x)) / 2
assert acoth(x).rewrite(atanh) == atanh(1/x)
assert acoth(x).rewrite(asinh) == \
x*sqrt(x**(-2))*asinh(sqrt(1/(x**2 - 1))) + I*pi*(sqrt((x - 1)/x)*sqrt(x/(x - 1)) - sqrt(x/(x + 1))*sqrt(1 + 1/x))/2


def test_acoth_series():
Expand Down Expand Up @@ -956,8 +981,8 @@ def test_leading_term():
for func in [sinh, tanh, asinh, atanh]:
assert func(x).as_leading_term(x) == x
for func in [sinh, cosh, tanh, coth, asinh, acosh, atanh, acoth]:
for arg in (1/x, S.Half):
eq = func(arg)
for ar in (1/x, S.Half):
eq = func(ar)
assert eq.as_leading_term(x) == eq
for func in [csch, sech]:
eq = func(S.Half)
Expand Down
Loading

0 comments on commit 4e72b8b

Please sign in to comment.