Skip to content

Commit

Permalink
FLUX dev and lite models support (openvinotoolkit#1340)
Browse files Browse the repository at this point in the history
```python
prompt = "A cat holding a sign that says hello world"
image_tensor = pipe.generate(
        prompt,
        num_inference_steps=5,
        generator=Generator(42) 
    )
```

Freepik/flux.1-lite-8B-alpha:


![image](https://github.com/user-attachments/assets/fbc60a4e-bc55-470b-92bd-a80ae454504c)

black-forest-labs/FLUX.1-dev:


![image](https://github.com/user-attachments/assets/4163efff-637d-4877-b46d-9462836ace17)
  • Loading branch information
likholat authored Dec 9, 2024
1 parent d91123a commit d465c5b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class OPENVINO_GENAI_EXPORTS FluxTransformer2DModel {
public:
struct Config {
size_t in_channels = 64;
bool guidance_embeds = false;
size_t m_default_sample_size = 128;

explicit Config(const std::filesystem::path& config_path);
Expand Down
6 changes: 6 additions & 0 deletions src/cpp/src/image_generation/flux_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ class FluxPipeline : public DiffusionPipeline {

ov::Tensor latent_image_ids = prepare_latent_image_ids(generation_config.num_images_per_prompt, height / 2, width / 2);

if (m_transformer->get_config().guidance_embeds) {
ov::Tensor guidance = ov::Tensor(ov::element::f32, {generation_config.num_images_per_prompt});
std::fill_n(guidance.data<float>(), guidance.get_size(), static_cast<float>(generation_config.guidance_scale));
m_transformer->set_hidden_states("guidance", guidance);
}

m_transformer->set_hidden_states("pooled_projections", pooled_prompt_embeds);
m_transformer->set_hidden_states("encoder_hidden_states", prompt_embeds);
m_transformer->set_hidden_states("txt_ids", text_ids);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ FluxTransformer2DModel::Config::Config(const std::filesystem::path& config_path)
using utils::read_json_param;

read_json_param(data, "in_channels", in_channels);
read_json_param(data, "guidance_embeds", guidance_embeds);
file.close();
}

Expand Down Expand Up @@ -95,6 +96,8 @@ FluxTransformer2DModel& FluxTransformer2DModel::reshape(int batch_size,
name_to_shape[input_name] = {height * width / 4, name_to_shape[input_name][1]};
} else if (input_name == "txt_ids") {
name_to_shape[input_name] = {tokenizer_model_max_length, name_to_shape[input_name][1]};
} else if (input_name == "guidance") {
name_to_shape[input_name] = {batch_size};
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/docs/SUPPORTED_MODELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ The pipeline can work with other similar topologies produced by `optimum-intel`
<td>
<ul>
<li><a href="https://huggingface.co/black-forest-labs/FLUX.1-schnell"><code>black-forest-labs/FLUX.1-schnell</code></a></li>
<li><a href="https://huggingface.co/Freepik/flux.1-lite-8B-alpha"><code>Freepik/flux.1-lite-8B-alpha</code></a></li>
<li><a href="https://huggingface.co/black-forest-labs/FLUX.1-dev"><code>black-forest-labs/FLUX.1-dev</code></a></li>
<li><a href="https://huggingface.co/shuttleai/shuttle-3-diffusion"><code>shuttleai/shuttle-3-diffusion</code></a></li>
</ul>
</td>
</tr>
Expand Down

0 comments on commit d465c5b

Please sign in to comment.