Skip to content

Commit

Permalink
Merge pull request #1790 from jrioux/solvers
Browse files Browse the repository at this point in the history
solve: return all independent solutions.
  • Loading branch information
smichr committed Feb 21, 2013
2 parents aaa5ae5 + 9d751ce commit bb44ac7
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 29 deletions.
78 changes: 52 additions & 26 deletions sympy/solvers/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,12 @@ def solve(f, *symbols, **flags):
>>> solve(x - 3)
[3]
>>> solve(x**2 - y**2) # doctest: +SKIP
[{x: -y}, {x: y}]
>>> solve(z**2*x**2 - z**2*y**2) # doctest: +SKIP
>>> solve(x**2 - y**2)
[{x: -y}, {x: y}]
>>> solve(z**2*x**2 - z**2*y**2)
[{x: -y}, {x: y}, {z: 0}]
>>> solve(z**2*x - z**2*y**2)
[{x: y**2}]
[{x: y**2}, {z: 0}]
* when an object other than a Symbol is given as a symbol, it is
isolated algebraically and an implicit solution may be obtained.
Expand Down Expand Up @@ -544,11 +544,11 @@ def solve(f, *symbols, **flags):
* if there is no linear solution then the first successful
attempt for a nonlinear solution will be returned
>>> solve(x**2 - y**2, x, y) # doctest: +SKIP
>>> solve(x**2 - y**2, x, y)
[{x: -y}, {x: y}]
>>> solve(x**2 - y**2/exp(x), x, y)
[{x: 2*LambertW(y/2)}]
>>> solve(x**2 - y**2/exp(x), y, x) # doctest: +SKIP
>>> solve(x**2 - y**2/exp(x), y, x)
[{y: -x*exp(x/2)}, {y: x*exp(x/2)}]
* iterable of one or more of the above
Expand Down Expand Up @@ -1024,23 +1024,36 @@ def _solve(f, *symbols, **flags):
return soln
# find first successful solution
failed = []
got_s = set([])
result = []
for s in symbols:
n, d = solve_linear(f, symbols=[s])
if n.is_Symbol:
# no need to check but we should simplify if desired
if flags.get('simplify', True):
d = simplify(d)
return [{n: d}]
if got_s and any([ss in d.free_symbols for ss in got_s]):
# sol depends on previously solved symbols: discard it
continue
got_s.add(n)
result.append({n: d})
elif n and d: # otherwise there was no solution for s
failed.append(s)
if not failed:
return []
return result
for s in failed:
try:
soln = _solve(f, s, **flags)
return [{s: sol} for sol in soln]
for sol in soln:
if got_s and any([ss in sol.free_symbols for ss in got_s]):
# sol depends on previously solved symbols: discard it
continue
got_s.add(s)
result.append({s: sol})
except NotImplementedError:
continue
if got_s:
return result
else:
msg = "No algorithms are implemented to solve equation %s"
raise NotImplementedError(msg % f)
Expand Down Expand Up @@ -1340,15 +1353,28 @@ def _solve_system(exprs, symbols, **flags):
free = set_union(*[p.free_symbols for p in polys])
free = list(free.intersection(symbols))
free.sort(key=default_sort_key)
got_s = set([])
result = []
for syms in subsets(free, len(polys)):
try:
# returns [] or list of tuples of solutions for syms
result = solve_poly_system(polys, *syms)
if result:
solved_syms = syms
break
res = solve_poly_system(polys, *syms)
if res:
for r in res:
skip = False
for r1 in r:
if got_s and any([ss in r1.free_symbols
for ss in got_s]):
# sol depends on previously
# solved symbols: discard it
skip = True
if not skip:
got_s.update(syms)
result.extend([dict(zip(syms, r))])
except NotImplementedError:
pass
if got_s:
solved_syms = list(got_s)
else:
raise NotImplementedError('no valid subset found')
else:
Expand All @@ -1358,14 +1384,13 @@ def _solve_system(exprs, symbols, **flags):
except NotImplementedError:
failed.extend([g.as_expr() for g in polys])
solved_syms = []

if result:
# we don't know here if the symbols provided were given
# or not, so let solve resolve that. A list of dictionaries
# is going to always be returned from here.
#
# We do not check the solution obtained from polys, either.
result = [dict(zip(solved_syms, r)) for r in result]
if result:
# we don't know here if the symbols provided were given
# or not, so let solve resolve that. A list of dictionaries
# is going to always be returned from here.
#
# We do not check the solution obtained from polys, either.
result = [dict(zip(solved_syms, r)) for r in result]

if failed:
# For each failed equation, see if we can solve for one of the
Expand Down Expand Up @@ -1394,7 +1419,7 @@ def _ok_syms(e, sort=False):
for eq in ordered(failed, lambda _: len(_ok_syms(_))):
newresult = []
bad_results = []
got_s = None
got_s = set([])
u = Dummy()
for r in result:
# update eq with everything that is known so far
Expand Down Expand Up @@ -1430,6 +1455,9 @@ def _ok_syms(e, sort=False):
if do_simplify:
flags['simplify'] = False # for checksol's sake
for sol in soln:
if got_s and any([ss in sol.free_symbols for ss in got_s]):
# sol depends on previously solved symbols: discard it
continue
if check:
# check that it satisfies *other* equations
ok = False
Expand All @@ -1452,13 +1480,11 @@ def _ok_syms(e, sort=False):
newresult.append(rnew)
if simplify_flag is not None:
flags['simplify'] = simplify_flag
got_s = s
break
else:
got_s.add(s)
if not got_s:
raise NotImplementedError('could not solve %s' % eq2)
if got_s:
result = newresult
solved_syms.add(got_s)
for b in bad_results:
result.remove(b)
# if there is only one result should we return just the dictionary?
Expand Down
18 changes: 15 additions & 3 deletions sympy/solvers/tests/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,13 @@ def test_tsolve():
assert solve(z**x - y, x) == [log(y)/log(z)]
# issue #1405
assert solve(2**x - 10, x) == [log(10)/log(2)]
# issue #3645
assert solve(x*y) == [{x: 0}, {y: 0}]
assert solve([x*y]) == [{x: 0}, {y: 0}]
assert solve(x**y - 1) == [{x: 1}, {y: 0}]
assert solve([x**y - 1]) == [{x: 1}, {y: 0}]
assert solve(x*y*(x**2 - y**2)) == [{x: 0}, {x: -y}, {x: y}, {y: 0}]
assert solve([x*y*(x**2 - y**2)]) == [{x: 0}, {x: -y}, {x: y}, {y: 0}]


def test_solve_for_functions_derivatives():
Expand Down Expand Up @@ -451,8 +458,7 @@ def test_issue_1694():
([y], set([
(-sqrt(exp(x)*log(x**2)),),
(sqrt(exp(x)*log(x**2)),)]))
assert solve(
x**2*z**2 - z**2*y**2) in ([{x: y}, {x: -y}], [{x: -y}, {x: y}])
assert solve(x**2*z**2 - z**2*y**2) == [{x: -y}, {x: y}, {z: 0}]
assert solve((x - 1)/(1 + 1/(x - 1))) == []
assert solve(x**(y*z) - x, x) == [1]
raises(NotImplementedError, lambda: solve(log(x) - exp(x), x))
Expand Down Expand Up @@ -1065,7 +1071,13 @@ def test_exclude():
Rf: Ri*(C*R*s + 1)**2/(C*R*s),
Vminus: Vplus,
V1: Vplus*(2*C*R*s + 1)/(C*R*s),
Vout: Vplus*(C**2*R**2*s**2 + 3*C*R*s + 1)/(C*R*s)}]
Vout: Vplus*(C**2*R**2*s**2 + 3*C*R*s + 1)/(C*R*s)},
{
Vplus: 0,
Vminus: 0,
V1: 0,
Vout: 0},
]
assert solve(eqs, exclude=[Vplus, s, C]) == [
{
Rf: Ri*(V1 - Vplus)**2/(Vplus*(V1 - 2*Vplus)),
Expand Down

0 comments on commit bb44ac7

Please sign in to comment.