Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2b1710f

Browse files
joshuadengfacebook-github-bot
authored andcommittedJun 14, 2022
Add column wise sharding support for EmbeddingCollection (sequence embeddings) (pytorch#432)
Summary: Pull Request resolved: pytorch#432 Support for CW sharding in EmbeddingCollection. Added logic to stitch feature outputs with local embedding dim after output dist to match original dim. Respects the correct order according to sharding as the output loses order of shards (provides it in order based on rank). Differential Revision: D36944684 fbshipit-source-id: 12e698093d0aa5e09441084d2eba77e08b53d9f1
1 parent ab83e17 commit 2b1710f

File tree

6 files changed

+248
-44
lines changed

6 files changed

+248
-44
lines changed
 

‎torchrec/distributed/dist_data.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ class PooledEmbeddingsAllToAll(nn.Module):
619619
dim_sum_per_rank (List[int]): number of features (sum of dimensions) of the
620620
embedding in each rank.
621621
device (Optional[torch.device]): device on which buffers will be allocated.
622-
callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]])
622+
callbacks (Optional[List[Callable[[torch.Tensor], torch.Tensor]]]):
623623
624624
Example::
625625
@@ -669,6 +669,8 @@ def forward(
669669
670670
Args:
671671
local_embs (torch.Tensor): tensor of values to distribute.
672+
batch_size_per_rank (Optional[List[int]]): batch size per rank, to support
673+
variable batch size.
672674
673675
Returns:
674676
PooledEmbeddingsAwaitable: awaitable of pooled embeddings.
@@ -838,7 +840,7 @@ def _wait_impl(self) -> torch.Tensor:
838840
Syncs sequence embeddings after collective operation.
839841
840842
Returns:
841-
torch.Tensor: synced pooled embeddings.
843+
torch.Tensor: synced sequence embeddings.
842844
"""
843845

844846
ret = self._tensor_awaitable.wait()

‎torchrec/distributed/embedding.py

+137-38
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,21 @@
77

88

99
import copy
10-
from collections import OrderedDict
10+
from collections import defaultdict, deque, OrderedDict
1111
from dataclasses import dataclass
12-
from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple, Type, Union
12+
from typing import (
13+
Any,
14+
cast,
15+
Dict,
16+
Iterator,
17+
List,
18+
Mapping,
19+
Optional,
20+
Set,
21+
Tuple,
22+
Type,
23+
Union,
24+
)
1325

1426
import torch
1527
from torch import nn
@@ -27,6 +39,9 @@
2739
SparseFeatures,
2840
SparseFeaturesList,
2941
)
42+
from torchrec.distributed.sharding.cw_sequence_sharding import (
43+
CwSequenceEmbeddingSharding,
44+
)
3045
from torchrec.distributed.sharding.dp_sequence_sharding import (
3146
DpSequenceEmbeddingSharding,
3247
)
@@ -46,6 +61,7 @@
4661
ShardedModuleContext,
4762
ShardedTensor,
4863
ShardingEnv,
64+
ShardMetadata,
4965
)
5066
from torchrec.distributed.utils import append_prefix, filter_state_dict
5167
from torchrec.modules.embedding_configs import EmbeddingTableConfig, PoolingType
@@ -76,6 +92,8 @@ def create_embedding_sharding(
7692
return RwSequenceEmbeddingSharding(sharding_infos, env, device)
7793
elif sharding_type == ShardingType.DATA_PARALLEL.value:
7894
return DpSequenceEmbeddingSharding(sharding_infos, env, device)
95+
elif sharding_type == ShardingType.COLUMN_WISE.value:
96+
return CwSequenceEmbeddingSharding(sharding_infos, env, device)
7997
else:
8098
raise ValueError(f"Sharding not supported {sharding_type}")
8199

@@ -134,26 +152,45 @@ def create_sharding_infos_by_sharding(
134152
def _construct_jagged_tensors(
135153
embeddings: torch.Tensor,
136154
features: KeyedJaggedTensor,
137-
embedding_names: List[str],
155+
features_to_permute_indices: Dict[str, List[int]],
138156
need_indices: bool = False,
139157
) -> Dict[str, JaggedTensor]:
140158
ret: Dict[str, JaggedTensor] = {}
141-
lengths = features.lengths().view(-1, features.stride())
142-
values = features.values()
143-
length_per_key = features.length_per_key()
144-
values_list = torch.split(values, length_per_key) if need_indices else None
145-
embeddings_list = torch.split(embeddings, length_per_key, dim=0)
146159
stride = features.stride()
160+
length_per_key = features.length_per_key()
161+
values = features.values()
162+
163+
lengths = features.lengths().view(-1, stride)
147164
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
165+
embeddings_list = torch.split(embeddings, length_per_key, dim=0)
166+
values_list = torch.split(values, length_per_key) if need_indices else None
167+
168+
key_indices = defaultdict(list)
148169
for i, key in enumerate(features.keys()):
170+
key_indices[key].append(i)
171+
for key, indices in key_indices.items():
172+
# combines feature outputs in correct order of shards
173+
indices = (
174+
_permute_indices(indices, features_to_permute_indices[key])
175+
if features_to_permute_indices and key in features_to_permute_indices
176+
else indices
177+
)
178+
feature_embeddings = [embeddings_list[i] for i in indices]
149179
ret[key] = JaggedTensor(
150-
lengths=lengths_tuple[i],
151-
values=embeddings_list[i],
152-
weights=values_list[i] if need_indices else None,
180+
lengths=lengths_tuple[indices[0]],
181+
values=torch.cat(feature_embeddings, dim=1),
182+
weights=values_list[indices[0]] if need_indices else None,
153183
)
154184
return ret
155185

156186

187+
def _permute_indices(indices: List[int], permute: List[int]) -> List[int]:
188+
permuted_indices = [0] * len(indices)
189+
for i, permuted_index in enumerate(permute):
190+
permuted_indices[i] = indices[permuted_index]
191+
return permuted_indices
192+
193+
157194
@dataclass
158195
class EmbeddingCollectionContext(ShardedModuleContext):
159196
sharding_contexts: List[SequenceShardingContext]
@@ -168,24 +205,28 @@ def __init__(
168205
self,
169206
awaitables_per_sharding: List[Awaitable[torch.Tensor]],
170207
features_per_sharding: List[KeyedJaggedTensor],
171-
embedding_names_per_sharding: List[str],
208+
features_to_permute_indices: Dict[str, List[int]],
172209
need_indices: bool = False,
173210
) -> None:
174211
super().__init__()
175212
self._awaitables_per_sharding = awaitables_per_sharding
176213
self._features_per_sharding = features_per_sharding
177-
self._embedding_names_per_sharding = embedding_names_per_sharding
214+
self._features_to_permute_indices = features_to_permute_indices
178215
self._need_indices = need_indices
179216

180217
def _wait_impl(self) -> Dict[str, JaggedTensor]:
181218
jt_dict: Dict[str, JaggedTensor] = {}
182-
for w, f, e in zip(
219+
for w, f in zip(
183220
self._awaitables_per_sharding,
184221
self._features_per_sharding,
185-
self._embedding_names_per_sharding,
186222
):
187223
jt_dict.update(
188-
_construct_jagged_tensors(w.wait(), f, e, self._need_indices)
224+
_construct_jagged_tensors(
225+
embeddings=w.wait(),
226+
features=f,
227+
features_to_permute_indices=self._features_to_permute_indices,
228+
need_indices=self._need_indices,
229+
)
189230
)
190231
return jt_dict
191232

@@ -250,11 +291,58 @@ def __init__(
250291
optims.append(("", m.fused_optimizer))
251292
self._optim: CombinedOptimizer = CombinedOptimizer(optims)
252293
self._embedding_dim: int = module.embedding_dim
253-
self._embedding_names_per_sharding: List[List[str]] = []
254-
for sharding in self._sharding_type_to_sharding.values():
255-
self._embedding_names_per_sharding.append(sharding.embedding_names())
294+
self._local_embedding_dim: int = self._embedding_dim
295+
self._features_to_permute_indices: Dict[str, List[int]] = {}
296+
297+
if ShardingType.COLUMN_WISE.value in self._sharding_type_to_sharding:
298+
sharding = self._sharding_type_to_sharding[ShardingType.COLUMN_WISE.value]
299+
self._local_embedding_dim = cast(
300+
ShardMetadata, sharding.embedding_shard_metadata()[0]
301+
).shard_sizes[1]
302+
self._generate_permute_indices_per_feature(
303+
module, table_name_to_parameter_sharding
304+
)
305+
256306
self._need_indices: bool = module.need_indices
257307

308+
def _generate_permute_indices_per_feature(
309+
self,
310+
module: EmbeddingCollectionInterface,
311+
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
312+
) -> None:
313+
"""
314+
Generates permute indices for features in column-wise sharding.
315+
316+
Outputs are stored in order based on rank i.e. [f_0, f_1, f_2] for f_x = feature
317+
on rank x. However, in column-wise sharding, there can be multiple shards of a
318+
table on the same rank and thereby multiple outputs on the same rank.
319+
320+
i.e.
321+
rank 0: [f_0, f_0, f_1]
322+
rank 1: [f_0, f_1]
323+
when flattened this becomes [f_0, f_0, f_1, f_0, f_1]
324+
325+
f_0's shard ranks = [0, 1, 0]
326+
327+
Since outputs are stored by rank, the intra-shard order is lost and the shards
328+
on rank 0 would be combined first, making an incorrect combination of f_0's
329+
output with the shard ranks = [0, 0, 1].
330+
331+
To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate
332+
permute indices for each feature to match the shard ranks.
333+
"""
334+
for table, embedding_names in zip(
335+
module.embedding_configs, module.embedding_names_by_table
336+
):
337+
sharding = table_name_to_parameter_sharding[table.name]
338+
ranks = cast(List[int], sharding.ranks)
339+
rank_to_indices = defaultdict(deque)
340+
for i, rank in enumerate(sorted(ranks)):
341+
rank_to_indices[rank].append(i)
342+
permute_indices = [rank_to_indices[rank].popleft() for rank in ranks]
343+
for embedding_name in embedding_names:
344+
self._features_to_permute_indices[embedding_name] = permute_indices
345+
258346
def _create_input_dist(
259347
self,
260348
input_feature_names: List[str],
@@ -295,25 +383,22 @@ def input_dist(
295383
features: KeyedJaggedTensor,
296384
) -> Awaitable[SparseFeaturesList]:
297385
if self._has_uninitialized_input_dist:
298-
self._create_input_dist(
299-
input_feature_names=features.keys() if features is not None else []
300-
)
386+
self._create_input_dist(input_feature_names=features.keys())
301387
self._has_uninitialized_input_dist = False
302388
with torch.no_grad():
303-
features_by_sharding = []
304389
if self._features_order:
305390
features = features.permute(
306391
self._features_order,
307392
# pyre-ignore [6]
308393
self._features_order_tensor,
309394
)
310-
features_by_sharding = features.split(
395+
features_by_shards = features.split(
311396
self._feature_splits,
312397
)
313398
# save input splits and output splits in sharding context which
314399
# will be reused in sequence embedding all2all
315400
awaitables = []
316-
for module, features in zip(self._input_dists, features_by_sharding):
401+
for module, features in zip(self._input_dists, features_by_shards):
317402
tensor_awaitable = module(
318403
SparseFeatures(
319404
id_list_features=features,
@@ -351,18 +436,23 @@ def compute(
351436
self, ctx: ShardedModuleContext, dist_input: SparseFeaturesList
352437
) -> List[torch.Tensor]:
353438
ret: List[torch.Tensor] = []
354-
for lookup, features, sharding_ctx in zip(
439+
for lookup, features, sharding_ctx, sharding_type in zip(
355440
self._lookups,
356441
dist_input,
357-
# pyre-ignore [16]
358-
ctx.sharding_contexts,
442+
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
443+
self._sharding_type_to_sharding,
359444
):
360445
sharding_ctx.lengths_after_input_dist = (
361446
features.id_list_features.lengths().view(
362447
-1, features.id_list_features.stride()
363448
)
364449
)
365-
ret.append(lookup(features).view(-1, self._embedding_dim))
450+
embedding_dim = (
451+
self._local_embedding_dim
452+
if sharding_type == ShardingType.COLUMN_WISE.value
453+
else self._embedding_dim
454+
)
455+
ret.append(lookup(features).view(-1, embedding_dim))
366456
return ret
367457

368458
def output_dist(
@@ -373,8 +463,7 @@ def output_dist(
373463
for odist, embeddings, sharding_ctx in zip(
374464
self._output_dists,
375465
output,
376-
# pyre-ignore [16]
377-
ctx.sharding_contexts,
466+
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
378467
):
379468
awaitables_per_sharding.append(odist(embeddings, sharding_ctx))
380469
features_before_all2all_per_sharding.append(
@@ -383,7 +472,7 @@ def output_dist(
383472
return EmbeddingCollectionAwaitable(
384473
awaitables_per_sharding=awaitables_per_sharding,
385474
features_per_sharding=features_before_all2all_per_sharding,
386-
embedding_names_per_sharding=self._embedding_names_per_sharding,
475+
features_to_permute_indices=self._features_to_permute_indices,
387476
need_indices=self._need_indices,
388477
)
389478

@@ -392,28 +481,33 @@ def compute_and_output_dist(
392481
) -> LazyAwaitable[Dict[str, torch.Tensor]]:
393482
awaitables_per_sharding: List[Awaitable[Dict[str, JaggedTensor]]] = []
394483
features_before_all2all_per_sharding: List[KeyedJaggedTensor] = []
395-
for lookup, odist, features, sharding_ctx in zip(
484+
for lookup, odist, features, sharding_ctx, sharding_type in zip(
396485
self._lookups,
397486
self._output_dists,
398487
input,
399-
# pyre-ignore [16]
400-
ctx.sharding_contexts,
488+
cast(EmbeddingCollectionContext, ctx).sharding_contexts,
489+
self._sharding_type_to_sharding,
401490
):
402491
sharding_ctx.lengths_after_input_dist = (
403492
features.id_list_features.lengths().view(
404493
-1, features.id_list_features.stride()
405494
)
406495
)
496+
embedding_dim = (
497+
self._local_embedding_dim
498+
if sharding_type == ShardingType.COLUMN_WISE.value
499+
else self._embedding_dim
500+
)
407501
awaitables_per_sharding.append(
408-
odist(lookup(features).view(-1, self._embedding_dim), sharding_ctx)
502+
odist(lookup(features).view(-1, embedding_dim), sharding_ctx)
409503
)
410504
features_before_all2all_per_sharding.append(
411505
sharding_ctx.features_before_input_dist
412506
)
413507
return EmbeddingCollectionAwaitable(
414508
awaitables_per_sharding=awaitables_per_sharding,
415509
features_per_sharding=features_before_all2all_per_sharding,
416-
embedding_names_per_sharding=self._embedding_names_per_sharding,
510+
features_to_permute_indices=self._features_to_permute_indices,
417511
need_indices=self._need_indices,
418512
)
419513

@@ -533,8 +627,13 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
533627
types = [
534628
ShardingType.DATA_PARALLEL.value,
535629
ShardingType.TABLE_WISE.value,
536-
ShardingType.ROW_WISE.value,
630+
ShardingType.COLUMN_WISE.value,
537631
]
632+
if compute_device_type in {"cuda"}:
633+
types += [
634+
ShardingType.ROW_WISE.value,
635+
]
636+
538637
return types
539638

540639
@property
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Failed to load comments.