Skip to content

Commit

Permalink
Add a toggle for F16/BF16 accumulation in gemm. (huggingface#2141)
Browse files Browse the repository at this point in the history
* Add a toggle to control f16/bf16 gemm precision.

* Use the faster variant in the quantized example.

* Bugfix.
LaurentMazare authored Apr 29, 2024
1 parent 287013e commit ed7b99f
Showing 4 changed files with 153 additions and 15 deletions.
137 changes: 125 additions & 12 deletions candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
@@ -1635,25 +1635,17 @@ impl BackendStorage for CudaStorage {
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}
.w()?;
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}
.w()?;
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
@@ -1856,3 +1848,124 @@ impl BackendStorage for CudaStorage {
Ok(())
}
}

// Default for the reduced precision setting is false, similar to pytorch.
// https://github.com/pytorch/pytorch/issues/123157
static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);

/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn gemm_reduced_precision_f16() -> bool {
MM_F16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
}

/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn set_gemm_reduced_precision_f16(b: bool) {
MM_F16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}

/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn gemm_reduced_precision_bf16() -> bool {
MM_BF16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
}

/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn set_gemm_reduced_precision_bf16(b: bool) {
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}

unsafe fn gemm_strided_batched_f16(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<f16>,
a: &cudarc::driver::CudaView<f16>,
b: &cudarc::driver::CudaView<f16>,
c: &mut CudaSlice<f16>,
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;

let compute_type = if gemm_reduced_precision_f16() {
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
} else {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};

let alpha = cfg.gemm.alpha;
let beta = cfg.gemm.beta;
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const f16 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const f16 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16F,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
compute_type,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}

unsafe fn gemm_strided_batched_bf16(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<bf16>,
a: &cudarc::driver::CudaView<bf16>,
b: &cudarc::driver::CudaView<bf16>,
c: &mut CudaSlice<bf16>,
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;

let compute_type = if gemm_reduced_precision_bf16() {
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F
} else {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};

let alpha = cfg.gemm.alpha;
let beta = cfg.gemm.beta;
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
(&alpha) as *const bf16 as *const _,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldb,
cfg.stride_b,
(&beta) as *const bf16 as *const _,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_16BF,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
compute_type,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}
20 changes: 20 additions & 0 deletions candle-core/src/dummy_cuda_backend.rs
Original file line number Diff line number Diff line change
@@ -238,3 +238,23 @@ impl crate::backend::BackendDevice for CudaDevice {
Ok(())
}
}

/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn gemm_reduced_precision_f16() -> bool {
true
}

/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
pub fn set_gemm_reduced_precision_f16(_: bool) {}

/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn gemm_reduced_precision_bf16() -> bool {
true
}

/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
8 changes: 5 additions & 3 deletions candle-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@ mod custom_op;
mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
pub mod dummy_cuda_backend;
mod dummy_metal_backend;
pub mod error;
mod indexer;
@@ -89,10 +89,12 @@ pub use tensor::{Tensor, TensorId};
pub use variable::Var;

#[cfg(feature = "cuda")]
pub use cuda_backend::{CudaDevice, CudaStorage};
pub use cuda_backend as cuda;

#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
pub use dummy_cuda_backend as cuda;

pub use cuda::{CudaDevice, CudaStorage};

#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
3 changes: 3 additions & 0 deletions candle-examples/examples/quantized/main.rs
Original file line number Diff line number Diff line change
@@ -374,6 +374,9 @@ fn main() -> anyhow::Result<()> {
#[cfg(feature = "cuda")]
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);

candle::cuda::set_gemm_reduced_precision_f16(true);
candle::cuda::set_gemm_reduced_precision_bf16(true);

let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();

0 comments on commit ed7b99f

Please sign in to comment.