7
7
8
8
9
9
import copy
10
- from collections import OrderedDict
10
+ from collections import defaultdict , deque , OrderedDict
11
11
from dataclasses import dataclass
12
- from typing import Any , Dict , Iterator , List , Mapping , Optional , Set , Tuple , Type , Union
12
+ from typing import (
13
+ Any ,
14
+ cast ,
15
+ Dict ,
16
+ Iterator ,
17
+ List ,
18
+ Mapping ,
19
+ Optional ,
20
+ Set ,
21
+ Tuple ,
22
+ Type ,
23
+ Union ,
24
+ )
13
25
14
26
import torch
15
27
from torch import nn
27
39
SparseFeatures ,
28
40
SparseFeaturesList ,
29
41
)
42
+ from torchrec .distributed .sharding .cw_sequence_sharding import (
43
+ CwSequenceEmbeddingSharding ,
44
+ )
30
45
from torchrec .distributed .sharding .dp_sequence_sharding import (
31
46
DpSequenceEmbeddingSharding ,
32
47
)
46
61
ShardedModuleContext ,
47
62
ShardedTensor ,
48
63
ShardingEnv ,
64
+ ShardMetadata ,
49
65
)
50
66
from torchrec .distributed .utils import append_prefix , filter_state_dict
51
67
from torchrec .modules .embedding_configs import EmbeddingTableConfig , PoolingType
@@ -76,6 +92,8 @@ def create_embedding_sharding(
76
92
return RwSequenceEmbeddingSharding (sharding_infos , env , device )
77
93
elif sharding_type == ShardingType .DATA_PARALLEL .value :
78
94
return DpSequenceEmbeddingSharding (sharding_infos , env , device )
95
+ elif sharding_type == ShardingType .COLUMN_WISE .value :
96
+ return CwSequenceEmbeddingSharding (sharding_infos , env , device )
79
97
else :
80
98
raise ValueError (f"Sharding not supported { sharding_type } " )
81
99
@@ -134,26 +152,45 @@ def create_sharding_infos_by_sharding(
134
152
def _construct_jagged_tensors (
135
153
embeddings : torch .Tensor ,
136
154
features : KeyedJaggedTensor ,
137
- embedding_names : List [str ],
155
+ features_to_permute_indices : Dict [str , List [ int ] ],
138
156
need_indices : bool = False ,
139
157
) -> Dict [str , JaggedTensor ]:
140
158
ret : Dict [str , JaggedTensor ] = {}
141
- lengths = features .lengths ().view (- 1 , features .stride ())
142
- values = features .values ()
143
- length_per_key = features .length_per_key ()
144
- values_list = torch .split (values , length_per_key ) if need_indices else None
145
- embeddings_list = torch .split (embeddings , length_per_key , dim = 0 )
146
159
stride = features .stride ()
160
+ length_per_key = features .length_per_key ()
161
+ values = features .values ()
162
+
163
+ lengths = features .lengths ().view (- 1 , stride )
147
164
lengths_tuple = torch .unbind (lengths .view (- 1 , stride ), dim = 0 )
165
+ embeddings_list = torch .split (embeddings , length_per_key , dim = 0 )
166
+ values_list = torch .split (values , length_per_key ) if need_indices else None
167
+
168
+ key_indices = defaultdict (list )
148
169
for i , key in enumerate (features .keys ()):
170
+ key_indices [key ].append (i )
171
+ for key , indices in key_indices .items ():
172
+ # combines feature outputs in correct order of shards
173
+ indices = (
174
+ _permute_indices (indices , features_to_permute_indices [key ])
175
+ if features_to_permute_indices and key in features_to_permute_indices
176
+ else indices
177
+ )
178
+ feature_embeddings = [embeddings_list [i ] for i in indices ]
149
179
ret [key ] = JaggedTensor (
150
- lengths = lengths_tuple [i ],
151
- values = embeddings_list [ i ] ,
152
- weights = values_list [i ] if need_indices else None ,
180
+ lengths = lengths_tuple [indices [ 0 ] ],
181
+ values = torch . cat ( feature_embeddings , dim = 1 ) ,
182
+ weights = values_list [indices [ 0 ] ] if need_indices else None ,
153
183
)
154
184
return ret
155
185
156
186
187
+ def _permute_indices (indices : List [int ], permute : List [int ]) -> List [int ]:
188
+ permuted_indices = [0 ] * len (indices )
189
+ for i , permuted_index in enumerate (permute ):
190
+ permuted_indices [i ] = indices [permuted_index ]
191
+ return permuted_indices
192
+
193
+
157
194
@dataclass
158
195
class EmbeddingCollectionContext (ShardedModuleContext ):
159
196
sharding_contexts : List [SequenceShardingContext ]
@@ -168,24 +205,28 @@ def __init__(
168
205
self ,
169
206
awaitables_per_sharding : List [Awaitable [torch .Tensor ]],
170
207
features_per_sharding : List [KeyedJaggedTensor ],
171
- embedding_names_per_sharding : List [str ],
208
+ features_to_permute_indices : Dict [str , List [ int ] ],
172
209
need_indices : bool = False ,
173
210
) -> None :
174
211
super ().__init__ ()
175
212
self ._awaitables_per_sharding = awaitables_per_sharding
176
213
self ._features_per_sharding = features_per_sharding
177
- self ._embedding_names_per_sharding = embedding_names_per_sharding
214
+ self ._features_to_permute_indices = features_to_permute_indices
178
215
self ._need_indices = need_indices
179
216
180
217
def _wait_impl (self ) -> Dict [str , JaggedTensor ]:
181
218
jt_dict : Dict [str , JaggedTensor ] = {}
182
- for w , f , e in zip (
219
+ for w , f in zip (
183
220
self ._awaitables_per_sharding ,
184
221
self ._features_per_sharding ,
185
- self ._embedding_names_per_sharding ,
186
222
):
187
223
jt_dict .update (
188
- _construct_jagged_tensors (w .wait (), f , e , self ._need_indices )
224
+ _construct_jagged_tensors (
225
+ embeddings = w .wait (),
226
+ features = f ,
227
+ features_to_permute_indices = self ._features_to_permute_indices ,
228
+ need_indices = self ._need_indices ,
229
+ )
189
230
)
190
231
return jt_dict
191
232
@@ -250,11 +291,58 @@ def __init__(
250
291
optims .append (("" , m .fused_optimizer ))
251
292
self ._optim : CombinedOptimizer = CombinedOptimizer (optims )
252
293
self ._embedding_dim : int = module .embedding_dim
253
- self ._embedding_names_per_sharding : List [List [str ]] = []
254
- for sharding in self ._sharding_type_to_sharding .values ():
255
- self ._embedding_names_per_sharding .append (sharding .embedding_names ())
294
+ self ._local_embedding_dim : int = self ._embedding_dim
295
+ self ._features_to_permute_indices : Dict [str , List [int ]] = {}
296
+
297
+ if ShardingType .COLUMN_WISE .value in self ._sharding_type_to_sharding :
298
+ sharding = self ._sharding_type_to_sharding [ShardingType .COLUMN_WISE .value ]
299
+ self ._local_embedding_dim = cast (
300
+ ShardMetadata , sharding .embedding_shard_metadata ()[0 ]
301
+ ).shard_sizes [1 ]
302
+ self ._generate_permute_indices_per_feature (
303
+ module , table_name_to_parameter_sharding
304
+ )
305
+
256
306
self ._need_indices : bool = module .need_indices
257
307
308
+ def _generate_permute_indices_per_feature (
309
+ self ,
310
+ module : EmbeddingCollectionInterface ,
311
+ table_name_to_parameter_sharding : Dict [str , ParameterSharding ],
312
+ ) -> None :
313
+ """
314
+ Generates permute indices for features in column-wise sharding.
315
+
316
+ Outputs are stored in order based on rank i.e. [f_0, f_1, f_2] for f_x = feature
317
+ on rank x. However, in column-wise sharding, there can be multiple shards of a
318
+ table on the same rank and thereby multiple outputs on the same rank.
319
+
320
+ i.e.
321
+ rank 0: [f_0, f_0, f_1]
322
+ rank 1: [f_0, f_1]
323
+ when flattened this becomes [f_0, f_0, f_1, f_0, f_1]
324
+
325
+ f_0's shard ranks = [0, 1, 0]
326
+
327
+ Since outputs are stored by rank, the intra-shard order is lost and the shards
328
+ on rank 0 would be combined first, making an incorrect combination of f_0's
329
+ output with the shard ranks = [0, 0, 1].
330
+
331
+ To keep the correct shard rank of [0, 1, 0] when combining outputs, we generate
332
+ permute indices for each feature to match the shard ranks.
333
+ """
334
+ for table , embedding_names in zip (
335
+ module .embedding_configs , module .embedding_names_by_table
336
+ ):
337
+ sharding = table_name_to_parameter_sharding [table .name ]
338
+ ranks = cast (List [int ], sharding .ranks )
339
+ rank_to_indices = defaultdict (deque )
340
+ for i , rank in enumerate (sorted (ranks )):
341
+ rank_to_indices [rank ].append (i )
342
+ permute_indices = [rank_to_indices [rank ].popleft () for rank in ranks ]
343
+ for embedding_name in embedding_names :
344
+ self ._features_to_permute_indices [embedding_name ] = permute_indices
345
+
258
346
def _create_input_dist (
259
347
self ,
260
348
input_feature_names : List [str ],
@@ -295,25 +383,22 @@ def input_dist(
295
383
features : KeyedJaggedTensor ,
296
384
) -> Awaitable [SparseFeaturesList ]:
297
385
if self ._has_uninitialized_input_dist :
298
- self ._create_input_dist (
299
- input_feature_names = features .keys () if features is not None else []
300
- )
386
+ self ._create_input_dist (input_feature_names = features .keys ())
301
387
self ._has_uninitialized_input_dist = False
302
388
with torch .no_grad ():
303
- features_by_sharding = []
304
389
if self ._features_order :
305
390
features = features .permute (
306
391
self ._features_order ,
307
392
# pyre-ignore [6]
308
393
self ._features_order_tensor ,
309
394
)
310
- features_by_sharding = features .split (
395
+ features_by_shards = features .split (
311
396
self ._feature_splits ,
312
397
)
313
398
# save input splits and output splits in sharding context which
314
399
# will be reused in sequence embedding all2all
315
400
awaitables = []
316
- for module , features in zip (self ._input_dists , features_by_sharding ):
401
+ for module , features in zip (self ._input_dists , features_by_shards ):
317
402
tensor_awaitable = module (
318
403
SparseFeatures (
319
404
id_list_features = features ,
@@ -351,18 +436,23 @@ def compute(
351
436
self , ctx : ShardedModuleContext , dist_input : SparseFeaturesList
352
437
) -> List [torch .Tensor ]:
353
438
ret : List [torch .Tensor ] = []
354
- for lookup , features , sharding_ctx in zip (
439
+ for lookup , features , sharding_ctx , sharding_type in zip (
355
440
self ._lookups ,
356
441
dist_input ,
357
- # pyre-ignore [16]
358
- ctx . sharding_contexts ,
442
+ cast ( EmbeddingCollectionContext , ctx ). sharding_contexts ,
443
+ self . _sharding_type_to_sharding ,
359
444
):
360
445
sharding_ctx .lengths_after_input_dist = (
361
446
features .id_list_features .lengths ().view (
362
447
- 1 , features .id_list_features .stride ()
363
448
)
364
449
)
365
- ret .append (lookup (features ).view (- 1 , self ._embedding_dim ))
450
+ embedding_dim = (
451
+ self ._local_embedding_dim
452
+ if sharding_type == ShardingType .COLUMN_WISE .value
453
+ else self ._embedding_dim
454
+ )
455
+ ret .append (lookup (features ).view (- 1 , embedding_dim ))
366
456
return ret
367
457
368
458
def output_dist (
@@ -373,8 +463,7 @@ def output_dist(
373
463
for odist , embeddings , sharding_ctx in zip (
374
464
self ._output_dists ,
375
465
output ,
376
- # pyre-ignore [16]
377
- ctx .sharding_contexts ,
466
+ cast (EmbeddingCollectionContext , ctx ).sharding_contexts ,
378
467
):
379
468
awaitables_per_sharding .append (odist (embeddings , sharding_ctx ))
380
469
features_before_all2all_per_sharding .append (
@@ -383,7 +472,7 @@ def output_dist(
383
472
return EmbeddingCollectionAwaitable (
384
473
awaitables_per_sharding = awaitables_per_sharding ,
385
474
features_per_sharding = features_before_all2all_per_sharding ,
386
- embedding_names_per_sharding = self ._embedding_names_per_sharding ,
475
+ features_to_permute_indices = self ._features_to_permute_indices ,
387
476
need_indices = self ._need_indices ,
388
477
)
389
478
@@ -392,28 +481,33 @@ def compute_and_output_dist(
392
481
) -> LazyAwaitable [Dict [str , torch .Tensor ]]:
393
482
awaitables_per_sharding : List [Awaitable [Dict [str , JaggedTensor ]]] = []
394
483
features_before_all2all_per_sharding : List [KeyedJaggedTensor ] = []
395
- for lookup , odist , features , sharding_ctx in zip (
484
+ for lookup , odist , features , sharding_ctx , sharding_type in zip (
396
485
self ._lookups ,
397
486
self ._output_dists ,
398
487
input ,
399
- # pyre-ignore [16]
400
- ctx . sharding_contexts ,
488
+ cast ( EmbeddingCollectionContext , ctx ). sharding_contexts ,
489
+ self . _sharding_type_to_sharding ,
401
490
):
402
491
sharding_ctx .lengths_after_input_dist = (
403
492
features .id_list_features .lengths ().view (
404
493
- 1 , features .id_list_features .stride ()
405
494
)
406
495
)
496
+ embedding_dim = (
497
+ self ._local_embedding_dim
498
+ if sharding_type == ShardingType .COLUMN_WISE .value
499
+ else self ._embedding_dim
500
+ )
407
501
awaitables_per_sharding .append (
408
- odist (lookup (features ).view (- 1 , self . _embedding_dim ), sharding_ctx )
502
+ odist (lookup (features ).view (- 1 , embedding_dim ), sharding_ctx )
409
503
)
410
504
features_before_all2all_per_sharding .append (
411
505
sharding_ctx .features_before_input_dist
412
506
)
413
507
return EmbeddingCollectionAwaitable (
414
508
awaitables_per_sharding = awaitables_per_sharding ,
415
509
features_per_sharding = features_before_all2all_per_sharding ,
416
- embedding_names_per_sharding = self ._embedding_names_per_sharding ,
510
+ features_to_permute_indices = self ._features_to_permute_indices ,
417
511
need_indices = self ._need_indices ,
418
512
)
419
513
@@ -533,8 +627,13 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
533
627
types = [
534
628
ShardingType .DATA_PARALLEL .value ,
535
629
ShardingType .TABLE_WISE .value ,
536
- ShardingType .ROW_WISE .value ,
630
+ ShardingType .COLUMN_WISE .value ,
537
631
]
632
+ if compute_device_type in {"cuda" }:
633
+ types += [
634
+ ShardingType .ROW_WISE .value ,
635
+ ]
636
+
538
637
return types
539
638
540
639
@property
0 commit comments