diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 45f6554baa..91e569372d 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -173,8 +173,8 @@ impl Device { pub fn supports_bf16(&self) -> bool { match self { - Self::Cuda(_) => true, - Self::Metal(_) | Self::Cpu => false, + Self::Cuda(_) | Self::Metal(_) => true, + Self::Cpu => false, } } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 09d5fd49cd..19557cf2ea 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1398,6 +1398,7 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; Ok(acc) } + fn matmul( &self, rhs: &Self, @@ -1406,32 +1407,51 @@ impl BackendStorage for MetalStorage { rhs_l: &Layout, ) -> Result { let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?; - let name = match self.dtype { - DType::F32 => "sgemm", - DType::F16 => "hgemm", - DType::BF16 => "bgemm", - dtype => { - return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) - } - }; - let command_buffer = self.device.command_buffer()?; command_buffer.set_label("matmul"); - candle_metal_kernels::call_gemm( - &self.device.device, - &command_buffer, - &self.device.kernels, - name, - (b, m, n, k), - lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &self.buffer, - rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &rhs.buffer, - &buffer, - ) - .map_err(MetalError::from)?; + if self.dtype == DType::BF16 { + candle_metal_kernels::call_mlx_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + candle_metal_kernels::GemmDType::BF16, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + } else { + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + dtype => { + return Err( + MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(), + ) + } + }; + + candle_metal_kernels::call_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + (b, m, n, k), + lhs_l.stride(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), + &self.buffer, + rhs_l.stride(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + } Ok(Self::new( buffer, self.device.clone(),