Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Mixed Input TensorOp #1084

Merged
merged 21 commits into from
Sep 27, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast
  • Loading branch information
Manish Gupta committed Sep 26, 2023
commit 45c4450347d46d22a45993a5b01e219ec5e6c044
4 changes: 2 additions & 2 deletions include/cutlass/arch/mma.h
Original file line number Diff line number Diff line change
@@ -69,8 +69,8 @@ struct OpMultiplyAddFastF16 {};
/////////////////////////////////////////////////////////////////////////////////////////////////

/// Tag indicating the input data types are mixed and the narrower type is
/// converted to the wider type
struct OpMultiplyAddMixedInput {};
/// upcasted to the wider type
struct OpMultiplyAddMixedInputUpcast {};

/////////////////////////////////////////////////////////////////////////////////////////////////

18 changes: 9 additions & 9 deletions include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
Original file line number Diff line number Diff line change
@@ -252,20 +252,20 @@ template <
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<
WarpShape_,
GemmShape<16, 8, 16>, // InstructionShape
ElementA, // Element type of A matrix in Global Memory
LayoutA, // Layout of A matrix in Global Memory
ElementB, // Element type of B matrix in Global Memory
LayoutB, // Layout of B matrix in Global Memory
ElementC, // Element type of C matrix in Global Memory
LayoutC, // Layout of C matrix in Global Memory
arch::OpMultiplyAddMixedInput, // Tag to indicate mixed-input datatype
GemmShape<16, 8, 16>, // InstructionShape
ElementA, // Element type of A matrix in Global Memory
LayoutA, // Layout of A matrix in Global Memory
ElementB, // Element type of B matrix in Global Memory
LayoutB, // Layout of B matrix in Global Memory
ElementC, // Element type of C matrix in Global Memory
LayoutC, // Layout of C matrix in Global Memory
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
PartitionsK, AccumulatorsInRowMajor> {


// Check if the ElementA and ElementB are of different data types
static_assert(!std::is_same<ElementA, ElementB>::value,
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInput ElementA and ElementB cannot be of the same data type");
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");

// Data type used for internal computation - use the wider of the two data types for mma.sync operands
using ElementOperand = typename std::conditional<(sizeof(ElementA) > sizeof(ElementB)),
8 changes: 4 additions & 4 deletions python/cutlass_library/generator.py
Original file line number Diff line number Diff line change
@@ -2204,22 +2204,22 @@ def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version):
[16, 8, 16], \
DataType.s8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input),
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input),
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.bf16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input),
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.bf16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input),
MathOperation.multiply_add_mixed_input_upcast),
]

min_cc = 80
4 changes: 2 additions & 2 deletions python/cutlass_library/library.py
Original file line number Diff line number Diff line change
@@ -289,7 +289,7 @@ class ComplexMultiplyOp(enum.Enum):
class MathOperation(enum.Enum):
multiply_add = enum_auto()
multiply_add_saturate = enum_auto()
multiply_add_mixed_input = enum_auto()
multiply_add_mixed_input_upcast = enum_auto()
xor_popc = enum_auto()
and_popc = enum_auto()
multiply_add_fast_bf16 = enum_auto()
@@ -303,7 +303,7 @@ class MathOperation(enum.Enum):
MathOperationTag = {
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
MathOperation.multiply_add_mixed_input: 'cutlass::arch::OpMultiplyAddMixedInput',
MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast',
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@ TEST(SM80_Device_GemmUniversal_f16t_s8t_f16t_mixed_input_tensor_op_f16, 128x128x
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInput,
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@ TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInput,
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
24 changes: 12 additions & 12 deletions test/unit/gemm/warp/gemm_mixed_input_sm80.cu
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
@@ -89,7 +89,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
@@ -112,7 +112,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
@@ -133,7 +133,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
@@ -157,7 +157,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
@@ -177,7 +177,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
@@ -200,7 +200,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
@@ -220,7 +220,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
@@ -243,7 +243,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_1

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
@@ -266,7 +266,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_1

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
@@ -289,7 +289,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_1

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
@@ -312,7 +312,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_1

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()