Skip to content

Commit

Permalink
refactor graph reuse (#790)
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold authored Apr 11, 2024
1 parent b9b7921 commit 85d5666
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch

from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.oneflow import oneflow_compiler_config
from onediff.schedulers import EulerDiscreteScheduler
from diffusers import StableDiffusionXLPipeline

Expand Down Expand Up @@ -62,7 +61,6 @@
base.to("cuda")


oneflow_compiler_config.mlir_enable_inference_optimization = False
# Compile unet with oneflow
if args.compile_unet:
print("Compiling unet with oneflow.")
Expand Down
13 changes: 11 additions & 2 deletions onediff_sd_webui_extensions/compile_vae.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from modules import shared
from modules.sd_vae_approx import model as get_vae_model, sd_vae_approx_models
from onediff.infer_compiler import oneflow_compile
from modules.sd_vae_approx import VAEApprox
from onediff.infer_compiler import oneflow_compile, register
from onediff.infer_compiler.transform import proxy_class

__all__ = ["VaeCompileCtx"]


compiled_models = {}

class VAEApproxOflow(proxy_class(VAEApprox)):
pass

torch2oflow_class_map = {
VAEApprox: VAEApproxOflow,
}

register(package_names=["modules"], torch2oflow_class_map=torch2oflow_class_map)

class VaeCompileCtx(object):
def __init__(self, options=None):
Expand Down
25 changes: 6 additions & 19 deletions onediff_sd_webui_extensions/onediff_lora.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import torch
import oneflow as flow
from onediff.infer_compiler.deployable_module import DeployableModule
from onediff.infer_compiler.utils.param_utils import update_graph_related_tensor


class HijackLoraActivate:
def __init__(self, conv_dict=None):
def __init__(self):
from modules import extra_networks

self.conv_dict = conv_dict

if "lora" in extra_networks.extra_network_registry:
cls_extra_network_lora = type(extra_networks.extra_network_registry["lora"])
else:
Expand All @@ -19,9 +17,7 @@ def __enter__(self):
if self.lora_class is None:
return
self.orig_func = self.lora_class.activate
self.lora_class.activate = hijacked_activate(
self.lora_class.activate, conv_dict=self.conv_dict
)
self.lora_class.activate = hijacked_activate(self.lora_class.activate)

def __exit__(self, exc_type, exc_val, exc_tb):
if self.lora_class is None:
Expand All @@ -31,7 +27,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.orig_func = None


def hijacked_activate(activate_func, *, conv_dict=None):
def hijacked_activate(activate_func):
import networks

if hasattr(activate_func, "_onediff_hijacked"):
Expand All @@ -53,17 +49,8 @@ def activate(self, p, params_list):
):
continue
networks.network_apply_weights(sub_module)

# for LyCORIS cases
if conv_dict is not None and isinstance(sub_module, torch.nn.Conv2d):
target_tensor = conv_dict.get(name + ".weight", None)
if target_tensor is None:
continue
target_tensor.copy_(
flow.utils.tensor.from_torch(
sub_module.weight.permute(0, 2, 3, 1)
)
)
if isinstance(sub_module, torch.nn.Conv2d):
update_graph_related_tensor(sub_module)

activate._onediff_hijacked = True
return activate
47 changes: 6 additions & 41 deletions onediff_sd_webui_extensions/scripts/onediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
import gradio as gr
from pathlib import Path
from typing import Union, Dict
from collections import defaultdict
import oneflow as flow
import modules.scripts as scripts
import modules.shared as shared
from modules.sd_models import select_checkpoint
from modules.processing import process_images
from modules import script_callbacks

from compile_ldm import compile_ldm_unet, SD21CompileCtx
from compile_sgm import compile_sgm_unet
Expand All @@ -19,7 +16,7 @@
from onediff_hijack import do_hijack as onediff_do_hijack

from onediff.infer_compiler.utils.log_utils import logger
from onediff.infer_compiler.utils.env_var import parse_boolean_from_env
from onediff.infer_compiler.utils.param_utils import get_constant_folding_info
from onediff.optimization.quant_optimizer import (
quantize_model,
varify_can_use_quantization,
Expand Down Expand Up @@ -107,7 +104,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):

class Script(scripts.Script):
current_type = None
convname_dict = None

def title(self):
return "onediff_diffusion_model"
Expand Down Expand Up @@ -187,50 +183,19 @@ def run(self, p, quantization=False):
model_changed = ckpt_name != compiled_ckpt_name
model_structure_changed = self.check_model_structure_change(shared.sd_model)
need_recompile = (quantization and model_changed) or model_structure_changed
if not need_recompile:
logger.info(
f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile"
)
if model_changed:
# need to transpose conv weights
for k in self.convname_dict:
orig_tensor = original_diffusion_model.get_parameter(k)
target_tensor = self.convname_dict[k]
if target_tensor is None:
need_recompile = True
break
target_tensor.copy_(
flow.utils.tensor.from_torch(orig_tensor.permute(0, 2, 3, 1))
)

if need_recompile:
compiled_unet = compile_unet(
original_diffusion_model, quantization=quantization
)
compiled_ckpt_name = ckpt_name
self.convname_dict = None
else:
logger.info(
f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile"
)

with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(
self.convname_dict
):
with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate():
proc = process_images(p)

# AutoNHWC will transpose conv weight, which generate a new tensor in graph
# The part is to find the corresponding relationship between the tensors before/after transpose
def convert_var_name(s: str, prefix="variable_transpose_"):
s = re.sub(r"_[0-9]+$", "", s.removeprefix(prefix)).removeprefix("model.")
return s

if not quantization and self.convname_dict is None:
self.convname_dict = {}
run_state = (
compiled_unet._deployable_module_dpl_graph._c_nn_graph.get_runtime_var_states()
)
self.convname_dict = {
convert_var_name(k): v
for k, v in zip(run_state[0], run_state[1])
if k.startswith("variable_")
}
return proc


Expand Down
9 changes: 9 additions & 0 deletions src/onediff/infer_compiler/backends/oneflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from .registry import register_backend
from ..utils.options import CompileOptions
from ..utils.param_utils import state_update_hook, init_state_update_attr, forward_pre_check_and_update_state_hook, forward_generate_constant_folding_info_hook


@register_backend("oneflow")
Expand Down Expand Up @@ -57,4 +58,12 @@ def state_dict_hook(module, state_dict, prefix, local_metadata):

model._register_state_dict_hook(state_dict_hook)

# for checking state dict update of torch_module
model._torch_module.register_load_state_dict_post_hook(state_update_hook)
init_state_update_attr(model._torch_module)

# hooks for constant folding
model.register_forward_pre_hook(forward_pre_check_and_update_state_hook)
model.register_forward_hook(forward_generate_constant_folding_info_hook)

return model
8 changes: 5 additions & 3 deletions src/onediff/infer_compiler/oneflow/deployable_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def from_existing(cls, existing_module, dynamic=True, options=None):
instance._deployable_module_input_count = (
existing_module._deployable_module_input_count
)
instance._deployable_module_quant_config = existing_module._deployable_module_quant_config
instance._deployable_module_quant_config = (
existing_module._deployable_module_quant_config
)

return instance

Expand Down Expand Up @@ -85,12 +87,12 @@ def apply_model(self, *args, **kwargs):
*args, **kwargs
)
return output

@quantize_and_deploy_wrapper
@input_output_processor
@handle_deployable_exception
@graph_file_management
def __call__(self, *args, **kwargs):
def forward(self, *args, **kwargs):
if self._deployable_module_options.use_graph:
dpl_graph = self.get_graph()
with oneflow_exec_mode():
Expand Down
156 changes: 155 additions & 1 deletion src/onediff/infer_compiler/utils/param_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import re
import torch
import oneflow as flow
from typing import List, Dict, Any
from typing import List, Dict, Any, Union

from .log_utils import logger


def parse_device(args: List[Any], kwargs: Dict[str, Any]):
Expand All @@ -25,3 +28,154 @@ def _convert(device):
return device

return _convert(current_device) == _convert(target_device)


# hooks and helper functions for constant folding conv weights

STATE_UPDATED_ATTR = "_onediff_state_updated"
CONSTANT_FOLDING_INFO_ATTR = "_onediff_constant_folding_info"
GRAPH_RELATED_TENSOR_ATTR = "_onediff_graph_related_tensor"


def init_state_update_attr(module: torch.nn.Module):
from onediff.infer_compiler.deployable_module import DeployableModule

if isinstance(module, DeployableModule):
module = module._torch_module
if not isinstance(module, torch.nn.Module):
raise TypeError(f"module must be a torch.nn.Module, got {type(module)}")
setattr(module, STATE_UPDATED_ATTR, False)


def set_constant_folded_conv_attr(
deployable_module, constant_folding_info: Dict[str, flow.Tensor] = None
) -> None:
from onediff.infer_compiler.deployable_module import DeployableModule

if not isinstance(deployable_module, DeployableModule):
raise TypeError(
f"deployable_model must be a DeployableModule, got {type(deployable_module)}"
)

constant_folding_info = constant_folding_info or get_constant_folding_info(
deployable_module
)
if constant_folding_info is None:
return

torch_module: torch.nn.Module = deployable_module._torch_module
for submodule in torch_module.modules():
if isinstance(submodule, torch.nn.Conv2d) and hasattr(
submodule, GRAPH_RELATED_TENSOR_ATTR
):
delattr(submodule, GRAPH_RELATED_TENSOR_ATTR)

for weight_name, weight_tensor in constant_folding_info.items():
submodule = deployable_module._torch_module.get_submodule(
weight_name.removesuffix(".weight")
)
object.__setattr__(submodule, GRAPH_RELATED_TENSOR_ATTR, weight_tensor)


def generate_constant_folding_info(
deployable_module, torch_module: torch.nn.Module = None
) -> Dict[str, flow.Tensor]:
# convert str like 'variable_transpose_model.input_blocks.10.0.in_layers.2.weight_239'
# to 'input_blocks.10.0.in_layers.2.weight'
def convert_var_name(s: str, prefix="variable_transpose_"):
s = re.sub(r"_[0-9]+$", "", s.removeprefix(prefix)).removeprefix("model.")
return s

from onediff.infer_compiler.deployable_module import DeployableModule

if not isinstance(deployable_module, DeployableModule):
raise TypeError(
f"deployable_model must be a DeployableModule, got {type(deployable_module)}"
)
if torch_module is None:
torch_module = deployable_module._torch_module

graph = deployable_module._deployable_module_dpl_graph
if graph is None:
raise RuntimeError(f"The graph of deployable_module is not built yet")

result = {
convert_var_name(k): v
for k, v in zip(*graph._c_nn_graph.get_runtime_var_states())
if k.startswith("variable_")
}
setattr(deployable_module, CONSTANT_FOLDING_INFO_ATTR, result)
set_constant_folded_conv_attr(deployable_module, result)


def update_graph_with_constant_folding_info(
module: torch.nn.Module, info: Dict[str, flow.Tensor] = None
) -> None:
from onediff.infer_compiler.deployable_module import DeployableModule

if isinstance(module, DeployableModule):
if info is None:
info = get_constant_folding_info(module)
module = module._torch_module
if info is None:
return

for k in info:
orig_tensor = module.get_parameter(k)
target_tensor = info.get(k, None)
if target_tensor is None:
raise RuntimeError(f"Can't find tensor named {k} in graph")
target_tensor.copy_(
flow.utils.tensor.from_torch(orig_tensor.permute(0, 2, 3, 1))
)


def update_graph_related_tensor(module: torch.nn.Conv2d) -> None:
if not isinstance(module, torch.nn.Conv2d):
return
target_tensor = getattr(module, GRAPH_RELATED_TENSOR_ATTR, None)
if target_tensor is None:
return
target_tensor.copy_(
flow.utils.tensor.from_torch(module.weight.data.permute(0, 2, 3, 1))
)


def get_constant_folding_info(module) -> Union[Dict[str, flow.Tensor], None]:
from onediff.infer_compiler.deployable_module import DeployableModule

if not isinstance(module, DeployableModule):
raise TypeError(f"module must be a DeployableModule, got {type(module)}")
return getattr(module, CONSTANT_FOLDING_INFO_ATTR, None)


def state_update_hook(module, incompatible_keys):
if not hasattr(module, STATE_UPDATED_ATTR):
return
logger.info(f"load_state_dict called, set {STATE_UPDATED_ATTR} to True")
setattr(module, STATE_UPDATED_ATTR, True)


def forward_generate_constant_folding_info_hook(module, args, output):
if module._deployable_module_dpl_graph is None:
return

if getattr(module, CONSTANT_FOLDING_INFO_ATTR, None) is not None:
return

generate_constant_folding_info(module)


def forward_pre_check_and_update_state_hook(module, args):
if module._deployable_module_dpl_graph is None:
return

if not getattr(module._torch_module, STATE_UPDATED_ATTR, False):
return

constant_folding_info = getattr(module, CONSTANT_FOLDING_INFO_ATTR, None)
if constant_folding_info is None:
return

update_graph_with_constant_folding_info(module, constant_folding_info)
setattr(module._torch_module, STATE_UPDATED_ATTR, False)

0 comments on commit 85d5666

Please sign in to comment.