-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accessing incoming gradient and activations in custom_vjp #22993
Comments
I'm not sure I totally follow the goal here, but your intuition that you don't want to rely on the At a high level, I expect that the key assumption here is that this custom_vjp will always be called inside a I'm happy to try to give more concrete suggestions if you can put together a runnable demo! |
Thank you! I can explain a bit more. The goal I am trying to achieve is to compute the pairwise inner product between the gradient of every batch element. Normally, in the backward pass, the gradient to the weight would be computed as the matrix multiplication import numpy as np
import jax
jnp = jax.numpy
vmap = jax.vmap
@jax.custom_vjp
def linear_vjp(weight_MxN, x_N):
return weight_MxN @ x_N
def linear_vjp_fwd(weight_MxN, x_N):
y = weight_MxN @ x_N
return y, (weight_MxN, x_N)
def linear_vjp_bwd(res, grad):
weight_MxN, x_N = res
if type(grad) != type(x_N):
grad_weight = jnp.tile(grad[:, None], (1, x_N.val.shape[0])) @ x_N.val
else:
grad_weight = grad.val @ x_N.val
grad_x = weight_MxN.T @ grad
if type(grad) != type(x_N):
delta_inner = grad @ grad
elif grad.batch_dim == 0:
delta_inner = grad.val @ grad.val.T
else:
delta_inner = grad.val.T @ grad.val # This line computes the pairwise inner product
print(delta_inner) # This line prints out the pairwise inner product
return grad_weight, grad_x
linear_vjp.defvjp(linear_vjp_fwd, linear_vjp_bwd)
x_BxN = jnp.array(np.ones((2, 2))) # B=2, N=2
w_MxN = jnp.array(np.ones((3, 2))) # M=3
target_BxM = jnp.array(np.zeros((2, 3)))
def loss_fn(w, x):
def batch_linear(x):
def partial(x):
return linear_vjp(w, x)
return vmap(partial)(x)
return jnp.sum(batch_linear(x) - target_BxM)
value, grads = jax.value_and_grad(loss_fn, argnums=0)(w_MxN, x_BxN) This script should print the
|
I see. And I guess you're using |
I plan to use it to modify the gradient of the weight (not shown here for brevity) so it modifies the backpass. I also want to get it out of the grad by passing in a dummy |
Hi! I am trying to write a custom backward pass for a linear layer where I can simultaneously compute the pairwise inner product of all the gradients in the batch. More specifically, suppose I have$a = W x$ , where $x \in \mathbb{R}^{B\times N}$ is the input and let $\delta \in \mathbb{R}^{B\times M}$ be the incoming gradient. I want to compute two $B \times B$ matrices that is equal to $x x^\top$ and $\delta \delta^\top$ .
Both of these quantities are already computed for the backpass but as far as I can tell they are only available through the val fields of the Batchtraced shaped array. This becomes very inconvenient when the function has gone through several
vmap
which means I need to call.val
several times to reach the full tensor and it's hard to know a priori how many levels of vmap it has gone through. Is it possible to make it easier to do something like this? Another related question is what is the best way to compute the gradient in custom_vjp in this case? Also is this generally a good idea for jit?The text was updated successfully, but these errors were encountered: