Skip to content

Commit

Permalink
refactor/mul-ops (tracel-ai#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 5, 2022
1 parent 2bdad6f commit ee61e84
Show file tree
Hide file tree
Showing 25 changed files with 330 additions and 352 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-burn-dataset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: beta
toolchain: stable
components: rustfmt, clippy
override: true

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-burn-tensor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: beta
toolchain: stable
components: rustfmt, clippy
override: true

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-burn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: beta
toolchain: stable
components: rustfmt, clippy
override: true

Expand Down
6 changes: 3 additions & 3 deletions burn-tensor/src/tensor/backend/autodiff/ops/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
let ones = B::ones(shape, B::device(&grad));

let val = 1_f64 / shape.dims[dim] as f64;
let ones = ones.mul_scalar(&B::Elem::from_elem(val));
let ones = B::mul_scalar(&ones, &B::Elem::from_elem(val));

ones.mul(&grad)
B::mul(&ones, &grad)
}
}

Expand All @@ -94,7 +94,7 @@ impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimit
let grad = state.output.grad().sum_dim(dim);
let ones = B::ones(shape, B::device(&grad));

ones.mul(&grad)
B::mul(&ones, &grad)
}
}

Expand Down
8 changes: 4 additions & 4 deletions burn-tensor/src/tensor/backend/autodiff/ops/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ register_ops!(
let value = state.right.value();
let value = value.ones().div(&value);

state.output.grad().mul(&value)
B::mul(&state.output.grad(), &value)
},
partial_right |state: &BinaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>, B::TensorPrimitive<D>>| {
let value_left = state.left.value();
let value_right = state.right.value();
let value = value_left.neg().div(&value_right.mul(&value_right));
let value = value_left.neg().div(&B::mul(&value_right, &value_right));

state.output.grad().mul(&value)
B::mul(&state.output.grad(), &value)
},
);

Expand All @@ -31,7 +31,7 @@ register_ops!(
let value = state_recorded.input.value();
let tmp = value.ones().div_scalar(state);

state_recorded.output.grad().mul(&tmp)
B::mul(&state_recorded.output.grad(), &tmp)
},
);

Expand Down
5 changes: 3 additions & 2 deletions burn-tensor/src/tensor/backend/autodiff/ops/erf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ register_ops!(
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
let value = state.input.value();
let exponent = value.powf(2.0.to_elem()).neg();
let numerator = exponent.exp().mul_scalar(&2.0.to_elem());
let numerator = B::mul_scalar(&exponent.exp(), &2.0.to_elem());
let denominator = std::f64::consts::PI.sqrt().to_elem();
let value = numerator.div_scalar(&denominator);
state.output.grad().mul(&value)

B::mul(&state.output.grad(), &value)
},
);

Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/exp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ register_ops!(
ops UnaryOps,
name ADTensorExpOps,
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
state.output.grad().mul(&state.output.value())
B::mul(&state.output.grad(), &state.output.value())
},
);

Expand Down
2 changes: 1 addition & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ register_ops!(
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
let value = state.input.value();
let value = value.ones().div(&value);
state.output.grad().mul(&value)
B::mul(&state.output.grad(), &value)
},
);

Expand Down
1 change: 0 additions & 1 deletion burn-tensor/src/tensor/backend/autodiff/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ mod map_comparison;
mod mask;
mod matmul;
mod module;
mod mul;
mod neg;
mod pow;
mod precision;
Expand Down
109 changes: 0 additions & 109 deletions burn-tensor/src/tensor/backend/autodiff/ops/mul.rs

This file was deleted.

6 changes: 3 additions & 3 deletions burn-tensor/src/tensor/backend/autodiff/ops/pow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ register_ops!(
value: &f32,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
| {
let value = state.input
let value = B::mul_scalar(&state.input
.value()
.powf(value - 1.0)
.mul_scalar(&value.clone().to_elem());
state.output.grad().mul(&value)
, &value.clone().to_elem());
B::mul(&state.output.grad(), &value)
},
);

Expand Down
Loading

0 comments on commit ee61e84

Please sign in to comment.