diff --git a/src/onediff/infer_compiler/utils/args_tree_util.py b/src/onediff/infer_compiler/utils/args_tree_util.py index a875b07ad..95b764e95 100644 --- a/src/onediff/infer_compiler/utils/args_tree_util.py +++ b/src/onediff/infer_compiler/utils/args_tree_util.py @@ -1,6 +1,7 @@ import torch import oneflow as flow from oneflow.framework.args_tree import ArgsTree +from .log_utils import logger def input_output_processor(func): @@ -13,10 +14,13 @@ def input_fn(value): return value args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor) + input_count = len( + [v for v in args_tree.iter_nodes() if isinstance(v, torch.Tensor)] + ) out = args_tree.map_leaf(input_fn) mapped_args = out[0] mapped_kwargs = out[1] - return mapped_args, mapped_kwargs + return mapped_args, mapped_kwargs, input_count def process_output(output): def output_fn(value): @@ -29,9 +33,19 @@ def output_fn(value): out = out_tree.map_leaf(output_fn) return out[0] - def wrapper(cls, *args, **kwargs): - mapped_args, mapped_kwargs = process_input(*args, **kwargs) - output = func(cls, *mapped_args, **mapped_kwargs) + def wrapper(self: "DeployableModule", *args, **kwargs): + mapped_args, mapped_kwargs, input_count = process_input(*args, **kwargs) + if self._deployable_module_use_graph: + count = self._deployable_module_input_count + if count != input_count: + logger.warning( + f"Module {type(self._deployable_module_model.oneflow_module)} input tensor count changed from {count} to {input_count}, will compile again." + ) + self._deployable_module_dpl_graph = None + self._load_graph_first_run = True + self._deployable_module_input_count = input_count + + output = func(self, *mapped_args, **mapped_kwargs) return process_output(output) return wrapper diff --git a/src/onediff/infer_compiler/with_oneflow_compile.py b/src/onediff/infer_compiler/with_oneflow_compile.py index 085f07847..5501a423e 100644 --- a/src/onediff/infer_compiler/with_oneflow_compile.py +++ b/src/onediff/infer_compiler/with_oneflow_compile.py @@ -198,6 +198,8 @@ def __init__( self._deployable_module_options = options self._deployable_module_dpl_graph = None self._is_raw_deployable_module = True + self._load_graph_first_run = True + self._deployable_module_input_count = None @classmethod def from_existing(cls, existing_module, use_graph=None, dynamic=None, options=None): @@ -207,6 +209,11 @@ def from_existing(cls, existing_module, use_graph=None, dynamic=None, options=No instance._deployable_module_dpl_graph = ( existing_module._deployable_module_dpl_graph if use_graph else None ) + instance._load_graph_first_run = existing_module._load_graph_first_run + instance._deployable_module_input_count = ( + existing_module._deployable_module_input_count + ) + return instance def get_graph(self):