-
Notifications
You must be signed in to change notification settings - Fork 468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add column wise sharding support for EmbeddingCollection (sequence embeddings) #432
Conversation
This pull request was exported from Phabricator. Differential Revision: D36944684 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D36944684 |
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Respects the correct order according to sharding as the output loses order of shards (provides it in order based on rank). Differential Revision: D36944684 fbshipit-source-id: 12e698093d0aa5e09441084d2eba77e08b53d9f1
1fa47b7
to
2b1710f
Compare
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Respects the correct order according to sharding as the output loses order of shards (provides it in order based on rank). Differential Revision: D36944684 fbshipit-source-id: 88a9e08989cf85739a0ec815838dfe34d140690a
2b1710f
to
9242c92
Compare
This pull request was exported from Phabricator. Differential Revision: D36944684 |
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank. i.e. rank 0: [f_0, f_0, f_1] rank 1: [f_0, f_1] output: [f_0, f_0, f_1, f_0, f_1] f_0 shard ranks = [0, 1, 0] Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1]. To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks. Differential Revision: D36944684 fbshipit-source-id: f196a5c85a27beaae87d2bc7f06c784064e3bb67
This pull request was exported from Phabricator. Differential Revision: D36944684 |
9242c92
to
0ab6312
Compare
This pull request was exported from Phabricator. Differential Revision: D36944684 |
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW and TWCW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank. i.e. rank 0: [f_0, f_0, f_1] rank 1: [f_0, f_1] output: [f_0, f_0, f_1, f_0, f_1] f_0 shard ranks = [0, 1, 0] Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1]. To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks. Differential Revision: D36944684 fbshipit-source-id: 30ac10cae786df3a955cf2b8f44bcb77e6088cd6
0ab6312
to
137229e
Compare
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW and TWCW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank. i.e. rank 0: [f_0, f_0, f_1] rank 1: [f_0, f_1] output: [f_0, f_0, f_1, f_0, f_1] f_0 shard ranks = [0, 1, 0] Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1]. To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks. Differential Revision: D36944684 fbshipit-source-id: 67b5a4c3825583dd675926a37db79b79b7653a97
This pull request was exported from Phabricator. Differential Revision: D36944684 |
137229e
to
99251a8
Compare
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW and TWCW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank. i.e. rank 0: [f_0, f_0, f_1] rank 1: [f_0, f_1] output: [f_0, f_0, f_1, f_0, f_1] f_0 shard ranks = [0, 1, 0] Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1]. To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks. Differential Revision: D36944684 fbshipit-source-id: 2c5c6a5d6742eed396c766bfa3eff4f4446a5b9a
99251a8
to
0b62f63
Compare
This pull request was exported from Phabricator. Differential Revision: D36944684 |
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW and TWCW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank. i.e. rank 0: [f_0, f_0, f_1] rank 1: [f_0, f_1] output: [f_0, f_0, f_1, f_0, f_1] f_0 shard ranks = [0, 1, 0] Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1]. To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks. Differential Revision: D36944684 fbshipit-source-id: 95afdb66589d52665ef351e40515e4ff2a0363a6
0b62f63
to
358f095
Compare
This pull request was exported from Phabricator. Differential Revision: D36944684 |
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW and TWCW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank. i.e. rank 0: [f_0, f_0, f_1] rank 1: [f_0, f_1] output: [f_0(shard_0), f_0(shard_2), f_1(shard_0), f_0(shard_1), f_1(shard_1)] f_0 shard ranks = [0, 1, 0] # [f_0(shard_0) = rank0, f_0(shard_1) = rank0, f_0(shard_2) = rank1] f_0 output ranks = [0, 0, 1] # [f_0(shard_0) = rank0, f_0(shard_2) = rank0, f_0(shard_1) = rank0] # To get the correct order of outputs we want permute indices for output_ranks -> shard_ranks permute_indices = [0, 2, 1] Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1]. To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks. Differential Revision: D36944684 fbshipit-source-id: ccfcc552d9b04e858335ed28f915a9811f727bba
358f095
to
ae95d9f
Compare
This pull request was exported from Phabricator. Differential Revision: D36944684 |
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW and TWCW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank. i.e. rank 0: [f_0, f_0, f_1] rank 1: [f_0, f_1] output: [f_0(shard_0), f_0(shard_2), f_1(shard_0), f_0(shard_1), f_1(shard_1)] f_0 shard ranks = [0, 1, 0] # [f_0(shard_0) = rank0, f_0(shard_1) = rank0, f_0(shard_2) = rank1] f_0 output ranks = [0, 0, 1] # [f_0(shard_0) = rank0, f_0(shard_2) = rank0, f_0(shard_1) = rank0] # To get the correct order of outputs we want permute indices for output_ranks -> shard_ranks permute_indices = [0, 2, 1] Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1]. To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks. Reviewed By: dstaay-fb Differential Revision: D36944684 fbshipit-source-id: bd579b41f21c51de4f99b653e6ae34779e783717
ae95d9f
to
169f291
Compare
This pull request was exported from Phabricator. Differential Revision: D36944684 |
Summary: Refactor `SequenceEmbeddingAllToAll` -> `SequenceEmbeddingsAllToAll` to match convention of other modules incl. SequenceEmbeddingsAwaitable, PooledEmbeddingsAllToAll, PooledEmbeddingsAwaitable Differential Revision: D37291358 fbshipit-source-id: bfe90316a656f7f8c095e4498adc71d4eb836875
…beddings) (pytorch#432) Summary: Pull Request resolved: pytorch#432 Support for CW and TWCW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank. i.e. rank 0: [f_0, f_0, f_1] rank 1: [f_0, f_1] output: [f_0(shard_0), f_0(shard_2), f_1(shard_0), f_0(shard_1), f_1(shard_1)] f_0 shard ranks = [0, 1, 0] # [f_0(shard_0) = rank0, f_0(shard_1) = rank0, f_0(shard_2) = rank1] f_0 output ranks = [0, 0, 1] # [f_0(shard_0) = rank0, f_0(shard_2) = rank0, f_0(shard_1) = rank0] # To get the correct order of outputs we want permute indices for output_ranks -> shard_ranks permute_indices = [0, 2, 1] Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1]. To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks. Reviewed By: dstaay-fb Differential Revision: D36944684 fbshipit-source-id: f5af89e1fe01dc9ee2e9a72ed43adaf4a06d24e0
169f291
to
19c6444
Compare
This pull request was exported from Phabricator. Differential Revision: D36944684 |
Summary:
Support for CW sharding in EmbeddingCollection.
Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.
Respects the correct order according to sharding as the output loses order of shards (provides it in order based on rank).
Differential Revision: D36944684