Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float cast tensor op #2483

Merged
merged 11 commits into from
Nov 18, 2024
Next Next commit
Remove elem type generic from TchTensor
  • Loading branch information
laggui committed Nov 13, 2024
commit 23d24bf0f7f7bd5009f8f8f377ac7aeac0b5652b
8 changes: 4 additions & 4 deletions crates/burn-tch/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ impl<E: TchElement, Q: QuantElement> Backend for LibTorch<E, Q> {
type Device = LibTorchDevice;
type FullPrecisionBridge = PrecisionBridge<f32>;

type FloatTensorPrimitive = TchTensor<E>;
type FloatTensorPrimitive = TchTensor;
type FloatElem = E;

type IntTensorPrimitive = TchTensor<i64>;
type IntTensorPrimitive = TchTensor;
type IntElem = i64;

type BoolTensorPrimitive = TchTensor<bool>;
type BoolTensorPrimitive = TchTensor;

type QuantizedTensorPrimitive = TchQTensor<Q>;
type QuantizedTensorPrimitive = TchQTensor;
type QuantizedEncoding = Q;

fn seed(seed: u64) {
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tch/src/bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{ops::TchOps, LibTorch, QuantElement, TchElement, TchTensor};
use burn_tensor::{backend::BackendBridge, ops::FloatTensor, Device};
use std::marker::PhantomData;

/// Handle precision conversion for the candle backend.
/// Handle precision conversion for the tch backend.
#[derive(Debug)]
pub struct PrecisionBridge<E: TchElement> {
_e: PhantomData<E>,
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-tch/src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ impl TchElement for i8 {}

impl TchElement for u8 {}

impl TchElement for bool {}

/// A quantized element for the tch backend.
pub trait QuantElement: TchElement {}

Expand Down
10 changes: 5 additions & 5 deletions crates/burn-tch/src/ops/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@ use crate::{element::TchElement, LibTorch, QuantElement, TchTensor};
use burn_tensor::ops::ActivationOps;

impl<E: TchElement, Q: QuantElement> ActivationOps<Self> for LibTorch<E, Q> {
fn relu(tensor: TchTensor<E>) -> TchTensor<E> {
fn relu(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu())
}

fn gelu(tensor: TchTensor<E>) -> TchTensor<E> {
fn gelu(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(
|mut tensor| tensor.gelu_("none"),
|tensor| tensor.gelu("none"),
)
}

fn gelu_backward(tensor: TchTensor<E>, grad: TchTensor<E>) -> TchTensor<E> {
fn gelu_backward(tensor: TchTensor, grad: TchTensor) -> TchTensor {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.gelu_backward(&grad.tensor, "none");

TchTensor::from_existing(tensor, storage)
}

fn sigmoid(tensor: TchTensor<E>) -> TchTensor<E> {
fn sigmoid(tensor: TchTensor) -> TchTensor {
tensor.unary_ops(|mut tensor| tensor.sigmoid_(), |tensor| tensor.sigmoid())
}

fn log_sigmoid(tensor: TchTensor<E>) -> TchTensor<E> {
fn log_sigmoid(tensor: TchTensor) -> TchTensor {
// NOTE: we don't override log_sigmoid_backward because Torch has a special backward
// formula that uses a buffer with computed values from the forward pass

Expand Down
Loading