Skip to content

Commit

Permalink
Feat/gelu (tracel-ai#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 24, 2022
1 parent a84df25 commit fe5ed0d
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 3 deletions.
6 changes: 4 additions & 2 deletions burn-tensor/examples/autodiff_simple.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use burn_tensor::{activation, backend, Data, Distribution, Shape, Tensor};
use rand::{rngs::StdRng, SeedableRng};

fn loss<B: backend::Backend>(x: &Tensor<B, 2>, y: &Tensor<B, 2>) -> Tensor<B, 2> {
let z = x.matmul(y);
Expand Down Expand Up @@ -35,8 +36,9 @@ fn run<B: backend::Backend>(x: Data<B::Elem, 2>, y: Data<B::Elem, 2>) {

fn main() {
// Same data for all backends
let x = Data::random(Shape::new([2, 3]), Distribution::Standard);
let y = Data::random(Shape::new([3, 1]), Distribution::Standard);
let mut rng = StdRng::from_entropy();
let x = Data::random(Shape::new([2, 3]), Distribution::Standard, &mut rng);
let y = Data::random(Shape::new([3, 1]), Distribution::Standard, &mut rng);

#[cfg(feature = "ndarray")]
{
Expand Down
14 changes: 13 additions & 1 deletion burn-tensor/src/tensor/activation/base.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
use crate::backend::Backend;
use crate::Tensor;
use crate::{ElementPrecision, Precision};
use crate::{ElementConversion, ElementPrecision, Precision};

/// Applies the rectified linear unit function.
pub fn relu<const D: usize, B: Backend>(tensor: &Tensor<B, D>) -> Tensor<B, D> {
tensor.relu()
}

/// Applies the Gaussian Error Linear Units function as described in the paper in [Gaussian Error Linear Units (GELUs)](https://arxiv.org/pdf/1606.08415v3.pdf).
pub fn gelu<const D: usize, B: Backend>(tensor: &Tensor<B, D>) -> Tensor<B, D> {
let x = tensor
.div_scalar(&2.0_f32.sqrt().to_elem())
.erf()
.add_scalar(&1.0_f32.to_elem());
tensor.mul(&x).mul_scalar(&0.5_f32.to_elem())
}

/// Applies the softmax function.
pub fn softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
log_softmax(tensor, dim).exp()
}

/// Applies the log softmax function.
pub fn log_softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let tensor_tmp = match B::Elem::precision() {
Precision::Half => {
Expand Down
18 changes: 18 additions & 0 deletions burn-tensor/tests/tensor/activation/gelu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use super::super::TestBackend;
use burn_tensor::activation;
use burn_tensor::{Data, Tensor};

#[test]
fn test_gelu() {
let data = Data::from([[
0.5447, 0.9809, 0.4114, 0.1398, 0.8045, 0.4103, 0.2388, 0.5262, 0.6677, 0.6737,
]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);

let data_actual = activation::gelu(&tensor).to_data();

let data_expected = Data::from([[
0.3851, 0.8207, 0.2714, 0.0777, 0.6351, 0.2704, 0.1419, 0.3687, 0.4993, 0.5051,
]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
1 change: 1 addition & 0 deletions burn-tensor/tests/tensor/activation/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mod gelu;
mod relu;
mod softmax;
19 changes: 19 additions & 0 deletions burn/src/nn/gelu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use crate::module::Forward;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;

/// Applies the Gaussian Error Linear Units function element-wise.
#[derive(Clone, Debug, Default)]
pub struct GELU {}

impl GELU {
pub fn new() -> Self {
Self {}
}
}

impl<B: Backend, const D: usize> Forward<Tensor<B, D>, Tensor<B, D>> for GELU {
fn forward(&self, input: Tensor<B, D>) -> Tensor<B, D> {
crate::tensor::activation::gelu(&input)
}
}
2 changes: 2 additions & 0 deletions burn/src/nn/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod dropout;
mod gelu;
mod layer_norm;
mod linear;
mod relu;

pub use dropout::*;
pub use gelu::*;
pub use layer_norm::*;
pub use linear::*;
pub use relu::*;

0 comments on commit fe5ed0d

Please sign in to comment.