Skip to content

Commit

Permalink
refactor: autodiff gradients types (tracel-ai#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 20, 2022
1 parent dda067e commit ca94a9f
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 28 deletions.
6 changes: 0 additions & 6 deletions burn-autodiff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,5 @@ export_tests = ["burn-tensor-testgen"]
[dependencies]
burn-tensor-testgen = { path = "../burn-tensor-testgen", optional = true }
burn-tensor = { path = "../burn-tensor", version = "0.2.3" }
libm = "0.2"
derive-new = "0.5"
rand = "0.8"
num-traits = "0.2"
nanoid = "0.4"

[dev-dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.2.3", features = ["export_tests"] }
23 changes: 12 additions & 11 deletions burn-autodiff/src/graph/grad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,39 @@ use crate::graph::{
node::{BackwardNode, ForwardNode},
traversal::{BreadthFirstSearch, GraphTraversal},
};
use burn_tensor::{backend::Gradients, ops::Zeros};
use burn_tensor::{
backend::{ADBackend, Gradients},
ops::Zeros,
Tensor,
};
use std::{any::Any, collections::HashMap, ops::Add};

#[derive(Default)]
pub struct Grads {
grads: HashMap<String, Box<dyn Any + Send + Sync>>,
}

impl Gradients for Grads {
impl<B: ADBackend> Gradients<B> for Grads {
fn empty() -> Self {
Self {
grads: HashMap::new(),
}
}
fn get<V: 'static>(&self, id: &str) -> Option<&V> {
fn get<const D: usize>(&self, id: &str) -> Option<&Tensor<B::InnerBackend, D>> {
let grad = match self.grads.get(id) {
Some(grad) => grad,
None => return None,
};

grad.downcast_ref()
}
fn register<V>(&mut self, id: String, value: V)
where
V: std::fmt::Debug + 'static + Send + Sync,
{
fn register<const D: usize>(&mut self, id: String, value: Tensor<B::InnerBackend, D>) {
self.grads.insert(id, Box::new(value));
}
}

impl Grads {
pub fn register<T>(&mut self, node: &BackwardNode<T>)
pub fn register_node<T>(&mut self, node: &BackwardNode<T>)
where
T: Zeros + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static + Send + Sync,
Expand All @@ -42,14 +43,14 @@ impl Grads {
self.grads.insert(node.id.clone(), Box::new(grad));
}

pub fn from<T>(node: &BackwardNode<T>) -> Self
pub fn from_node<T>(node: &BackwardNode<T>) -> Self
where
T: Zeros + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static + Send + Sync,
{
let mut grads = Self::empty();
let mut grads = Self::default();
let traversal = BreadthFirstSearch::new(node);
grads.register(node);
grads.register_node(node);

traversal.traverse(|node| {
node.register_grad(&mut grads);
Expand Down
4 changes: 2 additions & 2 deletions burn-autodiff/src/graph/node/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ where
}
}

Grads::from(self)
Grads::from_node(self)
}
}

Expand All @@ -89,6 +89,6 @@ where
&self.id
}
fn register_grad(&self, grads: &mut Grads) {
grads.register(self)
grads.register_node(self)
}
}
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =

pub trait ADBackend: Backend {
type InnerBackend: Backend<Device = Self::Device, Elem = Self::Elem>;
type Gradients: Gradients;
type Gradients: Gradients<Self>;

fn backward<const D: usize>(tensor: &Self::TensorPrimitive<D>) -> Self::Gradients;
fn grad<const D: usize>(
Expand Down
11 changes: 6 additions & 5 deletions burn-tensor/src/tensor/backend/grad.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
pub trait Gradients: Send + Sync {
use crate::backend::ADBackend;
use crate::Tensor;

pub trait Gradients<B: ADBackend>: Send + Sync {
fn empty() -> Self;
fn get<V: 'static>(&self, id: &str) -> Option<&V>;
fn register<V>(&mut self, id: String, value: V)
where
V: std::fmt::Debug + 'static + Send + Sync;
fn get<const D: usize>(&self, id: &str) -> Option<&Tensor<B::InnerBackend, D>>;
fn register<const D: usize>(&mut self, id: String, value: Tensor<B::InnerBackend, D>);
}
2 changes: 1 addition & 1 deletion burn/src/optim/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub(super) fn register_state_gradients<const D: usize, B: ADBackend, F: Fn(&str)
) {
let id = id.to_string();

if let Some(velocity) = grads.get::<Tensor<B::InnerBackend, D>>(&id) {
if let Some(velocity) = grads.get::<D>(&id) {
let data = State::Data(velocity.to_data().serialize());
state.register_state(id_to_key(&id).as_str(), data);
};
Expand Down
2 changes: 1 addition & 1 deletion burn/src/optim/decay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<B: ADBackend> WeightDecay<B> {
) -> Tensor<B::InnerBackend, D> {
let id = id.to_string();

let grad = match self.gradients.get::<Tensor<B::InnerBackend, D>>(&id) {
let grad = match self.gradients.get::<D>(&id) {
Some(grad_last_step) => grad_last_step.mul_scalar(self.penalty).add(&grad),
None => grad,
};
Expand Down
2 changes: 1 addition & 1 deletion burn/src/optim/momentum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl<B: ADBackend> Momentum<B> {
) -> Tensor<B::InnerBackend, D> {
let id = id.to_string();

let velocity = match self.velocity.get::<Tensor<B::InnerBackend, D>>(&id) {
let velocity = match self.velocity.get::<D>(&id) {
Some(grad_last_step) => grad
.mul_scalar(1.0 - self.dampening)
.add(&grad_last_step.mul_scalar(self.momentum)),
Expand Down

0 comments on commit ca94a9f

Please sign in to comment.