Skip to content

Commit

Permalink
Add None return type to init
Browse files Browse the repository at this point in the history
ghstack-source-id: 8aa50782aa1bb0b413fdb5d54fd0bdf377647727
Pull Request resolved: #132335
  • Loading branch information
oulgen committed Jul 31, 2024
1 parent f0da167 commit d733b03
Show file tree
Hide file tree
Showing 128 changed files with 290 additions and 290 deletions.
2 changes: 1 addition & 1 deletion torch/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __getattr__(self, attr):
class _Classes(types.ModuleType):
__file__ = "_classes.py"

def __init__(self):
def __init__(self) -> None:
super().__init__("torch.classes")

def __getattr__(self, name):
Expand Down
2 changes: 1 addition & 1 deletion torch/_decomp/decompositions_for_rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class PhiloxState:
trace time.
"""

def __init__(self):
def __init__(self) -> None:
self.reset()

def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/backends/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def run_node(self, n: Node) -> Any:
# This gives us the appropriately strided outputs here which will reflect runtime strides.

class FakeifyFirstAOTInvocationGuard:
def __init__(self):
def __init__(self) -> None:
self.tc = torch._guards.TracingContext.try_get()
assert self.tc
torch._guards.TracingContext.try_get().fakify_first_call = True
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class CodeContextDict:
def __init__(self):
def __init__(self) -> None:
self.code_context = ExactWeakKeyDictionary()

def has_context(self, code: types.CodeType):
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def convert(gm):
"""
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
"""
)
Expand Down Expand Up @@ -491,7 +491,7 @@ def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]:


class NopInputReader:
def __init__(self):
def __init__(self) -> None:
self.total = 0

def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
Expand Down
6 changes: 3 additions & 3 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def _fn(*args, **kwargs):
wrapper function.
>> class CallableClass:
>> def __init__(self):
>> def __init__(self) -> None:
>> super().__init__()
>> self.relu = torch.nn.ReLU()
>>
Expand Down Expand Up @@ -577,7 +577,7 @@ def __reduce__(self):


class RunOnlyContext(_TorchDynamoContext):
def __init__(self):
def __init__(self) -> None:
# cudagraph trees relies on generation increment
def on_enter():
torch._dynamo.mutation_guard.GenerationTracker.generation += 1
Expand All @@ -589,7 +589,7 @@ def __reduce__(self):


class DisableContext(_TorchDynamoContext):
def __init__(self):
def __init__(self) -> None:
super().__init__(callback=None)

def __call__(self, fn):
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, name):


class ResetRequired(TorchDynamoException):
def __init__(self):
def __init__(self) -> None:
super().__init__(
textwrap.dedent(
"""
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def print_missing(stack):
class Profiler:
unique_graphs = 0

def __init__(self):
def __init__(self) -> None:
self.prof = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU],
with_stack=should_print_missing(),
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class MutableLocal(MutableLocalBase):
state.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(MutableLocalSource.Local)

def __hash__(self):
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __eq__(self, other):


class BackwardStateGraphArg(GraphArg):
def __init__(self):
def __init__(self) -> None:
super().__init__(
source=None,
_example=BackwardState(),
Expand Down Expand Up @@ -2638,7 +2638,7 @@ class SourcelessBuilder:
if/else type->VariableTracker trees that were cropping up all over dynamo.
"""

def __init__(self):
def __init__(self) -> None:
raise AssertionError("Use SourcelessBuilder.create()")

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/db/examples/class_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ClassMethod(torch.nn.Module):
def method(cls, x):
return x + 1

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 2)

Expand Down
2 changes: 1 addition & 1 deletion torch/_export/db/examples/cond_branch_class_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class CondBranchClassMethod(torch.nn.Module):
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.subm = MySubModule()

Expand Down
2 changes: 1 addition & 1 deletion torch/_export/db/examples/model_attr_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class ModelAttrMutation(torch.nn.Module):
Attribute mutation is not supported.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]

Expand Down
2 changes: 1 addition & 1 deletion torch/_export/db/examples/scalar_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class ScalarOutput(torch.nn.Module):
Returning scalar values from the graph is supported, in addition to Tensor
outputs. Symbolic shapes are captured and rank is specialized.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/db/examples/specialized_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SpecializedAttribute(torch.nn.Module):
Model attributes are specialized.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.a = "moo"
self.b = 4
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/passes/lift_constants_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ConstantAttrMap(collections.abc.MutableMapping):
if that's the case).
"""

def __init__(self):
def __init__(self) -> None:
# Underlying dict that we use to implement this mapping.
self._constant_attrs: Dict[
Union[int, torch.Tensor, FakeScriptObject], List[Any]
Expand Down
2 changes: 1 addition & 1 deletion torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,7 +1413,7 @@ class Result:
constants: Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]]
example_inputs: Optional[Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]]

def __init__(self):
def __init__(self) -> None:
self.serialized_name_to_node: Dict[str, torch.fx.Node] = {}
self.serialized_name_to_meta: Dict[str, MetaType] = {}
self.graph = torch.fx.Graph()
Expand Down
2 changes: 1 addition & 1 deletion torch/_functorch/_aot_autograd/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class SubclassMeta:
# Optional field because we don't compute for inference graphs
grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]] = None

def __init__(self):
def __init__(self) -> None:
# The fields in this class get set after its construction.
pass

Expand Down
2 changes: 1 addition & 1 deletion torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ def functional_call(named_params, named_buffers, *args, **kwargs):
)

class AOTModule(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.orig_module = mod

Expand Down
4 changes: 2 additions & 2 deletions torch/_functorch/autograd_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# We do this by using creating a custom HigherOrderOperator that only functorch
# dispatches specially.
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("custom_function_call")

def __call__(self, autograd_function, *args, **kwargs):
Expand Down Expand Up @@ -713,7 +713,7 @@ def new_forward(ctx, *args, **kwargs):


class AutogradFunctionApply(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("autograd_function_apply")

def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
Expand Down
6 changes: 3 additions & 3 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def __eq__(self, other):


class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
def __init__(self):
def __init__(self) -> None:
self.nn_modules: Dict[str, Any] = {}

def copy_graphstate(self):
Expand Down Expand Up @@ -456,7 +456,7 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
"autocast_cache_enabled",
}

def __init__(self):
def __init__(self) -> None:
self.global_state: Dict[str, Tuple[Callable, ...]] = {}

def copy_graphstate(self):
Expand Down Expand Up @@ -524,7 +524,7 @@ def remove_guards_with_source(self, source):


class GuardsContext(Checkpointable[GuardsCheckpointState]):
def __init__(self):
def __init__(self) -> None:
self.dynamo_guards: GuardsSet = GuardsSet()
self.aotautograd_guards: List[GuardEnvExpr] = []

Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/auto_functionalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class AutoFunctionalized(HigherOrderOperator):
underscore is to prevent collisions with kwarg names in **kwargs.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__("auto_functionalized")

def __call__(
Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class WithEffects(HigherOrderOperator):
per "effect type", which are enumerated in the _EffectType enum.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__("with_effects")

def __call__(
Expand Down
4 changes: 2 additions & 2 deletions torch/_higher_order_ops/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __torch_function__(self, func, types, args, kwargs=None):


class FlexAttentionHOP(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("flex_attention")

def __call__(
Expand Down Expand Up @@ -74,7 +74,7 @@ def __call__(


class FlexAttentionBackwardHOP(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("flex_attention_backward")

def __call__(
Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/out_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class OutDtypeOperator(HigherOrderOperator):
3. Cast the output to `out_dtype`
"""

def __init__(self):
def __init__(self) -> None:
super().__init__("out_dtype")
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
# become different (torch._higher_order_ops.out_dtype) which will result
Expand Down
4 changes: 2 additions & 2 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def identify_mutated_tensors(kernel, kwargs):

# Used for wrapping a Triton Kernel
class TritonKernelWrapperMutation(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_mutation")


Expand All @@ -528,7 +528,7 @@ def __init__(self):

# Used for wrapping a Triton Kernel in a functional manner
class TritonKernelWrapperFunctional(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("triton_kernel_wrapper_functional")


Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class WhileLoopOp(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("while_loop")

def __call__(
Expand Down
8 changes: 4 additions & 4 deletions torch/_higher_order_ops/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Used for testing the HigherOrderOperator mechanism
class Wrap(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("wrap")

def __call__(self, func, *args, **kwargs):
Expand All @@ -36,7 +36,7 @@ def wrapper():


class WrapWithSetGradEnabled(HigherOrderOperator):
def __init__(self):
def __init__(self) -> None:
super().__init__("wrap_with_set_grad_enabled")

def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
Expand Down Expand Up @@ -74,7 +74,7 @@ class WrapActivationCheckpoint(HigherOrderOperator):
partitioners. See TagActivationCheckpoint for more information.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__("wrap_activation_checkpoint")

def __call__(self, function, *args, **kwargs):
Expand Down Expand Up @@ -113,7 +113,7 @@ class TagActivationCheckpoint(HigherOrderOperator):
the forward and recomputed forward in backward.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__("tag_activation_checkpoint")

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,7 +1560,7 @@ def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariabl


class CodeGen:
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.exit_stack = contextlib.ExitStack()

Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp_wrapper_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CppWrapperCuda(CppWrapperCpu):
Generates cpp wrapper for running on GPU and calls CUDA kernels
"""

def __init__(self):
def __init__(self) -> None:
self.device = "cuda"
super().__init__()
self.grid_id = count()
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ class HelperFunctions:
_templates_seen: Dict[str, str] # Template code to function name
finalized_helpers: List[str]

def __init__(self):
def __init__(self) -> None:
self._templates_seen = {}
self.finalized_helpers = []

Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def canonicalization_prefix():
class FreeUnbackedSymbolsOpsHandler:
symbols: Set[sympy.Symbol]

def __init__(self):
def __init__(self) -> None:
self.symbols = set()

def __getattr__(self, name: str) -> Callable[..., Any]:
Expand Down
Loading

0 comments on commit d733b03

Please sign in to comment.