Skip to content

Commit

Permalink
Use a u64 counter for autodiff NodeIDs (tracel-ai#843)
Browse files Browse the repository at this point in the history
  • Loading branch information
dae authored Oct 4, 2023
1 parent ce120ea commit 28e2a99
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
10 changes: 4 additions & 6 deletions burn-autodiff/src/grads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
};

/// Gradient identifier.
pub type GradID = String;
pub type GradID = u64;

/// Gradients container used during the backward pass.
pub struct Gradients {
Expand Down Expand Up @@ -80,13 +80,11 @@ impl Gradients {
value: TensorPrimitive<B, D>,
) {
if let Some(tensor_old) = self.container.remove::<B, D>(&node.id.value) {
self.container.register(
node.id.value.clone(),
Tensor::from_primitive(value).add(tensor_old),
);
self.container
.register(node.id.value, Tensor::from_primitive(value).add(tensor_old));
} else {
self.container
.register::<B, D>(node.id.value.clone(), Tensor::from_primitive(value));
.register::<B, D>(node.id.value, Tensor::from_primitive(value));
}
}
}
12 changes: 7 additions & 5 deletions burn-autodiff/src/graph/node.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;

use burn_common::id::IdGenerator;

use super::Requirement;

/// A node contains graph metadata and should be used wrapped in an Arc for cheap cloning.
Expand All @@ -27,15 +26,18 @@ impl Node {
/// Unique identifier generated for each [node](Node).
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub struct NodeID {
pub value: String,
pub value: u64,
}

impl NodeID {
/// Create a unique [node id](NodeID).
pub fn new() -> Self {
Self {
value: IdGenerator::generate(),
static COUNTER: AtomicU64 = AtomicU64::new(0);
let value = COUNTER.fetch_add(1, Ordering::Relaxed);
if value == u64::MAX {
panic!("NodeID overflowed");
}
Self { value }
}
}

Expand Down

0 comments on commit 28e2a99

Please sign in to comment.