Skip to content

Commit

Permalink
Add batching to quant_ebc (pytorch#463)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#463

ATT, see discussion in

https://fb.workplace.com/groups/970281557043698/permalink/1243408826397635/

Reviewed By: zyan0

Differential Revision: D37352626

fbshipit-source-id: acabc75cea7c1a0b7082b135ab26abd2106c8723
  • Loading branch information
YLGH authored and facebook-github-bot committed Jun 29, 2022
1 parent 286a5cd commit d9684a1
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 73 deletions.
189 changes: 133 additions & 56 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# LICENSE file in the root directory of this source tree.

import copy
from collections import OrderedDict
import itertools
from collections import defaultdict, OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch
Expand All @@ -17,6 +18,7 @@
PoolingMode,
)
from torch import Tensor
from torch.nn.modules.module import _IncompatibleKeys
from torchrec.modules.embedding_configs import (
DATA_TYPE_NUM_BITS,
data_type_to_sparse_type,
Expand All @@ -25,6 +27,7 @@
EmbeddingBagConfig,
EmbeddingConfig,
pooling_type_to_pooling_mode,
PoolingType,
)
from torchrec.modules.embedding_modules import (
EmbeddingBagCollection as OriginalEmbeddingBagCollection,
Expand All @@ -35,7 +38,6 @@
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor


try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
Expand Down Expand Up @@ -115,7 +117,8 @@ def quantize_state_dict(
class EmbeddingBagCollection(EmbeddingBagCollectionInterface):
"""
EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags).
This EmbeddingBagCollection is quantized for lower precision. It relies on fbgemm quantized ops
This EmbeddingBagCollection is quantized for lower precision. It relies on fbgemm quantized ops and provides
table batching.
It processes sparse data in the form of KeyedJaggedTensor
with values of the form [F X B X L]
Expand Down Expand Up @@ -184,73 +187,115 @@ def __init__(
super().__init__()
self._is_weighted = is_weighted
self._embedding_bag_configs: List[EmbeddingBagConfig] = embedding_configs
self.embedding_bags: nn.ModuleList = nn.ModuleList()
self._lengths_per_embedding: List[int] = []
self._key_to_tables: Dict[
Tuple[PoolingType, DataType], List[EmbeddingBagConfig]
] = defaultdict(list)
self._length_per_key: List[int] = []
self._emb_modules: nn.ModuleList = nn.ModuleList()
self._output_dtype = output_dtype

table_names = set()
for emb_config in self._embedding_bag_configs:
if emb_config.name in table_names:
raise ValueError(f"Duplicate table name {emb_config.name}")
table_names.add(emb_config.name)
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
for table in self._embedding_bag_configs:
if table.name in table_names:
raise ValueError(f"Duplicate table name {table.name}")
table_names.add(table.name)
self._length_per_key.extend(
[table.embedding_dim] * len(table.feature_names)
)
key = (table.pooling, table.data_type)
self._key_to_tables[key].append(table)

self._sum_length_per_key: int = sum(self._length_per_key)

location = (
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
)

for key, emb_configs in self._key_to_tables.items():
(pooling, data_type) = key
embedding_specs = []
weight_lists = []
for table in emb_configs:
embedding_specs.append(
(
"",
emb_config.num_embeddings,
emb_config.embedding_dim,
data_type_to_sparse_type(emb_config.data_type),
EmbeddingLocation.HOST
if device.type == "cpu"
else EmbeddingLocation.DEVICE,
table.name,
table.num_embeddings,
table.embedding_dim,
data_type_to_sparse_type(data_type),
location,
)
],
pooling_mode=pooling_type_to_pooling_mode(emb_config.pooling),
weight_lists=[table_name_to_quantized_weights[emb_config.name]],
)
weight_lists.append(table_name_to_quantized_weights[table.name])

emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=embedding_specs,
pooling_mode=pooling_type_to_pooling_mode(pooling),
weight_lists=weight_lists,
device=device,
output_dtype=data_type_to_sparse_type(dtype_to_data_type(output_dtype)),
row_alignment=16,
)
self.embedding_bags.append(emb_module)
if not emb_config.feature_names:
emb_config.feature_names = [emb_config.name]
self._lengths_per_embedding.extend(
len(emb_config.feature_names) * [emb_config.embedding_dim]
)
self._emb_modules.append(emb_module)

self._embedding_names: List[str] = [
embedding
for embeddings in get_embedding_names_by_table(embedding_configs)
for embedding in embeddings
]
self._embedding_names: List[str] = list(
itertools.chain(*get_embedding_names_by_table(self._embedding_bag_configs))
)

def forward(
self,
features: KeyedJaggedTensor,
) -> KeyedTensor:
pooled_embeddings: List[Tensor] = []
length_per_key: List[int] = []
"""
Args:
features (KeyedJaggedTensor): KJT of form [F X B X L].
Returns:
KeyedTensor
"""

feature_dict = features.to_dict()
for emb_config, emb_module in zip(
self._embedding_bag_configs, self.embedding_bags
embeddings = []

# TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script.
# Once torchsccript is no longer a requirement, we should revisit this.

for emb_op, (_key, tables) in zip(
self._emb_modules, self._key_to_tables.items()
):
for feature_name in emb_config.feature_names:
f = feature_dict[feature_name]
values = f.values()
offsets = f.offsets()
pooled_embeddings.append(
emb_module(
indices=values.int(),
offsets=offsets.int(),
per_sample_weights=f.weights() if self._is_weighted else None,
)
indices = []
lengths = []
offsets = []
weights = []

for table in tables:
for feature in table.feature_names:
f = feature_dict[feature]
indices.append(f.values())
lengths.append(f.lengths())
if self._is_weighted:
weights.append(f.weights())

indices = torch.cat(indices)
lengths = torch.cat(lengths)

offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
if self._is_weighted:
weights = torch.cat(weights)

embeddings.append(
emb_op(
indices=indices.int(),
offsets=offsets.int(),
per_sample_weights=weights if self._is_weighted else None,
)
)

length_per_key.append(emb_config.embedding_dim)
embeddings = torch.stack(embeddings).reshape(-1, self._sum_length_per_key)

return KeyedTensor(
keys=self._embedding_names,
values=torch.cat(pooled_embeddings, dim=1),
length_per_key=self._lengths_per_embedding,
values=embeddings,
length_per_key=self._length_per_key,
)

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
Expand All @@ -264,16 +309,48 @@ def state_dict(
destination = OrderedDict()
# pyre-ignore [16]
destination._metadata = OrderedDict()
for emb_config, emb_module in zip(
self._embedding_bag_configs,
self.embedding_bags,
for emb_op, (_key, tables) in zip(
self._emb_modules, self._key_to_tables.items()
):
(weight, _) = emb_module.split_embedding_weights(split_scale_shifts=False)[
0
]
destination[prefix + f"embedding_bags.{emb_config.name}.weight"] = weight
for table, (weight, _) in zip(
tables, emb_op.split_embedding_weights(split_scale_shifts=False)
):
destination[prefix + f"embedding_bags.{table.name}.weight"] = weight
return destination

# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` inconsistently.
def load_state_dict(
self,
state_dict: "OrderedDict[str, torch.Tensor]",
strict: bool = True,
) -> _IncompatibleKeys:

missing_keys = []
unexpected_keys = []

current_state_dict = self.state_dict()
for key in current_state_dict.keys():
if key not in state_dict:
missing_keys.append(key)
for key in state_dict.keys():
if key not in current_state_dict.keys():
unexpected_keys.append(key)

if missing_keys or unexpected_keys:
return _IncompatibleKeys(
missing_keys=missing_keys, unexpected_keys=unexpected_keys
)

for (_key, tables) in self._key_to_tables.items():
for table in tables:
current_state_dict[
f"embedding_bags.{table.name}.weight"
].detach().copy_(state_dict[f"embedding_bags.{table.name}.weight"])

return _IncompatibleKeys(
missing_keys=missing_keys, unexpected_keys=unexpected_keys
)

def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
Expand Down
Loading

0 comments on commit d9684a1

Please sign in to comment.