Skip to content

Commit

Permalink
Feat/embedding (tracel-ai#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Oct 23, 2022
1 parent b1df39e commit 847243d
Show file tree
Hide file tree
Showing 45 changed files with 390 additions and 70 deletions.
2 changes: 2 additions & 0 deletions burn-derive/src/module/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let state_fn = param.gen_state_fn();
let load_fn = param.gen_load_fn();
let inner_fn = param.gen_inner_fn();
let detach_fn = param.gen_detach_fn();

let gen = quote! {
impl #generics burn::module::Module for #name #generics_ty #generics_where {
type Backend=B;

#devices_fn
#to_device_fn
#detach_fn

#state_fn
#load_fn
Expand Down
16 changes: 16 additions & 0 deletions burn-derive/src/module/param.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ impl Param {
}
}

pub fn gen_detach_fn(&self) -> TokenStream {
let mut body = quote! {};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
self.#name.detach();
});
}

quote! {
fn detach(&mut self) {
#body
}
}
}

pub fn gen_inner_fn(&self) -> TokenStream {
let mut body = quote! {};
let mut names = Vec::new();
Expand Down
2 changes: 0 additions & 2 deletions burn-tensor/src/tensor/backend/autodiff/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ mod activation;
mod backend;
mod ops;
mod tensor;
mod tensor_ops;

pub use backend::*;
pub use ops::*;
pub use tensor::*;
30 changes: 30 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::{
backend::{autodiff::ADTensor, Backend},
graph::{
node::{ForwardNode, ForwardNodeRef, ForwardNodeState},
ops::{ForwardUnaryRecordedOps, UnaryOps},
},
};
use std::sync::Arc;

pub fn unary_ops_wrapper<B, O, const D1: usize, const D2: usize>(
input: ForwardNodeRef<B::TensorPrimitive<D1>>,
output: B::TensorPrimitive<D2>,
ops: O,
) -> ADTensor<D2, B>
where
B: Backend,
O: UnaryOps<B::TensorPrimitive<D1>, B::TensorPrimitive<D2>> + 'static,
{
let shape = *B::shape(&output);
let state = ForwardNodeState::new(output);

let ops = Arc::new(ops);
let ops = ForwardUnaryRecordedOps::new(input.clone(), ops);
let ops = Arc::new(ops);

let node = ForwardNode::from_unary(&input, state, ops);
let node = Arc::new(node);

ADTensor { node, shape }
}
4 changes: 4 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod add;
mod aggregation;
mod arg;
mod base;
mod cat;
mod creation;
mod detach;
Expand All @@ -12,13 +13,16 @@ mod log;
mod map_comparison;
mod mask;
mod matmul;
mod module;
mod mul;
mod neg;
mod pow;
mod precision;
mod reshape;
mod sub;
mod tensor;
mod transpose;

mod macros;
pub(crate) use base::*;
pub use macros::*;
49 changes: 49 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/ops/module.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use super::unary_ops_wrapper;
use crate::{
backend::{
autodiff::{ADBackendDecorator, ADTensor},
Backend,
},
graph::ops::{UnaryOps, UnaryOpsNodeState},
ops::ModuleOps,
};

#[derive(new, Debug)]
struct EmbeddingBackward<B: Backend> {
indexes: <B::IntegerBackend as Backend>::TensorPrimitive<2>,
}

impl<B: Backend> UnaryOps<B::TensorPrimitive<2>, B::TensorPrimitive<3>> for EmbeddingBackward<B> {
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<2>, B::TensorPrimitive<3>>,
) -> B::TensorPrimitive<2> {
B::embedding_backward(&state.input.value, &state.output.grad(), &self.indexes)
}
}

impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn embedding(
weights: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<2>,
indexes: &<<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<
2,
>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<3> {
let input = weights.node.clone();
let output = B::embedding(weights.tensor_ref(), indexes);
let ops = EmbeddingBackward::<B>::new(indexes.clone());

unary_ops_wrapper(input, output, ops)
}

fn embedding_backward(
weights: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<2>,
output: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<3>,
indexes: &<<ADBackendDecorator<B> as Backend>::IntegerBackend as Backend>::TensorPrimitive<
2,
>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<2> {
let tensor = B::embedding_backward(weights.tensor_ref(), output.tensor_ref(), indexes);
ADTensor::from_tensor(tensor)
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use super::{ADBackendDecorator, ADTensor};
use super::unary_ops_wrapper;
use crate::{
backend::Backend,
graph::{
node::{ForwardNode, ForwardNodeRef, ForwardNodeState},
ops::{ForwardUnaryRecordedOps, UnaryOps, UnaryOpsNodeState},
},
backend::{autodiff::ADBackendDecorator, Backend},
graph::ops::{UnaryOps, UnaryOpsNodeState},
ops::TensorOps,
Data, Shape,
};
use std::sync::Arc;

#[derive(new, Debug)]
struct ToDeviceBackward<B: Backend, const D: usize> {
Expand Down Expand Up @@ -73,32 +69,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
device: <ADBackendDecorator<B> as Backend>::Device,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
let device_old = B::device(tensor.tensor_ref());
let input = tensor.node.clone();
let output = B::to_device(tensor.tensor_ref(), device);
let ops = ToDeviceBackward::<B, D>::new(device);
let ops = ToDeviceBackward::<B, D>::new(device_old);

unary_ops_wrapper(input, output, ops)
}
}

fn unary_ops_wrapper<B, O, const D: usize>(
input: ForwardNodeRef<B::TensorPrimitive<D>>,
output: B::TensorPrimitive<D>,
ops: O,
) -> ADTensor<D, B>
where
B: Backend,
O: UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>> + 'static,
{
let shape = *B::shape(&output);
let state = ForwardNodeState::new(output);

let ops = Arc::new(ops);
let ops = ForwardUnaryRecordedOps::new(input.clone(), ops);
let ops = Arc::new(ops);

let node = ForwardNode::from_unary(&input, state, ops);
let node = Arc::new(node);

ADTensor { node, shape }
}
10 changes: 9 additions & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@ use crate::tensor::{Data, Distribution, Shape};
use crate::Gradients;

pub trait Backend:
TensorOps<Self> + Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'static
TensorOps<Self>
+ ModuleOps<Self>
+ Clone
+ Sized
+ Default
+ Send
+ Sync
+ std::fmt::Debug
+ 'static
{
type Device: Copy + Clone + Default + std::fmt::Debug + Send + Sync;
type Elem: Element;
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/ndarray/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod activation;
mod backend;
mod module_ops;
mod ops;
mod shape;
mod tensor;
Expand Down
56 changes: 56 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/module_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use super::{NdArrayBackend, NdArrayTensor};
use crate::{ops::*, NdArrayElement, Shape};

impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn embedding(
weights: &NdArrayTensor<E, 2>,
indexes: &NdArrayTensor<i64, 2>,
) -> NdArrayTensor<E, 3> {
let [batch_size, seq_length] = indexes.shape.dims;
let [_n_embedding, d_model] = weights.shape.dims;

let mut tensors = Vec::with_capacity(batch_size * seq_length);

for index in indexes
.reshape(Shape::new([batch_size * seq_length]))
.array
.iter()
{
let index = *index as usize;
tensors.push(weights.index([index..index + 1, 0..d_model]));
}
let embedding = TensorOpsCat::cat(tensors.iter().collect(), 0);
embedding.reshape(Shape::new([batch_size, seq_length, d_model]))
}

fn embedding_backward(
weights: &NdArrayTensor<E, 2>,
output: &NdArrayTensor<E, 3>,
indexes: &NdArrayTensor<i64, 2>,
) -> NdArrayTensor<E, 2> {
let [batch_size, seq_length] = indexes.shape.dims;
let [_n_embedding, d_model] = weights.shape.dims;

let mut weights_grad = weights.zeros();
let output = output.reshape(Shape::new([batch_size * seq_length, d_model]));

for (index_output, index) in indexes
.reshape(Shape::new([batch_size * seq_length]))
.array
.iter()
.enumerate()
{
let index = *index as usize;

let weights_grad_current = weights_grad.index([index..index + 1, 0..d_model]);
let output_grad = output.index([index_output..index_output + 1, 0..d_model]);

weights_grad = weights_grad.index_assign(
[index..index + 1, 0..d_model],
&output_grad.add(&weights_grad_current),
);
}

weights_grad
}
}
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/tch/activation/relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ where
Self {
tensor,
shape: self.shape,
kind: self.kind.clone(),
kind: self.kind,
}
}
}
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/backend/tch/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod activation;
mod backend;
mod module_ops;
mod ops;
mod tensor;
mod tensor_ops;
Expand Down
38 changes: 38 additions & 0 deletions burn-tensor/src/tensor/backend/tch/module_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use super::{TchBackend, TchTensor};
use crate::{ops::ModuleOps, Shape, TchElement};

impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
fn embedding(weights: &TchTensor<E, 2>, indexes: &TchTensor<i64, 2>) -> TchTensor<E, 3> {
let tensor = tch::Tensor::embedding(&weights.tensor, &indexes.tensor, -1, false, false);
let shape = Shape::from(tensor.size());

TchTensor {
kind: weights.kind,
tensor,
shape,
}
}

fn embedding_backward(
weights: &TchTensor<E, 2>,
output: &TchTensor<E, 3>,
indexes: &TchTensor<i64, 2>,
) -> TchTensor<E, 2> {
let [n_embedding, _d_model] = weights.shape.dims;
let tensor = tch::Tensor::embedding_backward(
&output.tensor,
&indexes.tensor,
n_embedding as i64,
-1,
false,
false,
);
let shape = Shape::from(tensor.size());

TchTensor {
kind: weights.kind,
tensor,
shape,
}
}
}
6 changes: 3 additions & 3 deletions burn-tensor/src/tensor/backend/tch/ops/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::ops::Add;
impl<P: TchElement, const D: usize> TensorOpsAdd<P, D> for TchTensor<P, D> {
fn add(&self, other: &Self) -> Self {
let tensor = (&self.tensor).add(&other.tensor);
let kind = self.kind.clone();
let kind = self.kind;
let shape = self.shape.higher(&other.shape);

Self {
Expand All @@ -18,8 +18,8 @@ impl<P: TchElement, const D: usize> TensorOpsAdd<P, D> for TchTensor<P, D> {
}
fn add_scalar(&self, other: &P) -> Self {
let other: f64 = (other.clone()).to_elem();
let tensor = (&self.tensor).add(other);
let kind = self.kind.clone();
let tensor = (&self.tensor).add(other).to_kind(self.kind.kind());
let kind = self.kind;
let shape = self.shape;

Self {
Expand Down
8 changes: 4 additions & 4 deletions burn-tensor/src/tensor/backend/tch/ops/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{

impl<E: TchElement, const D: usize> TensorOpsAggregation<TchBackend<E>, D> for TchTensor<E, D> {
fn mean(&self) -> <TchBackend<E> as Backend>::TensorPrimitive<1> {
let kind = self.kind.clone();
let kind = self.kind;
let tensor = self.tensor.mean(kind.kind());
let shape = Shape::new([1]);

Expand All @@ -22,7 +22,7 @@ impl<E: TchElement, const D: usize> TensorOpsAggregation<TchBackend<E>, D> for T
}

fn sum(&self) -> <TchBackend<E> as Backend>::TensorPrimitive<1> {
let kind = self.kind.clone();
let kind = self.kind;
let tensor = self.tensor.sum(kind.kind());
let shape = Shape::new([1]);

Expand All @@ -34,7 +34,7 @@ impl<E: TchElement, const D: usize> TensorOpsAggregation<TchBackend<E>, D> for T
}

fn mean_dim(&self, dim: usize) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
let kind = self.kind.clone();
let kind = self.kind;
let tensor = self.tensor.mean_dim(&[dim as i64], true, kind.kind());
let shape = Shape::from(tensor.size());

Expand All @@ -46,7 +46,7 @@ impl<E: TchElement, const D: usize> TensorOpsAggregation<TchBackend<E>, D> for T
}

fn sum_dim(&self, dim: usize) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
let kind = self.kind.clone();
let kind = self.kind;
let tensor = self
.tensor
.sum_dim_intlist(&[dim as i64], true, kind.kind());
Expand Down
Loading

0 comments on commit 847243d

Please sign in to comment.