Skip to content

Commit

Permalink
Add column wise sharding support for EmbeddingCollection (sequence em…
Browse files Browse the repository at this point in the history
…beddings) (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

Support for CW sharding in EmbeddingCollection.

Added logic to stitch feature outputs with local embedding dim after output dist to match original dim.
Respects the correct order according to sharding as the output loses order of shards (provides it in order based on rank).

Differential Revision: D36944684

fbshipit-source-id: 88a9e08989cf85739a0ec815838dfe34d140690a
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Jun 15, 2022
1 parent b8a1132 commit 9242c92
Showing 8 changed files with 251 additions and 48 deletions.
6 changes: 4 additions & 2 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
@@ -619,7 +619,7 @@ class PooledEmbeddingsAllToAll(nn.Module):
dim_sum_per_rank (List[int]): number of features (sum of dimensions) of the
embedding in each rank.
device (Optional[torch.device]): device on which buffers will be allocated.
callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]])
callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]):
Example::
@@ -669,6 +669,8 @@ def forward(
Args:
local_embs (torch.Tensor): tensor of values to distribute.
batch_size_per_rank (Optional[List[int]]): batch size per rank, to support
variable batch size.
Returns:
PooledEmbeddingsAwaitable: awaitable of pooled embeddings.
@@ -838,7 +840,7 @@ def _wait_impl(self) -> torch.Tensor:
Syncs sequence embeddings after collective operation.
Returns:
torch.Tensor: synced pooled embeddings.
torch.Tensor: synced sequence embeddings.
"""

ret = self._tensor_awaitable.wait()
175 changes: 137 additions & 38 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
@@ -7,9 +7,21 @@


import copy
from collections import OrderedDict
from collections import defaultdict, deque, OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Type, Union
from typing import (
Any,
cast,
Dict,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
Union,
)

import torch
from torch import nn
@@ -27,6 +39,9 @@
SparseFeatures,
SparseFeaturesList,
)
from torchrec.distributed.sharding.cw_sequence_sharding import (
CwSequenceEmbeddingSharding,
)
from torchrec.distributed.sharding.dp_sequence_sharding import (
DpSequenceEmbeddingSharding,
)
@@ -46,6 +61,7 @@
ShardedModuleContext,
ShardedTensor,
ShardingEnv,
ShardMetadata,
)
from torchrec.distributed.utils import append_prefix, filter_state_dict
from torchrec.modules.embedding_configs import EmbeddingTableConfig, PoolingType
@@ -76,6 +92,8 @@ def create_embedding_sharding(
return RwSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.DATA_PARALLEL.value:
return DpSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return CwSequenceEmbeddingSharding(sharding_infos, env, device)
else:
raise ValueError(f"Sharding not supported {sharding_type}")

@@ -134,26 +152,45 @@ def create_sharding_infos_by_sharding(
def _construct_jagged_tensors(
embeddings: torch.Tensor,
features: KeyedJaggedTensor,
embedding_names: List[str],
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
) -> Dict[str, JaggedTensor]:
ret: Dict[str, JaggedTensor] = {}
lengths = features.lengths().view(-1, features.stride())
values = features.values()
length_per_key = features.length_per_key()
values_list = torch.split(values, length_per_key) if need_indices else None
embeddings_list = torch.split(embeddings, length_per_key, dim=0)
stride = features.stride()
length_per_key = features.length_per_key()
values = features.values()

lengths = features.lengths().view(-1, stride)
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
embeddings_list = torch.split(embeddings, length_per_key, dim=0)
values_list = torch.split(values, length_per_key) if need_indices else None

key_indices = defaultdict(list)
for i, key in enumerate(features.keys()):
key_indices[key].append(i)
for key, indices in key_indices.items():
# combines feature outputs in correct order of shards
indices = (
_permute_indices(indices, features_to_permute_indices[key])
if features_to_permute_indices and key in features_to_permute_indices
else indices
)
feature_embeddings = [embeddings_list[i] for i in indices]
ret[key] = JaggedTensor(
lengths=lengths_tuple[i],
values=embeddings_list[i],
weights=values_list[i] if need_indices else None,
lengths=lengths_tuple[indices[0]],
values=torch.cat(feature_embeddings, dim=1),
weights=values_list[indices[0]] if need_indices else None,
)
return ret


def _permute_indices(indices: List[int], permute: List[int]) -> List[int]:
permuted_indices = [0] * len(indices)
for i, permuted_index in enumerate(permute):
permuted_indices[i] = indices[permuted_index]
return permuted_indices


@dataclass
class EmbeddingCollectionContext(ShardedModuleContext):
sharding_contexts: List[SequenceShardingContext]
@@ -168,24 +205,28 @@ def __init__(
self,
awaitables_per_sharding: List[Awaitable[torch.Tensor]],
features_per_sharding: List[KeyedJaggedTensor],
embedding_names_per_sharding: List[str],
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
) -> None:
super().__init__()
self._awaitables_per_sharding = awaitables_per_sharding
self._features_per_sharding = features_per_sharding
self._embedding_names_per_sharding = embedding_names_per_sharding
self._need_indices = need_indices
self._features_to_permute_indices = features_to_permute_indices

def _wait_impl(self) -> Dict[str, JaggedTensor]:
jt_dict: Dict[str, JaggedTensor] = {}
for w, f, e in zip(
for w, f in zip(
self._awaitables_per_sharding,
self._features_per_sharding,
self._embedding_names_per_sharding,
):
jt_dict.update(
_construct_jagged_tensors(w.wait(), f, e, self._need_indices)
_construct_jagged_tensors(
embeddings=w.wait(),
features=f,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
)
)
return jt_dict

@@ -250,11 +291,58 @@ def __init__(
optims.append(("", m.fused_optimizer))
self._optim: CombinedOptimizer = CombinedOptimizer(optims)
self._embedding_dim: int = module.embedding_dim
self._embedding_names_per_sharding: List[List[str]] = []
for sharding in self._sharding_type_to_sharding.values():
self._embedding_names_per_sharding.append(sharding.embedding_names())
self._local_embedding_dim: int = self._embedding_dim
self._features_to_permute_indices: Dict[str, List[int]] = {}

if ShardingType.COLUMN_WISE.value in self._sharding_type_to_sharding:
sharding = self._sharding_type_to_sharding[ShardingType.COLUMN_WISE.value]
self._local_embedding_dim = cast(
ShardMetadata, sharding.embedding_shard_metadata()[0]
).shard_sizes[1]
self._generate_permute_indices_per_feature(
module, table_name_to_parameter_sharding
)

self._need_indices: bool = module.need_indices

def _generate_permute_indices_per_feature(
self,
module: EmbeddingCollectionInterface,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
) -> None:
"""
Generates permute indices for features in column-wise sharding.
Outputs are stored in order based on rank i.e. [f_0, f_1, f_2] for f_x = feature
on rank x. However, in column-wise sharding, there can be multiple shards of a
table on the same rank and thereby multiple outputs on the same rank.
i.e.
rank 0: [f_0, f_0, f_1]
rank 1: [f_0, f_1]
when flattened this becomes [f_0, f_0, f_1, f_0, f_1]
f_0's shard ranks = [0, 1, 0]
Since outputs are stored by rank, the intra-shard order is lost and the shards
on rank 0 would be combined first, making an incorrect combination of f_0's
output with the shard ranks = [0, 0, 1].
To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate
permute indices for each feature to match the shard ranks.
"""
for table, embedding_names in zip(
module.embedding_configs, module.embedding_names_by_table
):
sharding = table_name_to_parameter_sharding[table.name]
ranks = cast(List[int], sharding.ranks)
rank_to_indices = defaultdict(deque)
for i, rank in enumerate(sorted(ranks)):
rank_to_indices[rank].append(i)
permute_indices = [rank_to_indices[rank].popleft() for rank in ranks]
for embedding_name in embedding_names:
self._features_to_permute_indices[embedding_name] = permute_indices

def _create_input_dist(
self,
input_feature_names: List[str],
@@ -295,25 +383,22 @@ def input_dist(
features: KeyedJaggedTensor,
) -> Awaitable[SparseFeaturesList]:
if self._has_uninitialized_input_dist:
self._create_input_dist(
input_feature_names=features.keys() if features is not None else []
)
self._create_input_dist(input_feature_names=features.keys())
self._has_uninitialized_input_dist = False
with torch.no_grad():
features_by_sharding = []
if self._features_order:
features = features.permute(
self._features_order,
# pyre-ignore [6]
self._features_order_tensor,
)
features_by_sharding = features.split(
features_by_shards = features.split(
self._feature_splits,
)
# save input splits and output splits in sharding context which
# will be reused in sequence embedding all2all
awaitables = []
for module, features in zip(self._input_dists, features_by_sharding):
for module, features in zip(self._input_dists, features_by_shards):
tensor_awaitable = module(
SparseFeatures(
id_list_features=features,
@@ -351,18 +436,23 @@ def compute(
self, ctx: ShardedModuleContext, dist_input: SparseFeaturesList
) -> List[torch.Tensor]:
ret: List[torch.Tensor] = []
for lookup, features, sharding_ctx in zip(
for lookup, features, sharding_ctx, sharding_type in zip(
self._lookups,
dist_input,
# pyre-ignore [16]
ctx.sharding_contexts,
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
self._sharding_type_to_sharding,
):
sharding_ctx.lengths_after_input_dist = (
features.id_list_features.lengths().view(
-1, features.id_list_features.stride()
)
)
ret.append(lookup(features).view(-1, self._embedding_dim))
embedding_dim = (
self._local_embedding_dim
if sharding_type == ShardingType.COLUMN_WISE.value
else self._embedding_dim
)
ret.append(lookup(features).view(-1, embedding_dim))
return ret

def output_dist(
@@ -373,8 +463,7 @@ def output_dist(
for odist, embeddings, sharding_ctx in zip(
self._output_dists,
output,
# pyre-ignore [16]
ctx.sharding_contexts,
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
):
awaitables_per_sharding.append(odist(embeddings, sharding_ctx))
features_before_all2all_per_sharding.append(
@@ -383,38 +472,43 @@ def output_dist(
return EmbeddingCollectionAwaitable(
awaitables_per_sharding=awaitables_per_sharding,
features_per_sharding=features_before_all2all_per_sharding,
embedding_names_per_sharding=self._embedding_names_per_sharding,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
)

def compute_and_output_dist(
self, ctx: ShardedModuleContext, input: SparseFeaturesList
) -> LazyAwaitable[Dict[str, torch.Tensor]]:
awaitables_per_sharding: List[Awaitable[Dict[str, JaggedTensor]]] = []
features_before_all2all_per_sharding: List[KeyedJaggedTensor] = []
for lookup, odist, features, sharding_ctx in zip(
for lookup, odist, features, sharding_ctx, sharding_type in zip(
self._lookups,
self._output_dists,
input,
# pyre-ignore [16]
ctx.sharding_contexts,
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
self._sharding_type_to_sharding,
):
sharding_ctx.lengths_after_input_dist = (
features.id_list_features.lengths().view(
-1, features.id_list_features.stride()
)
)
embedding_dim = (
self._local_embedding_dim
if sharding_type == ShardingType.COLUMN_WISE.value
else self._embedding_dim
)
awaitables_per_sharding.append(
odist(lookup(features).view(-1, self._embedding_dim), sharding_ctx)
odist(lookup(features).view(-1, embedding_dim), sharding_ctx)
)
features_before_all2all_per_sharding.append(
sharding_ctx.features_before_input_dist
)
return EmbeddingCollectionAwaitable(
awaitables_per_sharding=awaitables_per_sharding,
features_per_sharding=features_before_all2all_per_sharding,
embedding_names_per_sharding=self._embedding_names_per_sharding,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
)

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
@@ -533,8 +627,13 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
types = [
ShardingType.DATA_PARALLEL.value,
ShardingType.TABLE_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
]
if compute_device_type in {"cuda"}:
types += [
ShardingType.ROW_WISE.value,
]

return types

@property
3 changes: 3 additions & 0 deletions torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
@@ -126,6 +126,9 @@ def test_sequence_2_table_perf(self) -> None:
("batched_fused", "table_wise"): [0.001880471390093715],
("batched_fused_uvm", "table_wise"): [0.25958192114736517],
("batched_fused_uvm_caching", "table_wise"): [0.060433813055248066],
("batched_fused", "column_wise"): [0.001880471390093715],
("batched_fused_uvm", "column_wise"): [0.25958192114736517],
("batched_fused_uvm_caching", "column_wise"): [0.060433813055248066],
("batched_fused", "row_wise"): [
0.0007915177871551004,
0.0007915177871551004,
4 changes: 0 additions & 4 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
@@ -109,9 +109,6 @@ def __init__(
self._has_uninitialized_output_dist: bool = True

self._embedding_dim: int = module.embedding_dim
self._embedding_names_per_sharding: List[List[str]] = []
for sharding in self._sharding_type_to_sharding.values():
self._embedding_names_per_sharding.append(sharding.embedding_names())
self._need_indices: bool = module.need_indices

def _create_input_dist(
@@ -225,7 +222,6 @@ def output_dist(
return EmbeddingCollectionAwaitable(
awaitables_per_sharding=awaitables_per_sharding,
features_per_sharding=features_before_all2all_per_sharding,
embedding_names_per_sharding=self._embedding_names_per_sharding,
need_indices=self._need_indices,
)

Loading

0 comments on commit 9242c92

Please sign in to comment.