Skip to content

Commit

Permalink
generalize trec feature processor (pytorch#443)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#443

torchrec feature processor oss change to support D37165122

Reviewed By: xing-liu, RenfeiChen-FB

Differential Revision: D35735337

fbshipit-source-id: 0bf93776b0f3994962ae325b6105002284a589fd
  • Loading branch information
Ning Wang authored and facebook-github-bot committed Jun 21, 2022
1 parent 6da4b4c commit ca52aa0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
20 changes: 12 additions & 8 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,7 @@ def forward(
if (
config.has_feature_processor
and self._feature_processor is not None
and isinstance(
self._feature_processor, GroupedPositionWeightedModule
)
and isinstance(self._feature_processor, BaseGroupedFeatureProcessor)
):
features = self._feature_processor(features)
embeddings.append(emb_op(features))
Expand All @@ -301,9 +299,17 @@ def forward(
self._id_score_list_feature_splits,
)
)
for emb_op, features in zip(
self._score_emb_modules, id_score_list_features_by_group
for config, emb_op, features in zip(
self.grouped_score_configs,
self._score_emb_modules,
id_score_list_features_by_group,
):
if (
config.has_feature_processor
and self._feature_processor is not None
and isinstance(self._feature_processor, BaseGroupedFeatureProcessor)
):
features = self._feature_processor(features)
embeddings.append(emb_op(features))

if len(embeddings) == 0:
Expand Down Expand Up @@ -563,9 +569,7 @@ def forward(
if (
config.has_feature_processor
and self._feature_processor is not None
and isinstance(
self._feature_processor, GroupedPositionWeightedModule
)
and isinstance(self._feature_processor, BaseGroupedFeatureProcessor)
):
features = self._feature_processor(features)
embeddings.append(emb_op(features))
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _group_tables_per_rank(
for data_type in DataType:
for pooling in PoolingType:
for is_weighted in [True, False]:
for has_feature_processor in [True, False]:
for has_feature_processor in [False, True]:
for compute_kernel in [
EmbeddingComputeKernel.DENSE,
EmbeddingComputeKernel.FUSED,
Expand Down

0 comments on commit ca52aa0

Please sign in to comment.