Skip to content

Commit

Permalink
bump candle to 0.3.1 and conv_transpose_1d (tracel-ai#977)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Nov 21, 2023
1 parent cdf54d0 commit 4711db0
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
4 changes: 2 additions & 2 deletions burn-candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ derive-new = { workspace = true }
burn-tensor = { path = "../burn-tensor", version = "0.11.0", default-features = false }
half = { workspace = true }

# TODO remove pinned version ("=") once candle-core is updated to 0.3.1
candle-core = { version = "=0.3.0" }
candle-core = { version = "0.3.1" }


[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.11.0", default-features = false, features = [
Expand Down
1 change: 1 addition & 0 deletions burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ impl From<candle_core::Device> for CandleDevice {
match device.location() {
DeviceLocation::Cpu => CandleDevice::Cpu,
DeviceLocation::Cuda { gpu_id } => CandleDevice::Cuda(gpu_id),
DeviceLocation::Metal => panic!("Metal unsupported"),
}
}
}
Expand Down
21 changes: 20 additions & 1 deletion burn-candle/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,26 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<Self, 3> {
panic!("Candle does not support conv_transpose1d")
assert!(
options.groups == 1,
"Candle does not support groups in transposed convolutions"
);
let conv_transpose = x
.tensor
.conv_transpose1d(
&weight.tensor,
options.padding[0],
options.padding_out[0],
options.stride[0],
options.dilation[0],
)
.unwrap();
CandleTensor::new(match bias {
Some(bias) => conv_transpose
.broadcast_add(&bias.tensor.unsqueeze(0).unwrap().unsqueeze(2).unwrap())
.unwrap(),
None => conv_transpose,
})
}

fn conv_transpose2d(
Expand Down

0 comments on commit 4711db0

Please sign in to comment.