Skip to content

Commit

Permalink
Refactor/device ops (tracel-ai#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Oct 21, 2022
1 parent 0e1b0ac commit b1df39e
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 104 deletions.
8 changes: 4 additions & 4 deletions burn-tensor/src/tensor/backend/autodiff/ops/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
) -> B::TensorPrimitive<D> {
let grad = state.output.grad();
let ones = B::ones(self.state, grad.device());
let ones = B::ones(self.state, B::device(&grad));

let grad: Tensor<B, 1> = Tensor::new(grad);
let val = 1_f64 / self.state.num_elements() as f64;
Expand All @@ -54,7 +54,7 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<1>>,
) -> B::TensorPrimitive<D> {
let grad = state.output.grad();
let ones = B::ones(self.state, grad.device());
let ones = B::ones(self.state, B::device(&grad));

let grad: Tensor<B, 1> = Tensor::new(grad);
let ones: Tensor<B, D> = Tensor::new(ones);
Expand All @@ -73,7 +73,7 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
let (shape, dim) = self.state;

let grad = state.output.grad().sum_dim(dim);
let ones = B::ones(shape, grad.device());
let ones = B::ones(shape, B::device(&grad));

let val = 1_f64 / shape.dims[dim] as f64;
let ones = ones.mul_scalar(&B::Elem::from_elem(val));
Expand All @@ -92,7 +92,7 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
let (shape, dim) = self.state;

let grad = state.output.grad().sum_dim(dim);
let ones = B::ones(shape, grad.device());
let ones = B::ones(shape, B::device(&grad));

ones.mul(&grad)
}
Expand Down
36 changes: 0 additions & 36 deletions burn-tensor/src/tensor/backend/autodiff/ops/device.rs

This file was deleted.

1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ mod arg;
mod cat;
mod creation;
mod detach;
mod device;
mod div;
mod erf;
mod exp;
Expand Down
68 changes: 66 additions & 2 deletions burn-tensor/src/tensor/backend/autodiff/tensor_ops.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,30 @@
use super::ADBackendDecorator;
use crate::{backend::Backend, ops::TensorOps, Data, Shape};
use super::{ADBackendDecorator, ADTensor};
use crate::{
backend::Backend,
graph::{
node::{ForwardNode, ForwardNodeRef, ForwardNodeState},
ops::{ForwardUnaryRecordedOps, UnaryOps, UnaryOpsNodeState},
},
ops::TensorOps,
Data, Shape,
};
use std::sync::Arc;

#[derive(new, Debug)]
struct ToDeviceBackward<B: Backend, const D: usize> {
device: B::Device,
}

impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for ToDeviceBackward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
B::to_device(&state.output.grad(), self.device)
}
}

impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn shape<const D: usize>(
Expand Down Expand Up @@ -37,4 +62,43 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
) -> Data<bool, D> {
B::bool_into_data(tensor)
}

fn device<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::Device {
B::device(tensor.tensor_ref())
}

fn to_device<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
device: <ADBackendDecorator<B> as Backend>::Device,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
let input = tensor.node.clone();
let output = B::to_device(tensor.tensor_ref(), device);
let ops = ToDeviceBackward::<B, D>::new(device);

unary_ops_wrapper(input, output, ops)
}
}

fn unary_ops_wrapper<B, O, const D: usize>(
input: ForwardNodeRef<B::TensorPrimitive<D>>,
output: B::TensorPrimitive<D>,
ops: O,
) -> ADTensor<D, B>
where
B: Backend,
O: UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>> + 'static,
{
let shape = *B::shape(&output);
let state = ForwardNodeState::new(output);

let ops = Arc::new(ops);
let ops = ForwardUnaryRecordedOps::new(input.clone(), ops);
let ops = Arc::new(ops);

let node = ForwardNode::from_unary(&input, state, ops);
let node = Arc::new(node);

ADTensor { node, shape }
}
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ pub trait Backend:
+ Ones<Self::TensorPrimitive<D>>
+ TensorOpsReshape<Self, D>
+ TensorOpsPrecision<Self, D>
+ TensorOpsDevice<Self, D>
+ TensorOpsIndex<Self::Elem, D>
+ TensorOpsAggregation<Self, D>
+ TensorOpsExp<Self::Elem, D>
Expand Down
18 changes: 0 additions & 18 deletions burn-tensor/src/tensor/backend/ndarray/ops/device.rs

This file was deleted.

1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/ndarray/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ mod arg;
mod cat;
mod creation;
mod detach;
mod device;
mod div;
mod erf;
mod exp;
Expand Down
18 changes: 16 additions & 2 deletions burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use super::NdArrayBackend;
use crate::{backend::Backend, ops::TensorOps, Data, NdArrayElement, Shape};
use super::{NdArrayBackend, NdArrayTensor};
use crate::{
backend::{Backend, NdArrayDevice},
ops::TensorOps,
Data, NdArrayElement, Shape,
};

impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn shape<const D: usize>(
Expand Down Expand Up @@ -41,4 +45,14 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
let values = tensor.array.into_iter().collect();
Data::new(values, tensor.shape)
}
fn device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {
NdArrayDevice::Cpu
}

fn to_device<const D: usize>(
tensor: &NdArrayTensor<E, D>,
_device: NdArrayDevice,
) -> NdArrayTensor<E, D> {
tensor.clone()
}
}
27 changes: 0 additions & 27 deletions burn-tensor/src/tensor/backend/tch/ops/device.rs

This file was deleted.

1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/tch/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ mod arg;
mod cat;
mod creation;
mod detach;
mod device;
mod div;
mod erf;
mod exp;
Expand Down
20 changes: 19 additions & 1 deletion burn-tensor/src/tensor/backend/tch/tensor_ops.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::TchBackend;
use super::{TchBackend, TchDevice, TchTensor};
use crate::{backend::Backend, ops::TensorOps, Data, Shape, TchElement};

impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
Expand Down Expand Up @@ -39,4 +39,22 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
let values: Vec<bool> = tensor.tensor.into();
Data::new(values, tensor.shape)
}
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
match tensor.tensor.device() {
tch::Device::Cpu => TchDevice::Cpu,
tch::Device::Cuda(num) => TchDevice::Cuda(num),
}
}

fn to_device<const D: usize>(tensor: &TchTensor<E, D>, device: TchDevice) -> TchTensor<E, D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor {
kind: tensor.kind.clone(),
tensor: tensor.tensor.to(device),
shape: tensor.shape,
}
}
}
10 changes: 5 additions & 5 deletions burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ where

/// Returns a new tensor on the given device.
pub fn to_device(&self, device: B::Device) -> Self {
Self::new(self.value.to_device(device))
Self::new(B::to_device(&self.value, device))
}

/// Returns the device of the current tensor.
pub fn device(&self) -> B::Device {
self.value.device()
B::device(&self.value)
}

/// Applies element wise exponential operation.
Expand Down Expand Up @@ -99,18 +99,18 @@ where

/// Returns a new tensor with the same shape and device as the current tensor filled with zeros.
pub fn zeros_like(&self) -> Self {
Tensor::new(B::zeros(*self.shape(), self.value.device()))
Tensor::new(B::zeros(*self.shape(), self.device()))
}

/// Returns a new tensor with the same shape and device as the current tensor filled with ones.
pub fn ones_like(&self) -> Self {
Tensor::new(B::ones(*self.shape(), self.value.device()))
Tensor::new(B::ones(*self.shape(), self.device()))
}

/// Returns a new tensor with the same shape and device as the current tensor filled random
/// values sampled from the given distribution.
pub fn random_like(&self, distribution: Distribution<B::Elem>) -> Self {
Tensor::new(B::random(*self.shape(), distribution, self.value.device()))
Tensor::new(B::random(*self.shape(), distribution, self.device()))
}

/// Create a one hot tensor.
Expand Down
10 changes: 5 additions & 5 deletions burn-tensor/src/tensor/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ pub trait TensorOps<B: Backend> {
fn bool_shape<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> &Shape<D>;
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
}

pub trait TensorOpsDevice<B: Backend, const D: usize> {
fn device(&self) -> B::Device;
fn to_device(&self, device: B::Device) -> Self;
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
fn to_device<const D: usize>(
tensor: &B::TensorPrimitive<D>,
device: B::Device,
) -> B::TensorPrimitive<D>;
}

pub trait TensorOpsAdd<E, const D: usize>: std::ops::Add<Self, Output = Self>
Expand Down

0 comments on commit b1df39e

Please sign in to comment.