Skip to content

Commit

Permalink
Removed _IdDict
Browse files Browse the repository at this point in the history
The use of _IdDict was causing issues with symbol lookups, as we can
obtain strings in more than one way. (Which I knew when I originally
wrote this, but for some reason thought we avoided.)

There was a good reason for looking things up by identity, namely that
it makes most sense to look up each AbstractNode by identity. (As
they're not otherwise hashable in any meaningful way.) So we've switched
to doing that just for AbstractNode instead.
  • Loading branch information
patrick-kidger committed Aug 3, 2022
1 parent 5c6e80e commit 9781eae
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions sympy2jax/sympy_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,40 +84,33 @@ def fn_(*args):
assert len(_reverse_lookup) == len(_lookup)


class _IdDict:
def __init__(self, **values):
self._dict = {id(k): v for k, v in values.items()}

def __getitem__(self, item):
return self._dict[id(item)]

def __setitem__(self, item, value):
self._dict[id(item)] = value


class _AbstractNode(eqx.Module):
@abc.abstractmethod
def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
...

@abc.abstractmethod
def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
...

# Comparisons based on identity
__hash__ = object.__hash__
__eq__ = object.__eq__


class _Symbol(_AbstractNode):
_name: str

def __init__(self, expr: sympy.Expr):
self._name = expr.name

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
try:
return memodict[self._name]
except KeyError as e:
raise KeyError(f"Missing input for symbol {self._name}") from e

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
# memodict not needed as sympy deduplicates internally
return sympy.Symbol(self._name)

Expand All @@ -136,10 +129,10 @@ def __init__(self, expr: sympy.Expr, make_array: bool):
assert isinstance(expr, sympy.Integer)
self._value = _maybe_array(int(expr), make_array)

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
return self._value

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
# memodict not needed as sympy deduplicates internally
return sympy.Integer(self._value.item())

Expand All @@ -151,10 +144,10 @@ def __init__(self, expr: sympy.Expr, make_array: bool):
assert isinstance(expr, sympy.Float)
self._value = _maybe_array(float(expr), make_array)

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
return self._value

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
# memodict not needed as sympy deduplicates internally
return sympy.Float(self._value.item())

Expand All @@ -168,10 +161,10 @@ def __init__(self, expr: sympy.Expr, make_array: bool):
self._numerator = _maybe_array(int(expr.numerator), make_array)
self._denominator = _maybe_array(int(expr.denominator), make_array)

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
return self._numerator / self._denominator

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
# memodict not needed as sympy deduplicates internally
return sympy.Integer(self._numerator) / sympy.Integer(self._denominator)

Expand All @@ -181,7 +174,7 @@ class _Func(_AbstractNode):
_args: list

def __init__(
self, expr: sympy.Expr, memodict: _IdDict, func_lookup: dict, make_array: bool
self, expr: sympy.Expr, memodict: dict, func_lookup: dict, make_array: bool
):
try:
self._func = func_lookup[expr.func]
Expand All @@ -191,7 +184,7 @@ def __init__(
_sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
]

def __call__(self, memodict: _IdDict):
def __call__(self, memodict: dict):
args = []
for arg in self._args:
try:
Expand All @@ -202,7 +195,7 @@ def __call__(self, memodict: _IdDict):
args.append(arg_call)
return self._func(*args)

def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:
def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
try:
return memodict[self]
except KeyError:
Expand All @@ -214,7 +207,7 @@ def sympy(self, memodict: _IdDict, func_lookup: dict) -> sympy.Expr:


def _sympy_to_node(
expr: sympy.Expr, memodict: _IdDict, func_lookup: dict, make_array: bool
expr: sympy.Expr, memodict: dict, func_lookup: dict, make_array: bool
) -> _AbstractNode:
try:
return memodict[expr]
Expand Down Expand Up @@ -257,7 +250,7 @@ def __init__(
self.has_extra_funcs = True
_convert = ft.partial(
_sympy_to_node,
memodict=_IdDict(),
memodict=dict(),
func_lookup=lookup,
make_array=make_array,
)
Expand All @@ -268,11 +261,11 @@ def sympy(self) -> sympy.Expr:
raise NotImplementedError(
"SymbolicModule cannot be converted back to SymPy if `extra_funcs` is passed"
)
memodict = _IdDict()
memodict = dict()
return jax.tree_map(
lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node
)

def __call__(self, **symbols):
memodict = _IdDict(**symbols)
memodict = symbols
return jax.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)

0 comments on commit 9781eae

Please sign in to comment.