Skip to content

Commit

Permalink
add further nerf-attention
Browse files Browse the repository at this point in the history
  • Loading branch information
rromb committed Jul 6, 2022
1 parent 0540e68 commit e2a6bee
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ model:
image_size: 64 # unused
in_channels: 4
out_channels: 4
model_channels: 352
attention_resolutions: [ 8, 4, 2 ]
num_res_blocks: [ 2, 2, 2, 6 ]
model_channels: 384
attention_resolutions: [ 8, 4, 2, 1 ]
num_res_blocks: [ 2, 2, 2, 5 ]
channel_mult: [ 1, 2, 4, 4 ]
disable_self_attentions: [ True, True, True, False ] # converts the self-attention to a cross-attention layer if true
disable_self_attentions: [ False, False, False, False ] # converts the self-attention to a cross-attention layer if true
num_attention_blocks: [1, 1, 1, 3]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
Expand Down
60 changes: 36 additions & 24 deletions ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ def __init__(
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None
disable_self_attentions=None,
num_attention_blocks=None
):
super().__init__()
if use_spatial_transformer:
Expand Down Expand Up @@ -503,6 +504,13 @@ def __init__(
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.") # todo: convert to warning

self.attention_resolutions = attention_resolutions
self.dropout = dropout
Expand Down Expand Up @@ -538,7 +546,7 @@ def __init__(
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(self.num_res_blocks[level]):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
Expand All @@ -564,18 +572,20 @@ def __init__(
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa

if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
)
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
Expand Down Expand Up @@ -670,18 +680,20 @@ def __init__(
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa

if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
)
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
Expand Down

0 comments on commit e2a6bee

Please sign in to comment.