Skip to content

Commit

Permalink
[Quality] Split utils.h and utils.cpp (pytorch#2348)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Aug 2, 2024
1 parent 59d2ae1 commit 0029c32
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 45 deletions.
48 changes: 48 additions & 0 deletions torchrl/csrc/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
// utils.h
#include "utils.h"

#include <iostream>
torch::Tensor safetanh(torch::Tensor input, float eps) {
return SafeTanh::apply(input, eps);
}
torch::Tensor safeatanh(torch::Tensor input, float eps) {
return SafeInvTanh::apply(input, eps);
}
torch::Tensor SafeTanh::forward(torch::autograd::AutogradContext* ctx,
torch::Tensor input, float eps) {
auto out = torch::tanh(input);
auto lim = 1.0 - eps;
out = out.clamp(-lim, lim);
ctx->save_for_backward({out});
return out;
}
torch::autograd::tensor_list SafeTanh::backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto out = saved[0];
auto go = grad_outputs[0];
auto grad = go * (1 - out * out);
return {grad, torch::Tensor()};
}
torch::Tensor SafeInvTanh::forward(torch::autograd::AutogradContext* ctx,
torch::Tensor input, float eps) {
auto lim = 1.0 - eps;
auto intermediate = input.clamp(-lim, lim);
ctx->save_for_backward({intermediate});
auto out = torch::atanh(intermediate);
return out;
}
torch::autograd::tensor_list SafeInvTanh::backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto go = grad_outputs[0];
auto grad = go / (1 - input * input);
return {grad, torch::Tensor()};
}
62 changes: 17 additions & 45 deletions torchrl/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,30 @@
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
// utils.h

#pragma once

#include <torch/extension.h>
#include <torch/torch.h>

#include <iostream>

using namespace torch::autograd;
torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6);
torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6);

class SafeTanh : public Function<SafeTanh> {
class SafeTanh : public torch::autograd::Function<SafeTanh> {
public:
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input,
float eps = 1e-6) {
auto out = torch::tanh(input);
auto lim = 1.0 - eps;
out = out.clamp(-lim, lim);
ctx->save_for_backward({out});
return out;
}

static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto out = saved[0];
auto go = grad_outputs[0];
auto grad = go * (1 - out * out);
return {grad, torch::Tensor()};
}
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
torch::Tensor input, float eps);
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs);
};

torch::Tensor safetanh(torch::Tensor input, float eps = 1e-6) {
return SafeTanh::apply(input, eps);
}

class SafeInvTanh : public Function<SafeInvTanh> {
class SafeInvTanh : public torch::autograd::Function<SafeInvTanh> {
public:
static torch::Tensor forward(AutogradContext* ctx, torch::Tensor input,
float eps = 1e-6) {
auto lim = 1.0 - eps;
auto intermediate = input.clamp(-lim, lim);
ctx->save_for_backward({intermediate});
auto out = torch::atanh(intermediate);
return out;
}

static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto go = grad_outputs[0];
auto grad = go / (1 - input * input);
return {grad, torch::Tensor()};
}
static torch::Tensor forward(torch::autograd::AutogradContext* ctx,
torch::Tensor input, float eps);
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs);
};

torch::Tensor safeatanh(torch::Tensor input, float eps = 1e-6) {
return SafeInvTanh::apply(input, eps);
}

0 comments on commit 0029c32

Please sign in to comment.