Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add column wise sharding support for EmbeddingCollection (sequence embeddings) #432

Closed
wants to merge 2 commits into from

Conversation

joshuadeng
Copy link
Contributor

Summary:
Support for CW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.
Respects the correct order according to sharding as the output loses order of shards (provides it in order based on rank).

Differential Revision: D36944684

@facebook-github-bot facebook-github-bot added CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported labels Jun 14, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jun 14, 2022
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.
Respects the correct order according to sharding as the output loses order of shards (provides it in order based on rank).

Differential Revision: D36944684

fbshipit-source-id: 12e698093d0aa5e09441084d2eba77e08b53d9f1
joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jun 15, 2022
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.
Respects the correct order according to sharding as the output loses order of shards (provides it in order based on rank).

Differential Revision: D36944684

fbshipit-source-id: 88a9e08989cf85739a0ec815838dfe34d140690a
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jun 15, 2022
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0, f_0, f_1, f_0, f_1]

        f_0 shard ranks = [0, 1, 0]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Differential Revision: D36944684

fbshipit-source-id: f196a5c85a27beaae87d2bc7f06c784064e3bb67
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jun 16, 2022
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW and TWCW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0, f_0, f_1, f_0, f_1]

        f_0 shard ranks = [0, 1, 0]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Differential Revision: D36944684

fbshipit-source-id: 30ac10cae786df3a955cf2b8f44bcb77e6088cd6
joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jun 17, 2022
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW and TWCW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0, f_0, f_1, f_0, f_1]

        f_0 shard ranks = [0, 1, 0]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Differential Revision: D36944684

fbshipit-source-id: 67b5a4c3825583dd675926a37db79b79b7653a97
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jun 20, 2022
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW and TWCW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0, f_0, f_1, f_0, f_1]

        f_0 shard ranks = [0, 1, 0]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Differential Revision: D36944684

fbshipit-source-id: 2c5c6a5d6742eed396c766bfa3eff4f4446a5b9a
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jun 20, 2022
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW and TWCW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0, f_0, f_1, f_0, f_1]

        f_0 shard ranks = [0, 1, 0]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Differential Revision: D36944684

fbshipit-source-id: 95afdb66589d52665ef351e40515e4ff2a0363a6
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jun 20, 2022
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW and TWCW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0(shard_0), f_0(shard_2), f_1(shard_0), f_0(shard_1), f_1(shard_1)]

        f_0 shard ranks = [0, 1, 0]  # [f_0(shard_0) = rank0, f_0(shard_1) = rank0, f_0(shard_2) = rank1]

        f_0 output ranks = [0, 0, 1]  # [f_0(shard_0) = rank0, f_0(shard_2) = rank0, f_0(shard_1) = rank0]

        # To get the correct order of outputs we want permute indices for output_ranks -> shard_ranks
        permute_indices = [0, 2, 1]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Differential Revision: D36944684

fbshipit-source-id: ccfcc552d9b04e858335ed28f915a9811f727bba
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

joshuadeng added a commit to joshuadeng/torchrec that referenced this pull request Jun 20, 2022
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW and TWCW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0(shard_0), f_0(shard_2), f_1(shard_0), f_0(shard_1), f_1(shard_1)]

        f_0 shard ranks = [0, 1, 0]  # [f_0(shard_0) = rank0, f_0(shard_1) = rank0, f_0(shard_2) = rank1]

        f_0 output ranks = [0, 0, 1]  # [f_0(shard_0) = rank0, f_0(shard_2) = rank0, f_0(shard_1) = rank0]

        # To get the correct order of outputs we want permute indices for output_ranks -> shard_ranks
        permute_indices = [0, 2, 1]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Reviewed By: dstaay-fb

Differential Revision: D36944684

fbshipit-source-id: bd579b41f21c51de4f99b653e6ae34779e783717
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

Summary: Refactor `SequenceEmbeddingAllToAll` -> `SequenceEmbeddingsAllToAll` to match convention of other modules incl. SequenceEmbeddingsAwaitable, PooledEmbeddingsAllToAll, PooledEmbeddingsAwaitable

Differential Revision: D37291358

fbshipit-source-id: bfe90316a656f7f8c095e4498adc71d4eb836875
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW and TWCW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0(shard_0), f_0(shard_2), f_1(shard_0), f_0(shard_1), f_1(shard_1)]

        f_0 shard ranks = [0, 1, 0]  # [f_0(shard_0) = rank0, f_0(shard_1) = rank0, f_0(shard_2) = rank1]

        f_0 output ranks = [0, 0, 1]  # [f_0(shard_0) = rank0, f_0(shard_2) = rank0, f_0(shard_1) = rank0]

        # To get the correct order of outputs we want permute indices for output_ranks -> shard_ranks
        permute_indices = [0, 2, 1]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Reviewed By: dstaay-fb

Differential Revision: D36944684

fbshipit-source-id: f5af89e1fe01dc9ee2e9a72ed43adaf4a06d24e0
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D36944684

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants