Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML] Update DML to 1.14 #20304

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x64/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="python" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.1" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.14.0" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x86/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="pythonx86" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.1" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.14.0" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion cmake/external/dml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.1)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.14.0)

# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,28 @@ enum class MLOperatorTensorDataType : uint32_t
Complex64 = 14,

//! 128 bit complex type (unsupported)
Complex128 = 15
Complex128 = 15,

//! bfloat16 type (unsupported)
TensorProto_DataType_BFLOAT16 = 16,

//! FLOAT8E4M3FN type (unsupported)
TensorProto_DataType_FLOAT8E4M3FN = 17,

//! FLOAT8E4M3FNUZ type (unsupported)
TensorProto_DataType_FLOAT8E4M3FNUZ = 18,

//! FLOAT8E5M2 type (unsupported)
TensorProto_DataType_FLOAT8E5M2 = 19,

//! FLOAT8E5M2FNUZ type (unsupported)
TensorProto_DataType_FLOAT8E5M2FNUZ = 20,

//! 4 bit unsigned integer
UInt4 = 21,

//! 4 bit signed integer
Int4 = 22
};

//! \enum MLOperatorEdgeType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ T ApiTraits::StringifyHelpers::FromString(std::string_view value)
#endif
}


template <>
DML_TENSOR_DATA_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value)
{
Expand All @@ -39,6 +40,8 @@ DML_TENSOR_DATA_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view va
{"DML_TENSOR_DATA_TYPE_FLOAT64", DML_TENSOR_DATA_TYPE_FLOAT64},
{"DML_TENSOR_DATA_TYPE_UINT64", DML_TENSOR_DATA_TYPE_UINT64},
{"DML_TENSOR_DATA_TYPE_INT64", DML_TENSOR_DATA_TYPE_INT64},
{"DML_TENSOR_DATA_TYPE_UINT4", DML_TENSOR_DATA_TYPE_UINT4},
{"DML_TENSOR_DATA_TYPE_INT4", DML_TENSOR_DATA_TYPE_INT4},
};
auto index = StringUtil::MapToIndex(value, mapping);
if (!index)
Expand Down Expand Up @@ -243,6 +246,10 @@ DML_OPERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value
{"DML_OPERATOR_MULTIHEAD_ATTENTION", DML_OPERATOR_MULTIHEAD_ATTENTION},
{"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING},
{"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT},
{"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2},
{"DML_OPERATOR_MULTIHEAD_ATTENTION1", DML_OPERATOR_MULTIHEAD_ATTENTION1},
{"DML_OPERATOR_QUANTIZE", DML_OPERATOR_QUANTIZE},
{"DML_OPERATOR_DEQUANTIZE", DML_OPERATOR_DEQUANTIZE},
};
auto index = StringUtil::MapToIndex(value, mapping);
if (!index)
Expand Down Expand Up @@ -446,6 +453,7 @@ DML_FEATURE_LEVEL ApiTraits::StringifyHelpers::FromString(std::string_view value
{"DML_FEATURE_LEVEL_6_0", DML_FEATURE_LEVEL_6_0},
{"DML_FEATURE_LEVEL_6_1", DML_FEATURE_LEVEL_6_1},
{"DML_FEATURE_LEVEL_6_2", DML_FEATURE_LEVEL_6_2},
{"DML_FEATURE_LEVEL_6_3", DML_FEATURE_LEVEL_6_3},
};
auto index = StringUtil::MapToIndex(value, mapping);
if (!index)
Expand Down Expand Up @@ -568,3 +576,21 @@ DML_MULTIHEAD_ATTENTION_MASK_TYPE ApiTraits::StringifyHelpers::FromString(std::s
return static_cast<DML_MULTIHEAD_ATTENTION_MASK_TYPE>(*index);
}


template <>
DML_QUANTIZATION_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value)
{
constexpr StringUtil::NameAndIndex mapping[] =
{
{"DML_QUANTIZATION_TYPE_NONE", DML_QUANTIZATION_TYPE_NONE},
{"DML_QUANTIZATION_TYPE_SCALE", DML_QUANTIZATION_TYPE_SCALE},
{"DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT", DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT},
};
auto index = StringUtil::MapToIndex(value, mapping);
if (!index)
{
assert(false);
return static_cast<DML_QUANTIZATION_TYPE>(0);
}
return static_cast<DML_QUANTIZATION_TYPE>(*index);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ DML_TENSOR_DATA_TYPE GetDmlDataTypeFromMlDataTypeNoThrow(MLOperatorTensorDataTyp
switch (tensorDataType)
{
case MLOperatorTensorDataType::Float: return DML_TENSOR_DATA_TYPE_FLOAT32;
case MLOperatorTensorDataType::UInt4: return DML_TENSOR_DATA_TYPE_UINT4;
case MLOperatorTensorDataType::Int4: return DML_TENSOR_DATA_TYPE_INT4;
case MLOperatorTensorDataType::UInt8: return DML_TENSOR_DATA_TYPE_UINT8;
case MLOperatorTensorDataType::Int8: return DML_TENSOR_DATA_TYPE_INT8;
case MLOperatorTensorDataType::UInt16: return DML_TENSOR_DATA_TYPE_UINT16;
Expand Down Expand Up @@ -41,10 +43,12 @@ bool IsSigned(DML_TENSOR_DATA_TYPE dataType)
case DML_TENSOR_DATA_TYPE_UINT32: return false;
case DML_TENSOR_DATA_TYPE_UINT16: return false;
case DML_TENSOR_DATA_TYPE_UINT8: return false;
case DML_TENSOR_DATA_TYPE_UINT4: return false;
case DML_TENSOR_DATA_TYPE_INT64: return true;
case DML_TENSOR_DATA_TYPE_INT32: return true;
case DML_TENSOR_DATA_TYPE_INT16: return true;
case DML_TENSOR_DATA_TYPE_INT8: return true;
case DML_TENSOR_DATA_TYPE_INT4: return true;
default:
assert(false);
return false;
Expand All @@ -69,6 +73,8 @@ MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tenso
switch (tensorDataType)
{
case DML_TENSOR_DATA_TYPE_FLOAT32: return MLOperatorTensorDataType::Float;
case DML_TENSOR_DATA_TYPE_UINT4: return MLOperatorTensorDataType::UInt4;
case DML_TENSOR_DATA_TYPE_INT4: return MLOperatorTensorDataType::Int4;
case DML_TENSOR_DATA_TYPE_UINT8: return MLOperatorTensorDataType::UInt8;
case DML_TENSOR_DATA_TYPE_INT8: return MLOperatorTensorDataType::Int8;
case DML_TENSOR_DATA_TYPE_UINT16: return MLOperatorTensorDataType::UInt16;
Expand All @@ -87,9 +93,15 @@ MLOperatorTensorDataType GetMlDataTypeFromDmlDataType(DML_TENSOR_DATA_TYPE tenso
}
#pragma warning(pop)

size_t ComputeBitSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType)
{
auto bitSize = ComputeElementCountFromDimensions(dimensions) * GetBitSizeFromMlDataType(tensorDataType);
return bitSize;
}

size_t ComputeByteSizeFromDimensions(gsl::span<const DimensionType> dimensions, MLOperatorTensorDataType tensorDataType)
{
return ComputeElementCountFromDimensions(dimensions) * GetByteSizeFromMlDataType(tensorDataType);
return (ComputeBitSizeFromDimensions(dimensions, tensorDataType) + CHAR_BIT - 1) / CHAR_BIT;
}

size_t ComputeByteSizeFromTensor(IMLOperatorTensor& tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct EnumTraits
template <>
struct EnumTraits<DML_TENSOR_DATA_TYPE>
{
static constexpr auto ValueCount = 12;
static constexpr auto ValueCount = 14;
};

template <>
Expand All @@ -24,7 +24,7 @@ struct EnumTraits<DML_TENSOR_TYPE>
template <>
struct EnumTraits<DML_OPERATOR_TYPE>
{
static constexpr auto ValueCount = 168;
static constexpr auto ValueCount = 174;
static constexpr size_t ActivationFunctionCount = 26;
};

Expand Down Expand Up @@ -62,7 +62,7 @@ struct EnumTraits<DML_CONVOLUTION_DIRECTION>
template <>
struct EnumTraits<DML_PADDING_MODE>
{
static constexpr auto ValueCount = 5;
static constexpr auto ValueCount = 4;
};

template <>
Expand All @@ -86,7 +86,7 @@ struct EnumTraits<DML_FEATURE>
template <>
struct EnumTraits<DML_FEATURE_LEVEL>
{
static constexpr auto ValueCount = 13;
static constexpr auto ValueCount = 14;
};

template <>
Expand Down Expand Up @@ -125,6 +125,12 @@ struct EnumTraits<DML_MULTIHEAD_ATTENTION_MASK_TYPE>
static constexpr auto ValueCount = 5;
};

template <>
struct EnumTraits<DML_QUANTIZATION_TYPE>
{
static constexpr auto ValueCount = 3;
};

template <typename T>
constexpr auto EnumValueCount = EnumTraits<T>::ValueCount;

Expand Down Expand Up @@ -879,12 +885,6 @@ struct OperatorDescTraits<DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY;
};

template <>
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
};

template <>
struct OperatorDescTraits<DML_CONVOLUTION_INTEGER_OPERATOR_DESC>
{
Expand Down Expand Up @@ -1047,6 +1047,36 @@ struct OperatorDescTraits<DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
};

template <>
struct OperatorDescTraits<DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT;
};

template <>
struct OperatorDescTraits<DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2;
};

template <>
struct OperatorDescTraits<DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION1;
};

template <>
struct OperatorDescTraits<DML_QUANTIZE_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZE;
};

template <>
struct OperatorDescTraits<DML_DEQUANTIZE_OPERATOR_DESC>
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DEQUANTIZE;
};

template <>
struct OperatorDescTraits<DML_ACTIVATION_ELU_OPERATOR_DESC>
{
Expand Down Expand Up @@ -1203,6 +1233,7 @@ struct OperatorDescTraits<DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC>
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARD_SWISH;
};


template <DML_OPERATOR_TYPE Type>
struct OperatorTypeTraits
{
Expand Down Expand Up @@ -2072,6 +2103,30 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGE
using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2>
{
using DescType = DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION1>
{
using DescType = DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZE>
{
using DescType = DML_QUANTIZE_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DEQUANTIZE>
{
using DescType = DML_DEQUANTIZE_OPERATOR_DESC;
};

template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU>
{
Expand Down Expand Up @@ -2228,16 +2283,15 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SWISH>
using DescType = DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC;
};


// Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as
// the first argument.
//
//
// For example:
// Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) {
// using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs
// });
//
#pragma warning(push)
#pragma warning(disable:4702)
template <typename Visitor, typename... Ts>
auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args)
{
Expand Down Expand Up @@ -2531,6 +2585,14 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT:
return std::invoke(std::forward<Visitor>(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2:
return std::invoke(std::forward<Visitor>(visitor), DML_MEAN_VARIANCE_NORMALIZATION2_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_MULTIHEAD_ATTENTION1:
return std::invoke(std::forward<Visitor>(visitor), DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_QUANTIZE:
return std::invoke(std::forward<Visitor>(visitor), DML_QUANTIZE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_DEQUANTIZE:
return std::invoke(std::forward<Visitor>(visitor), DML_DEQUANTIZE_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_ELU:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
case DML_OPERATOR_ACTIVATION_CELU:
Expand Down Expand Up @@ -2584,11 +2646,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
case DML_OPERATOR_ACTIVATION_HARD_SWISH:
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC{}, std::forward<Ts>(args)...);
default:
ORT_THROW_HR(E_INVALIDARG);
return std::invoke(std::forward<Visitor>(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward<Ts>(args)...);
THROW_HR(E_INVALIDARG);
}
}
#pragma warning(pop)


namespace StringifyHelpers
{
Expand Down Expand Up @@ -2619,6 +2680,8 @@ inline gsl::czstring ToString(DML_TENSOR_DATA_TYPE value)
case DML_TENSOR_DATA_TYPE_FLOAT64: return "DML_TENSOR_DATA_TYPE_FLOAT64";
case DML_TENSOR_DATA_TYPE_UINT64: return "DML_TENSOR_DATA_TYPE_UINT64";
case DML_TENSOR_DATA_TYPE_INT64: return "DML_TENSOR_DATA_TYPE_INT64";
case DML_TENSOR_DATA_TYPE_UINT4: return "DML_TENSOR_DATA_TYPE_UINT4";
case DML_TENSOR_DATA_TYPE_INT4: return "DML_TENSOR_DATA_TYPE_INT4";
default:
assert(false);
return "<unknown>";
Expand Down Expand Up @@ -2813,6 +2876,10 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION";
case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING";
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT";
case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION2";
case DML_OPERATOR_MULTIHEAD_ATTENTION1: return "DML_OPERATOR_MULTIHEAD_ATTENTION1";
case DML_OPERATOR_QUANTIZE: return "DML_OPERATOR_QUANTIZE";
case DML_OPERATOR_DEQUANTIZE: return "DML_OPERATOR_DEQUANTIZE";
default:
assert(false);
return "<unknown>";
Expand Down Expand Up @@ -2968,6 +3035,7 @@ inline gsl::czstring ToString(DML_FEATURE_LEVEL value)
case DML_FEATURE_LEVEL_6_0: return "DML_FEATURE_LEVEL_6_0";
case DML_FEATURE_LEVEL_6_1: return "DML_FEATURE_LEVEL_6_1";
case DML_FEATURE_LEVEL_6_2: return "DML_FEATURE_LEVEL_6_2";
case DML_FEATURE_LEVEL_6_3: return "DML_FEATURE_LEVEL_6_3";
default:
assert(false);
return "<unknown>";
Expand Down Expand Up @@ -3056,6 +3124,20 @@ inline gsl::czstring ToString(DML_MULTIHEAD_ATTENTION_MASK_TYPE value)
}
}

template <>
inline gsl::czstring ToString(DML_QUANTIZATION_TYPE value)
{
switch (value)
{
case DML_QUANTIZATION_TYPE_NONE: return "DML_QUANTIZATION_TYPE_NONE";
case DML_QUANTIZATION_TYPE_SCALE: return "DML_QUANTIZATION_TYPE_SCALE";
case DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT: return "DML_QUANTIZATION_TYPE_SCALE_ZERO_POINT";
default:
assert(false);
return "<unknown>";
}
}


template <typename T>
T FromString(std::string_view value);
Expand Down
Loading
Loading