Skip to content

Commit

Permalink
Merge pull request #22135 from oscargus/bettertypechecking
Browse files Browse the repository at this point in the history
Improved type checking
  • Loading branch information
smichr authored Sep 22, 2021
2 parents 9dd704b + 7431cfd commit c4bcd75
Show file tree
Hide file tree
Showing 49 changed files with 102 additions and 107 deletions.
2 changes: 1 addition & 1 deletion sympy/assumptions/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def refine_Pow(expr, assumptions):
if ask(Q.odd(expr.exp), assumptions):
return sign(expr.base) * abs(expr.base) ** expr.exp
if isinstance(expr.exp, Rational):
if type(expr.base) is Pow:
if isinstance(expr.base, Pow):
return abs(expr.base.base) ** (expr.base.exp * expr.exp)

if expr.base is S.NegativeOne:
Expand Down
2 changes: 1 addition & 1 deletion sympy/categories/baseclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def __new__(cls, *args):
for morphism in premises_arg:
objects |= FiniteSet(morphism.domain, morphism.codomain)
Diagram._add_morphism_closure(premises, morphism, empty)
elif isinstance(premises_arg, dict) or isinstance(premises_arg, Dict):
elif isinstance(premises_arg, (dict, Dict)):
# The user has supplied a dictionary of morphisms and
# their properties.
for morphism, props in premises_arg.items():
Expand Down
2 changes: 1 addition & 1 deletion sympy/categories/diagram_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def group_to_finiteset(group):
# should be converted to a FiniteSet, because that is what the
# following code expects.

if isinstance(groups, dict) or isinstance(groups, Dict):
if isinstance(groups, (dict, Dict)):
finiteset_groups = {}
for group, local_hints in groups.items():
finiteset_group = group_to_finiteset(group)
Expand Down
2 changes: 1 addition & 1 deletion sympy/combinatorics/graycode.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def current(self):
'100'
"""
rv = self._current or '0'
if type(rv) is not str:
if not isinstance(rv, str):
rv = bin(rv)[2:]
return rv.rjust(self.n, '0')

Expand Down
2 changes: 1 addition & 1 deletion sympy/combinatorics/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def bitlist_from_subset(self, subset, superset):
subset_from_bitlist
"""
bitlist = ['0'] * len(superset)
if type(subset) is Subset:
if isinstance(subset, Subset):
subset = subset.subset
for i in Subset.subset_indices(subset, superset):
bitlist[i] = '1'
Expand Down
6 changes: 3 additions & 3 deletions sympy/concrete/expr_with_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class ExprWithLimits(Expr):

def __new__(cls, function, *symbols, **assumptions):
pre = _common_new(cls, function, *symbols, **assumptions)
if type(pre) is tuple:
if isinstance(pre, tuple):
function, limits, _ = pre
else:
return pre
Expand Down Expand Up @@ -358,7 +358,7 @@ def _eval_subs(self, old, new):
if len(xab[0].free_symbols.intersection(old.free_symbols)) != 0:
sub_into_func = False
break
if isinstance(old, AppliedUndef) or isinstance(old, UndefinedFunction):
if isinstance(old, (AppliedUndef, UndefinedFunction)):
sy2 = set(self.variables).intersection(set(new.atoms(Symbol)))
sy1 = set(self.variables).intersection(set(old.args))
if not sy2.issubset(sy1):
Expand Down Expand Up @@ -497,7 +497,7 @@ class AddWithLimits(ExprWithLimits):

def __new__(cls, function, *symbols, **assumptions):
pre = _common_new(cls, function, *symbols, **assumptions)
if type(pre) is tuple:
if isinstance(pre, tuple):
function, limits, orientation = pre
else:
return pre
Expand Down
2 changes: 1 addition & 1 deletion sympy/concrete/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def doit(self, **hints):
if reps:
undo = {v: k for k, v in reps.items()}
did = self.xreplace(reps).doit(**hints)
if type(did) is tuple: # when separate=True
if isinstance(did, tuple): # when separate=True
did = tuple([i.xreplace(undo) for i in did])
else:
did = did.xreplace(undo)
Expand Down
2 changes: 1 addition & 1 deletion sympy/concrete/summations.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def doit(self, **hints):
if reps:
undo = {v: k for k, v in reps.items()}
did = self.xreplace(reps).doit(**hints)
if type(did) is tuple: # when separate=True
if isinstance(did, tuple): # when separate=True
did = tuple([i.xreplace(undo) for i in did])
elif did is not None:
did = did.xreplace(undo)
Expand Down
2 changes: 1 addition & 1 deletion sympy/core/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def as_int(n, strict=True):
"""
if strict:
try:
if type(n) is bool:
if isinstance(n, bool):
raise TypeError
return operator.index(n)
except TypeError:
Expand Down
2 changes: 1 addition & 1 deletion sympy/core/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def tuple_wrapper(method):
def wrap_tuples(*args, **kw_args):
newargs = []
for arg in args:
if type(arg) is tuple:
if isinstance(arg, tuple):
newargs.append(Tuple(*arg))
else:
newargs.append(arg)
Expand Down
8 changes: 4 additions & 4 deletions sympy/core/evalf.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def scaled_zero(mag, sign=1):
>>> Float(scaled_zero(ok), p)
-0.e+30
"""
if type(mag) is tuple and len(mag) == 4 and iszero(mag, scaled=True):
if isinstance(mag, tuple) and len(mag) == 4 and iszero(mag, scaled=True):
return (mag[0][0],) + mag[1:]
elif isinstance(mag, SYMPY_INTS):
if sign not in [-1, 1]:
Expand All @@ -189,7 +189,7 @@ def scaled_zero(mag, sign=1):
def iszero(mpf, scaled=False):
if not scaled:
return not mpf or not mpf[1] and not mpf[-1]
return mpf and type(mpf[0]) is list and mpf[1] == mpf[-1] == 1
return mpf and isinstance(mpf[0], list) and mpf[1] == mpf[-1] == 1


def complex_accuracy(result):
Expand Down Expand Up @@ -915,9 +915,9 @@ def evalf_piecewise(expr, prec, options):
del newopts['subs']
if hasattr(expr, 'func'):
return evalf(expr, prec, newopts)
if type(expr) == float:
if isinstance(expr, float):
return evalf(Float(expr), prec, newopts)
if type(expr) == int:
if isinstance(expr, int):
return evalf(Integer(expr), prec, newopts)

# We still have undefined symbols
Expand Down
10 changes: 5 additions & 5 deletions sympy/core/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ def __new__(cls, num, dps=None, prec=None, precision=None):
else:
raise ValueError("unexpected decimal value %s" % str(num))
elif isinstance(num, tuple) and len(num) in (3, 4):
if type(num[1]) is str:
if isinstance(num[1], str):
# it's a hexadecimal (coming from a pickled object)
num = list(num)
# If we're loading an object pickled in Python 2 into
Expand Down Expand Up @@ -1952,31 +1952,31 @@ def __gt__(self, other):
rv = self._Rrel(other, '__lt__')
if rv is None:
rv = self, other
elif not type(rv) is tuple:
elif not isinstance(rv, tuple):
return rv
return Expr.__gt__(*rv)

def __ge__(self, other):
rv = self._Rrel(other, '__le__')
if rv is None:
rv = self, other
elif not type(rv) is tuple:
elif not isinstance(rv, tuple):
return rv
return Expr.__ge__(*rv)

def __lt__(self, other):
rv = self._Rrel(other, '__gt__')
if rv is None:
rv = self, other
elif not type(rv) is tuple:
elif not isinstance(rv, tuple):
return rv
return Expr.__lt__(*rv)

def __le__(self, other):
rv = self._Rrel(other, '__ge__')
if rv is None:
rv = self, other
elif not type(rv) is tuple:
elif not isinstance(rv, tuple):
return rv
return Expr.__le__(*rv)

Expand Down
2 changes: 1 addition & 1 deletion sympy/crypto/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def AZ(s=None):
"""
if not s:
return uppercase
t = type(s) is str
t = isinstance(s, str)
if t:
s = [s]
rv = [check_and_join(i.upper().split(), uppercase, filter=True)
Expand Down
8 changes: 4 additions & 4 deletions sympy/functions/elementary/trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _eval_nseries(self, x, n, logx, cdir=0):

def _eval_rewrite_as_exp(self, arg, **kwargs):
I = S.ImaginaryUnit
if isinstance(arg, TrigonometricFunction) or isinstance(arg, HyperbolicFunction):
if isinstance(arg, (TrigonometricFunction, HyperbolicFunction)):
arg = arg.func(arg.args[0]).rewrite(exp)
return (exp(arg*I) - exp(-arg*I))/(2*I)

Expand Down Expand Up @@ -734,7 +734,7 @@ def _eval_nseries(self, x, n, logx, cdir=0):

def _eval_rewrite_as_exp(self, arg, **kwargs):
I = S.ImaginaryUnit
if isinstance(arg, TrigonometricFunction) or isinstance(arg, HyperbolicFunction):
if isinstance(arg, (TrigonometricFunction, HyperbolicFunction)):
arg = arg.func(arg.args[0]).rewrite(exp)
return (exp(arg*I) + exp(-arg*I))/2

Expand Down Expand Up @@ -1222,7 +1222,7 @@ def _eval_expand_trig(self, **hints):

def _eval_rewrite_as_exp(self, arg, **kwargs):
I = S.ImaginaryUnit
if isinstance(arg, TrigonometricFunction) or isinstance(arg, HyperbolicFunction):
if isinstance(arg, (TrigonometricFunction, HyperbolicFunction)):
arg = arg.func(arg.args[0]).rewrite(exp)
neg_exp, pos_exp = exp(-arg*I), exp(arg*I)
return I*(neg_exp - pos_exp)/(neg_exp + pos_exp)
Expand Down Expand Up @@ -1498,7 +1498,7 @@ def as_real_imag(self, deep=True, **hints):

def _eval_rewrite_as_exp(self, arg, **kwargs):
I = S.ImaginaryUnit
if isinstance(arg, TrigonometricFunction) or isinstance(arg, HyperbolicFunction):
if isinstance(arg, (TrigonometricFunction, HyperbolicFunction)):
arg = arg.func(arg.args[0]).rewrite(exp)
neg_exp, pos_exp = exp(-arg*I), exp(arg*I)
return I*(pos_exp + neg_exp)/(pos_exp - neg_exp)
Expand Down
2 changes: 1 addition & 1 deletion sympy/geometry/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def encloses(self, o):
return self.encloses_point(o)
elif isinstance(o, Segment):
return all(self.encloses_point(x) for x in o.points)
elif isinstance(o, Ray) or isinstance(o, Line):
elif isinstance(o, (Ray, Line)):
return False
elif isinstance(o, Ellipse):
return self.encloses_point(o.center) and \
Expand Down
4 changes: 2 additions & 2 deletions sympy/holonomic/holonomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,7 +2185,7 @@ def from_hyper(func, x0=0, evalf=False):

simp = hyperexpand(func)

if isinstance(simp, Infinity) or isinstance(simp, NegativeInfinity):
if simp in (Infinity, NegativeInfinity):
return HolonomicFunction(sol, x).composition(z)

def _find_conditions(simp, x, x0, order, evalf=False):
Expand Down Expand Up @@ -2271,7 +2271,7 @@ def from_meijerg(func, x0=0, evalf=False, initcond=True, domain=QQ):

simp = hyperexpand(func)

if isinstance(simp, Infinity) or isinstance(simp, NegativeInfinity):
if simp in (Infinity, NegativeInfinity):
return HolonomicFunction(sol, x).composition(z)

def _find_conditions(simp, x, x0, order, evalf=False):
Expand Down
2 changes: 1 addition & 1 deletion sympy/integrals/integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def doit(self, **hints):
if reps:
undo = {v: k for k, v in reps.items()}
did = self.xreplace(reps).doit(**hints)
if type(did) is tuple: # when separate=True
if isinstance(did, tuple): # when separate=True
did = tuple([i.xreplace(undo) for i in did])
else:
did = did.xreplace(undo)
Expand Down
2 changes: 1 addition & 1 deletion sympy/integrals/manualintegrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def make_second_step(steps, integrand):

def trig_rule(integral):
integrand, symbol = integral
if isinstance(integrand, sympy.sin) or isinstance(integrand, sympy.cos):
if isinstance(integrand, (sympy.sin, sympy.cos)):
arg = integrand.args[0]

if not isinstance(arg, sympy.Symbol):
Expand Down
4 changes: 2 additions & 2 deletions sympy/integrals/meijerint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,7 @@ def meijerint_indefinite(f, x):
rv = meijerint_indefinite(
_rewrite_hyperbolics_as_exp(f), x)
if rv:
if not type(rv) is list:
if not isinstance(rv, list):
return collect(factor_terms(rv), rv.atoms(exp))
results.extend(rv)
if results:
Expand Down Expand Up @@ -1890,7 +1890,7 @@ def meijerint_definite(f, x, a, b):
rv = meijerint_definite(
_rewrite_hyperbolics_as_exp(f_), x_, a_, b_)
if rv:
if not type(rv) is list:
if not isinstance(rv, list):
rv = (collect(factor_terms(rv[0]), rv[0].atoms(exp)),) + rv[1:]
return rv
results.extend(rv)
Expand Down
13 changes: 6 additions & 7 deletions sympy/integrals/rationaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def ratint(f, x, **flags):
sympy.integrals.rationaltools.ratint_ratpart
"""
if type(f) is not tuple:
p, q = f.as_numer_denom()
else:
if isinstance(f, tuple):
p, q = f
else:
p, q = f.as_numer_denom()

p, q = Poly(p, x, composite=False, field=True), Poly(q, x, composite=False, field=True)

Expand Down Expand Up @@ -78,12 +78,11 @@ def ratint(f, x, **flags):
real = flags.get('real')

if real is None:
if type(f) is not tuple:
atoms = f.atoms()
else:
if isinstance(f, tuple):
p, q = f

atoms = p.atoms() | q.atoms()
else:
atoms = f.atoms()

for elt in atoms - {x}:
if not elt.is_extended_real:
Expand Down
2 changes: 1 addition & 1 deletion sympy/integrals/risch.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def frac_in(f, t, *, cancel=False, **kwargs):
where fa and fd are either basic expressions or Polys, and f == fa/fd.
**kwargs are applied to Poly.
"""
if type(f) is tuple:
if isinstance(f, tuple):
fa, fd = f
f = fa.as_expr()/fd.as_expr()
fa, fd = f.as_expr().as_numer_denom()
Expand Down
2 changes: 1 addition & 1 deletion sympy/integrals/rubi/rubimain.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def rubi_integrate(expr, var, showsteps=False):
expr = expr.replace(sym_exp, rubi_exp)
expr = process_trig(expr)
expr = rubi_powsimp(expr)
if isinstance(expr, (int, Integer)) or isinstance(expr, (float, Float)):
if isinstance(expr, (int, Integer, float, Float)):
return S(expr)*var
if isinstance(expr, Add):
results = 0
Expand Down
4 changes: 2 additions & 2 deletions sympy/liealgebras/cartan_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ class CartanType_generator(Basic):
"""
def __call__(self, *args):
c = args[0]
if type(c) == list:
if isinstance(c, list):
letter, n = c[0], int(c[1])
elif type(c) == str:
elif isinstance(c, str):
letter, n = c[0], int(c[1:])
else:
raise TypeError("Argument must be a string (e.g. 'A3') or a list (e.g. ['A', 3])")
Expand Down
2 changes: 1 addition & 1 deletion sympy/matrices/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3234,7 +3234,7 @@ def _matrixify(mat):

def a2idx(j, n=None):
"""Return integer after making positive and validating against n."""
if type(j) is not int:
if not isinstance(j, int):
jindex = getattr(j, '__index__', None)
if jindex is not None:
j = jindex()
Expand Down
19 changes: 9 additions & 10 deletions sympy/matrices/eigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,14 +612,14 @@ def _bidiagonal_decomposition(M, upper=True):
"""

if type(upper) is not bool:
if not isinstance(upper, bool):
raise ValueError("upper must be a boolean")

if not upper:
X = _bidiagonal_decmp_hholder(M.H)
return X[2].H, X[1].H, X[0].H
if upper:
return _bidiagonal_decmp_hholder(M)

return _bidiagonal_decmp_hholder(M)
X = _bidiagonal_decmp_hholder(M.H)
return X[2].H, X[1].H, X[0].H


def _bidiagonalize(M, upper=True):
Expand All @@ -642,13 +642,12 @@ def _bidiagonalize(M, upper=True):
"""

if type(upper) is not bool:
if not isinstance(upper, bool):
raise ValueError("upper must be a boolean")

if not upper:
return _eval_bidiag_hholder(M.H).H

return _eval_bidiag_hholder(M)
if upper:
return _eval_bidiag_hholder(M)
return _eval_bidiag_hholder(M.H).H


def _diagonalize(M, reals_only=False, sort=False, normalize=False):
Expand Down
Loading

0 comments on commit c4bcd75

Please sign in to comment.