Skip to content

Commit

Permalink
[Oryx] Fix use of jax.core.literalable_types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 334964477
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Oct 2, 2020
1 parent aaed4d3 commit 900c08d
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions spinoffs/oryx/oryx/core/interpreters/unzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ def instantiate_const(self, tracer):
if isinstance(pv, jax_core.AbstractValue):
return tracer
elif not pv:
if (isinstance(const, tuple(jax_core.literalable_types)) and
not onp.shape(const)):
if type(const) in jax_core.literalable_types and not onp.shape(const): # pylint: disable=unidiomatic-typecheck
return self.new_instantiated_literal(const)
else:
return self.new_instantiated_const(const)
Expand Down

0 comments on commit 900c08d

Please sign in to comment.