{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e778599f-aabf-409c-9cc5-cdb9113de0f1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import argparse\n", "import torch\n", "import importlib\n", "import pipeline_flux_rf_inversion\n", "importlib.reload(pipeline_flux_rf_inversion)\n", "from pipeline_flux_rf_inversion import FluxRFInversionPipeline\n", "from diffusers import FluxImg2ImgPipeline, AutoencoderKL\n", "from diffusers.utils import load_image\n", "from diffusers.image_processor import VaeImageProcessor\n", "from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel\n", "import torch\n", "import gc\n", "from IPython.display import Image\n", "import os\n", "\n", "from PIL import Image\n", "\n", "ckpt_id = \"black-forest-labs/FLUX.1-dev\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "25a6c9d9-55ef-4be7-bc14-b15558d5dbd6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading shards: 100%|████████████████████████████████████████████████████| 2/2 [00:00<00:00, 3672.77it/s]\n", "Loading checkpoint shards: 100%|███████████████████████████████████████████████| 2/2 [00:00<00:00, 8.77it/s]\n", "You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers\n", "Loading pipeline components...: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 609.90it/s]\n" ] } ], "source": [ "text_encoder = CLIPTextModel.from_pretrained(\n", " ckpt_id, subfolder=\"text_encoder\", torch_dtype=torch.bfloat16\n", ")\n", "text_encoder_2 = T5EncoderModel.from_pretrained(\n", " ckpt_id, subfolder=\"text_encoder_2\", torch_dtype=torch.bfloat16\n", ")\n", "tokenizer = CLIPTokenizer.from_pretrained(ckpt_id, subfolder=\"tokenizer\")\n", "tokenizer_2 = T5TokenizerFast.from_pretrained(ckpt_id, subfolder=\"tokenizer_2\")\n", "\n", "pipeline_encoder = FluxRFInversionPipeline.from_pretrained(\n", " ckpt_id,\n", " text_encoder=text_encoder,\n", " text_encoder_2=text_encoder_2,\n", " tokenizer=tokenizer,\n", " tokenizer_2=tokenizer_2,\n", " transformer=None,\n", " vae=None,\n", ").to(\"cuda:0\")\n", "\n", "vae = AutoencoderKL.from_pretrained(\"black-forest-labs/FLUX.1-schnell\", revision=\"refs/pr/1\", subfolder=\"vae\", torch_dtype=torch.bfloat16).to(\n", " \"cuda:0\"\n", ")" ] }, { "cell_type": "code", "execution_count": 3, "id": "69bf8028-1bac-4951-91ad-e0b4de46b114", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Keyword arguments {'generator': } are not expected by FluxRFInversionPipeline and will be ignored.\n", "Loading pipeline components...: 100%|██████████████████████████████████████████| 2/2 [00:00<00:00, 3.80it/s]\n" ] } ], "source": [ "\n", "pipeline = FluxRFInversionPipeline.from_pretrained(\n", " #\"black-forest-labs/FLUX.1-dev\",\n", " ckpt_id,\n", " text_encoder=None,\n", " text_encoder_2=None,\n", " tokenizer=None,\n", " tokenizer_2=None,\n", " vae=None,\n", " torch_dtype=torch.bfloat16,\n", " generator=torch.Generator(\"cuda:1\").manual_seed(0)\n", ").to(\"cuda:1\")\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "7ba1aa0e-ada3-4718-bc3d-8a2b355880c2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Encoding prompts.\n", "Encoding null prompts.\n" ] } ], "source": [ "init_image = Image.open(\"panda.png\").resize((1024, 1024))\n", "prompt = \"a panda wearing a crown\"\n", "prompt_2 = prompt\n", "gamma = 0.5\n", "eta = 0.9\n", "start_timestep = 0\n", "stop_timestep = 12\n", "num_inference_steps = 28\n", "strength = 0.95\n", "guidance_scale = 3.5\n", "output = \"stop12.jpg\"\n", "\n", "# split basename and extension\n", "save_base, ext = output.rsplit(\".\", 1)\n", "\n", "kwargs = {\"gamma\": gamma, \"eta\": eta, \"start_timestep\": start_timestep, \"stop_timestep\": stop_timestep} \n", "\n", "\n", "print(\"Encoding prompts.\")\n", "prompt_embeds, pooled_prompt_embeds, text_ids = pipeline_encoder.encode_prompt(\n", " prompt=prompt, prompt_2=None, max_sequence_length=256\n", ")\n", "\n", "print(\"Encoding null prompts.\")\n", "null_prompt_embeds, null_pooled_prompt_embeds, null_text_ids = pipeline_encoder.encode_prompt(\n", " prompt=\"\", prompt_2=\"\", max_sequence_length=256\n", ")\n", "\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "5014e83d-a4eb-44e2-836c-ebb7bad3635f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████████████████| 27/27 [00:30<00:00, 1.11s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Output image saved as stop12.jpg\n" ] } ], "source": [ "\n", "images = pipeline(\n", " prompt_embeds=prompt_embeds.to(\"cuda:1\"),\n", " pooled_prompt_embeds=pooled_prompt_embeds.to(\"cuda:1\"),\n", " null_prompt_embeds=null_prompt_embeds.to(\"cuda:1\"),\n", " null_pooled_prompt_embeds=null_pooled_prompt_embeds.to(\"cuda:1\"),\n", " image=init_image,\n", " num_inference_steps=num_inference_steps,\n", " strength=strength,\n", " guidance_scale=guidance_scale,\n", " vae=vae,\n", " **kwargs,\n", ").images[0]\n", "\n", "images.save(f\"{save_base}.{ext}\")\n", "print(f\"Output image saved as {save_base}.{ext}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "c69864f7-21d9-4d5e-baa6-7ba3436de282", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }