Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accessing incoming gradient and activations in custom_vjp #22993

Open
yidingjiang opened this issue Aug 12, 2024 · 4 comments
Open

Accessing incoming gradient and activations in custom_vjp #22993

yidingjiang opened this issue Aug 12, 2024 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@yidingjiang
Copy link

yidingjiang commented Aug 12, 2024

Hi! I am trying to write a custom backward pass for a linear layer where I can simultaneously compute the pairwise inner product of all the gradients in the batch. More specifically, suppose I have $a = W x$, where $x \in \mathbb{R}^{B\times N}$ is the input and let $\delta \in \mathbb{R}^{B\times M}$ be the incoming gradient. I want to compute two $B \times B$ matrices that is equal to $x x^\top$ and $\delta \delta^\top$.

Both of these quantities are already computed for the backpass but as far as I can tell they are only available through the val fields of the Batchtraced shaped array. This becomes very inconvenient when the function has gone through several vmap which means I need to call .val several times to reach the full tensor and it's hard to know a priori how many levels of vmap it has gone through. Is it possible to make it easier to do something like this? Another related question is what is the best way to compute the gradient in custom_vjp in this case? Also is this generally a good idea for jit?

import jax

jnp = jax.numpy

@jax.custom_vjp
def linear_vjp(weight_MxN, x_N):
    return weight_MxN @ x_N

def linear_vjp_fwd(weight, x_N):
    y = weight_MxN @ x_N
    return y, (weight_MxN, x_N)

def linear_vjp_bwd(res, grad):
    weight_MxN, x_N = res

    if type(grad) != type(x_N):
        grad_weight = jnp.tile(grad[:, None], (1, x_N.val.shape[0])) @ x_N.val
    else:
        grad_weight = grad.val @ x_N.val

    grad_x = weight_MxN.T @ grad

    if type(grad) != type(x_N):
        delta_inner = grad @ grad
    elif grad.batch_dim == 0:
        delta_inner = grad.val @ grad.val.T
    else:
        delta_inner = grad.val.T @ grad.val # what is the best way to do this?

    activation_inner = x_N.val @ x_N.val.T
    # additional processing
    return grad_weight, grad_x

linear_vjp.defvjp(linear_vjp_fwd, linear_vjp_bwd)
@yidingjiang yidingjiang added the enhancement New feature or request label Aug 12, 2024
@dfm
Copy link
Member

dfm commented Aug 12, 2024

I'm not sure I totally follow the goal here, but your intuition that you don't want to rely on the .val attribute is a good one! Do you think you could put together a simpler end-to-end example of the expected behavior that I could run?

At a high level, I expect that the key assumption here is that this custom_vjp will always be called inside a vmap? Perhaps it would be better to define the linear_vjp function to directly operate on batched inputs rather than calling it from within a vmap?

I'm happy to try to give more concrete suggestions if you can put together a runnable demo!

@dfm dfm self-assigned this Aug 12, 2024
@yidingjiang
Copy link
Author

yidingjiang commented Aug 12, 2024

Thank you! I can explain a bit more. The goal I am trying to achieve is to compute the pairwise inner product between the gradient of every batch element. Normally, in the backward pass, the gradient to the weight would be computed as the matrix multiplication $\delta^\top x$ where $\delta \in \mathbb{R}^{B\times M}, x\in \mathbb{R}^{B\times N}$ where $B$ is the batch dimension. Within the bwd, I wish to do $\delta \delta^\top \in \mathbb{R}^{B\times B}$ which computes the (part of the) pairwise inner product between the gradient of the different batch elements so the outcome is not really a "batched" object anymore. I wish to integrate this into any arbitrary model I can define so the function may go through more than one vmap (e.g., $\delta \in \mathbb{R}^{B\times T\times M}$ in data with two batch axes) -- it would be inconvenient to have to rewrite all existing models to explicitly account for batching (and during inference none of this matters so forcing the function to be batched is not ideal). Below is an end-to-end example:

import numpy as np
import jax

jnp = jax.numpy
vmap = jax.vmap

@jax.custom_vjp
def linear_vjp(weight_MxN, x_N):
    return weight_MxN @ x_N

def linear_vjp_fwd(weight_MxN, x_N):
    y = weight_MxN @ x_N
    return y, (weight_MxN, x_N)

def linear_vjp_bwd(res, grad):
    weight_MxN, x_N = res

    if type(grad) != type(x_N):
        grad_weight = jnp.tile(grad[:, None], (1, x_N.val.shape[0])) @ x_N.val
    else:
        grad_weight = grad.val @ x_N.val

    grad_x = weight_MxN.T @ grad

    if type(grad) != type(x_N):
        delta_inner = grad @ grad
    elif grad.batch_dim == 0:
        delta_inner = grad.val @ grad.val.T
    else:
        delta_inner = grad.val.T @ grad.val # This line computes the pairwise inner product

    print(delta_inner) # This line prints out the pairwise inner product

    return grad_weight, grad_x

linear_vjp.defvjp(linear_vjp_fwd, linear_vjp_bwd)

x_BxN = jnp.array(np.ones((2, 2))) # B=2, N=2
w_MxN = jnp.array(np.ones((3, 2))) # M=3
target_BxM = jnp.array(np.zeros((2, 3)))

def loss_fn(w, x):
    def batch_linear(x):
        def partial(x):
            return linear_vjp(w, x)
        return vmap(partial)(x)
    return jnp.sum(batch_linear(x) - target_BxM)

value, grads = jax.value_and_grad(loss_fn, argnums=0)(w_MxN, x_BxN)

This script should print the $2\times 2$ pairwise inner product between the incoming gradients.

[[3. 3.]
 [3. 3.]]

@dfm
Copy link
Member

dfm commented Aug 13, 2024

I see. And I guess you're using delta_inner for some sort of logging, or does it get used for something else? You can probably hack something like this by combining a batched implementation with custom_vmap, but it depends a bit on what you want to do with delta_inner...

@yidingjiang
Copy link
Author

yidingjiang commented Aug 13, 2024

I plan to use it to modify the gradient of the weight (not shown here for brevity) so it modifies the backpass. I also want to get it out of the grad by passing in a dummy $B \times B$ matrix into the function and returning this matrix as its gradient, but I think this part should be relatively easy (I think one way to think about it is that it's like an additional "gradient meta data" for that parameter). Could you elaborate on custom_vmap? Another related question is do you foresee this breaking anything if I am using distributed training across 4 TPU nodes with sharding?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants