Skip to content

Commit

Permalink
feat: handle seed with backend (tracel-ai#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 21, 2022
1 parent 5844d43 commit 1125948
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 21 deletions.
4 changes: 4 additions & 0 deletions burn-tensor/src/tensor/backend/autodiff/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ macro_rules! define_impl {
fn name() -> String {
format!("autodiff<{}>", <$backend as Backend>::name())
}

fn seed(seed: u64) {
<$backend as Backend>::seed(seed)
}
}

impl<E: $element> ADBackend for $name<E> {
Expand Down
5 changes: 2 additions & 3 deletions burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,13 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st

fn ad_enabled() -> bool;
fn name() -> String;
fn seed(seed: u64);

fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_data(Data::random(shape, distribution), device)
}
) -> Self::TensorPrimitive<D>;

fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_data(Data::zeros(shape), device)
Expand Down
27 changes: 27 additions & 0 deletions burn-tensor/src/tensor/backend/ndarray/backend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use super::NdArrayTensor;
use crate::tensor::Data;
use crate::tensor::{backend::Backend, NdArrayElement};
use crate::{Distribution, Shape};
use rand::rngs::StdRng;
use rand::SeedableRng;
use std::sync::Mutex;

static SEED: Mutex<Option<StdRng>> = Mutex::new(None);

#[derive(Clone, Copy, Debug)]
pub enum NdArrayDevice {
Expand Down Expand Up @@ -45,7 +51,28 @@ impl<E: NdArrayElement> Backend for NdArrayBackend<E> {
false
}

fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
let mut seed = SEED.lock().unwrap();
let mut rng: StdRng = match seed.as_ref() {
Some(rng) => rng.clone(),
None => StdRng::from_entropy(),
};
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
*seed = Some(rng);
tensor
}

fn name() -> String {
"ndarray".to_string()
}

fn seed(seed: u64) {
let rng = StdRng::seed_from_u64(seed);
let mut seed = SEED.lock().unwrap();
*seed = Some(rng);
}
}
27 changes: 22 additions & 5 deletions burn-tensor/src/tensor/backend/ndarray/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,18 @@ where

#[cfg(test)]
mod tests {
use rand::{rngs::StdRng, SeedableRng};

use super::*;
use crate::tensor::Distribution;

#[test]
fn should_support_into_and_from_data_1d() {
let data_expected = Data::<f32, 1>::random(Shape::new([3]), Distribution::Standard);
let data_expected = Data::<f32, 1>::random(
Shape::new([3]),
Distribution::Standard,
&mut StdRng::from_entropy(),
);
let tensor = NdArrayTensor::from_data(data_expected.clone());

let data_actual = tensor.into_data();
Expand All @@ -165,7 +171,11 @@ mod tests {

#[test]
fn should_support_into_and_from_data_2d() {
let data_expected = Data::<f32, 2>::random(Shape::new([2, 3]), Distribution::Standard);
let data_expected = Data::<f32, 2>::random(
Shape::new([2, 3]),
Distribution::Standard,
&mut StdRng::from_entropy(),
);
let tensor = NdArrayTensor::from_data(data_expected.clone());

let data_actual = tensor.into_data();
Expand All @@ -175,7 +185,11 @@ mod tests {

#[test]
fn should_support_into_and_from_data_3d() {
let data_expected = Data::<f32, 3>::random(Shape::new([2, 3, 4]), Distribution::Standard);
let data_expected = Data::<f32, 3>::random(
Shape::new([2, 3, 4]),
Distribution::Standard,
&mut StdRng::from_entropy(),
);
let tensor = NdArrayTensor::from_data(data_expected.clone());

let data_actual = tensor.into_data();
Expand All @@ -185,8 +199,11 @@ mod tests {

#[test]
fn should_support_into_and_from_data_4d() {
let data_expected =
Data::<f32, 4>::random(Shape::new([2, 3, 4, 2]), Distribution::Standard);
let data_expected = Data::<f32, 4>::random(
Shape::new([2, 3, 4, 2]),
Distribution::Standard,
&mut StdRng::from_entropy(),
);
let tensor = NdArrayTensor::from_data(data_expected.clone());

let data_actual = tensor.into_data();
Expand Down
4 changes: 4 additions & 0 deletions burn-tensor/src/tensor/backend/tch/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ impl<E: TchElement> Backend for TchBackend<E> {
tensor
}

fn seed(seed: u64) {
tch::manual_seed(seed as i64);
}

fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.ones_like();
Expand Down
17 changes: 13 additions & 4 deletions burn-tensor/src/tensor/backend/tch/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,18 @@ impl<const D: usize> TensorOpsUtilities<usize, D> for TchTensor<i64, D> {

#[cfg(test)]
mod tests {
use crate::tensor::Distribution;

use super::*;
use crate::tensor::Distribution;
use rand::prelude::StdRng;
use rand::SeedableRng;

#[test]
fn should_support_into_and_from_data_1d() {
let data_expected = Data::<f32, 1>::random(Shape::new([3]), Distribution::Standard);
let data_expected = Data::<f32, 1>::random(
Shape::new([3]),
Distribution::Standard,
&mut StdRng::from_entropy(),
);
let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu);

let data_actual = tensor.into_data();
Expand All @@ -153,7 +158,11 @@ mod tests {

#[test]
fn should_support_into_and_from_data_2d() {
let data_expected = Data::<f32, 2>::random(Shape::new([2, 3]), Distribution::Standard);
let data_expected = Data::<f32, 2>::random(
Shape::new([2, 3]),
Distribution::Standard,
&mut StdRng::from_entropy(),
);
let tensor = TchTensor::from_data(data_expected.clone(), tch::Device::Cpu);

let data_actual = tensor.into_data();
Expand Down
15 changes: 6 additions & 9 deletions burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::ops::{Ones, Zeros};
use crate::{tensor::Shape, Element, ElementConversion};
use rand::{distributions::Standard, prelude::StdRng, Rng, SeedableRng};
use rand::{distributions::Standard, prelude::StdRng, Rng};

#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq)]
pub struct DataSerialize<P> {
Expand Down Expand Up @@ -134,21 +134,16 @@ impl<const D: usize> Data<bool, D> {
}
}
impl<P: Element, const D: usize> Data<P, D> {
pub fn random(shape: Shape<D>, distribution: Distribution<P>) -> Self {
pub fn random(shape: Shape<D>, distribution: Distribution<P>, rng: &mut StdRng) -> Self {
let num_elements = shape.num_elements();
let mut rng = StdRng::from_entropy();
let mut data = Vec::with_capacity(num_elements);

for _ in 0..num_elements {
data.push(P::random(distribution, &mut rng));
data.push(P::random(distribution, rng));
}

Data::new(data, shape)
}
/// Usefull to force a kind
pub fn random_(shape: Shape<D>, distribution: Distribution<P>, _kind: P) -> Self {
Self::random(shape, distribution)
}
}
impl<P: std::fmt::Debug, const D: usize> Data<P, D>
where
Expand Down Expand Up @@ -310,11 +305,13 @@ impl<P: std::fmt::Debug, const D: usize> std::fmt::Display for Data<P, D> {
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;

#[test]
fn should_have_right_num_elements() {
let shape = Shape::new([3, 5, 6]);
let data = Data::<f32, 3>::random(shape, Distribution::Standard);
let data =
Data::<f32, 3>::random(shape, Distribution::Standard, &mut StdRng::from_entropy());
assert_eq!(shape.num_elements(), data.value.len());
}

Expand Down
1 change: 1 addition & 0 deletions burn/examples/mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ fn run<B: ADBackend>(device: B::Device) {
MlpConfig::new(),
),
};
B::seed(config.seed);

let batcher_train = Arc::new(MNISTBatcher::<B> { device });
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend> { device });
Expand Down

0 comments on commit 1125948

Please sign in to comment.