Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

candle 0.3.0 #881

Merged
merged 2 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
candle 0.3.0
  • Loading branch information
louisfd committed Oct 20, 2023
commit e8e00cf85f56589050ffa531c5c9843d5d927972
3 changes: 1 addition & 2 deletions burn-candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ version = "0.10.0"
derive-new = { workspace = true }
burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features = false }
half = { workspace = true }
# candle-core = { version = "0.1.2" }
candle-core = { git = "https://github.com/huggingface/candle", rev = "237323c" }
candle-core = { version = "0.3.0" }

[dev-dependencies]
burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", default-features = false, features = [
Expand Down
4 changes: 2 additions & 2 deletions burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mod tests {
burn_tensor::testgen_clamp!();
burn_tensor::testgen_cos!();
// burn_tensor::testgen_div!();
// burn_tensor::testgen_erf!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_flatten!();
burn_tensor::testgen_full!();
Expand Down Expand Up @@ -121,7 +121,7 @@ mod tests {
burn_autodiff::testgen_ad_cos!();
burn_autodiff::testgen_ad_cross_entropy_loss!();
burn_autodiff::testgen_ad_div!();
// burn_autodiff::testgen_ad_erf!();
burn_autodiff::testgen_ad_erf!();
burn_autodiff::testgen_ad_exp!();
// burn_autodiff::testgen_ad_slice!();
burn_autodiff::testgen_ad_gather_scatter!();
Expand Down
3 changes: 1 addition & 2 deletions burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> TensorOps<CandleBackend<F, I>>
}

fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
// TODO submit an issue at Candle
panic!("erf not supported by Candle")
CandleTensor::new(tensor.tensor.erf().unwrap())
}

fn cat<const D: usize>(tensors: Vec<FloatTensor<Self, D>>, dim: usize) -> FloatTensor<Self, D> {
Expand Down
76 changes: 76 additions & 0 deletions burn-common/src/benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use alloc::string::String;
use alloc::vec::Vec;
use core::time::Duration;
#[cfg(feature = "std")]
use std::time::Instant;

/// Results of a benchmark run.
#[derive(new, Debug)]
pub struct BenchmarkResult {
durations: Vec<Duration>,
}

impl BenchmarkResult {
/// Returns the median duration among all durations
pub fn median_duration(&self) -> Duration {
let mut sorted = self.durations.clone();
sorted.sort();
*sorted.get(sorted.len() / 2).unwrap()
}
}

/// Benchmark trait.
pub trait Benchmark {
/// Benchmark arguments.
type Args;

/// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't
/// count as included in the duration.
///
/// # Notes
///
/// This should not include warmup, the benchmark will be run at least one time without
/// measuring the execution time.
fn prepare(&self) -> Self::Args;
/// Execute the benchmark and returns the time it took to complete.
fn execute(&mut self, args: Self::Args);
/// Number of samples required to have a statistical significance.
fn num_samples(&self) -> usize {
10
}
/// Name of the benchmark.
fn name(&self) -> String;
/// Wait for computations to be over
fn sync(&mut self);
/// Run the benchmark a number of times.
fn run(&mut self) -> BenchmarkResult {
#[cfg(not(feature = "std"))]
panic!("Attempting to run benchmark in a no-std environment");

#[cfg(feature = "std")]
{
// Warmup
self.execute(self.prepare());
self.sync();

let mut durations = Vec::with_capacity(self.num_samples());

for _ in 0..self.num_samples() {
// Prepare
let args = self.prepare();
self.sync();

// Execute the benchmark
let start = Instant::now();
self.execute(args);
self.sync();
let end = Instant::now();

// Register the duration
durations.push(end - start);
}

BenchmarkResult { durations }
}
}
}
Loading