Skip to content

Commit

Permalink
do not automatically enable xformers (open-mmlab#1640)
Browse files Browse the repository at this point in the history
* do not automatically enable xformers

* uP
  • Loading branch information
patrickvonplaten authored Dec 9, 2022
1 parent 63c4944 commit 6b68afd
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 11 deletions.
10 changes: 10 additions & 0 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
Expand Down Expand Up @@ -488,6 +489,15 @@ def main(args):
revision=args.revision,
)

if is_xformers_available():
try:
unet.enable_xformers_memory_efficient_attention(True)
except Exception as e:
logger.warning(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)

vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
Expand Down
10 changes: 10 additions & 0 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms
from tqdm.auto import tqdm
Expand Down Expand Up @@ -364,6 +365,15 @@ def main():
revision=args.revision,
)

if is_xformers_available():
try:
unet.enable_xformers_memory_efficient_attention(True)
except Exception as e:
logger.warning(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)

# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
Expand Down
10 changes: 10 additions & 0 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, whoami

# TODO: remove and import from diffusers.utils when the new version of diffusers is released
Expand Down Expand Up @@ -439,6 +440,15 @@ def main():
revision=args.revision,
)

if is_xformers_available():
try:
unet.enable_xformers_memory_efficient_attention(True)
except Exception as e:
logger.warning(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)

# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))

Expand Down
11 changes: 0 additions & 11 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import Optional

Expand Down Expand Up @@ -447,16 +446,6 @@ def __init__(
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim)

# if xformers is installed try to use memory_efficient_attention by default
if is_xformers_available():
try:
self.set_use_memory_efficient_attention_xformers(True)
except Exception as e:
warnings.warn(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
Expand Down

0 comments on commit 6b68afd

Please sign in to comment.