Skip to content

Commit

Permalink
add not running with deploy when calling torch.compiler
Browse files Browse the repository at this point in the history
Reviewed By: PaulZhang12

Differential Revision: D55389988
  • Loading branch information
s4ayub authored and facebook-github-bot committed Mar 27, 2024
1 parent 26b6899 commit de3f53d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
3 changes: 2 additions & 1 deletion torchrec/distributed/comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.autograd.profiler import record_function
from torchrec.distributed.types import Awaitable, NoWait, QuantizedCommCodecs
from torchrec.distributed.utils import none_throws
from torchrec.sparse.jagged_tensor import can_use_torch_compiler

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -1011,7 +1012,7 @@ def reduce_scatter_v_pooled(
]

equal_splits = False
if not torch.compiler.is_dynamo_compiling():
if can_use_torch_compiler() and not torch.compiler.is_dynamo_compiling():
# We can not check during tracing equality of splits -> fallback on general
equal_splits = all(ip_split == input_splits[0] for ip_split in input_splits)

Expand Down
10 changes: 5 additions & 5 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torchrec.distributed.embedding_types import KJTList
from torchrec.distributed.types import Awaitable, QuantizedCommCodecs
from torchrec.fx.utils import fx_marker
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.sparse.jagged_tensor import can_use_torch_compiler, KeyedJaggedTensor

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -105,9 +105,9 @@ def _get_recat(
recat.append(i + j * local_split)

vb_condition: bool = batch_size_per_rank is not None
if not torch.compiler.is_dynamo_compiling():
if can_use_torch_compiler() and not torch.compiler.is_dynamo_compiling():
vb_condition = vb_condition and any(
# pyre-ignore
# pyre-ignor
bs != batch_size_per_rank[0]
# pyre-ignore
for bs in batch_size_per_rank
Expand Down Expand Up @@ -384,7 +384,7 @@ def _wait_impl(self) -> KJTAllToAllTensorsAwaitable:
self._output_splits = output_list[:-1]
self._stride_per_rank = output_list[-1]

if torch.compiler.is_dynamo_compiling():
if can_use_torch_compiler() and torch.compiler.is_dynamo_compiling():
rank: int = self._pg.rank()
for i in range(len(self._output_splits)):
for j in range(len(self._output_splits[i])):
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def forward(
"""

# Dynamo can not trace through data dependent condition: len(set(input_splits)) > 1
if torch.compiler.is_dynamo_compiling():
if can_use_torch_compiler() and torch.compiler.is_dynamo_compiling():
if input_splits is not None:
tensor_awaitable = reduce_scatter_v_pooled(
local_embs, input_splits, self._pg, codecs=self._codecs
Expand Down
15 changes: 12 additions & 3 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
return False


def can_use_torch_compiler() -> bool:
return not torch.jit.is_scripting() and not torch._running_with_deploy()


def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
if is_torchdynamo_compiling():
# TODO(ivankobzarev): Dynamo trace with pin_memory once FakeTensor supports it.
Expand Down Expand Up @@ -248,7 +252,11 @@ def _permute_tensor_by_segments(


def is_non_strict_exporting() -> bool:
return not torch.compiler.is_dynamo_compiling() and torch.compiler.is_compiling()
return (
can_use_torch_compiler()
and not torch.compiler.is_dynamo_compiling()
and torch.compiler.is_compiling()
)


class JaggedTensorMeta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
Expand Down Expand Up @@ -676,7 +684,7 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
def _assert_tensor_has_no_elements_or_has_integers(
tensor: torch.Tensor, tensor_name: str
) -> None:
if torch.compiler.is_dynamo_compiling():
if can_use_torch_compiler() and torch.compiler.is_dynamo_compiling():
# Skipping assert on tensor.numel() == 0 for dynamo to avoid DataDependentError
return

Expand Down Expand Up @@ -802,7 +810,8 @@ def _maybe_compute_length_per_key(
else:
cond: bool = False
if (
torch.compiler.is_dynamo_compiling()
can_use_torch_compiler()
and torch.compiler.is_dynamo_compiling()
and not torch.jit.is_scripting()
):
# pyre-ignore
Expand Down

0 comments on commit de3f53d

Please sign in to comment.