Skip to content

Commit

Permalink
feat: module init (tracel-ai#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 26, 2022
1 parent 4abc281 commit 46d06f0
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 24 deletions.
5 changes: 5 additions & 0 deletions burn-tch/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ impl<E: TchElement> Backend for TchBackend<E> {
.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap());
tensor
}
Distribution::Normal(mean, std) => {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.normal(mean, std);
tensor
}
}
}

Expand Down
1 change: 1 addition & 0 deletions burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", opti
num-traits = "0.2"
derive-new = "0.5"
rand = "0.8"
statrs = "0.16"
half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work with tch

# Autodiff
Expand Down
9 changes: 9 additions & 0 deletions burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub enum Distribution<P> {
Standard,
Bernoulli(f64),
Uniform(P, P),
Normal(f64, f64),
}

#[derive(new)]
Expand All @@ -39,6 +40,7 @@ where
Standard(rand::distributions::Standard),
Uniform(rand::distributions::Uniform<P>),
Bernoulli(rand::distributions::Bernoulli),
Normal(statrs::distribution::Normal),
}

impl<'a, P> DistributionSampler<'a, P>
Expand All @@ -58,6 +60,9 @@ where
P::zeros(&P::default())
}
}
DistributionSamplerKind::Normal(distribution) => {
self.rng.sample(distribution).to_elem()
}
}
}
}
Expand All @@ -78,6 +83,9 @@ where
Distribution::Bernoulli(prob) => DistributionSamplerKind::Bernoulli(
rand::distributions::Bernoulli::new(prob).unwrap(),
),
Distribution::Normal(mean, std) => DistributionSamplerKind::Normal(
statrs::distribution::Normal::new(mean, std).unwrap(),
),
};

DistributionSampler::new(kind, rng)
Expand All @@ -93,6 +101,7 @@ where
Distribution::Standard => Distribution::Standard,
Distribution::Uniform(a, b) => Distribution::Uniform(E::from_elem(a), E::from_elem(b)),
Distribution::Bernoulli(prob) => Distribution::Bernoulli(prob),
Distribution::Normal(mean, std) => Distribution::Normal(mean, std),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions burn/src/module/param/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::ParamId;
use crate::module::{LoadingError, State, StateNamed};
use crate::tensor::Element;

/// Define a trainable parameter.
#[derive(Debug)]
pub struct Param<T> {
pub(super) id: ParamId,
Expand Down
4 changes: 2 additions & 2 deletions burn/src/nn/dropout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ impl Dropout {
///
/// # Shapes
///
/// - input: [..., any]
/// - output: [..., any]
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
if !B::ad_enabled() || self.prob == 0.0 {
return input;
Expand Down
15 changes: 10 additions & 5 deletions burn/src/nn/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, ElementConversion, Tensor};
use crate::tensor::{Distribution, Tensor};

/// Configuration to create an [Embedding](Embedding) layer.
#[derive(Config)]
Expand All @@ -16,6 +16,11 @@ pub struct EmbeddingConfig {
}

/// Lookup table to store a fix number of vectors.
///
/// # Params
///
/// - weight: Matrix of shape `[n_embedding, d_model]` initialized from a normal distribution:
/// `N(0, 1)`
#[derive(Module, Debug)]
pub struct Embedding<B: Backend> {
weight: Param<Tensor<B, 2>>,
Expand All @@ -24,10 +29,10 @@ pub struct Embedding<B: Backend> {
impl<B: Backend> Embedding<B> {
/// Create the module from the given configuration.
pub fn new(config: &EmbeddingConfig) -> Self {
let start = -1.0 / f64::sqrt(config.d_model as f64);
let end = 1.0 / f64::sqrt(config.d_model as f64);
let distribution = Distribution::Uniform(start.to_elem(), end.to_elem());
let weight = Tensor::random([config.n_embedding, config.d_model], distribution);
let weight = Tensor::random(
[config.n_embedding, config.d_model],
Distribution::Normal(0.0, 1.0),
);

Self {
weight: Param::new(weight),
Expand Down
4 changes: 2 additions & 2 deletions burn/src/nn/gelu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ impl GELU {
///
/// # Shapes
///
/// - input: [..., any]
/// - output: [..., any]
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
crate::tensor::activation::gelu(&input)
}
Expand Down
4 changes: 2 additions & 2 deletions burn/src/nn/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ impl<B: Backend> LayerNorm<B> {
///
/// # Shapes
///
/// - input: [..., any, d_model]
/// - output: [..., any, d_model]
/// - input: `[..., any, d_model]`
/// - output: `[..., any, d_model]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let (var, mean) = input.var_mean_bias(D - 1);

Expand Down
24 changes: 15 additions & 9 deletions burn/src/nn/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, ElementConversion, Shape, Tensor};
use crate::tensor::{Distribution, ElementConversion, Tensor};
use std::ops::Deref;

/// Configuration to create a [Linear](Linear) layer.
Expand All @@ -22,6 +22,14 @@ pub struct LinearConfig {
/// Applies a linear transformation to the input tensor:
///
/// `O = IW + b`
///
/// # Params
///
/// - weight: Matrix of shape `[d_input, d_output]` initialized from a uniform distribution:
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
///
/// - bias (optional): Vector of size `d_output` initialized from a uniform distribution:
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
#[derive(Module, Debug)]
pub struct Linear<B: Backend> {
weight: Param<Tensor<B, 2>>,
Expand All @@ -31,14 +39,12 @@ pub struct Linear<B: Backend> {
impl<B: Backend> Linear<B> {
/// Create the module from the given configuration.
pub fn new(config: &LinearConfig) -> Self {
// Glorot init
let start = -1.0 / f64::sqrt(config.d_input as f64);
let end = 1.0 / f64::sqrt(config.d_input as f64);
let distribution = Distribution::Uniform(start.to_elem(), end.to_elem());
let k = f64::sqrt(1.0 / config.d_input as f64);
let distribution = Distribution::Uniform((-1.0 * k).to_elem(), k.to_elem());

let weight = Tensor::random(Shape::new([config.d_input, config.d_output]), distribution);
let weight = Tensor::random([config.d_input, config.d_output], distribution);
let bias = match config.bias {
true => Some(Tensor::zeros(Shape::new([config.d_output]))),
true => Some(Tensor::random([config.d_output], distribution)),
false => None,
};

Expand All @@ -52,8 +58,8 @@ impl<B: Backend> Linear<B> {
///
/// # Shapes
///
/// - input: [..., any, d_input]
/// - output: [..., any, d_output]
/// - input: `[..., any, d_input]`
/// - output: `[..., any, d_output]`
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let output = input.matmul(&self.weight.unsqueeze());

Expand Down
4 changes: 2 additions & 2 deletions burn/src/nn/relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ impl ReLU {
///
/// # Shapes
///
/// - input: [..., any]
/// - output: [..., any]
/// - input: `[..., any]`
/// - output: `[..., any]`
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
crate::tensor::activation::relu(&input)
}
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist/src/mlp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ impl<B: Backend> Mlp<B> {
///
/// # Shapes
///
/// - input: [batch_size, d_model]
/// - output: [batch_size, d_model]
/// - input: `[batch_size, d_model]`
/// - output: `[batch_size, d_model]`
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let mut x = input;

Expand Down

0 comments on commit 46d06f0

Please sign in to comment.