diff --git a/torchrl/csrc/utils.cpp b/torchrl/csrc/utils.cpp new file mode 100644 index 00000000000..79cd43fdffb --- /dev/null +++ b/torchrl/csrc/utils.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +// utils.h +#include "utils.h" + +#include +torch::Tensor safetanh(torch::Tensor input, float eps) { + return SafeTanh::apply(input, eps); +} +torch::Tensor safeatanh(torch::Tensor input, float eps) { + return SafeInvTanh::apply(input, eps); +} +torch::Tensor SafeTanh::forward(torch::autograd::AutogradContext* ctx, + torch::Tensor input, float eps) { + auto out = torch::tanh(input); + auto lim = 1.0 - eps; + out = out.clamp(-lim, lim); + ctx->save_for_backward({out}); + return out; +} +torch::autograd::tensor_list SafeTanh::backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto out = saved[0]; + auto go = grad_outputs[0]; + auto grad = go * (1 - out * out); + return {grad, torch::Tensor()}; +} +torch::Tensor SafeInvTanh::forward(torch::autograd::AutogradContext* ctx, + torch::Tensor input, float eps) { + auto lim = 1.0 - eps; + auto intermediate = input.clamp(-lim, lim); + ctx->save_for_backward({intermediate}); + auto out = torch::atanh(intermediate); + return out; +} +torch::autograd::tensor_list SafeInvTanh::backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto go = grad_outputs[0]; + auto grad = go / (1 - input * input); + return {grad, torch::Tensor()}; +} diff --git a/torchrl/csrc/utils.h b/torchrl/csrc/utils.h index a6e5e0b2161..2d93469d82a 100644 --- a/torchrl/csrc/utils.h +++ b/torchrl/csrc/utils.h @@ -2,58 +2,30 @@ // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +// utils.h + +#pragma once #include #include -#include - -using namespace torch::autograd; +torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6); +torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6); -class SafeTanh : public Function { +class SafeTanh : public torch::autograd::Function { public: - static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input, - float eps = 1e-6) { - auto out = torch::tanh(input); - auto lim = 1.0 - eps; - out = out.clamp(-lim, lim); - ctx->save_for_backward({out}); - return out; - } - - static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto out = saved[0]; - auto go = grad_outputs[0]; - auto grad = go * (1 - out * out); - return {grad, torch::Tensor()}; - } + static torch::Tensor forward(torch::autograd::AutogradContext* ctx, + torch::Tensor input, float eps); + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs); }; -torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6) { - return SafeTanh::apply(input, eps); -} - -class SafeInvTanh : public Function { +class SafeInvTanh : public torch::autograd::Function { public: - static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input, - float eps = 1e-6) { - auto lim = 1.0 - eps; - auto intermediate = input.clamp(-lim, lim); - ctx->save_for_backward({intermediate}); - auto out = torch::atanh(intermediate); - return out; - } - - static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto input = saved[0]; - auto go = grad_outputs[0]; - auto grad = go / (1 - input * input); - return {grad, torch::Tensor()}; - } + static torch::Tensor forward(torch::autograd::AutogradContext* ctx, + torch::Tensor input, float eps); + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs); }; - -torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6) { - return SafeInvTanh::apply(input, eps); -}