Skip to content

Commit

Permalink
Revert D68246404
Browse files Browse the repository at this point in the history
Differential Revision: D68780760

Pull Request resolved: #8002
  • Loading branch information
Gasoonjia authored Jan 28, 2025
1 parent e37a585 commit 98697f6
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 78 deletions.
6 changes: 1 addition & 5 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
ExecutorchProgramManager,
to_edge,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
from executorch.exir.passes import ToOutVarPass
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
Expand Down Expand Up @@ -187,17 +186,14 @@ def export_to_edge(
edge_prog_manager = to_edge(
expo_program,
compile_config=EdgeCompileConfig(
_skip_dim_order=True,
# Allow specific non-core aten ops in the IR.
_core_aten_ops_exception_list=[
torch.ops.aten._native_batch_norm_legit_functional.default,
torch.ops.aten.linear.default,
torch.ops.aten.linalg_vector_norm.default,
torch.ops.aten.unfold.default,
torch.ops.aten.angle.default,
# cadence replaced to_dim_order_copy with _to_copy for performance
# skip _to_copy op to get around of dim order check
# We should remove this op once cadence can support dim order
exir_ops.edge.aten._to_copy.default,
],
),
constant_methods=constant_methods,
Expand Down
73 changes: 0 additions & 73 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

# pyre-unsafe

import copy
import math
from operator import neg
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
Expand All @@ -36,12 +35,7 @@
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from executorch.exir.dim_order_utils import get_memory_format
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from executorch.exir.passes.dim_order_ops_registry import (
DimOrderOpsMap,
MemoryFormatOpsMap,
)
from torch._subclasses import FakeTensor
from torch.fx.node import Argument

Expand Down Expand Up @@ -1805,72 +1799,6 @@ def call_operator(
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceToDimOrderCopyWithToCopyPass(ExportPass):
"""
dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass.
If the dim order is sequential, we don't need the extra work with strides and
can just use to_copy.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in DimOrderOpsMap:
return super().call_operator(op, args, kwargs, meta)

# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable

ndim = None

# can always get the shape, assuming rank is specialized

# pyre-ignore[16]: `None` has no attribute `to_tensor`
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
# pyre-ignore[16]: `None` has no attribute `to_tensor`
ndim = args[0].to_tensor().dim()
elif isinstance(args[0], torch.Tensor):
# pyre-ignore[16]: `None` has no attribute `dim`
ndim = args[0].dim()
elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):
# pyre-ignore[6]: Incompatible parameter type
ndim = len(args[0])
else:
assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"

# get the "to" memory format for the EdgeOp
contiguous_dim_order = list(range(ndim))
dim_order = nkwargs.pop("dim_order", None)

# Cadence only supports contiguous memory format
assert (
dim_order is None
# pyre-ignore[6]: Incompatible parameter type
or len(dim_order) == 0
or dim_order == contiguous_dim_order
), "Expected dim order in congituous or prevserve memory format, but got {}".format(
dim_order
)

# bring back memory format
# pyre-ignore[6]: Incompatible parameter type
nkwargs["memory_format"] = get_memory_format(dim_order)

memory_format_op = MemoryFormatOpsMap[op]

return super().call_operator(
memory_format_op,
args,
nkwargs,
meta,
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceFullLikeWithFullPass(ExportPass):
"""
Expand Down Expand Up @@ -2180,5 +2108,4 @@ class CadenceReplaceOpsInGraph:
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
ReplaceToDimOrderCopyWithToCopyPass,
]

0 comments on commit 98697f6

Please sign in to comment.