Skip to content

Add auto variable sharding for all backbones/tasks #1689

Open
@mattdangerw

Description

We want model parallelism to be easy to use across the library. At a high level, a user should express their hardware, and (possibly) desired model parallel vs data parallel split for the device grid.

Currently, we have a auto layer helper for Gemma here, but it is not a salable design. The correct layout map will depend on the config of the model. E.g. you need to shard a Gemma model with multi-head-attention differently then multi-query-attention.

I think there's two main direction we can go with the API.

  1. Write our own manual sharing for a model given the config for a model. Do this for all models (most will have the same recipe, especially for our transformer models).
  2. Use some form of autosharding functionality in Jax, or add a autosharding API to Keras. In this case, we will not need to write the sharding recipes ourselves per model.

One potential high-level API would be to directly take in a device mesh when constructing the model. For both 1) and 2), we could support an API something like this...

device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'), devices=devices)
gemma_model = keras_nlp.models.GemmaCausalLM.from_preset(
    "gemma_2b_en",
    device_mesh=device_mesh,
)

For 1) we would need to enter into a LayoutMap scope after loading the config for a model but before loading the weights. For 2) it would depend on the details of the autosharding API we use.

Metadata

Labels

GemmaGemma model specific issuesteam-createdIssues created by Keras Hub team as part of development roadmap.type:featureNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions