Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OneDiffCheckpointLoader #457

Merged
merged 7 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions onediff_comfy_nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ControlNetGraphSaver,
SVDSpeedup,
ModuleDeepCacheSpeedup,
OneDiffCheckpointLoaderSimple,
)
from ._compare_node import CompareModel, ShowImageDiff

Expand All @@ -30,6 +31,7 @@
"ControlNetGraphSaver": ControlNetGraphSaver,
"SVDSpeedup": SVDSpeedup,
"ModuleDeepCacheSpeedup": ModuleDeepCacheSpeedup,
"OneDiffCheckpointLoaderSimple": OneDiffCheckpointLoaderSimple,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -46,6 +48,7 @@
"ControlNetGraphSaver": "ControlNet Graph Saver",
"SVDSpeedup": "SVD Speedup",
"ModuleDeepCacheSpeedup": "Model DeepCache Speedup",
"OneDiffCheckpointLoaderSimple": "OneDiff Load Checkpoint",
ccssu marked this conversation as resolved.
Show resolved Hide resolved
}

if _USE_UNET_INT8:
Expand Down
38 changes: 38 additions & 0 deletions onediff_comfy_nodes/_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import folder_paths
from comfy import model_management
from comfy.cli_args import args
from folder_paths import get_input_directory

from .utils import (
OneFlowSpeedUpModelPatcher,
Expand Down Expand Up @@ -577,3 +578,40 @@ def apply_model(model_function, kwargs):

oneflow_model.set_model_unet_function_wrapper(apply_model)
return (oneflow_model,)


def get_guess_graph_path(ckpt_name, model):
ccssu marked this conversation as resolved.
Show resolved Hide resolved
input_dir = get_input_directory()
input_dir = Path(input_dir)
graph_dir = input_dir / "graphs" / ckpt_name
graph_file_path = graph_dir / (type(model).__name__ + ".graph")
return graph_file_path


from nodes import CheckpointLoaderSimple
class OneDiffCheckpointLoaderSimple(CheckpointLoaderSimple):
doombeaker marked this conversation as resolved.
Show resolved Hide resolved
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ccssu marked this conversation as resolved.
Show resolved Hide resolved
model, clip, vae = super().load_checkpoint(ckpt_name, output_vae, output_clip)
offload_device = model_management.unet_offload_device()

diffusion_model = model.model.diffusion_model
file_path = get_guess_graph_path(ckpt_name, diffusion_model)
print(f" OneDiffCheckpointLoaderSimple load_checkpoint file_path {file_path}")
ccssu marked this conversation as resolved.
Show resolved Hide resolved

oneflow_model = OneFlowSpeedUpModelPatcher(
model.model,
load_device=model_management.get_torch_device(),
offload_device=offload_device,
use_graph=True,
graph_path=file_path,
graph_device=model_management.get_torch_device(),
)

file_path = get_guess_graph_path(ckpt_name, vae.first_stage_model)
vae.first_stage_model = oneflow_compile(
vae.first_stage_model,
use_graph=True,
graph_path=file_path,
graph_device=model_management.get_torch_device(),
)
return oneflow_model, clip, vae
18 changes: 13 additions & 5 deletions onediff_comfy_nodes/utils/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(
weight_inplace_update=False,
*,
use_graph=None,
graph_path=None,
graph_device=None,
):
from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.with_oneflow_compile import DeployableModule
Expand All @@ -46,7 +48,10 @@ def __init__(
] = self.model.diffusion_model
else:
self.model.__dict__["_modules"]["diffusion_model"] = oneflow_compile(
self.model.diffusion_model, use_graph=use_graph
self.model.diffusion_model,
use_graph=use_graph,
graph_path=graph_path,
graph_device=graph_device,
)
self.model._register_state_dict_hook(state_dict_hook)
self.patches = {}
Expand Down Expand Up @@ -495,7 +500,6 @@ def __init__(
):
from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.with_oneflow_compile import DeployableModule


self.weight_inplace_update = weight_inplace_update
self.object_patches = {}
Expand All @@ -504,10 +508,14 @@ def __init__(
self.model = copy.copy(model)
self.model.__dict__["_modules"] = copy.copy(model.__dict__["_modules"])
self.deep_cache_unet = oneflow_compile(
DeepCacheUNet(self.model.diffusion_model, cache_layer_id, cache_block_id), use_graph=use_graph
DeepCacheUNet(self.model.diffusion_model, cache_layer_id, cache_block_id),
use_graph=use_graph,
)
self.fast_deep_cache_unet =oneflow_compile(
FastDeepCacheUNet(self.model.diffusion_model, cache_layer_id, cache_block_id), use_graph=use_graph
self.fast_deep_cache_unet = oneflow_compile(
FastDeepCacheUNet(
self.model.diffusion_model, cache_layer_id, cache_block_id
),
use_graph=use_graph,
)
self.model._register_state_dict_hook(state_dict_hook)
self.patches = {}
Expand Down
113 changes: 96 additions & 17 deletions src/onediff/infer_compiler/with_oneflow_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def _align_tensor(torch_module, oneflow_module):
[x for x, _ in oneflow_module.named_parameters()]
+ [x for x, _ in oneflow_module.named_buffers()]
)
for name, tensor in chain.from_iterable([
torch_module.named_parameters(),
torch_module.named_buffers(),
]):
for name, tensor in chain.from_iterable(
[
torch_module.named_parameters(),
torch_module.named_buffers(),
]
):
if name not in oneflow_tensor_list:
tensor.data = tensor.to(*args, **kwargs)
else:
Expand All @@ -76,7 +78,6 @@ def _align_tensor(torch_module, oneflow_module):
else:
_align_tensor(module, self._oneflow_module.get_submodule(name))


def __getattr__(self, name):
if name == "_torch_module":
return self._modules[name]
Expand Down Expand Up @@ -124,7 +125,11 @@ def __init__(self, torch_modules, oneflow_modules):
for torch_module, oneflow_module in zip(
self._torch_modules, self._oneflow_modules
):
dual_modules.append(get_mixed_dual_module(torch_module.__class__)(torch_module, oneflow_module))
dual_modules.append(
get_mixed_dual_module(torch_module.__class__)(
torch_module, oneflow_module
)
)
# clear self._modules since `self._torch_modules = torch_modules` will append a module to self._modules
self._modules.clear()
self += dual_modules
Expand All @@ -147,6 +152,7 @@ def __setattr__(self, key, value):
setattr(self._oneflow_modules, key, value)
return object.__setattr__(self, key, value)


def get_mixed_dual_module(module_cls):
class MixedDualModule(DualModule, module_cls):
def __init__(self, torch_module, oneflow_module):
Expand All @@ -158,29 +164,77 @@ def __init__(self, torch_module, oneflow_module):
def handle_deployable_exception(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
def load_graph_from_config():
ccssu marked this conversation as resolved.
Show resolved Hide resolved
try:
if self._graph_config is not None:
graph_path = self._graph_config[0]
if not os.path.exists(graph_path):
logger.warning(
f"Graph file {graph_path} not exists, skip load graph."
ccssu marked this conversation as resolved.
Show resolved Hide resolved
)
return
graph_device = torch2oflow(self._graph_config[1])
self.load_graph(graph_path, graph_device)
self._graph_config = None
except Exception as e:
logger.error(f"Exception in load_graph_from_config: {e=}")

def save_graph_to_config():
try:
if self._graph_config is not None:
graph_file = self._graph_config[0]
os.makedirs(os.path.dirname(graph_file), exist_ok=True)
self.save_graph(self._graph_config[0])
logger.info(f"Save graph to {self._graph_config[0]} done!")
except Exception as e:
logger.error(f"Exception in save_graph_to_config: {e=}")

finally:
self._graph_config = None

def _run_func():
load_graph_from_config()
result = func(self, *args, **kwargs)
save_graph_to_config()
return result

if transform_mgr.debug_mode:
return func(self, *args, **kwargs)
return _run_func()
else:
try:
return func(self, *args, **kwargs)
return _run_func()
except Exception as e:
logger.error(f"Exception in {func.__name__}: {e=}")
logger.warning("Recompile oneflow module ...")
del self._deployable_module_model.oneflow_module
self._deployable_module_dpl_graph = None
return func(self, *args, **kwargs)
return _run_func()

return wrapper


class DeployableModule(torch.nn.Module):
def __init__(self, torch_module, oneflow_module, use_graph=True, options={}):
def __init__(
self,
torch_module,
oneflow_module,
use_graph=True,
options={},
graph_path=None,
graph_device=None,
):
torch.nn.Module.__init__(self)
self._deployable_module_model = get_mixed_dual_module(torch_module.__class__)(torch_module, oneflow_module)
self._deployable_module_model = get_mixed_dual_module(torch_module.__class__)(
torch_module, oneflow_module
)
self._deployable_module_use_graph = use_graph
self._deployable_module_options = options
self._deployable_module_dpl_graph = None
self._is_raw_deployable_module = True
if graph_path is not None:
ccssu marked this conversation as resolved.
Show resolved Hide resolved
self._graph_config = (graph_path, graph_device)
else:
self._graph_config = None

@classmethod
def from_existing(cls, existing_module, use_graph=None, options=None):
Expand All @@ -190,6 +244,7 @@ def from_existing(cls, existing_module, use_graph=None, options=None):
instance._deployable_module_dpl_graph = (
existing_module._deployable_module_dpl_graph if use_graph else None
)
instance._graph_config = existing_module._graph_config
return instance

def get_graph(self):
Expand Down Expand Up @@ -245,7 +300,10 @@ def to(self, *args, **kwargs):

# assert the target device is same as graph device
target_device = parse_device(args, kwargs)
if target_device is not None and len(self._deployable_module_dpl_graph._blocks) > 0:
if (
target_device is not None
and len(self._deployable_module_dpl_graph._blocks) > 0
):
current_device = next(self._deployable_module_dpl_graph._state()).device
if not check_device(current_device, target_device):
raise RuntimeError(
Expand Down Expand Up @@ -357,11 +415,25 @@ def state_dict_hook(module, state_dict, prefix, local_metadata):


# Return a DeployableModule that using module_cls as it's parent class.
def get_mixed_deployable_module(module_cls):
def get_mixed_deployable_module(module_cls, graph_path=None, graph_device=None):
class MixedDeployableModule(DeployableModule, module_cls):
def __init__(self, torch_module, oneflow_module, use_graph=True, options={}):
def __init__(
self,
torch_module,
oneflow_module,
use_graph=True,
options={},
graph_path=None,
graph_device=None,
):
DeployableModule.__init__(
self, torch_module, oneflow_module, use_graph, options
self,
torch_module,
oneflow_module,
use_graph,
options,
graph_path,
graph_device,
)
self._is_raw_deployable_module = False

Expand All @@ -378,7 +450,14 @@ def from_existing(cls, existing_module, use_graph=None, options=None):
return MixedDeployableModule


def oneflow_compile(torch_module: torch.nn.Module, *, use_graph=True, options={}):
def oneflow_compile(
torch_module: torch.nn.Module,
*,
use_graph=True,
options={},
graph_path=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些放到 options 里面吧,oneflow_compile 要谨慎扩展参数

Copy link
Contributor Author

@ccssu ccssu Dec 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的就是传入 options = { "graph_config": (graph_path, graph_device)} 这种吗 还是 options = { "graph_path":graph_path, "graph_device":graph_device}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

graph_file, graph_file_device

graph_device=None,
):
set_default_registry()

def wrap_module(module):
Expand All @@ -387,7 +466,7 @@ def wrap_module(module):
return module.__class__.from_existing(module, use_graph, options)
else:
return get_mixed_deployable_module(module.__class__)(
module, None, use_graph, options
module, None, use_graph, options, graph_path, graph_device
)

model = wrap_module(torch_module)
Expand Down
Loading