TypeError during gradient computation: type <class 'objax.variable.TrainVar'> is not a valid JAX type #260
Open
Description
I was trying to run a simple example but there are type issues when evaluating the gradients?
TypeError: Argument 'objax.TrainVar(Traced<ConcreteArray([-1.1010288 -0.6818452 -0.95236534], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([-1.1010288 , -0.6818452 , -0.95236534], dtype=float32)
tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[3]), None)
recipe = LambdaBinding(), reduce=reduce_mean)' of type <class 'objax.variable.TrainVar'> is not a valid JAX type.
Minimal example from the docs:
import objax
import jax.numpy as jn
n = 1000
ndim = 10
X = objax.random.normal((n, ndim))
y = objax.random.normal((n, 1))
w = objax.TrainVar(jn.zeros(ndim))
b = objax.TrainVar(jn.zeros(1))
def loss(x, y):
pred = jn.dot(x, w) + b
return 0.5 * ((y - pred) ** 2).mean()
g_fn = objax.Grad(loss, # g_fn is Objax module
objax.VarCollection({'w': w, 'b': b}))
g_value = g_fn(X, y)
Metadata
Assignees
Labels
No labels