Skip to content

Commit

Permalink
Implement tensor.recip() function to calculate elementwise reciprocals (
Browse files Browse the repository at this point in the history
  • Loading branch information
gzsombor authored Nov 15, 2023
1 parent e882d41 commit 4fc0c27
Show file tree
Hide file tree
Showing 25 changed files with 241 additions and 1 deletion.
26 changes: 26 additions & 0 deletions burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,32 @@ impl<B: Backend> TensorOps<Self> for Autodiff<B> {
.stateless(B::neg(tensor.primitive))
}

fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct Recip;

impl<B: Backend, const D: usize> Backward<B, D, 1> for Recip {
type State = B::TensorPrimitive<D>;

fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let tensor = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let tmp = B::powf(tensor, -2.0);
let value = B::neg(tmp);

B::mul(grad, value)
});
}
}

match Recip.prepare([tensor.node], [tensor.graph]).stateful() {
OpsKind::Tracked(prep) => {
prep.finish(tensor.primitive.clone(), B::recip(tensor.primitive))
}
OpsKind::UnTracked(prep) => prep.finish(B::recip(tensor.primitive)),
}
}

fn swap_dims<const D: usize>(
tensor: FloatTensor<Self, D>,
dim1: usize,
Expand Down
2 changes: 2 additions & 0 deletions burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod mul;
mod multithread;
mod neg;
mod pow;
mod recip;
mod relu;
mod reshape;
mod select;
Expand Down Expand Up @@ -94,6 +95,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_recip!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!();
Expand Down
20 changes: 20 additions & 0 deletions burn-autodiff/src/tests/recip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#[burn_tensor_testgen::testgen(ad_recip)]
mod tests {
use super::*;
use burn_tensor::Data;

#[test]
fn should_diff_recip() {
let data = Data::from([2.0, 5.0, 0.4]);

let tensor = TestAutodiffTensor::from_data(data).require_grad();
let tensor_out = tensor.clone().recip();

let grads = tensor_out.backward();
let grad = tensor.grad(&grads).unwrap();

assert_eq!(tensor_out.into_data(), Data::from([0.5, 0.2, 2.5]));
grad.to_data()
.assert_approx_eq(&Data::from([-0.25, -0.04, -6.25]), 3);
}
}
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ Those operations are only available for `Float` tensors.
| `tensor.erf()` | `tensor.erf()` |
| `tensor.powf(value)` | `tensor.pow(value)` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.tanh()` | `tensor.tanh()` |
Expand Down
2 changes: 2 additions & 0 deletions burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ mod tests {
burn_tensor::testgen_arg!();
burn_tensor::testgen_cast!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_recip!();
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
// burn_tensor::testgen_div!();
Expand Down Expand Up @@ -133,6 +134,7 @@ mod tests {
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_recip!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_softmax!();
Expand Down
4 changes: 4 additions & 0 deletions burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,4 +442,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<Self> for Candle<F, I
) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.clamp(min, max).unwrap())
}

fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
CandleTensor::new(tensor.tensor.recip().unwrap())
}
}
7 changes: 7 additions & 0 deletions burn-fusion/src/graph/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ pub enum FloatOpsDescription<B: FusionBackend> {
(TensorDescription, Distribution<FloatElem<B>>),
Box<dyn Ops<B, Args = (TensorDescription, Distribution<FloatElem<B>>)>>,
),
/// Operation corresponding to [recip](burn_tensor::ops::TensorOps::recip).
Recip(
UnaryOpsDescription,
Box<dyn Ops<B, Args = UnaryOpsDescription>>,
),
}

/// Operation description specific to module.
Expand Down Expand Up @@ -1252,6 +1257,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
FloatOpsDescription::Log(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Log1p(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Erf(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Recip(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Powf(desc, _) => handles.cleanup(&desc.lhs),
FloatOpsDescription::Sqrt(desc, _) => handles.cleanup(&desc.input),
FloatOpsDescription::Cos(desc, _) => handles.cleanup(&desc.input),
Expand All @@ -1268,6 +1274,7 @@ impl<B: FusionBackend> FloatOpsDescription<B> {
FloatOpsDescription::Log(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Log1p(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Erf(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Recip(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Powf(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Sqrt(desc, ops) => ops.execute(desc, handles),
FloatOpsDescription::Cos(desc, ops) => ops.execute(desc, handles),
Expand Down
15 changes: 15 additions & 0 deletions burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,21 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
out
}

fn recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(Recip, B::recip);

let out = tensor.client.create_tensor_empty(tensor.shape.clone());
out.client
.register(TensorOpsDescription::FloatOps(FloatOpsDescription::Recip(
UnaryOpsDescription {
input: tensor.into_description(),
out: out.to_description_out(),
},
Box::new(Recip::<D>),
)));
out
}

fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary_float_ops!(TanhOps, B::erf);

Expand Down
2 changes: 1 addition & 1 deletion burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ represent the corresponding Burn Op.
| [RandomUniform][128] |||
| [RandomUniformLike][129] |||
| [Range][130] |||
| [Reciprocal][131] | | |
| [Reciprocal][131] | | |
| [ReduceL][132] |||
| [ReduceLogSum][133] |||
| [ReduceLogSumExp][134] |||
Expand Down
1 change: 1 addition & 0 deletions burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fn main() {
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/mul/mul.onnx")
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
Expand Down
14 changes: 14 additions & 0 deletions burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include_models!(
log_softmax,
maxpool2d,
mul,
recip,
relu,
reshape,
sigmoid,
Expand Down Expand Up @@ -591,4 +592,17 @@ mod tests {
let expected = Data::from([[[[0.7616, 0.9640, 0.9951, 0.9993]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}

#[test]
fn recip() {
// Initialize the model
let model = recip::Model::<Backend>::new();

// Run the model
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
let output = model.forward(input);
// data from pyTorch
let expected = Data::from([[[[1.0000, 0.5000, 0.3333, 0.2500]]]]);
output.to_data().assert_approx_eq(&expected, 4);
}
}
17 changes: 17 additions & 0 deletions burn-import/onnx-tests/tests/recip/recip.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pytorch2.1.0:�
0
onnx::Reciprocal_01 /Reciprocal"
Reciprocal
main_graphZ,
onnx::Reciprocal_0




b
1




B
42 changes: 42 additions & 0 deletions burn-import/onnx-tests/tests/recip/recip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/recip/recip.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x):
return x.reciprocal()


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")
onnx_name = "recip.onnx"
dummy_input = torch.randn(1, 2, 3, 4, device=device)

torch.onnx.export(model, (dummy_input), onnx_name,
verbose=False, opset_version=16)

print("Finished exporting model to {}".format(onnx_name))

# Output some test data for use in the test
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])

print("Test input data: {}".format(test_input))
output = model.forward(test_input)
print("Test output data: {}".format(output))


if __name__ == '__main__':
main()
26 changes: 26 additions & 0 deletions burn-import/src/burn/node/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub enum UnaryNodeKind {
LogSoftmax,
Softmax,
Relu,
Reciprocal,
Sigmoid,
Tanh,
Transpose,
Expand All @@ -40,6 +41,7 @@ impl UnaryNodeKind {
Self::LogSoftmax => "log_softmax",
Self::Softmax => "softmax",
Self::Relu => "relu",
Self::Reciprocal => "reciprocal",
Self::Sigmoid => "sigmoid",
Self::Tanh => "tanh",
Self::Transpose => "transpose",
Expand Down Expand Up @@ -141,6 +143,11 @@ impl UnaryNode {
Self::new(input, output, UnaryNodeKind::Transpose, Rc::new(function))
}

pub(crate) fn reciprocal(input: Type, output: Type) -> Self {
let function = move |input| quote! { #input.recip() };
Self::new(input, output, UnaryNodeKind::Reciprocal, Rc::new(function))
}

/// Casts the input to the output type.
///
/// Currently this function only supports the following conversions:
Expand Down Expand Up @@ -334,6 +341,25 @@ mod tests {
);
}

#[test]
fn test_unary_codegen_reciprocal() {
one_node_graph(
UnaryNode::reciprocal(
Type::Tensor(TensorType::new_float("tensor1", 4)),
Type::Tensor(TensorType::new_float("tensor2", 4)),
),
quote! {
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = tensor1.recip();

tensor2
}
},
vec!["tensor1".to_string()],
vec!["tensor2".to_string()],
);
}

#[test]
fn test_unary_codegen_cast() {
one_node_graph(
Expand Down
1 change: 1 addition & 0 deletions burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub fn dim_inference(
NodeType::Erf => same_as_input(node),
NodeType::Sqrt => same_as_input(node),
NodeType::Tanh => same_as_input(node),
NodeType::Reciprocal => same_as_input(node),
NodeType::Softmax => same_as_input(node),
NodeType::ReduceMean => mean_update_outputs(node),
NodeType::Constant => constant_update_outputs(node),
Expand Down
8 changes: 8 additions & 0 deletions burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ impl ONNXGraph {
NodeType::Tanh => graph.register(Self::tanh_conversion(node)),
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
NodeType::Reciprocal => graph.register(Self::reciprocal_conversion(node)),
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
NodeType::Concat => graph.register(Self::concat_conversion(node)),
Expand Down Expand Up @@ -447,6 +448,13 @@ impl ONNXGraph {
UnaryNode::sigmoid(input, output)
}

fn reciprocal_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();

UnaryNode::reciprocal(input, output)
}

fn log_softmax_conversion(node: Node) -> UnaryNode {
let input = node.inputs.get(0).unwrap().to_type();
let output = node.outputs.get(0).unwrap().to_type();
Expand Down
7 changes: 7 additions & 0 deletions burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ where
NdArrayTensor { array }
}

pub fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.map(|x| 1.elem::<E>() / *x);
let array = array.into_shared();

NdArrayTensor { array }
}

pub fn mean<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, 1> {
let data = Data::from([tensor.array.mean().unwrap()]);
NdArrayTensor::from_data(data)
Expand Down
4 changes: 4 additions & 0 deletions burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ impl<E: FloatNdArrayElement> TensorOps<Self> for NdArray<E> {
Self::mul_scalar(tensor, (-1f32).elem::<E>())
}

fn recip<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
NdArrayMathOps::recip(tensor)
}

fn swap_dims<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim1: usize,
Expand Down
4 changes: 4 additions & 0 deletions burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ impl<E: TchElement> TensorOps<Self> for LibTorch<E> {
Self::mul_scalar(tensor, (-1f32).elem::<E>())
}

fn recip<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
TchTensor::new(tensor.tensor.reciprocal())
}

fn swap_dims<const D: usize>(
tensor: TchTensor<E, D>,
dim1: usize,
Expand Down
5 changes: 5 additions & 0 deletions burn-tensor/src/tensor/api/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ where
Self::new(B::powf(self.primitive, value))
}

/// Applies element wise reciprocal operation.
pub fn recip(self) -> Self {
Self::new(B::recip(self.primitive))
}

/// Applies element wise root square operation.
pub fn sqrt(self) -> Self {
Self::new(B::sqrt(self.primitive))
Expand Down
3 changes: 3 additions & 0 deletions burn-tensor/src/tensor/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ pub trait TensorOps<B: Backend> {
Self::mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
}

/// Calculates the reciprocals elementwise
fn recip<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;

/// Transposes a tensor.
///
/// # Arguments
Expand Down
Loading

0 comments on commit 4fc0c27

Please sign in to comment.