-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to implement train_step with multiple gradient calculations with a JAX backend? Ex. GAN #18881
Comments
I think you can use a Probably something like def train_step(self, state, data):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
) = state
grad_fn_gen = jax.value_and_grad(self.compute_loss_and_updates_gen, has_aux=True)
grad_fn_disc = jax.value_and_grad(self.compute_loss_and_updates_disc, has_aux=True)
state_mapping = list(zip(self.trainable_variables, trainable_variables)) + list(zip(self.non_trainable_variables, non_trainable_variables))
with keras.StatelessScope(state_mapping) as scope:
(loss_gen, (y_pred_gen, non_trainable_variables_gen)), grads = grad_fn_gen(
self.gen.trainable_variables,
self.gen.non_trainable_variables,
gen_x,
gen_y,
training=True,
)
(loss_disc, (y_pred_disc, non_trainable_variables_disc)), grads = grad_fn_disc(
self.disc.trainable_variables,
self.disc.non_trainable_variables,
disc_x,
disc_y,
training=True,
)
...
trainable_variables = [scope.get_current_value(w)] for w in self.trainable_variables]
non_trainable_variables = [scope.get_current_value(w) for w in self.non_trainable_variables] You get the idea. Just set variable values with the scope and then you can use |
For a real-world example see how we handle stateful metrics in the JAX backend: https://github.com/keras-team/keras/blob/master/keras/backend/jax/trainer.py#L130-L145 In general, working with JAX statelessness is pretty terrible, so the solution is to open a StatelessScope and pretend everything is stateful 👍 |
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
Thanks for the info, any chance of getting the example updated though? |
Hi @craig-sony, Thanks for reporting this. The example has been updated and runs fine with keras 3 in this gist. |
I have been trying to figure out how to write a GAN using Keras 3 with a JAX backend using the
stateless_call
API.I cannot figure out a clean way to deal with the need to have separate gradients computed for the discriminator and generator.
The only approach I've gotten close with is to create a mapping between the trainable_variables/non_trainable_variables lists used by the model and the corresponding layers. Then when I call
stateless_call
I first have to extract the corresponding trainable_variables/non_trainable_variables for the layer being called from those passed into thetrain_step
function, but then I need to reinsert thenon_trainable_variables
returned bystateless_call
. It's a mess.Can you please update the following example for Keras 3?
https://keras.io/examples/generative/conditional_gan/
Thanks.
The text was updated successfully, but these errors were encountered: