Skip to content

Think about Keras3 preference, should I develop my distributed training CTR model based on Jax or TensorFlow?  #19866

Open
@MoFHeka

Description

I am a developer of tensorflow recommenders-addons and I now need to develop an all-to-all embedding layer for multi-GPU distributed training of recommendation models. The old tensorflow distributed strategy clearly did not meet this need.
So the question is, should I develop on TF DTensor or Jax? Because it seems that Keras support for TF DTensor is not friendly. But Jax lacks the ability to online inference services and the functional components used by various recommendation algorithms. Also recommenders-addons has a lot of custom operators.

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions