Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New autodiff graph memory management strategy #1698

Merged
merged 9 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix
  • Loading branch information
louisfd committed Apr 26, 2024
commit 29892bc704c2e48b146b2b39549845b4ef1249bd
13 changes: 7 additions & 6 deletions crates/burn-autodiff/src/runtime/memory_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ impl GraphMemoryManagement {

/// Free the node from the state.
pub fn consume_node(&mut self, node_id: NodeID) {
self.leaves.remove(&node_id);
self.nodes.remove(&node_id);
if !self.is_referenced(node_id) {
self.leaves.remove(&node_id);
self.nodes.remove(&node_id);
}
}

/// Free all nodes whose backward call has become impossible
Expand Down Expand Up @@ -70,7 +72,7 @@ impl GraphMemoryManagement {

// New leaves are the roots of a useful backward sub-tree.
// Deletables are everything not marked as useful.
for leaf in leaves.clone() {
for leaf in leaves {
self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);
}

Expand Down Expand Up @@ -104,9 +106,8 @@ impl GraphMemoryManagement {
self.statuses.insert(node_id, node_status.clone());
node_status
}
// If node does not exist, it was either
// - deleted, so this all its descendants are unavailable
// - not requiring grad or detached, the status remains unknown (TODO REGISTER THEM WITH EMPTY PARENTS LIKE THOSE WITH REGISTER_GRAD)
// If node does not exist, it was
// deleted, so this and all its descendants are unavailable
None => {
self.statuses.insert(node_id, NodeMemoryStatus::Unavailable);
NodeMemoryStatus::Unavailable
Expand Down
13 changes: 4 additions & 9 deletions crates/burn-autodiff/src/tests/cat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,17 @@ mod tests {
#[test]
fn should_diff_cat() {
let device = Default::default();
let data_1 = Data::from([[2.0, -1.0], [5.0, 2.0]]);
let data_2 = Data::from([[5.0, 4.0], [-1.0, 4.0]]);

let tensor_1 = TestAutodiffTensor::from_data(data_1.clone(), &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();
let tensor_1 =
TestAutodiffTensor::from_data([[2.0, -1.0], [5.0, 2.0]], &device).require_grad();
let tensor_2 =
TestAutodiffTensor::from_data([[5.0, 4.0], [-1.0, 4.0]], &device).require_grad();

let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
let grads = tensor_3.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

// Redeclared because consumed in previous backward
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();

let mut tensor_1_list = Vec::new();
let mut tensor_2_list = Vec::new();

Expand Down
25 changes: 9 additions & 16 deletions crates/burn-autodiff/src/tests/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,16 @@ mod tests {
let data_3: Data<f32, 2> = Data::from([[14.0, 97.0, 100.0, 9.0], [2.0, 3.0, 15.0, 7.0]]);

let device = Default::default();
let tensor_1_slice = TestAutodiffTensor::from_data(data_1.clone(), &device).require_grad();
let tensor_2_slice = TestAutodiffTensor::from_data(data_2.clone(), &device).require_grad();

let tensor_1_cat = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2_cat = TestAutodiffTensor::from_data(data_2, &device).require_grad();

let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = TestAutodiffTensor::from_data(data_3, &device);

let slice_assign_output = TestAutodiffTensor::zeros([2, 4], &Default::default());
let slice_assign_output =
slice_assign_output.slice_assign([0..2, 0..2], tensor_1_slice.clone());
let slice_assign_output =
slice_assign_output.slice_assign([0..2, 2..4], tensor_2_slice.clone());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 0..2], tensor_1.clone());
let slice_assign_output = slice_assign_output.slice_assign([0..2, 2..4], tensor_2.clone());
let slice_assign_output = slice_assign_output / tensor_3.clone();

let cat_output =
TestAutodiffTensor::cat(vec![tensor_1_cat.clone(), tensor_2_cat.clone()], 1);
let cat_output = TestAutodiffTensor::cat(vec![tensor_1.clone(), tensor_2.clone()], 1);
let cat_output = cat_output / tensor_3;

slice_assign_output
Expand All @@ -111,10 +104,10 @@ mod tests {
let slice_assign_grads = slice_assign_output.backward();
let cat_grads = cat_output.backward();

let slice_assign_grad_1 = tensor_1_slice.grad(&slice_assign_grads).unwrap();
let slice_assign_grad_2 = tensor_2_slice.grad(&slice_assign_grads).unwrap();
let cat_grad_1 = tensor_1_cat.grad(&cat_grads).unwrap();
let cat_grad_2 = tensor_2_cat.grad(&cat_grads).unwrap();
let slice_assign_grad_1 = tensor_1.grad(&slice_assign_grads).unwrap();
let slice_assign_grad_2 = tensor_2.grad(&slice_assign_grads).unwrap();
let cat_grad_1 = tensor_1.grad(&cat_grads).unwrap();
let cat_grad_2 = tensor_2.grad(&cat_grads).unwrap();

slice_assign_grad_1
.to_data()
Expand Down
Loading