Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port(models): unet #2

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ wgpu-backend = ["burn-wgpu"]

[dependencies.burn-wgpu]
package = "burn-wgpu"
version = "0.10.0"
version = "0.11.0"
optional = true

[dependencies]
anyhow = "1.0.75"
burn = "0.10.0"
burn-tch = "0.10.0"
burn = "0.11.0"
burn-tch = "0.11.0"
clap = { version = "4.4.7", features = ["derive"] }
serde = { version = "1.0.192", features = ["derive"] }

[patch.crates-io]
burn = { git = "https://github.com/OxideAI/burn", branch = "feature/unsqueeze-dim" }
burn-wgpu = { git = "https://github.com/OxideAI/burn", branch = "feature/unsqueeze-dim" }
burn-tch = { git = "https://github.com/OxideAI/burn", branch = "feature/unsqueeze-dim" }
16 changes: 7 additions & 9 deletions src/cli/txt2img.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
use anyhow::Result;
use clap::Args;

use burn::{
module::{Module},
tensor::{backend::Backend},
};
use burn::{module::Module, tensor::backend::Backend};

#[cfg(feature = "wgpu-backend")]
use burn_wgpu::{WgpuBackend, WgpuDevice, AutoGraphicsApi};
#[cfg(not(feature = "wgpu-backend"))]
use burn_tch::{TchBackend, TchDevice};
#[cfg(feature = "wgpu-backend")]
use burn_wgpu::{AutoGraphicsApi, WgpuBackend, WgpuDevice};
use diffusers_burn::pipelines::stable_diffusion;

const GUIDANCE_SCALE: f64 = 7.5;
Expand Down Expand Up @@ -85,7 +82,6 @@ pub struct Txt2ImgArgs {

#[arg(long)]
use_flash_attn: bool,

// #[arg(long)]
// use_f16: bool,
}
Expand All @@ -98,11 +94,13 @@ enum StableDiffusionVersion {
}

pub fn handle_txt2img(args: &Txt2ImgArgs) -> Result<()> {
#[cfg(feature = "wgpu-backend")] {
#[cfg(feature = "wgpu-backend")]
{
type Backend = WgpuBackend<AutoGraphicsApi, f32, i32>;
let device = WgpuDevice::BestAvailable;
}
#[cfg(not(feature = "wgpu-backend"))] {
#[cfg(not(feature = "wgpu-backend"))]
{
type Backend = TchBackend<f32>;
let device = TchDevice::Cpu;
}
Expand Down
16 changes: 2 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
pub mod models;
pub mod pipelines;
pub mod transformers;

// pub fn add(left: usize, right: usize) -> usize {
// left + right
// }
//
// #[cfg(test)]
// mod tests {
// use super::*;
//
// #[test]
// fn it_works() {
// let result = add(2, 2);
// assert_eq!(result, 4);
// }
// }
mod utils;
146 changes: 146 additions & 0 deletions src/models/attention.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
//! Attention Based Building Blocks

use burn::config::Config;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::{gelu, softmax};
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;

#[derive(Config)]
pub struct GeGluConfig {
dim_in: usize,
dim_out: usize,
}

impl GeGluConfig {
fn init<B: Backend>(&self) -> GeGlu<B> {
let proj = LinearConfig::new(self.dim_in, self.dim_out * 2).init();
GeGlu { proj }
}
}

#[derive(Module, Debug)]
struct GeGlu<B: Backend> {
proj: Linear<B>
}

impl<B: Backend> GeGlu<B> {
fn forward(&self, xs: Tensor<B, 3>) -> Tensor<B, 3> {
let projected = self.proj.forward(xs);
let [n_batch, n_ctx, n_state] = projected.dims();

let n_state_out = n_state / 2;

let xs = projected
.clone()
.slice([0..n_batch, 0..n_ctx, 0..n_state_out]);
let gate = projected.slice([0..n_batch, 0..n_ctx, n_state_out..n_state]);

xs * gelu(gate)
}
}

#[derive(Config)]
pub struct FeedForwardConfig {
dim: usize,
dim_out: Option<usize>,
mult: usize,
}

impl FeedForwardConfig {
fn init<B: Backend>(&self) -> FeedForward<B> {
let inner_dim = self.dim * self.mult;
let dim_out = self.dim_out.unwrap_or(self.dim);
let proj_in = GeGluConfig::new(self.dim, inner_dim).init();
let linear = LinearConfig::new(inner_dim, dim_out).init();
FeedForward { proj_in, linear }
}
}

#[derive(Module, Debug)]
struct FeedForward<B: Backend> {
proj_in: GeGlu<B>,
linear: Linear<B>,
}

impl<B: Backend> FeedForward<B> {
fn forward(&self, xs: Tensor<B, 3>) -> Tensor<B, 3> {
let xs = self.proj_in.forward(xs);
self.linear.forward(xs)
}
}

#[derive(Config)]
pub struct CrossAttentionConfig {
d_query: usize,
d_context: Option<usize>,
n_heads: usize,
d_head: usize,
slice_size: Option<usize>,
use_flash_attn: bool,
}

#[derive(Module, Debug)]
struct CrossAttention<B: Backend> {
query: Linear<B>,
key: Linear<B>,
value: Linear<B>,
output: Linear<B>,
n_heads: usize,
scale: f64,
slice_size: Option<usize>,
use_flash_attn: bool,
}

impl CrossAttentionConfig {
fn init<B: Backend>(&self) -> CrossAttention<B> {
let linear = |in_dim: usize, out_dim: usize| {
LinearConfig::new(in_dim, out_dim)
.with_bias(false)
.init()
};

let inner_dim = self.d_head * self.n_heads;
let context_dim = self.d_context.unwrap_or(self.d_query);

CrossAttention {
query: linear(self.d_query, inner_dim),
key: linear(context_dim, inner_dim),
value: linear(context_dim, inner_dim),
output: linear(inner_dim, self.d_query),
n_heads: self.n_heads,
scale: 1.0 / f64::sqrt(self.d_head as f64),
slice_size: self.slice_size,
use_flash_attn: self.use_flash_attn,
}
}
}


impl<B: Backend> CrossAttention<B> {
fn reshape_heads_to_batch_dim(&self, xs: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch_size, seq_len, dim] = xs.dims();
xs.reshape([batch_size, seq_len, self.n_heads, dim / self.n_heads])
.swap_dims(1, 2)
.reshape([batch_size * self.n_heads, seq_len, dim / self.n_heads])
}

fn reshape_batch_dim_to_heads(&self, xs: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch_size, seq_len, dim] = xs.dims();
xs.reshape([batch_size / self.n_heads, self.n_heads, seq_len, dim])
.swap_dims(1, 2)
.reshape([batch_size / self.n_heads, seq_len, dim * self.n_heads])
}
}

// trait TensorStack where Self: Sized {
// fn stack<B: Backend, const D2: usize>(xs: &[Self], dim: usize) -> Tensor<B, D2>;
// }
//
// impl<B: Backend, const D: usize> TensorStack for Tensor<B, D> {
// fn stack<C: Backend, const D2: usize>(xs: &[Self], dim: usize) -> Tensor<C, D2> {
//
//
// }
// }
117 changes: 117 additions & 0 deletions src/models/groupnorm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use burn::config::Config;
use burn::module::{Module, Param};
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;

/// Configuration to create a [GroupNorm](GroupNorm) layer.
#[derive(Config)]
pub struct GroupNormConfig {
/// The number of groups to separate the channels into
num_groups: usize,
/// The number of channels expected in the input
num_channels: usize,
/// A value required for numerical stability. Default: 1e-5
#[config(default = 1e-5)]
epsilon: f64,
/// A boolean value that when set to `true`, this module has learnable
/// per-channel affine parameters initialized to ones (for weights)
/// and zeros (for biases). Default: `true`
#[config(default = true)]
affine: bool,
}

/// Applies Group Normalization over a mini-batch of inputs.
///
/// `Y = groupnorm(X) * γ + β`
#[derive(Module, Debug)]
pub struct GroupNorm<B: Backend> {
num_groups: usize,
num_channels: usize,
gamma: Param<Tensor<B, 1>>,
beta: Param<Tensor<B, 1>>,
epsilon: f64,
affine: bool,
}

impl GroupNormConfig {
/// Initialize a new [group norm](GroupNorm) module.
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
assert_eq!(
self.num_channels % self.num_groups,
0,
"The number of channels must be divisible by the number of groups"
);

let gamma = Tensor::ones([self.num_channels]).into();
let beta = Tensor::zeros([self.num_channels]).into();

GroupNorm {
num_groups: self.num_groups,
num_channels: self.num_channels,
gamma,
beta,
epsilon: self.epsilon,
affine: self.affine,
}
}

/// Initialize a new [group norm](GroupNorm) module with a [record](GroupNormRecord).
pub fn init_with<B: Backend>(&self, record: GroupNormRecord<B>) -> GroupNorm<B> {
GroupNorm {
num_groups: self.num_groups,
num_channels: self.num_channels,
gamma: record.gamma,
beta: record.beta,
epsilon: self.epsilon,
affine: self.affine,
}
}
}

impl<B: Backend> GroupNorm<B> {
/// Applies the forward pass on the input tensor.
///
/// # Shapes
///
/// - input: `[..., any, d_model]`
/// - output: `[..., any, d_model]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let shape = input.shape();
if shape.num_elements() <= 2 {
panic!(
"input rank for GroupNorm should be at least 3, but got {}",
shape.num_elements()
);
}

let batch_size = shape.dims[0];
let num_channels = shape.dims[1];

if num_channels != self.num_channels {
panic!(
"expected {} channels but got {}",
self.num_channels, num_channels
);
}

let hidden_size =
shape.dims[D - 1..].iter().product::<usize>() * num_channels / self.num_groups;
let input = input.reshape([batch_size, self.num_groups, hidden_size]);

let mean = input.clone().sum_dim(D - 1) / hidden_size as f64;
let var = input.clone().sqrt().sum_dim(D - 1) / hidden_size as f64;
let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon));

if self.affine {
let mut affine_shape = [1; D];
affine_shape[1] = num_channels;

input_normalized
.reshape(shape)
.mul(self.gamma.val().reshape(affine_shape))
.add(self.beta.val().reshape(affine_shape))
} else {
input_normalized.reshape(shape)
}
}
}
9 changes: 9 additions & 0 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//! # Models
//!
//! A collection of models to be used in a diffusion loop.

mod groupnorm;

pub mod attention;
pub mod resnet;
pub mod unet_2d_blocks;
Loading