Skip to content

Commit

Permalink
Add column wise sharding support for EmbeddingCollection (sequence em…
Browse files Browse the repository at this point in the history
…beddings) (#432)

Summary:
Pull Request resolved: #432

Support for CW and TWCW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0, f_0, f_1, f_0, f_1]

        f_0 shard ranks = [0, 1, 0]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Differential Revision: D36944684

fbshipit-source-id: 67b5a4c3825583dd675926a37db79b79b7653a97
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Jun 17, 2022
1 parent 50c1b05 commit 99251a8
Showing 8 changed files with 254 additions and 79 deletions.
6 changes: 4 additions & 2 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
@@ -623,7 +623,7 @@ class PooledEmbeddingsAllToAll(nn.Module):
dim_sum_per_rank (List[int]): number of features (sum of dimensions) of the
embedding in each rank.
device (Optional[torch.device]): device on which buffers will be allocated.
callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]])
callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]):
Example::
@@ -673,6 +673,8 @@ def forward(
Args:
local_embs (torch.Tensor): tensor of values to distribute.
batch_size_per_rank (Optional[List[int]]): batch size per rank, to support
variable batch size.
Returns:
PooledEmbeddingsAwaitable: awaitable of pooled embeddings.
@@ -842,7 +844,7 @@ def _wait_impl(self) -> torch.Tensor:
Syncs sequence embeddings after collective operation.
Returns:
torch.Tensor: synced pooled embeddings.
torch.Tensor: synced sequence embeddings.
"""

ret = self._tensor_awaitable.wait()
196 changes: 137 additions & 59 deletions torchrec/distributed/embedding.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
@@ -126,6 +126,9 @@ def test_sequence_2_table_perf(self) -> None:
("batched_fused", "table_wise"): [0.001880471390093715],
("batched_fused_uvm", "table_wise"): [0.25958192114736517],
("batched_fused_uvm_caching", "table_wise"): [0.060433813055248066],
("batched_fused", "column_wise"): [0.001880471390093715],
("batched_fused_uvm", "column_wise"): [0.25958192114736517],
("batched_fused_uvm_caching", "column_wise"): [0.060433813055248066],
("batched_fused", "row_wise"): [
0.0007915177871551004,
0.0007915177871551004,
17 changes: 3 additions & 14 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@


from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, cast, Dict, List, Optional, Type

import torch
from torch import nn
@@ -42,7 +42,6 @@
ShardingEnv,
)
from torchrec.distributed.utils import filter_state_dict
from torchrec.modules.embedding_configs import EmbeddingTableConfig
from torchrec.quant.embedding_modules import (
EmbeddingCollection as QuantEmbeddingCollection,
)
@@ -62,7 +61,6 @@ def create_infer_embedding_sharding(
device: Optional[torch.device] = None,
) -> EmbeddingSharding[SparseFeaturesList, List[torch.Tensor]]:
if sharding_type == ShardingType.TABLE_WISE.value:

return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
else:
raise ValueError(f"Sharding type not supported {sharding_type}")
@@ -109,9 +107,6 @@ def __init__(
self._has_uninitialized_output_dist: bool = True

self._embedding_dim: int = module.embedding_dim
self._embedding_names_per_sharding: List[List[str]] = []
for sharding in self._sharding_type_to_sharding.values():
self._embedding_names_per_sharding.append(sharding.embedding_names())
self._need_indices: bool = module.need_indices

def _create_input_dist(
@@ -210,26 +205,20 @@ def compute(
def output_dist(
self, ctx: ShardedModuleContext, output: List[List[torch.Tensor]]
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
awaitables_per_sharding: List[Awaitable[Dict[str, JaggedTensor]]] = []
awaitables_per_sharding: List[Awaitable[torch.Tensor]] = []
features_before_all2all_per_sharding: List[KeyedJaggedTensor] = []
for odist, embeddings, sharding_ctx in zip(
self._output_dists,
output,
# pyre-ignore [16]
ctx.sharding_contexts,
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
):
awaitables_per_sharding.append(odist(embeddings, sharding_ctx))
features_before_all2all_per_sharding.append(
sharding_ctx.features_before_input_dist
)
return EmbeddingCollectionAwaitable(
# pyre-fixme[6]: For 1st param expected `List[Awaitable[Tensor]]` but
# got `List[Awaitable[Dict[str, JaggedTensor]]]`.
awaitables_per_sharding=awaitables_per_sharding,
features_per_sharding=features_before_all2all_per_sharding,
# pyre-fixme[6]: For 3rd param expected `List[str]` but got
# `List[List[str]]`.
embedding_names_per_sharding=self._embedding_names_per_sharding,
need_indices=self._need_indices,
)

67 changes: 67 additions & 0 deletions torchrec/distributed/sharding/cw_sequence_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import torch
from torchrec.distributed.embedding_lookup import GroupedEmbeddingsLookup
from torchrec.distributed.embedding_sharding import (
BaseEmbeddingLookup,
BaseSparseFeaturesDist,
)
from torchrec.distributed.embedding_types import (
BaseGroupedFeatureProcessor,
SparseFeatures,
)
from torchrec.distributed.sharding.cw_sharding import BaseCwEmbeddingSharding
from torchrec.distributed.sharding.sequence_sharding import BaseSequenceEmbeddingDist
from torchrec.distributed.sharding.tw_sequence_sharding import TwSequenceEmbeddingDist
from torchrec.distributed.sharding.tw_sharding import TwSparseFeaturesDist


class CwSequenceEmbeddingSharding(
BaseCwEmbeddingSharding[SparseFeatures, torch.Tensor]
):
"""
Shards sequence (unpooled) embeddings column-wise, i.e.. a given embedding is
partitioned along its columns and placed on specified ranks.
"""

def create_input_dist(
self,
device: Optional[torch.device] = None,
) -> BaseSparseFeaturesDist[SparseFeatures]:
return TwSparseFeaturesDist(
self._pg,
self._id_list_features_per_rank(),
self._id_score_list_features_per_rank(),
device if device is not None else self._device,
)

def create_lookup(
self,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
) -> BaseEmbeddingLookup:
assert feature_processor is None
return GroupedEmbeddingsLookup(
grouped_configs=self._grouped_embedding_configs,
fused_params=fused_params,
pg=self._pg,
device=device if device is not None else self._device,
)

def create_output_dist(
self,
device: Optional[torch.device] = None,
) -> BaseSequenceEmbeddingDist[torch.Tensor]:
return TwSequenceEmbeddingDist(
self._pg,
self._id_list_features_per_rank(),
device if device is not None else self._device,
)
2 changes: 1 addition & 1 deletion torchrec/distributed/sharding/twcw_sharding.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Tuple
from typing import List, Optional

import torch
from torchrec.distributed.embedding_sharding import EmbeddingShardingInfo
3 changes: 1 addition & 2 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum
from typing import Dict, List, Optional, Type, Union
from typing import Dict, List, Optional, Type

import torch.distributed as dist # noqa
import torch.nn as nn
39 changes: 38 additions & 1 deletion torchrec/distributed/tests/test_sequence_model_parallel.py
Original file line number Diff line number Diff line change
@@ -7,13 +7,14 @@


import unittest
from typing import List, Optional, Type
from typing import Dict, List, Optional, Type

import hypothesis.strategies as st
import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from hypothesis import given, settings, Verbosity
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.planner import ParameterConstraints
from torchrec.distributed.test_utils.multi_process import MultiProcessTestBase
from torchrec.distributed.test_utils.test_model import TestSparseNNBase
from torchrec.distributed.test_utils.test_sharding import sharding_single_rank_test
@@ -117,6 +118,40 @@ def test_sharding_nccl_tw(self, sharding_type: str, kernel_type: str) -> None:
backend="nccl",
)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
# pyre-fixme[56]
@given(
sharding_type=st.sampled_from(
[
ShardingType.COLUMN_WISE.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.BATCHED_DENSE.value,
EmbeddingComputeKernel.BATCHED_FUSED.value,
]
),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
def test_sharding_nccl_cw(self, sharding_type: str, kernel_type: str) -> None:
self._test_sharding(
sharders=[
TestEmbeddingCollectionSharder(
sharding_type=sharding_type,
kernel_type=kernel_type,
)
],
backend="nccl",
constraints={
table.name: ParameterConstraints(min_partition=4)
for table in self.tables
},
)

@seed_and_log
def setUp(self) -> None:
super().setUp()
@@ -142,6 +177,7 @@ def _test_sharding(
backend: str = "gloo",
world_size: int = 2,
local_size: Optional[int] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
) -> None:
self._run_multi_process_test(
@@ -154,4 +190,5 @@ def _test_sharding(
sharders=sharders,
optim=EmbOptimType.EXACT_SGD,
backend=backend,
constraints=constraints,
)

0 comments on commit 99251a8

Please sign in to comment.