Skip to content

Commit

Permalink
Fixes for the stable diffusion example. (huggingface#342)
Browse files Browse the repository at this point in the history
* Fixes for the stable diffusion example.

* Bugfix.

* Another fix.

* Fix for group-norm.

* More fixes to get SD to work.
LaurentMazare authored Aug 8, 2023
1 parent ab35684 commit 89d3926
Showing 6 changed files with 27 additions and 12 deletions.
4 changes: 2 additions & 2 deletions candle-examples/examples/stable-diffusion/clip.rs
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ pub struct Config {
embed_dim: usize, // aka config.hidden_size
activation: Activation, // aka config.hidden_act
intermediate_size: usize,
max_position_embeddings: usize,
pub max_position_embeddings: usize,
// The character to use for padding, use EOS when not set.
pad_with: Option<String>,
num_hidden_layers: usize,
@@ -90,7 +90,7 @@ impl ClipTextEmbeddings {
vs.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(1)?;
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
token_embedding,
position_embedding,
2 changes: 1 addition & 1 deletion candle-examples/examples/stable-diffusion/embeddings.rs
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ impl Timesteps {
let exponent = (exponent / (half_dim as f64 - self.downscale_freq_shift))?;
let emb = exponent.exp()?;
// emb = timesteps[:, None].float() * emb[None, :]
let emb = (xs.unsqueeze(D::Minus1)? * emb.unsqueeze(0)?)?;
let emb = xs.unsqueeze(D::Minus1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
let (cos, sin) = (emb.cos()?, emb.sin()?);
let emb = if self.flip_sin_to_cos {
Tensor::cat(&[&cos, &sin], D::Minus1)?
15 changes: 13 additions & 2 deletions candle-examples/examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
@@ -181,19 +181,29 @@ fn run(args: Args) -> Result<()> {
let device = candle_examples::device(cpu)?;

let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let pad_id = match tokenizer.get_padding() {
Some(padding) => padding.pad_id,
None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
};
println!("Running with prompt \"{prompt}\".");
let tokens = tokenizer
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while tokens.len() < sd_config.clip.max_position_embeddings {
tokens.push(pad_id)
}
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;

let uncond_tokens = tokenizer
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;

println!("Building the Clip transformer.");
@@ -202,6 +212,7 @@ fn run(args: Args) -> Result<()> {
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?;

println!("text-embeddings: {text_embeddings:?}");
println!("Building the autoencoder.");
let vae = sd_config.build_vae(&vae_weights, &device)?;
println!("Building the unet.");
2 changes: 1 addition & 1 deletion candle-examples/examples/stable-diffusion/resnet.rs
Original file line number Diff line number Diff line change
@@ -118,7 +118,7 @@ impl ResnetBlock2D {
.forward(&nn::ops::silu(temb)?)?
.unsqueeze(D::Minus1)?
.unsqueeze(D::Minus1)?
.add(&xs)?,
.broadcast_add(&xs)?,
_ => xs,
};
let xs = self
12 changes: 8 additions & 4 deletions candle-nn/src/group_norm.rs
Original file line number Diff line number Diff line change
@@ -59,17 +59,21 @@ impl GroupNorm {
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let mut w_dims = vec![1; x_shape.len()];
w_dims[1] = n_channels;
let weight = self.weight.reshape(w_dims.clone())?;
let bias = self.bias.reshape(w_dims)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?
.reshape(x_shape)
.reshape(x_shape)?
.broadcast_mul(&weight)?
.broadcast_add(&bias)
}
}

pub fn group_norm(
num_channels: usize,
num_groups: usize,
num_channels: usize,
eps: f64,
vb: crate::VarBuilder,
) -> Result<GroupNorm> {
4 changes: 2 additions & 2 deletions candle-nn/tests/group_norm.rs
Original file line number Diff line number Diff line change
@@ -30,8 +30,8 @@ use test_utils::to_vec3_round;
#[test]
fn group_norm() -> Result<()> {
let device = &Device::Cpu;
let w = Tensor::new(&[1f32], device)?;
let b = Tensor::new(&[0f32], device)?;
let w = Tensor::from_vec(vec![1f32; 6], 6, device)?;
let b = Tensor::from_vec(vec![0f32; 6], 6, device)?;
let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?;
let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?;

0 comments on commit 89d3926

Please sign in to comment.