Skip to content

Commit

Permalink
Lintrunner: Enable mypy testing on backends/arm (#7776)
Browse files Browse the repository at this point in the history
Migration from pyre to mypy in the lintrunner by enabling mypy for
backends/arm.

But, choosing to ignore the directory backends/arm/test.

Adding ignores all over the place. These needs to be fixed properly in
the future, but now we will start to catch new things trying to sneak
in.

Change-Id: Ie7f73d5688aaec3b32dca9f0cd042da94c06f487
  • Loading branch information
perheld authored Jan 28, 2025
1 parent 6cce750 commit f4786ac
Show file tree
Hide file tree
Showing 65 changed files with 191 additions and 179 deletions.
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ code = 'MYPY'
include_patterns = [
# TODO(https://github.com/pytorch/executorch/issues/7441): Gradually start enabling all folders.
# 'backends/**/*.py',
'backends/arm/**/*.py',
'build/**/*.py',
'codegen/**/*.py',
# 'devtools/**/*.py',
Expand All @@ -312,6 +313,7 @@ exclude_patterns = [
'**/third-party/**',
'scripts/check_binary_dependencies.py',
'profiler/test/test_profiler_e2e.py',
'backends/arm/test/**',
]
command = [
'python',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def call(self, graph_module: torch.fx.GraphModule):
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
dim_order = self.HWCM_order
else:
dim_order = tuple(range(node_data.dim()))
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
node.meta["tosa_dim_order"] = dim_order
# Take care of cases when:
# 4D (NHWC) -> >4D (NCH)
Expand Down
18 changes: 11 additions & 7 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from executorch.backends.arm._passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
from executorch.backends.arm._passes.convert_squeezes_to_view import (
from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found]
ConvertSqueezesToViewPass,
)
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
Expand All @@ -30,7 +30,9 @@
)
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
DecomposeSelectPass,
)
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
DecomposeSoftmaxesPass,
)
Expand All @@ -40,18 +42,20 @@
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( # type: ignore[attr-defined]
ConvertMeanDimToAveragePoolPass,
)
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
from executorch.backends.arm._passes.mm_to_bmm_pass import ( # type: ignore[import-not-found]
ConvertMmToBmmPass,
)
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
ScalarsToAttributePass,
Expand Down Expand Up @@ -89,7 +93,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(FoldAndAnnotateQParamsPass())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))

Expand Down Expand Up @@ -125,7 +129,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(FoldAndAnnotateQParamsPass())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))

Expand Down
8 changes: 4 additions & 4 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -58,9 +58,9 @@ def get_param_tensor(
elif is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, node.target)
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
except AttributeError:
return getattr(exp_prog.graph_module, node.target)
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
raise RuntimeError(f"unsupported param type, {node.op}.")


Expand Down Expand Up @@ -156,7 +156,7 @@ def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
)
elif isinstance(key, str):
return args.get(key, default_value) # pyre-ignore[16]
return args.get(key, default_value) # type: ignore[union-attr] # pyre-ignore[16]
elif isclass(key):
for arg in args:
if isinstance(arg, key):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def fold_and_annotate_arg(
node.meta["input_qparams"][i] = input_qparams
for n in nodes_to_remove:
assert n.target == dq_op
n.replace_all_uses_with(n.args[0])
n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type]
graph_module.graph.erase_node(n)

def call(self, graph_module: GraphModule) -> PassResult:
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/_passes/keep_dims_false_to_squeeze_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -66,15 +66,15 @@ def call(self, graph_module: torch.fx.GraphModule):
sum_node = cast(torch.fx.Node, node)
keep_dim = get_node_arg(
# pyre-ignore[6]
sum_node.args,
sum_node.args, # type: ignore[arg-type]
keep_dim_index,
False,
)

if keep_dim:
continue

dim_list = get_node_arg(sum_node.args, 1, [0]) # pyre-ignore[6]
dim_list = get_node_arg(sum_node.args, 1, [0]) # type: ignore[arg-type] # pyre-ignore[6]

# Add keep_dim = True arg to sum node.
set_node_arg(sum_node, 2, True)
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -54,7 +54,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
if isinstance(arg, int) and not torch.is_floating_point(
get_first_fake_tensor(n)
):
new_args.append(arg)
new_args.append(arg) # type: ignore[arg-type]
continue

prefix = "_tensor_constant_"
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from typing import cast, final, List, Optional

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.arm_vela import vela_compile
from executorch.backends.arm.operators.node_visitor import get_node_visitors

Expand Down Expand Up @@ -230,7 +230,7 @@ def preprocess( # noqa: C901
# Converted output for this subgraph, serializer needs path early as it emits
# const data directly. Path created and data written only in debug builds.
tosa_graph = ts.TosaSerializer(artifact_path)
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline(
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore
exported_program=edge_program
)

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Callable, final, List, Optional, Tuple

import torch
from executorch.backends.arm.arm_backend import (
from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined]
ArmBackend,
) # usort: skip
from executorch.backends.arm.operator_support.tosa_supported_operators import (
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/arm_vela.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -12,7 +12,7 @@
from typing import List

import numpy as np
from ethosu.vela import vela
from ethosu.vela import vela # type: ignore


# Pack either input or output tensor block, compose the related arrays into
Expand Down Expand Up @@ -96,13 +96,13 @@ def vela_compile(tosa_graph, args: List[str], shape_order=None):
block_name = block_name + b"\x00" * (16 - len(block_name))

# We need the acual unpadded block lengths for hw setup
block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0)
block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0) # type: ignore[assignment]

# Pad block data to multiple of 16 bytes
block_data = bin_blocks[key]
block_data = block_data + b"\x00" * (15 - (len(block_data) - 1) % 16)

block = block_name + block_length + block_data
block = block_name + block_length + block_data # type: ignore[operator]
blocks = blocks + block

return blocks
10 changes: 5 additions & 5 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
if input_dtype not in supported_dtypes:
logger.info(
f"Input dtype {input_val.dtype} is not supported in "
f"{node.target.name()}." # pyre-ignore[16]
f"{node.target.name()}." # type: ignore[union-attr] # pyre-ignore[16]
)
return False

Expand All @@ -107,7 +107,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
if output_val.dtype not in supported_dtypes[input_dtype]:
logger.info(
f"Output dtype {output_val.dtype} is not supported in "
f"{node.target.name()} for input dtype {input_dtype}. " # pyre-ignore[16]
f"{node.target.name()} for input dtype {input_dtype}. " # type: ignore[union-attr] # pyre-ignore[16]
f"Supported output types: "
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
)
Expand All @@ -118,18 +118,18 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
if node.kwargs["memory_format"] in (torch.preserve_format,):
logger.info(
f"Argument 'memory_format' is not supported for "
f"{node.target.name()} right now." # pyre-ignore[16]
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
)
return False

# Check dim_order (to_dim_order_copy)
if "dim_order" in node.kwargs:
dim_order = node.kwargs["dim_order"]
# pyre-ignore[6]
if dim_order != list(range(len(dim_order))):
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
logger.info(
f"Argument {dim_order=} is not supported for "
f"{node.target.name()} right now." # pyre-ignore[16]
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
)
return False

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,5 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
def is_node_supported_custom(self, node: fx.Node) -> bool:
tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
if node.target in tosa_checks.keys():
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec)
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index]
return False
6 changes: 3 additions & 3 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -7,7 +7,7 @@

from typing import Dict, List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand Down Expand Up @@ -44,7 +44,7 @@ def define_node(


# container for all node visitors
_node_visitor_dicts = {
_node_visitor_dicts = { # type: ignore[var-annotated]
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
}
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -10,7 +10,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand Down Expand Up @@ -75,7 +75,7 @@ def define_node(
if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined]


@register_node_visitor
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
import torch

# pyre-fixme[21]: ' Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -7,7 +7,7 @@
# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
import torch

# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
Expand Down Expand Up @@ -75,14 +75,14 @@ def define_node(
if output.dtype == ts.DType.INT8:
output_qparams = get_output_qparams(node)[0] # pyre-ignore[16]
final_output_scale = (
input_qparams[0].scale * input_qparams[1].scale # pyre-ignore[61]
input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61]
) / output_qparams.scale

build_rescale(
tosa_fb=tosa_graph,
scale=final_output_scale,
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
input_node=bmm_result,
input_node=bmm_result, # type: ignore[possibly-undefined]
output_name=output.name,
output_type=ts.DType.INT8,
output_shape=bmm_result.shape,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_cat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -7,7 +7,7 @@

from typing import List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand Down
Loading

0 comments on commit f4786ac

Please sign in to comment.