Skip to content

Commit

Permalink
Support union ControlNet (#2988)
Browse files Browse the repository at this point in the history
* Basic union impl

* nits

* Remove unused imports

* nit
  • Loading branch information
huchenlei authored Jul 9, 2024
1 parent ee96dc9 commit b92e415
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 12 deletions.
6 changes: 6 additions & 0 deletions internal_controlnet/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ControlMode,
HiResFixOption,
PuLIDMode,
ControlNetUnionControlType,
)
from annotator.util import HWC3

Expand Down Expand Up @@ -202,6 +203,11 @@ def parse_effective_region_mask(cls, value) -> np.ndarray:
# https://github.com/ToTheBeginning/PuLID
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY

# ControlNet control type for ControlNet union model.
# https://github.com/xinsir6/ControlNetPlus/tree/main
# The value of this field is only used when the model is ControlNetUnion.
union_control_type: ControlNetUnionControlType = ControlNetUnionControlType.UNKNOWN

# ------- API only fields -------
# The tensor input for ipadapter. When this field is set in the API,
# the base64string will be interpret by torch.load to reconstruct ipadapter
Expand Down
88 changes: 83 additions & 5 deletions scripts/cldm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List
import torch
import torch.nn as nn

from modules import devices

from scripts.controlnet_core.controlnet_union import ControlAddEmbedding, ResBlockUnionControlnet

try:
from sgm.modules.diffusionmodules.openaimodel import conv_nd, linear, zero_module, timestep_embedding, \
Expand All @@ -26,7 +27,7 @@ def __init__(self, config, state_dict=None):

def reset(self):
pass

def forward(self, *args, **kwargs):
return self.control_model(*args, **kwargs)

Expand Down Expand Up @@ -57,7 +58,7 @@ def send_me_to_gpu(module, _):
def fullvram(self):
self.to(devices.get_device_for("controlnet"))
return


class ControlNet(nn.Module):
def __init__(
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
union_controlnet=False,
device=None,
global_average_pooling=False,
):
Expand Down Expand Up @@ -280,10 +282,74 @@ def __init__(
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch

if union_controlnet:
self.num_control_type = 6
num_trans_channel = 320
num_trans_head = 8
num_trans_layer = 1
num_proj_channel = 320
self.task_embedding = nn.Parameter(torch.empty(
self.num_control_type, num_trans_channel, dtype=self.dtype, device=device
))

self.transformer_layes = nn.Sequential(*[
ResBlockUnionControlnet(
num_trans_channel, num_trans_head, dtype=self.dtype, device=device
)
for _ in range(num_trans_layer)
])
self.spatial_ch_projs = nn.Linear(
num_trans_channel, num_proj_channel, dtype=self.dtype, device=device
)

control_add_embed_dim = 256
self.control_add_embedding = ControlAddEmbedding(
control_add_embed_dim, time_embed_dim, self.num_control_type,
dtype=self.dtype, device=device
)
else:
self.task_embedding = None
self.control_add_embedding = None

def union_controlnet_merge(
self,
hint: torch.Tensor,
control_type: List[int],
emb: torch.Tensor,
context: torch.Tensor
):
""" Note: control_type is a list of enum values. The length of the list
is the number of control images."""
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
inputs = []
condition_list = []

for idx in range(min(1, len(control_type))):
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
if idx < len(control_type):
feat_seq += self.task_embedding[control_type[idx]]

inputs.append(feat_seq.unsqueeze(1))
condition_list.append(controlnet_cond)

x = torch.cat(inputs, dim=1)
x = self.transformer_layes(x)
controlnet_cond_fuser = None
for idx in range(len(control_type)):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
o = condition_list[idx] + alpha
if controlnet_cond_fuser is None:
controlnet_cond_fuser = o
else:
controlnet_cond_fuser += o
return controlnet_cond_fuser

def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))

def forward(self, x, hint, timesteps, context, y=None, **kwargs):
def forward(self, x, hint, timesteps, context, y=None, control_type: List[int] = None, **kwargs):
original_type = x.dtype

x = x.to(self.dtype)
Expand All @@ -297,7 +363,19 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb)

guided_hint = self.input_hint_block(hint, emb, context)
guided_hint = None
if self.control_add_embedding is not None:
assert control_type is not None

emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
if len(control_type) > 0:
if len(hint.shape) < 5:
hint = hint.unsqueeze(dim=0)
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)

if guided_hint is None:
guided_hint = self.input_hint_block(hint, emb, context)

outs = []

if self.num_classes is not None:
Expand Down
6 changes: 6 additions & 0 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,10 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe
control_model_type.is_controlnet and
model_net.control_model.global_average_pooling
)

if control_model_type == ControlModelType.ControlNetUnion:
logger.info(f"ControlNetUnion control type: {unit.union_control_type}")

forward_param = ControlParams(
control_model=model_net,
preprocessor=preprocessor_dict,
Expand All @@ -1047,6 +1051,8 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe
if unit.effective_region_mask is not None
else None
),
# TODO: Implement merge of units with the same union model.
union_control_types=[unit.union_control_type],
)
forward_params.append(forward_param)

Expand Down
118 changes: 118 additions & 0 deletions scripts/controlnet_core/controlnet_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from collections import OrderedDict
import torch
import torch.nn as nn

try:
from sgm.modules.diffusionmodules.openaimodel import (
timestep_embedding,
)

using_sgm = True
except ImportError:
from ldm.modules.diffusionmodules.openaimodel import (
timestep_embedding,
)

using_sgm = False


def attention_pytorch(
q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False
):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)

out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
return out


class ControlAddEmbedding(nn.Module):
def __init__(
self,
in_dim,
out_dim,
num_control_type,
dtype=None,
device=None,
):
super().__init__()
self.num_control_type = num_control_type
self.in_dim = in_dim
self.linear_1 = nn.Linear(
in_dim * num_control_type, out_dim, dtype=dtype, device=device
)
self.linear_2 = nn.Linear(out_dim, out_dim, dtype=dtype, device=device)

def forward(self, control_type, dtype, device):
c_type = torch.zeros((self.num_control_type,), device=device)
c_type[control_type] = 1.0
c_type = (
timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False)
.to(dtype)
.reshape((-1, self.num_control_type * self.in_dim))
)
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))


class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.heads = nhead
self.c = c

self.in_proj = nn.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
self.out_proj = nn.Linear(c, c, bias=True, dtype=dtype, device=device)

def forward(self, x):
x = self.in_proj(x)
q, k, v = x.split(self.c, dim=2)
out = attention_pytorch(q, k, v, self.heads)
return self.out_proj(out)


class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class ResBlockUnionControlnet(nn.Module):
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
super().__init__()
self.attn = OptimizedAttention(
dim, nhead, dtype=dtype, device=device, operations=operations
)
self.ln_1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.mlp = nn.Sequential(
OrderedDict(
[
(
"c_fc",
nn.Linear(dim, dim * 4, dtype=dtype, device=device),
),
("gelu", QuickGELU()),
(
"c_proj",
nn.Linear(dim * 4, dim, dtype=dtype, device=device),
),
]
)
)
self.ln_2 = nn.LayerNorm(dim, dtype=dtype, device=device)

def attention(self, x: torch.Tensor):
return self.attn(x)

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
19 changes: 15 additions & 4 deletions scripts/controlnet_model_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,25 @@ def build_model_by_guess(state_dict, unet, model_path: str) -> ControlModel:
final_state_dict[key] = p_new
state_dict = final_state_dict

config['use_fp16'] = devices.dtype_unet == torch.float16
if "control_add_embedding.linear_1.bias" in state_dict: # Controlnet Union
config["union_controlnet"] = True
final_state_dict = {}
for k in list(state_dict.keys()):
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
final_state_dict[new_k] = state_dict.pop(k)
state_dict = final_state_dict

network = PlugableControlModel(config, state_dict)
network.to(devices.dtype_unet)
if "instant_id" in model_path.lower():
control_model_type = ControlModelType.ControlNetUnion
elif "instant_id" in model_path.lower():
control_model_type = ControlModelType.InstantID
else:
control_model_type = ControlModelType.ControlNet

config['use_fp16'] = devices.dtype_unet == torch.float16

network = PlugableControlModel(config, state_dict)
network.to(devices.dtype_unet)

return ControlModel(network, control_model_type)

if 'conv_in.weight' in state_dict:
Expand Down
24 changes: 24 additions & 0 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
PuLIDMode,
ControlMode,
ResizeMode,
ControlNetUnionControlType,
)
from modules import shared
from modules.ui_components import FormRow, FormHTML, ToolButton
Expand Down Expand Up @@ -265,6 +266,7 @@ def __init__(
self.output_dir_state = None
self.advanced_weighting = gr.State(None)
self.pulid_mode = None
self.union_control_type = None

# API-only fields
self.ipadapter_input = gr.State(None)
Expand Down Expand Up @@ -487,6 +489,13 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
visible=False,
)

with gr.Row():
self.union_control_type = gr.Textbox(
label="Union Control Type",
value=ControlNetUnionControlType.UNKNOWN.value,
visible=False,
)

with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
self.type_filter = (
gr.Dropdown
Expand Down Expand Up @@ -664,6 +673,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
self.advanced_weighting,
self.effective_region_mask,
self.pulid_mode,
self.union_control_type,
)

unit = gr.State(ControlNetUnit())
Expand Down Expand Up @@ -841,6 +851,19 @@ def filter_selected(k: str):
show_progress=False,
)

def register_union_control_type(self):
def filter_selected(k: str):
control_type = ControlNetUnionControlType.from_str(k)
logger.debug(f"Switch to union control type {control_type}")
return gr.update(value=control_type.value)

self.type_filter.change(
fn=filter_selected,
inputs=[self.type_filter],
outputs=[self.union_control_type],
show_progress=False,
)

def register_sd_version_changed(self):
def sd_version_changed(type_filter: str, current_model: str):
"""When SD version changes, update model dropdown choices."""
Expand Down Expand Up @@ -1227,6 +1250,7 @@ def register_core_callbacks(self):
self.register_webcam_mirror_toggle()
self.register_refresh_all_models()
self.register_build_sliders()
self.register_union_control_type()
self.register_shift_preview()
self.register_shift_upload_mask()
self.register_shift_pulid_mode()
Expand Down
Loading

0 comments on commit b92e415

Please sign in to comment.