Skip to content

Commit

Permalink
Feat/dropout (tracel-ai#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 17, 2022
1 parent 48e0fbd commit 2e29e82
Show file tree
Hide file tree
Showing 17 changed files with 163 additions and 97 deletions.
11 changes: 9 additions & 2 deletions burn-dataset/src/dataset/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
use crate::DatasetIterator;

pub trait Dataset<I>: Send + Sync {
fn iter(&self) -> DatasetIterator<'_, I>;
fn get(&self, index: usize) -> Option<I>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn iter(&self) -> DatasetIterator<'_, I>
where
Self: Sized,
{
DatasetIterator::new(self)
}
}
8 changes: 1 addition & 7 deletions burn-dataset/src/dataset/in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
io::{BufRead, BufReader},
};

use crate::{Dataset, DatasetIterator};
use crate::Dataset;

pub struct InMemDataset<I> {
items: Vec<I>,
Expand All @@ -22,15 +22,9 @@ where
fn get(&self, index: usize) -> Option<I> {
self.items.get(index).cloned()
}
fn iter(&self) -> DatasetIterator<'_, I> {
DatasetIterator::new(self)
}
fn len(&self) -> usize {
self.items.len()
}
fn is_empty(&self) -> bool {
self.items.is_empty()
}
}

impl<I> InMemDataset<I>
Expand Down
10 changes: 1 addition & 9 deletions burn-dataset/src/source/huggingface/mnist.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::downloader::cache_dir;
use crate::source::huggingface::downloader::{download, Extractor};
use crate::{Dataset, DatasetIterator, InMemDataset};
use crate::{Dataset, InMemDataset};
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Debug, Clone)]
Expand All @@ -14,21 +14,13 @@ pub struct MNISTDataset {
}

impl Dataset<MNISTItem> for MNISTDataset {
fn iter(&self) -> crate::DatasetIterator<'_, MNISTItem> {
DatasetIterator::new(self)
}

fn get(&self, index: usize) -> Option<MNISTItem> {
self.dataset.get(index)
}

fn len(&self) -> usize {
self.dataset.len()
}

fn is_empty(&self) -> bool {
self.dataset.is_empty()
}
}

impl MNISTDataset {
Expand Down
17 changes: 1 addition & 16 deletions burn-dataset/src/transform/composed.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Dataset, DatasetIterator};
use crate::Dataset;

pub struct ComposedDataset<I> {
datasets: Vec<Box<dyn Dataset<I>>>,
Expand All @@ -24,26 +24,11 @@ where
}
None
}
fn iter(&self) -> DatasetIterator<'_, I> {
DatasetIterator::new(self)
}
fn len(&self) -> usize {
let mut total = 0;
for dataset in self.datasets.iter() {
total += dataset.len();
}
total
}

fn is_empty(&self) -> bool {
let mut is_empty = true;

for dataset in self.datasets.iter() {
if !dataset.is_empty() {
is_empty = false;
}
}

is_empty
}
}
9 changes: 1 addition & 8 deletions burn-dataset/src/transform/mapper.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Dataset, DatasetIterator};
use crate::Dataset;

pub trait Mapper<I, O> {
fn map(&self, item: &I) -> O;
Expand Down Expand Up @@ -26,16 +26,9 @@ where
item.map(|item| self.mapper.map(&item))
}

fn iter(&self) -> DatasetIterator<'_, O> {
DatasetIterator::new(self)
}
fn len(&self) -> usize {
self.dataset.len()
}

fn is_empty(&self) -> bool {
self.dataset.is_empty()
}
}

#[cfg(test)]
Expand Down
10 changes: 1 addition & 9 deletions burn-dataset/src/transform/partial.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Dataset, DatasetIterator};
use crate::Dataset;
use std::sync::Arc;

pub struct PartialDataset<I> {
Expand Down Expand Up @@ -51,17 +51,9 @@ where
self.dataset.get(index)
}

fn iter(&self) -> DatasetIterator<'_, I> {
DatasetIterator::new(self)
}

fn len(&self) -> usize {
usize::min(self.end_index - self.start_index, self.dataset.len())
}

fn is_empty(&self) -> bool {
self.dataset.is_empty()
}
}

#[cfg(test)]
Expand Down
10 changes: 1 addition & 9 deletions burn-dataset/src/transform/random.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Dataset, DatasetIterator};
use crate::Dataset;
use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng};
use std::sync::Arc;

Expand Down Expand Up @@ -36,15 +36,7 @@ where
self.dataset.get(*index)
}

fn iter(&self) -> DatasetIterator<'_, I> {
DatasetIterator::new(self)
}

fn len(&self) -> usize {
self.dataset.len()
}

fn is_empty(&self) -> bool {
self.dataset.is_empty()
}
}
6 changes: 3 additions & 3 deletions burn-tensor/src/tensor/backend/autodiff/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ macro_rules! define_impl {
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_data(Data::random(shape, distribution), device)
Self::from_inner(<$backend as Backend>::random(shape, distribution, device))
}

fn ad_enabled() -> bool {
Expand All @@ -61,14 +61,14 @@ macro_rules! define_impl {
shape: Shape<D>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_data(Data::zeros(shape), device)
Self::from_inner(<$backend as Backend>::zeros(shape, device))
}

fn ones<const D: usize>(
shape: Shape<D>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_data(Data::ones(shape), device)
Self::from_inner(<$backend as Backend>::ones(shape, device))
}

fn name() -> String {
Expand Down
18 changes: 12 additions & 6 deletions burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,24 @@ pub trait Backend: Clone + Sized + Default + Send + Sync + std::fmt::Debug + 'st
device: Self::Device,
) -> Self::BoolTensorPrimitive<D>;

fn ad_enabled() -> bool;
fn name() -> String;

fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D>;
) -> Self::TensorPrimitive<D> {
Self::from_data(Data::random(shape, distribution), device)
}

fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D>;
fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_data(Data::zeros(shape), device)
}

fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D>;

fn ad_enabled() -> bool;
fn name() -> String;
fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_data(Data::ones(shape), device)
}
}

pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =
Expand Down
18 changes: 1 addition & 17 deletions burn-tensor/src/tensor/backend/ndarray/backend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::NdArrayTensor;
use crate::tensor::Data;
use crate::tensor::{backend::Backend, NdArrayElement};
use crate::tensor::{Data, Distribution, Shape};

#[derive(Clone, Copy, Debug)]
pub enum NdArrayDevice {
Expand Down Expand Up @@ -41,22 +41,6 @@ impl<E: NdArrayElement> Backend for NdArrayBackend<E> {
NdArrayTensor::from_data(data)
}

fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_data(Data::random(shape, distribution), device)
}

fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_data(Data::zeros(shape), device)
}

fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_data(Data::ones(shape), device)
}

fn ad_enabled() -> bool {
false
}
Expand Down
36 changes: 33 additions & 3 deletions burn-tensor/src/tensor/backend/tch/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ pub enum TchDevice {
Cpu,
Cuda(usize),
}
impl From<TchDevice> for tch::Device {
fn from(device: TchDevice) -> Self {
match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
}
}
}

impl Default for TchDevice {
fn default() -> Self {
Expand Down Expand Up @@ -55,15 +63,37 @@ impl<E: TchElement> Backend for TchBackend<E> {
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_data(Data::random(shape, distribution), device)
match distribution {
Distribution::Standard => {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.normal_(0.0, 1.0);
tensor
}
Distribution::Bernoulli(prob) => {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.f_bernoulli_float_(prob).unwrap();
tensor
}
Distribution::Uniform(from, to) => {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor
.tensor
.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap());
tensor
}
}
}

fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_data(Data::zeros(shape), device)
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.zero_();
tensor
}

fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_data(Data::ones(shape), device)
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.ones_like();
tensor
}

fn ad_enabled() -> bool {
Expand Down
10 changes: 6 additions & 4 deletions burn-tensor/src/tensor/backend/tch/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::tensor::{ops::TensorOpsUtilities, Data, Shape};
use crate::{
backend::TchDevice,
tensor::{ops::TensorOpsUtilities, Data, Shape},
};

lazy_static::lazy_static! {
static ref NO_GRAD: tch::NoGradGuard = {
Expand Down Expand Up @@ -84,11 +87,10 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
}

impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> TchTensor<P, D> {
pub fn empty(shape: Shape<D>) -> Self {
pub fn empty(shape: Shape<D>, device: TchDevice) -> Self {
let shape_tch = TchShape::from(shape);
let device = tch::Device::Cpu;
let kind = TchKind::new();
let tensor = tch::Tensor::empty(&shape_tch.dims, (kind.kind(), device));
let tensor = tch::Tensor::empty(&shape_tch.dims, (kind.kind(), device.into()));

lazy_static::initialize(&NO_GRAD);
let tensor = tensor.set_requires_grad(false);
Expand Down
14 changes: 14 additions & 0 deletions burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub struct Data<P, const D: usize> {
#[derive(Clone, Copy)]
pub enum Distribution<P> {
Standard,
Bernoulli(f64),
Uniform(P, P),
}

Expand All @@ -37,17 +38,26 @@ where
{
Standard(rand::distributions::Standard),
Uniform(rand::distributions::Uniform<P>),
Bernoulli(rand::distributions::Bernoulli),
}

impl<'a, P> DistributionSampler<'a, P>
where
Standard: rand::distributions::Distribution<P>,
P: rand::distributions::uniform::SampleUniform,
P: Element,
{
pub fn sample(&mut self) -> P {
match &self.kind {
DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
DistributionSamplerKind::Bernoulli(distribution) => {
if self.rng.sample(distribution) {
P::ones(&P::default())
} else {
P::zeros(&P::default())
}
}
}
}
}
Expand All @@ -65,6 +75,9 @@ where
Distribution::Uniform(low, high) => {
DistributionSamplerKind::Uniform(rand::distributions::Uniform::new(low, high))
}
Distribution::Bernoulli(prob) => DistributionSamplerKind::Bernoulli(
rand::distributions::Bernoulli::new(prob).unwrap(),
),
};

DistributionSampler::new(kind, rng)
Expand All @@ -79,6 +92,7 @@ where
match self {
Distribution::Standard => Distribution::Standard,
Distribution::Uniform(a, b) => Distribution::Uniform(E::from_elem(a), E::from_elem(b)),
Distribution::Bernoulli(prob) => Distribution::Bernoulli(prob),
}
}
}
Expand Down
Loading

0 comments on commit 2e29e82

Please sign in to comment.