Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to implement train_step with multiple gradient calculations with a JAX backend? Ex. GAN #18881

Open
craig-sony opened this issue Dec 4, 2023 · 5 comments
Assignees
Labels
stat:awaiting response from contributor type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.

Comments

@craig-sony
Copy link

I have been trying to figure out how to write a GAN using Keras 3 with a JAX backend using the stateless_call API.
I cannot figure out a clean way to deal with the need to have separate gradients computed for the discriminator and generator.
The only approach I've gotten close with is to create a mapping between the trainable_variables/non_trainable_variables lists used by the model and the corresponding layers. Then when I call stateless_call I first have to extract the corresponding trainable_variables/non_trainable_variables for the layer being called from those passed into the train_step function, but then I need to reinsert the non_trainable_variables returned by stateless_call. It's a mess.

Can you please update the following example for Keras 3?
https://keras.io/examples/generative/conditional_gan/

Thanks.

@fchollet
Copy link
Collaborator

fchollet commented Dec 4, 2023

I think you can use a StatelessScope and then just write a stateful train_step, which is 10x easier.

Probably something like

def train_step(self, state, data):
    (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
        metrics_variables,
    ) = state
    grad_fn_gen = jax.value_and_grad(self.compute_loss_and_updates_gen, has_aux=True)
    grad_fn_disc = jax.value_and_grad(self.compute_loss_and_updates_disc, has_aux=True)
    state_mapping = list(zip(self.trainable_variables, trainable_variables)) + list(zip(self.non_trainable_variables, non_trainable_variables))
    with keras.StatelessScope(state_mapping) as scope:
       (loss_gen, (y_pred_gen, non_trainable_variables_gen)), grads = grad_fn_gen(
           self.gen.trainable_variables,
           self.gen.non_trainable_variables,
           gen_x,
           gen_y,
           training=True,
       )
      (loss_disc, (y_pred_disc, non_trainable_variables_disc)), grads = grad_fn_disc(
           self.disc.trainable_variables,
           self.disc.non_trainable_variables,
           disc_x,
           disc_y,
           training=True,
       )
       ...
   trainable_variables = [scope.get_current_value(w)] for w in self.trainable_variables]
   non_trainable_variables = [scope.get_current_value(w) for w in self.non_trainable_variables]

You get the idea. Just set variable values with the scope and then you can use self.gen.variables, etc. At the scope exit you collect back the updated variable values and you return those.

@fchollet
Copy link
Collaborator

fchollet commented Dec 4, 2023

For a real-world example see how we handle stateful metrics in the JAX backend: https://github.com/keras-team/keras/blob/master/keras/backend/jax/trainer.py#L130-L145

In general, working with JAX statelessness is pretty terrible, so the solution is to open a StatelessScope and pretend everything is stateful 👍

@sachinprasadhs sachinprasadhs added type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited. stat:awaiting response from contributor labels Dec 4, 2023
Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Dec 19, 2023
@craig-sony
Copy link
Author

Thanks for the info, any chance of getting the example updated though?

@dhantule
Copy link
Contributor

dhantule commented Dec 27, 2024

Hi @craig-sony, Thanks for reporting this. The example has been updated and runs fine with keras 3 in this gist.

@dhantule dhantule added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels Dec 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting response from contributor type:support User is asking for help / asking an implementation question. Stackoverflow would be better suited.
Projects
None yet
Development

No branches or pull requests

4 participants