Skip to content

Commit

Permalink
Refactor/autodiff graph (tracel-ai#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 20, 2022
1 parent ca94a9f commit cd39a1c
Show file tree
Hide file tree
Showing 11 changed files with 37 additions and 221 deletions.
4 changes: 2 additions & 2 deletions burn-autodiff/src/graph/node/backward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{BackwardNodeState, ForwardNodeRef};
use crate::graph::grad::Grads;
use crate::graph::{
converter::Forward2BackwardGraphConverter,
ops::{BackwardRecordedOpsRef, RecordedOpsParent, RecordedOpsParentRef},
ops::{BackwardRecordedOpsBoxed, RecordedOpsParent, RecordedOpsParentRef},
traversal::{BreadthFirstSearch, GraphTraversal},
};
use burn_tensor::ops::{Ones, Zeros};
Expand All @@ -13,7 +13,7 @@ pub struct BackwardNode<Out> {
pub id: String,
pub order: usize,
pub state: BackwardNodeState<Out>,
pub ops: BackwardRecordedOpsRef<Out>,
pub ops: BackwardRecordedOpsBoxed<Out>,
}
pub type BackwardNodeRef<Out> = Arc<BackwardNode<Out>>;

Expand Down
12 changes: 6 additions & 6 deletions burn-autodiff/src/graph/node/forward.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
use super::ForwardNodeState;
use crate::graph::ops::ForwardRecordedOpsRef;
use crate::graph::ops::ForwardRecordedOpsBoxed;
use std::sync::Arc;

#[derive(Debug)]
pub struct ForwardNode<Out> {
pub id: String,
pub order: usize,
pub state: ForwardNodeState<Out>,
pub ops: ForwardRecordedOpsRef<Out>,
pub ops: ForwardRecordedOpsBoxed<Out>,
}
pub type ForwardNodeRef<Out> = Arc<ForwardNode<Out>>;

impl<Out> ForwardNode<Out> {
pub fn from_root(state: ForwardNodeState<Out>, ops: ForwardRecordedOpsRef<Out>) -> Self {
pub fn from_root(state: ForwardNodeState<Out>, ops: ForwardRecordedOpsBoxed<Out>) -> Self {
let order = 0;
Self::new(order, state, ops)
}

pub fn from_unary<T>(
node: &ForwardNode<T>,
state: ForwardNodeState<Out>,
ops: ForwardRecordedOpsRef<Out>,
ops: ForwardRecordedOpsBoxed<Out>,
) -> Self {
let order = node.order + 1;
Self::new(order, state, ops)
Expand All @@ -30,7 +30,7 @@ impl<Out> ForwardNode<Out> {
lhs: &ForwardNode<Lhs>,
rhs: &ForwardNode<Rhs>,
state: ForwardNodeState<Out>,
ops: ForwardRecordedOpsRef<Out>,
ops: ForwardRecordedOpsBoxed<Out>,
) -> Self {
let order = usize::max(lhs.order, rhs.order) + 1;
Self::new(order, state, ops)
Expand All @@ -39,7 +39,7 @@ impl<Out> ForwardNode<Out> {
pub fn new(
order: usize,
state: ForwardNodeState<Out>,
ops: ForwardRecordedOpsRef<Out>,
ops: ForwardRecordedOpsBoxed<Out>,
) -> Self {
let id = nanoid::nanoid!();
Self {
Expand Down
9 changes: 6 additions & 3 deletions burn-autodiff/src/graph/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ pub trait BackwardRecordedOps<T>: std::fmt::Debug {
}

pub trait ForwardRecordedOps<T>: std::fmt::Debug + Send + Sync {
fn to_backward(&self, graph: &mut Forward2BackwardGraphConverter) -> BackwardRecordedOpsRef<T>;
fn to_backward(
&self,
graph: &mut Forward2BackwardGraphConverter,
) -> BackwardRecordedOpsBoxed<T>;
}

pub trait RecordedOpsParent: std::fmt::Debug {
Expand All @@ -33,6 +36,6 @@ pub trait RecordedOpsParent: std::fmt::Debug {
fn register_grad(&self, grads: &mut Grads);
}

pub type ForwardRecordedOpsRef<T> = Arc<dyn ForwardRecordedOps<T>>;
pub type BackwardRecordedOpsRef<T> = Arc<dyn BackwardRecordedOps<T>>;
pub type ForwardRecordedOpsBoxed<T> = Box<dyn ForwardRecordedOps<T>>;
pub type BackwardRecordedOpsBoxed<T> = Box<dyn BackwardRecordedOps<T>>;
pub type RecordedOpsParentRef = Arc<dyn RecordedOpsParent>;
6 changes: 3 additions & 3 deletions burn-autodiff/src/graph/ops/binary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{
BackwardRecordedOps, BackwardRecordedOpsRef, BinaryOpsNodeState, ForwardRecordedOps,
BackwardRecordedOps, BackwardRecordedOpsBoxed, BinaryOpsNodeState, ForwardRecordedOps,
RecordedOpsParentRef,
};
use crate::graph::{
Expand Down Expand Up @@ -38,12 +38,12 @@ where
fn to_backward(
&self,
graph: &mut Forward2BackwardGraphConverter,
) -> BackwardRecordedOpsRef<Out> {
) -> BackwardRecordedOpsBoxed<Out> {
let lhs = graph.from(&self.lhs);
let rhs = graph.from(&self.rhs);
let ops = self.ops.clone();

Arc::new(BackwardBinaryRecordedOps::new(lhs, rhs, ops))
Box::new(BackwardBinaryRecordedOps::new(lhs, rhs, ops))
}
}

Expand Down
6 changes: 3 additions & 3 deletions burn-autodiff/src/graph/ops/root.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{BackwardRecordedOps, ForwardRecordedOps, RecordedOpsParentRef};
use crate::graph::{converter::Forward2BackwardGraphConverter, node::BackwardNodeState};
use burn_tensor::ops::Zeros;
use std::{ops::Add, sync::Arc};
use std::ops::Add;

#[derive(new, Debug, Clone)]
pub struct InitRecordedOps {}
Expand All @@ -23,7 +23,7 @@ where
fn to_backward(
&self,
_graph: &mut Forward2BackwardGraphConverter,
) -> super::BackwardRecordedOpsRef<Out> {
Arc::new(self.clone())
) -> super::BackwardRecordedOpsBoxed<Out> {
Box::new(self.clone())
}
}
4 changes: 2 additions & 2 deletions burn-autodiff/src/graph/ops/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ where
fn to_backward(
&self,
graph: &mut Forward2BackwardGraphConverter,
) -> super::BackwardRecordedOpsRef<Out> {
) -> super::BackwardRecordedOpsBoxed<Out> {
let input = graph.from(&self.input);
let ops = self.ops.clone();

Arc::new(BackwareUnaryRecordedOps::new(input, ops))
Box::new(BackwareUnaryRecordedOps::new(input, ops))
}
}

Expand Down
4 changes: 2 additions & 2 deletions burn-autodiff/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ where

let ops = Arc::new(ops);
let ops = ForwardUnaryRecordedOps::new(input.clone(), ops);
let ops = Arc::new(ops);
let ops = Box::new(ops);

let node = ForwardNode::from_unary(&input, state, ops);
let node = Arc::new(node);
Expand Down Expand Up @@ -56,7 +56,7 @@ where

let ops = Arc::new(ops);
let ops = ForwardBinaryRecordedOps::new(lhs.clone(), rhs.clone(), ops);
let ops = Arc::new(ops);
let ops = Box::new(ops);

let node = ForwardNode::from_binary(&lhs, &rhs, state, ops);
let node = Arc::new(node);
Expand Down
190 changes: 0 additions & 190 deletions burn-autodiff/src/ops/macros.rs

This file was deleted.

2 changes: 0 additions & 2 deletions burn-autodiff/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,4 @@ mod creation;
mod module;
mod tensor;

mod macros;
pub(crate) use base::*;
pub use macros::*;
6 changes: 3 additions & 3 deletions burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,8 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn to_backward(
&self,
graph: &mut Forward2BackwardGraphConverter,
) -> BackwardRecordedOpsRef<B::TensorPrimitive<D>> {
Arc::new(BackwardCatOps::<D, B>::new(
) -> BackwardRecordedOpsBoxed<B::TensorPrimitive<D>> {
Box::new(BackwardCatOps::<D, B>::new(
self.nodes
.iter()
.map(|node| {
Expand Down Expand Up @@ -1117,7 +1117,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
let state = crate::graph::node::ForwardNodeState::new(out);

let ops = ForwardCatOps::<D, B>::new(nodes, dim);
let ops = Arc::new(ops);
let ops = Box::new(ops);

let node = crate::graph::node::ForwardNode::new(order, state, ops);
let node = std::sync::Arc::new(node);
Expand Down
Loading

0 comments on commit cd39a1c

Please sign in to comment.