From b1df39e7fc8353c5b513dbf8f8d561b59b8f4e2a Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Fri, 21 Oct 2022 11:36:51 -0400 Subject: [PATCH] Refactor/device ops (#60) --- .../backend/autodiff/ops/aggregation.rs | 8 +-- .../src/tensor/backend/autodiff/ops/device.rs | 36 ---------- .../src/tensor/backend/autodiff/ops/mod.rs | 1 - .../src/tensor/backend/autodiff/tensor_ops.rs | 68 ++++++++++++++++++- burn-tensor/src/tensor/backend/base.rs | 1 - .../src/tensor/backend/ndarray/ops/device.rs | 18 ----- .../src/tensor/backend/ndarray/ops/mod.rs | 1 - .../src/tensor/backend/ndarray/tensor_ops.rs | 18 ++++- .../src/tensor/backend/tch/ops/device.rs | 27 -------- burn-tensor/src/tensor/backend/tch/ops/mod.rs | 1 - .../src/tensor/backend/tch/tensor_ops.rs | 20 +++++- burn-tensor/src/tensor/base.rs | 10 +-- burn-tensor/src/tensor/ops/base.rs | 10 +-- 13 files changed, 115 insertions(+), 104 deletions(-) delete mode 100644 burn-tensor/src/tensor/backend/autodiff/ops/device.rs delete mode 100644 burn-tensor/src/tensor/backend/ndarray/ops/device.rs delete mode 100644 burn-tensor/src/tensor/backend/tch/ops/device.rs diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/aggregation.rs b/burn-tensor/src/tensor/backend/autodiff/ops/aggregation.rs index 5d49ac42d4..79583ac7a4 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/aggregation.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/aggregation.rs @@ -36,7 +36,7 @@ impl UnaryOps, B::TensorPrimit state: &UnaryOpsNodeState, B::TensorPrimitive<1>>, ) -> B::TensorPrimitive { let grad = state.output.grad(); - let ones = B::ones(self.state, grad.device()); + let ones = B::ones(self.state, B::device(&grad)); let grad: Tensor = Tensor::new(grad); let val = 1_f64 / self.state.num_elements() as f64; @@ -54,7 +54,7 @@ impl UnaryOps, B::TensorPrimit state: &UnaryOpsNodeState, B::TensorPrimitive<1>>, ) -> B::TensorPrimitive { let grad = state.output.grad(); - let ones = B::ones(self.state, grad.device()); + let ones = B::ones(self.state, B::device(&grad)); let grad: Tensor = Tensor::new(grad); let ones: Tensor = Tensor::new(ones); @@ -73,7 +73,7 @@ impl UnaryOps, B::TensorPrimit let (shape, dim) = self.state; let grad = state.output.grad().sum_dim(dim); - let ones = B::ones(shape, grad.device()); + let ones = B::ones(shape, B::device(&grad)); let val = 1_f64 / shape.dims[dim] as f64; let ones = ones.mul_scalar(&B::Elem::from_elem(val)); @@ -92,7 +92,7 @@ impl UnaryOps, B::TensorPrimit let (shape, dim) = self.state; let grad = state.output.grad().sum_dim(dim); - let ones = B::ones(shape, grad.device()); + let ones = B::ones(shape, B::device(&grad)); ones.mul(&grad) } diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/device.rs b/burn-tensor/src/tensor/backend/autodiff/ops/device.rs deleted file mode 100644 index d0f74606a7..0000000000 --- a/burn-tensor/src/tensor/backend/autodiff/ops/device.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::backend::autodiff::ADBackendDecorator; -use crate::backend::Backend; -use crate::ops::TensorOpsDevice; -use crate::{execute_ops, register_ops}; -use crate::{ - graph::ops::{UnaryOps, UnaryOpsNodeState}, - tensor::backend::autodiff::ADTensor, -}; - -register_ops!( - ops UnaryOps, - name ADTensorDeviceOps state B::Device, - partial | - device: &B::Device, - state: &UnaryOpsNodeState, B::TensorPrimitive> - | { - state.output.grad().to_device(*device) - }, -); - -impl TensorOpsDevice, D> - for as Backend>::TensorPrimitive -{ - fn device(&self) -> as Backend>::Device { - TensorOpsDevice::device(&self.tensor()) - } - - fn to_device(&self, device: as Backend>::Device) -> ADTensor { - let tensor = self.tensor(); - execute_ops!( - input self.node.clone(), - out TensorOpsDevice::to_device(&tensor, device), - ops ADTensorDeviceOps::::new(tensor.device()), - ) - } -} diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs index 2bea71beb7..0028c5caa9 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -4,7 +4,6 @@ mod arg; mod cat; mod creation; mod detach; -mod device; mod div; mod erf; mod exp; diff --git a/burn-tensor/src/tensor/backend/autodiff/tensor_ops.rs b/burn-tensor/src/tensor/backend/autodiff/tensor_ops.rs index 5e4775fc0e..2c83972c78 100644 --- a/burn-tensor/src/tensor/backend/autodiff/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/autodiff/tensor_ops.rs @@ -1,5 +1,30 @@ -use super::ADBackendDecorator; -use crate::{backend::Backend, ops::TensorOps, Data, Shape}; +use super::{ADBackendDecorator, ADTensor}; +use crate::{ + backend::Backend, + graph::{ + node::{ForwardNode, ForwardNodeRef, ForwardNodeState}, + ops::{ForwardUnaryRecordedOps, UnaryOps, UnaryOpsNodeState}, + }, + ops::TensorOps, + Data, Shape, +}; +use std::sync::Arc; + +#[derive(new, Debug)] +struct ToDeviceBackward { + device: B::Device, +} + +impl UnaryOps, B::TensorPrimitive> + for ToDeviceBackward +{ + fn partial( + &self, + state: &UnaryOpsNodeState, B::TensorPrimitive>, + ) -> B::TensorPrimitive { + B::to_device(&state.output.grad(), self.device) + } +} impl TensorOps> for ADBackendDecorator { fn shape( @@ -37,4 +62,43 @@ impl TensorOps> for ADBackendDecorator { ) -> Data { B::bool_into_data(tensor) } + + fn device( + tensor: & as Backend>::TensorPrimitive, + ) -> as Backend>::Device { + B::device(tensor.tensor_ref()) + } + + fn to_device( + tensor: & as Backend>::TensorPrimitive, + device: as Backend>::Device, + ) -> as Backend>::TensorPrimitive { + let input = tensor.node.clone(); + let output = B::to_device(tensor.tensor_ref(), device); + let ops = ToDeviceBackward::::new(device); + + unary_ops_wrapper(input, output, ops) + } +} + +fn unary_ops_wrapper( + input: ForwardNodeRef>, + output: B::TensorPrimitive, + ops: O, +) -> ADTensor +where + B: Backend, + O: UnaryOps, B::TensorPrimitive> + 'static, +{ + let shape = *B::shape(&output); + let state = ForwardNodeState::new(output); + + let ops = Arc::new(ops); + let ops = ForwardUnaryRecordedOps::new(input.clone(), ops); + let ops = Arc::new(ops); + + let node = ForwardNode::from_unary(&input, state, ops); + let node = Arc::new(node); + + ADTensor { node, shape } } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 08f53eb2ed..038035c815 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -25,7 +25,6 @@ pub trait Backend: + Ones> + TensorOpsReshape + TensorOpsPrecision - + TensorOpsDevice + TensorOpsIndex + TensorOpsAggregation + TensorOpsExp diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/device.rs b/burn-tensor/src/tensor/backend/ndarray/ops/device.rs deleted file mode 100644 index 87feff24cf..0000000000 --- a/burn-tensor/src/tensor/backend/ndarray/ops/device.rs +++ /dev/null @@ -1,18 +0,0 @@ -use crate::{ - backend::ndarray::{NdArrayBackend, NdArrayDevice, NdArrayTensor}, - ops::TensorOpsDevice, - NdArrayElement, -}; - -impl TensorOpsDevice, D> for NdArrayTensor -where - E: NdArrayElement, -{ - fn device(&self) -> as crate::backend::Backend>::Device { - NdArrayDevice::Cpu - } - - fn to_device(&self, _device: as crate::backend::Backend>::Device) -> Self { - self.clone() - } -} diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs index bd6d2469a2..65f82ccafd 100644 --- a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs @@ -4,7 +4,6 @@ mod arg; mod cat; mod creation; mod detach; -mod device; mod div; mod erf; mod exp; diff --git a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs index 3e417fa59b..a05f2971a5 100644 --- a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs @@ -1,5 +1,9 @@ -use super::NdArrayBackend; -use crate::{backend::Backend, ops::TensorOps, Data, NdArrayElement, Shape}; +use super::{NdArrayBackend, NdArrayTensor}; +use crate::{ + backend::{Backend, NdArrayDevice}, + ops::TensorOps, + Data, NdArrayElement, Shape, +}; impl TensorOps> for NdArrayBackend { fn shape( @@ -41,4 +45,14 @@ impl TensorOps> for NdArrayBackend { let values = tensor.array.into_iter().collect(); Data::new(values, tensor.shape) } + fn device(_tensor: &NdArrayTensor) -> NdArrayDevice { + NdArrayDevice::Cpu + } + + fn to_device( + tensor: &NdArrayTensor, + _device: NdArrayDevice, + ) -> NdArrayTensor { + tensor.clone() + } } diff --git a/burn-tensor/src/tensor/backend/tch/ops/device.rs b/burn-tensor/src/tensor/backend/tch/ops/device.rs deleted file mode 100644 index f728bf7a7c..0000000000 --- a/burn-tensor/src/tensor/backend/tch/ops/device.rs +++ /dev/null @@ -1,27 +0,0 @@ -use crate::{ - backend::tch::{TchBackend, TchDevice, TchTensor}, - backend::Backend, - ops::TensorOpsDevice, - TchElement, -}; - -impl TensorOpsDevice, D> for TchTensor { - fn device(&self) -> as Backend>::Device { - match self.tensor.device() { - tch::Device::Cpu => TchDevice::Cpu, - tch::Device::Cuda(num) => TchDevice::Cuda(num), - } - } - - fn to_device(&self, device: as Backend>::Device) -> Self { - let device = match device { - TchDevice::Cpu => tch::Device::Cpu, - TchDevice::Cuda(num) => tch::Device::Cuda(num), - }; - Self { - kind: self.kind.clone(), - tensor: self.tensor.to(device), - shape: self.shape, - } - } -} diff --git a/burn-tensor/src/tensor/backend/tch/ops/mod.rs b/burn-tensor/src/tensor/backend/tch/ops/mod.rs index bd6d2469a2..65f82ccafd 100644 --- a/burn-tensor/src/tensor/backend/tch/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/tch/ops/mod.rs @@ -4,7 +4,6 @@ mod arg; mod cat; mod creation; mod detach; -mod device; mod div; mod erf; mod exp; diff --git a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs index 027da68104..66756f36bf 100644 --- a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs @@ -1,4 +1,4 @@ -use super::TchBackend; +use super::{TchBackend, TchDevice, TchTensor}; use crate::{backend::Backend, ops::TensorOps, Data, Shape, TchElement}; impl TensorOps> for TchBackend { @@ -39,4 +39,22 @@ impl TensorOps> for TchBackend { let values: Vec = tensor.tensor.into(); Data::new(values, tensor.shape) } + fn device(tensor: &TchTensor) -> TchDevice { + match tensor.tensor.device() { + tch::Device::Cpu => TchDevice::Cpu, + tch::Device::Cuda(num) => TchDevice::Cuda(num), + } + } + + fn to_device(tensor: &TchTensor, device: TchDevice) -> TchTensor { + let device = match device { + TchDevice::Cpu => tch::Device::Cpu, + TchDevice::Cuda(num) => tch::Device::Cuda(num), + }; + TchTensor { + kind: tensor.kind.clone(), + tensor: tensor.tensor.to(device), + shape: tensor.shape, + } + } } diff --git a/burn-tensor/src/tensor/base.rs b/burn-tensor/src/tensor/base.rs index 4c5e4b5933..1ec0879619 100644 --- a/burn-tensor/src/tensor/base.rs +++ b/burn-tensor/src/tensor/base.rs @@ -34,12 +34,12 @@ where /// Returns a new tensor on the given device. pub fn to_device(&self, device: B::Device) -> Self { - Self::new(self.value.to_device(device)) + Self::new(B::to_device(&self.value, device)) } /// Returns the device of the current tensor. pub fn device(&self) -> B::Device { - self.value.device() + B::device(&self.value) } /// Applies element wise exponential operation. @@ -99,18 +99,18 @@ where /// Returns a new tensor with the same shape and device as the current tensor filled with zeros. pub fn zeros_like(&self) -> Self { - Tensor::new(B::zeros(*self.shape(), self.value.device())) + Tensor::new(B::zeros(*self.shape(), self.device())) } /// Returns a new tensor with the same shape and device as the current tensor filled with ones. pub fn ones_like(&self) -> Self { - Tensor::new(B::ones(*self.shape(), self.value.device())) + Tensor::new(B::ones(*self.shape(), self.device())) } /// Returns a new tensor with the same shape and device as the current tensor filled random /// values sampled from the given distribution. pub fn random_like(&self, distribution: Distribution) -> Self { - Tensor::new(B::random(*self.shape(), distribution, self.value.device())) + Tensor::new(B::random(*self.shape(), distribution, self.device())) } /// Create a one hot tensor. diff --git a/burn-tensor/src/tensor/ops/base.rs b/burn-tensor/src/tensor/ops/base.rs index eda8d1067e..1e55064cc5 100644 --- a/burn-tensor/src/tensor/ops/base.rs +++ b/burn-tensor/src/tensor/ops/base.rs @@ -8,11 +8,11 @@ pub trait TensorOps { fn bool_shape(tensor: &B::BoolTensorPrimitive) -> &Shape; fn bool_to_data(tensor: &B::BoolTensorPrimitive) -> Data; fn bool_into_data(tensor: B::BoolTensorPrimitive) -> Data; -} - -pub trait TensorOpsDevice { - fn device(&self) -> B::Device; - fn to_device(&self, device: B::Device) -> Self; + fn device(tensor: &B::TensorPrimitive) -> B::Device; + fn to_device( + tensor: &B::TensorPrimitive, + device: B::Device, + ) -> B::TensorPrimitive; } pub trait TensorOpsAdd: std::ops::Add