Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a CPU nbit to float dequantization op that supports torch.quintMxN type and QuantizedCPU backend #2995

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/quantize_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,23 @@ hfp8_to_float(uint8_t hfp8_val, int ebits, int exponent_bias) {
return val_out.F;
}

// Get the number of bytes of a row in a tensor with quantized nbit integers
inline int32_t nbit_elems_to_bytes(const at::Tensor& input) {
const auto input_sizes = input.sizes();
const int32_t ncols = input_sizes[1];
// at::kQUInt4x2 is the dtype for quantized int4 tensors and at::kQUInt2x4 is
// for quantized int2 tensors. QUIntMxN (M*N=8) means quantized M-bit integer
// with each byte holding N such elements.
// input_sizes[1] is the number of elements in each row, so we need to divide
// it by 2 or 4 for quint4x2 or quint2x4 respectively to get the number of
// bytes in each row.
if (input.dtype() == at::kQUInt2x4) {
return fbgemm_gpu::div_up(ncols, 4);
} else if (input.dtype() == at::kQUInt4x2) {
return fbgemm_gpu::div_up(ncols, 2);
} else {
return ncols;
}
}

} // namespace fbgemm_gpu
5 changes: 5 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ __builtin_ia32_serialize(void) {
#define DISPATCH_TO_CPU(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::CPU, TORCH_FN(function)))

#define DISPATCH_TO_QUANTIZED_CPU(name, function) \
m.impl( \
name, \
torch::dispatch(c10::DispatchKey::QuantizedCPU, TORCH_FN(function)))

#define DISPATCH_TO_META(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::Meta, TORCH_FN(function)))

Expand Down
4 changes: 4 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ using fint32 = union fint32 {
float F;
};

inline int64_t div_up(int64_t val, int64_t unit) {
return (val + unit - 1) / unit;
}

} // namespace fbgemm_gpu
63 changes: 62 additions & 1 deletion fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ Tensor _fusednbitrowwise_to_float_cpu(

const auto input_sizes = input.sizes();
const int64_t nrows = input_sizes[0];
const int32_t ncols = input_sizes[1];
// Here we want the number of bytes in a row
const int32_t ncols = nbit_elems_to_bytes(input);
const int32_t num_elem_per_byte = 8 / bit_rate;
const int32_t output_columns =
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;
Expand All @@ -149,6 +150,40 @@ Tensor _fusednbitrowwise_to_float_cpu(
return output;
}

Tensor _fusednbitrowwise_sbfront_to_float_cpu(
const Tensor& input,
const int64_t bit_rate) {
TENSOR_ON_CPU(input);
TENSOR_NDIM_EQUALS(input, 2);

const auto input_sizes = input.sizes();
const int64_t nrows = input_sizes[0];
// Here we want the number of bytes in a row
const int32_t ncols = nbit_elems_to_bytes(input);
const int32_t num_elem_per_byte = 8 / bit_rate;
const int32_t output_columns =
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;

Tensor output;
output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
input.options().dtype(at::kFloat));

float* output_data = static_cast<float*>(
output.data_ptr()); // output.data_ptr<output_t>(); -> Yields
// unresolved data_ptr symbol.

fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<float>(
bit_rate,
input.data_ptr<uint8_t>(),
nrows,
ncols,
output_data,
/*scale_bias_last=*/false);

return output;
}

/// @ingroup quantize-data-cpu
///
Tensor& _fused8bitrowwise_to_float_cpu_out(
Expand Down Expand Up @@ -274,6 +309,24 @@ Tensor fusednbitrowwise_to_float_cpu(
return _fusednbitrowwise_to_float_cpu<float>(input, bit_rate);
}

/// @ingroup quantize-data-cpu
/// @brief Dequantize int4/int2 rows with scale and bias stored in the front
/// into float32.
/// @param input Tensor of int4/int2 rows with scale and bias stored in the
/// front.
/// @param bit_rate Bit rate of each element. Should be 4 or 2.
/// @return Tensor of float32, holding dequantized numbers.
///
/// Dequantize int4/int2 rows with scale and bias stored in the front into
/// float32. The input tensor should have torch.quint4x2 or torch.quint2x4 dtype
/// and QuantizedCPU backend. This operator is only recommended for testing
/// purpose because its kernel is reference implementation and not optimized.
Tensor fusednbitrowwise_sbfront_to_float_cpu(
const Tensor& input,
const int64_t bit_rate) {
return _fusednbitrowwise_sbfront_to_float_cpu(input, bit_rate);
}

/// @ingroup quantize-data-cpu
///
Tensor fusednbitrowwise_to_half_cpu(
Expand Down Expand Up @@ -466,6 +519,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfToHalf(Tensor input, int bit_rate) -> Tensor");
m.def(
Expand All @@ -485,6 +540,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("dequantize_mx_cuda(Tensor input, int mx_group_size) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, QuantizedCPU, m) {
DISPATCH_TO_QUANTIZED_CPU(
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat",
fbgemm_gpu::fusednbitrowwise_sbfront_to_float_cpu);
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU(
"FloatToFused8BitRowwiseQuantized",
Expand Down
3 changes: 2 additions & 1 deletion include/fbgemm/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
OutputType* output,
bool scale_bias_last = true);

/**
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
Expand Down
14 changes: 10 additions & 4 deletions src/QuantUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,8 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output) {
OutputType* output,
bool scale_bias_last) {
static_assert(
std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(),
"Only float and float16 types are allowed.");
Expand All @@ -742,13 +743,17 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
const std::uint8_t* input_row = input + row * input_columns;
const float16* input_row_scale_bias = reinterpret_cast<const float16*>(
input_row +
(output_columns + num_elem_per_byte - 1) / num_elem_per_byte);
(scale_bias_last
? (output_columns + num_elem_per_byte - 1) / num_elem_per_byte
: 0));
float scale = cpu_half2float(input_row_scale_bias[0]);
float bias = cpu_half2float(input_row_scale_bias[1]);
const std::uint8_t* nums =
(scale_bias_last) ? input_row : input_row + 2 * sizeof(float16);
OutputType* output_row = output + row * output_columns;

for (int64_t col = 0; col < output_columns; ++col) {
std::uint8_t quantized = input_row[col / num_elem_per_byte];
std::uint8_t quantized = nums[col / num_elem_per_byte];
quantized >>= (col % num_elem_per_byte) * bit_rate;
quantized &= (1 << bit_rate) - 1;
float output_value = scale * quantized + bias;
Expand Down Expand Up @@ -857,7 +862,8 @@ void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
const uint8_t* input, \
size_t input_rows, \
int input_columns, \
type* output); \
type* output, \
bool scale_bias_last); \
template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<type>( \
int bit_rate, \
const uint8_t* input, \
Expand Down
Loading