From 296fc63297138ff0c0347630d543daaf2b86ec39 Mon Sep 17 00:00:00 2001 From: dcvz Date: Thu, 16 Nov 2023 21:22:36 +0100 Subject: [PATCH 1/5] Add GroupNorm --- burn-core/src/nn/norm/group.rs | 139 +++++++++++++++++++++++++++++++++ burn-core/src/nn/norm/mod.rs | 2 + 2 files changed, 141 insertions(+) create mode 100644 burn-core/src/nn/norm/group.rs diff --git a/burn-core/src/nn/norm/group.rs b/burn-core/src/nn/norm/group.rs new file mode 100644 index 0000000000..3f4c94959c --- /dev/null +++ b/burn-core/src/nn/norm/group.rs @@ -0,0 +1,139 @@ +use crate as burn; + +use crate::config::Config; +use crate::module::Module; +use crate::module::Param; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; + +/// Configuration to create a [GroupNorm](GroupNorm) layer. +#[derive(Config)] +pub struct GroupNormConfig { + /// The number of groups to separate the channels into + num_groups: usize, + /// The number of channels expected in the input + num_channels: usize, + /// A value required for numerical stability. Default: 1e-5 + #[config(default = 1e-5)] + epsilon: f64, + /// A boolean value that when set to `true`, this module has learnable + /// per-channel affine parameters initialized to ones (for weights) + /// and zeros (for biases). Default: `true` + #[config(default = true)] + affine: bool, +} + +/// Applies Group Normalization over a mini-batch of inputs. +/// +/// `Y = groupnorm(X) * γ + β` +#[derive(Module, Debug)] +pub struct GroupNorm { + num_groups: usize, + num_channels: usize, + gamma: Param>, + beta: Param>, + epsilon: f64, + affine: bool, +} + +impl GroupNormConfig { + /// Initialize a new [group norm](GroupNorm) module. + pub fn init(&self) -> GroupNorm { + assert_eq!( + self.num_channels % self.num_groups, + 0, + "The number of channels must be divisible by the number of groups" + ); + + let gamma = Tensor::ones([self.num_channels]).into(); + let beta = Tensor::zeros([self.num_channels]).into(); + + GroupNorm { + num_groups: self.num_groups, + num_channels: self.num_channels, + gamma, + beta, + epsilon: self.epsilon, + affine: self.affine, + } + } + + /// Initialize a new [group norm](GroupNorm) module with a [record](GroupNormRecord). + pub fn init_with(&self, record: GroupNormRecord) -> GroupNorm { + GroupNorm { + num_groups: self.num_groups, + num_channels: self.num_channels, + gamma: record.gamma, + beta: record.beta, + epsilon: self.epsilon, + affine: self.affine, + } + } +} + +impl GroupNorm { + /// Applies the forward pass on the input tensor. + /// + /// # Shapes + /// + /// - input: `[..., any, d_model]` + /// - output: `[..., any, d_model]` + pub fn forward(&self, input: Tensor) -> Tensor { + let shape = input.shape(); + if shape.num_elements() <= 2 { + panic!( + "input rank for GroupNorm should be at least 3, but got {}", + shape.num_elements() + ); + } + + let batch_size = shape.dims[0]; + let num_channels = shape.dims[1]; + if num_channels != self.num_channels { + panic!( + "expected {} channels but got {}", + self.num_channels, num_channels + ); + } + + let input = input.reshape([ + batch_size, + self.num_groups, + shape.num_elements() / (batch_size * self.num_groups), + ]); + + let mean = input.clone().mean_dim(D - 1); + let var = (mean.clone() * mean.clone()).mean_dim(D - 1); + + let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon)); + let input_normalized = input_normalized.reshape(shape); + + if self.affine { + let mut affine_shape = [1; D]; + affine_shape[1] = num_channels; + + input_normalized + .mul(self.gamma.val().reshape(affine_shape)) + .add(self.beta.val().reshape(affine_shape)) + } else { + input_normalized + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use burn_tensor::{Data, Distribution}; + + #[cfg(feature = "std")] + use crate::{TestAutodiffBackend, TestBackend}; + + #[cfg(not(feature = "std"))] + use crate::TestBackend; + + #[test] + fn group_norm_forward() { + let module = GroupNormConfig::new(3, 6).init::(); + } +} diff --git a/burn-core/src/nn/norm/mod.rs b/burn-core/src/nn/norm/mod.rs index 5a23674973..26c01a1683 100644 --- a/burn-core/src/nn/norm/mod.rs +++ b/burn-core/src/nn/norm/mod.rs @@ -1,5 +1,7 @@ mod batch; +mod group; mod layer; pub use batch::*; +pub use group::*; pub use layer::*; From 7adeefe4991fdb139586eafc469bd8c498063103 Mon Sep 17 00:00:00 2001 From: dcvz Date: Thu, 16 Nov 2023 23:17:38 +0100 Subject: [PATCH 2/5] Fix implemenation and add tests --- burn-core/src/nn/norm/group.rs | 66 +++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/burn-core/src/nn/norm/group.rs b/burn-core/src/nn/norm/group.rs index 3f4c94959c..a5bc487887 100644 --- a/burn-core/src/nn/norm/group.rs +++ b/burn-core/src/nn/norm/group.rs @@ -89,6 +89,7 @@ impl GroupNorm { let batch_size = shape.dims[0]; let num_channels = shape.dims[1]; + if num_channels != self.num_channels { panic!( "expected {} channels but got {}", @@ -96,27 +97,24 @@ impl GroupNorm { ); } - let input = input.reshape([ - batch_size, - self.num_groups, - shape.num_elements() / (batch_size * self.num_groups), - ]); - - let mean = input.clone().mean_dim(D - 1); - let var = (mean.clone() * mean.clone()).mean_dim(D - 1); + let hidden_size = + shape.dims[2..].iter().product::() * num_channels / self.num_groups; + let input = input.reshape([batch_size, self.num_groups, hidden_size]); + let mean = input.clone().sum_dim(2) / hidden_size as f64; + let var = input.clone().sqrt().sum_dim(2) / hidden_size as f64; let input_normalized = input.sub(mean).div(var.sqrt().add_scalar(self.epsilon)); - let input_normalized = input_normalized.reshape(shape); if self.affine { let mut affine_shape = [1; D]; affine_shape[1] = num_channels; input_normalized + .reshape(shape) .mul(self.gamma.val().reshape(affine_shape)) .add(self.beta.val().reshape(affine_shape)) } else { - input_normalized + input_normalized.reshape(shape) } } } @@ -124,16 +122,58 @@ impl GroupNorm { #[cfg(test)] mod tests { use super::*; - use burn_tensor::{Data, Distribution}; + use burn_tensor::Data; #[cfg(feature = "std")] - use crate::{TestAutodiffBackend, TestBackend}; + use crate::TestBackend; #[cfg(not(feature = "std"))] use crate::TestBackend; #[test] fn group_norm_forward() { - let module = GroupNormConfig::new(3, 6).init::(); + let module = GroupNormConfig::new(2, 6).init::(); + let input = Tensor::from_data(Data::from([ + [ + [-0.3034f32, 0.2726, -0.9659], + [-1.1845, -1.3236, 0.0172], + [1.9507, 1.2554, -0.8625], + [1.0682, 0.3604, 0.3985], + [-0.4957, -0.4461, -0.9721], + [1.5157, -0.1546, -0.5596], + ], + [ + [-1.6698, -0.4040, -0.7927], + [0.3736, -0.0975, -0.1351], + [-0.9461, 0.5461, -0.6334], + [-1.0919, -0.1158, 0.1213], + [-0.9535, 0.1281, 0.4372], + [-0.2845, 0.3488, 0.5641], + ], + ])); + + let output = module.forward(input); + + output.to_data().assert_approx_eq( + &Data::from([ + [ + [-0.1653, 0.3748, -0.7866], + [-0.9916, -1.1220, 0.1353], + [1.9485, 1.2965, -0.6896], + [1.2769, 0.3628, 0.4120], + [-0.7427, -0.6786, -1.3578], + [1.8547, -0.3022, -0.8252], + ], + [ + [-1.9342, 0.0211, -0.5793], + [1.2223, 0.4945, 0.4365], + [-0.8163, 1.4887, -0.3333], + [-1.7960, -0.0392, 0.3875], + [-1.5469, 0.3998, 0.9561], + [-0.3428, 0.7970, 1.1845], + ], + ]), + 3, + ); } } From 34b24f6941b459a19aadd1fba888cdc4604a4013 Mon Sep 17 00:00:00 2001 From: dcvz Date: Sun, 19 Nov 2023 23:46:10 +0100 Subject: [PATCH 3/5] Address PR comments --- burn-core/src/nn/norm/group.rs | 99 +++++++++++++++++++++++++++++----- 1 file changed, 86 insertions(+), 13 deletions(-) diff --git a/burn-core/src/nn/norm/group.rs b/burn-core/src/nn/norm/group.rs index a5bc487887..8bd1d29aad 100644 --- a/burn-core/src/nn/norm/group.rs +++ b/burn-core/src/nn/norm/group.rs @@ -30,8 +30,8 @@ pub struct GroupNormConfig { pub struct GroupNorm { num_groups: usize, num_channels: usize, - gamma: Param>, - beta: Param>, + gamma: Option>>, + beta: Option>>, epsilon: f64, affine: bool, } @@ -45,8 +45,14 @@ impl GroupNormConfig { "The number of channels must be divisible by the number of groups" ); - let gamma = Tensor::ones([self.num_channels]).into(); - let beta = Tensor::zeros([self.num_channels]).into(); + let (gamma, beta) = if self.affine { + let gamma = Tensor::ones([self.num_channels]).into(); + let beta = Tensor::zeros([self.num_channels]).into(); + + (Some(gamma), Some(beta)) + } else { + (None, None) + }; GroupNorm { num_groups: self.num_groups, @@ -111,8 +117,8 @@ impl GroupNorm { input_normalized .reshape(shape) - .mul(self.gamma.val().reshape(affine_shape)) - .add(self.beta.val().reshape(affine_shape)) + .mul(self.gamma.clone().unwrap().val().reshape(affine_shape)) + .add(self.beta.clone().unwrap().val().reshape(affine_shape)) } else { input_normalized.reshape(shape) } @@ -123,16 +129,17 @@ impl GroupNorm { mod tests { use super::*; use burn_tensor::Data; - - #[cfg(feature = "std")] - use crate::TestBackend; - - #[cfg(not(feature = "std"))] use crate::TestBackend; #[test] - fn group_norm_forward() { - let module = GroupNormConfig::new(2, 6).init::(); + fn group_norm_forward_affine_false() { + let module = GroupNormConfig::new(2, 6) + .with_affine(false) + .init::(); + + assert!(module.gamma.is_none()); + assert!(module.beta.is_none()); + let input = Tensor::from_data(Data::from([ [ [-0.3034f32, 0.2726, -0.9659], @@ -176,4 +183,70 @@ mod tests { 3, ); } + + #[test] + fn group_norm_forward_affine_true() { + let module = GroupNormConfig::new(3, 6) + .with_affine(true) + .init::(); + + module + .gamma + .as_ref() + .expect("Gamma is None") + .val() + .to_data() + .assert_approx_eq(&Data::ones([6].into()), 3); + + module + .beta + .as_ref() + .expect("beta is None") + .val() + .to_data() + .assert_approx_eq(&Data::zeros([6]), 3); + + let input = Tensor::from_data(Data::from([ + [ + [-0.3034f32, 0.2726, -0.9659], + [-1.1845, -1.3236, 0.0172], + [1.9507, 1.2554, -0.8625], + [1.0682, 0.3604, 0.3985], + [-0.4957, -0.4461, -0.9721], + [1.5157, -0.1546, -0.5596], + ], + [ + [-1.6698, -0.4040, -0.7927], + [0.3736, -0.0975, -0.1351], + [-0.9461, 0.5461, -0.6334], + [-1.0919, -0.1158, 0.1213], + [-0.9535, 0.1281, 0.4372], + [-0.2845, 0.3488, 0.5641], + ], + ])); + + let output = module.forward(input); + + output.to_data().assert_approx_eq( + &Data::from([ + [ + [0.4560, 1.4014, -0.6313], + [-0.9901, -1.2184, 0.9822], + [1.4254, 0.6360, -1.7682], + [0.4235, -0.3800, -0.3367], + [-0.3890, -0.3268, -0.9862], + [2.1325, 0.0386, -0.4691] + ], + [ + [-1.8797, 0.0777, -0.5234], + [1.2802, 0.5517, 0.4935], + [-1.0102, 1.5327, -0.4773], + [-1.2587, 0.4047, 0.8088], + [-1.9074, 0.1691, 0.7625], + [-0.6230, 0.5928, 1.0061] + ] + ]), + 3, + ); + } } From dc3f9dedddce1980d74c908894f73955b4381268 Mon Sep 17 00:00:00 2001 From: dcvz Date: Mon, 20 Nov 2023 01:07:17 +0100 Subject: [PATCH 4/5] Fix formatting --- burn-core/src/nn/norm/group.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/burn-core/src/nn/norm/group.rs b/burn-core/src/nn/norm/group.rs index 8bd1d29aad..14f1aa5b14 100644 --- a/burn-core/src/nn/norm/group.rs +++ b/burn-core/src/nn/norm/group.rs @@ -128,8 +128,8 @@ impl GroupNorm { #[cfg(test)] mod tests { use super::*; - use burn_tensor::Data; use crate::TestBackend; + use burn_tensor::Data; #[test] fn group_norm_forward_affine_false() { @@ -235,7 +235,7 @@ mod tests { [1.4254, 0.6360, -1.7682], [0.4235, -0.3800, -0.3367], [-0.3890, -0.3268, -0.9862], - [2.1325, 0.0386, -0.4691] + [2.1325, 0.0386, -0.4691], ], [ [-1.8797, 0.0777, -0.5234], @@ -243,8 +243,8 @@ mod tests { [-1.0102, 1.5327, -0.4773], [-1.2587, 0.4047, 0.8088], [-1.9074, 0.1691, 0.7625], - [-0.6230, 0.5928, 1.0061] - ] + [-0.6230, 0.5928, 1.0061], + ], ]), 3, ); From 423b1bbc8b3ae18a106dbd5b5e8ddefc3cf5b7d1 Mon Sep 17 00:00:00 2001 From: dcvz Date: Mon, 20 Nov 2023 10:46:49 +0100 Subject: [PATCH 5/5] Update burn book --- burn-book/src/building-blocks/module.md | 1 + 1 file changed, 1 insertion(+) diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index c4f8b96322..cb5d9cf5ef 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -111,6 +111,7 @@ Burn comes with built-in modules that you can use to build your own modules. | ----------- | --------------------------------------- | | `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. | | `LayerNorm` | `nn.LayerNorm` | +| `GroupNorm` | `nn.GroupNorm` | | `Dropout` | `nn.Dropout` | | `GELU` | `nn.GELU` | | `Linear` | `nn.Linear` |