Skip to content

Commit

Permalink
Add column wise sharding support for EmbeddingCollection (sequence em…
Browse files Browse the repository at this point in the history
…beddings) (#432)

Summary:
Pull Request resolved: #432

Support for CW and TWCW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.

Outputs are stored in order of rank, however, in column-wise sharding, there can be multiple shards of a table on the same rank and thereby multiple outputs on the same rank.

i.e.
        rank 0: [f_0, f_0, f_1]
        rank 1: [f_0, f_1]
        output: [f_0(shard_0), f_0(shard_2), f_1(shard_0), f_0(shard_1), f_1(shard_1)]

        f_0 shard ranks = [0, 1, 0]  # [f_0(shard_0) = rank0, f_0(shard_1) = rank0, f_0(shard_2) = rank1]

        f_0 output ranks = [0, 0, 1]  # [f_0(shard_0) = rank0, f_0(shard_2) = rank0, f_0(shard_1) = rank0]

        # To get the correct order of outputs we want permute indices for output_ranks -> shard_ranks
        permute_indices = [0, 2, 1]

Since outputs are stored by rank, the inter-shard order is lost and the shards on rank 0 would be combined first, making an incorrect combination of f_0's output with the shard ranks = [0, 0, 1].

To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate permute indices for each feature to match the shard ranks.

Reviewed By: dstaay-fb

Differential Revision: D36944684

fbshipit-source-id: 10f24419e896db0fcd57540a61b767fc62d5e50d
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Jun 21, 2022
1 parent 1b90350 commit 54bcaa8
Showing 7 changed files with 254 additions and 76 deletions.
199 changes: 141 additions & 58 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
@@ -7,9 +7,21 @@


import copy
from collections import OrderedDict
from collections import defaultdict, deque, OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Type, Union
from typing import (
Any,
cast,
Dict,
Iterator,
List,
MutableMapping,
Optional,
Set,
Tuple,
Type,
Union,
)

import torch
from torch import nn
@@ -27,6 +39,9 @@
SparseFeatures,
SparseFeaturesList,
)
from torchrec.distributed.sharding.cw_sequence_sharding import (
CwSequenceEmbeddingSharding,
)
from torchrec.distributed.sharding.dp_sequence_sharding import (
DpSequenceEmbeddingSharding,
)
@@ -46,9 +61,14 @@
ShardedModuleContext,
ShardedTensor,
ShardingEnv,
ShardMetadata,
)
from torchrec.distributed.utils import append_prefix, filter_state_dict
from torchrec.modules.embedding_configs import EmbeddingTableConfig, PoolingType
from torchrec.modules.embedding_configs import (
EmbeddingConfig,
EmbeddingTableConfig,
PoolingType,
)
from torchrec.modules.embedding_modules import (
EmbeddingCollection,
EmbeddingCollectionInterface,
@@ -76,6 +96,8 @@ def create_embedding_sharding(
return RwSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.DATA_PARALLEL.value:
return DpSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return CwSequenceEmbeddingSharding(sharding_infos, env, device)
else:
raise ValueError(f"Sharding not supported {sharding_type}")

@@ -134,27 +156,46 @@ def create_sharding_infos_by_sharding(
def _construct_jagged_tensors(
embeddings: torch.Tensor,
features: KeyedJaggedTensor,
embedding_names: List[str],
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
) -> Dict[str, JaggedTensor]:
ret: Dict[str, JaggedTensor] = {}
lengths = features.lengths().view(-1, features.stride())
values = features.values()
length_per_key = features.length_per_key()
values_list = torch.split(values, length_per_key) if need_indices else None
embeddings_list = torch.split(embeddings, length_per_key, dim=0)
stride = features.stride()
length_per_key = features.length_per_key()
values = features.values()

lengths = features.lengths().view(-1, stride)
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
embeddings_list = torch.split(embeddings, length_per_key, dim=0)
values_list = torch.split(values, length_per_key) if need_indices else None

key_indices = defaultdict(list)
for i, key in enumerate(features.keys()):
key_indices[key].append(i)
for key, indices in key_indices.items():
# combines outputs in correct order for CW sharding
indices = (
_permute_indices(indices, features_to_permute_indices[key])
if features_to_permute_indices and key in features_to_permute_indices
else indices
)
ret[key] = JaggedTensor(
lengths=lengths_tuple[i],
values=embeddings_list[i],
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
weights=values_list[i] if need_indices else None,
lengths=lengths_tuple[indices[0]],
values=embeddings_list[indices[0]]
if len(indices) == 1
else torch.cat([embeddings_list[i] for i in indices], dim=1),
weights=values_list[indices[0]] if values_list else None,
)
return ret


def _permute_indices(indices: List[int], permute: List[int]) -> List[int]:
permuted_indices = [0] * len(indices)
for i, permuted_index in enumerate(permute):
permuted_indices[i] = indices[permuted_index]
return permuted_indices


@dataclass
class EmbeddingCollectionContext(ShardedModuleContext):
sharding_contexts: List[SequenceShardingContext]
@@ -169,24 +210,28 @@ def __init__(
self,
awaitables_per_sharding: List[Awaitable[torch.Tensor]],
features_per_sharding: List[KeyedJaggedTensor],
embedding_names_per_sharding: List[str],
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
) -> None:
super().__init__()
self._awaitables_per_sharding = awaitables_per_sharding
self._features_per_sharding = features_per_sharding
self._embedding_names_per_sharding = embedding_names_per_sharding
self._need_indices = need_indices
self._features_to_permute_indices = features_to_permute_indices

def _wait_impl(self) -> Dict[str, JaggedTensor]:
jt_dict: Dict[str, JaggedTensor] = {}
for w, f, e in zip(
for w, f in zip(
self._awaitables_per_sharding,
self._features_per_sharding,
self._embedding_names_per_sharding,
):
jt_dict.update(
_construct_jagged_tensors(w.wait(), f, e, self._need_indices)
_construct_jagged_tensors(
embeddings=w.wait(),
features=f,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
)
)
return jt_dict

@@ -195,7 +240,7 @@ class ShardedEmbeddingCollection(
ShardedModule[
SparseFeaturesList,
List[torch.Tensor],
Dict[str, torch.Tensor],
Dict[str, JaggedTensor],
],
FusedOptimizerModule,
):
@@ -243,19 +288,63 @@ def __init__(
for _, m in lookup.named_modules():
if isinstance(m, FusedOptimizerModule):
# modify param keys to match EmbeddingCollection
params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {}
params: MutableMapping[str, Union[torch.Tensor, ShardedTensor]] = {}
for param_key, weight in m.fused_optimizer.params.items():
# pyre-fixme[16]: `Mapping` has no attribute `__setitem__`.
params["embeddings." + param_key] = weight
m.fused_optimizer.params = params
optims.append(("", m.fused_optimizer))
self._optim: CombinedOptimizer = CombinedOptimizer(optims)
self._embedding_dim: int = module.embedding_dim
self._embedding_names_per_sharding: List[List[str]] = []
for sharding in self._sharding_type_to_sharding.values():
self._embedding_names_per_sharding.append(sharding.embedding_names())
self._local_embedding_dim: int = self._embedding_dim
self._features_to_permute_indices: Dict[str, List[int]] = {}
if ShardingType.COLUMN_WISE.value in self._sharding_type_to_sharding:
sharding = self._sharding_type_to_sharding[ShardingType.COLUMN_WISE.value]
# CW partition must be same for all CW sharded parameters
self._local_embedding_dim = cast(
ShardMetadata, sharding.embedding_shard_metadata()[0]
).shard_sizes[1]
self._generate_permute_indices_per_feature(
module.embedding_configs, table_name_to_parameter_sharding
)
self._need_indices: bool = module.need_indices

def _generate_permute_indices_per_feature(
self,
embedding_configs: List[EmbeddingConfig],
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
) -> None:
"""
Generates permute indices per feature for column-wise sharding.
Since outputs are stored in order of rank, column-wise shards of a table on the
same rank will be seen as adjacent, which may not be correct.
The permute indices store the correct ordering of outputs relative to the
provided ordering.
Example::
rank_0 = [f_0(shard_0), f_0(shard_2)]
rank_1 = [f_0(shard_1)]
output = [f_0(shard_0), f_0(shard_2), f_0(shard_1)]
shard_ranks = [0, 1, 0]
output_ranks = [0, 0, 1]
# To get the correct order from output_ranks -> shard_ranks
permute_indices = [0, 2, 1]
"""
for table in embedding_configs:
sharding = table_name_to_parameter_sharding[table.name]
if sharding.sharding_type != ShardingType.COLUMN_WISE.value:
continue
ranks = cast(List[int], sharding.ranks)
rank_to_indices = defaultdict(deque)
for i, rank in enumerate(sorted(ranks)):
rank_to_indices[rank].append(i)
permute_indices = [rank_to_indices[rank].popleft() for rank in ranks]
for feature_name in table.feature_names:
self._features_to_permute_indices[feature_name] = permute_indices

def _create_input_dist(
self,
input_feature_names: List[str],
@@ -296,25 +385,22 @@ def input_dist(
features: KeyedJaggedTensor,
) -> Awaitable[SparseFeaturesList]:
if self._has_uninitialized_input_dist:
self._create_input_dist(
input_feature_names=features.keys() if features is not None else []
)
self._create_input_dist(input_feature_names=features.keys())
self._has_uninitialized_input_dist = False
with torch.no_grad():
features_by_sharding = []
if self._features_order:
features = features.permute(
self._features_order,
# pyre-ignore [6]
self._features_order_tensor,
)
features_by_sharding = features.split(
features_by_shards = features.split(
self._feature_splits,
)
# save input splits and output splits in sharding context which
# will be reused in sequence embedding all2all
awaitables = []
for module, features in zip(self._input_dists, features_by_sharding):
for module, features in zip(self._input_dists, features_by_shards):
tensor_awaitable = module(
SparseFeatures(
id_list_features=features,
@@ -352,82 +438,78 @@ def compute(
self, ctx: ShardedModuleContext, dist_input: SparseFeaturesList
) -> List[torch.Tensor]:
ret: List[torch.Tensor] = []
for lookup, features, sharding_ctx in zip(
for lookup, features, sharding_ctx, sharding_type in zip(
self._lookups,
dist_input,
# pyre-ignore [16]
ctx.sharding_contexts,
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
self._sharding_type_to_sharding,
):
sharding_ctx.lengths_after_input_dist = (
features.id_list_features.lengths().view(
-1, features.id_list_features.stride()
)
)
ret.append(lookup(features).view(-1, self._embedding_dim))
embedding_dim = self._embedding_dim_for_sharding_type(sharding_type)
ret.append(lookup(features).view(-1, embedding_dim))
return ret

def output_dist(
self, ctx: ShardedModuleContext, output: List[torch.Tensor]
) -> LazyAwaitable[Dict[str, torch.Tensor]]:
awaitables_per_sharding: List[Awaitable[Dict[str, JaggedTensor]]] = []
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
awaitables_per_sharding: List[Awaitable[torch.Tensor]] = []
features_before_all2all_per_sharding: List[KeyedJaggedTensor] = []
for odist, embeddings, sharding_ctx in zip(
self._output_dists,
output,
# pyre-ignore [16]
ctx.sharding_contexts,
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
):
awaitables_per_sharding.append(odist(embeddings, sharding_ctx))
features_before_all2all_per_sharding.append(
sharding_ctx.features_before_input_dist
)
# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got
# `EmbeddingCollectionAwaitable`.
return EmbeddingCollectionAwaitable(
# pyre-fixme[6]: For 1st param expected `List[Awaitable[Tensor]]` but
# got `List[Awaitable[Dict[str, JaggedTensor]]]`.
awaitables_per_sharding=awaitables_per_sharding,
features_per_sharding=features_before_all2all_per_sharding,
# pyre-fixme[6]: For 3rd param expected `List[str]` but got
# `List[List[str]]`.
embedding_names_per_sharding=self._embedding_names_per_sharding,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
)

def compute_and_output_dist(
self, ctx: ShardedModuleContext, input: SparseFeaturesList
) -> LazyAwaitable[Dict[str, torch.Tensor]]:
awaitables_per_sharding: List[Awaitable[Dict[str, JaggedTensor]]] = []
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
awaitables_per_sharding: List[Awaitable[torch.Tensor]] = []
features_before_all2all_per_sharding: List[KeyedJaggedTensor] = []
for lookup, odist, features, sharding_ctx in zip(
for lookup, odist, features, sharding_ctx, sharding_type in zip(
self._lookups,
self._output_dists,
input,
# pyre-ignore [16]
ctx.sharding_contexts,
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
self._sharding_type_to_sharding,
):
sharding_ctx.lengths_after_input_dist = (
features.id_list_features.lengths().view(
-1, features.id_list_features.stride()
)
)
embedding_dim = self._embedding_dim_for_sharding_type(sharding_type)
awaitables_per_sharding.append(
odist(lookup(features).view(-1, self._embedding_dim), sharding_ctx)
odist(lookup(features).view(-1, embedding_dim), sharding_ctx)
)
features_before_all2all_per_sharding.append(
sharding_ctx.features_before_input_dist
)
# pyre-fixme[7]: Expected `LazyAwaitable[Dict[str, Tensor]]` but got
# `EmbeddingCollectionAwaitable`.
return EmbeddingCollectionAwaitable(
# pyre-fixme[6]: For 1st param expected `List[Awaitable[Tensor]]` but
# got `List[Awaitable[Dict[str, JaggedTensor]]]`.
awaitables_per_sharding=awaitables_per_sharding,
features_per_sharding=features_before_all2all_per_sharding,
# pyre-fixme[6]: For 3rd param expected `List[str]` but got
# `List[List[str]]`.
embedding_names_per_sharding=self._embedding_names_per_sharding,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
)

def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
return (
self._local_embedding_dim
if sharding_type == ShardingType.COLUMN_WISE.value
else self._embedding_dim
)

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
@@ -546,6 +628,7 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
types = [
ShardingType.DATA_PARALLEL.value,
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
ShardingType.ROW_WISE.value,
]
return types
3 changes: 3 additions & 0 deletions torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
@@ -126,6 +126,9 @@ def test_sequence_2_table_perf(self) -> None:
("fused", "table_wise"): [0.001880471390093715],
("fused_uvm", "table_wise"): [0.25958192114736517],
("fused_uvm_caching", "table_wise"): [0.060433813055248066],
("fused", "column_wise"): [0.001880471390093715],
("fused_uvm", "column_wise"): [0.25958192114736517],
("fused_uvm_caching", "column_wise"): [0.060433813055248066],
("fused", "row_wise"): [
0.0007915177871551004,
0.0007915177871551004,
Loading

0 comments on commit 54bcaa8

Please sign in to comment.