diff --git a/BUILD.md b/BUILD.md index ddb9acd1dbbac..5cf420c44c2ec 100644 --- a/BUILD.md +++ b/BUILD.md @@ -403,6 +403,8 @@ alias cmake="/usr/bin/cmake -DCMAKE_TOOLCHAIN_FILE=$OECORE_NATIVE_SYSROOT/usr/sh ``` cmake ../onnxruntime-arm-upstream/cmake -DONNX_CUSTOM_PROTOC_EXECUTABLE=/usr/bin/protoc -Donnxruntime_RUN_ONNX_TESTS=OFF -Donnxruntime_GENERATE_TEST_REPORTS=ON -Donnxruntime_DEV_MODE=ON -DPYTHON_EXECUTABLE=/usr/bin/python3 -Donnxruntime_USE_CUDA=OFF -Donnxruntime_USE_NSYNC=OFF -Donnxruntime_CUDNN_HOME= -Donnxruntime_USE_JEMALLOC=OFF -Donnxruntime_ENABLE_PYTHON=OFF -Donnxruntime_BUILD_CSHARP=OFF -Donnxruntime_BUILD_SHARED_LIB=ON -Donnxruntime_USE_EIGEN_FOR_BLAS=ON -Donnxruntime_USE_OPENBLAS=OFF -Donnxruntime_USE_ACL=ON -Donnxruntime_USE_DNNL=OFF -Donnxruntime_USE_MKLML=OFF -Donnxruntime_USE_OPENMP=ON -Donnxruntime_USE_TVM=OFF -Donnxruntime_USE_LLVM=OFF -Donnxruntime_ENABLE_MICROSOFT_INTERNAL=OFF -Donnxruntime_USE_BRAINSLICE=OFF -Donnxruntime_USE_NUPHAR=OFF -Donnxruntime_USE_EIGEN_THREADPOOL=OFF -Donnxruntime_BUILD_UNIT_TESTS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo ``` +The ```-Donnxruntime_USE_ACL=ON``` option will use, by default, the 19.05 version of the Arm Compute Library. To set the right version you can use: +```-Donnxruntime_USE_ACL_1902=ON```, ```-Donnxruntime_USE_ACL_1905=ON``` or ```-Donnxruntime_USE_ACL_1908=ON```; 2. Build ONNX Runtime library, test and performance application: ``` diff --git a/cgmanifest.json b/cgmanifest.json index f0f08d036467e..6034653b8c010 100644 --- a/cgmanifest.json +++ b/cgmanifest.json @@ -450,7 +450,7 @@ { "component": { "git": { - "commitHash": "9dfa77ad5c59abd07b77a4c463315ef8112daefc", + "commitHash": "7922489c1e7e7baf20c4b1557743d6c7ea72647d", "repositoryUrl": "https://github.com/microsoft/FeaturizersLibrary.git" }, "type": "git" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index fdfd64cb83d78..f5d98495efb12 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -57,7 +57,7 @@ option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_EIGEN_FOR_BLAS "Use eign for blas" ON) option(onnxruntime_USE_NNAPI "Build with DNNLibrary for Android NNAPI support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_MKLML "Build DNNL with MKL-ML binary dependency" OFF) +option(onnxruntime_USE_MKLML "Build the default cpu provider with MKL-ML binary dependency" OFF) option(onnxruntime_USE_FEATURIZERS "Build ML Featurizers support" OFF) option(onnxruntime_USE_NGRAPH "Build with nGraph support" OFF) option(onnxruntime_USE_OPENBLAS "Use openblas" OFF) @@ -91,6 +91,9 @@ option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump node input shapes and output option(onnxruntime_USE_DML "Build with DirectML support" OFF) option(onnxruntime_USE_WINML "Build with WinML support" OFF) option(onnxruntime_USE_ACL "Build with ACL support" OFF) +option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF) +option(onnxruntime_USE_ACL_1905 "Build with ACL version 1905 support" OFF) +option(onnxruntime_USE_ACL_1908 "Build with ACL version 1908 support" OFF) option(onnxruntime_ENABLE_INSTRUMENT "Enable Instrument with Event Tracing for Windows (ETW)" OFF) option(onnxruntime_USE_TELEMETRY "Build with Telemetry" OFF) #The onnxruntime_PREFER_SYSTEM_LIB is mainly designed for package managers like apt/yum/vcpkg. @@ -213,8 +216,8 @@ if (MSVC) set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib for gtest" FORCE) endif() #Always enable exception handling, even for Windows ARM - string(APPEND CMAKE_CXX_FLAGS " /EHsc") - string(APPEND CMAKE_C_FLAGS " /EHsc") + string(APPEND CMAKE_CXX_FLAGS " /EHsc /wd26812") + string(APPEND CMAKE_C_FLAGS " /EHsc /wd26812") if(onnxruntime_USE_AVX) string(APPEND CMAKE_CXX_FLAGS " /arch:AVX") string(APPEND CMAKE_C_FLAGS " /arch:AVX") @@ -472,7 +475,18 @@ endfunction() set(onnxruntime_EXTERNAL_DEPENDENCIES onnx_proto) # ACL -if (onnxruntime_USE_ACL) +if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908) + set(onnxruntime_USE_ACL ON) + if(onnxruntime_USE_ACL_1902) + add_definitions(-DACL_1902=1) + else() + if(onnxruntime_USE_ACL_1908) + add_definitions(-DACL_1908=1) + else() + add_definitions(-DACL_1905=1) + endif() + endif() + list(APPEND onnxruntime_EXTERNAL_LIBRARIES arm_compute acl arm_compute_graph arm_compute_core) endif() @@ -571,7 +585,7 @@ if (WIN32) string(APPEND CMAKE_CXX_FLAGS " /wd4251") if (onnxruntime_ENABLE_STATIC_ANALYSIS) string(APPEND CMAKE_CXX_FLAGS - " /analyze:WX- " + " /analyze:stacksize 131072" # disable warning because there are many occurrences from test macros " /wd6326 " # potential comparison of a constant with another constant ) @@ -756,12 +770,12 @@ if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_STANDARD 11) file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) set(ONNXRUNTIME_CUDA_LIBRARIES ${CUDA_LIBRARIES}) - list(APPEND ONNXRUNTIME_CUDA_LIBRARIES cublas cudnn curand) + list(APPEND ONNXRUNTIME_CUDA_LIBRARIES cublas cudnn curand cufft) if (WIN32) link_directories(${onnxruntime_CUDNN_HOME}/lib/x64) # delayload causes crash on exit, so disable for now - #file(GLOB cuda_dll_paths "${onnxruntime_CUDA_HOME}/bin/cublas64_*" "${onnxruntime_CUDA_HOME}/bin/cudart64_*" "${onnxruntime_CUDA_HOME}/bin/curand64_*") + #file(GLOB cuda_dll_paths "${onnxruntime_CUDA_HOME}/bin/cublas64_*" "${onnxruntime_CUDA_HOME}/bin/cudart64_*" "${onnxruntime_CUDA_HOME}/bin/curand64_*" "${onnxruntime_CUDA_HOME}/bin/cufft64_*") #set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:cudnn64_7.dll") #foreach(cuda_dll_path ${cuda_dll_paths}) # get_filename_component(cuda_dll_file_name ${cuda_dll_path} NAME) @@ -867,7 +881,7 @@ else() list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${CMAKE_DL_LIBS} Threads::Threads) endif() -# Default version parts for Windows.AI.MachineLearning.dll and onnxruntime.dll in non-ADO pipeline local builds +# Default version parts for Microsoft.AI.MachineLearning.dll and onnxruntime.dll in non-ADO pipeline local builds set(VERSION_MAJOR_PART 0 CACHE STRING "First part of numeric file/product version.") set(VERSION_MINOR_PART 0 CACHE STRING "Second part of numeric file/product version.") set(VERSION_BUILD_PART 0 CACHE STRING "Third part of numeric file/product version.") diff --git a/cmake/CMakeSettings.json b/cmake/CMakeSettings.json index cd740a929cfad..72ffe909863fb 100644 --- a/cmake/CMakeSettings.json +++ b/cmake/CMakeSettings.json @@ -16,6 +16,10 @@ "value": "True", "type": "BOOL" }, + { + "name": "onnxruntime_WINML_NAMESPACE_OVERRIDE", + "value": "Microsoft" + }, { "name": "onnxruntime_USE_DML", "value": "True", @@ -45,6 +49,10 @@ "value": "True", "type": "BOOL" }, + { + "name": "onnxruntime_WINML_NAMESPACE_OVERRIDE", + "value": "Microsoft" + }, { "name": "onnxruntime_USE_DML", "value": "True", diff --git a/cmake/external/FeaturizersLibrary b/cmake/external/FeaturizersLibrary index 9dfa77ad5c59a..7922489c1e7e7 160000 --- a/cmake/external/FeaturizersLibrary +++ b/cmake/external/FeaturizersLibrary @@ -1 +1 @@ -Subproject commit 9dfa77ad5c59abd07b77a4c463315ef8112daefc +Subproject commit 7922489c1e7e7baf20c4b1557743d6c7ea72647d diff --git a/cmake/external/dnnl.cmake b/cmake/external/dnnl.cmake index a1aa3e1260d35..f67ecbf49cd74 100644 --- a/cmake/external/dnnl.cmake +++ b/cmake/external/dnnl.cmake @@ -77,7 +77,7 @@ if (onnxruntime_USE_DNNL) GIT_TAG ${DNNL_TAG} # PATCH_COMMAND ${MKLDNN_PATCH_DISCARD_COMMAND} COMMAND ${DNNL_PATCH_COMMAND} SOURCE_DIR ${DNNL_SOURCE} - CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${DNNL_INSTALL} + CMAKE_ARGS -DDNNL_BUILD_TESTS=OFF -DDNNL_BUILD_EXAMPLES=OFF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${DNNL_INSTALL} ) link_directories(${DNNL_LIB_DIR}) #if (onnxruntime_USE_MKLML) diff --git a/cmake/onnxruntime_dependencies.dot b/cmake/onnxruntime_dependencies.dot deleted file mode 100644 index 7f1f5d9af5c41..0000000000000 --- a/cmake/onnxruntime_dependencies.dot +++ /dev/null @@ -1,21 +0,0 @@ -digraph "GG" { -node [ - fontsize = "12" -]; - "node12" [ label="onnxruntime_graph" shape="diamond"]; - "node10" [ label="onnxruntime_common" shape="diamond"]; - "node12" -> "node10" // onnxruntime_graph -> onnxruntime_common - "node4" [ label="onnx" shape="diamond"]; - "node12" -> "node4" // onnxruntime_graph -> onnx - "node15" [ label="onnxruntime_framework" shape="diamond"]; - "node15" -> "node12" // onnxruntime_framework -> onnxruntime_graph - "node15" -> "node10" // onnxruntime_framework -> onnxruntime_common - "node15" -> "node4" // onnxruntime_framework -> onnx - "node17" [ label="onnxruntime_providers" shape="diamond"]; - "node17" -> "node10" // onnxruntime_providers -> onnxruntime_common - "node17" -> "node15" // onnxruntime_providers -> onnxruntime_framework - "node18" [ label="onnxruntime_test_common" shape="house"]; - "node6" [ label="onnxruntime_test_framework" shape="house"]; - "node19" [ label="onnxruntime_test_ir" shape="house"]; - "node20" [ label="onnxruntime_test_providers" shape="house"]; -} diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 07e40ad809c18..1567ef92e510a 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -75,6 +75,12 @@ endif() if (onnxruntime_USE_NUPHAR) target_compile_definitions(onnxruntime4j_jni PRIVATE USE_NUPHAR=1) endif() +if (onnxruntime_USE_ACL) + target_compile_definitions(onnxruntime4j_jni PRIVATE USE_ACL=1) +endif() +if (onnxruntime_USE_DML) + target_compile_definitions(onnxruntime4j_jni PRIVATE USE_DIRECTML=1) +endif() # depend on java sources. if they change, the JNI should recompile add_dependencies(onnxruntime4j_jni onnxruntime4j) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index e22d1f1a5a38e..679a563008a70 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -263,3 +263,6 @@ endif() add_library(onnxruntime_mlas STATIC ${mlas_common_srcs} ${mlas_platform_srcs}) target_include_directories(onnxruntime_mlas PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib ${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64) set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") +if (WIN32) + target_compile_options(onnxruntime_mlas PRIVATE "/wd6385") +endif() \ No newline at end of file diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 7a4533c7b1e03..fbe632fd507c6 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -28,6 +28,9 @@ function(AddTest) #TODO: fix the warnings, they are dangerous target_compile_options(${_UT_TARGET} PRIVATE "/wd4244") endif() + if (MSVC) + target_compile_options(${_UT_TARGET} PRIVATE "/wd6330") + endif() source_group(TREE ${TEST_SRC_DIR} FILES ${_UT_SOURCES}) set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest") @@ -474,7 +477,7 @@ add_library(onnx_test_data_proto ${TEST_SRC_DIR}/proto/tml.proto) add_dependencies(onnx_test_data_proto onnx_proto ${onnxruntime_EXTERNAL_DEPENDENCIES}) if(WIN32) - target_compile_options(onnx_test_data_proto PRIVATE "/wd4125" "/wd4456" "/wd4100" "/wd4267") + target_compile_options(onnx_test_data_proto PRIVATE "/wd4125" "/wd4456" "/wd4100" "/wd4267" "/wd6011" "/wd6387" "/wd28182") else() if(HAS_UNUSED_PARAMETER) target_compile_options(onnx_test_data_proto PRIVATE "-Wno-unused-parameter") @@ -588,6 +591,8 @@ if(onnxruntime_BUILD_BENCHMARKS) if(WIN32) target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd4141>" "$<$>:/wd4141>") + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /utf-8>" + "$<$>:/utf-8>") endif() target_link_libraries(onnxruntime_benchmark PRIVATE onnx_test_runner_common benchmark ${onnx_test_libs}) add_dependencies(onnxruntime_benchmark ${onnxruntime_EXTERNAL_DEPENDENCIES}) @@ -751,7 +756,7 @@ if(UNIX) if (APPLE) set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-Xlinker -dead_strip") else() - set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-Xlinker --no-undefined -Xlinker --gc-sections") + set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-Xlinker --version-script=${REPO_ROOT}/onnxruntime/test/testdata/custom_op_library/custom_op_library.lds -Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") endif() else() set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${REPO_ROOT}/onnxruntime/test/testdata/custom_op_library/custom_op_library.def") diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 5b8fb9a1c6310..a4b7a8f9522b0 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -24,6 +24,24 @@ set(winml_lib_api_ort_dir ${REPO_ROOT}/winml/lib/api.ort) set(winml_lib_common_dir ${REPO_ROOT}/winml/lib/common) set(winml_lib_telemetry_dir ${REPO_ROOT}/winml/lib/telemetry) +if (onnxruntime_WINML_NAMESPACE_OVERRIDE) + set(output_name "${onnxruntime_WINML_NAMESPACE_OVERRIDE}.AI.MachineLearning") + set(idl_native_output_name "${onnxruntime_WINML_NAMESPACE_OVERRIDE}.AI.MachineLearning.Native") + set(idl_native_internal_output_name "${onnxruntime_WINML_NAMESPACE_OVERRIDE}.AI.MachineLearning.Native.Internal") + set(winml_midl_defines "/DROOT_NS=${onnxruntime_WINML_NAMESPACE_OVERRIDE}") + set(winml_root_ns "${onnxruntime_WINML_NAMESPACE_OVERRIDE}") + set(BINARY_NAME "${onnxruntime_WINML_NAMESPACE_OVERRIDE}.AI.MachineLearning.dll") + set(winml_api_use_ns_prefix false) +else() + set(output_name "Microsoft.AI.MachineLearning") + set(idl_native_output_name "Microsoft.AI.MachineLearning.Native") + set(idl_native_internal_output_name "Microsoft.AI.MachineLearning.Native.Internal") + set(winml_midl_defines "/DROOT_NS=Microsoft") + set(winml_root_ns "Microsoft") + set(BINARY_NAME "Microsoft.AI.MachineLearning.dll") + set(winml_api_use_ns_prefix true) +endif() + get_filename_component(exclusions "${winml_api_root}/exclusions.txt" ABSOLUTE) convert_forward_slashes_to_back(${exclusions} CPPWINRT_COMPONENT_EXCLUSION_LIST) @@ -51,25 +69,33 @@ add_generate_cppwinrt_sdk_headers_target( # generate winml headers from idl target_cppwinrt(winml_api - ${winrt_idl} # winml winrt idl to compile - ${winml_lib_api_dir} # location for cppwinrt generated component sources - ${sdk_folder} # location of sdk folder - ${sdk_version} # sdk version - ${target_folder} # the folder this target will be placed under + ${winrt_idl} # winml winrt idl to compile + ${output_name} # outputs name + ${winml_lib_api_dir} # location for cppwinrt generated component sources + ${sdk_folder} # location of sdk folder + ${sdk_version} # sdk version + ${target_folder} # the folder this target will be placed under + ${winml_midl_defines} # the midl compiler defines + ${winml_api_use_ns_prefix} # set ns_prefix ) target_midl(winml_api_native - ${idl_native} # winml native idl to compile - ${sdk_folder} # location of sdk folder - ${sdk_version} # sdk version - ${target_folder} # the folder this target will be placed under + ${idl_native} # winml native idl to compile + ${idl_native_output_name} # outputs name + ${sdk_folder} # location of sdk folder + ${sdk_version} # sdk version + ${target_folder} # the folder this target will be placed under + ${winml_midl_defines} # the midl compiler defines ) target_midl(winml_api_native_internal - ${idl_native_internal} # winml internal native idl to compile - ${sdk_folder} # location of sdk folder - ${sdk_version} # sdk version - ${target_folder}) # the folder this target will be placed under + ${idl_native_internal} # winml internal native idl to compile + ${idl_native_internal_output_name} # outputs name + ${sdk_folder} # location of sdk folder + ${sdk_version} # sdk version + ${target_folder} # the folder this target will be placed under + ${winml_midl_defines} # the midl compiler defines +) ########################### # Add winml_lib_telemetry @@ -96,6 +122,7 @@ endif() # Compiler flags target_compile_definitions(winml_lib_telemetry PRIVATE PLATFORM_WINDOWS) target_compile_definitions(winml_lib_telemetry PRIVATE _SCL_SECURE_NO_WARNINGS) # remove warnings about unchecked iterators +target_compile_definitions(winml_lib_telemetry PRIVATE BINARY_NAME=\"${BINARY_NAME}\") # Specify the usage of a precompiled header target_precompiled_header(winml_lib_telemetry pch.h) @@ -153,6 +180,7 @@ target_compile_features(winml_lib_ort PRIVATE cxx_std_17) target_compile_options(winml_lib_ort PRIVATE /GR- /await /wd4238) # Compiler definitions +target_compile_definitions(winml_lib_ort PRIVATE WINML_ROOT_NS=${winml_root_ns}) target_compile_definitions(winml_lib_ort PRIVATE PLATFORM_WINDOWS) target_compile_definitions(winml_lib_ort PRIVATE _SCL_SECURE_NO_WARNINGS) # remove warnings about unchecked iterators @@ -280,6 +308,7 @@ target_compile_features(winml_lib_image PRIVATE cxx_std_17) target_compile_options(winml_lib_image PRIVATE /GR- /await /wd4238) # Compiler flags +target_compile_definitions(winml_lib_image PRIVATE WINML_ROOT_NS=${winml_root_ns}) target_compile_definitions(winml_lib_image PRIVATE ONNX_NAMESPACE=onnx) target_compile_definitions(winml_lib_image PRIVATE ONNX_ML) target_compile_definitions(winml_lib_image PRIVATE LOTUS_LOG_THRESHOLD=2) @@ -370,6 +399,7 @@ target_compile_features(winml_lib_api PRIVATE cxx_std_17) target_compile_options(winml_lib_api PRIVATE /GR- /await /bigobj /wd4238) # Compiler flags +target_compile_definitions(winml_lib_api PRIVATE WINML_ROOT_NS=${winml_root_ns}) target_compile_definitions(winml_lib_api PRIVATE ONNX_NAMESPACE=onnx) target_compile_definitions(winml_lib_api PRIVATE ONNX_ML) target_compile_definitions(winml_lib_api PRIVATE LOTUS_LOG_THRESHOLD=2) @@ -452,18 +482,23 @@ set_target_properties(winml_lib_common PROPERTIES CXX_STANDARD_REQUIRED ON) target_compile_options(winml_lib_common PRIVATE /GR- /await /bigobj /wd4238) target_link_libraries(winml_lib_common PRIVATE wil) target_include_directories(winml_lib_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) -target_compile_definitions(winml_lib_common PRIVATE - ONNX_NAMESPACE=onnx - ONNX_ML - LOTUS_LOG_THRESHOLD=2 - LOTUS_ENABLE_STDERR_LOGGING - PLATFORM_WINDOWS - _SCL_SECURE_NO_WARNINGS) + +# Compiler flags +target_compile_definitions(winml_lib_common PRIVATE BINARY_NAME=\"${BINARY_NAME}\") +target_compile_definitions(winml_lib_common PRIVATE WINML_ROOT_NS=${winml_root_ns}) +target_compile_definitions(winml_lib_common PRIVATE ONNX_NAMESPACE=onnx) +target_compile_definitions(winml_lib_common PRIVATE ONNX_ML) +target_compile_definitions(winml_lib_common PRIVATE LOTUS_LOG_THRESHOLD=2) +target_compile_definitions(winml_lib_common PRIVATE LOTUS_ENABLE_STDERR_LOGGING) +target_compile_definitions(winml_lib_common PRIVATE PLATFORM_WINDOWS) +target_compile_definitions(winml_lib_common PRIVATE _SCL_SECURE_NO_WARNINGS) + add_dependencies(winml_lib_common winml_sdk_cppwinrt) add_dependencies(winml_lib_common winml_api) add_dependencies(winml_lib_common winml_api_native) add_dependencies(winml_lib_common winml_api_native_internal) + target_include_directories(winml_lib_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api) # windows machine learning generated component headers target_include_directories(winml_lib_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated) # windows machine learning generated component headers target_include_directories(winml_lib_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/winml/sdk/cppwinrt/include) # sdk cppwinrt headers @@ -489,7 +524,7 @@ set_source_files_properties( # Add library add_library(winml_dll SHARED ${CMAKE_CURRENT_BINARY_DIR}/winml_api/comp_generated/module.g.excl.cpp - ${winml_dll_dir}/windows.ai.machinelearning.def + ${winml_dll_dir}/winml.def ${winml_dll_dir}/winml.rc ${winml_dll_dir}/pch.h ${winml_dll_dir}/module.cpp @@ -500,6 +535,7 @@ target_compile_features(winml_dll PRIVATE cxx_std_17) target_compile_options(winml_dll PRIVATE /GR- /await /bigobj /wd4238) # Compiler definitions +target_compile_definitions(winml_dll PRIVATE WINML_ROOT_NS=${winml_root_ns}) target_compile_definitions(winml_dll PRIVATE ONNX_NAMESPACE=onnx) target_compile_definitions(winml_dll PRIVATE ONNX_ML) target_compile_definitions(winml_dll PRIVATE LOTUS_LOG_THRESHOLD=2) @@ -510,6 +546,7 @@ target_compile_definitions(winml_dll PRIVATE VER_MINOR=${VERSION_MINOR_PART}) target_compile_definitions(winml_dll PRIVATE VER_BUILD=${VERSION_BUILD_PART}) target_compile_definitions(winml_dll PRIVATE VER_PRIVATE=${VERSION_PRIVATE_PART}) target_compile_definitions(winml_dll PRIVATE VER_STRING=\"${VERSION_STRING}\") +target_compile_definitions(winml_dll PRIVATE BINARY_NAME=\"${BINARY_NAME}\") # Specify the usage of a precompiled header target_precompiled_header(winml_dll pch.h) @@ -545,7 +582,7 @@ target_include_directories(winml_dll PRIVATE ${REPO_ROOT}/cmake/external/eigen) # Properties set_target_properties(winml_dll PROPERTIES - OUTPUT_NAME windows.ai.machinelearning) + OUTPUT_NAME ${output_name}) if (onnxruntime_USE_DML) set(delayload_dml "/DELAYLOAD:directml.dll") @@ -553,7 +590,7 @@ endif(onnxruntime_USE_DML) set(os_component_link_flags_list ${os_component_link_flags}) separate_arguments(os_component_link_flags_list) -target_link_options(winml_dll PRIVATE /DEF:${WINML_DIR}/windows.ai.machinelearning.def ${os_component_link_flags_list} /DELAYLOAD:api-ms-win-core-libraryloader-l1-2-1.dll /DELAYLOAD:api-ms-win-core-threadpool-legacy-l1-1-0.dll /DELAYLOAD:api-ms-win-core-processtopology-obsolete-l1-1-0.dll /DELAYLOAD:api-ms-win-core-kernel32-legacy-l1-1-0.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:d3d11.dll /DELAYLOAD:dxgi.dll ${delayload_dml}) +target_link_options(winml_dll PRIVATE /DEF:${WINML_DIR}/winml.def ${os_component_link_flags_list} /DELAYLOAD:api-ms-win-core-libraryloader-l1-2-1.dll /DELAYLOAD:api-ms-win-core-threadpool-legacy-l1-1-0.dll /DELAYLOAD:api-ms-win-core-processtopology-obsolete-l1-1-0.dll /DELAYLOAD:api-ms-win-core-kernel32-legacy-l1-1-0.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:d3d11.dll /DELAYLOAD:dxgi.dll ${delayload_dml}) if (EXISTS ${dxcore_header}) target_link_options(winml_dll PRIVATE /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll) diff --git a/cmake/winml_cppwinrt.cmake b/cmake/winml_cppwinrt.cmake index c047689b32588..acde0406c7318 100644 --- a/cmake/winml_cppwinrt.cmake +++ b/cmake/winml_cppwinrt.cmake @@ -20,9 +20,12 @@ function(target_midl target_name idl_file - sdk_folder # sdk kit directory - sdk_version # sdk version - folder_name) + output_name # output name of the generated headers, winmd and tlb + sdk_folder # sdk kit directory + sdk_version # sdk version + folder_name + midl_options # defines for the midl compiler + ) if (MSVC) # get sdk include paths for midl get_sdk_include_folder(${sdk_folder} ${sdk_version} sdk_include_folder) @@ -38,9 +41,7 @@ function(target_midl get_sdk_midl_exe(${sdk_folder} ${sdk_version} midl_exe) # Filename variables - get_filename_component(file_name_with_extension ${idl_file} NAME) - string(REGEX REPLACE "\\.[^.]*$" "" file_name ${file_name_with_extension}) - set(header_filename ${file_name}.h) + set(header_filename ${output_name}.h) convert_forward_slashes_to_back(${idl_file} idl_file_forward_slash) # using add_custom_command trick to prevent rerunning script unless ${file} is changed @@ -55,6 +56,7 @@ function(target_midl /I ${winrt_sdk_directory} /I ${CMAKE_CURRENT_SOURCE_DIR} /h ${header_filename} + ${midl_options} ${idl_file_forward_slash} DEPENDS ${idl_file} ) @@ -70,12 +72,15 @@ function(target_midl endfunction() function(target_cppwinrt - target_name # the name of the target to add - file # name of the idl file to compile - out_sources_folder # path where generated sources will be placed - sdk_folder # sdk kit directory - sdk_version # sdk version - folder_name # folder this target will be placed + target_name # the name of the target to add + file # name of the idl file to compile + output_name # output name of the generated headers, winmd and tlb + out_sources_folder # path where generated sources will be placed + sdk_folder # sdk kit directory + sdk_version # sdk version + folder_name # folder this target will be placed + midl_options # defines for the midl compiler + set_ns_prefix # set ns_prefix option ) if (MSVC) # get sdk include paths for midl @@ -95,12 +100,9 @@ function(target_cppwinrt get_sdk_cppwinrt_exe(${sdk_folder} ${sdk_version} cppwinrt_exe) # Filename variables - convert_forward_slashes_to_back(${file} idl_file_forward_slash) - get_filename_component(file_name_with_extension ${file} NAME) - string(REGEX REPLACE "\\.[^.]*$" "" fileName ${file_name_with_extension}) - set(header_filename ${fileName}.h) - set(winmd_filename ${fileName}.winmd) - set(tlb_filename ${fileName}.tlb) + set(header_filename ${output_name}.h) + set(winmd_filename ${output_name}.winmd) + set(tlb_filename ${output_name}.tlb) # Get directory get_filename_component(idl_source_directory ${file} DIRECTORY) @@ -111,6 +113,24 @@ function(target_cppwinrt convert_forward_slashes_to_back(${target_outputs}/comp_generated generated_dir_back_slash) convert_forward_slashes_to_back(${generated_dir_back_slash}/module.g.cpp module_g_cpp_back_slash) convert_forward_slashes_to_back(${generated_dir_back_slash}/module.g.excl.cpp module_g_ecxl_cpp_back_slash) + + if (set_ns_prefix) + set(ns_prefix "/ns_prefix") + else() + set(ns_prefix "") + endif() + + # Get name + set(renamed_idl_filename ${output_name}.idl) + set(renamed_idl_fullpath ${target_outputs}/${renamed_idl_filename}) + + get_filename_component(idl_source_filename ${file} NAME) + set(copied_idl_fullpath ${target_outputs}/${idl_source_filename}) + + file(COPY ${file} DESTINATION ${target_outputs}) + file(RENAME ${copied_idl_fullpath} ${renamed_idl_fullpath}) + + convert_forward_slashes_to_back(${renamed_idl_fullpath} renamed_idl_fullpath_back_slash) # using add_custom_command trick to prevent rerunning script unless ${file} is changed add_custom_command( @@ -125,9 +145,11 @@ function(target_cppwinrt /I ${winrt_sdk_directory} /I ${idl_source_directory} /winmd ${winmd_filename} + ${ns_prefix} /h ${header_filename} /tlb ${tlb_filename} - ${idl_file_forward_slash} + ${midl_options} + ${renamed_idl_fullpath_back_slash} COMMAND ${cppwinrt_exe} -in ${winmd_filename} -comp ${output_dir_back_slash} -ref ${sdk_metadata_directory} -out ${generated_dir_back_slash} -verbose COMMAND diff --git a/cmake/winml_unittests.cmake b/cmake/winml_unittests.cmake index 73d0ff7fde512..c13e607df8905 100644 --- a/cmake/winml_unittests.cmake +++ b/cmake/winml_unittests.cmake @@ -23,6 +23,8 @@ function(set_winml_target_properties target) CXX_EXTENSIONS NO ) target_include_directories(${target} PRIVATE ${WINML_TEST_INC_DIR}) + target_compile_definitions(${target} PRIVATE WINML_ROOT_NS=${winml_root_ns}) + target_compile_definitions(${target} PRIVATE BINARY_NAME=\"${BINARY_NAME}\") endfunction() function(add_winml_test) @@ -59,10 +61,14 @@ function(get_winml_test_scenario_src output_winml_test_scenario_libs ) if (onnxruntime_USE_DML) - file(GLOB winml_test_scenario_src CONFIGURE_DEPENDS "${winml_test_src_path}/scenario/cppwinrt/*.cpp") + file(GLOB winml_test_scenario_src CONFIGURE_DEPENDS + "${winml_test_src_path}/scenario/cppwinrt/*.h" + "${winml_test_src_path}/scenario/cppwinrt/*.cpp") set(${output_winml_test_scenario_libs} "onnxruntime_providers_dml" PARENT_SCOPE) else() - set(winml_test_scenario_src "${winml_test_src_path}/scenario/cppwinrt/scenariotestscppwinrt.cpp") + set(winml_test_scenario_src + "${winml_test_src_path}/scenario/cppwinrt/scenariotestscppwinrt.h" + "${winml_test_src_path}/scenario/cppwinrt/scenariotestscppwinrt.cpp") endif() set(${output_winml_test_scenario_src} ${winml_test_scenario_src} PARENT_SCOPE) endfunction() @@ -71,7 +77,9 @@ function(get_winml_test_api_src winml_test_src_path output_winml_test_api_src ) - file(GLOB winml_test_api_src CONFIGURE_DEPENDS "${winml_test_src_path}/api/*.cpp") + file(GLOB winml_test_api_src CONFIGURE_DEPENDS + "${winml_test_src_path}/api/*.h" + "${winml_test_src_path}/api/*.cpp") set(${output_winml_test_api_src} ${winml_test_api_src} PARENT_SCOPE) endfunction() @@ -79,10 +87,24 @@ function(get_winml_test_concurrency_src winml_test_src_path output_winml_test_concurrency_src ) - file(GLOB winml_test_concurrency_src CONFIGURE_DEPENDS "${winml_test_src_path}/concurrency/*.cpp") + file(GLOB winml_test_concurrency_src CONFIGURE_DEPENDS + "${winml_test_src_path}/concurrency/*.h" + "${winml_test_src_path}/concurrency/*.cpp") set(${output_winml_test_concurrency_src} ${winml_test_concurrency_src} PARENT_SCOPE) endfunction() +function(get_winml_test_adapter_src + winml_test_src_path + output_winml_test_adapter_src + output_winml_test_adapter_libs +) + set(${output_winml_test_adapter_libs} "onnxruntime" PARENT_SCOPE) + file(GLOB winml_test_adapter_src CONFIGURE_DEPENDS + "${winml_test_src_path}/adapter/*.h" + "${winml_test_src_path}/adapter/*.cpp") + set(${output_winml_test_adapter_src} ${winml_test_adapter_src} PARENT_SCOPE) +endfunction() + function(get_winml_test_image_src winml_test_src_path output_winml_test_image_src @@ -90,11 +112,15 @@ function(get_winml_test_image_src if (onnxruntime_USE_DML) set(${output_winml_test_scenario_libs} "onnxruntime_providers_dml" PARENT_SCOPE) endif() - file(GLOB winml_test_image_src CONFIGURE_DEPENDS "${winml_test_src_path}/image/*.cpp") + file(GLOB winml_test_image_src CONFIGURE_DEPENDS + "${winml_test_src_path}/image/*.h" + "${winml_test_src_path}/image/*.cpp") set(${output_winml_test_image_src} ${winml_test_image_src} PARENT_SCOPE) endfunction() -file(GLOB winml_test_common_src CONFIGURE_DEPENDS "${WINML_TEST_SRC_DIR}/common/*.cpp") +file(GLOB winml_test_common_src CONFIGURE_DEPENDS + "${WINML_TEST_SRC_DIR}/common/*.h" + "${WINML_TEST_SRC_DIR}/common/*.cpp") add_library(winml_test_common STATIC ${winml_test_common_src}) add_dependencies(winml_test_common onnx @@ -159,6 +185,14 @@ add_winml_test( ) target_include_directories(winml_test_concurrency PRIVATE ${ONNXRUNTIME_ROOT}/core/graph) +get_winml_test_adapter_src(${WINML_TEST_SRC_DIR} winml_test_adapter_src winml_test_adapter_libs) +add_winml_test( + TARGET winml_test_adapter + SOURCES ${winml_test_adapter_src} + LIBS winml_test_common ${winml_test_adapter_libs} +) +target_include_directories(winml_test_adapter PRIVATE ${REPO_ROOT}/winml/adapter) + # During build time, copy any modified collaterals. # configure_file(source destination COPYONLY), which configures CMake to copy the file whenever source is modified, # can't be used here because we don't know the destination during configure time (in multi-configuration generators, diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj index 669497220c567..bceeb06ba1e10 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj @@ -10,7 +10,8 @@ ..\..\OnnxRuntime.snk - ..\.. + ..\..\.. + $(OnnxRuntimeRoot)\csharp $(OnnxRuntimeCsharpRoot)\..\build\Windows $(OnnxRuntimeBuildDirectory)\$(Configuration)\$(Configuration) x64 @@ -92,8 +93,8 @@ CopyToOutputDirectory="Always" Visible="false" /> - - - - - - + + PreserveNewest false + + Microsoft.AI.MachineLearning.dll + PreserveNewest + false + onnxruntime.dll @@ -79,5 +85,11 @@ PreserveNewest false + + Microsoft.AI.MachineLearning.dll + PreserveNewest + false + diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 86d9af68a890a..d99dce3508c31 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -1836,4 +1836,44 @@ public SkipNonPackageTests() } } + + // A Disposable list is a list of IDisposable objects. All elements will be disposed when the container is disposed. + internal class DisposableList : List, IDisposableReadOnlyCollection + where T : IDisposable + { + public DisposableList() { } + public DisposableList(int count) : base(count) { } + + #region IDisposable Support + private bool disposedValue = false; // To detect redundant calls + + protected virtual void Dispose(bool disposing) + { + if (!disposedValue) + { + if (disposing) + { + for (int i = 0; i < this.Count; i++) + { + this[i]?.Dispose(); + } + this.Clear(); + } + + disposedValue = true; + } + } + + ~DisposableList() + { + Dispose(false); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + #endregion + } } diff --git a/docs/ONNX_Runtime_Perf_Tuning.md b/docs/ONNX_Runtime_Perf_Tuning.md index aaab8f8422fe5..3f6019550d447 100644 --- a/docs/ONNX_Runtime_Perf_Tuning.md +++ b/docs/ONNX_Runtime_Perf_Tuning.md @@ -1,17 +1,17 @@ # ONNX Runtime Performance Tuning ## Why do we need to tune performance? -ONNX Runtime is designed to be open and extensible with its concept of "Execution Provider" to represents different execution kernels. See the [design overview](./HighLevelDesign.md). +ONNX Runtime is designed to be open and extensible with its concept of "Execution Provider" to represent different execution kernels. See the [design overview](./HighLevelDesign.md). ONNX Runtime supports a variety of execution providers across CPU and GPU: [see the list here](../README.md#high-performance). -For different models and different hardware, there is no silver bullet which can always perform the best. Even for a single execution provider, often there are several knobs that can be tuned (e.g. thread number, wait policy etc.). +For different models and different hardware, there is no silver bullet that can always perform the best. Even for a single execution provider, often there are several knobs that can be tuned (e.g. thread number, wait policy etc.). This document covers basic tools and knobs that can be leveraged to find the best performance for your model and hardware. ## Is there a tool to help with performance tuning? Yes, the onnxruntime_perf_test.exe tool (available from the build drop) can be used to test various knobs. Please find the usage instructions using `onnxruntime_perf_test.exe -h`. -Additionally, the [ONNX Go Live "OLive" tool](https://github.com/microsoft/OLive) provides an easy-to-use pipeline for converting models to ONNX and optimizing performance with ONNX Runtime. The tool can help identify the optimal runtime configuration to get the best performance on the target hardware for the model. +Additionally, the [ONNX Go Live "OLive" tool](https://github.com/microsoft/OLive) provides an easy-to-use pipeline for converting models to ONNX and optimizing performance with ONNX Runtime. The tool can help identify the optimal runtime configuration to get the best performance on the target hardware for the model. For quickstart, check out the notebooks on how to use OLive [here](https://github.com/microsoft/OLive/blob/master/notebook/Convert_Models_and_Tune_Performance_with_OLive_Python_SDK.ipynb) (using Python) and [here](https://github.com/microsoft/OLive/blob/master/notebook/Convert_Models_and_Tune_Performance_with_OLive_Docker_Images.ipynb) (using Docker). ## Using different execution providers @@ -84,7 +84,7 @@ sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL * Thread Count * `sess_options.intra_op_num_threads = 2` controls the number of threads to use to run the model * Sequential vs Parallel Execution - * `sess_options.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL` controls whether then operators in the graph should run sequentially or in parallel. Usually when a model has many branches, setting this option to false will provide better performance. + * `sess_options.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL` controls whether the operators in the graph run sequentially or in parallel. Usually when a model has many branches, setting this option to false will provide better performance. * When `sess_options.execution_mode = rt.ExecutionMode.ORT_PARALLEL`, you can set `sess_options.inter_op_num_threads` to control the number of threads used to parallelize the execution of the graph (across nodes). @@ -122,3 +122,12 @@ In both cases, you will get a JSON file which contains the detailed performance * Open chrome browser * Type chrome://tracing in the address bar * Load the generated JSON file + +## Performance Tuning for Bert Models + +For Bert models, sometimes ONNX Runtime cannot apply the best optimization due to reasons such as framework version updates. In this case, we recommend trying out the [Bert optimization tool](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/bert), which reflects the latest changes in graph pattern matching and model conversions, and a set of [notebooks](https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/bert/notebooks) for quickstart. + + +## Model graph is not optimized even with graph_optimization_level set to ORT_ENABLE_ALL? + +ONNX model from IR_VERSION 4 only treats initializers that appear in graph input as non-constant. This may fail some of the graph optimizations, like const folding, operator fusion and etc. Move initializers out of graph inputs if there is no need to override them, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py. diff --git a/docs/WinRT_API.md b/docs/WinRT_API.md index 2b52ea005cbd6..78649252b315d 100644 --- a/docs/WinRT_API.md +++ b/docs/WinRT_API.md @@ -1,14 +1,14 @@ # Windows Machine Learning WinRT API -New in the ONNX Runtime Nuget package is the ability to use the full [Windows.AI.MachineLearning API](https://docs.microsoft.com/en-us/windows/ai/windows-ml/api-reference). +New in the ONNX Runtime Nuget package is the ability to use the full [WinML API](https://docs.microsoft.com/en-us/windows/ai/windows-ml/api-reference). This allows scenarios such as passing a [Windows.Media.VideoFrame](https://docs.microsoft.com/en-us/uwp/api/Windows.Media.VideoFrame) from your connected camera directly into the runtime for realtime inference. -The Windows.AI.MachineLearning API is a WinRT API that shipped inside the Windows OS starting with build 1809 (RS5). It embedded a version of the ONNX Runtime. +The WinML API is a WinRT API that shipped inside the Windows OS starting with build 1809 (RS5) in the Windows.AI.MachineLearning namespace. It embedded a version of the ONNX Runtime. Many customers have asked for a way to use this offering as an application redistributable package. -With our new [layered architecture](HighLevelDesign.md#the-onnx-runtime-and-windows-os-integration) you can now do this, with some limitations. +With our new [layered architecture](HighLevelDesign.md#the-onnx-runtime-and-windows-os-integration) you can now do this, with some limitations. The WinML APIs have been lifted and mirrored into the Microsoft.AI.MachineLearning namespace in the redistributable. ## NuGet Package @@ -18,30 +18,10 @@ Note: As of the 1.2 release, you can use all of the CPU functionality from these ## Sample Code -Any code already written for the Windows.AI.MachineLearning API can be easily modified to run against the Microsoft.ML.OnnxRuntime package. Check out these [existing samples](https://github.com/microsoft/windows-Machine-Learning) in github. +Any code already written for the Windows.AI.MachineLearning API can be easily modified to run against the Microsoft.ML.OnnxRuntime package. All types originally referenced by inbox customers via the Windows namespace will need to be updated to now use the Microsoft namespace. Check out these [existing samples](https://github.com/microsoft/windows-Machine-Learning) in github. -## Activation and Side-by-Side - -Because Windows.AI.MachineLearning ships inside the OS, default object activation is going to use those OS binaries. Applications must explicitly code to enable the use of the redist binaries when creating WinML objects (Like [LearningModelSession](https://docs.microsoft.com/en-us/uwp/api/windows.ai.machinelearning.learningmodelsession)). - -Read up [here](HighLevelDesign.md#the-onnx-runtime-and-windows-os-integration) in how to decide when to use the OS binaries and when to use redist binaries. - -To create objects using the redist binaries you have several choices depending on how you are consuming the WinRT: - -* cpp/winrt: You can use WINRT_RoGetActivationFactory hooking as shown [here](https://github.com/microsoft/Windows-Machine-Learning/blob/master/Samples/SqueezeNetObjectDetection/Desktop/cpp/dllload.cpp) in our sample projects. -* WRL: (coming soon) -* Raw C++: Simply use the similar code to the cpp/winrt sample to load and use the activation factory in your redist binary. - -## Deciding which header files to use - -The best way to use the API is to use the header files that come with the Windows SDK. - -* For Visual Studio they are included as an optional feature. -* For Visual Studio Code you can download them [here](https://developer.microsoft.com/en-US/windows/downloads/windows-10-sdk/). - -This [tutorial](https://docs.microsoft.com/en-us/windows/ai/windows-ml/get-started-desktop) is a great place to get started. - -To detect if an OS already has Windows.AI.MachineLearning you can use the [IsApiContractPresent](https://docs.microsoft.com/en-us/uwp/api/windows.foundation.metadata.apiinformation.isapicontractpresent) method. This can be called from either UWP or native apps. +## Deciding on whether to use WinML in the Windows SDK or the Redist +To detect if a particular OS version of Windows has the WinML APIs, use the [IsApiContractPresent](https://docs.microsoft.com/en-us/uwp/api/windows.foundation.metadata.apiinformation.isapicontractpresent) method. This can be called from either UWP or native apps. If the OS does not have the runtime you need you can switch to use the redist binaries instead. diff --git a/docs/onnxruntime_dependencies.dot b/docs/onnxruntime_dependencies.dot new file mode 100644 index 0000000000000..a765cf4d6b6a9 --- /dev/null +++ b/docs/onnxruntime_dependencies.dot @@ -0,0 +1,88 @@ +digraph "GG" { + compound=true; + +node [ + fontsize = "12" +]; +subgraph cluster_0 { + label = "onnxruntime.dll"; + "ort_graph" [ label="onnxruntime_graph\n(schemas)" shape="box"]; + "ort_common" [ label="onnxruntime_common" shape="box"]; + "ort_util" [ label="onnxruntime_util" shape="box"]; + "ort_mlas" [ label="onnxruntime_mlas" shape="box"]; + "ort_optimizer" [ label="onnxruntime_optimizer" shape="box"]; + "ort_session" [ label="onnxruntime_session" shape="box"]; + "ort_graph" -> "ort_common" + "onnx" [ label="onnx" shape="box"]; + "protobuf" [ label="Google Protobuf" shape="box"]; + "onnx" -> "protobuf" + "ort_graph" -> "protobuf" + "ort_graph" -> "onnx" + "ort_optimizer" -> "onnx" + "ort_framework" [ label="onnxruntime_framework" shape="box"]; + "ort_framework" -> "ort_graph" + "ort_framework" -> "ort_common" + "ort_framework" -> "onnx" + "ort_cpu_provider" [ label="onnxruntime_cpu_provider\n(kernels)" shape="box"]; + "ort_cpu_provider" -> "ort_common" + "ort_cpu_provider" -> "ort_framework" + "ort_cpu_provider" -> "ort_util" + "ort_cpu_provider" -> "ort_mlas" + "ort_cpu_provider" -> "onnx" + "ort_cuda_provider" [ label="onnxruntime_cuda_provider\n(kernels)" shape="box"]; + "ort_cuda_provider" -> "ort_common" + "ort_cuda_provider" -> "ort_framework" + "ort_cuda_provider" -> "ort_util" + "ort_cuda_provider" -> "ort_mlas" + "ort_cuda_provider" -> "onnx" + "ort_util" -> "ort_common" + "ort_util" -> "ort_framework" + "ort_util" -> "ort_mlas" + "ort_mlas" -> "ort_common" + "ort_session" -> "ort_framework" + "ort_session" -> "ort_common" + "ort_session" -> "ort_graph" + "ort_session" -> "ort_optimizer" + "ort_session" -> "ort_cpu_provider" + "ort_optimizer" -> "ort_cpu_provider" + "ort_optimizer" -> "ort_common" + "ort_optimizer" -> "ort_framework" + "ort_optimizer" -> "ort_graph" + "capi" [ label="C API" shape="box"]; +} + +subgraph cluster_1 { + label = "Application Interfaces"; + style=filled; + color=lightgrey; + node [style=filled,color=white]; + "javaapi" [ label="Java API" shape="box"]; + "csharpapi" [ label="C# API" shape="box"]; + "cppapi" [ label="C++ API\n(header only)" shape="box"]; + "javaapi" -> "capi" + "cppapi" -> "capi" + "csharpapi" -> "capi" + "pythonapi" [ label="Python API" shape="box"]; + pythonapi -> ort_session [lhead=cluster_0] +} + +"grpc" [ label="gRPC" shape="box"]; +"boost" [ label="Boost" shape="box"]; +"onnx2" [ label="onnx" shape="box"]; +"protobuf2" [ label="Google Protobuf" shape="box"]; +"onnx2" -> "protobuf2" +"grpc" -> "protobuf2" + +subgraph cluster_2 { + label = "Applications"; + "onnxruntime_server" [ label="ONNX Runtime Server" shape="box"]; + "onnxruntime_server" -> "cppapi" + "app1" [ label="User application" shape="box"]; + "app2" [ label="User application" shape="box"]; +} + "onnxruntime_server" -> "grpc" + "onnxruntime_server" -> "boost" + "onnxruntime_server" -> "onnx2" +} + + diff --git a/docs/onnxruntime_dependencies.png b/docs/onnxruntime_dependencies.png new file mode 100644 index 0000000000000..a04b5e85b1005 Binary files /dev/null and b/docs/onnxruntime_dependencies.png differ diff --git a/include/onnxruntime/core/common/version.h b/include/onnxruntime/core/common/version.h deleted file mode 100644 index c2274b634d76d..0000000000000 --- a/include/onnxruntime/core/common/version.h +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#define ONNXRUNTIME_VERSION_STRING "1.0" diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 91d2401ab1d66..c348ab4ff27ad 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -84,7 +84,7 @@ struct OrtMemoryInfo { OrtMemoryInfo() = default; // to allow default construction of Tensor // use string for name, so we could have customized allocator in execution provider. - const char* name; + const char* name = nullptr; int id = -1; OrtMemType mem_type = OrtMemTypeDefault; OrtAllocatorType alloc_type = Invalid; diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index aac9c4db5f58e..b2e642765b8af 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -15,22 +15,19 @@ class KernelRegistry { KernelRegistry() = default; // Register a kernel with kernel definition and function to create the kernel. - Status Register(KernelDefBuilder& kernel_def_builder, - const KernelCreateFn& kernel_creator); + Status Register(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator) ORT_MUST_USE_RESULT; - Status Register(KernelCreateInfo&& create_info); + Status Register(KernelCreateInfo&& create_info) ORT_MUST_USE_RESULT; // factory functions should always return a unique_ptr for maximum flexibility // for its clients unless the factory is managing the lifecycle of the pointer // itself. // TODO(Task:132) Make usage of unique_ptr/shared_ptr as out param consistent - Status TryCreateKernel(const onnxruntime::Node& node, - const IExecutionProvider& execution_provider, + Status TryCreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider, const std::unordered_map& constant_initialized_tensors, - const OrtValueNameIdxMap& mlvalue_name_idx_map, - const FuncManager& funcs_mgr, + const OrtValueNameIdxMap& mlvalue_name_idx_map, const FuncManager& funcs_mgr, const DataTransferManager& data_transfer_mgr, - std::unique_ptr& op_kernel) const; + std::unique_ptr& op_kernel) const ORT_MUST_USE_RESULT; // Check if an execution provider can create kernel for a node and return // the kernel if so diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 31a93d1518b7e..4b5c5eba723a4 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -43,9 +43,9 @@ class OpKernel { return op_kernel_info_.GetKernelDef(); } - virtual Status Compute(OpKernelContext* context) const ORT_MUST_USE_RESULT = 0; + virtual Status Compute(_Inout_ OpKernelContext* context) const ORT_MUST_USE_RESULT = 0; - virtual Status ComputeAsync(OpKernelContext*, DoneCallback) const ORT_MUST_USE_RESULT { + virtual Status ComputeAsync(_Inout_ OpKernelContext*, DoneCallback) const ORT_MUST_USE_RESULT { ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); } @@ -64,10 +64,8 @@ class OpKernelContext { public: using ArgMap = std::unordered_map; - explicit OpKernelContext(IExecutionFrame* frame, - const OpKernel* kernel, - concurrency::ThreadPool* threadpool, - const logging::Logger& logger); + OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel, + _In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger); virtual ~OpKernelContext() = default; @@ -136,7 +134,7 @@ class OpKernelContext { Return an allocator on device 0, with memtype of OrtMemTypeDefault. @remarks Use SafeInt when calculating the size of memory to allocate using AllocatorPtr->Alloc. */ - Status GetTempSpaceAllocator(AllocatorPtr* output) const; + Status GetTempSpaceAllocator(AllocatorPtr* output) const ORT_MUST_USE_RESULT; /** Return the fence of current node's input. @@ -193,10 +191,10 @@ class OpKernelContext { int GetImplicitInputArgIndex(int index) const; int GetOutputArgIndex(int index) const; - IExecutionFrame* execution_frame_{nullptr}; - const OpKernel* kernel_{nullptr}; - concurrency::ThreadPool* threadpool_{nullptr}; - const logging::Logger* logger_{nullptr}; + IExecutionFrame* const execution_frame_; + const OpKernel* const kernel_; + concurrency::ThreadPool* const threadpool_; + const logging::Logger* const logger_; // The argument starting index in ExecutionFrame. int node_input_start_index_{-1}; diff --git a/include/onnxruntime/core/graph/onnx_protobuf.h b/include/onnxruntime/core/graph/onnx_protobuf.h index 86ea13940df99..40d55a9581333 100644 --- a/include/onnxruntime/core/graph/onnx_protobuf.h +++ b/include/onnxruntime/core/graph/onnx_protobuf.h @@ -29,6 +29,9 @@ #pragma warning(disable : 4506) /*no definition for inline function 'function'*/ #pragma warning(disable : 4800) /*'type' : forcing value to bool 'true' or 'false' (performance warning)*/ #pragma warning(disable : 4996) /*The compiler encountered a deprecated declaration.*/ +#pragma warning(disable : 6011) /*Dereferencing NULL pointer*/ +#pragma warning(disable : 6387) /*'value' could be '0'*/ +#pragma warning(disable : 26495) /*Variable is uninitialized.*/ #endif #include "onnx/defs/schema.h" #include "onnx/onnx_pb.h" diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e7867142b5ec1..af6e461d392e0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7,7 +7,7 @@ #include // This value is used in structures passed to ORT so that a newer version of ORT will still work with them -#define ORT_API_VERSION 2 +#define ORT_API_VERSION 3 #ifdef __cplusplus extern "C" { @@ -16,7 +16,9 @@ extern "C" { // SAL2 Definitions #ifndef _WIN32 #define _In_ +#define _In_z_ #define _In_opt_ +#define _In_opt_z_ #define _Out_ #define _Outptr_ #define _Out_opt_ @@ -26,7 +28,13 @@ extern "C" { #define _Ret_maybenull_ #define _Ret_notnull_ #define _Check_return_ +#define _Outptr_result_maybenull_ +#define _In_reads_(X) +#define _Inout_updates_all_(X) +#define _Out_writes_bytes_all_(X) +#define _Out_writes_all_(X) #define _Success_(X) +#define _Outptr_result_buffer_maybenull_(X) #define ORT_ALL_ARGS_NONNULL __attribute__((nonnull)) #else #include @@ -125,25 +133,10 @@ typedef enum OrtErrorCode { ORT_EP_FAIL, } OrtErrorCode; -// __VA_ARGS__ on Windows and Linux are different -#define ORT_API(RETURN_TYPE, NAME, ...) \ - ORT_EXPORT RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION - -#define ORT_API_STATUS(NAME, ...) \ - ORT_EXPORT _Check_return_ _Success_(return == 0) _Ret_maybenull_ OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT - -#define ORT_API2_STATUS(NAME, ...) \ - _Check_return_ _Success_(return == 0) _Ret_maybenull_ OrtStatus*(ORT_API_CALL * NAME)(__VA_ARGS__)NO_EXCEPTION ORT_MUST_USE_RESULT - -// Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT -#define ORT_API_STATUS_IMPL(NAME, ...) \ - ORT_EXPORT _Check_return_ _Success_(return == 0) _Ret_maybenull_ OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION #define ORT_RUNTIME_CLASS(X) \ struct Ort##X; \ typedef struct Ort##X Ort##X; -// ORT_API(void, OrtRelease##X, _Frees_ptr_opt_ Ort##X* input); -#define ORT_CLASS_RELEASE(X) void(ORT_API_CALL * Release##X)(_Frees_ptr_opt_ Ort##X * input) // The actual types defined have an Ort prefix ORT_RUNTIME_CLASS(Env); @@ -162,6 +155,28 @@ ORT_RUNTIME_CLASS(ModelMetadata); ORT_RUNTIME_CLASS(ThreadPoolParams); ORT_RUNTIME_CLASS(ThreadingOptions); +#ifdef _WIN32 +typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; +#else +typedef OrtStatus* OrtStatusPtr; +#endif + +// __VA_ARGS__ on Windows and Linux are different +#define ORT_API(RETURN_TYPE, NAME, ...) ORT_EXPORT RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION + +#define ORT_API_STATUS(NAME, ...) \ + ORT_EXPORT _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT + +// XXX: Unfortunately, SAL annotations are known to not work with function pointers +#define ORT_API2_STATUS(NAME, ...) \ + _Check_return_ _Ret_maybenull_ OrtStatusPtr(ORT_API_CALL* NAME)(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT + +// Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT and ORT_EXPORT +#define ORT_API_STATUS_IMPL(NAME, ...) \ + _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION + +#define ORT_CLASS_RELEASE(X) void(ORT_API_CALL * Release##X)(_Frees_ptr_opt_ Ort##X * input) + // When passing in an allocator to any ORT function, be sure that the allocator object // is not destroyed until the last allocated object using it is freed. typedef struct OrtAllocator { @@ -244,105 +259,106 @@ struct OrtApi { /** * \param out Should be freed by `OrtReleaseEnv` after use */ - OrtStatus*(ORT_API_CALL* CreateEnv)(OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out) - NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + ORT_API2_STATUS(CreateEnv, OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out); /** * \param out Should be freed by `OrtReleaseEnv` after use */ - OrtStatus*(ORT_API_CALL* CreateEnvWithCustomLogger)(OrtLoggingFunction logging_function, - _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, - _In_ const char* logid, - _Outptr_ OrtEnv** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, + OrtLoggingLevel default_warning_level, _In_ const char* logid, _Outptr_ OrtEnv** out); // Platform telemetry events are on by default since they are lightweight. You can manually turn them off. - OrtStatus*(ORT_API_CALL* EnableTelemetryEvents)(_In_ const OrtEnv* env)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* DisableTelemetryEvents)(_In_ const OrtEnv* env)NO_EXCEPTION; + ORT_API2_STATUS(EnableTelemetryEvents, _In_ const OrtEnv* env); + ORT_API2_STATUS(DisableTelemetryEvents, _In_ const OrtEnv* env); // TODO: document the path separator convention? '/' vs '\' // TODO: should specify the access characteristics of model_path. Is this read only during the // execution of CreateSession, or does the OrtSession retain a handle to the file/directory // and continue to access throughout the OrtSession lifetime? // What sort of access is needed to model_path : read or read/write? - OrtStatus*(ORT_API_CALL* CreateSession)(_In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); - OrtStatus*(ORT_API_CALL* CreateSessionFromArray)(_In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); - OrtStatus*(ORT_API_CALL* Run)(_Inout_ OrtSession* sess, - _In_opt_ const OrtRunOptions* run_options, - _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, - _In_ const char* const* output_names, size_t output_names_len, _Outptr_ OrtValue** output)NO_EXCEPTION; + ORT_API2_STATUS(Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* input, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** output); /** * \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use */ - OrtStatus*(ORT_API_CALL* CreateSessionOptions)(_Outptr_ OrtSessionOptions** options)NO_EXCEPTION; + ORT_API2_STATUS(CreateSessionOptions, _Outptr_ OrtSessionOptions** options); // Set filepath to save optimized model after graph level transformations. - OrtStatus*(ORT_API_CALL* SetOptimizedModelFilePath)(_Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath)NO_EXCEPTION; + ORT_API2_STATUS(SetOptimizedModelFilePath, _Inout_ OrtSessionOptions* options, + _In_ const ORTCHAR_T* optimized_model_filepath); // create a copy of an existing OrtSessionOptions - OrtStatus*(ORT_API_CALL* CloneSessionOptions)(_In_ const OrtSessionOptions* in_options, _Outptr_ OrtSessionOptions** out_options)NO_EXCEPTION; + ORT_API2_STATUS(CloneSessionOptions, _In_ const OrtSessionOptions* in_options, + _Outptr_ OrtSessionOptions** out_options); // Controls whether you want to execute operators in your graph sequentially or in parallel. Usually when the model // has many branches, setting this option to ExecutionMode.ORT_PARALLEL will give you better performance. // See [docs/ONNX_Runtime_Perf_Tuning.md] for more details. - OrtStatus*(ORT_API_CALL* SetSessionExecutionMode)(_Inout_ OrtSessionOptions* options, ExecutionMode execution_mode)NO_EXCEPTION; + ORT_API2_STATUS(SetSessionExecutionMode, _Inout_ OrtSessionOptions* options, ExecutionMode execution_mode); // Enable profiling for this session. - OrtStatus*(ORT_API_CALL* EnableProfiling)(_Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* DisableProfiling)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION; + ORT_API2_STATUS(EnableProfiling, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); + ORT_API2_STATUS(DisableProfiling, _Inout_ OrtSessionOptions* options); // Enable the memory pattern optimization. // The idea is if the input shapes are the same, we could trace the internal memory allocation // and generate a memory pattern for future request. So next time we could just do one allocation // with a big chunk for all the internal memory allocation. // Note: memory pattern optimization is only available when SequentialExecution enabled. - OrtStatus*(ORT_API_CALL* EnableMemPattern)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* DisableMemPattern)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION; + ORT_API2_STATUS(EnableMemPattern, _Inout_ OrtSessionOptions* options); + ORT_API2_STATUS(DisableMemPattern, _Inout_ OrtSessionOptions* options); // Enable the memory arena on CPU // Arena may pre-allocate memory for future usage. // set this option to false if you don't want it. - OrtStatus*(ORT_API_CALL* EnableCpuMemArena)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* DisableCpuMemArena)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION; + ORT_API2_STATUS(EnableCpuMemArena, _Inout_ OrtSessionOptions* options); + ORT_API2_STATUS(DisableCpuMemArena, _Inout_ OrtSessionOptions* options); // < logger id to use for session output - OrtStatus*(ORT_API_CALL* SetSessionLogId)(_Inout_ OrtSessionOptions* options, const char* logid)NO_EXCEPTION; + ORT_API2_STATUS(SetSessionLogId, _Inout_ OrtSessionOptions* options, const char* logid); // < applies to session load, initialization, etc - OrtStatus*(ORT_API_CALL* SetSessionLogVerbosityLevel)(_Inout_ OrtSessionOptions* options, int session_log_verbosity_level)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* SetSessionLogSeverityLevel)(_Inout_ OrtSessionOptions* options, int session_log_severity_level)NO_EXCEPTION; + ORT_API2_STATUS(SetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level); + ORT_API2_STATUS(SetSessionLogSeverityLevel, _Inout_ OrtSessionOptions* options, int session_log_severity_level); - OrtStatus*(ORT_API_CALL* SetSessionGraphOptimizationLevel)(_Inout_ OrtSessionOptions* options, GraphOptimizationLevel graph_optimization_level)NO_EXCEPTION; + ORT_API2_STATUS(SetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, + GraphOptimizationLevel graph_optimization_level); // Sets the number of threads used to parallelize the execution within nodes // A value of 0 means ORT will pick a default - OrtStatus*(ORT_API_CALL* SetIntraOpNumThreads)(_Inout_ OrtSessionOptions* options, int intra_op_num_threads); + ORT_API2_STATUS(SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads); // Sets the number of threads used to parallelize the execution of the graph (across nodes) // If sequential execution is enabled this value is ignored // A value of 0 means ORT will pick a default - OrtStatus*(ORT_API_CALL* SetInterOpNumThreads)(_Inout_ OrtSessionOptions* options, int inter_op_num_threads); + ORT_API2_STATUS(SetInterOpNumThreads, _Inout_ OrtSessionOptions* options, int inter_op_num_threads); /* Create a custom op domain. After all sessions using it are released, call OrtReleaseCustomOpDomain */ - OrtStatus*(ORT_API_CALL* CreateCustomOpDomain)(_In_ const char* domain, _Outptr_ OrtCustomOpDomain** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out); /* * Add custom ops to the OrtCustomOpDomain * Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released */ - OrtStatus*(ORT_API_CALL* CustomOpDomain_Add)(_Inout_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op)NO_EXCEPTION; + ORT_API2_STATUS(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op); /* * Add a custom op domain to the OrtSessionOptions * Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released */ - OrtStatus*(ORT_API_CALL* AddCustomOpDomain)(_Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain)NO_EXCEPTION; + ORT_API2_STATUS(AddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain); /* * Loads a DLL named 'library_path' and looks for this entry point: @@ -351,7 +367,8 @@ struct OrtApi { * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in * session options are destroyed, or if an error occurs and it is non null. */ - OrtStatus*(ORT_API_CALL* RegisterCustomOpsLibrary)(_Inout_ OrtSessionOptions* options, _In_ const char* library_path, void** library_handle)NO_EXCEPTION; + ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, + void** library_handle); /** * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these @@ -364,130 +381,135 @@ struct OrtApi { * If none are called Ort will use its internal CPU execution provider. */ - OrtStatus*(ORT_API_CALL* SessionGetInputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* SessionGetOutputCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* SessionGetOverridableInitializerCount)(_In_ const OrtSession* sess, _Out_ size_t* out)NO_EXCEPTION; + ORT_API2_STATUS(SessionGetInputCount, _In_ const OrtSession* sess, _Out_ size_t* out); + ORT_API2_STATUS(SessionGetOutputCount, _In_ const OrtSession* sess, _Out_ size_t* out); + ORT_API2_STATUS(SessionGetOverridableInitializerCount, _In_ const OrtSession* sess, _Out_ size_t* out); /** * \param out should be freed by OrtReleaseTypeInfo after use */ - OrtStatus*(ORT_API_CALL* SessionGetInputTypeInfo)(_In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + ORT_API2_STATUS(SessionGetInputTypeInfo, _In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info); /** * \param out should be freed by OrtReleaseTypeInfo after use */ - OrtStatus*(ORT_API_CALL* SessionGetOutputTypeInfo)(_In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + ORT_API2_STATUS(SessionGetOutputTypeInfo, _In_ const OrtSession* sess, size_t index, + _Outptr_ OrtTypeInfo** type_info); /** * \param out should be freed by OrtReleaseTypeInfo after use */ - OrtStatus*(ORT_API_CALL* SessionGetOverridableInitializerTypeInfo)(_In_ const OrtSession* sess, size_t index, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + ORT_API2_STATUS(SessionGetOverridableInitializerTypeInfo, _In_ const OrtSession* sess, size_t index, + _Outptr_ OrtTypeInfo** type_info); /** * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. */ - OrtStatus*(ORT_API_CALL* SessionGetInputName)(_In_ const OrtSession* sess, size_t index, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* SessionGetOutputName)(_In_ const OrtSession* sess, size_t index, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* SessionGetOverridableInitializerName)(_In_ const OrtSession* sess, size_t index, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; + ORT_API2_STATUS(SessionGetInputName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, + _Outptr_ char** value); + ORT_API2_STATUS(SessionGetOutputName, _In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, + _Outptr_ char** value); + ORT_API2_STATUS(SessionGetOverridableInitializerName, _In_ const OrtSession* sess, size_t index, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); /** * \return A pointer to the newly created object. The pointer should be freed by OrtReleaseRunOptions after use */ - OrtStatus*(ORT_API_CALL* CreateRunOptions)(_Outptr_ OrtRunOptions** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateRunOptions, _Outptr_ OrtRunOptions** out); - OrtStatus*(ORT_API_CALL* RunOptionsSetRunLogVerbosityLevel)(_Inout_ OrtRunOptions* options, int value)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* RunOptionsSetRunLogSeverityLevel)(_Inout_ OrtRunOptions* options, int value)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* RunOptionsSetRunTag)(_In_ OrtRunOptions*, _In_ const char* run_tag)NO_EXCEPTION; + ORT_API2_STATUS(RunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int value); + ORT_API2_STATUS(RunOptionsSetRunLogSeverityLevel, _Inout_ OrtRunOptions* options, int value); + ORT_API2_STATUS(RunOptionsSetRunTag, _Inout_ OrtRunOptions*, _In_ const char* run_tag); - OrtStatus*(ORT_API_CALL* RunOptionsGetRunLogVerbosityLevel)(_In_ const OrtRunOptions* options, _Out_ int* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* RunOptionsGetRunLogSeverityLevel)(_In_ const OrtRunOptions* options, _Out_ int* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* RunOptionsGetRunTag)(_In_ const OrtRunOptions*, _Out_ const char** out)NO_EXCEPTION; + ORT_API2_STATUS(RunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, _Out_ int* out); + ORT_API2_STATUS(RunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, _Out_ int* out); + ORT_API2_STATUS(RunOptionsGetRunTag, _In_ const OrtRunOptions*, _Out_ const char** out); // Set a flag so that ALL incomplete OrtRun calls that are using this instance of OrtRunOptions // will exit as soon as possible. - OrtStatus*(ORT_API_CALL* RunOptionsSetTerminate)(_Inout_ OrtRunOptions* options)NO_EXCEPTION; + ORT_API2_STATUS(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options); // Unset the terminate flag to enable this OrtRunOptions instance being used in new OrtRun calls. - OrtStatus*(ORT_API_CALL* RunOptionsUnsetTerminate)(_Inout_ OrtRunOptions* options)NO_EXCEPTION; + ORT_API2_STATUS(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); /** * Create a tensor from an allocator. OrtReleaseValue will also release the buffer inside the output value * \param out Should be freed by calling OrtReleaseValue * \param type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx */ - OrtStatus*(ORT_API_CALL* CreateTensorAsOrtValue)(_Inout_ OrtAllocator* allocator, - _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, - _Outptr_ OrtValue** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, _Outptr_ OrtValue** out); /** * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. * p_data is owned by caller. OrtReleaseValue won't release p_data. * \param out Should be freed by calling OrtReleaseValue */ - OrtStatus*(ORT_API_CALL* CreateTensorWithDataAsOrtValue)(_In_ const OrtMemoryInfo* info, - _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type, _Outptr_ OrtValue** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, + size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); /** * \Sets *out to 1 iff an OrtValue is a tensor, 0 otherwise */ - OrtStatus*(ORT_API_CALL* IsTensor)(_In_ const OrtValue* value, _Out_ int* out)NO_EXCEPTION; + ORT_API2_STATUS(IsTensor, _In_ const OrtValue* value, _Out_ int* out); // This function doesn't work with string tensor // this is a no-copy method whose pointer is only valid until the backing OrtValue is free'd. - OrtStatus*(ORT_API_CALL* GetTensorMutableData)(_Inout_ OrtValue* value, _Outptr_ void** out)NO_EXCEPTION; + ORT_API2_STATUS(GetTensorMutableData, _Inout_ OrtValue* value, _Outptr_ void** out); /** * \param value A tensor created from OrtCreateTensor... function. * \param s each A string array. Each string in this array must be null terminated. * \param s_len length of s */ - OrtStatus*(ORT_API_CALL* FillStringTensor)(_Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len)NO_EXCEPTION; + ORT_API2_STATUS(FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len); /** * \param value A tensor created from OrtCreateTensor... function. * \param len total data length, not including the trailing '\0' chars. */ - OrtStatus*(ORT_API_CALL* GetStringTensorDataLength)(_In_ const OrtValue* value, _Out_ size_t* len)NO_EXCEPTION; + ORT_API2_STATUS(GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len); /** * \param s string contents. Each string is NOT null-terminated. * \param value A tensor created from OrtCreateTensor... function. * \param s_len total data length, get it from OrtGetStringTensorDataLength */ - OrtStatus*(ORT_API_CALL* GetStringTensorContent)(_In_ const OrtValue* value, _Out_ void* s, size_t s_len, - _Out_ size_t* offsets, size_t offsets_len)NO_EXCEPTION; + ORT_API2_STATUS(GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, + size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len); /** * Don't free the 'out' value */ - OrtStatus*(ORT_API_CALL* CastTypeInfoToTensorInfo)(_In_ const OrtTypeInfo*, _Out_ const OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION; + ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo*, + _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); /** * Return OnnxType from OrtTypeInfo */ - OrtStatus*(ORT_API_CALL* GetOnnxTypeFromTypeInfo)(_In_ const OrtTypeInfo*, _Out_ enum ONNXType* out)NO_EXCEPTION; + ORT_API2_STATUS(GetOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ enum ONNXType* out); /** * The 'out' value should be released by calling OrtReleaseTensorTypeAndShapeInfo */ - OrtStatus*(ORT_API_CALL* CreateTensorTypeAndShapeInfo)(_Outptr_ OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out); - OrtStatus*(ORT_API_CALL* SetTensorElementType)(_Inout_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type)NO_EXCEPTION; + ORT_API2_STATUS(SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type); /** * \param info Created from CreateTensorTypeAndShapeInfo() function * \param dim_values An array with length of `dim_count`. Its elements can contain negative values. * \param dim_count length of dim_values */ - OrtStatus*(ORT_API_CALL* SetDimensions)(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count)NO_EXCEPTION; + ORT_API2_STATUS(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); - OrtStatus*(ORT_API_CALL* GetTensorElementType)(_In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* GetDimensionsCount)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* GetDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* GetSymbolicDimensions)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ const char** dim_params, size_t dim_params_length)NO_EXCEPTION; + ORT_API2_STATUS(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo*, + _Out_ enum ONNXTensorElementDataType* out); + ORT_API2_STATUS(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); + ORT_API2_STATUS(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, + size_t dim_values_length); + ORT_API2_STATUS(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, + _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); /** * Return the number of elements specified by the tensor shape. @@ -498,57 +520,57 @@ struct OrtApi { * [2,0,4] -> 0 * [-1,3,4] -> -1 */ - OrtStatus*(ORT_API_CALL* GetTensorShapeElementCount)(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out)NO_EXCEPTION; + ORT_API2_STATUS(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); /** * \param out Should be freed by OrtReleaseTensorTypeAndShapeInfo after use */ - OrtStatus*(ORT_API_CALL* GetTensorTypeAndShape)(_In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out)NO_EXCEPTION; + ORT_API2_STATUS(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); /** * Get the type information of an OrtValue * \param value * \param out The returned value should be freed by OrtReleaseTypeInfo after use */ - OrtStatus*(ORT_API_CALL* GetTypeInfo)(_In_ const OrtValue* value, _Outptr_ OrtTypeInfo** out)NO_EXCEPTION; + ORT_API2_STATUS(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); - OrtStatus*(ORT_API_CALL* GetValueType)(_In_ const OrtValue* value, _Out_ enum ONNXType* out)NO_EXCEPTION; + ORT_API2_STATUS(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); - OrtStatus*(ORT_API_CALL* CreateMemoryInfo)(_In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, + enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out); /** * Convenience function for special case of CreateMemoryInfo, for the CPU allocator. Uses name = "Cpu" and id = 0. */ - OrtStatus*(ORT_API_CALL* CreateCpuMemoryInfo)(enum OrtAllocatorType type, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out)NO_EXCEPTION - ORT_ALL_ARGS_NONNULL; + ORT_API2_STATUS(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, + _Outptr_ OrtMemoryInfo** out); /** * Test if two memory info are equal * \Sets 'out' to 0 if equal, -1 if not equal */ - OrtStatus*(ORT_API_CALL* CompareMemoryInfo)(_In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2, _Out_ int* out)NO_EXCEPTION - ORT_ALL_ARGS_NONNULL; + ORT_API2_STATUS(CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2, _Out_ int* out); /** * Do not free the returned value */ - OrtStatus*(ORT_API_CALL* MemoryInfoGetName)(_In_ const OrtMemoryInfo* ptr, _Out_ const char** out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* MemoryInfoGetId)(_In_ const OrtMemoryInfo* ptr, _Out_ int* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* MemoryInfoGetMemType)(_In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* MemoryInfoGetType)(_In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out)NO_EXCEPTION; + ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out); + ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out); + ORT_API2_STATUS(MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out); + ORT_API2_STATUS(MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out); - OrtStatus*(ORT_API_CALL* AllocatorAlloc)(_Inout_ OrtAllocator* ptr, size_t size, _Outptr_ void** out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* AllocatorFree)(_Inout_ OrtAllocator* ptr, void* p)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* AllocatorGetInfo)(_In_ const OrtAllocator* ptr, _Out_ const OrtMemoryInfo** out)NO_EXCEPTION; + ORT_API2_STATUS(AllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size, _Outptr_ void** out); + ORT_API2_STATUS(AllocatorFree, _Inout_ OrtAllocator* ptr, void* p); + ORT_API2_STATUS(AllocatorGetInfo, _In_ const OrtAllocator* ptr, _Outptr_ const struct OrtMemoryInfo** out); // The returned pointer doesn't have to be freed. // Always returns the same instance on every invocation. - OrtStatus*(ORT_API_CALL* GetAllocatorWithDefaultOptions)(_Outptr_ OrtAllocator** out)NO_EXCEPTION; + ORT_API2_STATUS(GetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out); // Override symbolic dimensions with actual values if known at session initialization time to enable // optimizations that can take advantage of fixed values (such as memory planning, etc) - OrtStatus*(ORT_API_CALL* AddFreeDimensionOverride)(_Inout_ OrtSessionOptions* options, - _In_ const char* symbolic_dim, _In_ int64_t dim_override)NO_EXCEPTION; + ORT_API2_STATUS(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* symbolic_dim, + _In_ int64_t dim_override); /** * APIs to support non-tensor types - map and sequence. @@ -581,13 +603,14 @@ struct OrtApi { * If input OrtValue represents a sequence, use index to retrieve the index'th element * of the sequence. */ - OrtStatus*(ORT_API_CALL* GetValue)(_In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out)NO_EXCEPTION; + ORT_API2_STATUS(GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** out); /** * Returns 2 for type map and N for sequence where N is the number of elements * in the sequence. */ - OrtStatus*(ORT_API_CALL* GetValueCount)(_In_ const OrtValue* value, _Out_ size_t* out)NO_EXCEPTION; + ORT_API2_STATUS(GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out); /** * To construct a map, use num_values = 2 and 'in' should be an arrary of 2 OrtValues @@ -596,8 +619,8 @@ struct OrtApi { * sequence. 'in' should be an arrary of N OrtValues. * \value_type should be either map or sequence. */ - OrtStatus*(ORT_API_CALL* CreateValue)(_In_ const OrtValue* const* in, size_t num_values, enum ONNXType value_type, - _Outptr_ OrtValue** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values, + enum ONNXType value_type, _Outptr_ OrtValue** out); /** * Construct OrtValue that contains a value of non-standard type created for @@ -618,8 +641,8 @@ struct OrtApi { * \data_container_size - sizeof() of the data container. Must match the sizeof() of the expected * data_container size internally. */ - OrtStatus*(ORT_API_CALL* CreateOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name, - _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out)NO_EXCEPTION; + ORT_API2_STATUS(CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name, + _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out); /** * Fetch data from an OrtValue that contains a value of non-standard type created for @@ -631,17 +654,22 @@ struct OrtApi { * data_container size internally. */ - OrtStatus*(ORT_API_CALL* GetOpaqueValue)(_In_ const char* domain_name, _In_ const char* type_name, - _In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size)NO_EXCEPTION; + ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, + _Out_ void* data_container, size_t data_container_size); - OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_float)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_int64)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* KernelInfoGetAttribute_string)(_In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, _Inout_ size_t* size)NO_EXCEPTION; + ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out); + ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out); + ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, + _Inout_ size_t* size); - OrtStatus*(ORT_API_CALL* KernelContext_GetInputCount)(_In_ const OrtKernelContext* context, _Out_ size_t* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* KernelContext_GetOutputCount)(_In_ const OrtKernelContext* context, _Out_ size_t* out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* KernelContext_GetInput)(_In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* KernelContext_GetOutput)(_Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out)NO_EXCEPTION; + ORT_API2_STATUS(KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); + ORT_API2_STATUS(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); + ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, + _Out_ const OrtValue** out); + ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, + _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); ORT_CLASS_RELEASE(Env); ORT_CLASS_RELEASE(Status); // nullptr for Status* indicates success @@ -663,7 +691,8 @@ struct OrtApi { * This api augments OrtTypeInfo to return denotations on the type. * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. */ - OrtStatus*(ORT_API_CALL* GetDenotationFromTypeInfo)(_In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len)NO_EXCEPTION; + ORT_API2_STATUS(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, + _Out_ size_t* len); // OrtTypeInfo Casting methods @@ -676,7 +705,8 @@ struct OrtApi { * * Don't free the 'out' value */ - OrtStatus*(ORT_API_CALL* CastTypeInfoToMapTypeInfo)(_In_ const OrtTypeInfo* type_info, _Out_ const OrtMapTypeInfo** out)NO_EXCEPTION; + ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtMapTypeInfo** out); /** * CastTypeInfoToSequenceTypeInfo @@ -686,7 +716,8 @@ struct OrtApi { * * Don't free the 'out' value */ - OrtStatus*(ORT_API_CALL* CastTypeInfoToSequenceTypeInfo)(_In_ const OrtTypeInfo* type_info, _Out_ const OrtSequenceTypeInfo** out)NO_EXCEPTION; + ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); // OrtMapTypeInfo Accessors @@ -695,13 +726,13 @@ struct OrtApi { * This api augments get the key type of a map. Key types are restricted to being scalar types and use ONNXTensorElementDataType. * This is used by WinML to support model reflection APIs. */ - OrtStatus*(ORT_API_CALL* GetMapKeyType)(_In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out)NO_EXCEPTION; + ORT_API2_STATUS(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); /** * GetMapValueType * This api augments get the value type of a map. */ - OrtStatus*(ORT_API_CALL* GetMapValueType)(_In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + ORT_API2_STATUS(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info); // OrtSequenceTypeInfo Accessors @@ -710,7 +741,8 @@ struct OrtApi { * This api augments get the element type of a sequence. * This is used by WinML to support model reflection APIs. */ - OrtStatus*(ORT_API_CALL* GetSequenceElementType)(_In_ const OrtSequenceTypeInfo* sequence_type_info, _Outptr_ OrtTypeInfo** type_info)NO_EXCEPTION; + ORT_API2_STATUS(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, + _Outptr_ OrtTypeInfo** type_info); ORT_CLASS_RELEASE(MapTypeInfo); ORT_CLASS_RELEASE(SequenceTypeInfo); @@ -720,34 +752,32 @@ struct OrtApi { * Profiling is turned ON automatically if enabled for the particular session by invoking EnableProfiling() * on the SessionOptions instance used to create the session. */ - OrtStatus*(ORT_API_CALL* SessionEndProfiling)(_In_ OrtSession* sess, _Inout_ OrtAllocator* allocator, - _Outptr_ char** out)NO_EXCEPTION; + ORT_API2_STATUS(SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator, _Outptr_ char** out); /** * \param out is a pointer to the newly created object. The pointer should be freed by calling ReleaseModelMetadata after use. */ - OrtStatus*(ORT_API_CALL* SessionGetModelMetadata)(_In_ const OrtSession* sess, - _Outptr_ OrtModelMetadata** out)NO_EXCEPTION; + ORT_API2_STATUS(SessionGetModelMetadata, _In_ const OrtSession* sess, _Outptr_ OrtModelMetadata** out); /** * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. */ - OrtStatus*(ORT_API_CALL* ModelMetadataGetProducerName)(_In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* ModelMetadataGetGraphName)(_In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* ModelMetadataGetDomain)(_In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; - OrtStatus*(ORT_API_CALL* ModelMetadataGetDescription)(_In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value)NO_EXCEPTION; + ORT_API2_STATUS(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + ORT_API2_STATUS(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + ORT_API2_STATUS(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, + _Outptr_ char** value); + ORT_API2_STATUS(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); /** * \param value is set to a null terminated string allocated using 'allocator'. The caller is responsible for freeing it. * 'value' will be a nullptr if the given key is not found in the custom metadata map. */ - OrtStatus*(ORT_API_CALL* ModelMetadataLookupCustomMetadataMap)(_In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, - _In_ const char* key, _Outptr_ char** value)NO_EXCEPTION; + ORT_API2_STATUS(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value); - OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION; + ORT_API2_STATUS(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value); ORT_CLASS_RELEASE(ModelMetadata); @@ -756,9 +786,8 @@ struct OrtApi { * Use this in conjunction with DisablePerSessionThreads API or else the session will use * its own thread pools. */ - OrtStatus*(ORT_API_CALL* CreateEnvWithGlobalThreadPools)(OrtLoggingLevel default_logging_level, _In_ const char* logid, - _In_ const OrtThreadingOptions* t_options, _Outptr_ OrtEnv** out) - NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + ORT_API2_STATUS(CreateEnvWithGlobalThreadPools, OrtLoggingLevel default_logging_level, _In_ const char* logid, + _In_ const OrtThreadingOptions* t_options, _Outptr_ OrtEnv** out); /* TODO: Should there be a version of CreateEnvWithGlobalThreadPools with custom logging function? */ @@ -766,12 +795,20 @@ struct OrtApi { * Calling this API will make the session use the global threadpools shared across sessions. * This API should be used in conjunction with CreateEnvWithGlobalThreadPools API. */ - OrtStatus*(ORT_API_CALL* DisablePerSessionThreads)(_Inout_ OrtSessionOptions* options)NO_EXCEPTION; + ORT_API2_STATUS(DisablePerSessionThreads, _Inout_ OrtSessionOptions* options); - OrtStatus*(ORT_API_CALL* CreateThreadingOptions)(_Outptr_ OrtThreadingOptions** out) - NO_EXCEPTION; + ORT_API2_STATUS(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out); ORT_CLASS_RELEASE(ThreadingOptions); + + /** + * \param num_keys contains the number of keys in the custom metadata map + * \param keys is an array of null terminated strings (array count = num_keys) allocated using 'allocator'. + * The caller is responsible for freeing each string and the pointer array. + * 'keys' will be a nullptr if custom metadata map is empty. + */ + ORT_API2_STATUS(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys); }; /* @@ -790,7 +827,8 @@ struct OrtCustomOp { uint32_t version; // Initialize to ORT_API_VERSION // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. - void*(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtApi* api, _In_ const OrtKernelInfo* info); + void*(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ const OrtApi* api, + _In_ const OrtKernelInfo* info); // Returns the name of the op const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index cd132b162897c..de6886bcdd847 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -199,6 +199,7 @@ struct ModelMetadata : Base { char* GetGraphName(OrtAllocator* allocator) const; char* GetDomain(OrtAllocator* allocator) const; char* GetDescription(OrtAllocator* allocator) const; + char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; int64_t GetVersion() const; }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 258b2bc4905c5..5c60743cf7b28 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -84,7 +84,7 @@ inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLog ThrowOnError(Global::api_.CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_)); } -inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel default_warning_level, const char* logid) { +inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel default_warning_level, _In_ const char* logid) { ThrowOnError(Global::api_.CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_)); } @@ -324,6 +324,12 @@ inline char* ModelMetadata::LookupCustomMetadataMap(const char* key, OrtAllocato return out; } +inline char** ModelMetadata::GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const { + char** out; + ThrowOnError(Global::api_.ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys)); + return out; +} + inline int64_t ModelMetadata::GetVersion() const { int64_t out; ThrowOnError(Global::api_.ModelMetadataGetVersion(p_, &out)); diff --git a/java/build.gradle b/java/build.gradle index 528867c1de2b8..dd6c7b3b38ce0 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -1,5 +1,6 @@ plugins { id 'java' + id 'jacoco' id 'com.diffplug.gradle.spotless' version '3.26.0' } @@ -105,7 +106,6 @@ if (cmakeBuildDir != null) { } - dependencies { testImplementation 'org.junit.jupiter:junit-jupiter-api:5.1.1' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.1.1' @@ -118,3 +118,11 @@ test { events "passed", "skipped", "failed" } } + +jacocoTestReport { + reports { + xml.enabled true + csv.enabled true + html.destination file("${buildDir}/jacocoHtml") + } +} diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index a82bc00a02308..18f9c397bba93 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -22,6 +22,8 @@ final class OnnxRuntime { // The initial release of the ORT API. private static final int ORT_API_VERSION_1 = 1; + // Post 1.0 builds of the ORT API. + private static final int ORT_API_VERSION_2 = 2; /** The short name of the ONNX runtime shared library */ static final String ONNXRUNTIME_LIBRARY_NAME = "onnxruntime"; @@ -48,7 +50,7 @@ static synchronized void init() throws IOException { try { load(tempDirectory, ONNXRUNTIME_LIBRARY_NAME); load(tempDirectory, ONNXRUNTIME_JNI_LIBRARY_NAME); - ortApiHandle = initialiseAPIBase(ORT_API_VERSION_1); + ortApiHandle = initialiseAPIBase(ORT_API_VERSION_2); loaded = true; } finally { if (!isAndroid()) { diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 2cd21cd3bfab7..80a1d78fa49f7 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -60,18 +60,9 @@ public class OrtSession implements AutoCloseable { */ OrtSession(OrtEnvironment env, String modelPath, OrtAllocator allocator, SessionOptions options) throws OrtException { - nativeHandle = - createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelPath, options.nativeHandle); - this.allocator = allocator; - numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle); - inputNames = - new LinkedHashSet<>( - Arrays.asList(getInputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); - numOutputs = getNumOutputs(OnnxRuntime.ortApiHandle, nativeHandle); - outputNames = - new LinkedHashSet<>( - Arrays.asList( - getOutputNames(OnnxRuntime.ortApiHandle, nativeHandle, allocator.handle))); + this( + createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelPath, options.nativeHandle), + allocator); } /** @@ -81,12 +72,24 @@ public class OrtSession implements AutoCloseable { * @param modelArray The model protobuf as a byte array. * @param allocator The allocator to use. * @param options Session configuration options. - * @throws OrtException If the mode was corrupted or some other error occurred in native code. + * @throws OrtException If the model was corrupted or some other error occurred in native code. */ OrtSession(OrtEnvironment env, byte[] modelArray, OrtAllocator allocator, SessionOptions options) throws OrtException { - nativeHandle = - createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelArray, options.nativeHandle); + this( + createSession(OnnxRuntime.ortApiHandle, env.nativeHandle, modelArray, options.nativeHandle), + allocator); + } + + /** + * Private constructor to build the Java object wrapped around a native session. + * + * @param nativeHandle The pointer to the native session. + * @param allocator The allocator to use. + * @throws OrtException If the model's inputs, outputs or metadata could not be read. + */ + private OrtSession(long nativeHandle, OrtAllocator allocator) throws OrtException { + this.nativeHandle = nativeHandle; this.allocator = allocator; numInputs = getNumInputs(OnnxRuntime.ortApiHandle, nativeHandle); inputNames = @@ -289,17 +292,17 @@ public void close() throws OrtException { private static Map wrapInMap(NodeInfo[] infos) { Map output = new LinkedHashMap<>(); - for (int i = 0; i < infos.length; i++) { - output.put(infos[i].getName(), infos[i]); + for (NodeInfo info : infos) { + output.put(info.getName(), info); } return output; } - private native long createSession( + private static native long createSession( long apiHandle, long envHandle, String modelPath, long optsHandle) throws OrtException; - private native long createSession( + private static native long createSession( long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException; @@ -548,6 +551,26 @@ public void addNuphar(boolean allowUnalignedBuffers, String settings) throws Ort addNuphar(OnnxRuntime.ortApiHandle, nativeHandle, allowUnalignedBuffers ? 1 : 0, settings); } + /** + * Adds DirectML as an execution backend. + * + * @param deviceId The id of the DirectML device. + * @throws OrtException If there was an error in native code. + */ + public void addDirectML(int deviceId) throws OrtException { + addDirectML(OnnxRuntime.ortApiHandle, nativeHandle, deviceId); + } + + /** + * Adds the ARM Compute Library as an execution backend. + * + * @param useArena If true use the arena memory allocator. + * @throws OrtException If there was an error in native code. + */ + public void addACL(boolean useArena) throws OrtException { + addACL(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); + } + private native void setExecutionMode(long apiHandle, long nativeHandle, int mode) throws OrtException; @@ -601,6 +624,11 @@ private native void addTensorrt(long apiHandle, long nativeHandle, int deviceNum private native void addNuphar( long apiHandle, long nativeHandle, int allowUnalignedBuffers, String settings) throws OrtException; + + private native void addDirectML(long apiHandle, long nativeHandle, int deviceId) + throws OrtException; + + private native void addACL(long apiHandle, long nativeHandle, int useArena) throws OrtException; } /** diff --git a/java/src/main/java/ai/onnxruntime/SequenceInfo.java b/java/src/main/java/ai/onnxruntime/SequenceInfo.java index ded856bc04301..a417634b72de0 100644 --- a/java/src/main/java/ai/onnxruntime/SequenceInfo.java +++ b/java/src/main/java/ai/onnxruntime/SequenceInfo.java @@ -49,7 +49,7 @@ public class SequenceInfo implements ValueInfo { } /** - * Constructs a sequence of known lenght containing maps. + * Constructs a sequence of known length containing maps. * * @param length The length of the sequence. * @param keyType The map key type. diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index f2d97dd24a86b..e89e0a70c2292 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -205,10 +205,14 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * inf return convertToTensorInfo(jniEnv, api, (const OrtTensorTypeAndShapeInfo *) tensorInfo); } case ONNX_TYPE_SEQUENCE: { - return createEmptySequenceInfo(jniEnv); + const OrtSequenceTypeInfo* sequenceInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToSequenceTypeInfo(info,&sequenceInfo)); + return convertToSequenceInfo(jniEnv, api, sequenceInfo); } case ONNX_TYPE_MAP: { - return createEmptyMapInfo(jniEnv); + const OrtMapTypeInfo* mapInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(info,&mapInfo)); + return convertToMapInfo(jniEnv, api, mapInfo); } case ONNX_TYPE_UNKNOWN: case ONNX_TYPE_OPAQUE: @@ -261,8 +265,56 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT return tensorInfo; } -//jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) { -// As map info isn't available at this point, it creates an empty map info type. +jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info) { + // Create the java methods we need to call. + // Get the ONNXTensorType enum static method + char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType"; + jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName); + jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;"); + + // Get the ONNXJavaType enum static method + char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType"; + jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName); + jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;"); + + // Get the map info class + char *mapInfoClassName = "ai/onnxruntime/MapInfo"; + jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName); + jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V"); + + // Extract the key type + ONNXTensorElementDataType keyType; + checkOrtStatus(jniEnv,api,api->GetMapKeyType(info,&keyType)); + + // Convert key type to java + jint onnxTypeKey = convertFromONNXDataFormat(keyType); + jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey); + jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey); + + // according to include/onnxruntime/core/framework/data_types.h only the following values are supported. + // string, int64, float, double + // So extract the value type, then convert it to a tensor type so we can get it's element type. + OrtTypeInfo* valueTypeInfo; + checkOrtStatus(jniEnv,api,api->GetMapValueType(info,&valueTypeInfo)); + const OrtTensorTypeAndShapeInfo* tensorValueInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(valueTypeInfo,&tensorValueInfo)); + ONNXTensorElementDataType valueType; + checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorValueInfo,&valueType)); + api->ReleaseTypeInfo(valueTypeInfo); + tensorValueInfo = NULL; + valueTypeInfo = NULL; + + // Convert value type to java + jint onnxTypeValue = convertFromONNXDataFormat(valueType); + jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue); + jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue); + + // Construct map info + jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)-1,onnxJavaTypeKey,onnxJavaTypeValue); + + return mapInfo; +} + jobject createEmptyMapInfo(JNIEnv *jniEnv) { // Create the ONNXJavaType enum char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType"; @@ -278,8 +330,72 @@ jobject createEmptyMapInfo(JNIEnv *jniEnv) { return mapInfo; } -//jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) { -// As sequence info isn't available at this point, it creates an empty sequence info type. +jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info) { + // Get the sequence info class + char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo"; + jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName); + + // according to include/onnxruntime/core/framework/data_types.h the following values are supported. + // tensor types, map and map + OrtTypeInfo* elementTypeInfo; + checkOrtStatus(jniEnv,api,api->GetSequenceElementType(info,&elementTypeInfo)); + ONNXType type; + checkOrtStatus(jniEnv,api,api->GetOnnxTypeFromTypeInfo(elementTypeInfo,&type)); + + jobject sequenceInfo; + + switch (type) { + case ONNX_TYPE_TENSOR: { + // Figure out element type + const OrtTensorTypeAndShapeInfo* elementTensorInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(elementTypeInfo,&elementTensorInfo)); + ONNXTensorElementDataType element; + checkOrtStatus(jniEnv,api,api->GetTensorElementType(elementTensorInfo,&element)); + + // Convert element type into ONNXTensorType + jint onnxTypeInt = convertFromONNXDataFormat(element); + // Get the ONNXTensorType enum static method + char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType"; + jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName); + jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;"); + jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeInt); + + // Get the ONNXJavaType enum static method + char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType"; + jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName); + jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;"); + jobject onnxJavaType = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJava); + + // Construct sequence info + jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;)V"); + sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,onnxJavaType); + break; + } + case ONNX_TYPE_MAP: { + // Extract the map info + const OrtMapTypeInfo* mapInfo; + checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(elementTypeInfo,&mapInfo)); + + // Convert it using the existing convert function + jobject javaMapInfo = convertToMapInfo(jniEnv,api,mapInfo); + + // Construct sequence info + jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/MapInfo;)V"); + sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,javaMapInfo); + break; + } + default: { + sequenceInfo = createEmptySequenceInfo(jniEnv); + throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"Invalid element type found in sequence"); + break; + } + } + api->ReleaseTypeInfo(elementTypeInfo); + elementTypeInfo = NULL; + + return sequenceInfo; +} + jobject createEmptySequenceInfo(JNIEnv *jniEnv) { // Create the ONNXJavaType enum char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType"; diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 4f05096215685..e42af5dd6c9ab 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -31,9 +31,8 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * inf jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info); -//TODO when C API supports inspecting the types of map and sequence types from OutputInfos -//jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info); -//jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info); +jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info); +jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info); jobject createEmptyMapInfo(JNIEnv *jniEnv); jobject createEmptySequenceInfo(JNIEnv *jniEnv); diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index 6e1c7ea48820c..e0eeec63dd854 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -14,8 +14,8 @@ * Signature: (JJLjava/lang/String;J)J */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_lang_String_2J - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jstring modelPath, jlong optsHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtSession* session; @@ -43,8 +43,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la * Signature: (JJ[BJ)J */ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJ_3BJ - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jbyteArray jModelArray, jlong optsHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtSession* session; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index c3a5d16a3b9c7..739ab68357916 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -17,6 +17,10 @@ #include "onnxruntime/core/providers/nuphar/nuphar_provider_factory.h" #include "onnxruntime/core/providers/openvino/openvino_provider_factory.h" #include "onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h" +#include "onnxruntime/core/providers/acl/acl_provider_factory.h" +#ifdef USE_DIRECTML +#include "onnxruntime/core/providers/dml/dml_provider_factory.h" +#endif /* * Class: ai_onnxruntime_OrtSession_SessionOptions @@ -233,8 +237,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addNna /* * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: addNuphar - * Signature: (JILjava/lang/String { - })V + * Signature: (JILjava/lang/String)V */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addNuphar (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint allowUnalignedBuffers, jstring settingsString) { @@ -248,3 +251,35 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addNup throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with Nuphar support."); #endif } + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: addDirectML + * Signature: (JJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDirectML + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint deviceID) { + (void)jobj; + #ifdef USE_DIRECTML + checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_DML((OrtSessionOptions*) handle, deviceID)); + #else + (void)apiHandle;(void)handle;(void)deviceID; // Parameters used when DirectML is defined. + throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with DirectML support."); + #endif +} + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: addACL + * Signature: (JJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addACL + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint useArena) { + (void)jobj; + #ifdef USE_ACL + checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_ACL((OrtSessionOptions*) handle,useArena)); + #else + (void)apiHandle;(void)handle;(void)useArena; // Parameters used when ACL is defined. + throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with ACL support."); + #endif +} diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index c3ed213329381..d0521d9429a2c 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -915,6 +915,12 @@ public void testModelSequenceOfMapIntFloat() throws OrtException { assertTrue(firstOutputInfo.getInfo() instanceof TensorInfo); assertTrue(secondOutputInfo.getInfo() instanceof SequenceInfo); assertEquals(OnnxJavaType.INT64, ((TensorInfo) firstOutputInfo.getInfo()).type); + assertTrue(((SequenceInfo) secondOutputInfo.getInfo()).sequenceOfMaps); + assertEquals(OnnxJavaType.UNKNOWN, ((SequenceInfo) secondOutputInfo.getInfo()).sequenceType); + MapInfo mapInfo = ((SequenceInfo) secondOutputInfo.getInfo()).mapInfo; + assertNotNull(mapInfo); + assertEquals(OnnxJavaType.INT64, mapInfo.keyType); + assertEquals(OnnxJavaType.FLOAT, mapInfo.valueType); Map container = new HashMap<>(); long[] shape = new long[] {1, 2}; @@ -975,6 +981,12 @@ public void testModelSequenceOfMapStringFloat() throws OrtException { assertTrue(firstOutputInfo.getInfo() instanceof TensorInfo); assertTrue(secondOutputInfo.getInfo() instanceof SequenceInfo); assertEquals(OnnxJavaType.STRING, ((TensorInfo) firstOutputInfo.getInfo()).type); + assertTrue(((SequenceInfo) secondOutputInfo.getInfo()).sequenceOfMaps); + assertEquals(OnnxJavaType.UNKNOWN, ((SequenceInfo) secondOutputInfo.getInfo()).sequenceType); + MapInfo mapInfo = ((SequenceInfo) secondOutputInfo.getInfo()).mapInfo; + assertNotNull(mapInfo); + assertEquals(OnnxJavaType.STRING, mapInfo.keyType); + assertEquals(OnnxJavaType.FLOAT, mapInfo.valueType); Map container = new HashMap<>(); long[] shape = new long[] {1, 2}; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 1732a26c7707f..fbc8f639e838e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -186,10 +186,8 @@ Status Attention::Compute(OpKernelContext* context) const { BufferUniquePtr scratch_buffer(scratch_data, BufferDeleter(allocator)); { - size_t mask_data_bytes = 0; - if (mask_index != nullptr) { - mask_data_bytes = SafeInt(batch_size) * sequence_length * element_size; - } + size_t mask_data_bytes = + mask_index == nullptr ? SafeInt(0) : SafeInt(batch_size) * sequence_length * element_size; void* mask_data = nullptr; if (mask_data_bytes > 0) { @@ -198,7 +196,7 @@ Status Attention::Compute(OpKernelContext* context) const { } BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator)); - if (mask_index != nullptr) { + if (mask_index != nullptr && mask_data != nullptr) { T* p_mask = reinterpret_cast(mask_data); for (int b_i = 0; b_i < batch_size; b_i++) { // TODO: mask_index can be used in softmax to save some calculation. diff --git a/onnxruntime/contrib_ops/cpu/quantize_ops.cc b/onnxruntime/contrib_ops/cpu/quantize_ops.cc index fb4a95b7dd13e..46587733c4eba 100644 --- a/onnxruntime/contrib_ops/cpu/quantize_ops.cc +++ b/onnxruntime/contrib_ops/cpu/quantize_ops.cc @@ -2,8 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/tensor/quantize_linear.h" -#include "core/providers/common.h" - +#include "core/providers/common.h" namespace onnxruntime { namespace contrib { @@ -13,11 +12,8 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( 1, uint8_t, KernelDefBuilder() - .TypeConstraint("axis", DataTypeImpl::GetType()) - .TypeConstraint("x", DataTypeImpl::GetTensorType()) - .TypeConstraint("x_scale", DataTypeImpl::GetTensorType()) - .TypeConstraint("x_zero_point", DataTypeImpl::GetTensorType()) - .TypeConstraint("y", DataTypeImpl::GetTensorType()), + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), DequantizeLinear); ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( @@ -25,11 +21,8 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( 1, int8_t, KernelDefBuilder() - .TypeConstraint("axis", DataTypeImpl::GetType()) - .TypeConstraint("x", DataTypeImpl::GetTensorType()) - .TypeConstraint("x_scale", DataTypeImpl::GetTensorType()) - .TypeConstraint("x_zero_point", DataTypeImpl::GetTensorType()) - .TypeConstraint("y", DataTypeImpl::GetTensorType()), + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), DequantizeLinear); ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( @@ -37,12 +30,18 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( 1, uint8_t, KernelDefBuilder() - .TypeConstraint("axis", DataTypeImpl::GetType()) - .TypeConstraint("x", DataTypeImpl::GetTensorType()) - .TypeConstraint("y_scale", DataTypeImpl::GetTensorType()) - .TypeConstraint("y_zero_point", DataTypeImpl::GetTensorType()) - .TypeConstraint("y", DataTypeImpl::GetTensorType()), + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), QuantizeLinear); -} // namespace contrib +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + QuantizeLinear, + 1, + int8_t, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QuantizeLinear); + +} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc index 315e4857c54b0..923cd6908a348 100644 --- a/onnxruntime/contrib_ops/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu_contrib_kernels.cc @@ -30,6 +30,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, CDist); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, CDist); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu); @@ -104,6 +105,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h b/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h new file mode 100644 index 0000000000000..2cd062b55c2a8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h @@ -0,0 +1,108 @@ +/*All contributions by Facebook : +Copyright(c) 2016 Facebook Inc. +==============================================================================*/ +/* Modifications Copyright (c) Microsoft. */ + +#pragma once +#include "core/providers/cuda/cuda_common.h" +#include "cufft.h" +#include "cufftXt.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +//key +struct FFTState { + int64_t signal_ndim; + int64_t signal_dims[5]; + cudaDataType itype; + cudaDataType otype; + int64_t batch_size; + cudaDataType exec_type; +}; + +//value +struct CufftPlanInfo { + cufftHandle plan; + size_t ws_size_t; +}; + +// Hashing machinery for Params +// Fowler–Noll–Vo hash function +// see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function +template +struct ParamsHash { + // Params must be a POD because we read out its memory + // contenst as char* when hashing + + static_assert(std::is_pod::value, "Params is not POD"); + size_t operator()(const T& params) const { + auto ptr = reinterpret_cast(¶ms); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < (int)sizeof(T); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return (size_t)value; + } +}; + +template +struct ParamsEqual { + // Params must be a POD because we read out its memory + // contenst as char* when comparing + + static_assert(std::is_pod::value, "Params is not POD"); + + bool operator()(const T& a, const T& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(T)) == 0; + } +}; + +class CuFFTPlanCache { + public: + CufftPlanInfo TryEmplaceValue(FFTState& key) { + std::lock_guard lock(mutex); + + auto it = map.find(key); + if (it == map.end()) { + CufftPlanInfo plan_info = CreatePlanInfo(key); + map.emplace(key, plan_info); + return plan_info; + } else { + return it->second; + } + } + + int64_t GetCacheSize() { return map.size(); } + + std::mutex mutex; + + private: + CufftPlanInfo CreatePlanInfo(FFTState& key) { + cufftHandle plan; + size_t ws_size_t; + CufftPlanInfo plan_info; + + CUFFT_CALL_THROW(cufftCreate(&plan)); + + CUFFT_CALL_THROW(cufftXtMakePlanMany(plan, static_cast(key.signal_ndim), reinterpret_cast(key.signal_dims), + /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, key.itype, + /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, key.otype, + key.batch_size, &ws_size_t, key.exec_type)); + + plan_info.plan = plan; + plan_info.ws_size_t = ws_size_t; + + return plan_info; + } + + std::unordered_map, ParamsEqual> map; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc new file mode 100644 index 0000000000000..3c60644d707c7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc @@ -0,0 +1,156 @@ +/*All contributions by Facebook : +Copyright(c) 2016 Facebook Inc. +==============================================================================*/ +/* Modifications Copyright (c) Microsoft. */ + +#include "fft_ops.h" +#include "fft_ops_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace { +void SetFFTState(FFTState* state, + int64_t signal_ndim, + const std::vector& signal_dims, + cudaDataType itype, + cudaDataType otype, + int64_t batch_size, + cudaDataType exec_type) { + memset(state, 0, sizeof(FFTState)); + state->signal_ndim = signal_ndim; + for (int32_t i = 0; i < signal_dims.size(); ++i) { + state->signal_dims[i] = signal_dims[i]; + } + state->itype = itype; + state->otype = otype; + state->batch_size = batch_size; + state->exec_type = exec_type; +} +} // namespace +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Rfft, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Rfft); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Irfft, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Irfft); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(double) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +Status FFTBase::DoFFT(OpKernelContext* context, const Tensor* X, bool complex_input, bool complex_output, bool inverse) const { + typedef typename ::onnxruntime::cuda::ToCudaType::MappedType CudaT; + + ORT_ENFORCE((complex_input || complex_output) && (complex_input != complex_output), + "Only support RFFT and IRFFT, so either input or output has to be complex type and the other is real type. Got complex input:", + complex_input, " complex output: ", complex_output); + + TensorShape input_shape = X->Shape(); + int64_t input_ndim = input_shape.NumDimensions(); + ORT_ENFORCE(input_ndim >= signal_ndim_, "signal_ndim cannot be greater than the dimension of Input: ", signal_ndim_, " > ", input_ndim); + auto signal_tensor_ndim = signal_ndim_ + static_cast(complex_input); // add complex dim + + cudaDataType itype, otype, exec_type; + if (X->IsDataType()) { + itype = complex_input ? CUDA_C_32F : CUDA_R_32F; + otype = complex_output ? CUDA_C_32F : CUDA_R_32F; + exec_type = CUDA_C_32F; + } else if (X->IsDataType()) { + itype = complex_input ? CUDA_C_64F : CUDA_R_64F; + otype = complex_output ? CUDA_C_64F : CUDA_R_64F; + exec_type = CUDA_C_64F; + } else if (X->IsDataType()) { + itype = complex_input ? CUDA_C_16F : CUDA_R_16F; + otype = complex_output ? CUDA_C_16F : CUDA_R_16F; + exec_type = CUDA_C_16F; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "cuFFT does not support tensor type: ", X->DataType()); + } + + //calculate batch size + int64_t batch_ndim = input_ndim - signal_tensor_ndim; + int64_t batch_size = (batch_ndim == 0 ? 1 : input_shape.SizeToDimension(batch_ndim)); + + //infer output shape + //copy the input shape up to the second last dimention + std::vector output_dims, signal_dims; + int i = 0; + for (; i < batch_ndim + signal_ndim_ - 1; ++i) { + output_dims.push_back(input_shape[i]); + if (i >= batch_ndim) { + signal_dims.push_back(input_shape[i]); + } + } + + //process the last dim(s) + if (onesided_) { + if (complex_input && !complex_output) { //IRFFT + int64_t inferred_size = input_shape[i] * 2 - 1; + output_dims.push_back(inferred_size); + signal_dims.push_back(inferred_size); + } else if (!complex_input && complex_output) { // RFFT + output_dims.push_back(input_shape[i] / 2 + 1); + signal_dims.push_back(input_shape[i]); + } + } else { // not onesided + output_dims.push_back(input_shape[i]); + signal_dims.push_back(input_shape[i]); + } + + if (complex_output) { + output_dims.push_back(2); + } + + FFTState fft_state; + + SetFFTState(&fft_state, signal_ndim_, signal_dims, itype, otype, batch_size, exec_type); + + CufftPlanInfo plan_info = cufft_cache_.TryEmplaceValue(fft_state); + + int64_t output_size = std::accumulate(output_dims.begin(), output_dims.end(), 1ll, std::multiplies()); + + Tensor* Y = const_cast(context)->Output(0, TensorShape(output_dims)); + auto* x_data = reinterpret_cast(X->template Data()); + auto* y_data = reinterpret_cast(Y->template MutableData()); + + CUFFT_RETURN_IF_ERROR(cufftXtExec(plan_info.plan, const_cast(x_data), y_data, inverse ? CUFFT_INVERSE : CUFFT_FORWARD)); + + if (inverse) { + PostProcess(signal_dims, output_size, y_data); + } + + return Status::OK(); +} + +template +Status Rfft::ComputeInternal(OpKernelContext* context) const { + const Tensor* X = context->Input(0); + + return FFTBase::DoFFT(context, X, false, true, false); +} + +template +Status Irfft::ComputeInternal(OpKernelContext* context) const { + const Tensor* X = context->Input(0); + + return FFTBase::DoFFT(context, X, true, false, true); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops.h b/onnxruntime/contrib_ops/cuda/math/fft_ops.h new file mode 100644 index 0000000000000..35a730e6a90ac --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/fft_ops.h @@ -0,0 +1,52 @@ +/*All contributions by Facebook : +Copyright(c) 2016 Facebook Inc. +==============================================================================*/ +/* Modifications Copyright (c) Microsoft. */ + +#pragma once +#include "core/providers/cuda/cuda_common.h" +#include "cufft_plan_cache.h" +#include "cufft.h" +#include "cufftXt.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +class FFTBase : public ::onnxruntime::cuda::CudaKernel { + public: + FFTBase(const OpKernelInfo info) : ::onnxruntime::cuda::CudaKernel{info}, normalized_{0}, onesided_{1} { + ORT_ENFORCE((info.GetAttr("signal_ndim", &signal_ndim_)).IsOK(), + "Attribute signal_ndim is missing in Node ", info.node().Name()); + ORT_ENFORCE(signal_ndim_ >= 1 && signal_ndim_ <= 3, + "Expected signal_ndim to be 1, 2, or 3, but got signal_ndim=", signal_ndim_); + normalized_ = info.GetAttrOrDefault("normalized", 0); + onesided_ = info.GetAttrOrDefault("onesided", 1); + ORT_ENFORCE(normalized_ == 0, "Don't support normalized FFT yet."); + ORT_ENFORCE(onesided_ != 0, "Only support onesided FFT."); + } + + protected: + int64_t signal_ndim_, normalized_, onesided_; + mutable CuFFTPlanCache cufft_cache_; + Status DoFFT(OpKernelContext* context, const Tensor* X, bool complex_input, bool complex_output, bool inverse) const; +}; + +template +class Rfft final : public FFTBase { + public: + Rfft(const OpKernelInfo info) : FFTBase{info} {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + +template +class Irfft final : public FFTBase { + public: + Irfft(const OpKernelInfo info) : FFTBase{info} {} + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.cu b/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.cu new file mode 100644 index 0000000000000..995bf544808dc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.cu @@ -0,0 +1,46 @@ +/*All contributions by Facebook : +Copyright(c) 2016 Facebook Inc. +==============================================================================*/ +/* Modifications Copyright (c) Microsoft. */ + +#pragma once +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cu_inc/binary_elementwise_impl.cuh" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +#include "fft_ops_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +__global__ void _Normalize( + T* data, + const int64_t N, + const int scale) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) + + int index = static_cast(id); + data[index] = data[index] / static_cast(scale); +} + +template +void PostProcess(const std::vector& signal_dims, int64_t N, T* output_data) { + int64_t scale = std::accumulate(signal_dims.begin(), signal_dims.end(), 1ll, std::multiplies()); + int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); + _Normalize<<>>(output_data, N, scale); +} + +#define SPECIALIZED_IMPL(T) \ + template void PostProcess(const std::vector& signal_dims, int64_t N, T* output_data); + +SPECIALIZED_IMPL(float) +SPECIALIZED_IMPL(double) +SPECIALIZED_IMPL(half) + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.h b/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.h new file mode 100644 index 0000000000000..8a7f7789c0077 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.h @@ -0,0 +1,19 @@ +/*All contributions by Facebook : +Copyright(c) 2016 Facebook Inc. +==============================================================================*/ +/* Modifications Copyright (c) Microsoft. */ + +#pragma once +#include "core/providers/cuda/cuda_common.h" +#include "fft_ops.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +void PostProcess(const std::vector& signal_dims, int64_t N, T* output_data); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantize_ops.cc b/onnxruntime/contrib_ops/cuda/quantize_ops.cc new file mode 100644 index 0000000000000..f0afad58a5f03 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantize_ops.cc @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/tensor/quantize_linear.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +#define REGISTER_Q_KERNEL_TYPED(T, U) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + QuantizeLinear, \ + kMSDomain, \ + 1, \ + T##_##U, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + QuantizeLinear); + +REGISTER_Q_KERNEL_TYPED(int8_t, MLFloat16) +REGISTER_Q_KERNEL_TYPED(uint8_t, MLFloat16) + +#define REGISTER_DQ_KERNEL_TYPED(T, U) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DequantizeLinear, \ + kMSDomain, \ + 1, \ + T##_##U, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + DequantizeLinear); + +REGISTER_DQ_KERNEL_TYPED(int8_t, MLFloat16) +REGISTER_DQ_KERNEL_TYPED(uint8_t, MLFloat16) + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda_contrib_kernels.cc index 674546219a3a5..9a139d45e9ad6 100644 --- a/onnxruntime/contrib_ops/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda_contrib_kernels.cc @@ -17,6 +17,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Rfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Rfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Irfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Irfft); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Irfft); // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility @@ -50,8 +56,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double_float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float, LayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear); -void RegisterCudaContribKernels(KernelRegistry& kernel_registry) { +Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -61,6 +71,12 @@ void RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility @@ -93,11 +109,16 @@ void RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { - kernel_registry.Register(function_table_entry()); + ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); } + return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda_contrib_kernels.h b/onnxruntime/contrib_ops/cuda_contrib_kernels.h index 259c7c8570d31..ef2f5866ab9cf 100644 --- a/onnxruntime/contrib_ops/cuda_contrib_kernels.h +++ b/onnxruntime/contrib_ops/cuda_contrib_kernels.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -void RegisterCudaContribKernels(KernelRegistry& kernel_registry); +Status RegisterCudaContribKernels(KernelRegistry& kernel_registry); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 33fe2e78735c9..56a3c5c974e62 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -57,8 +57,8 @@ const OrtMemoryInfo& CPUAllocator::Info() const { return *memory_info_; } std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return (out << info.ToString()); } -ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, OrtAllocatorType type, int id1, - OrtMemType mem_type1, _Out_ OrtMemoryInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, + enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) { if (strcmp(name1, onnxruntime::CPU) == 0) { *out = new OrtMemoryInfo(name1, type, OrtDevice(), id1, mem_type1); } else if (strcmp(name1, onnxruntime::CUDA) == 0) { diff --git a/onnxruntime/core/framework/error_code.cc b/onnxruntime/core/framework/error_code.cc index 1a88f1095584b..4c6f73f915e01 100644 --- a/onnxruntime/core/framework/error_code.cc +++ b/onnxruntime/core/framework/error_code.cc @@ -14,8 +14,14 @@ struct OrtStatus { char msg[1]; // a null-terminated string }; +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 28196) +#pragma warning(disable : 6387) +#endif //Even we say it may not return NULL, indeed it may. -ORT_EXPORT _Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL OrtApis::CreateStatus(OrtErrorCode code, _In_ const char* msg) NO_EXCEPTION { +_Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL OrtApis::CreateStatus(OrtErrorCode code, + _In_z_ const char* msg) NO_EXCEPTION { assert(!(code == 0 && msg != nullptr)); SafeInt clen(strlen(msg)); OrtStatus* p = reinterpret_cast(::malloc(sizeof(OrtStatus) + clen)); @@ -27,17 +33,22 @@ ORT_EXPORT _Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL OrtApis::CreateS } namespace onnxruntime { -OrtStatus* ToOrtStatus(const Status& st) { +_Ret_notnull_ OrtStatus* ToOrtStatus(const Status& st) { if (st.IsOK()) return nullptr; SafeInt clen(st.ErrorMessage().length()); OrtStatus* p = reinterpret_cast(::malloc(sizeof(OrtStatus) + clen)); + if (p == nullptr) + return nullptr; p->code = static_cast(st.Code()); memcpy(p->msg, st.ErrorMessage().c_str(), clen); p->msg[clen] = '\0'; return p; } } // namespace onnxruntime +#ifdef _MSC_VER +#pragma warning(pop) +#endif ORT_API(OrtErrorCode, OrtApis::GetErrorCode, _In_ const OrtStatus* status) { return status->code; } diff --git a/onnxruntime/core/framework/error_code_helper.h b/onnxruntime/core/framework/error_code_helper.h index f6ec883bc1f16..fd7e10c62dac9 100644 --- a/onnxruntime/core/framework/error_code_helper.h +++ b/onnxruntime/core/framework/error_code_helper.h @@ -5,9 +5,10 @@ #include "core/common/status.h" #include "core/common/exceptions.h" +#include "core/session/onnxruntime_c_api.h" namespace onnxruntime { -OrtStatus* ToOrtStatus(const onnxruntime::common::Status& st); +_Ret_notnull_ OrtStatus* ToOrtStatus(const onnxruntime::common::Status& st); }; #define API_IMPL_BEGIN try { diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index ed972abb930d4..c2bc15b55e4ae 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -31,7 +31,7 @@ class KernelRegistryManager { KernelRegistryManager() = default; // Register kernels from providers - Status RegisterKernels(const ExecutionProviders& execution_providers); + Status RegisterKernels(const ExecutionProviders& execution_providers) ORT_MUST_USE_RESULT; // The registry passed in this function has highest priority than anything already in this KernelRegistryManager, // and anything registered from RegisterKernels @@ -44,10 +44,9 @@ class KernelRegistryManager { // This function assumes the node is already assigned to an execution provider // Don't call this function before graph partition is done - Status CreateKernel(const onnxruntime::Node& node, - const IExecutionProvider& execution_provider, + Status CreateKernel(const onnxruntime::Node& node, const IExecutionProvider& execution_provider, const SessionState& session_state, - /*out*/ std::unique_ptr& op_kernel) const; + /*out*/ std::unique_ptr& op_kernel) const ORT_MUST_USE_RESULT; // This function assumes the node is already assigned to an execution provider // Don't call this function before graph partition is done diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index 107cdbbed10c2..ec08108afbda9 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -66,19 +66,20 @@ OrtStatus* OrtMapTypeInfo::Clone(OrtMapTypeInfo** out) { } // OrtMapTypeInfo Accessors -ORT_API_STATUS_IMPL(OrtApis::GetMapKeyType, const OrtMapTypeInfo* map_type_info, enum ONNXTensorElementDataType* out) { +ORT_API_STATUS_IMPL(OrtApis::GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, + _Out_ enum ONNXTensorElementDataType* out) { API_IMPL_BEGIN *out = map_type_info->map_key_type_; return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, const OrtMapTypeInfo* map_type_info, OrtTypeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** out) { API_IMPL_BEGIN return map_type_info->map_value_type_->Clone(out); API_IMPL_END } -ORT_API(void, OrtApis::ReleaseMapTypeInfo, OrtMapTypeInfo* ptr) { +ORT_API(void, OrtApis::ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo* ptr) { delete ptr; } \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc index a5ee0c9a63bb1..18427fd5fcaae 100644 --- a/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_sequence_type_info.cc @@ -38,12 +38,13 @@ OrtStatus* OrtSequenceTypeInfo::Clone(OrtSequenceTypeInfo** out) { return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, const OrtSequenceTypeInfo* sequence_type_info, OrtTypeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, + _Outptr_ OrtTypeInfo** out) { API_IMPL_BEGIN return sequence_type_info->sequence_key_type_->Clone(out); API_IMPL_END } -ORT_API(void, OrtApis::ReleaseSequenceTypeInfo, OrtSequenceTypeInfo* ptr) { +ORT_API(void, OrtApis::ReleaseSequenceTypeInfo, _Frees_ptr_opt_ OrtSequenceTypeInfo* ptr) { delete ptr; } \ No newline at end of file diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index c79409d3e9fbb..489adbbf840fc 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -49,31 +49,35 @@ OrtTypeInfo::~OrtTypeInfo() { } } -ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input, ONNXType* out) { +ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeInfo* input, _Out_ ONNXType* out) { *out = input->type; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, const struct OrtTensorTypeAndShapeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, + _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { *out = input->type == ONNX_TYPE_TENSOR ? input->data : nullptr; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, const OrtTypeInfo* type_info, const OrtMapTypeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtMapTypeInfo** out) { API_IMPL_BEGIN *out = type_info->type == ONNX_TYPE_MAP ? type_info->map_type_info : nullptr; return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToSequenceTypeInfo, const OrtTypeInfo* type_info, const OrtSequenceTypeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out) { API_IMPL_BEGIN *out = type_info->type == ONNX_TYPE_SEQUENCE ? type_info->sequence_type_info : nullptr; return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, const OrtTypeInfo* type_info, const char** const out, size_t* len) { +ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const out, + _Out_ size_t* len) { API_IMPL_BEGIN *out = type_info->denotation.c_str(); *len = type_info->denotation.size(); diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index 62d4e8cd25b0f..553875e1bc6a3 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -9,14 +9,9 @@ using namespace ::onnxruntime::common; namespace onnxruntime { -OpKernelContext::OpKernelContext(IExecutionFrame* frame, - const OpKernel* kernel, - concurrency::ThreadPool* threadpool, - const logging::Logger& logger) - : execution_frame_(frame), - kernel_(kernel), - threadpool_(threadpool), - logger_(&logger) { +OpKernelContext::OpKernelContext(_Inout_ IExecutionFrame* frame, _In_ const OpKernel* kernel, + _In_opt_ concurrency::ThreadPool* threadpool, _In_ const logging::Logger& logger) + : execution_frame_(frame), kernel_(kernel), threadpool_(threadpool), logger_(&logger) { ORT_ENFORCE(frame != nullptr, "Execution frame was null"); ORT_ENFORCE(kernel != nullptr, "OpKernel was null"); diff --git a/onnxruntime/core/framework/run_options.cc b/onnxruntime/core/framework/run_options.cc index a50a2b7728e65..8ecb55ee483ac 100644 --- a/onnxruntime/core/framework/run_options.cc +++ b/onnxruntime/core/framework/run_options.cc @@ -6,40 +6,40 @@ #include "core/session/ort_apis.h" #include "core/framework/error_code_helper.h" -ORT_API_STATUS_IMPL(OrtApis::CreateRunOptions, OrtRunOptions** out) { +ORT_API_STATUS_IMPL(OrtApis::CreateRunOptions, _Outptr_ OrtRunOptions** out) { API_IMPL_BEGIN *out = new OrtRunOptions(); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::RunOptionsSetRunLogVerbosityLevel, _In_ OrtRunOptions* options, int value) { +ORT_API_STATUS_IMPL(OrtApis::RunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int value) { options->run_log_verbosity_level = value; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::RunOptionsSetRunLogSeverityLevel, _In_ OrtRunOptions* options, int value) { +ORT_API_STATUS_IMPL(OrtApis::RunOptionsSetRunLogSeverityLevel, _Inout_ OrtRunOptions* options, int value) { options->run_log_severity_level = value; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::RunOptionsSetRunTag, _In_ OrtRunOptions* options, _In_ const char* run_tag) { +ORT_API_STATUS_IMPL(OrtApis::RunOptionsSetRunTag, _Inout_ OrtRunOptions* options, _In_ const char* run_tag) { if (run_tag) options->run_tag = run_tag; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::RunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, int* out) { +ORT_API_STATUS_IMPL(OrtApis::RunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, _Out_ int* out) { *out = options->run_log_verbosity_level; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::RunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, int* out) { +ORT_API_STATUS_IMPL(OrtApis::RunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, _Out_ int* out) { *out = options->run_log_severity_level; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::RunOptionsGetRunTag, _In_ const OrtRunOptions* options, const char** out) { +ORT_API_STATUS_IMPL(OrtApis::RunOptionsGetRunTag, _In_ const OrtRunOptions* options, _Out_ const char** out) { *out = options->run_tag.c_str(); return nullptr; } diff --git a/onnxruntime/core/framework/session_state_initializer.cc b/onnxruntime/core/framework/session_state_initializer.cc index a80d6b6d47c85..0bf32ae6f7f4c 100644 --- a/onnxruntime/core/framework/session_state_initializer.cc +++ b/onnxruntime/core/framework/session_state_initializer.cc @@ -56,9 +56,8 @@ SessionStateInitializer::SessionStateInitializer(bool enable_mem_pattern, enable_mem_pattern_(enable_mem_pattern) {} common::Status SessionStateInitializer::CreatePlan( - const Node* parent_node, - const ConstPointerContainer>* outer_scope_node_args, - ExecutionMode execution_mode) { + _In_opt_ const Node* parent_node, + _In_opt_ const ConstPointerContainer>* outer_scope_node_args, ExecutionMode execution_mode) { session_state_.SetGraph(graph_); const GraphViewer* graph_viewer = session_state_.GetGraphViewer(); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 64bb11dbcbcfa..692531140f463 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -21,7 +21,7 @@ using onnxruntime::MLFloat16; using onnxruntime::SparseTensor; using onnxruntime::Tensor; -ORT_API_STATUS_IMPL(OrtApis::CreateTensorTypeAndShapeInfo, _Out_ OrtTensorTypeAndShapeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN *out = new OrtTensorTypeAndShapeInfo(); return nullptr; @@ -32,7 +32,8 @@ ORT_API(void, OrtApis::ReleaseTensorTypeAndShapeInfo, _Frees_ptr_opt_ OrtTensorT delete ptr; } -ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _In_ OrtTensorTypeAndShapeInfo* this_ptr, enum ONNXTensorElementDataType type) { +ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo* this_ptr, + enum ONNXTensorElementDataType type) { API_IMPL_BEGIN this_ptr->type = type; return nullptr; @@ -62,7 +63,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetDimensions, _In_ const struct OrtTensorTypeAndSh } ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, _In_ const struct OrtTensorTypeAndShapeInfo* info, - _Out_ const char** names, size_t dim_params_length) { + _Out_writes_all_(dim_params_length) const char** names, size_t dim_params_length) { for (size_t idx = 0, end = std::min(info->dim_params.size(), dim_params_length); idx < end; ++idx) { names[idx] = info->dim_params[idx].c_str(); } @@ -197,7 +198,7 @@ OrtStatus* OrtTensorTypeAndShapeInfo::Clone(OrtTensorTypeAndShapeInfo** out) return GetTensorShapeAndTypeHelper(type, shape, &dim_params, out); } -ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Out_ OrtTensorTypeAndShapeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN onnxruntime::MLDataType type = v->Type(); ORT_ENFORCE(type != nullptr, "OrtValue is not a Tensor"); @@ -238,7 +239,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetValueType, _In_ const OrtValue* v, _Out_ ONNXTyp * \param value * \return The returned value should be freed by OrtReleaseTypeInfo after use */ -ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, struct OrtTypeInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::GetTypeInfo, _In_ const OrtValue* v, _Outptr_result_maybenull_ struct OrtTypeInfo** out) { API_IMPL_BEGIN // TODO: This is consistent with the previous implementation but inconsistent with GetValueType which returns // ONNX_TYPE_UNKNOWN if v->Type() is null. Should we instead just call OrtTypeInfo::FromOrtValue and diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 6dee78fb240bf..bf47f74f44298 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -265,25 +265,8 @@ struct UnInitializeParam { ONNXTensorElementDataType ele_type; }; -// In the future, we may make these two function as public C API -/** - * Initialize a buffer for being used with the OrtCreateTensorWithDataAsOrtValue function - * - */ -ORT_API_STATUS(OrtInitializeBufferForTensor, _In_opt_ void* input, size_t input_len, - enum ONNXTensorElementDataType type); - -/** - * Uninitialize the buffer that was initialized by the OrtInitializeBufferForTensor function - * - */ -ORT_API(void, OrtUninitializeBuffer, _In_opt_ void* input, size_t input_len, enum ONNXTensorElementDataType type); -static void UnInitTensor(void* param) noexcept { - UnInitializeParam* p = reinterpret_cast(param); - OrtUninitializeBuffer(p->preallocated, p->preallocated_size, p->ele_type); - delete p; -} + ORT_API_STATUS_IMPL(OrtInitializeBufferForTensor, _In_opt_ void* input, size_t input_len, enum ONNXTensorElementDataType type) { @@ -310,6 +293,12 @@ ORT_API(void, OrtUninitializeBuffer, _In_opt_ void* input, size_t input_len, enu } } +static void UnInitTensor(void* param) noexcept { + UnInitializeParam* p = reinterpret_cast(param); + OrtUninitializeBuffer(p->preallocated, p->preallocated_size, p->ele_type); + delete p; +} + #define CASE_PROTO(X, Y) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ ORT_RETURN_IF_ERROR( \ @@ -368,7 +357,10 @@ static void MoveOrtCallback(OrtCallback& from, OrtCallback& to) { from.f = nullptr; from.param = nullptr; } - +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 6239) +#endif Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path, const ONNX_NAMESPACE::TensorProto& tensor_proto, const MemBuffer& m, OrtValue& value, OrtCallback& deleter) { @@ -486,7 +478,10 @@ Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path, ml_tensor->GetDeleteFunc()); return Status::OK(); } - +#ifdef _MSC_VER +#pragma warning(pop) +#pragma warning(disable : 6239) +#endif #define CASE_TYPE(X) \ case ONNX_NAMESPACE::TensorProto_DataType_##X: \ return ONNX_TENSOR_ELEMENT_DATA_TYPE_##X; diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 48b8742624f68..d9974cb4c99d3 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -12,7 +12,6 @@ #include "core/framework/ml_value.h" #include "core/framework/mem_buffer.h" #include "core/framework/tensor_external_data_info.h" -#include "core/session/onnxruntime_cxx_api.h" #include "core/graph/onnx_protobuf.h" #include "core/platform/env.h" diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 25cfd774e14ce..b7fa3316551a4 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -963,6 +963,28 @@ Sample echo operator.)DOC"); ONNX_NAMESPACE::convPoolShapeInference(ctx, false, true, 0, 1); }); + ONNX_CONTRIB_OPERATOR_SCHEMA(Rfft) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(R"DOC()DOC") + .Input(0, "X", "input tensor", "T") + .Attr("signal_ndim", "", AttributeProto::INT) + .Attr("normalized", "", AttributeProto::INT, static_cast(0)) + .Attr("onesided", "", AttributeProto::INT, static_cast(1)) + .Output(0, "Y", "output tensor", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(Irfft) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(R"DOC()DOC") + .Input(0, "X", "input tensor", "T") + .Attr("signal_ndim", "", AttributeProto::INT) + .Attr("normalized", "", AttributeProto::INT, static_cast(0)) + .Attr("onesided", "", AttributeProto::INT, static_cast(1)) + .Output(0, "Y", "output tensor", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)"}, "Constrain input and output types to float or half tensors."); + ONNX_CONTRIB_OPERATOR_SCHEMA(ConvTransposeWithDynamicPads) .SetDomain(kMSDomain) .SinceVersion(1) @@ -1226,9 +1248,10 @@ activation and leaky_relu_alpha.)DOC") ONNX_CONTRIB_OPERATOR_SCHEMA_ELSEWHERE(Range, RegisterRangeOpSchema); static const char* QuantizeLinear_ver1_doc = R"DOC( -The linear quantization operator. It consumes a full precision data, a scale, a zero point and computes the quantized data. -The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), it computes the nearest integer value to arg (in floating-point format), - rounding halfway cases away from zero. Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC"; +The linear quantization operator. It consumes a full precision data, a scale, a zero point to compute the low precision / quantized tensor. +The quantization formula is y = saturate ((x / y_scale) + y_zero_point).For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8. +For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. +Scale and zero point must have same shape. They must be either scalar (per tensor) or 1-D tensor (per 'axis').)DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(QuantizeLinear) .SetDomain(kMSDomain) @@ -1265,7 +1288,7 @@ The quantization formula is y = (x / y_scale) + y_zero_point. For (x / y_scale), "T2") .TypeConstraint( "T1", - {"tensor(float)"}, + {"tensor(float16)", "tensor(float)"}, "Constrain 'x', 'y_scale' to float tensors.") .TypeConstraint( "T2", @@ -1296,35 +1319,36 @@ Scale and zero point must have same shape. They must be either scalar (per tenso "If it's specified, it means per 'axis' quantization and input 'x_scale' and 'x_zero_point' must be 1-D tensors.", AttributeProto::INT, false) - .Input(0, - "x", - "N-D quantized Input tensor to be de-quantized.", - "T2") + .Input( + 0, + "x", + "N-D quantized Input tensor to be de-quantized.", + "T1") .Input( 1, "x_scale", "Scale for input 'x'. It could be a scalar or a 1-D tensor, which means a per-tensor or per-axis quantization." "If it's a 1-D tensor, its number of elements should be equal to the dimension value of 'axis' dimension of input 'x'.", - "T1") + "T2") .Input( 2, "x_zero_point", "Zero point for input 'x'. It could be a scalar or a 1-D tensor, which means a per-tensor or per-axis quantization." "If it's a 1-D tensor, its number of elements should be equal to the dimension value of 'axis' dimension of input 'x'.", - "T2") + "T1") .Output( 0, "y", "N-D full precision output tensor. It has same shape as input 'x'.", - "T1") + "T2") .TypeConstraint( "T1", - {"tensor(float)"}, - "Constrain 'y', 'x_scale' to float tensors.") + {"tensor(int8)", "tensor(uint8)"}, + "Constrain 'x' and 'x_zero_point' to 8-bit integer tensors.") .TypeConstraint( "T2", - {"tensor(int8)", "tensor(uint8)"}, - "Constrain 'x_zero_point' and 'x' to 8-bit integer tensors.") + {"tensor(float16)", "tensor(float)"}, + "Constrain 'y', 'x_scale' to float tensors.") .SetDoc(DequantizeLinear_ver1_doc) .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { auto y_type = ctx.getOutputType(0); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 2cba64b36a806..6d1cf7547cb0b 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -778,6 +778,13 @@ Graph::Graph(const Model& owning_model, if (matching_graph_input == nullptr) { name_to_type_map[tensor.name()] = t; ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor.name(), &t)); + } else { + LOGS(logger_, WARNING) << "Initializer " << tensor.name() + << " appears in graph inputs and will not be treated as constant value/weight. " + << "This may fail some of the graph optimizations, like const folding. " + << "Move it out of graph inputs if there is no need to override it, " + << "by either re-generating the model with latest exporter/converter " + << "or with the tool onnxruntime/tools/python/remove_initializer_from_input.py."; } } } diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 22a4980e96178..cb1cc3caa3f79 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -3,7 +3,6 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph.h" -#include "core/framework/tensorprotoutils.h" #include "core/common/logging/logging.h" namespace onnxruntime { diff --git a/onnxruntime/core/mlas/lib/erf.cpp b/onnxruntime/core/mlas/lib/erf.cpp index 1fc538087cab4..34390f9582581 100644 --- a/onnxruntime/core/mlas/lib/erf.cpp +++ b/onnxruntime/core/mlas/lib/erf.cpp @@ -23,8 +23,6 @@ Module Name: #include "mlasi.h" -#include - // // Bundles the constants for use by kernels written in assembly. // diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 293efda47cb70..5128c7a7f02ea 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -21,6 +21,7 @@ Module Name: #include #include #include +#include #if defined(_WIN32) #include diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 5bcd688385729..629be7a0f430b 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -42,8 +42,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, // constant folding does not support executing a node that includes subgraphs (control flow operators, // such as If/Loop/Scan, fall into this category). individual nodes in the subgraph will be processed // by the Recurse call above - node->ContainsSubgraph() || - !graph_utils::AllNodeInputsAreConstant(graph, *node, constant_inputs)) { + node->ContainsSubgraph() || !graph_utils::AllNodeInputsAreConstant(graph, *node, constant_inputs)) { continue; } @@ -68,7 +67,9 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); auto* kernel = info.GetKernel(node->Index()); - OpKernelContext op_kernel_context(&frame, kernel, nullptr, onnxruntime::logging::LoggingManager::DefaultLogger()); + if (kernel == nullptr) + continue; + OpKernelContext op_kernel_context(&frame, kernel, nullptr, logger); ORT_RETURN_IF_ERROR(kernel->Compute(&op_kernel_context)); @@ -93,8 +94,7 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, const auto* constant_arg_out = node->OutputDefs()[fetch_idx]; ORT_ENFORCE(ort_value.IsTensor()); const Tensor& out_tensor = ort_value.Get(); - ONNX_NAMESPACE::TensorProto out_tensorproto = - utils::TensorToTensorProto(out_tensor, constant_arg_out->Name()); + ONNX_NAMESPACE::TensorProto out_tensorproto = utils::TensorToTensorProto(out_tensor, constant_arg_out->Name()); graph.AddInitializedTensor(out_tensorproto); } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 14b686cbc9199..823e15e0f1ee2 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -122,7 +122,10 @@ std::vector> GenerateTransformers(TransformerL #ifndef DISABLE_CONTRIB_OPS transformers.emplace_back(onnxruntime::make_unique(cpu_execution_providers)); - transformers.emplace_back(onnxruntime::make_unique(cpu_execution_providers)); + + std::unordered_set cpu_acl_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider}; + + transformers.emplace_back(onnxruntime::make_unique(cpu_acl_execution_providers)); std::unordered_set cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider}; transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index fdcab6c247eea..327639f4257bf 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -271,9 +271,11 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // remove all the other nodes. graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node); +#ifdef ENABLE_TRAINING // add two extra output defs, so we have 3 output defs that match what gradient builder expected layer_norm_node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_mean"), nullptr)); layer_norm_node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_inv_std_var"), nullptr)); +#endif modified = true; } diff --git a/onnxruntime/core/platform/env.h b/onnxruntime/core/platform/env.h index 1f88b89f4689a..90019b9402643 100644 --- a/onnxruntime/core/platform/env.h +++ b/onnxruntime/core/platform/env.h @@ -93,7 +93,7 @@ class Env { * Caller is responsible for deleting the returned value */ // clang-format on - virtual EnvThread* CreateThread(_In_opt_ const ORTCHAR_T* name_prefix, int index, + virtual EnvThread* CreateThread(_In_opt_z_ const ORTCHAR_T* name_prefix, int index, _In_ unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* threadpool, const ThreadOptions& thread_options) = 0; virtual Task CreateTask(std::function f) = 0; @@ -130,7 +130,7 @@ class Env { /** * Gets the length of the specified file. */ - virtual common::Status GetFileLength(const ORTCHAR_T* file_path, size_t& length) const = 0; + virtual common::Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const = 0; /** * Copies the content of the file into the provided buffer. @@ -139,7 +139,7 @@ class Env { * @param length The length in bytes to read. * @param buffer The buffer in which to write. */ - virtual common::Status ReadFileIntoBuffer(const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, + virtual common::Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, gsl::span buffer) const = 0; using MappedMemoryPtr = std::unique_ptr; @@ -154,7 +154,7 @@ class Env { * @param[out] mapped_memory A smart pointer to the mapped memory which * unmaps the memory (unless release()'d) when destroyed. */ - virtual common::Status MapFileIntoMemory(const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, + virtual common::Status MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, MappedMemoryPtr& mapped_memory) const = 0; #ifdef _WIN32 diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 229de9fabaa0b..5e762b77741fe 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -100,7 +100,7 @@ class WindowsThread : public EnvThread { class WindowsEnv : public Env { public: - EnvThread* CreateThread(const ORTCHAR_T* name_prefix, int index, + EnvThread* CreateThread(_In_opt_z_ const ORTCHAR_T* name_prefix, int index, unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { return new WindowsThread(name_prefix, index, start_address, param, thread_options); @@ -150,7 +150,7 @@ class WindowsEnv : public Env { return GetCurrentProcessId(); } - Status GetFileLength(const ORTCHAR_T* file_path, size_t& length) const override { + Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const override { wil::unique_hfile file_handle{ CreateFileW(file_path, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL)}; LARGE_INTEGER filesize; @@ -165,7 +165,7 @@ class WindowsEnv : public Env { return Status::OK(); } - Status ReadFileIntoBuffer(const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, + Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, const gsl::span buffer) const override { ORT_RETURN_IF_NOT(file_path); ORT_RETURN_IF_NOT(offset >= 0); @@ -212,7 +212,7 @@ class WindowsEnv : public Env { return Status::OK(); } - Status MapFileIntoMemory(const ORTCHAR_T*, FileOffsetType, size_t, MappedMemoryPtr&) const override { + Status MapFileIntoMemory(_In_z_ const ORTCHAR_T*, FileOffsetType, size_t, MappedMemoryPtr&) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "MapFileIntoMemory is not implemented on Windows."); } diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 0a5a46e63dfea..93b10220a2ce7 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/platform/windows/telemetry.h" -#include "core/common/version.h" +#include "onnxruntime_config.h" // ETW includes // need space after Windows.h to prevent clang-format re-ordering breaking the build. @@ -105,7 +105,7 @@ void WindowsTelemetry::LogProcessInfo() const { TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), // Telemetry info TraceLoggingUInt8(0, "schemaVersion"), - TraceLoggingString(ONNXRUNTIME_VERSION_STRING, "runtimeVersion"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingBool(true, "isRedist")); process_info_logged = true; diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index f1d08995438e1..8319121abf1cc 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. // Licensed under the MIT License. #include "acl_execution_provider.h" @@ -37,6 +37,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOn // Opset 10 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 10, 10, float, AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, float, FusedConv); + static void RegisterACLKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); @@ -58,6 +60,8 @@ static void RegisterACLKernels(KernelRegistry& kernel_registry) { // Opset 10 kernel_registry.Register(BuildKernelCreateInfo()); + + kernel_registry.Register(BuildKernelCreateInfo()); } std::shared_ptr GetAclKernelRegistry() { diff --git a/onnxruntime/core/providers/acl/math/gemm.h b/onnxruntime/core/providers/acl/math/gemm.h index c171837c6d108..7bbfbf0a917de 100644 --- a/onnxruntime/core/providers/acl/math/gemm.h +++ b/onnxruntime/core/providers/acl/math/gemm.h @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. // Licensed under the MIT License. #pragma once @@ -19,12 +19,26 @@ #include "arm_compute/runtime/MemoryManagerOnDemand.h" //NEON -#include "arm_compute/runtime/NEON/functions/NEGEMM.h" -#include "arm_compute/runtime/NEON/functions/NETranspose.h" #include "arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h" -#undef GEMM_ACL -#define CACHE_TRANSPOSED_DATA +template +void importDataFromTensor(arm_compute::Tensor* tensor, T* data){ + + arm_compute::Window aclInpuWindow; + aclInpuWindow.use_tensor_dimensions(tensor->info()->tensor_shape()); + + arm_compute::Iterator aclInputIt(tensor, aclInpuWindow); + const unsigned int aclWidth = tensor->info()->dimension(0); + const unsigned int aclHeight = tensor->info()->dimension(1); + + // copy input tensor into the larger buffer + arm_compute::execute_window_loop( + aclInpuWindow, + [&](const arm_compute::Coordinates& co) { + data[co.z() * (aclWidth * aclHeight) + co.y() * aclWidth + co.x()] = *reinterpret_cast(aclInputIt.ptr()); + }, + aclInputIt); +} namespace onnxruntime { namespace acl { @@ -53,27 +67,51 @@ class Gemm : public onnxruntime::Gemm { } Status Compute(OpKernelContext* context) const override { - const auto X = context->Input(0); - const auto W = context->Input(1); - const auto B = context->Input(2); + const auto A = context->Input(0); + const auto B = context->Input(1); + const auto C = context->Input(2); - GemmHelper helper(X->Shape(), trans_A_ != CblasNoTrans, W->Shape(), trans_B_ != CblasNoTrans, B->Shape()); + GemmHelper helper(A->Shape(), trans_A_ != CblasNoTrans, B->Shape(), trans_B_ != CblasNoTrans, C->Shape()); if (!helper.State().IsOK()) return helper.State(); int64_t M = helper.M(); int64_t N = helper.N(); - auto Y = context->Output(0, TensorShape({M, N})); + auto D = context->Output(0, TensorShape({M, N})); + + bool FC = alpha_ == 1 && (beta_ == 1 || beta_ == 0); + bool useC = C != nullptr && beta_ != 0; - bool FC = ((alpha_ == 1 && beta_ == 1) || (alpha_ == 1 && beta_ == 0)); + if(trans_A_ == CblasTrans){ // transpose input + return onnxruntime::Gemm::Compute(context); + } + + arm_compute::TensorShape cShape = ACLTensorShape(C->Shape()); + if(useC && + (cShape.num_dimensions() > 2 || + (cShape.num_dimensions() == 2 && cShape[0] > 1 && cShape[1] > 1))) { // Multi-dimensional Bias + return onnxruntime::Gemm::Compute(context); + } + + if(useC && (cShape.num_dimensions() == 1 && cShape[0] != (long unsigned int) N)) { // Broadcast + return onnxruntime::Gemm::Compute(context); + } + + if(useC && cShape.num_dimensions() == 2){ + if((cShape[0] == 1 && cShape[1] != (long unsigned int) N) || + (cShape[1] == 1 && cShape[0] != (long unsigned int) N)) { + return onnxruntime::Gemm::Compute(context); + } + cShape = arm_compute::TensorShape(1, N); + } int64_t K = helper.K(); LOGS_DEFAULT(VERBOSE) << "Gemm ACL:" << std::endl; - if (X) LOGS_DEFAULT(VERBOSE) << "X " << X->Shape().ToString().c_str() << std::endl; - if (W) LOGS_DEFAULT(VERBOSE) << "W " << W->Shape().ToString().c_str() << std::endl; + if (A) LOGS_DEFAULT(VERBOSE) << "A " << A->Shape().ToString().c_str() << std::endl; if (B) LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str() << std::endl; - LOGS_DEFAULT(VERBOSE) << "Y " << Y->Shape().ToString().c_str() << std::endl; + if (C) LOGS_DEFAULT(VERBOSE) << "C " << C->Shape().ToString().c_str() << std::endl; + LOGS_DEFAULT(VERBOSE) << "D " << D->Shape().ToString().c_str() << std::endl; LOGS_DEFAULT(VERBOSE) << "M " << (int)M << ", N " << (int)N << ", K " << (int)K << std::endl; LOGS_DEFAULT(VERBOSE) << "Alfa " << alpha_ << ", Beta " << beta_ << std::endl; LOGS_DEFAULT(VERBOSE) << "trans_A_ " << (trans_A_ == CblasTrans) << std::endl; @@ -89,51 +127,22 @@ class Gemm : public onnxruntime::Gemm { tGEMM.c = std::make_shared(); tGEMM.d = std::make_shared(); - tGEMM.a->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape()), arm_compute::Format::F32)); - tGEMM.c->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(B->Shape()), arm_compute::Format::F32)); + tGEMM.a->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(A->Shape()), arm_compute::Format::F32)); + tGEMM.b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(B->Shape()), arm_compute::Format::F32)); + tGEMM.c->allocator()->init(arm_compute::TensorInfo(cShape, arm_compute::Format::F32)); // dimensions are stored in the opposite order to ACL's - tGEMM.d->allocator()->init(arm_compute::TensorInfo(arm_compute::TensorShape(N, M), tGEMM.a->info()->format())); - - // transpose - if (!FC && trans_B_ == CblasTrans) { - auto trans_layer = std::make_shared(); - tGEMM.b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32)); - - arm_compute::Tensor tmp; - tmp.allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32)); - - trans_layer->configure(&tmp, tGEMM.b.get()); - - const T* b_data = W->template Data(); - ACLImportMemory(tmp.allocator(), (void*)b_data, W->Shape().Size() * 4); - - tGEMM.b->allocator()->allocate(); - - trans_layer->run(); - } else { - tGEMM.b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32)); - } - + tGEMM.d->allocator()->init(arm_compute::TensorInfo(arm_compute::TensorShape(N, M), arm_compute::Format::F32)); + tGEMM.mm_layer = ACLCreateMemoryManager(); if(FC) { auto layer = std::make_shared(tGEMM.mm_layer); - layer->configure(tGEMM.a.get(), tGEMM.b.get(), (B != nullptr && beta_ != 0) ? tGEMM.c.get() : nullptr, tGEMM.d.get()); + arm_compute::FullyConnectedLayerInfo fc_info; + fc_info.transpose_weights = trans_B_ == CblasTrans; + layer->configure(tGEMM.a.get(), tGEMM.b.get(), useC ? tGEMM.c.get() : nullptr, tGEMM.d.get(), fc_info); tGEMM.layer = std::move(layer); } else { -#ifdef GEMM_ACL - auto layer = std::make_shared(tGEMM.mm_layer); - layer->configure(tGEMM.a.get(), tGEMM.b.get(), (B != nullptr && beta_ != 0) ? tGEMM.c.get() : nullptr, tGEMM.d.get(), alpha_, beta_, arm_compute::GEMMInfo()); - tGEMM.layer = std::move(layer); -#else return onnxruntime::Gemm::Compute(context); -#endif - } - - // non-transpose - if (FC || trans_B_ != CblasTrans) { - const T* b_data = W->template Data(); - ACLImportMemory(tGEMM.b->allocator(), (void*)b_data, W->Shape().Size() * 4); } std::pair ret; @@ -142,41 +151,24 @@ class Gemm : public onnxruntime::Gemm { } else { //TODO: valildate shapes pGEMM = &it->second; + } - // transpose - if (!FC && trans_B_ == CblasTrans) { -#ifndef CACHE_TRANSPOSED_DATA - auto trans_layer = std::make_shared(); - pGEMM->b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32)); - - arm_compute::Tensor tmp; - tmp.allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32)); - - trans_layer->configure(&tmp, pGEMM->b.get()); - - const T* b_data = W->template Data(); - ACLImportMemory(tmp.allocator(), (void*)b_data, W->Shape().Size() * 4); - - // allocate memory for b - pGEMM->b->allocator()->allocate(); + const T* a_data = A->template Data(); + const T* b_data = B->template Data(); + T* d_data = D->template MutableData(); - trans_layer->run(); -#else - LOGS_DEFAULT(VERBOSE) << "Reuse transposed data" << std::endl; -#endif - } else { - const T* b_data = W->template Data(); - ACLImportMemory(pGEMM->b->allocator(), (void*)b_data, W->Shape().Size() * 4); - } + ACLImportMemory(pGEMM->a->allocator(), (void*)a_data, A->Shape().Size() * 4); + ACLImportMemory(pGEMM->b->allocator(), (void*)b_data, B->Shape().Size() * 4); + if(useC){ + const T* c_data = C->template Data(); + ACLImportMemory(pGEMM->c->allocator(), (void*)c_data, C->Shape().Size() * 4); } - const T* a_data = X->template Data(); - const T* c_data = B->template Data(); - const T* d_data = Y->template Data(); - - ACLImportMemory(pGEMM->a->allocator(), (void*)a_data, X->Shape().Size() * 4); - ACLImportMemory(pGEMM->c->allocator(), (void*)c_data, B->Shape().Size() * 4); - ACLImportMemory(pGEMM->d->allocator(), (void*)d_data, Y->Shape().Size() * 4); + if(D->Shape().Size() != 0 && pGEMM->d->info()->has_padding() ){ + pGEMM->d.get()->allocator()->allocate(); + } else { + ACLImportMemory(pGEMM->d->allocator(), (void*)d_data, D->Shape().Size() * 4); + } ACLPrintTensorShape("a", *pGEMM->a); ACLPrintTensorShape("b", *pGEMM->b); @@ -188,13 +180,12 @@ class Gemm : public onnxruntime::Gemm { pGEMM->layer->run(); pGEMM->mm_layer->clear(); + if(D->Shape().Size() != 0 && pGEMM->d->info()->has_padding() ){ + importDataFromTensor(pGEMM->d.get(), d_data); + } + pGEMM->a->allocator()->free(); -#ifdef CACHE_TRANSPOSED_DATA - if (trans_B_ != CblasTrans) - pGEMM->b->allocator()->free(); -#else pGEMM->b->allocator()->free(); -#endif pGEMM->c->allocator()->free(); pGEMM->d->allocator()->free(); diff --git a/onnxruntime/core/providers/acl/nn/conv.cc b/onnxruntime/core/providers/acl/nn/conv.cc index a4a42c908d24f..c7ccaf5f04ffb 100644 --- a/onnxruntime/core/providers/acl/nn/conv.cc +++ b/onnxruntime/core/providers/acl/nn/conv.cc @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. // Licensed under the MIT License. #ifdef _WIN32 @@ -26,7 +26,8 @@ #ifdef ACL_1902 #include "arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h" -#else +#endif +#if defined(ACL_1905) || defined(ACL_1908) #include "arm_compute/runtime/NEON/functions/assembly/NEDepthwiseConvolutionAssemblyDispatch.h" #endif @@ -109,20 +110,20 @@ Status Conv::Compute(OpKernelContext* context) const { arm_compute::ActivationLayerInfo::ActivationFunction acl_activ_func; bool acl_activ_enabled = false; - if (conv_attrs_.activation == "Relu") { + if (activation_type == "Relu") { acl_activ_func = arm_compute::ActivationLayerInfo::ActivationFunction::RELU; acl_activ_enabled = true; - } else if (conv_attrs_.activation == "LeakyRelu") { + } else if (activation_type == "LeakyRelu") { acl_activ_func = arm_compute::ActivationLayerInfo::ActivationFunction::LEAKY_RELU; acl_activ_enabled = true; - } else if (conv_attrs_.activation == "Tanh") { + } else if (activation_type == "Tanh") { acl_activ_func = arm_compute::ActivationLayerInfo::ActivationFunction::TANH; acl_activ_enabled = true; - } else if (conv_attrs_.activation == "Sigmoid") { + } else if (activation_type == "Sigmoid") { acl_activ_func = arm_compute::ActivationLayerInfo::ActivationFunction::LOGISTIC; acl_activ_enabled = true; - } else if (!conv_attrs_.activation.empty()) { - ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", conv_attrs_.activation); + } else if (!activation_type.empty()) { + ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_type); } if (it == Conv::convLayers.end()) { @@ -195,7 +196,8 @@ Status Conv::Compute(OpKernelContext* context) const { tconv.in->info()->data_type(), 1 /* depth multiplier */, tconv.in->info()->data_layout()); -#else +#endif +#if defined(ACL_1905) || defined(ACL_1908) bool optimizable = arm_compute::NEDepthwiseConvolutionAssemblyDispatch::is_optimized_supported(tconv.in->info(), tconv.k->info(), @@ -205,12 +207,18 @@ Status Conv::Compute(OpKernelContext* context) const { #endif if(optimizable) { //optimized depthwise convolution +#if defined(ACL_1902) || defined(ACL_1905) auto layer = std::make_shared(); +#endif +#ifdef ACL_1908 + auto layer = std::make_shared(); +#endif #ifdef ACL_1902 layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), aclPadStride, 1 /* depth multiplier */, acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo()); -#else +#endif +#if defined(ACL_1905) || defined(ACL_1908) layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), aclPadStride, 1 /* depth multiplier */, acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), @@ -225,7 +233,7 @@ Status Conv::Compute(OpKernelContext* context) const { ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); return s; } -#endif +#endif //DEPTHWISE_CPU } else { if(tconv.k->info()->tensor_shape()[0] == 1 && tconv.k->info()->tensor_shape()[1] == 1) { //pointwise convolution diff --git a/onnxruntime/core/providers/acl/nn/conv.h b/onnxruntime/core/providers/acl/nn/conv.h index e4e413e454349..7d4fc2b170acd 100644 --- a/onnxruntime/core/providers/acl/nn/conv.h +++ b/onnxruntime/core/providers/acl/nn/conv.h @@ -1,5 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. // Licensed under the MIT License. #pragma once @@ -36,7 +36,7 @@ typedef struct typedef std::map::iterator ConvLayersIterator; template -class Conv final : public onnxruntime::Conv { +class Conv : public onnxruntime::Conv { public: explicit Conv(const OpKernelInfo& info) : onnxruntime::Conv(info), conv_attrs_(info) { provider_ = (const_cast( @@ -49,10 +49,11 @@ class Conv final : public onnxruntime::Conv { Status Compute(OpKernelContext* context) const override; - private: + protected: static thread_local std::map convLayers; ConvAttributes conv_attrs_; ACLExecutionProvider* provider_; + std::string activation_type; arm_compute::TensorShape ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const; }; diff --git a/onnxruntime/core/providers/acl/nn/fused_conv.cc b/onnxruntime/core/providers/acl/nn/fused_conv.cc new file mode 100644 index 0000000000000..740e13091502b --- /dev/null +++ b/onnxruntime/core/providers/acl/nn/fused_conv.cc @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// Licensed under the MIT License. + +#ifdef _WIN32 +#pragma warning(disable : 4244) +#endif +#include +#include + +#include "core/providers/acl/nn/conv.h" +#include "core/providers/acl/acl_common.h" +#include "core/providers/acl/acl_fwd.h" +#include "core/providers/acl/acl_execution_provider.h" + +namespace onnxruntime { +namespace acl{ + +template +class FusedConv final : public acl::Conv { + public: + explicit FusedConv(const OpKernelInfo& info) : acl::Conv(info) { + ORT_ENFORCE(info.GetAttr("activation", &(this->activation_type)).IsOK()); + // printf("fused\n"); + } + // Status Compute(OpKernelContext* context) const override; +}; + +ONNX_OPERATOR_TYPED_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + float, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + FusedConv); + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/common.h b/onnxruntime/core/providers/common.h index c23dded2f9c40..e12deb0305fd3 100644 --- a/onnxruntime/core/providers/common.h +++ b/onnxruntime/core/providers/common.h @@ -18,7 +18,7 @@ inline int64_t HandleNegativeAxis(int64_t axis, int64_t tensor_rank) { ORT_ENFORCE(axis >= -tensor_rank && axis <= tensor_rank - 1, "axis ", axis, " is not in valid range [-", tensor_rank, ",", tensor_rank - 1, "]"); // Handle negative axis - return axis = axis < 0 ? axis + tensor_rank : axis; + return axis < 0 ? axis + tensor_rank : axis; } /** diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index c0167f23b0bcf..ff4653639e8d6 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -15,6 +15,13 @@ #include "core/framework/compute_capability.h" +namespace { +struct KernelRegistryAndStatus { + std::shared_ptr kernel_registry = std::make_shared(); + Status st; +}; +} // namespace + namespace onnxruntime { // Forward declarations of op kernels @@ -151,6 +158,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMin); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceProd); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSum); @@ -328,6 +336,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int64_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, int32_t, ReduceSum); @@ -657,10 +666,14 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { ReduceProd)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo kernel_registry = std::make_shared(); - Status st; -}; - KernelRegistryAndStatus GetCpuKernelRegistry() { KernelRegistryAndStatus ret; ret.st = RegisterCPUKernels(*ret.kernel_registry); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_factory.cc b/onnxruntime/core/providers/cpu/cpu_provider_factory.cc index facfdbc827806..28445122c9d7e 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_factory.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_factory.cc @@ -35,7 +35,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessio return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type, _Out_ OrtMemoryInfo** out) { +ORT_API_STATUS_IMPL(OrtApis::CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type, + _Outptr_ OrtMemoryInfo** out) { *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), 0, mem_type); return nullptr; } diff --git a/onnxruntime/core/providers/cpu/generator/random.cc b/onnxruntime/core/providers/cpu/generator/random.cc index 78401e03247a4..bd602bd9f8440 100644 --- a/onnxruntime/core/providers/cpu/generator/random.cc +++ b/onnxruntime/core/providers/cpu/generator/random.cc @@ -17,12 +17,8 @@ limitations under the License. #include "core/providers/cpu/generator/random.h" -// build\windows\debug\external\eigen3\unsupported\eigen\cxx11\src/Tensor/Tensor.h(76): -// warning C4554: '&': check operator precedence for possible error; use parentheses to clarify precedence -// build\windows\relwithdebinfo\eigen\src\eigen\eigen-eigen-5a0156e40feb\unsupported\eigen\cxx11\src/Tensor/TensorChipping.h(52) -// warning C4100: 'dim': unreferenced formal parameter #ifdef _WIN32 -#pragma warning(disable : 4554 4100) +#pragma warning(disable : 28020) #endif #include diff --git a/onnxruntime/core/providers/cpu/math/softmax.cc b/onnxruntime/core/providers/cpu/math/softmax.cc index a6cfbf0b28b79..f1f0dd69fb409 100644 --- a/onnxruntime/core/providers/cpu/math/softmax.cc +++ b/onnxruntime/core/providers/cpu/math/softmax.cc @@ -2,7 +2,11 @@ // Licensed under the MIT License. #include "onnxruntime_config.h" -//Ignore a wired warning in gcc 7.4.0. The latest gcc doesn't generate this warning +#if defined(_MSC_VER) +// TODO: fix the warning +#pragma warning(disable : 28020) +#endif +// Ignore a weird warning in gcc 7.4.0. The latest gcc doesn't generate this warning #ifdef __GNUC__ #ifdef HAS_MAYBE_UNINITIALIZED #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc index c05617f808084..92a976be825be 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cpu/reduction/reduction_ops.cc @@ -129,6 +129,8 @@ REGISTER_UNARY_ELEMENTWISE_KERNEL_UINT8_ONLY(ReduceMin, 12); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceProd, 11); +REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL_INT64_ONLY(ReduceProd, 1, 10); +REGISTER_UNARY_ELEMENTWISE_KERNEL_INT64_ONLY(ReduceProd, 11); REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 1, 10); REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSum, 11); diff --git a/onnxruntime/core/providers/cpu/tensor/quantize_linear.cc b/onnxruntime/core/providers/cpu/tensor/quantize_linear.cc index ec7be04b1d701..3ba3ba6837e71 100644 --- a/onnxruntime/core/providers/cpu/tensor/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/tensor/quantize_linear.cc @@ -12,10 +12,7 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( 10, uint8_t, KernelDefBuilder() - .TypeConstraint("x", DataTypeImpl::GetTensorType()) - .TypeConstraint("x_scale", DataTypeImpl::GetTensorType()) - .TypeConstraint("x_zero_point", DataTypeImpl::GetTensorType()) - .TypeConstraint("y", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", DataTypeImpl::GetTensorType()), DequantizeLinear); ONNX_CPU_OPERATOR_TYPED_KERNEL( @@ -23,10 +20,7 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( 10, int8_t, KernelDefBuilder() - .TypeConstraint("x", DataTypeImpl::GetTensorType()) - .TypeConstraint("x_scale", DataTypeImpl::GetTensorType()) - .TypeConstraint("x_zero_point", DataTypeImpl::GetTensorType()) - .TypeConstraint("y", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", DataTypeImpl::GetTensorType()), DequantizeLinear); template @@ -35,49 +29,45 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { auto& x = *ctx->Input(0); auto& x_scale = *ctx->Input(1); auto& x_zero_point = *ctx->Input(2); - auto& y = *ctx->Output(0, x.Shape()); const auto& x_shape = x.Shape(); + auto& y = *ctx->Output(0, x_shape); - int64_t axis = 0; - int64_t broadcastDim = x_shape[axis]; - size_t stride = 0; + int64_t N; + int64_t broadcast_dim; + int64_t block_size; if (has_axis_) { - axis = HandleNegativeAxis(axis_, x_shape.NumDimensions()); - broadcastDim = x_shape[axis]; - stride = 1; + const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions()); + N = x_shape.SizeToDimension(axis); + broadcast_dim = x_shape[axis]; + block_size = x_shape.SizeFromDimension(axis + 1); // if an axis was specified, ensure the scale and zero point are compatible - ORT_ENFORCE(x_scale.Shape().NumDimensions() == 1 && x_scale.Shape().Size() == broadcastDim, "x_scale must be 1D tensor with size ", broadcastDim); - ORT_ENFORCE(x_zero_point.Shape().NumDimensions() == 1 && x_zero_point.Shape().Size() == broadcastDim, "x_zero_point must be 1D tensor with size ", broadcastDim); + ORT_ENFORCE(x_scale.Shape().NumDimensions() == 1 && x_scale.Shape().Size() == broadcast_dim, "x_scale must be 1D tensor with size ", broadcast_dim); + ORT_ENFORCE(x_zero_point.Shape().NumDimensions() == 1 && x_zero_point.Shape().Size() == broadcast_dim, "x_zero_point must be 1D tensor with size ", broadcast_dim); } else { + N = 1; + broadcast_dim = 1; + block_size = static_cast(x_shape.Size()); + // if no axis, enforce that scale and zero point are scalars ORT_ENFORCE(IsScalarOr1ElementVector(&x_scale), "x_scale must be a scalar or 1D tensor or size 1."); ORT_ENFORCE(IsScalarOr1ElementVector(&x_zero_point), "x_zero_point must be a scalar or 1D tensor or size 1."); } - size_t N = x_shape.SizeToDimension(axis); - size_t block_size = x_shape.SizeFromDimension(axis + 1); - const T* zero_point = x_zero_point.template Data(); const float* scale = x_scale.template Data(); const T* input = x.template Data(); float* output = y.template MutableData(); - for (size_t n = 0; n < N; n++) { - const float* current_scale = scale; - const T* current_zero_point = zero_point; - - for (size_t bd = 0; bd < static_cast(broadcastDim); bd++) { - auto zp = static_cast(*current_zero_point); - auto sc = *current_scale; + for (size_t n = 0; n < static_cast(N); n++) { + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { + auto zp = static_cast(zero_point[bd]); + auto sc = scale[bd]; - for (size_t bs = 0; bs < block_size; bs++) { - *output++ = static_cast(static_cast(*input++) - zp) * sc; + for (size_t bs = 0; bs < static_cast(block_size); bs++) { + *output++ = static_cast(static_cast(*input++) - zp) * sc; } - - current_scale += stride; - current_zero_point += stride; } } @@ -89,9 +79,8 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( 10, uint8_t, KernelDefBuilder() - .TypeConstraint("x", DataTypeImpl::GetTensorType()) - .TypeConstraint("y_zero_point", DataTypeImpl::GetTensorType()) - .TypeConstraint("y", DataTypeImpl::GetTensorType()), + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), QuantizeLinear); ONNX_CPU_OPERATOR_TYPED_KERNEL( @@ -99,9 +88,8 @@ ONNX_CPU_OPERATOR_TYPED_KERNEL( 10, int8_t, KernelDefBuilder() - .TypeConstraint("x", DataTypeImpl::GetTensorType()) - .TypeConstraint("y_zero_point", DataTypeImpl::GetTensorType()) - .TypeConstraint("y", DataTypeImpl::GetTensorType()), + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), QuantizeLinear); template @@ -110,68 +98,46 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const { auto& x = *ctx->Input(0); auto& y_scale = *ctx->Input(1); auto& y_zero_point = *ctx->Input(2); - auto& y = *ctx->Output(0, x.Shape()); const auto& x_shape = x.Shape(); + auto& y = *ctx->Output(0, x_shape); - const float* input = x.template Data(); - T* output = y.template MutableData(); - - const float qmax = std::numeric_limits::max(); - const float qmin_default = std::numeric_limits::min(); - // adjust qmin for int8 inputs. This is required to keep zero point as zero - const float qmin = qmin_default == -128 ? -127 : qmin_default; - - // Schema of QuantizeLinearOp changed when it was promoted to onnx domain. In order to maintain backward compatiblity - // both the versions need to be supported. - if (ctx->GetOpDomain() != kMSDomain) { - ORT_ENFORCE(IsScalarOr1ElementVector(&y_scale), "x_scale must be a scalar or 1D tensor or size 1."); - ORT_ENFORCE(IsScalarOr1ElementVector(&y_zero_point), "x_zero_point must be a scalar or 1D tensor or size 1."); + int64_t N; + int64_t broadcast_dim; + int64_t block_size; - const T zero_point = *(y_zero_point.template Data()); - const float scale = *(y_scale.template Data()); - const auto num_of_elements = x_shape.Size(); - - MlasQuantizeLinear(input, output, num_of_elements, scale, zero_point); - - } else { - size_t stride = 0; + if (has_axis_) { const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions()); - const auto& broadcastDim = x_shape[axis]; - - if (has_axis_) { - // if an axis was specified, ensure the scale and zero point are compatible - ORT_ENFORCE(y_scale.Shape().NumDimensions() == 1 && y_scale.Shape().Size() == broadcastDim, "x_scale must be 1D tensor with size ", broadcastDim); - ORT_ENFORCE(y_zero_point.Shape().NumDimensions() == 1 && y_zero_point.Shape().Size() == broadcastDim, "x_zero_point must be 1D tensor with size ", broadcastDim); - stride = 1; - } else { - // if no axis, enforce that scale and zero point are scalars - ORT_ENFORCE(IsScalarOr1ElementVector(&y_scale), "x_scale must be a scalar or 1D tensor or size 1."); - ORT_ENFORCE(IsScalarOr1ElementVector(&y_zero_point), "x_zero_point must be a scalar or 1D tensor or size 1."); - } + N = x_shape.SizeToDimension(axis); + broadcast_dim = x_shape[axis]; + block_size = x_shape.SizeFromDimension(axis + 1); - size_t N = x_shape.SizeToDimension(axis); - size_t block_size = x_shape.SizeFromDimension(axis + 1); - const T* zero_point = y_zero_point.template Data(); - const float* scale = y_scale.template Data(); - - for (size_t n = 0; n < N; n++) { - const float* current_scale = scale; - const T* current_zero_point = zero_point; + // if an axis was specified, ensure the scale and zero point are compatible + ORT_ENFORCE(y_scale.Shape().NumDimensions() == 1 && y_scale.Shape().Size() == broadcast_dim, "x_scale must be 1D tensor with size ", broadcast_dim); + ORT_ENFORCE(y_zero_point.Shape().NumDimensions() == 1 && y_zero_point.Shape().Size() == broadcast_dim, "x_zero_point must be 1D tensor with size ", broadcast_dim); + } else { + N = 1; + broadcast_dim = 1; + block_size = x_shape.Size(); - for (size_t bd = 0; bd < static_cast(broadcastDim); bd++) { - float zp = *current_zero_point; - auto sc = *current_scale; + // if no axis, enforce that scale and zero point are scalars + ORT_ENFORCE(IsScalarOr1ElementVector(&y_scale), "x_scale must be a scalar or 1D tensor or size 1."); + ORT_ENFORCE(IsScalarOr1ElementVector(&y_zero_point), "x_zero_point must be a scalar or 1D tensor or size 1."); + } - for (size_t bs = 0; bs < block_size; bs++) { - *output++ = static_cast(clamp(std::round(static_cast(*input++) / sc) + zp, qmin, qmax)); - } + const T* zero_point = y_zero_point.template Data(); + const float* scale = y_scale.template Data(); + const float* input = x.template Data(); + T* output = y.template MutableData(); - current_scale += stride; - current_zero_point += stride; - } + for (size_t n = 0; n < static_cast(N); n++) { + for (size_t bd = 0; bd < static_cast(broadcast_dim); bd++) { + MlasQuantizeLinear(input, output, static_cast(block_size), scale[bd], zero_point[bd]); + input += block_size; + output += block_size; } } return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index 7ee4b2dea0b0a..da15197574bb1 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -442,7 +442,10 @@ float CubicInterpolation1D(const T* Xdata, return result; } - +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 6001) +#endif template void ResizeBiCubic( int64_t batch_size, @@ -580,6 +583,9 @@ void ResizeBiCubic( } } } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif template Status Upsample::BaseCompute(OpKernelContext* context, diff --git a/onnxruntime/core/providers/cpu/tensor/utils.h b/onnxruntime/core/providers/cpu/tensor/utils.h index c905fa32b8be3..b8ad186e9c1d0 100644 --- a/onnxruntime/core/providers/cpu/tensor/utils.h +++ b/onnxruntime/core/providers/cpu/tensor/utils.h @@ -25,6 +25,11 @@ struct TensorPitches : std::vector { if (gsl::narrow_cast(padded_rank) < 0) return false; + // Guard against Scalars + if (pitch_rank == 0) { + return true; + } + *(p.rbegin()) = 1; // The innermost axis is 1 (single values) if (tensor_rank > 1) { for (size_t i = tensor_rank - 1; i-- > 0;) { diff --git a/onnxruntime/core/providers/cuda/cuda_call.cc b/onnxruntime/core/providers/cuda/cuda_call.cc index 805d3d0b465c2..a0855076e19d4 100644 --- a/onnxruntime/core/providers/cuda/cuda_call.cc +++ b/onnxruntime/core/providers/cuda/cuda_call.cc @@ -30,6 +30,7 @@ const char* CudaErrString(cudaError_t x) { cudaDeviceSynchronize(); return cudaGetErrorString(x); } + template <> const char* CudaErrString(cublasStatus_t e) { cudaDeviceSynchronize(); @@ -62,6 +63,21 @@ const char* CudaErrString(cudnnStatus_t e) { return cudnnGetErrorString(e); } +template <> +const char* CudaErrString(cufftResult e) { + cudaDeviceSynchronize(); + switch (e) { + CASE_ENUM_TO_STR(CUFFT_SUCCESS); + CASE_ENUM_TO_STR(CUFFT_ALLOC_FAILED); + CASE_ENUM_TO_STR(CUFFT_INVALID_VALUE); + CASE_ENUM_TO_STR(CUFFT_INTERNAL_ERROR); + CASE_ENUM_TO_STR(CUFFT_SETUP_FAILED); + CASE_ENUM_TO_STR(CUFFT_INVALID_SIZE); + default: + return "Unknown cufft error status"; + } +} + #ifdef USE_NCCL template <> const char* CudaErrString(ncclResult_t e) { @@ -122,6 +138,9 @@ template bool CudaCall(cudnnStatus_t retCode, const char* template bool CudaCall(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg); template bool CudaCall(curandStatus_t retCode, const char* exprString, const char* libName, curandStatus_t successCode, const char* msg); template bool CudaCall(curandStatus_t retCode, const char* exprString, const char* libName, curandStatus_t successCode, const char* msg); +template bool CudaCall(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg); +template bool CudaCall(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg); + #ifdef USE_NCCL template bool CudaCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg); #endif diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 027427319099e..3d8ce0c06f059 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -46,6 +46,11 @@ namespace cuda { ? common::Status::OK() \ : ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDNN2 error executing ", #expr)) +#define CUFFT_RETURN_IF_ERROR(expr) \ + ORT_RETURN_IF_ERROR(CUFFT_CALL(expr) \ + ? common::Status::OK() \ + : ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUFFT error executing ", #expr)) + // ----------------------------------------------------------------------- // Base class for CUDA kernels // ----------------------------------------------------------------------- @@ -165,7 +170,7 @@ class CudaKernel : public OpKernel { inline curandGenerator_t CurandGenerator() const { return provider_->PerThreadCurandGenerator(); } - + template inline const T* GetConstOnes(size_t count) const { return provider_->template GetConstOnes(count); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d4e5d39b15d0c..2a6853299cb2d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -17,6 +17,13 @@ using namespace onnxruntime::common; +namespace { +struct KernelRegistryAndStatus { + std::shared_ptr kernel_registry = std::make_shared(); + Status st; +}; +} // namespace + namespace onnxruntime { namespace cuda { @@ -745,8 +752,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, int8_t, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, uint8_t, ReduceMin); - -static void RegisterCudaKernels(KernelRegistry& kernel_registry) { +static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1257,26 +1263,28 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { }; for (auto& function_table_entry : function_table) { - kernel_registry.Register(function_table_entry()); + ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); } -} - -std::shared_ptr GetCudaKernelRegistry() { - std::shared_ptr kernel_registry = std::make_shared(); - RegisterCudaKernels(*kernel_registry); #ifndef DISABLE_CONTRIB_OPS - ::onnxruntime::contrib::cuda::RegisterCudaContribKernels(*kernel_registry); + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaContribKernels(kernel_registry)); #endif + return Status::OK(); +} - return kernel_registry; +KernelRegistryAndStatus GetCudaKernelRegistry() { + KernelRegistryAndStatus ret; + ret.st = RegisterCudaKernels(*ret.kernel_registry); + return ret; } } // namespace cuda std::shared_ptr CUDAExecutionProvider::GetKernelRegistry() const { - static std::shared_ptr kernel_registry = onnxruntime::cuda::GetCudaKernelRegistry(); - return kernel_registry; + static KernelRegistryAndStatus k = onnxruntime::cuda::GetCudaKernelRegistry(); + // throw if the registry failed to initialize + ORT_THROW_IF_ERROR(k.st); + return k.kernel_registry; } static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, @@ -1344,7 +1352,7 @@ static bool ConvNeedFallbackToCPU(const onnxruntime::Node& node) { //cudnn only supports symmetric padding if ("pads" == attr_name && ::ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_INTS == attr_value.type()) { - auto pads = attr_value.ints(); + auto& pads = attr_value.ints(); int pads_size = pads.size(); ORT_ENFORCE(pads_size % 2 == 0); int rank = pads_size / 2; diff --git a/onnxruntime/core/providers/cuda/cuda_pch.h b/onnxruntime/core/providers/cuda/cuda_pch.h index 32355050608de..b5899532176ef 100644 --- a/onnxruntime/core/providers/cuda/cuda_pch.h +++ b/onnxruntime/core/providers/cuda/cuda_pch.h @@ -13,6 +13,7 @@ #include #include #include +#include #ifdef USE_NCCL #include @@ -20,4 +21,4 @@ #if defined(_MSC_VER) #pragma warning(pop) -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu index cda4bd850849d..4a8b44b8367ca 100644 --- a/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu +++ b/onnxruntime/core/providers/cuda/object_detection/non_max_suppression_impl.cu @@ -47,7 +47,6 @@ constexpr int kNmsBoxesPerThread = 8 * sizeof(int); // i / 32 == i >> 5. Using these bit operations should reduce the stall on host // thread. __device__ constexpr int NumBits(int n) { return (n == 0) ? 0 : NumBits(n >> 1) + 1; } -constexpr int kNmsBoxesPerThreadModuloMask = kNmsBoxesPerThread - 1; constexpr int kNmsBlockDim = 16; constexpr int kNmsBlockDimMax = 128; diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu index cb53a91161dd8..ec34c3fee9407 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu +++ b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu @@ -398,4 +398,4 @@ template void reduce_matrix_rows( const double* data, double* output, int m, int n); } // namespace cuda -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h index 0a162c6f5847d..75379e492f290 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h @@ -19,6 +19,7 @@ bool CudaCall(ERRTYPE retCode, const char* exprString, const char* libName, ERRT #define CURAND_CALL(expr) (CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS)) #define CUDNN_CALL(expr) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS)) #define CUDNN_CALL2(expr, m) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, m)) +#define CUFFT_CALL(expr) (CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS)) #define CUDA_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess)) #define CUBLAS_CALL_THROW(expr) (CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS)) @@ -26,6 +27,7 @@ bool CudaCall(ERRTYPE retCode, const char* exprString, const char* libName, ERRT #define CURAND_CALL_THROW(expr) (CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS)) #define CUDNN_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS)) #define CUDNN_CALL_THROW2(expr, m) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, m)) +#define CUFFT_CALL_THROW(expr) (CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS)) #ifdef USE_NCCL #define NCCL_CALL(expr) (CudaCall((expr), #expr, "NCCL", ncclSuccess)) diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index c87975cda10f2..3f3979a5103c3 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -65,7 +65,7 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { const int64_t* pads_tensor_raw_data = pads_tensor.template Data(); size_t pads_size = static_cast(pads_tensor.Shape().Size()); - ORT_ENFORCE(pads_size == 2 * dimension_count, + ORT_ENFORCE(pads_size == 2 * static_cast(dimension_count), "Pads tensor size should be equal to twice the input dimension count "); pads.reserve(2 * dimension_count); diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cc b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cc index 626313096ed43..6cc9b6e57df44 100644 --- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cc +++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cc @@ -9,28 +9,10 @@ namespace onnxruntime { namespace cuda { -ONNX_OPERATOR_TYPED_KERNEL_EX(QuantizeLinear, - kOnnxDomain, - 10, - uint8_t, - kCudaExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - QuantizeLinear); - -ONNX_OPERATOR_TYPED_KERNEL_EX(QuantizeLinear, - kOnnxDomain, - 10, - int8_t, - kCudaExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), - QuantizeLinear); - -template -Status QuantizeLinear::ComputeInternal(OpKernelContext* ctx) const { +template +Status QuantizeLinear::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToCudaType::MappedType CudaU; + auto x = ctx->Input(0); auto y_scale = ctx->Input(1); auto y_zero_point = ctx->Input(2); @@ -42,14 +24,15 @@ Status QuantizeLinear::ComputeInternal(OpKernelContext* ctx) const { const auto& x_shape = x->Shape(); - const float* input = x->template Data(); + const CudaU* input = reinterpret_cast(x->template Data()); T* output = y->template MutableData(); - ORT_ENFORCE(IsScalarOr1ElementVector(y_scale), "x_scale must be a scalar or 1D tensor of size 1."); - ORT_ENFORCE(IsScalarOr1ElementVector(y_zero_point), "x_zero_point must be a scalar or 1D tensor of size 1."); + // TO DO: support per-channel + ORT_ENFORCE(IsScalarOr1ElementVector(y_scale), "y_scale must be a scalar or 1D tensor of size 1."); + ORT_ENFORCE(IsScalarOr1ElementVector(y_zero_point), "y_zero_point must be a scalar or 1D tensor of size 1."); const T* zero_point = y_zero_point->template Data(); - const float* scale = y_scale->template Data(); + const CudaU* scale = reinterpret_cast(y_scale->template Data()); const auto num_of_elements = x_shape.Size(); CudaQuantizeLinear(input, output, scale, zero_point, num_of_elements); @@ -57,26 +40,10 @@ Status QuantizeLinear::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } -ONNX_OPERATOR_TYPED_KERNEL_EX(DequantizeLinear, - kOnnxDomain, - 10, - uint8_t, - kCudaExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - DequantizeLinear); - -ONNX_OPERATOR_TYPED_KERNEL_EX(DequantizeLinear, - kOnnxDomain, - 10, - int8_t, - kCudaExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - DequantizeLinear); - -template -Status DequantizeLinear::ComputeInternal(OpKernelContext* ctx) const { +template +Status DequantizeLinear::ComputeInternal(OpKernelContext* ctx) const { + typedef typename ToCudaType::MappedType CudaU; + auto x = ctx->Input(0); auto y_scale = ctx->Input(1); auto y_zero_point = ctx->Input(2); @@ -90,13 +57,13 @@ Status DequantizeLinear::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(y != nullptr); const T* input = x->template Data(); - float* output = y->template MutableData(); + CudaU* output = reinterpret_cast(y->template MutableData()); - ORT_ENFORCE(IsScalarOr1ElementVector(y_scale), "x_scale must be a scalar or 1D tensor of size 1."); - ORT_ENFORCE(IsScalarOr1ElementVector(y_zero_point), "x_zero_point must be a scalar or 1D tensor of size 1."); + ORT_ENFORCE(IsScalarOr1ElementVector(y_scale), "y_scale must be a scalar or 1D tensor of size 1."); + ORT_ENFORCE(IsScalarOr1ElementVector(y_zero_point), "y_zero_point must be a scalar or 1D tensor of size 1."); const T* zero_point = y_zero_point->template Data(); - const float* scale = y_scale->template Data(); + const CudaU* scale = reinterpret_cast(y_scale->template Data()); const auto num_of_elements = x_shape.Size(); CudaDequantizeLinear(input, output, scale, zero_point, num_of_elements); @@ -104,5 +71,46 @@ Status DequantizeLinear::ComputeInternal(OpKernelContext* ctx) const { return Status::OK(); } +// register QuantizeLinear kernels +#define REGISTER_Q_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + QuantizeLinear, \ + kOnnxDomain, \ + 10, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + QuantizeLinear); + +REGISTER_Q_KERNEL_TYPED(int8_t) +REGISTER_Q_KERNEL_TYPED(uint8_t) + +// register DequantizeLinear kernels +#define REGISTER_DQ_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DequantizeLinear, \ + kOnnxDomain, \ + 10, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + DequantizeLinear); + +REGISTER_DQ_KERNEL_TYPED(int8_t) +REGISTER_DQ_KERNEL_TYPED(uint8_t) + +// specialize QuantizeLinear::ComputeInternal and DequantizeLinear::ComputeInternal +#define SPECIALIZED_QDQ_COMPUTE(T, U) \ + template Status QuantizeLinear::ComputeInternal(OpKernelContext* ctx) const; \ + template Status DequantizeLinear::ComputeInternal(OpKernelContext* ctx) const; + +SPECIALIZED_QDQ_COMPUTE(int8_t, float) +SPECIALIZED_QDQ_COMPUTE(uint8_t, float) +SPECIALIZED_QDQ_COMPUTE(int8_t, MLFloat16) +SPECIALIZED_QDQ_COMPUTE(uint8_t, MLFloat16) + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu index ec49023ced095..b5f97f8e99cbe 100644 --- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu +++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cu @@ -8,6 +8,20 @@ namespace onnxruntime { namespace cuda { +template +__global__ void QuantizeLinearKernel(const half* input, int8_t* output, const half* scale, const int8_t* zero_point, CUDA_LONG N) { + CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + int value = __half2int_rn(input[id] / (*scale)) + *zero_point; + output[id] = static_cast(max(-128, min(127, value))); + id += NumThreadsPerBlock; + } + } +} + template __global__ void QuantizeLinearKernel(const float* input, int8_t* output, const float* scale, const int8_t* zero_point, CUDA_LONG N) { CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; @@ -36,56 +50,72 @@ __global__ void QuantizeLinearKernel(const float* input, uint8_t* output, const } } -template -Status CudaQuantizeLinear(const float* input, T* output, const float* scale, const T* zero_point, size_t num_of_element) { +template +__global__ void QuantizeLinearKernel(const half* input, uint8_t* output, const half* scale, const uint8_t* zero_point, CUDA_LONG N) { + CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + int value = __half2int_rn(input[id] / (*scale)) + *zero_point; + output[id] = static_cast(max(0, min(255, value))); + id += NumThreadsPerBlock; + } + } +} + +template +Status CudaQuantizeLinear(const U* input, T* output, const U* scale, const T* zero_point, size_t num_of_element) { if (num_of_element <= 0) return Status::OK(); int blocksPerGrid = static_cast(CeilDiv(num_of_element, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); - QuantizeLinearKernel - <<>>( - input, - output, - scale, - zero_point, - num_of_element); + QuantizeLinearKernel<<>>( + input, + output, + scale, + zero_point, + num_of_element); return Status::OK(); } -template -__global__ void DequantizeLinearKernel(const T* input, float* output, const float* scale, const T* zero_point, CUDA_LONG N) { +template +__global__ void DequantizeLinearKernel(const T* input, U* output, const U* scale, const T* zero_point, CUDA_LONG N) { CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; #pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { if (id < N) { - output[id] = (input[id] - *zero_point) * (*scale); + output[id] = static_cast((input[id] - *zero_point)) * (*scale); id += NumThreadsPerBlock; } } } -template -Status CudaDequantizeLinear(const T* input, float* output, const float* scale, const T* zero_point, size_t num_of_element) { +template +Status CudaDequantizeLinear(const T* input, U* output, const U* scale, const T* zero_point, size_t num_of_element) { if (num_of_element <= 0) return Status::OK(); int blocksPerGrid = static_cast(CeilDiv(num_of_element, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); - DequantizeLinearKernel - <<>>( - input, - output, - scale, - zero_point, - num_of_element); + DequantizeLinearKernel<<>>( + input, + output, + scale, + zero_point, + num_of_element); return Status::OK(); } -template Status CudaQuantizeLinear(const float* input, int8_t* output, const float* scale, const int8_t* zero_point, size_t num_of_element); -template Status CudaQuantizeLinear(const float* input, uint8_t* output, const float* scale, const uint8_t* zero_point, size_t num_of_element); +template Status CudaQuantizeLinear(const float* input, int8_t* output, const float* scale, const int8_t* zero_point, size_t num_of_element); +template Status CudaQuantizeLinear(const float* input, uint8_t* output, const float* scale, const uint8_t* zero_point, size_t num_of_element); +template Status CudaQuantizeLinear(const half* input, int8_t* output, const half* scale, const int8_t* zero_point, size_t num_of_element); +template Status CudaQuantizeLinear(const half* input, uint8_t* output, const half* scale, const uint8_t* zero_point, size_t num_of_element); -template Status CudaDequantizeLinear(const int8_t* input, float* output, const float* scale, const int8_t* zero_point, size_t num_of_element); -template Status CudaDequantizeLinear(const uint8_t* input, float* output, const float* scale, const uint8_t* zero_point, size_t num_of_element); +template Status CudaDequantizeLinear(const int8_t* input, float* output, const float* scale, const int8_t* zero_point, size_t num_of_element); +template Status CudaDequantizeLinear(const uint8_t* input, float* output, const float* scale, const uint8_t* zero_point, size_t num_of_element); +template Status CudaDequantizeLinear(const int8_t* input, half* output, const half* scale, const int8_t* zero_point, size_t num_of_element); +template Status CudaDequantizeLinear(const uint8_t* input, half* output, const half* scale, const uint8_t* zero_point, size_t num_of_element); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cuh b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cuh index a5d42d1180137..5d140981d61b2 100644 --- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.cuh +++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.cuh @@ -11,11 +11,11 @@ namespace onnxruntime { namespace cuda { -template -Status CudaQuantizeLinear(const float* input, T* output, const float* scale, const T* zero_point, size_t num_of_element); +template +Status CudaQuantizeLinear(const U* input, T* output, const U* scale, const T* zero_point, size_t num_of_element); -template -Status CudaDequantizeLinear(const T* input, float* output, const float* scale, const T* zero_point, size_t num_of_element); +template +Status CudaDequantizeLinear(const T* input, U* output, const U* scale, const T* zero_point, size_t num_of_element); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/quantize_linear.h b/onnxruntime/core/providers/cuda/tensor/quantize_linear.h index 7c964dc4ee741..9a36c716391cf 100644 --- a/onnxruntime/core/providers/cuda/tensor/quantize_linear.h +++ b/onnxruntime/core/providers/cuda/tensor/quantize_linear.h @@ -11,7 +11,7 @@ namespace onnxruntime { namespace cuda { -template +template class QuantizeLinear final : public CudaKernel { public: QuantizeLinear(const OpKernelInfo& info) : CudaKernel(info) {} @@ -19,7 +19,7 @@ class QuantizeLinear final : public CudaKernel { Status ComputeInternal(OpKernelContext* p_op_kernel_context) const override; }; -template +template class DequantizeLinear final : public CudaKernel { public: DequantizeLinear(const OpKernelInfo& info) : CudaKernel(info) {} diff --git a/onnxruntime/core/providers/cuda/tensor/tile.cc b/onnxruntime/core/providers/cuda/tensor/tile.cc index a66a11bb4356f..697c8293eafba 100644 --- a/onnxruntime/core/providers/cuda/tensor/tile.cc +++ b/onnxruntime/core/providers/cuda/tensor/tile.cc @@ -29,7 +29,7 @@ Status Tile::ComputeInternal(OpKernelContext* ctx) const { if (repeats_tensor.Shape().NumDimensions() != 1) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must be 1 dimensional"); - if (size_t(repeats_tensor.Shape().Size()) != rank) + if (repeats_tensor.Shape().Size() != rank) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "'repeat' input tensor must have the same length as the 'input' tensor"); // Calculate the shape of the output tensor diff --git a/onnxruntime/core/providers/cuda/tensor/where.cc b/onnxruntime/core/providers/cuda/tensor/where.cc index 3fc0b9272ed25..e2dd66e3cf233 100644 --- a/onnxruntime/core/providers/cuda/tensor/where.cc +++ b/onnxruntime/core/providers/cuda/tensor/where.cc @@ -94,40 +94,30 @@ struct TernaryElementwisePreparation { output_rank_or_simple_broadcast = out_rank; - if (a_shape != output_shape) { - TensorPitches a_pitches(a_shape.GetDims()); - a_padded_strides.size_ = out_rank; - auto offset = out_rank - a_rank; - for (auto i = offset; i < out_rank; ++i) { - // the stride for broadcast dimension is kept as 0 - if (a_shape.GetDims()[i - offset] != 1) { - a_padded_strides[i] = a_pitches[i]; + auto padder = [out_rank](int32_t rank, const TensorShape& shape, TArray& padded_strides) { + padded_strides.size_ = out_rank; + if (rank > 0) { + TensorPitches pitches(shape.GetDims()); + auto offset = out_rank - rank; + for (auto i = offset; i < out_rank; ++i) { + // the stride for broadcast dimension is kept as 0 + if (shape.GetDims()[i - offset] != 1) { + padded_strides[i] = pitches[i - offset]; + } } } + }; + + if (a_shape != output_shape) { + padder(a_rank, a_shape, a_padded_strides); } if (b_shape != output_shape) { - TensorPitches b_pitches(b_shape.GetDims()); - b_padded_strides.size_ = out_rank; - auto offset = out_rank - b_rank; - for (auto i = offset; i < out_rank; ++i) { - // the stride for broadcast dimension is kept as 0 - if (b_shape.GetDims()[i - offset] != 1) { - b_padded_strides[i] = b_pitches[i]; - } - } + padder(b_rank, b_shape, b_padded_strides); } if (c_shape != output_shape) { - TensorPitches c_pitches(c_shape.GetDims()); - c_padded_strides.size_ = out_rank; - auto offset = out_rank - c_rank; - for (auto i = offset; i < out_rank; ++i) { - // the stride for broadcast dimension is kept as 0 - if (c_shape.GetDims()[i - offset] != 1) { - c_padded_strides[i] = c_pitches[i]; - } - } + padder(c_rank, c_shape, c_padded_strides); } TensorPitches output_pitches(output_shape.GetDims()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h index a821755092b59..4f33d8381a7c8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h @@ -749,7 +749,7 @@ IMLOperatorKernelFactory : IUnknown //! \interface IMLOperatorRegistry //! \brief Represents an instance of a registry for custom operator kernel and schema. -//! Custom operators may be used with Windows.AI.MachineLearning APIs by returning +//! Custom operators may be used with WinML APIs by returning //! instances of IMLOperatorRegistry through ILearningModelOperatorProviderNative. interface DECLSPEC_UUID("2AF9DD2D-B516-4672-9AB5-530C208493AD") DECLSPEC_NOVTABLE IMLOperatorRegistry : IUnknown diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index 75ea82263eacb..fb17342e7f8e6 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -12,6 +12,12 @@ #include "dnnl_execution_provider.h" #include "dnnl_fwd.h" +namespace { +struct KernelRegistryAndStatus { + std::shared_ptr kernel_registry = std::make_shared(); + Status st; +}; +} // namespace namespace onnxruntime { constexpr const char* DNNL = "Dnnl"; @@ -44,26 +50,29 @@ DNNLExecutionProvider::~DNNLExecutionProvider() { namespace ort_dnnl { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kDnnlExecutionProvider, kOnnxDomain, 7, Gemm); -void RegisterDNNLKernels(KernelRegistry& kernel_registry) { +Status RegisterDNNLKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { - kernel_registry.Register(function_table_entry()); + ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); } + return Status::OK(); } -std::shared_ptr GetDnnlKernelRegistry() { - std::shared_ptr kernel_registry = std::make_shared(); - RegisterDNNLKernels(*kernel_registry); - return kernel_registry; +KernelRegistryAndStatus GetDnnlKernelRegistry() { + KernelRegistryAndStatus ret; + ret.st = RegisterDNNLKernels(*ret.kernel_registry); + return ret; } } // namespace ort_dnnl std::shared_ptr DNNLExecutionProvider::GetKernelRegistry() const { - static std::shared_ptr kernel_registry = onnxruntime::ort_dnnl::GetDnnlKernelRegistry(); - return kernel_registry; + static KernelRegistryAndStatus k = onnxruntime::ort_dnnl::GetDnnlKernelRegistry(); + // throw if the registry failed to initialize + ORT_THROW_IF_ERROR(k.st); + return k.kernel_registry; } bool DNNLExecutionProvider::UseSubgraph(const onnxruntime::GraphViewer& graph_viewer) const { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 2d5db7de07211..b8a725e08f556 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -26,7 +26,12 @@ using namespace ONNX_NAMESPACE; using namespace ::onnxruntime::logging; - +namespace { +struct KernelRegistryAndStatus { + std::shared_ptr kernel_registry = std::make_shared(); + Status st; +}; +} // namespace namespace onnxruntime { ONNX_OPERATOR_KERNEL_EX( @@ -54,27 +59,29 @@ ONNX_OPERATOR_KERNEL_EX( class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyToHost); -static void RegisterTensorrtKernels(KernelRegistry& kernel_registry) { +static Status RegisterTensorrtKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { - kernel_registry.Register(function_table_entry()); + ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); } + return Status::OK(); } -std::shared_ptr GetTensorrtKernelRegistry() { - std::shared_ptr kernel_registry = std::make_shared(); - RegisterTensorrtKernels(*kernel_registry); - - return kernel_registry; +KernelRegistryAndStatus GetTensorrtKernelRegistry() { + KernelRegistryAndStatus ret; + ret.st = RegisterTensorrtKernels(*ret.kernel_registry); + return ret; } std::shared_ptr TensorrtExecutionProvider::GetKernelRegistry() const { - static std::shared_ptr kernel_registry = onnxruntime::GetTensorrtKernelRegistry(); - return kernel_registry; + static KernelRegistryAndStatus k = onnxruntime::GetTensorrtKernelRegistry(); + // throw if the registry failed to initialize + ORT_THROW_IF_ERROR(k.st); + return k.kernel_registry; } // Per TensorRT documentation, logger needs to be a singleton. diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 78e5188ad97d3..62f02ef309d15 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -26,7 +26,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionOptions, OrtSessionOptions** out) { API_IMPL_END } -ORT_API(void, OrtApis::ReleaseSessionOptions, OrtSessionOptions* ptr) { +ORT_API(void, OrtApis::ReleaseSessionOptions, _Frees_ptr_opt_ OrtSessionOptions* ptr) { delete ptr; } @@ -140,12 +140,12 @@ ORT_API_STATUS_IMPL(OrtApis::SetSessionGraphOptimizationLevel, _In_ OrtSessionOp return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::SetIntraOpNumThreads, _In_ OrtSessionOptions* options, int intra_op_num_threads) { +ORT_API_STATUS_IMPL(OrtApis::SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads) { options->value.intra_op_param.thread_pool_size = intra_op_num_threads; return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::SetInterOpNumThreads, _In_ OrtSessionOptions* options, int inter_op_num_threads) { +ORT_API_STATUS_IMPL(OrtApis::SetInterOpNumThreads, _Inout_ OrtSessionOptions* options, int inter_op_num_threads) { options->value.inter_op_param.thread_pool_size = inter_op_num_threads; return nullptr; } diff --git a/onnxruntime/core/session/default_cpu_allocator_c_api.cc b/onnxruntime/core/session/default_cpu_allocator_c_api.cc index e90c6c7c86284..21abc0a77b5d8 100644 --- a/onnxruntime/core/session/default_cpu_allocator_c_api.cc +++ b/onnxruntime/core/session/default_cpu_allocator_c_api.cc @@ -57,7 +57,7 @@ struct OrtDefaultAllocator : OrtAllocatorImpl { return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ } -ORT_API_STATUS_IMPL(OrtApis::GetAllocatorWithDefaultOptions, _Out_ OrtAllocator** out) { +ORT_API_STATUS_IMPL(OrtApis::GetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out) { API_IMPL_BEGIN static OrtDefaultAllocator ort_default_allocator; *out = &ort_default_allocator; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 42406b1214fe9..2c93875c01893 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -164,7 +164,6 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. - logging_manager_ = session_env.GetLoggingManager(); InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. // Update the number of steps for the graph transformer manager using the "finalized" session options @@ -180,10 +179,9 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } // If the thread pool can use all the processors, then // we set affinity of each thread to each processor. - if (to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL && to.affinity_vec_len == 0) - to.auto_set_affinity = true; - else - to.auto_set_affinity = false; + to.auto_set_affinity = to.thread_pool_size == 0 && + session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL && + to.affinity_vec_len == 0; thread_pool_ = concurrency::CreateThreadPool(&Env::Default(), to, nullptr); } @@ -191,10 +189,8 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, OrtThreadPoolParams to = session_options_.inter_op_param; // If the thread pool can use all the processors, then // we set thread affinity. - if (to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL) - to.auto_set_affinity = true; - else - to.auto_set_affinity = false; + to.auto_set_affinity = + to.thread_pool_size == 0 && session_options_.execution_mode == ExecutionMode::ORT_SEQUENTIAL; if (to.name == nullptr) to.name = ORT_TSTR("intra-op"); inter_op_thread_pool_ = @@ -208,7 +204,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, LOGS(*session_logger_, INFO) << "Using global/env threadpools since use_per_session_threads_ is false"; intra_op_thread_pool_from_env_ = session_env.GetIntraOpThreadPool(); inter_op_thread_pool_from_env_ = session_env.GetInterOpThreadPool(); - ORT_ENFORCE(session_env.EnvCreatedWithGlobalThreadPools() == true, + ORT_ENFORCE(session_env.EnvCreatedWithGlobalThreadPools(), "When the session is not configured to use per session" " threadpools, the env must be created with the the CreateEnvWithGlobalThreadPools API."); } @@ -231,20 +227,20 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, session_id_ = global_session_id_.fetch_add(1); } -InferenceSession::InferenceSession(const SessionOptions& session_options, - const Environment& session_env) +InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env) : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), + logging_manager_(session_env.GetLoggingManager()), insert_cast_transformer_("CastFloat16Transformer") { // Initialize assets of this session instance ConstructorCommon(session_options, session_env); } -InferenceSession::InferenceSession(const SessionOptions& session_options, - const Environment& session_env, +InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, const std::string& model_uri) - : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), + : model_location_(ToWideString(model_uri)), + graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), + logging_manager_(session_env.GetLoggingManager()), insert_cast_transformer_("CastFloat16Transformer") { - model_location_ = ToWideString(model_uri); auto status = Model::Load(model_location_, model_proto_); ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ", status.ErrorMessage()); @@ -258,6 +254,7 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, const std::wstring& model_uri) : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), + logging_manager_(session_env.GetLoggingManager()), insert_cast_transformer_("CastFloat16Transformer") { model_location_ = ToWideString(model_uri); auto status = Model::Load(model_location_, model_proto_); @@ -269,10 +266,10 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, } #endif -InferenceSession::InferenceSession(const SessionOptions& session_options, - const Environment& session_env, +InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, std::istream& model_istream) : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), + logging_manager_(session_env.GetLoggingManager()), insert_cast_transformer_("CastFloat16Transformer") { google::protobuf::io::IstreamInputStream zero_copy_input(&model_istream); const bool result = model_proto_.ParseFromZeroCopyStream(&zero_copy_input) && model_istream.eof(); @@ -282,11 +279,10 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, ConstructorCommon(session_options, session_env); } -InferenceSession::InferenceSession(const SessionOptions& session_options, - const Environment& session_env, - const void* model_data, - int model_data_len) +InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, + const void* model_data, int model_data_len) : graph_transformation_mgr_(session_options.max_num_graph_transformation_steps), + logging_manager_(session_env.GetLoggingManager()), insert_cast_transformer_("CastFloat16Transformer") { const bool result = model_proto_.ParseFromArray(model_data, model_data_len); ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); @@ -365,7 +361,7 @@ common::Status InferenceSession::AddCustomTransformerList(const std::vector& op_domains) { std::shared_ptr custom_registry; ORT_RETURN_IF_ERROR_SESSIONID_(CreateCustomRegistry(op_domains, custom_registry)); - RegisterCustomRegistry(custom_registry); + ORT_RETURN_IF_ERROR_SESSIONID_(RegisterCustomRegistry(custom_registry)); return Status::OK(); } @@ -433,7 +429,7 @@ common::Status InferenceSession::Load(const std::basic_string& model_uri) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS LoadInterOp(model_location_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); for (const auto& domain : interop_domains_) { - AddCustomOpDomains({domain.get()}); + ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()})); } #endif return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, @@ -482,7 +478,7 @@ common::Status InferenceSession::Load(const ModelProto& model_proto) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); for (const auto& domain : interop_domains_) { - AddCustomOpDomains({domain.get()}); + ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()})); } #endif // This call will create a copy of model_proto and the constructed model instance will own the copy thereafter @@ -504,7 +500,7 @@ common::Status InferenceSession::Load(std::unique_ptr p_model_proto) #ifdef ENABLE_LANGUAGE_INTEROP_OPS LoadInterOp(*p_model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); for (const auto& domain : interop_domains_) { - AddCustomOpDomains({domain.get()}); + ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()})); } #endif return onnxruntime::Model::Load(std::move(*p_model_proto), PathString(), model, @@ -533,7 +529,7 @@ common::Status InferenceSession::Load(std::istream& model_istream) { #ifdef ENABLE_LANGUAGE_INTEROP_OPS LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); for (const auto& domain : interop_domains_) { - AddCustomOpDomains({domain.get()}); + ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()})); } #endif return onnxruntime::Model::Load(std::move(model_proto), PathString(), model, @@ -561,7 +557,7 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len #ifdef ENABLE_LANGUAGE_INTEROP_OPS LoadInterOp(model_proto, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); for (const auto& domain : interop_domains_) { - AddCustomOpDomains({domain.get()}); + ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()})); } #endif @@ -583,7 +579,7 @@ common::Status InferenceSession::Load() { #ifdef ENABLE_LANGUAGE_INTEROP_OPS LoadInterOp(this->model_proto_, interop_domains_, [&](const char* msg) { LOGS(*session_logger_, WARNING) << msg; }); for (const auto& domain : interop_domains_) { - AddCustomOpDomains({domain.get()}); + ORT_RETURN_IF_ERROR(AddCustomOpDomains({domain.get()})); } #endif // Pass on ownership of the parsed ModelProto to the Model instance (its job here is done by this stage) @@ -852,11 +848,11 @@ common::Status InferenceSession::Initialize() { if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL && execution_providers_.Get(onnxruntime::kCudaExecutionProvider)) { - LOGS(*session_logger_, ERROR) << "Parallel execution is currently not supported " - "for the registered CUDA Execution Provider."; + LOGS(*session_logger_, ERROR) << "Parallel execution mode doesn't support " + "CUDA Execution Provider currently."; return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Parallel execution is currently not supported " - "for the registered CUDA Execution Provider."); + "Parallel execution mode doesn't support " + "CUDA Execution Provider currently."); } // add predefined transformers @@ -869,7 +865,7 @@ common::Status InferenceSession::Initialize() { // There are 2 kinds of kernel registries with priority from high to low as below, // 1. Custom execution provider type specific kernel registries. // 2. common execution provider type specific kernel registries. - // The 1st and 2nd ones are shared across sessions. + // Kernel registries are shared across sessions. // The 1st ones should have already been registered via session-level API into KernelRegistryManager. // // Register 2nd registries into KernelRegistryManager. diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 3b15e527042b1..c63c09b9b7da1 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -159,7 +159,7 @@ class InferenceSession { * Calling this API is optional in which case onnxruntime will use its internal CPU execution provider. * @return OK if success. */ - common::Status RegisterExecutionProvider(std::unique_ptr p_exec_provider); + common::Status RegisterExecutionProvider(std::unique_ptr p_exec_provider) ORT_MUST_USE_RESULT; /** * Register a graph transformer. If you've one to register, call this before invoking Initialize(). @@ -170,7 +170,7 @@ class InferenceSession { * @return OK if success. */ common::Status RegisterGraphTransformer(std::unique_ptr p_graph_transformer, - TransformerLevel level = TransformerLevel::Level2); + TransformerLevel level = TransformerLevel::Level2) ORT_MUST_USE_RESULT; /** * Enable a custom set of transformers. Call this before invoking Initialize(). @@ -178,12 +178,12 @@ class InferenceSession { * When this list is provided ORT ignores the levels set in session options. * @return OK if success. */ - common::Status AddCustomTransformerList(const std::vector& transformers_to_enable); + common::Status AddCustomTransformerList(const std::vector& transformers_to_enable) ORT_MUST_USE_RESULT; /** * Add custom ops. This API is not thread safe. */ - common::Status AddCustomOpDomains(const std::vector& ops); + common::Status AddCustomOpDomains(const std::vector& ops) ORT_MUST_USE_RESULT; /** * Register a custom registry for operator schema and kernels. If you've one to register, @@ -194,23 +194,23 @@ class InferenceSession { * This API is not thread safe. * @return OK if success. */ - common::Status RegisterCustomRegistry(std::shared_ptr custom_registry); + common::Status RegisterCustomRegistry(std::shared_ptr custom_registry) ORT_MUST_USE_RESULT; /** * Load an ONNX model. * @param model_uri absolute path of the model file. * @return OK if success. */ - common::Status Load(const std::string& model_uri); + common::Status Load(const std::string& model_uri) ORT_MUST_USE_RESULT; #ifdef _WIN32 - common::Status Load(const std::wstring& model_uri); + common::Status Load(const std::wstring& model_uri) ORT_MUST_USE_RESULT; #endif /** * Load an ONNX model. * @param istream object of the model. * @return OK if success. */ - common::Status Load(std::istream& model_istream); + common::Status Load(std::istream& model_istream) ORT_MUST_USE_RESULT; /** * Load an ONNX model. @@ -218,14 +218,14 @@ class InferenceSession { * @param model_data_len Model data buffer size * @return OK if success. */ - common::Status Load(const void* model_data, int model_data_len); + common::Status Load(const void* model_data, int model_data_len) ORT_MUST_USE_RESULT; /** * Load an ONNX model from the member model_proto_. * To be called only in conjunction with a ctor that takes in a model path/ model stream/ model array * @return OK if success. */ - common::Status Load(); + common::Status Load() ORT_MUST_USE_RESULT; /** * Initializes a previously loaded model. Initialization includes but is not @@ -234,11 +234,11 @@ class InferenceSession { * This API is thread-safe. * @return OK if success */ - common::Status Initialize(); + common::Status Initialize() ORT_MUST_USE_RESULT; common::Status Run(const RunOptions& run_options, const std::vector& feed_names, const std::vector& feeds, const std::vector& output_names, - std::vector* p_fetches); + std::vector* p_fetches) ORT_MUST_USE_RESULT; /** * Run a pre-loaded and pre-intialized model. @@ -251,7 +251,7 @@ class InferenceSession { * @return OK if success. */ common::Status Run(const NameMLValMap& feeds, const std::vector& output_names, - std::vector* p_fetches); + std::vector* p_fetches) ORT_MUST_USE_RESULT; /** * See Run(const NameMLValMap& feeds, const std::vector& output_names, std::vector* p_fetches) @@ -259,17 +259,18 @@ class InferenceSession { * @param run_options use this to tune the Run call to your needs. */ common::Status Run(const RunOptions& run_options, const NameMLValMap& feeds, - const std::vector& output_names, std::vector* p_fetches); + const std::vector& output_names, + std::vector* p_fetches) ORT_MUST_USE_RESULT; /** * Creates a new binding object for binding inputs and outputs. * @param provider_type specifies the location where the inputs need to be potentially copied. * See IOBinding class for more info. */ - common::Status NewIOBinding(std::unique_ptr* io_binding); + common::Status NewIOBinding(std::unique_ptr* io_binding) ORT_MUST_USE_RESULT; - common::Status Run(const RunOptions& run_options, IOBinding& io_binding); - common::Status Run(IOBinding& io_binding); + common::Status Run(const RunOptions& run_options, IOBinding& io_binding) ORT_MUST_USE_RESULT; + common::Status Run(IOBinding& io_binding) ORT_MUST_USE_RESULT; /** * @return pair.first = OK; FAIL otherwise. pair.second is non-NULL when pair.first = OK. @@ -345,16 +346,16 @@ class InferenceSession { * @param protobuf object corresponding to the model file. model_proto will be copied by the API. * @return OK if success. */ - common::Status Load(const ONNX_NAMESPACE::ModelProto& model_proto); + common::Status Load(const ONNX_NAMESPACE::ModelProto& model_proto) ORT_MUST_USE_RESULT; /** * Load an ONNX model. * @param protobuf object corresponding to the model file. This is primarily to support large models. * @return OK if success. */ - common::Status Load(std::unique_ptr p_model_proto); + common::Status Load(std::unique_ptr p_model_proto) ORT_MUST_USE_RESULT; - common::Status DoPostLoadProcessing(onnxruntime::Model& model); + common::Status DoPostLoadProcessing(onnxruntime::Model& model) ORT_MUST_USE_RESULT; /// convenience pointer to logger. should always be the same as session_state_.Logger(); const logging::Logger* session_logger_; @@ -382,7 +383,7 @@ class InferenceSession { return !custom_schema_registries_.empty(); } - common::Status SaveModelMetadata(const onnxruntime::Model& model); + common::Status SaveModelMetadata(const onnxruntime::Model& model) ORT_MUST_USE_RESULT; // Create a Logger for a single execution if possible. Otherwise use the default logger. // If a new logger is created, it will also be stored in new_run_logger, @@ -392,18 +393,18 @@ class InferenceSession { const logging::Logger& CreateLoggerForRun(const RunOptions& run_options, std::unique_ptr& new_run_logger); - common::Status Load(std::function&)> loader, const std::string& event_name); + common::Status Load(std::function&)> loader, + const std::string& event_name) ORT_MUST_USE_RESULT; common::Status TransformGraph(onnxruntime::Graph& graph, const onnxruntime::GraphTransformerManager& graph_transformer_mgr, - const ExecutionProviders& providers, - KernelRegistryManager& kernel_registry_manager, + const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, const InsertCastTransformer& insert_cast_transformer, - SessionState& session_state); + SessionState& session_state) ORT_MUST_USE_RESULT; - common::Status CreateSubgraphSessionState(Graph& graph, SessionState& session_state); + common::Status CreateSubgraphSessionState(Graph& graph, SessionState& session_state) ORT_MUST_USE_RESULT; - common::Status InitializeSubgraphSessions(Graph& graph, SessionState& session_state); + common::Status InitializeSubgraphSessions(Graph& graph, SessionState& session_state) ORT_MUST_USE_RESULT; void AddPredefinedTransformers(GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, @@ -411,18 +412,19 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - common::Status CheckShapes(const std::string& input_name, - const TensorShape& input_shape, - const TensorShape& expected_shape) const; + common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, + const TensorShape& expected_shape) const ORT_MUST_USE_RESULT; - common::Status ValidateInputs(const std::vector& feed_names, const std::vector& feeds) const; + common::Status ValidateInputs(const std::vector& feed_names, + const std::vector& feeds) const ORT_MUST_USE_RESULT; - common::Status ValidateOutputs(const std::vector& output_names, const std::vector* p_fetches) const; + common::Status ValidateOutputs(const std::vector& output_names, + const std::vector* p_fetches) const ORT_MUST_USE_RESULT; - common::Status WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms); + common::Status WaitForNotification(Notification* p_executor_done, int64_t timeout_in_ms) ORT_MUST_USE_RESULT; template - common::Status Load(const std::basic_string& model_uri); + common::Status Load(const std::basic_string& model_uri) ORT_MUST_USE_RESULT; template void StartProfiling(const std::basic_string& file_prefix); @@ -437,7 +439,7 @@ class InferenceSession { std::vector transformers_to_enable_; /// Logging manager if provided. - logging::LoggingManager* logging_manager_ = nullptr; + logging::LoggingManager* const logging_manager_; /// Logger for this session. WARNING: Will contain nullptr if logging_manager_ is nullptr. std::unique_ptr owned_session_logger_ = nullptr; @@ -483,9 +485,6 @@ class InferenceSession { KernelRegistryManager kernel_registry_manager_; std::list> custom_schema_registries_; - // A set of executors that can run in parallel. - std::vector> executors_; // TODO do we need this vector? - ModelMetadata model_metadata_; std::unordered_set required_inputs_; diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index 571590836bbd1..5dbcb4f25498d 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -6,7 +6,14 @@ #include "core/session/inference_session.h" #include "core/framework/session_options.h" #include "core/common/common.h" +#ifdef _WIN32 +#pragma warning(push) +#pragma warning(disable : 28020) +#endif #include "single_include/nlohmann/json.hpp" +#ifdef _WIN32 +#pragma warning(pop) +#endif using json = nlohmann::json; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 12be053c81bb4..a9ee933d5739c 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -13,6 +13,7 @@ #include "core/common/logging/logging.h" #include "core/common/status.h" +#include "core/common/safeint.h" #include "core/graph/graph.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" @@ -43,6 +44,14 @@ using onnxruntime::common::Status; using namespace onnxruntime; +#ifndef ORT_STATUS_PTR +#ifdef _WIN32 +#define ORT_STATUS_PTR _Check_return_ _Ret_maybenull_ OrtStatusPtr +#else +#define ORT_STATUS_PTR OrtStatus* +#endif +#endif + #define ORT_API_RETURN_IF_ERROR(expr) \ do { \ auto _status = (expr); \ @@ -111,8 +120,8 @@ ORT_API_STATUS_IMPL(OrtApis::DisableTelemetryEvents, _In_ const OrtEnv* ort_env) API_IMPL_END } -OrtStatus* CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, OrtAllocator* allocator, - std::unique_ptr* out) { +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, + _Inout_ OrtAllocator* allocator, std::unique_ptr* out) { std::vector shapes(shape_len); for (size_t i = 0; i != shape_len; ++i) { shapes[i] = shape[i]; @@ -122,8 +131,7 @@ OrtStatus* CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t sha return nullptr; } -OrtStatus* CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, size_t shape_len, - Tensor& out) { +ORT_STATUS_PTR CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, size_t shape_len, Tensor& out) { std::vector shapes(shape_len); for (size_t i = 0; i != shape_len; ++i) { shapes[i] = shape[i]; @@ -144,8 +152,8 @@ OrtStatus* CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, si * * this function will create a copy of the allocator info */ -OrtStatus* CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info, - void* p_data, size_t p_data_len, std::unique_ptr* out) { +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info, + void* p_data, size_t p_data_len, std::unique_ptr* out) { size_t elem_count = 1; std::vector shapes(shape_len); for (size_t i = 0; i != shape_len; ++i) { @@ -169,15 +177,15 @@ OrtStatus* CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t sha namespace c_api_internal { template -inline OrtStatus* CallCreateTensorImpl(const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info, - void* p_data, size_t p_data_len, std::unique_ptr* out) { +inline ORT_STATUS_PTR CallCreateTensorImpl(const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info, + void* p_data, size_t p_data_len, std::unique_ptr* out) { auto ml_value = DataTypeImpl::GetType(); return CreateTensorImpl(ml_value, shape, shape_len, info, p_data, p_data_len, out); } template -inline OrtStatus* CallCreateTensorImpl(const int64_t* shape, size_t shape_len, OrtAllocator* allocator, - std::unique_ptr* out) { +inline ORT_STATUS_PTR CallCreateTensorImpl(const int64_t* shape, size_t shape_len, _Inout_ OrtAllocator* allocator, + std::unique_ptr* out) { auto ml_type = DataTypeImpl::GetType(); return CreateTensorImpl(ml_type, shape, shape_len, allocator, out); } @@ -327,18 +335,19 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCustomOpDomain, _In_ const char* domain, _Out API_IMPL_END } -ORT_API(void, OrtApis::ReleaseCustomOpDomain, OrtCustomOpDomain* ptr) { +ORT_API(void, OrtApis::ReleaseCustomOpDomain, _Frees_ptr_opt_ OrtCustomOpDomain* ptr) { delete ptr; } -ORT_API_STATUS_IMPL(OrtApis::CustomOpDomain_Add, _In_ OrtCustomOpDomain* custom_op_domain, OrtCustomOp* op) { +ORT_API_STATUS_IMPL(OrtApis::CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op) { API_IMPL_BEGIN custom_op_domain->custom_ops_.emplace_back(op); return nullptr; API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::AddCustomOpDomain, _In_ OrtSessionOptions* options, OrtCustomOpDomain* custom_op_domain) { +ORT_API_STATUS_IMPL(OrtApis::AddCustomOpDomain, _Inout_ OrtSessionOptions* options, + _In_ OrtCustomOpDomain* custom_op_domain) { API_IMPL_BEGIN options->custom_op_domains_.emplace_back(custom_op_domain); return nullptr; @@ -363,9 +372,9 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions } namespace { -OrtStatus* LoadAndInitializeSession(_In_ const OrtEnv* /*env*/, _In_ const OrtSessionOptions* options, - _In_ std::unique_ptr<::onnxruntime::InferenceSession>& sess, - _Outptr_ OrtSession** out) { +ORT_STATUS_PTR LoadAndInitializeSession(_In_ const OrtEnv* /*env*/, _In_ const OrtSessionOptions* options, + _In_ std::unique_ptr<::onnxruntime::InferenceSession>& sess, + _Outptr_ OrtSession** out) { // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of // byte addressable memory std::vector> provider_list; @@ -447,10 +456,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, - _In_opt_ const OrtRunOptions* run_options, - _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, - _In_ const char* const* output_names1, size_t output_names_len, _Outptr_ OrtValue** output) { +ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* input, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** output) { API_IMPL_BEGIN auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess); const int queue_id = 0; @@ -509,7 +519,7 @@ ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::IsTensor, _In_ const OrtValue* value, int* out) { +ORT_API_STATUS_IMPL(OrtApis::IsTensor, _In_ const OrtValue* value, _Out_ int* out) { auto v = reinterpret_cast(value); *out = v->IsTensor() ? 1 : 0; return nullptr; @@ -554,8 +564,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetStringTensorDataLength, _In_ const OrtValue* val API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetStringTensorContent, _In_ const OrtValue* value, - _Out_ void* s, size_t s_len, _Out_ size_t* offsets, size_t offsets_len) { +ORT_API_STATUS_IMPL(OrtApis::GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, + size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len) { TENSOR_READ_API_BEGIN const auto* input = tensor.Data(); auto len = static_cast(tensor.Shape().Size()); @@ -600,7 +610,7 @@ const auto get_inputs_fn = [](const ::onnxruntime::InferenceSession* session) -> const auto get_outputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetModelOutputs(); }; const auto get_overridable_initializers_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetOverridableInitializers(); }; -static OrtStatus* GetNodeDefListCountHelper(const OrtSession* sess, GetDefListFn get_fn, size_t* out) { +static ORT_STATUS_PTR GetNodeDefListCountHelper(const OrtSession* sess, GetDefListFn get_fn, size_t* out) { API_IMPL_BEGIN auto session = reinterpret_cast(sess); std::pair p = get_fn(session); @@ -623,7 +633,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerCount, _In_ const O return GetNodeDefListCountHelper(sess, get_overridable_initializers_fn, out); } -static OrtStatus* GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefListFn get_fn, size_t index, _Outptr_ struct OrtTypeInfo** out) { +static ORT_STATUS_PTR GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefListFn get_fn, size_t index, + _Outptr_ struct OrtTypeInfo** out) { API_IMPL_BEGIN auto session = reinterpret_cast(sess); std::pair p = get_fn(session); @@ -648,16 +659,15 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerTypeInfo, _In_ cons return GetNodeDefTypeInfoHelper(sess, get_overridable_initializers_fn, index, out); } -static char* StrDup(const std::string& str, OrtAllocator* allocator) { +static char* StrDup(const std::string& str, _Inout_ OrtAllocator* allocator) { char* output_string = reinterpret_cast(allocator->Alloc(allocator, str.size() + 1)); memcpy(output_string, str.c_str(), str.size()); output_string[str.size()] = '\0'; return output_string; } -static OrtStatus* GetNodeDefNameImpl(_In_ const OrtSession* sess, size_t index, - _Inout_ OrtAllocator* allocator, GetDefListFn get_fn, - _Outptr_ char** output) { +static ORT_STATUS_PTR GetNodeDefNameImpl(_In_ const OrtSession* sess, size_t index, _Inout_ OrtAllocator* allocator, + GetDefListFn get_fn, _Outptr_ char** output) { auto session = reinterpret_cast(sess); std::pair p = get_fn(session); if (!p.first.IsOK()) @@ -672,7 +682,7 @@ static OrtStatus* GetNodeDefNameImpl(_In_ const OrtSession* sess, size_t index, } ORT_API_STATUS_IMPL(OrtApis::SessionEndProfiling, _In_ OrtSession* sess, _Inout_ OrtAllocator* allocator, - _Out_ char** out) { + _Outptr_ char** out) { API_IMPL_BEGIN auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess); auto profile_file_name = session->EndProfiling(); @@ -733,10 +743,8 @@ ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetDescription, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::ModelMetadataLookupCustomMetadataMap, - _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, - _In_ const char* key, _Outptr_ char** value) { +ORT_API_STATUS_IMPL(OrtApis::ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value) { API_IMPL_BEGIN auto custom_metadata_map = reinterpret_cast(model_metadata)->custom_metadata_map; @@ -755,6 +763,38 @@ ORT_API_STATUS_IMPL(OrtApis::ModelMetadataLookupCustomMetadataMap, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetCustomMetadataMapKeys, + _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys) { + API_IMPL_BEGIN + const auto& custom_metadata_map = + reinterpret_cast(model_metadata)->custom_metadata_map; + + auto count = custom_metadata_map.size(); + if (count == 0) { + *keys = nullptr; + } else { + // To guard against overflow in the next step where we compute bytes to allocate + SafeInt alloc_count(count); + + // alloc_count * sizeof(...) will throw if there was an overflow which will be caught in API_IMPL_END + // and be returned to the user as a status + char** p = reinterpret_cast(allocator->Alloc(allocator, alloc_count * sizeof(char*))); + assert(p != nullptr); + auto map_iter = custom_metadata_map.cbegin(); + int64_t i = 0; + while (map_iter != custom_metadata_map.cend()) { + p[i++] = StrDup(map_iter->first, allocator); + ++map_iter; + } + *keys = p; + } + + *num_keys = static_cast(count); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value) { @@ -806,23 +846,23 @@ ORT_API_STATUS_IMPL(OrtApis::AllocatorGetInfo, _In_ const OrtAllocator* ptr, _Ou API_IMPL_END } -const int NUM_MAP_INDICES = 2; +static const int NUM_MAP_INDICES = 2; template -OrtStatus* OrtGetNumSequenceElements(const OrtValue* p_ml_value, size_t* out) { +ORT_STATUS_PTR OrtGetNumSequenceElements(const OrtValue* p_ml_value, size_t* out) { auto& data = p_ml_value->Get(); *out = data.size(); return nullptr; } template <> -OrtStatus* OrtGetNumSequenceElements(const OrtValue* p_ml_value, size_t* out) { +ORT_STATUS_PTR OrtGetNumSequenceElements(const OrtValue* p_ml_value, size_t* out) { auto& data = p_ml_value->Get(); *out = data.Size(); return nullptr; } -static OrtStatus* OrtGetValueCountImpl(const OrtValue* value, size_t* out) { +static ORT_STATUS_PTR OrtGetValueCountImpl(const OrtValue* value, size_t* out) { ONNXType value_type; if (auto status = OrtApis::GetValueType(value, &value_type)) return status; @@ -851,7 +891,7 @@ static OrtStatus* OrtGetValueCountImpl(const OrtValue* value, size_t* out) { } } -ORT_API_STATUS_IMPL(OrtApis::GetValueCount, const OrtValue* value, size_t* out) { +ORT_API_STATUS_IMPL(OrtApis::GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out) { API_IMPL_BEGIN return OrtGetValueCountImpl(value, out); API_IMPL_END @@ -860,7 +900,7 @@ ORT_API_STATUS_IMPL(OrtApis::GetValueCount, const OrtValue* value, size_t* out) /////////////////// // OrtGetValueImplSeqOfMap template -static OrtStatus* OrtGetValueImplSeqOfMap(const OrtValue* p_ml_value, int index, OrtValue** out) { +static ORT_STATUS_PTR OrtGetValueImplSeqOfMap(const OrtValue* p_ml_value, int index, _Outptr_ OrtValue** out) { using TKey = typename T::value_type::key_type; using TVal = typename T::value_type::mapped_type; using MapType = std::map; @@ -876,7 +916,8 @@ static OrtStatus* OrtGetValueImplSeqOfMap(const OrtValue* p_ml_value, int index, return nullptr; } -OrtStatus* PopulateTensorWithData(OrtValue* oval, const void* data_elem, size_t num_elems, size_t elem_size) { +ORT_STATUS_PTR PopulateTensorWithData(_Inout_ OrtValue* oval, _In_ const void* data_elem, size_t num_elems, + size_t elem_size) { void* raw_data = nullptr; auto st = OrtApis::GetTensorMutableData(oval, &raw_data); if (st) { @@ -886,8 +927,8 @@ OrtStatus* PopulateTensorWithData(OrtValue* oval, const void* data_elem, size_t return nullptr; } -OrtStatus* PopulateTensorWithData(OrtValue* oval, const std::string* data_elem, - size_t num_elems, size_t /* elem_size */) { +ORT_STATUS_PTR PopulateTensorWithData(_Inout_ OrtValue* oval, _In_reads_(num_elems) const std::string* data_elem, + size_t num_elems, size_t /* elem_size */) { auto v = reinterpret_cast(oval); auto tensor = v->GetMutable(); auto* dst = tensor->MutableData(); @@ -904,7 +945,8 @@ OrtStatus* PopulateTensorWithData(OrtValue* oval, const std::string* data_elem, namespace c_api_internal { template struct CallGetValueImpl { - OrtStatus* operator()(OrtAllocator* allocator, const onnxruntime::Tensor& tensor, OrtValue** out) const { + ORT_STATUS_PTR operator()(_Inout_ OrtAllocator* allocator, const onnxruntime::Tensor& tensor, + _Outptr_ OrtValue** out) const { const auto& shape = tensor.Shape(); const auto* tensor_data = tensor.Data(); OrtStatus* st = OrtApis::CreateTensorAsOrtValue(allocator, shape.GetDims().data(), shape.NumDimensions(), @@ -916,28 +958,35 @@ struct CallGetValueImpl { // Return status instead of throwing if unsupported type specified struct UnsupportedReturnFailStatus { - OrtStatus* operator()(int32_t dt_type) const { + ORT_STATUS_PTR operator()(int32_t dt_type) const { std::string msg("Unsupported tensor element type in the input: "); msg.append(std::to_string(dt_type)); return OrtApis::CreateStatus(ORT_FAIL, msg.c_str()); } }; } // namespace c_api_internal - -OrtStatus* OrtGetValueImplSeqOfTensors(const OrtValue* p_ml_value, int index, OrtAllocator* allocator, - OrtValue** out) { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 6101) +#endif +ORT_STATUS_PTR OrtGetValueImplSeqOfTensors(_In_ const OrtValue* p_ml_value, int index, _In_opt_ OrtAllocator* allocator, + _Outptr_ OrtValue** out) { auto& data = p_ml_value->Get(); auto& one_tensor = data.Get(index); using namespace c_api_internal; - utils::MLTypeCallDispatcherRet t_disp(one_tensor.GetElementType()); return t_disp.template InvokeWithUnsupportedPolicy(allocator, one_tensor, out); } -static OrtStatus* OrtGetValueImplSeq(const OrtValue* value, int index, OrtAllocator* allocator, - OrtValue** out) { +#ifdef _MSVC_VER +#pragma warning(pop) +#endif + +static ORT_STATUS_PTR OrtGetValueImplSeq(_In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** out) { auto p_ml_value = reinterpret_cast(value); auto type = p_ml_value->Type(); // Note: keep these in sync with the registered types in data_types.h @@ -956,8 +1005,8 @@ static OrtStatus* OrtGetValueImplSeq(const OrtValue* value, int index, OrtAlloca } template -static OrtStatus* OrtGetValueImplMapHelper(const OrtValue* p_ml_value, int index, OrtAllocator* allocator, - OrtValue** out) { +static ORT_STATUS_PTR OrtGetValueImplMapHelper(_In_ const OrtValue* p_ml_value, int index, + _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) { using namespace onnxruntime::utils; using TKey = typename T::key_type; using TVal = typename T::mapped_type; @@ -994,8 +1043,8 @@ static OrtStatus* OrtGetValueImplMapHelper(const OrtValue* p_ml_value, int index } } -static OrtStatus* OrtGetValueImplMap(const OrtValue* value, int index, OrtAllocator* allocator, - OrtValue** out) { +static ORT_STATUS_PTR OrtGetValueImplMap(_In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** out) { auto p_ml_value = reinterpret_cast(value); auto type = p_ml_value->Type(); // Note: keep these in sync with the registered types in data_types.h @@ -1022,8 +1071,8 @@ static OrtStatus* OrtGetValueImplMap(const OrtValue* value, int index, OrtAlloca return OrtApis::CreateStatus(ORT_FAIL, "Input is not of one of the supported map types."); } -static OrtStatus* OrtGetValueImpl(const OrtValue* value, int index, OrtAllocator* allocator, - OrtValue** out) { +static ORT_STATUS_PTR OrtGetValueImpl(_In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** out) { ONNXType value_type; if (auto status = OrtApis::GetValueType(value, &value_type)) return status; @@ -1037,8 +1086,8 @@ static OrtStatus* OrtGetValueImpl(const OrtValue* value, int index, OrtAllocator } } -ORT_API_STATUS_IMPL(OrtApis::GetValue, const OrtValue* value, int index, OrtAllocator* allocator, - OrtValue** out) { +ORT_API_STATUS_IMPL(OrtApis::GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** out) { API_IMPL_BEGIN return OrtGetValueImpl(value, index, allocator, out); API_IMPL_END @@ -1047,7 +1096,8 @@ ORT_API_STATUS_IMPL(OrtApis::GetValue, const OrtValue* value, int index, OrtAllo /////////////////// // OrtCreateValue template -static OrtStatus* OrtCreateValueImplSeqHelperMap(const OrtValue* const* in, size_t num_values, OrtValue** out) { +static OrtStatus* OrtCreateValueImplSeqHelperMap(const OrtValue* const* in, size_t num_values, + _Outptr_ OrtValue** out) { using SeqType = std::vector; auto seq_ptr = onnxruntime::make_unique(); seq_ptr->reserve(num_values); @@ -1099,8 +1149,8 @@ struct CallCreateValueImpl { } // namespace c_api_internal -static OrtStatus* OrtCreateValueImplSeqHelper(const OrtValue* const* in, size_t num_values, - OrtValue** out) { +static ORT_STATUS_PTR OrtCreateValueImplSeqHelper(const OrtValue* const* in, size_t num_values, + _Outptr_ OrtValue** out) { using namespace c_api_internal; std::vector tensors; tensors.resize(num_values); @@ -1140,8 +1190,8 @@ static OrtStatus* OrtCreateValueImplSeqHelper(const OrtValue* const* in, size_t return nullptr; } -static OrtStatus* OrtCreateValueImplSeq(const OrtValue* const* in, size_t num_values, - OrtValue** out) { +static ORT_STATUS_PTR OrtCreateValueImplSeq(_In_reads_(num_values) const OrtValue* const* in, size_t num_values, + _Outptr_ OrtValue** out) { // We only support limited sequence types. For the sake of simplicity the type of the first // OrtValue* in OrtValue** will determine the type of the vector used to create the output OrtValue // this type should be either a tensor of limited types or map of limited types @@ -1189,8 +1239,7 @@ static OrtStatus* OrtCreateValueImplSeq(const OrtValue* const* in, size_t num_va } template -static OrtStatus* OrtCreateMapMLValue(const Tensor& key_tensor, const Tensor& value_tensor, - OrtValue** out) { +static OrtStatus* OrtCreateMapMLValue(const Tensor& key_tensor, const Tensor& value_tensor, _Outptr_ OrtValue** out) { using MapType = std::map; auto map_ptr = onnxruntime::make_unique(); // iterate through the key and value tensors and populate map @@ -1213,8 +1262,8 @@ static OrtStatus* OrtCreateMapMLValue(const Tensor& key_tensor, const Tensor& va } template -static OrtStatus* OrtCreateValueImplMapHelper(const Tensor& key_tensor, const Tensor& value_tensor, - OrtValue** out) { +static ORT_STATUS_PTR OrtCreateValueImplMapHelper(const Tensor& key_tensor, const Tensor& value_tensor, + _Outptr_ OrtValue** out) { auto value_type = value_tensor.DataType()->AsPrimitiveDataType(); ORT_ENFORCE(value_type != nullptr, "Tensor must always contain primitive types. Found: ", DataTypeImpl::ToString(value_tensor.DataType())); @@ -1241,7 +1290,7 @@ static OrtStatus* OrtCreateValueImplMapHelper(const Tensor& key_tensor, const Te return OrtApis::CreateStatus(ORT_FAIL, msg.c_str()); } -static OrtStatus* OrtCreateValueImplMap(const OrtValue* const* in, size_t num_values, OrtValue** out) { +static ORT_STATUS_PTR OrtCreateValueImplMap(const OrtValue* const* in, size_t num_values, _Outptr_ OrtValue** out) { if (num_values != NUM_MAP_INDICES) { return OrtApis::CreateStatus(ORT_FAIL, "For map type num_values MUST be 2"); } @@ -1273,8 +1322,8 @@ static OrtStatus* OrtCreateValueImplMap(const OrtValue* const* in, size_t num_va return OrtApis::CreateStatus(ORT_FAIL, "Key type is not supported yet."); } -static OrtStatus* OrtCreateValueImpl(const OrtValue* const* in, size_t num_values, enum ONNXType value_type, - OrtValue** out) { +static ORT_STATUS_PTR OrtCreateValueImpl(_In_reads_(num_values) const OrtValue* const* in, size_t num_values, + enum ONNXType value_type, _Outptr_ OrtValue** out) { if (num_values <= 0) { return OrtApis::CreateStatus(ORT_FAIL, "Number of values should be at least 1."); } @@ -1287,15 +1336,15 @@ static OrtStatus* OrtCreateValueImpl(const OrtValue* const* in, size_t num_value return OrtApis::CreateStatus(ORT_FAIL, "Input is not of type sequence or map."); } -ORT_API_STATUS_IMPL(OrtApis::CreateValue, const OrtValue* const* in, size_t num_values, enum ONNXType value_type, - OrtValue** out) { +ORT_API_STATUS_IMPL(OrtApis::CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values, + enum ONNXType value_type, _Outptr_ OrtValue** out) { API_IMPL_BEGIN return OrtCreateValueImpl(in, num_values, value_type, out); API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::CreateOpaqueValue, const char* domain_name, const char* type_name, const void* data_container, - size_t data_container_size, OrtValue** out) { +ORT_API_STATUS_IMPL(OrtApis::CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name, + _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out) { API_IMPL_BEGIN std::string dtype("opaque("); dtype.append(domain_name).append(",").append(type_name).append(")"); @@ -1311,8 +1360,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateOpaqueValue, const char* domain_name, const c return nullptr; } -ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, const char* domain_name, const char* type_name, const OrtValue* in, - void* data_container, size_t data_container_size) { +ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, + _In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size) { API_IMPL_BEGIN std::string dtype("opaque("); dtype.append(domain_name).append(",").append(type_name).append(")"); @@ -1515,7 +1564,8 @@ static constexpr OrtApi ort_api_1_to_3 = { &OrtApis::CreateEnvWithGlobalThreadPools, &OrtApis::DisablePerSessionThreads, &OrtApis::CreateThreadingOptions, - &OrtApis::ReleaseThreadingOptions}; + &OrtApis::ReleaseThreadingOptions, + &OrtApis::ModelMetadataGetCustomMetadataMapKeys}; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) // If this assert hits, read the above 'Rules on how to add a new Ort API version' @@ -1536,7 +1586,7 @@ const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION { return &ort_api_base; } -ORT_API(void, OrtApis::ReleaseEnv, _Frees_ptr_opt_ OrtEnv* value) { +ORT_API(void, OrtApis::ReleaseEnv, OrtEnv* value) { OrtEnv::Release(value); } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index b1b50c6659255..930ebfaabd63a 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -7,20 +7,22 @@ ORT_API(const OrtApi*, GetApi, uint32_t version); ORT_API(const char*, GetVersionString); ORT_API(void, ReleaseEnv, OrtEnv*); -ORT_API(void, ReleaseStatus, OrtStatus*); -ORT_API(void, ReleaseMemoryInfo, OrtMemoryInfo*); -ORT_API(void, ReleaseSession, OrtSession*); -ORT_API(void, ReleaseValue, OrtValue*); -ORT_API(void, ReleaseRunOptions, OrtRunOptions*); -ORT_API(void, ReleaseTypeInfo, OrtTypeInfo*); -ORT_API(void, ReleaseTensorTypeAndShapeInfo, OrtTensorTypeAndShapeInfo*); -ORT_API(void, ReleaseSessionOptions, OrtSessionOptions*); -ORT_API(void, ReleaseCustomOpDomain, OrtCustomOpDomain*); -ORT_API(void, ReleaseMapTypeInfo, OrtMapTypeInfo*); -ORT_API(void, ReleaseSequenceTypeInfo, OrtSequenceTypeInfo*); -ORT_API(void, ReleaseModelMetadata, OrtModelMetadata*); - -ORT_API_STATUS_IMPL(CreateStatus, OrtErrorCode code, _In_ const char* msg); +ORT_API(void, ReleaseStatus, _Frees_ptr_opt_ OrtStatus*); +ORT_API(void, ReleaseMemoryInfo, _Frees_ptr_opt_ OrtMemoryInfo*); +ORT_API(void, ReleaseSession, _Frees_ptr_opt_ OrtSession*); +ORT_API(void, ReleaseValue, _Frees_ptr_opt_ OrtValue*); +ORT_API(void, ReleaseRunOptions, _Frees_ptr_opt_ OrtRunOptions*); +ORT_API(void, ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo*); +ORT_API(void, ReleaseTensorTypeAndShapeInfo, _Frees_ptr_opt_ OrtTensorTypeAndShapeInfo*); +ORT_API(void, ReleaseSessionOptions, _Frees_ptr_opt_ OrtSessionOptions*); +ORT_API(void, ReleaseCustomOpDomain, _Frees_ptr_opt_ OrtCustomOpDomain*); +ORT_API(void, ReleaseMapTypeInfo, _Frees_ptr_opt_ OrtMapTypeInfo*); +ORT_API(void, ReleaseSequenceTypeInfo, _Frees_ptr_opt_ OrtSequenceTypeInfo*); +ORT_API(void, ReleaseModelMetadata, _Frees_ptr_opt_ OrtModelMetadata*); + +_Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL CreateStatus(OrtErrorCode code, _In_z_ const char* msg) + NO_EXCEPTION ORT_MUST_USE_RESULT; + OrtErrorCode ORT_API_CALL GetErrorCode(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; const char* ORT_API_CALL GetErrorMessage(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; @@ -36,10 +38,11 @@ ORT_API_STATUS_IMPL(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* ORT_API_STATUS_IMPL(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); -ORT_API_STATUS_IMPL(Run, _Inout_ OrtSession* sess, - _In_opt_ const OrtRunOptions* run_options, - _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, - _In_ const char* const* output_names, size_t output_names_len, _Outptr_ OrtValue** output); +ORT_API_STATUS_IMPL(Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* input, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names1, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** output); ORT_API_STATUS_IMPL(CreateSessionOptions, OrtSessionOptions** out); ORT_API_STATUS_IMPL(CloneSessionOptions, const OrtSessionOptions* input, OrtSessionOptions** out); @@ -88,8 +91,7 @@ ORT_API_STATUS_IMPL(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_m ORT_API_STATUS_IMPL(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); ORT_API_STATUS_IMPL(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, - _In_ const char* key, _Outptr_ char** value); + _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value); ORT_API_STATUS_IMPL(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value); @@ -98,7 +100,7 @@ ORT_API_STATUS_IMPL(CreateRunOptions, _Outptr_ OrtRunOptions** out); ORT_API_STATUS_IMPL(RunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int value); ORT_API_STATUS_IMPL(RunOptionsSetRunLogSeverityLevel, _Inout_ OrtRunOptions* options, int value); -ORT_API_STATUS_IMPL(RunOptionsSetRunTag, _In_ OrtRunOptions*, _In_ const char* run_tag); +ORT_API_STATUS_IMPL(RunOptionsSetRunTag, _Inout_ OrtRunOptions*, _In_ const char* run_tag); ORT_API_STATUS_IMPL(RunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, _Out_ int* out); ORT_API_STATUS_IMPL(RunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, _Out_ int* out); @@ -117,9 +119,10 @@ ORT_API_STATUS_IMPL(IsTensor, _In_ const OrtValue* value, _Out_ int* out); ORT_API_STATUS_IMPL(GetTensorMutableData, _Inout_ OrtValue* value, _Outptr_ void** out); ORT_API_STATUS_IMPL(FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len); ORT_API_STATUS_IMPL(GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len); -ORT_API_STATUS_IMPL(GetStringTensorContent, _In_ const OrtValue* value, _Out_ void* s, size_t s_len, - _Out_ size_t* offsets, size_t offsets_len); -ORT_API_STATUS_IMPL(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo*, _Out_ const OrtTensorTypeAndShapeInfo** out); +ORT_API_STATUS_IMPL(GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, + size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len); +ORT_API_STATUS_IMPL(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo*, + _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); ORT_API_STATUS_IMPL(GetOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ enum ONNXType* out); ORT_API_STATUS_IMPL(CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out); ORT_API_STATUS_IMPL(SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo*, enum ONNXTensorElementDataType type); @@ -127,10 +130,11 @@ ORT_API_STATUS_IMPL(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const i ORT_API_STATUS_IMPL(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo*, _Out_ enum ONNXTensorElementDataType* out); ORT_API_STATUS_IMPL(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); ORT_API_STATUS_IMPL(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); -ORT_API_STATUS_IMPL(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ const char* dim_params[], size_t dim_params_length); +ORT_API_STATUS_IMPL(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, + _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); -ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_ OrtTypeInfo** out); +ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); ORT_API_STATUS_IMPL(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); ORT_API_STATUS_IMPL(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* symbolic_dim, _In_ int64_t dim_override); @@ -146,13 +150,13 @@ ORT_API_STATUS_IMPL(MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtA ORT_API_STATUS_IMPL(AllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size, _Outptr_ void** out); ORT_API_STATUS_IMPL(AllocatorFree, _Inout_ OrtAllocator* ptr, void* p); -ORT_API_STATUS_IMPL(AllocatorGetInfo, _In_ const OrtAllocator* ptr, _Out_ const OrtMemoryInfo** out); +ORT_API_STATUS_IMPL(AllocatorGetInfo, _In_ const OrtAllocator* ptr, _Outptr_ const struct OrtMemoryInfo** out); ORT_API_STATUS_IMPL(GetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out); ORT_API_STATUS_IMPL(GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out); ORT_API_STATUS_IMPL(GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out); -ORT_API_STATUS_IMPL(CreateValue, _In_ const OrtValue* const* in, size_t num_values, enum ONNXType value_type, - _Outptr_ OrtValue** out); -ORT_API_STATUS_IMPL(CreateOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, +ORT_API_STATUS_IMPL(CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values, + enum ONNXType value_type, _Outptr_ OrtValue** out); +ORT_API_STATUS_IMPL(CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name, _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out); ORT_API_STATUS_IMPL(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, _Out_ void* data_container, size_t data_container_size); @@ -168,8 +172,10 @@ ORT_API_STATUS_IMPL(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, // OrtTypeInfo methods ORT_API_STATUS_IMPL(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo*, _Out_ const char** const denotation, _Out_ size_t* len); -ORT_API_STATUS_IMPL(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const OrtMapTypeInfo** out); -ORT_API_STATUS_IMPL(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const OrtSequenceTypeInfo** out); +ORT_API_STATUS_IMPL(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtMapTypeInfo** out); +ORT_API_STATUS_IMPL(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); // OrtMapTypeInfo Accessors ORT_API_STATUS_IMPL(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); @@ -185,4 +191,8 @@ ORT_ALL_ARGS_NONNULL; ORT_API_STATUS_IMPL(DisablePerSessionThreads, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out); ORT_API(void, ReleaseThreadingOptions, _Frees_ptr_opt_ OrtThreadingOptions*); + +ORT_API_STATUS_IMPL(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys); + } // namespace OrtApis diff --git a/onnxruntime/core/util/eigen_common_wrapper.h b/onnxruntime/core/util/eigen_common_wrapper.h index a99eea26ce396..711c9dcea3dc6 100644 --- a/onnxruntime/core/util/eigen_common_wrapper.h +++ b/onnxruntime/core/util/eigen_common_wrapper.h @@ -27,9 +27,8 @@ #pragma warning(disable : 4554) #pragma warning(disable : 4245) #pragma warning(disable : 4127) -//The following warning can be fixed by updating eigen to the latest, however, the new code will trigger a MSVC bug -//that will slow down the build time to 3-5 hours. -#pragma warning(disable : 4723) +#pragma warning(disable : 6313) +#pragma warning(disable : 6294) #endif #include "unsupported/Eigen/CXX11/Tensor" diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 8930aafbf46f1..5e482efbd9cd1 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -26,6 +26,7 @@ #pragma warning(push) #pragma warning(disable : 4267) #pragma warning(disable : 4127) +#pragma warning(disable : 6255) #endif #include "Eigen/src/Core/arch/Default/Half.h" #if defined(__GNUC__) diff --git a/onnxruntime/core/util/math_cpuonly.h b/onnxruntime/core/util/math_cpuonly.h index e46cd55aea330..73af5f27b8ae3 100644 --- a/onnxruntime/core/util/math_cpuonly.h +++ b/onnxruntime/core/util/math_cpuonly.h @@ -46,16 +46,17 @@ #pragma warning(disable : 4324) #pragma warning(disable : 4245) #pragma warning(disable : 4127) +#pragma warning(disable : 6255) +#pragma warning(disable : 6294) #endif #include "Eigen/Core" - +#include "Eigen/Dense" #if defined(__GNUC__) #pragma GCC diagnostic pop #else #pragma warning(pop) #endif -#include "Eigen/Dense" #include "core/framework/tensor.h" namespace onnxruntime { diff --git a/onnxruntime/featurizers_ops/cpu/cat_imputer_transformer.cc b/onnxruntime/featurizers_ops/cpu/cat_imputer_transformer.cc index 0e7d4b14f4245..f57c6c5ae788e 100644 --- a/onnxruntime/featurizers_ops/cpu/cat_imputer_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/cat_imputer_transformer.cc @@ -36,7 +36,7 @@ struct CatImputerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::CatImputerTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/count_vectorizer_transfromer.cc b/onnxruntime/featurizers_ops/cpu/count_vectorizer_transfromer.cc index 28828572c870b..40b9df4743b42 100644 --- a/onnxruntime/featurizers_ops/cpu/count_vectorizer_transfromer.cc +++ b/onnxruntime/featurizers_ops/cpu/count_vectorizer_transfromer.cc @@ -19,7 +19,7 @@ void CountVectorizerTransformerImpl(OpKernelContext* ctx) { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::CountVectorizerTransformer(archive); }()); @@ -28,8 +28,10 @@ void CountVectorizerTransformerImpl(OpKernelContext* ctx) { const std::string* input_data = input_tensor->template Data(); // Prepare the callback that would output directly to output memory std::function)> callback; - callback = [ctx](NS::Featurizers::SparseVectorEncoding result) { + bool callback_allow = true; + callback = [ctx, callback_allow](NS::Featurizers::SparseVectorEncoding result) { // Prepare output + ORT_ENFORCE(callback_allow, "callback function can only be called during execute() and special flush() when needed"); ORT_ENFORCE(result.NumElements < static_cast(std::numeric_limits::max()), "NumElements in SparseVectorEncoding is GE than max(int64)"); auto* output_tensor = ctx->Output(0, TensorShape{static_cast(result.NumElements)}); @@ -40,6 +42,9 @@ void CountVectorizerTransformerImpl(OpKernelContext* ctx) { } }; transformer.execute(*input_data, callback); + // The flush() does nothing but shows Featurizers concept + callback_allow = false; + transformer.flush(callback); }; class CountVectorizerTransformer final : public OpKernel { diff --git a/onnxruntime/featurizers_ops/cpu/date_time_transformer.cc b/onnxruntime/featurizers_ops/cpu/date_time_transformer.cc index 8a5c4310d62a4..15c9712ac3073 100644 --- a/onnxruntime/featurizers_ops/cpu/date_time_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/date_time_transformer.cc @@ -24,7 +24,7 @@ class DateTimeTransformer final : public OpKernel { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::DateTimeTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/forecasting_pivot_transformer.cc b/onnxruntime/featurizers_ops/cpu/forecasting_pivot_transformer.cc index 2727c04a8394d..514a059550ade 100644 --- a/onnxruntime/featurizers_ops/cpu/forecasting_pivot_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/forecasting_pivot_transformer.cc @@ -13,19 +13,18 @@ namespace NS = Microsoft::Featurizer; namespace onnxruntime { namespace featurizers { -template //float, double +template struct ForecastingPivotTransformerImpl { void operator()(OpKernelContext* ctx) const { - using TransformerT = Microsoft::Featurizer::Featurizers::ForecastingPivotTransformer; - using MatrixT = NS::RowMajMatrix; - using InputMatrixT = Eigen::Map; - using InputType = std::vector; + using MatrixT = NS::RowMajMatrix::nullable_type>; + using InputType = std::vector>; using OutputType = std::vector; + using TransformerT = Microsoft::Featurizer::Featurizers::ForecastingPivotTransformer>; //Get the transformer const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); TransformerT transformer(archive); // Get the Number of Rows @@ -60,15 +59,16 @@ struct ForecastingPivotTransformerImpl { std::tuple info_tuple(input_data, input_dim_1, input_dim_2); dataPtrMap.insert(std::pair>(index, info_tuple)); } - const T* input_data(std::get<0>(dataPtrMap.at(index))); - const int64_t input_dim_1(std::get<1>(dataPtrMap.at(index))); - const int64_t input_dim_2(std::get<2>(dataPtrMap.at(index))); - input.push_back(InputMatrixT(input_data, input_dim_1, input_dim_2)); + std::tuple &inputTuple(dataPtrMap.at(index)); + const T* input_data(std::get<0>(inputTuple)); + const int64_t input_dim_1(std::get<1>(inputTuple)); + const int64_t input_dim_2(std::get<2>(inputTuple)); + input.push_back(typename InputType::value_type(input_data, input_dim_1, input_dim_2)); //Increment data pointer input_data += input_dim_1 * input_dim_2; } //Execute - transformer.execute(input, callback_fn); + transformer.execute(std::make_tuple(input.begin(), input.end()), callback_fn); } transformer.flush(callback_fn); diff --git a/onnxruntime/featurizers_ops/cpu/from_string_transformer.cc b/onnxruntime/featurizers_ops/cpu/from_string_transformer.cc index da3c63705ce33..0a237b8409473 100644 --- a/onnxruntime/featurizers_ops/cpu/from_string_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/from_string_transformer.cc @@ -21,7 +21,7 @@ struct FromStringTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::FromStringTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/hash_one_hot_vectorizer_transformer.cc b/onnxruntime/featurizers_ops/cpu/hash_one_hot_vectorizer_transformer.cc index c06157fba2987..44f880728c5ff 100644 --- a/onnxruntime/featurizers_ops/cpu/hash_one_hot_vectorizer_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/hash_one_hot_vectorizer_transformer.cc @@ -21,7 +21,7 @@ struct HashOneHotVectorizerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::HashOneHotVectorizerTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/imputation_marker_transformer.cc b/onnxruntime/featurizers_ops/cpu/imputation_marker_transformer.cc index 2630badc171a8..478dbf20f5d2e 100644 --- a/onnxruntime/featurizers_ops/cpu/imputation_marker_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/imputation_marker_transformer.cc @@ -27,7 +27,7 @@ struct ImputationMarkerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::ImputationMarkerTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/label_encoder_transformer.cc b/onnxruntime/featurizers_ops/cpu/label_encoder_transformer.cc index ed5bf1b9ccac8..a6843f07805b0 100644 --- a/onnxruntime/featurizers_ops/cpu/label_encoder_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/label_encoder_transformer.cc @@ -21,7 +21,7 @@ struct LabelEncoderTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::LabelEncoderTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/laglead_operator_transformer.cc b/onnxruntime/featurizers_ops/cpu/laglead_operator_transformer.cc index 7535b0eef3fff..703c6068e7839 100644 --- a/onnxruntime/featurizers_ops/cpu/laglead_operator_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/laglead_operator_transformer.cc @@ -28,7 +28,7 @@ struct LagLeadOperatorTransformerImpl { //Get the transformer const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); typename EstimatorT::TransformerType transformer(archive); // Get the Grains @@ -52,8 +52,8 @@ struct LagLeadOperatorTransformerImpl { bool has_allocate_output_data = false; std::function callback_fn; callback_fn = [ctx, &output_grains_data, &output_data, &has_allocate_output_data, output_dim_0](OutputType value) -> void { - GrainT & output_grains(std::get(value)); - const OutputMatrixType & output_matrix(std::get(value)); + GrainT & output_grains(std::get<0>(value)); + const OutputMatrixType & output_matrix(std::get<1>(value)); //Allocate tensor memory after first output is generated if (!has_allocate_output_data) { TensorShape output_shape({output_dim_0, output_matrix.rows(), output_matrix.cols()}); diff --git a/onnxruntime/featurizers_ops/cpu/max_abs_scaler_transformer.cc b/onnxruntime/featurizers_ops/cpu/max_abs_scaler_transformer.cc index af5b00415746c..8cb9ddd781981 100644 --- a/onnxruntime/featurizers_ops/cpu/max_abs_scaler_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/max_abs_scaler_transformer.cc @@ -44,7 +44,7 @@ struct MaxAbsScalerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::MaxAbsScalerTransformer::type>(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/mean_imputer_transformer.cc b/onnxruntime/featurizers_ops/cpu/mean_imputer_transformer.cc index 884da425e06fc..073f21608df31 100644 --- a/onnxruntime/featurizers_ops/cpu/mean_imputer_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/mean_imputer_transformer.cc @@ -36,7 +36,7 @@ struct MeanImputerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return MeanImputerTransformerT(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/median_imputer_transformer.cc b/onnxruntime/featurizers_ops/cpu/median_imputer_transformer.cc index 6b3aa5bd73057..bcd5602dfbca2 100644 --- a/onnxruntime/featurizers_ops/cpu/median_imputer_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/median_imputer_transformer.cc @@ -50,7 +50,7 @@ struct MedianImputerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return MedianImputerTransformerT(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/min_max_imputer_transformer.cc b/onnxruntime/featurizers_ops/cpu/min_max_imputer_transformer.cc index 7e7f0b3c79a36..94a6e399ceb77 100644 --- a/onnxruntime/featurizers_ops/cpu/min_max_imputer_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/min_max_imputer_transformer.cc @@ -45,7 +45,7 @@ struct MinMaxImputerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return MinMaxImputerTransformerT(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/min_max_scaler_transformer.cc b/onnxruntime/featurizers_ops/cpu/min_max_scaler_transformer.cc index 6d9ae461868ef..25d1e6c74f893 100644 --- a/onnxruntime/featurizers_ops/cpu/min_max_scaler_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/min_max_scaler_transformer.cc @@ -21,7 +21,7 @@ struct MinMaxScalerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::MinMaxScalerTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/missing_dummies_transformer.cc b/onnxruntime/featurizers_ops/cpu/missing_dummies_transformer.cc index 908855a17da34..465b58e9f086a 100644 --- a/onnxruntime/featurizers_ops/cpu/missing_dummies_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/missing_dummies_transformer.cc @@ -27,7 +27,7 @@ struct MissingDummiesTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::MissingDummiesTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/mode_imputer_transformer.cc b/onnxruntime/featurizers_ops/cpu/mode_imputer_transformer.cc index c310d05ca018f..6b8e39ca635aa 100644 --- a/onnxruntime/featurizers_ops/cpu/mode_imputer_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/mode_imputer_transformer.cc @@ -45,7 +45,7 @@ struct ModeImputerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return ModeImputerTransformerT(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/normalize_transformer.cc b/onnxruntime/featurizers_ops/cpu/normalize_transformer.cc index 1337623ae65cc..41c809aec7175 100644 --- a/onnxruntime/featurizers_ops/cpu/normalize_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/normalize_transformer.cc @@ -23,7 +23,7 @@ struct NormalizeTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Transformer(archive); }()); @@ -43,7 +43,9 @@ struct NormalizeTransformerImpl { std::vector result; std::function)> callback; - callback = [&result](std::vector val) mutable { + bool callback_allow = true; + callback = [&result, callback_allow](std::vector val) { + ORT_ENFORCE(callback_allow, "callback function can only be called during execute() and special flush() when needed"); result = std::move(val); }; @@ -58,6 +60,9 @@ struct NormalizeTransformerImpl { std::copy(result.cbegin(), result.cend(), output_data); output_data += row_size; } + // The flush() does nothing but shows Featurizers concept + callback_allow = false; + transformer.flush(callback); } }; diff --git a/onnxruntime/featurizers_ops/cpu/numericalize_transformer.cc b/onnxruntime/featurizers_ops/cpu/numericalize_transformer.cc index cd9e32f056b2c..710654fbe0a9d 100644 --- a/onnxruntime/featurizers_ops/cpu/numericalize_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/numericalize_transformer.cc @@ -21,7 +21,7 @@ struct NumericalizeTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::NumericalizeTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/one_hot_encoder_transformer.cc b/onnxruntime/featurizers_ops/cpu/one_hot_encoder_transformer.cc index 0231fc9ec8da7..54c94622b037e 100644 --- a/onnxruntime/featurizers_ops/cpu/one_hot_encoder_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/one_hot_encoder_transformer.cc @@ -21,7 +21,7 @@ struct OneHotEncoderTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::OneHotEncoderTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/pca_transformer.cc b/onnxruntime/featurizers_ops/cpu/pca_transformer.cc index 03a07b01f4260..2dd644fa35ea4 100644 --- a/onnxruntime/featurizers_ops/cpu/pca_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/pca_transformer.cc @@ -25,7 +25,7 @@ struct PCATransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::PCATransformer(archive); }()); @@ -47,10 +47,15 @@ struct PCATransformerImpl { Eigen::Map output_matrix(output_data, dim_0, dim_1); std::function callback; - callback = [&output_matrix](MatrixT val) { + bool callback_allow = true; + callback = [&output_matrix, callback_allow](MatrixT val) { + ORT_ENFORCE(callback_allow, "callback function can only be called during execute() and special flush() when needed"); output_matrix = val; }; transformer.execute(input_matrix, callback); + // The flush() does nothing but shows Featurizers concept + callback_allow = false; + transformer.flush(callback); } }; diff --git a/onnxruntime/featurizers_ops/cpu/robust_scaler_transformer.cc b/onnxruntime/featurizers_ops/cpu/robust_scaler_transformer.cc index 608f1c338afb7..4d67b86c3ebc7 100644 --- a/onnxruntime/featurizers_ops/cpu/robust_scaler_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/robust_scaler_transformer.cc @@ -44,7 +44,7 @@ struct RobustScalerTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::RobustScalerTransformer::type>(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/rolling_window_transformer.cc b/onnxruntime/featurizers_ops/cpu/rolling_window_transformer.cc index 4ee900f7cd251..6fbe055ab4f3d 100644 --- a/onnxruntime/featurizers_ops/cpu/rolling_window_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/rolling_window_transformer.cc @@ -24,7 +24,7 @@ struct RollingWindowTransformerImpl { //Get the transformer const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); typename EstimatorT::TransformerType transformer(archive); // Get the Grains @@ -69,6 +69,7 @@ struct RollingWindowTransformerImpl { target_data++; grains_data += grains_num; } + transformer.flush(callback_fn); } }; diff --git a/onnxruntime/featurizers_ops/cpu/short_grain_dropper_transformer.cc b/onnxruntime/featurizers_ops/cpu/short_grain_dropper_transformer.cc index cda8ca2e13a85..7cfc34fad2134 100644 --- a/onnxruntime/featurizers_ops/cpu/short_grain_dropper_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/short_grain_dropper_transformer.cc @@ -19,7 +19,7 @@ void ShortGrainDropperTransformerImpl(OpKernelContext* ctx) { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::ShortGrainDropperTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/standard_scale_wrapper_transformer.cc b/onnxruntime/featurizers_ops/cpu/standard_scale_wrapper_transformer.cc index c0918f0cadad5..ed0213d1e0410 100644 --- a/onnxruntime/featurizers_ops/cpu/standard_scale_wrapper_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/standard_scale_wrapper_transformer.cc @@ -21,7 +21,7 @@ struct StandardScaleWrapperTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::StandardScalerTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/string_transformer.cc b/onnxruntime/featurizers_ops/cpu/string_transformer.cc index 60a2c70106488..ac9342e0f7148 100644 --- a/onnxruntime/featurizers_ops/cpu/string_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/string_transformer.cc @@ -21,7 +21,7 @@ struct StringTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::StringTransformer(archive); }()); diff --git a/onnxruntime/featurizers_ops/cpu/tfidf_vectorizer_transformer.cc b/onnxruntime/featurizers_ops/cpu/tfidf_vectorizer_transformer.cc index 4b7b54c1bdbb1..9365c3742c9ab 100644 --- a/onnxruntime/featurizers_ops/cpu/tfidf_vectorizer_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/tfidf_vectorizer_transformer.cc @@ -19,7 +19,7 @@ void TfidfVectorizerTransformerImpl(OpKernelContext* ctx) { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::TfidfVectorizerTransformer(archive); }()); @@ -29,8 +29,10 @@ void TfidfVectorizerTransformerImpl(OpKernelContext* ctx) { // Prepare the callback that would output directly to output memory std::function)> callback; - callback = [ctx](NS::Featurizers::SparseVectorEncoding result) { + bool callback_allow = true; + callback = [ctx, callback_allow](NS::Featurizers::SparseVectorEncoding result) { // Prepare output + ORT_ENFORCE(callback_allow, "callback function can only be called during execute() and special flush() when needed"); ORT_ENFORCE(result.NumElements < static_cast(std::numeric_limits::max()), "NumElements in SparseVectorEncoding is GE than max(int64)"); auto* output_tensor = ctx->Output(0, TensorShape{static_cast(result.NumElements)}); @@ -41,6 +43,9 @@ void TfidfVectorizerTransformerImpl(OpKernelContext* ctx) { } }; transformer.execute(*input_data, callback); + // The flush() does nothing but shows Featurizers concept + callback_allow = false; + transformer.flush(callback); } class TfidfVectorizerTransformer final : public OpKernel { diff --git a/onnxruntime/featurizers_ops/cpu/truncated_svd_transformer.cc b/onnxruntime/featurizers_ops/cpu/truncated_svd_transformer.cc index 5e1a805523d66..f707bbdf91ac0 100644 --- a/onnxruntime/featurizers_ops/cpu/truncated_svd_transformer.cc +++ b/onnxruntime/featurizers_ops/cpu/truncated_svd_transformer.cc @@ -25,7 +25,7 @@ struct TruncatedSVDTransformerImpl { const auto* state_tensor(ctx->Input(0)); const uint8_t* const state_data(state_tensor->Data()); - Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().GetDims()[0]); + Microsoft::Featurizer::Archive archive(state_data, state_tensor->Shape().Size()); return Microsoft::Featurizer::Featurizers::TruncatedSVDTransformer(archive); }()); @@ -47,10 +47,15 @@ struct TruncatedSVDTransformerImpl { Eigen::Map output_matrix(output_data, dim_0, dim_1); std::function callback; - callback = [&output_matrix](MatrixT val) { + bool callback_allow = true; + callback = [&output_matrix, callback_allow](MatrixT val) { + ORT_ENFORCE(callback_allow, "callback function can only be called during execute() and special flush() when needed"); output_matrix = val; }; transformer.execute(input_matrix, callback); + // The flush() does nothing but shows Featurizers concept + callback_allow = false; + transformer.flush(callback); } }; diff --git a/onnxruntime/gsl/gsl-lite.hpp b/onnxruntime/gsl/gsl-lite.hpp index 876defb5aad77..27c0774bce57a 100644 --- a/onnxruntime/gsl/gsl-lite.hpp +++ b/onnxruntime/gsl/gsl-lite.hpp @@ -1148,12 +1148,7 @@ gsl_DISABLE_MSVC_WARNINGS(26410 26415 26418 26472 26439 26440 26473 26481 26482 } #if gsl_HAVE(TYPE_TRAITS) -#if gsl_COMPILER_MSVC_VERSION - // Suppress MSVC level 4 warning C4127 (conditional expression is constant) - if (0, !detail::is_same_signedness::value && ((t < T()) != (u < U()))) -#else if (!detail::is_same_signedness::value && ((t < T()) != (u < U()))) -#endif #else // Don't assume T() works: if ((t < 0) != (u < 0)) diff --git a/onnxruntime/test/contrib_ops/fft_op_test.cc b/onnxruntime/test/contrib_ops/fft_op_test.cc new file mode 100644 index 0000000000000..c259cbaf25ca3 --- /dev/null +++ b/onnxruntime/test/contrib_ops/fft_op_test.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/common/cuda_op_test_utils.h" + +namespace onnxruntime { +namespace test { +TEST(ContribOpTest, Rfft) { + if (DefaultCudaExecutionProvider() == nullptr) return; + + OpTester test("Rfft", 1, onnxruntime::kMSDomain); + test.AddAttribute("signal_ndim", static_cast(2)); + test.AddAttribute("onesided", static_cast(1)); + test.AddAttribute("normalized", static_cast(0)); + test.AddInput("X", {4, 5}, std::vector{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f}); + test.AddOutput("Y", {4, 3, 2}, std::vector{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(ContribOpTest, Irfft) { + if (DefaultCudaExecutionProvider() == nullptr) return; + + OpTester test("Irfft", 1, onnxruntime::kMSDomain); + test.AddAttribute("signal_ndim", static_cast(2)); + test.AddAttribute("onesided", static_cast(1)); + test.AddAttribute("normalized", static_cast(0)); + test.AddInput("X", {4, 3, 2}, std::vector{-5.6404f, 0.0000f, -3.6965f, -1.3401f, -6.6836f, -3.5202f, -3.3891f, 0.0769f, 1.4521f, 3.2068f, 5.9398f, -1.2344f, -0.1682f, 0.0000f, 1.9681f, -1.6241f, -3.3442f, 1.6817f, -3.3891f, -0.0769f, 2.9557f, -2.9384f, -1.2900f, -4.8683f}); + test.AddOutput("Y", {4, 5}, std::vector{-0.8992f, 0.6117f, -1.6091f, -0.4155f, -0.8346f, -2.1596f, -0.0853f, 0.7232f, 0.1941f, -0.0789f, -2.0329f, 1.1031f, 0.6869f, -0.5042f, 0.9895f, -0.1884f, 0.2858f, -1.5831f, 0.9917f, -0.8356f}); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/quantize_ops_test.cc b/onnxruntime/test/contrib_ops/quantize_ops_test.cc index 4da03df8b106a..9baae4b1f4402 100644 --- a/onnxruntime/test/contrib_ops/quantize_ops_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_ops_test.cc @@ -2,11 +2,57 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" namespace onnxruntime { namespace test { +// scalar zero & scale with uint8 +TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_uint8) { + OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{4}; + test.AddInput("x", dims, {0, 3, 128, 255}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("x_zero_point", {}, {128}); + test.AddOutput("y", dims, {-256.0f, -250.0f, 0.0f, 254.0f}); + test.Run(); +} + +// scalar zero & scale with int8 +TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_float_int8) { + OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{4}; + test.AddInput("x", dims, {-30, -3, 100, 127}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("x_zero_point", {}, {-10}); + test.AddOutput("y", dims, {-40.0f, 14.0f, 220.0f, 274.0f}); + test.Run(); +} + +#ifdef USE_CUDA +TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_half_uint8) { + OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{4}; + test.AddInput("x", dims, {0, 3, 128, 255}); + test.AddInput("x_scale", {}, ToFloat16({2.0f})); + test.AddInput("x_zero_point", {}, {128}); + test.AddOutput("y", dims, ToFloat16({-256.0f, -250.0f, 0.0f, 254.0f})); + test.Run(); +} + +// scalar zero & scale with int8 +TEST(DequantizeLinearOpTest, DequantizeLinear_per_tensor_half_int8) { + OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{4}; + test.AddInput("x", dims, {-30, -3, 100, 127}); + test.AddInput("x_scale", {}, ToFloat16({2.0f})); + test.AddInput("x_zero_point", {}, {-10}); + test.AddOutput("y", dims, ToFloat16({-40.0f, 14.0f, 220.0f, 274.0f})); + test.Run(); +} +#endif + // 1d zero & scale with uint8 broadcast axis 0 TEST(DequantizeLinearContribOpTest, DequantizeLinear_0) { OpTester test("DequantizeLinear", 1, onnxruntime::kMSDomain); @@ -120,18 +166,114 @@ TEST(DequantizeLinearContribOpTest, DequantizeLinear_3) { } // quantize with scalar zero point and scale -TEST(QuantizeLinearContribOpTest, QuantizeLinear_0) { +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_uint8) { OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); - std::vector dims{6}; - test.AddInput("x", dims, {0, 2, 3, 1000, -254, -1000}); + std::vector dims{16}; + test.AddInput("x", dims, { + 0.f, 2.f, + 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // low case + 3.1f, -3.1f, // up case + 254.f, -256.f, // critical point + 255.f, -257.f, // critical point + 256.f, -258.f, // critical point + 1000.f, -1000.f // saturate case + }); test.AddInput("y_scale", {}, {2.0f}); test.AddInput("y_zero_point", {}, {128}); - test.AddOutput("y", dims, {128, 129, 130, 255, 1, 0}); + test.AddOutput("y", dims, {128, 129, + 130, 126, + 129, 127, + 130, 126, + 255, 0, + 255, 0, + 255, 0, + 255, 0}); + test.Run(); +} + +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int8) { + OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{16}; + test.AddInput("x", dims, { + 0.f, 2.f, + 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // low case + 3.1f, -3.1f, // up case + 254.f, -256.f, // critical point + 255.f, -257.f, // critical point + 256.f, -258.f, // critical point + 1000.f, -1000.f // saturate case + }); + test.AddInput("y_scale", {}, {2.0f}); + test.AddInput("y_zero_point", {}, {1}); + test.AddOutput("y", dims, {1, 2, + 3, -1, + 2, 0, + 3, -1, + 127, -127, + 127, -127, + 127, -128, + 127, -128}); + test.Run(); +} + +#ifdef USE_CUDA +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_half_uint8) { + OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{16}; + test.AddInput("x", dims, ToFloat16({ + 0.f, 2.f, + 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // low case + 3.1f, -3.1f, // up case + 254.f, -256.f, // critical point + 255.f, -257.f, // critical point + 256.f, -258.f, // critical point + 1000.f, -1000.f // saturate case + })); + test.AddInput("y_scale", {}, ToFloat16({2.0f})); + test.AddInput("y_zero_point", {}, {128}); + test.AddOutput("y", dims, {128, 129, + 130, 126, + 129, 127, + 130, 126, + 255, 0, + 255, 0, + 255, 0, + 255, 0}); + test.Run(); +} + +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_half_int8) { + OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{16}; + test.AddInput("x", dims, ToFloat16({ + 0.f, 2.f, + 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // low case + 3.1f, -3.1f, // up case + 254.f, -256.f, // critical point + 255.f, -257.f, // critical point + 256.f, -258.f, // critical point + 1000.f, -1000.f // saturate case + })); + test.AddInput("y_scale", {}, ToFloat16({2.0f})); + test.AddInput("y_zero_point", {}, {1}); + test.AddOutput("y", dims, {1, 2, + 3, -1, + 2, 0, + 3, -1, + 127, -127, + 127, -127, + 127, -128, + 127, -128}); test.Run(); } +#endif // quantize with broadcasting -TEST(QuantizeLinearContribOpTest, QuantizeLinear_1) { +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_channel) { OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); std::vector dims{3, 4}; test.AddInput("X", dims, @@ -144,12 +286,12 @@ TEST(QuantizeLinearContribOpTest, QuantizeLinear_1) { test.AddOutput("Y", dims, {0, 2, 3, 255, 0, 1, 2, 255, - 0, 1, 1, 250}); + 0, 0, 1, 250}); test.Run(); } // quantize with broadcasting and negative axis (-2 resolves to axis 0) -TEST(QuantizeLinearContribOpTest, QuantizeLinear_2) { +TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_channel_negative_axis) { OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); std::vector dims{3, 4}; test.AddInput("X", dims, @@ -162,7 +304,7 @@ TEST(QuantizeLinearContribOpTest, QuantizeLinear_2) { test.AddOutput("Y", dims, {0, 2, 3, 255, 0, 1, 2, 255, - 0, 1, 1, 250}); + 0, 0, 1, 250}); test.Run(); } } // namespace test diff --git a/onnxruntime/test/featurizers_ops/forecasting_pivot_transformer_test.cc b/onnxruntime/test/featurizers_ops/forecasting_pivot_transformer_test.cc index 6c15a68e79db7..903efcdf1f4d4 100644 --- a/onnxruntime/test/featurizers_ops/forecasting_pivot_transformer_test.cc +++ b/onnxruntime/test/featurizers_ops/forecasting_pivot_transformer_test.cc @@ -14,12 +14,13 @@ namespace NS = Microsoft::Featurizer; namespace onnxruntime { namespace test { - namespace { template std::vector GetStream() { - NS::Featurizers::ForecastingPivotTransformer transformer; + using MatrixT = NS::RowMajMatrix::nullable_type>; + using InputType = std::vector>; + NS::Featurizers::ForecastingPivotTransformer> transformer; NS::Archive ar; transformer.save(ar); return ar.commit(); diff --git a/onnxruntime/test/featurizers_ops/truncated_svdtransformer_test.cc b/onnxruntime/test/featurizers_ops/truncated_svdtransformer_test.cc index 6e4fa47a4e3a4..96fc5f0dc83e8 100644 --- a/onnxruntime/test/featurizers_ops/truncated_svdtransformer_test.cc +++ b/onnxruntime/test/featurizers_ops/truncated_svdtransformer_test.cc @@ -17,7 +17,8 @@ template std::vector GetStream(const MatrixT& training_matrix) { using EstimatorT = NS::Featurizers::TruncatedSVDEstimator; NS::AnnotationMapsPtr const pAllColumnAnnotations(NS::CreateTestAnnotationMapsPtr(1)); - EstimatorT estimator(pAllColumnAnnotations, 0); + //Hardcode the seed = 42 + EstimatorT estimator(pAllColumnAnnotations, 0, static_cast(42)); std::vector> trainingBatches = NS::TestHelpers::make_vector>( NS::TestHelpers::make_vector(training_matrix)); @@ -56,7 +57,7 @@ void TruncatedSVDTransformerTestRowMajStandard() { // platform to platform enough so we choose to check max STD deviation. OpTester::CustomOutputVerifierFn ver_fn = [&verify_matrix](const std::vector& fetches, const std::string& provider) { std::cout << "Verifying TruncatedSVDTransformerTestRowMajStandard:" << provider << std::endl; - const float eps = 0.0003f; + const float eps = 0.0001f; ASSERT_TRUE(fetches.size() == 1); const auto& fetch = fetches.at(0); const auto& tensor = fetch.Get(); @@ -69,15 +70,13 @@ void TruncatedSVDTransformerTestRowMajStandard() { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNupharExecutionProvider}, nullptr, {}, ORT_SEQUENTIAL, ver_fn); } -// restrict the test for now -// TEST(FeaturizersTests, TruncatedSVDTransformer_double) { -// TruncatedSVDTransformerTestRowMajStandard(); -// } +TEST(FeaturizersTests, TruncatedSVDTransformer_double) { + TruncatedSVDTransformerTestRowMajStandard(); +} -// restrict the test for now -// TEST(FeaturizersTests, TruncatedSVDTransformer_float) { -// TruncatedSVDTransformerTestRowMajStandard(); -// } +TEST(FeaturizersTests, TruncatedSVDTransformer_float) { + TruncatedSVDTransformerTestRowMajStandard(); +} } // namespace test } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/framework/cuda/fence_cuda_test.cc b/onnxruntime/test/framework/cuda/fence_cuda_test.cc index 2367393769057..6ac5ee619c8d9 100644 --- a/onnxruntime/test/framework/cuda/fence_cuda_test.cc +++ b/onnxruntime/test/framework/cuda/fence_cuda_test.cc @@ -25,6 +25,7 @@ #include "test/framework/test_utils.h" #include "gtest/gtest.h" #include "core/util/protobuf_parsing_utils.h" +#include "asserts.h" using namespace std; using namespace ONNX_NAMESPACE; @@ -120,12 +121,13 @@ TEST(CUDAFenceTests, DISABLED_PartOnCPU) { FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; - session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); + ASSERT_STATUS_OK(session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info))); ASSERT_TRUE(session.Initialize().IsOK()); ASSERT_TRUE(1 == CountCopyNodes(graph)); vector outputs; - session.Run(std::unordered_map{{"X1", value}}, std::vector{"Out"}, &outputs); + ASSERT_STATUS_OK( + session.Run(std::unordered_map{{"X1", value}}, std::vector{"Out"}, &outputs)); ASSERT_TRUE(1 == outputs.size()); const Tensor& output = outputs[0].Get(); //Use reinterpret_cast to bypass a gcc bug: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51213 @@ -174,11 +176,12 @@ TEST(CUDAFenceTests, TileWithInitializer) { FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; - session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); + ASSERT_STATUS_OK(session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info))); ASSERT_TRUE(session.Initialize().IsOK()); vector outputs; - session.Run(std::unordered_map{{"X1", value}}, std::vector{"Y"}, &outputs); + ASSERT_STATUS_OK( + session.Run(std::unordered_map{{"X1", value}}, std::vector{"Y"}, &outputs)); ASSERT_TRUE(1 == outputs.size()); const Tensor& output = outputs[0].Get(); //Use reinterpret_cast to bypass a gcc bug: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51213 @@ -239,11 +242,12 @@ TEST(CUDAFenceTests, TileWithComputedInput) { FenceCudaTestInferenceSession session(so, GetEnvironment()); LoadInferenceSessionFromModel(session, *model); CUDAExecutionProviderInfo xp_info; - session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info)); + ASSERT_STATUS_OK(session.RegisterExecutionProvider(onnxruntime::make_unique(xp_info))); ASSERT_TRUE(session.Initialize().IsOK()); vector outputs; - session.Run(std::unordered_map{{"X1", value}}, std::vector{"Out"}, &outputs); + ASSERT_STATUS_OK( + session.Run(std::unordered_map{{"X1", value}}, std::vector{"Out"}, &outputs)); ASSERT_TRUE(1 == outputs.size()); const Tensor& output = outputs[0].Get(); //Use reinterpret_cast to bypass a gcc bug: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51213 diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index 68100d814628d..41cf044cffa85 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -9,7 +9,7 @@ #include "core/session/inference_session.h" #include "test_utils.h" #include "test/test_environment.h" - +#include "asserts.h" #include "gtest/gtest.h" #include "gmock/gmock.h" @@ -52,66 +52,63 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { onnxruntime::Node* node = &graph.AddNode("node1", "Relu", "Relu operator", ArgMap{&input_def}, ArgMap{&output_def}); node->SetExecutionProviderType(kCpuExecutionProvider); - Status status = graph.Resolve(); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(graph.Resolve()); auto cpu_xp = CreateCPUExecutionProvider(); auto xp_typ = cpu_xp->Type(); ExecutionProviders execution_providers; execution_providers.Add(xp_typ, std::move(cpu_xp)); KernelRegistryManager kernel_registry_manager; - status = kernel_registry_manager.RegisterKernels(execution_providers); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); SessionState state{execution_providers, true, &tp_, nullptr}; - status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(state.SetGraphAndCreateKernels(graph, kernel_registry_manager)); node->SetExecutionProviderType(xp_typ); std::unique_ptr p_seq_exec_plan; // TODO below line is for testing only. In production use SequentialPlanner::CreatePlan() SequentialPlannerContext context(ExecutionMode::ORT_SEQUENTIAL); - status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, kernel_registry_manager, - state.GetOrtValueNameIdxMap(), context, p_seq_exec_plan); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, kernel_registry_manager, + state.GetOrtValueNameIdxMap(), context, p_seq_exec_plan)); state.SetExecutionPlan(std::move(p_seq_exec_plan)); vector outputs; ExecutionFrame frame({}, {}, {}, outputs, {}, state); int start_index = frame.GetNodeOffset(node->Index()); - EXPECT_EQ(start_index, 0); + ASSERT_EQ(start_index, 0); TensorShape shape(std::vector{2, 3}); OrtValue& mlvalue0 = *frame.GetMutableNodeInputOrOutputMLValue(start_index); - status = frame.AllocateMLValueTensorSelfOwnBuffer(mlvalue0, start_index, DataTypeImpl::GetType(), - execution_providers.Get(xp_typ)->GetAllocator(0, OrtMemTypeDefault)->Info(), shape); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(frame.AllocateMLValueTensorSelfOwnBuffer(mlvalue0, start_index, DataTypeImpl::GetType(), + execution_providers.Get(xp_typ)->GetAllocator(0, OrtMemTypeDefault)->Info(), shape)); OrtValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0); - Tensor* p_tensor = p_ml_value ? p_ml_value->GetMutable() : nullptr; - EXPECT_TRUE(p_tensor); + ASSERT_TRUE(p_ml_value != nullptr); + Tensor* p_tensor = p_ml_value->GetMutable(); + ASSERT_TRUE(p_tensor != nullptr); //Use reinterpret_cast to bypass a gcc bug: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51213 - EXPECT_EQ(*reinterpret_cast*>(&p_tensor->Shape()), *reinterpret_cast*>(&shape)); - EXPECT_EQ(p_tensor->DataType(), DataTypeImpl::GetType()); + ASSERT_EQ(*reinterpret_cast*>(&p_tensor->Shape()), + *reinterpret_cast*>(&shape)); + ASSERT_EQ(p_tensor->DataType(), DataTypeImpl::GetType()); //test share memory from tensor TensorShape shape2(std::vector{3, 2}); OrtValue& mlvalue1 = *frame.GetMutableNodeInputOrOutputMLValue(start_index + 1); - status = frame.AllocateMLValueTensorPreAllocateBuffer(mlvalue1, + ASSERT_STATUS_OK(frame.AllocateMLValueTensorPreAllocateBuffer(mlvalue1, start_index, DataTypeImpl::GetType(), p_tensor->Location(), - shape2); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + shape2)); const OrtValue* p_ml_value_const = frame.GetNodeInputOrOutputMLValue(1); auto tensor2 = p_ml_value_const ? &(p_ml_value_const->Get()) : nullptr; - EXPECT_TRUE(tensor2); + ASSERT_TRUE(tensor2); //Use reinterpret_cast to bypass a gcc bug: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51213 - EXPECT_EQ(*reinterpret_cast*>(&tensor2->Shape()), *reinterpret_cast*>(&shape2)); - EXPECT_EQ(tensor2->template Data(), p_tensor->template Data()); + ASSERT_EQ(*reinterpret_cast*>(&tensor2->Shape()), + *reinterpret_cast*>(&shape2)); + ASSERT_EQ(tensor2->template Data(), p_tensor->template Data()); } TEST_F(ExecutionFrameTest, FeedInDataTest) { @@ -143,13 +140,12 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; execution_providers.Add(xp_typ, std::move(cpu_xp)); - EXPECT_TRUE(kernel_registry_manager.RegisterKernels(execution_providers).IsOK()); + ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); SessionState state{execution_providers, true, &tp_, nullptr}; - auto status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(state.SetGraphAndCreateKernels(graph, kernel_registry_manager)); const OrtValueNameIdxMap& mlvalue_name_idx_map = state.GetOrtValueNameIdxMap(); - int x_idx, y_idx; + int x_idx = -1, y_idx = -1; ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X", x_idx).IsOK()); ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK()); @@ -158,11 +154,12 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { OrtValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0); Tensor* p_tensor_arg_0 = p_ml_value ? p_ml_value->GetMutable() : nullptr; - EXPECT_TRUE(p_tensor_arg_0); + ASSERT_TRUE(p_tensor_arg_0); //Use reinterpret_cast to bypass a gcc bug: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51213 - EXPECT_EQ(*reinterpret_cast*>(&p_tensor_arg_0->Shape()), *reinterpret_cast*>(&shape)); - EXPECT_EQ(p_tensor_arg_0->DataType(), DataTypeImpl::GetType()); - EXPECT_EQ(p_tensor_arg_0->MutableData(), value.GetMutable()->MutableData()); + ASSERT_EQ(*reinterpret_cast*>(&p_tensor_arg_0->Shape()), + *reinterpret_cast*>(&shape)); + ASSERT_EQ(p_tensor_arg_0->DataType(), DataTypeImpl::GetType()); + ASSERT_EQ(p_tensor_arg_0->MutableData(), value.GetMutable()->MutableData()); } TEST_F(ExecutionFrameTest, MemPatternTest) { @@ -189,23 +186,21 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { graph.AddNode("node3", "Clip", "clip1", ArgMap{&gemm2_out_def}, ArgMap{&clip_out_def}) .SetExecutionProviderType(xp_type); - auto status = graph.Resolve(); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(graph.Resolve()); KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; execution_providers.Add(xp_type, std::move(cpu_xp)); - kernel_registry_manager.RegisterKernels(execution_providers); + ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); //1. prepare input SessionState state{execution_providers, true, &tp_, nullptr}; - status = state.SetGraphAndCreateKernels(graph, kernel_registry_manager); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(state.SetGraphAndCreateKernels(graph, kernel_registry_manager)); const OrtValueNameIdxMap& mlvalue_name_idx_map(state.GetOrtValueNameIdxMap()); - int x1_idx, x2_idx, x3_idx; - int t1_idx, t2_idx, t3_idx; + int x1_idx = -1, x2_idx = -1, x3_idx = -1; + int t1_idx = -1, t2_idx = -1, t3_idx = -1; ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X1", x1_idx).IsOK()); ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X2", x2_idx).IsOK()); ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("X3", x3_idx).IsOK()); @@ -229,9 +224,8 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { std::unique_ptr p_seq_exec_plan = onnxruntime::make_unique(); SequentialPlannerContext context(ExecutionMode::ORT_SEQUENTIAL); - status = SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, kernel_registry_manager, - mlvalue_name_idx_map, context, p_seq_exec_plan); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(SequentialPlanner::CreatePlan(nullptr, GraphViewer(graph), {}, execution_providers, kernel_registry_manager, + mlvalue_name_idx_map, context, p_seq_exec_plan)); state.SetExecutionPlan(std::move(p_seq_exec_plan)); @@ -242,34 +236,29 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { OrtValue& mlvalue4 = *frame.GetMutableNodeInputOrOutputMLValue(4); OrtValue& mlvalue5 = *frame.GetMutableNodeInputOrOutputMLValue(5); - status = frame.AllocateMLValueTensorSelfOwnBuffer(mlvalue3, 3, + ASSERT_STATUS_OK(frame.AllocateMLValueTensorSelfOwnBuffer(mlvalue3, 3, DataTypeImpl::GetType(), cpu_allocator->Info(), - TensorShape(std::vector{2, 2})); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + TensorShape(std::vector{2, 2}))); - status = frame.AllocateMLValueTensorSelfOwnBuffer(mlvalue4, 4, + ASSERT_STATUS_OK(frame.AllocateMLValueTensorSelfOwnBuffer(mlvalue4, 4, DataTypeImpl::GetType(), cpu_allocator->Info(), - TensorShape(std::vector{2, 3})); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + TensorShape(std::vector{2, 3}))); - status = frame.AllocateMLValueTensorSelfOwnBuffer(mlvalue5, 5, + ASSERT_STATUS_OK(frame.AllocateMLValueTensorSelfOwnBuffer(mlvalue5, 5, DataTypeImpl::GetType(), cpu_allocator->Info(), - TensorShape(std::vector{2, 3})); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); - + TensorShape(std::vector{2, 3}))); MemoryPatternGroup pattern; - status = frame.GeneratePatterns(&pattern); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(frame.GeneratePatterns(&pattern)); - EXPECT_EQ(pattern.patterns.size(), pattern.locations.size()); - EXPECT_EQ(pattern.patterns.size(), 1u); + ASSERT_EQ(pattern.patterns.size(), pattern.locations.size()); + ASSERT_EQ(pattern.patterns.size(), 1u); auto p = pattern.GetPatterns(cpu_allocator->Info()); - EXPECT_EQ(p->PeakSize(), 2u * 64u); // each allocation is 64-byte aligned - EXPECT_EQ(p->GetBlock(3)->offset_, 0u); - EXPECT_EQ(p->GetBlock(4)->offset_, 64u); + ASSERT_EQ(p->PeakSize(), 2u * 64u); // each allocation is 64-byte aligned + ASSERT_EQ(p->GetBlock(3)->offset_, 0u); + ASSERT_EQ(p->GetBlock(4)->offset_, 64u); } TEST(ExecutionFrameTestWithoutSessionState, BadModelInvalidDimParamUsage) { @@ -280,9 +269,8 @@ TEST(ExecutionFrameTestWithoutSessionState, BadModelInvalidDimParamUsage) { so.session_logid = "BadModelInvalidDimParamUsage"; InferenceSession session_object{so, GetEnvironment()}; - Status st; - ASSERT_TRUE((st = session_object.Load("testdata/invalid_dim_param_value_repetition.onnx")).IsOK()) << st; - ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st; + ASSERT_STATUS_OK(session_object.Load("testdata/invalid_dim_param_value_repetition.onnx")); + ASSERT_STATUS_OK(session_object.Initialize()); std::vector dims_X = {10, 6}; std::vector values_X; @@ -303,7 +291,7 @@ TEST(ExecutionFrameTestWithoutSessionState, BadModelInvalidDimParamUsage) { // Now run RunOptions run_options; - st = session_object.Run(run_options, feeds, output_names, &fetches); + auto st = session_object.Run(run_options, feeds, output_names, &fetches); EXPECT_FALSE(st.IsOK()) << st; EXPECT_THAT(st.ErrorMessage(), testing::HasSubstr("Shape mismatch attempting to re-use buffer.")); diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 201e7d679352b..0638d65c2da8e 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -44,11 +44,17 @@ using namespace std; using namespace ONNX_NAMESPACE; using namespace onnxruntime::logging; - +namespace { +struct KernelRegistryAndStatus { + std::shared_ptr kernel_registry = std::make_shared(); + Status st; +}; +} // namespace namespace onnxruntime { class FuseAdd : public OpKernel { public: - FuseAdd(const OpKernelInfo& info) : OpKernel(info) {} + explicit FuseAdd(const OpKernelInfo& info) : OpKernel(info) { + } Status Compute(OpKernelContext* context) const override { auto X = context->Input(0); @@ -62,8 +68,8 @@ class FuseAdd : public OpKernel { return Status::OK(); } }; -std::string kFuseTest = "FuseTest"; -std::string kFuseExecutionProvider = "FuseExecutionProvider"; +constexpr const char* kFuseTest = "FuseTest"; +constexpr const char* kFuseExecutionProvider = "FuseExecutionProvider"; class ONNX_OPERATOR_KERNEL_CLASS_NAME(kFuseExecutionProvider, kFuseTest, 1, FuseAdd); ONNX_OPERATOR_KERNEL_EX(FuseAdd, kFuseTest, @@ -72,22 +78,27 @@ ONNX_OPERATOR_KERNEL_EX(FuseAdd, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), FuseAdd); -void RegisterOperatorKernels(KernelRegistry& kernel_registry) { - kernel_registry.Register(BuildKernelCreateInfo()); +Status RegisterOperatorKernels(KernelRegistry& kernel_registry) { + return kernel_registry.Register( + BuildKernelCreateInfo()); } -std::shared_ptr GetFusedKernelRegistry() { - std::shared_ptr kernel_registry = std::make_shared(); - RegisterOperatorKernels(*kernel_registry); - return kernel_registry; +KernelRegistryAndStatus GetFusedKernelRegistry() { + KernelRegistryAndStatus ret; + ret.st = RegisterOperatorKernels(*ret.kernel_registry); + return ret; } class FuseExecutionProvider : public IExecutionProvider { public: explicit FuseExecutionProvider() : IExecutionProvider{kFuseExecutionProvider} { - DeviceAllocatorRegistrationInfo device_info({OrtMemTypeDefault, - [](int) { return onnxruntime::make_unique(); }, - std::numeric_limits::max()}); + DeviceAllocatorRegistrationInfo device_info( + {OrtMemTypeDefault, + [](int) { + return onnxruntime::make_unique( + onnxruntime::make_unique("Fuse", OrtAllocatorType::OrtDeviceAllocator)); + }, + std::numeric_limits::max()}); InsertAllocator(device_info.factory(0)); } @@ -113,8 +124,10 @@ class FuseExecutionProvider : public IExecutionProvider { } std::shared_ptr GetKernelRegistry() const override { - static std::shared_ptr kernel_registry = GetFusedKernelRegistry(); - return kernel_registry; + static KernelRegistryAndStatus k = GetFusedKernelRegistry(); + // throw if the registry failed to initialize + ORT_THROW_IF_ERROR(k.st); + return k.kernel_registry; } }; @@ -341,8 +354,8 @@ TEST(InferenceSessionTests, NoTimeout) { InferenceSession session_object{so, GetEnvironment()}; Status st; - ASSERT_TRUE((st = session_object.Load(MODEL_URI)).IsOK()) << st.ErrorMessage(); - ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st.ErrorMessage(); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "one session/one tag"; @@ -356,8 +369,8 @@ TEST(InferenceSessionTests, DisableCPUArena) { so.enable_cpu_mem_arena = false; InferenceSession session_object{so, GetEnvironment()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "one session/one tag"; @@ -385,7 +398,7 @@ TEST(InferenceSessionTests, TestModelSerialization) { so.optimized_model_filepath = ToWideString(test_model + "-TransformLevel-" + std::to_string(static_cast(so.graph_optimization_level))); InferenceSessionGetGraphWrapper session_object{so, GetEnvironment()}; ASSERT_TRUE(session_object.Load(test_model).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Initialize()); // Assert that model has been transformed and identity Node is removed. const auto& graph = session_object.GetGraph(); @@ -459,11 +472,10 @@ TEST(InferenceSessionTests, ModelMetadata) { so.session_logid = "InferenceSessionTests.ModelMetadata"; InferenceSession session_object{so, GetEnvironment()}; auto model_uri = ORT_TSTR("../models/opset8/test_squeezenet/model.onnx"); - ASSERT_TRUE(session_object.Load(model_uri).IsOK()); + ASSERT_STATUS_OK(session_object.Load(model_uri)); std::shared_ptr p_model; - Status st = onnxruntime::Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); - ASSERT_TRUE(st.IsOK()); + ASSERT_STATUS_OK(onnxruntime::Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger())); const onnxruntime::Graph& graph = p_model->MainGraph(); // 1. first test the model meta @@ -530,8 +542,8 @@ TEST(InferenceSessionTests, CheckRunLogger) { std::unique_ptr env; auto st = Environment::Create(std::move(logging_manager), env); InferenceSession session_object{so, *env.get()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "RunTag"; @@ -560,8 +572,8 @@ TEST(InferenceSessionTests, CheckRunProfilerWithSessionOptions) { so.profile_file_prefix = ORT_TSTR("onnxprofile_profile_test"); InferenceSession session_object(so, GetEnvironment()); - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "RunTag"; @@ -599,8 +611,8 @@ TEST(InferenceSessionTests, CheckRunProfilerWithStartProfile) { so.session_logid = "CheckRunProfiler"; InferenceSession session_object(so, GetEnvironment()); - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "RunTag"; @@ -637,8 +649,8 @@ TEST(InferenceSessionTests, MultipleSessionsNoTimeout) { session_options.session_logid = "InferenceSessionTests.MultipleSessionsNoTimeout"; InferenceSession session_object{session_options, GetEnvironment()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); std::thread thread1{[&session_object]() { RunOptions run_options; @@ -662,8 +674,8 @@ TEST(InferenceSessionTests, PreAllocateOutputVector) { so.session_logid = "InferenceSessionTests.PreAllocateOutputVector"; InferenceSession session_object{so, GetEnvironment()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "InferenceSessionTests.PreAllocateOutputVector"; @@ -691,8 +703,8 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { std::unique_ptr env; auto st = Environment::Create(std::move(logging_manager), env); InferenceSession session_object{so, *env.get()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "ConfigureVerbosityLevel"; @@ -730,7 +742,7 @@ TEST(InferenceSessionTests, TestWithIstream) { std::ifstream model_file_stream(MODEL_URI, ios::in | ios::binary); ASSERT_TRUE(model_file_stream.good()); ASSERT_TRUE(session_object.Load(model_file_stream).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "InferenceSessionTests.TestWithIstream"; @@ -749,7 +761,7 @@ TEST(InferenceSessionTests, TestRegisterExecutionProvider) { std::ifstream model_file_stream(MODEL_URI, ios::in | ios::binary); ASSERT_TRUE(model_file_stream.good()); ASSERT_TRUE(session_object.Load(model_file_stream).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "InferenceSessionTests.TestWithIstream"; @@ -812,7 +824,7 @@ TEST(InferenceSessionTests, TestIOBindingReuse) { p_model->ToProto().SerializeToString(&s1); std::stringstream sstr(s1); ASSERT_TRUE(session_object.Load(sstr).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Initialize()); unique_ptr io_binding; Status st = session_object.NewIOBinding(&io_binding); ASSERT_TRUE(st.IsOK()); @@ -846,8 +858,8 @@ TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) { so.session_logid = "InferenceSessionTests.InvalidInputTypeOfTensorElement"; InferenceSession session_object{so, GetEnvironment()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -1112,8 +1124,9 @@ TEST(ExecutionProviderTest, FunctionTest) { VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); InferenceSession session_object_2{so, GetEnvironment()}; - session_object_2.RegisterExecutionProvider(std::move(testCPUExecutionProvider)); - session_object_2.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::FuseExecutionProvider>()); + ASSERT_STATUS_OK(session_object_2.RegisterExecutionProvider(std::move(testCPUExecutionProvider))); + ASSERT_STATUS_OK( + session_object_2.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::FuseExecutionProvider>())); status = session_object_2.Load(model_file_name); ASSERT_TRUE(status.IsOK()); status = session_object_2.Initialize(); @@ -1636,7 +1649,7 @@ TEST(InferenceSessionTests, TestTruncatedSequence) { SessionOptions so; InferenceSession session_object(so, GetEnvironment()); ASSERT_TRUE(session_object.Load(LSTM_MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "one session/one tag"; @@ -1746,12 +1759,12 @@ TEST(InferenceSessionTests, TestCopyToFromDevices) { so.session_logid = "InferenceSessionTests.TestCopyToFromDevices"; InferenceSession session_object{so, GetEnvironment()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); auto dummy_provider = onnxruntime::make_unique(); auto* p_dummy_provider = dummy_provider.get(); - session_object.RegisterExecutionProvider(std::move(dummy_provider)); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(dummy_provider))); // prepare inputs std::vector dims_mul_x = {3, 2}; @@ -1811,10 +1824,10 @@ TEST(InferenceSessionTests, TestRegisterTransformers) { // Create and register dummy graph transformer auto dummy_transformer_unique_ptr = onnxruntime::make_unique("DummyTransformer"); const auto* dummy_transformer = dummy_transformer_unique_ptr.get(); - session_object.RegisterGraphTransformer(std::move(dummy_transformer_unique_ptr)); + ASSERT_STATUS_OK(session_object.RegisterGraphTransformer(std::move(dummy_transformer_unique_ptr))); - session_object.Load(model_uri); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); // Validate transformer was called after Session.Initialize ASSERT_TRUE(dummy_transformer->IsTransformerInvoked()); @@ -1837,8 +1850,8 @@ TEST(InferenceSessionTests, TestL1AndL2Transformers) { so.session_logid = "InferenceSessionTests.TestL1AndL2Transformers"; so.graph_optimization_level = TransformerLevel::Level2; InferenceSession session_object{so, GetEnvironment()}; - ASSERT_TRUE(session_object.Load(model_uri).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); } } @@ -1907,7 +1920,7 @@ TEST(InferenceSessionTests, TestParallelExecutionWithCudaProvider) { epi.device_id = 0; EXPECT_TRUE(session_object.RegisterExecutionProvider(onnxruntime::make_unique(epi)).IsOK()); - ASSERT_TRUE(session_object.Load(model_uri).IsOK()); + ASSERT_STATUS_OK(session_object.Load(model_uri)); auto status = session_object.Initialize(); @@ -1928,7 +1941,7 @@ TEST(InferenceSessionTests, ModelThatTriggersAllocationPlannerToReuseDoubleTenso Status st; ASSERT_TRUE((st = session_object.Load("testdata/test_cast_back_to_back_non_const_mixed_types_origin.onnx")).IsOK()) << st.ErrorMessage(); - ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st.ErrorMessage(); + ASSERT_STATUS_OK(session_object.Initialize()); RunOptions run_options; run_options.run_tag = "one session/one tag"; @@ -1980,7 +1993,7 @@ static char ort_load_config_from_model_env_var_disabled[] = "ORT_LOAD_CONFIG_FRO TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { // Part 1 - Load config from model feature enabled #ifdef _WIN32 - _putenv(ort_load_config_from_model_env_var_enabled); + (void)_putenv(ort_load_config_from_model_env_var_enabled); #else putenv(ort_load_config_from_model_env_var_enabled); #endif @@ -2018,7 +2031,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { // Part 2 - Load config from model feature disabled #ifdef _WIN32 - _putenv(ort_load_config_from_model_env_var_disabled); + (void)_putenv(ort_load_config_from_model_env_var_disabled); #else putenv(ort_load_config_from_model_env_var_disabled); #endif @@ -2046,7 +2059,7 @@ TEST(InferenceSessionTests, LoadModelWithValidOrtConfigJson) { TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { // Part 1 - Load config from model feature enabled #ifdef _WIN32 - _putenv(ort_load_config_from_model_env_var_enabled); + (void)_putenv(ort_load_config_from_model_env_var_enabled); #else putenv(ort_load_config_from_model_env_var_enabled); #endif @@ -2066,7 +2079,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { // Part 2 - Load config from model feature disabled // The invalid/improperly formed config json in the model should not come into the picture here #ifdef _WIN32 - _putenv(ort_load_config_from_model_env_var_disabled); + (void)_putenv(ort_load_config_from_model_env_var_disabled); #else putenv(ort_load_config_from_model_env_var_disabled); #endif @@ -2093,7 +2106,7 @@ TEST(InferenceSessionTests, LoadModelWithInValidOrtConfigJson) { TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { // Part 1 - Load config from model feature enabled #ifdef _WIN32 - _putenv(ort_load_config_from_model_env_var_enabled); + (void)_putenv(ort_load_config_from_model_env_var_enabled); #else putenv(ort_load_config_from_model_env_var_enabled); #endif @@ -2120,7 +2133,7 @@ TEST(InferenceSessionTests, LoadModelWithNoOrtConfigJson) { // Part 2 - Load config from model feature disabled // The missing config json should not come into the picture #ifdef _WIN32 - _putenv(ort_load_config_from_model_env_var_disabled); + (void)_putenv(ort_load_config_from_model_env_var_disabled); #else putenv(ort_load_config_from_model_env_var_disabled); #endif @@ -2141,7 +2154,7 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { // "10" is unsupported for ORT_LOAD_CONFIG_FROM_MODEL char env_var_value_set_to_unsupported_val[] = "ORT_LOAD_CONFIG_FROM_MODEL=10"; #ifdef _WIN32 - _putenv(env_var_value_set_to_unsupported_val); + (void)_putenv(env_var_value_set_to_unsupported_val); #else putenv(env_var_value_set_to_unsupported_val); #endif @@ -2160,7 +2173,7 @@ TEST(InferenceSessionTests, LoadModelWithEnvVarSetToUnsupportedVal) { // Disable the feature before exiting the test as this process is likely to be used for running other tests #ifdef _WIN32 - _putenv(ort_load_config_from_model_env_var_disabled); + (void)_putenv(ort_load_config_from_model_env_var_disabled); #else putenv(ort_load_config_from_model_env_var_disabled); #endif @@ -2202,8 +2215,8 @@ TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed) { ASSERT_TRUE(st.IsOK()); InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); // make sure we're using the per session threadpools auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); @@ -2242,8 +2255,8 @@ TEST(InferenceSessionTests, CheckIfGlobalThreadPoolsAreBeingUsed) { ASSERT_TRUE(st.IsOK()); InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); // make sure we're using the global threadpools in both session and session state auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); @@ -2280,8 +2293,8 @@ TEST(InferenceSessionTests, CheckIfPerSessionThreadPoolsAreBeingUsed2) { ASSERT_TRUE(st.IsOK()); InferenceSessionTestGlobalThreadPools session_object{so, *env.get()}; - ASSERT_TRUE(session_object.Load(MODEL_URI).IsOK()); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); // make sure we're using the per session threadpools auto intra_tp_from_session = session_object.GetIntraOpThreadPoolToUse(); diff --git a/onnxruntime/test/framework/memcpy_transformer_test.cc b/onnxruntime/test/framework/memcpy_transformer_test.cc index 198e7e7b54521..3fbe5f5be3399 100644 --- a/onnxruntime/test/framework/memcpy_transformer_test.cc +++ b/onnxruntime/test/framework/memcpy_transformer_test.cc @@ -9,6 +9,7 @@ #include "gtest/gtest.h" #include "test_utils.h" #include "test/test_environment.h" +#include "asserts.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -23,8 +24,8 @@ void ExpectSame(const onnxruntime::Node& source, const onnxruntime::Node& target EXPECT_EQ(source_output, target_input); } -void ExpectCopy(const onnxruntime::Node& source, const std::string copy_op, - const onnxruntime::Node& target, int argnum) { +void ExpectCopy(const onnxruntime::Node& source, const std::string& copy_op, const onnxruntime::Node& target, + int argnum) { // Check that source's output is consumed by a copy_op; for (auto it = source.OutputNodesBegin(); it != source.OutputNodesEnd(); ++it) { auto& copy_node = *it; @@ -109,7 +110,7 @@ TEST(TransformerTest, MemcpyTransformerTest) { execution_providers.Add(onnxruntime::kCpuExecutionProvider, onnxruntime::make_unique(CPUExecutionProviderInfo())); KernelRegistryManager test_registry_manager; - test_registry_manager.RegisterKernels(execution_providers); + ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); @@ -165,7 +166,7 @@ TEST(TransformerTest, MemcpyTransformerTestCudaFirst) { execution_providers.Add(onnxruntime::kCpuExecutionProvider, onnxruntime::make_unique(CPUExecutionProviderInfo())); KernelRegistryManager test_registry_manager; - test_registry_manager.RegisterKernels(execution_providers); + ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); @@ -248,8 +249,7 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) { &graph.GetOrCreateNodeArg("parent_constant", &tensor_float_type)}, ArgMap{&o2_def}); - auto status = subgraph.Resolve(); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(subgraph.Resolve()); // main graph continued // create the 'If' node @@ -272,8 +272,7 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) { node.SetExecutionProviderType(onnxruntime::kCpuExecutionProvider); } - status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(graph.Resolve()); KernelRegistryManager kernel_registry_manager; ExecutionProviders execution_providers; @@ -282,13 +281,12 @@ TEST(TransformerTest, TestCopyNodeInsertionInitializerInSubgraph) { execution_providers.Add(onnxruntime::kCpuExecutionProvider, onnxruntime::make_unique(CPUExecutionProviderInfo())); KernelRegistryManager test_registry_manager; - test_registry_manager.RegisterKernels(execution_providers); + ASSERT_STATUS_OK(test_registry_manager.RegisterKernels(execution_providers)); MemcpyTransformer transformer({onnxruntime::kCudaExecutionProvider}, test_registry_manager); bool modified = false; - status = transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); - EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + ASSERT_STATUS_OK(transformer.Apply(graph, modified, DefaultLoggingManager().DefaultLogger())); EXPECT_TRUE(modified); } diff --git a/onnxruntime/test/global_thread_pools/test_main.cc b/onnxruntime/test/global_thread_pools/test_main.cc index 4d5df10066f39..7d31cbe86bf1d 100644 --- a/onnxruntime/test/global_thread_pools/test_main.cc +++ b/onnxruntime/test/global_thread_pools/test_main.cc @@ -24,6 +24,9 @@ #pragma warning(disable : 4506) /*no definition for inline function 'function'*/ #pragma warning(disable : 4800) /*'type' : forcing value to bool 'true' or 'false' (performance warning)*/ #pragma warning(disable : 4996) /*The compiler encountered a deprecated declaration.*/ +#pragma warning(disable : 6011) /*Dereferencing NULL pointer*/ +#pragma warning(disable : 6387) /*'value' could be '0'*/ +#pragma warning(disable : 26495) /*Variable is uninitialized.*/ #endif #include #ifdef __GNUC__ diff --git a/onnxruntime/test/mlas/unittest.cpp b/onnxruntime/test/mlas/unittest.cpp index c8cc422b86102..6f9f6b74cff39 100644 --- a/onnxruntime/test/mlas/unittest.cpp +++ b/onnxruntime/test/mlas/unittest.cpp @@ -511,7 +511,7 @@ class MlasQgemmU8X8Test : public MlasTestBase for (size_t f = 0; f < M * N; f++) { if (C[f] != CReference[f]) { - printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d, offb=%d!\n", M, N, K, offa, offb); + printf("mismatch M=%zd, N=%zd, K=%zd, offa=%d, offb=%d!\n", M, N, K, (int)offa, (int)offb); } } } @@ -1944,7 +1944,7 @@ class MlasActivationTest : public MlasTestBase for (unsigned i = 0; i < _countof(TestData); i++) { // Sensitive to comparing positive/negative zero and NaNs. if (Buffer[i].u != TestData[i][kind].u && Buffer[i].f != TestData[i][kind].f) { - printf("mismatch activation kind=%d i=%d value=%08x expected=%08x\n", kind, i, Buffer[i].u, TestData[i][kind].u); + printf("mismatch activation kind=%d i=%d value=%08x expected=%08x\n", (int)kind, (int)i, Buffer[i].u, TestData[i][kind].u); } } @@ -1960,7 +1960,7 @@ class MlasActivationTest : public MlasTestBase for (unsigned i = 0; i < _countof(TestData); i++) { // Sensitive to comparing positive/negative zero and NaNs. if (Buffer[i].u != TestData[i][kind].u && Buffer[i].f != TestData[i][kind].f) { - printf("mismatch activation kind=%d i=%d value=%08x expected=%08x\n", kind, i, Buffer[i].u, TestData[i][kind].u); + printf("mismatch activation kind=%d i=%d value=%08x expected=%08x\n", (int)kind, (int)i, Buffer[i].u, TestData[i][kind].u); } } } diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 96fb654ffe7ac..bcaad4133cdc8 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -394,7 +394,7 @@ class OnnxTestCase : public ITestCase { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OnnxTestCase); public: - OnnxTestCase(const std::string& test_case_name, TestModelInfo* model, double default_per_sample_tolerance, + OnnxTestCase(const std::string& test_case_name, _In_ TestModelInfo* model, double default_per_sample_tolerance, double default_relative_per_sample_tolerance); ~OnnxTestCase() override { delete model_info_; } Status GetPerSampleTolerance(double* value) override; diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 9d1b76969b5f9..332e083cc0b22 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -662,6 +662,18 @@ int real_main(int argc, char* argv[], Ort::Env& env) { broken_tests.insert({"dynamic_slice_end_out_of_bounds", "This model uses contrib ops."}); broken_tests.insert({"dynamic_slice_neg", "This model uses contrib ops."}); broken_tests.insert({"mvn", "This model uses contrib ops.", {"onnx130"}}); + broken_tests.insert({"cdist_float32_euclidean_1000_2000_1", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float32_euclidean_1000_2000_500", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float32_euclidean_1_1_1", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float32_sqeuclidean_1000_2000_1", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float32_sqeuclidean_1000_2000_500", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float32_sqeuclidean_1_1_1", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float64_euclidean_1000_2000_1", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float64_euclidean_1000_2000_500", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float64_euclidean_1_1_1", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float64_sqeuclidean_1000_2000_1", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float64_sqeuclidean_1000_2000_500", "This model uses contrib ops."}); + broken_tests.insert({"cdist_float64_sqeuclidean_1_1_1", "This model uses contrib ops."}); #endif int result = 0; diff --git a/onnxruntime/test/onnx/pb_helper.h b/onnxruntime/test/onnx/pb_helper.h index 91de12a51ff45..7625b229b43da 100644 --- a/onnxruntime/test/onnx/pb_helper.h +++ b/onnxruntime/test/onnx/pb_helper.h @@ -33,6 +33,11 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-qualifiers" #pragma GCC diagnostic ignored "-Wunused-parameter" +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 6011) +#pragma warning(disable : 6387) +#pragma warning(disable : 28182) #endif #include #include @@ -40,6 +45,8 @@ #include "tml.pb.h" #ifdef __GNUC__ #pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) #endif namespace onnxruntime { bool ParseDelimitedFromCodedStream(google::protobuf::MessageLite* message, diff --git a/onnxruntime/test/opaque_api/test_opaque_api.cc b/onnxruntime/test/opaque_api/test_opaque_api.cc index a27030dcb5076..c1ff0f0a14e94 100644 --- a/onnxruntime/test/opaque_api/test_opaque_api.cc +++ b/onnxruntime/test/opaque_api/test_opaque_api.cc @@ -153,9 +153,10 @@ static void RegisterCustomKernel() { // Register kernel directly to KernelRegistry // because we can not create custom ops with Opaque types // as input + // TODO: But that registry is process-wide, such modification is super dangerous. BuildKernelCreateInfoFn fn = BuildKernelCreateInfo; auto kernel_registry = CPUExecutionProvider(CPUExecutionProviderInfo()).GetKernelRegistry(); - kernel_registry->Register(fn()); + ORT_ENFORCE(kernel_registry->Register(fn()).IsOK()); } namespace test { diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 73212d59b8e42..dcb870b759ed8 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -41,7 +41,7 @@ #include "test/framework/test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/test_environment.h" - +#include "asserts.h" #include "gtest/gtest.h" using namespace std; @@ -52,10 +52,19 @@ namespace test { #define MODEL_FOLDER ORT_TSTR("testdata/transform/") -TEST(GraphTransformationTests, IdentityElimination) { +class GraphTransformationTests : public ::testing::Test { + protected: + GraphTransformationTests() { + logger_ = DefaultLoggingManager().CreateLogger("GraphTransformationTests"); + } + + std::unique_ptr logger_; +}; + +TEST_F(GraphTransformationTests, IdentityElimination) { auto model_uri = MODEL_FOLDER "abs-id-max.onnx"; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Identity"] == 1); @@ -64,16 +73,16 @@ TEST(GraphTransformationTests, IdentityElimination) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Identity"] == 0); } -TEST(GraphTransformationTests, DropoutElimination) { +TEST_F(GraphTransformationTests, DropoutElimination) { auto model_uri = MODEL_FOLDER "dropout.onnx"; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Identity"] == 5); @@ -83,7 +92,7 @@ TEST(GraphTransformationTests, DropoutElimination) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); // Of the 6 Dropout nodes in the graph, all but the ones named `d1` and `d6` should have been removed. @@ -94,12 +103,12 @@ TEST(GraphTransformationTests, DropoutElimination) { ASSERT_TRUE(op_to_count["Dropout"] == 2); } -TEST(GraphTransformationTests, SliceElimination) { +TEST_F(GraphTransformationTests, SliceElimination) { std::vector > model_names = {ORT_TSTR("slice-v1-elim.onnx"), ORT_TSTR("slice-v11-elim.onnx")}; for (const auto& model_name : model_names) { auto model_uri = MODEL_FOLDER + model_name; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); int initial_slice_num = op_to_count["Slice"]; @@ -108,7 +117,7 @@ TEST(GraphTransformationTests, SliceElimination) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); // Only one Slice operator is redundant and is removed. @@ -116,10 +125,10 @@ TEST(GraphTransformationTests, SliceElimination) { } } -TEST(GraphTransformationTests, ConstantFolding) { +TEST_F(GraphTransformationTests, ConstantFolding) { auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Unsqueeze"] == 2); @@ -127,16 +136,16 @@ TEST(GraphTransformationTests, ConstantFolding) { onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); } -TEST(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) { +TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) { auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Unsqueeze"] == 2); @@ -149,7 +158,7 @@ TEST(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) { node.SetExecutionProviderType(kCudaExecutionProvider); } - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); @@ -160,7 +169,7 @@ TEST(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) { } } -TEST(GraphTransformationTests, ConstantFoldingSubgraph) { +TEST_F(GraphTransformationTests, ConstantFoldingSubgraph) { TensorProto value_tensor; value_tensor.add_dims(1); value_tensor.add_float_data(1.f); @@ -172,7 +181,7 @@ TEST(GraphTransformationTests, ConstantFoldingSubgraph) { auto create_subgraph = [&](GraphProto& graph_proto) { // create subgraph that has an Add node to add a local and parent graph initializer - Model model("ConstantFoldingSubgraphTest_subgraph", false, DefaultLoggingManager().DefaultLogger()); + Model model("ConstantFoldingSubgraphTest_subgraph", false, *logger_); auto& graph = model.MainGraph(); TensorProto local_constant(value_tensor); @@ -189,12 +198,11 @@ TEST(GraphTransformationTests, ConstantFoldingSubgraph) { auto& subgraph_out = graph.GetOrCreateNodeArg("subgraph_out", &float_tensor_type); graph.AddNode("identity", "Identity", "So Add isn't providing graph output.", {&add_out}, {&subgraph_out}); - auto status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()) << status; + ASSERT_STATUS_OK(graph.Resolve()); graph_proto = graph.ToGraphProto(); }; - Model model("ConstantFoldingSubgraphTest_main_graph", false, DefaultLoggingManager().DefaultLogger()); + Model model("ConstantFoldingSubgraphTest_main_graph", false, *logger_); auto& graph = model.MainGraph(); // add initializer at parent level @@ -217,8 +225,7 @@ TEST(GraphTransformationTests, ConstantFoldingSubgraph) { if_node.AddAttribute("then_branch", subgraph); if_node.AddAttribute("else_branch", subgraph); - auto status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()) << status; + ASSERT_STATUS_OK(graph.Resolve()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 2); // one in each subgraph @@ -226,18 +233,17 @@ TEST(GraphTransformationTests, ConstantFoldingSubgraph) { onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - status = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); - ASSERT_TRUE(status.IsOK()) << status; + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 0) << "Constant folding should have been able to remove the Add node in both subgraphs"; } -TEST(GraphTransformationTests, ShapeToInitializer) { +TEST_F(GraphTransformationTests, ShapeToInitializer) { auto model_uri = MODEL_FOLDER "shape-add.onnx"; std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Shape"] == 4); @@ -247,7 +253,7 @@ TEST(GraphTransformationTests, ShapeToInitializer) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); // Two of the Shapes are not eliminated because: @@ -257,19 +263,19 @@ TEST(GraphTransformationTests, ShapeToInitializer) { } // Check transformations in the case of a subgraph with constant inputs. -TEST(GraphTransformationTests, SubgraphWithConstantInputs) { +TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) { auto model_uri = MODEL_FOLDER "constant-subgraph.onnx"; SessionOptions so; so.graph_optimization_level = TransformerLevel::Level2; so.session_logid = "GraphTransformationTests.LoadModelToTransform"; InferenceSession session_object{so, GetEnvironment()}; - ASSERT_TRUE(session_object.Load(model_uri).IsOK()); + ASSERT_STATUS_OK(session_object.Load(model_uri)); std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - ASSERT_TRUE(session_object.Initialize().IsOK()); + ASSERT_STATUS_OK(session_object.Initialize()); NameMLValMap feeds; RunOptions run_options; @@ -277,14 +283,14 @@ TEST(GraphTransformationTests, SubgraphWithConstantInputs) { std::vector output_names = {"output"}; std::vector fetches; - ASSERT_TRUE(session_object.Run(run_options, feeds, output_names, &fetches).IsOK()); + ASSERT_STATUS_OK(session_object.Run(run_options, feeds, output_names, &fetches)); } -TEST(GraphTransformationTests, FuseConvBNNoBias) { +TEST_F(GraphTransformationTests, FuseConvBNNoBias) { auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); std::string bn_output_name; @@ -302,7 +308,7 @@ TEST(GraphTransformationTests, FuseConvBNNoBias) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); @@ -315,11 +321,11 @@ TEST(GraphTransformationTests, FuseConvBNNoBias) { } } -TEST(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { +TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); // add an optional output to the BN node. should not fuse if this is present @@ -336,13 +342,13 @@ TEST(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["BatchNormalization"] == 1); } -TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { +TEST_F(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { std::vector > test_models = {ORT_TSTR("fusion/fuse-conv-bn-mul-add-unsqueeze.onnx"), ORT_TSTR("fusion/fuse-conv-bn-mul-add-unsqueeze.negative_axes.onnx"), ORT_TSTR("fusion/fuse-conv-bn-mul-add-unsqueeze-no-bias.onnx")}; @@ -350,7 +356,7 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { auto model_uri = MODEL_FOLDER + model; std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -361,7 +367,7 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); @@ -372,7 +378,7 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { } #ifndef DISABLE_CONTRIB_OPS -TEST(GraphTransformationTests, FuseConvActivation) { +TEST_F(GraphTransformationTests, FuseConvActivation) { std::unordered_map, std::string> model_to_op_name{{ORT_TSTR("fusion/conv_relu.onnx"), "Relu"}, {ORT_TSTR("fusion/conv_clip.onnx"), "Clip"}, {ORT_TSTR("fusion/conv_sigmoid.onnx"), "Sigmoid"}, @@ -382,7 +388,7 @@ TEST(GraphTransformationTests, FuseConvActivation) { for (const auto& model : model_to_op_name) { auto model_uri = MODEL_FOLDER + model.first; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); @@ -391,18 +397,17 @@ TEST(GraphTransformationTests, FuseConvActivation) { // Apply transformer onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count[model.second] == 0); } } -TEST(GraphTransformationTests, FuseConvClip11Activation) { +TEST_F(GraphTransformationTests, FuseConvClip11Activation) { auto model_uri = MODEL_FOLDER "fusion/conv_clip11.onnx"; std::shared_ptr p_model; - auto status = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); - ASSERT_TRUE(status.IsOK()) << status; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); @@ -411,7 +416,7 @@ TEST(GraphTransformationTests, FuseConvClip11Activation) { // Apply transformer onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); op_to_count = CountOpsInGraph(graph); ASSERT_EQ(op_to_count["Clip"], 1); @@ -437,11 +442,11 @@ TEST(GraphTransformationTests, FuseConvClip11Activation) { } #endif -TEST(GraphTransformationTests, FuseConvMulNoBias) { +TEST_F(GraphTransformationTests, FuseConvMulNoBias) { auto model_uri = MODEL_FOLDER "fusion/fuse-conv-mul-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -450,18 +455,18 @@ TEST(GraphTransformationTests, FuseConvMulNoBias) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Mul"] == 0); ASSERT_TRUE(op_to_count["Unsqueeze"] == 0); } -TEST(GraphTransformationTests, FuseConvAddNoBias) { +TEST_F(GraphTransformationTests, FuseConvAddNoBias) { auto model_uri = MODEL_FOLDER "fusion/fuse-conv-add-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -470,7 +475,7 @@ TEST(GraphTransformationTests, FuseConvAddNoBias) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 0); @@ -479,11 +484,11 @@ TEST(GraphTransformationTests, FuseConvAddNoBias) { // if IR version is 4 or higher the weights can be overridden if there's a matching graph input. // check that we don't fuse if that is the case -TEST(GraphTransformationTests, NegativeFuseConvAddNoBias) { +TEST_F(GraphTransformationTests, NegativeFuseConvAddNoBias) { auto model_uri = MODEL_FOLDER "fusion/negative-fuse-conv-add-no-bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -492,7 +497,7 @@ TEST(GraphTransformationTests, NegativeFuseConvAddNoBias) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); // Nodes are not fused because the weights to conv/add are not constants (they appear in the graph inputs). // Unsqueeze is also not eliminated as the initializer that is its input is also not constant @@ -501,11 +506,11 @@ TEST(GraphTransformationTests, NegativeFuseConvAddNoBias) { ASSERT_TRUE(op_to_count["Unsqueeze"] != 0); } -TEST(GraphTransformationTests, FuseConvAddMul3D) { +TEST_F(GraphTransformationTests, FuseConvAddMul3D) { auto model_uri = MODEL_FOLDER "fusion/fuse-conv-add-mul-3d.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -514,18 +519,18 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 0); ASSERT_TRUE(op_to_count["Mul"] == 0); } -TEST(GraphTransformationTests, FuseConvAddMul3D_2) { +TEST_F(GraphTransformationTests, FuseConvAddMul3D_2) { auto model_uri = MODEL_FOLDER "fusion/fuse-conv-add-mul-3d-2.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -534,23 +539,23 @@ TEST(GraphTransformationTests, FuseConvAddMul3D_2) { rule_transformer_L1->Register(onnxruntime::make_unique()); graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 0); ASSERT_TRUE(op_to_count["Mul"] == 0); } -TEST(GraphTransformationTests, MatMulAddFusion_two_input) { +TEST_F(GraphTransformationTests, MatMulAddFusion_two_input) { auto model_uri = MODEL_FOLDER "matmul_add_fusion/2Input/model.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["MatMul"] == 0); @@ -558,16 +563,16 @@ TEST(GraphTransformationTests, MatMulAddFusion_two_input) { ASSERT_TRUE(op_to_count["Gemm"] == 1); } -TEST(GraphTransformationTests, MatMulAddFusion_three_input) { +TEST_F(GraphTransformationTests, MatMulAddFusion_three_input) { auto model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/model.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["MatMul"] == 0); @@ -579,16 +584,16 @@ TEST(GraphTransformationTests, MatMulAddFusion_three_input) { // We can do the fusion by changing shape to [1,k]*[k,N]+[1,N], then add a reshape [1,N]=>[N] // This will bring extra cost. And there's only very limited gain to fuse Matmul+Add to Gemm // Since the basic implementation is almost same -TEST(GraphTransformationTests, MatMulAddFusion_negitive_case) { +TEST_F(GraphTransformationTests, MatMulAddFusion_negitive_case) { auto model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/neg_model.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["MatMul"] == 1); @@ -597,31 +602,31 @@ TEST(GraphTransformationTests, MatMulAddFusion_negitive_case) { } #ifndef DISABLE_CONTRIB_OPS -TEST(GraphTransformationTests, Gemm_Relu_three_input) { +TEST_F(GraphTransformationTests, Gemm_Relu_three_input) { auto model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); std::map op_to_count1 = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Relu"] == 0); } -TEST(GraphTransformationTests, Gemm_LeakyRelu_Fusion) { +TEST_F(GraphTransformationTests, Gemm_LeakyRelu_Fusion) { auto model_uri = MODEL_FOLDER "gemm_activation_fusion/gemm_activation_fusion.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); std::map op_to_count1 = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["LeakyRelu"] == 0); @@ -630,7 +635,7 @@ TEST(GraphTransformationTests, Gemm_LeakyRelu_Fusion) { } #endif -TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { +TEST_F(GraphTransformationTests, FuseConvBnAddMulFloat16) { auto model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-add-mul-float16.onnx"; SessionOptions so; @@ -639,13 +644,13 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { ASSERT_TRUE(session_object.Load(model_uri).IsOK()); std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); auto rule_transformer_L1 = onnxruntime::make_unique("RuleTransformerL1"); rule_transformer_L1->Register(onnxruntime::make_unique()); rule_transformer_L1->Register(onnxruntime::make_unique()); rule_transformer_L1->Register(onnxruntime::make_unique()); - session_object.RegisterGraphTransformer(std::move(rule_transformer_L1), TransformerLevel::Level1); + ASSERT_STATUS_OK(session_object.RegisterGraphTransformer(std::move(rule_transformer_L1), TransformerLevel::Level1)); ASSERT_TRUE(session_object.Initialize().IsOK()); @@ -687,11 +692,11 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) { ASSERT_EQ(expected_values_prod, found); } -TEST(GraphTransformationTests, ReluClip6Fusion) { +TEST_F(GraphTransformationTests, ReluClip6Fusion) { // Clip op schema changed for opset version 11. Until Clip op is updated in ORT hard coding this model to use // older opset. Model model("ReluClip6Fusion", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 10}}, - {}, DefaultLoggingManager().DefaultLogger()); + {}, *logger_); auto& graph = model.MainGraph(); std::vector inputs; @@ -743,7 +748,7 @@ TEST(GraphTransformationTests, ReluClip6Fusion) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_).IsOK()); op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Relu"] == 0); @@ -758,12 +763,12 @@ TEST(GraphTransformationTests, ReluClip6Fusion) { } // test handling of Clip 11 -TEST(GraphTransformationTests, ReluClip11Fusion) { +TEST_F(GraphTransformationTests, ReluClip11Fusion) { std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = 11; - Model model("ReluClip6Fusion", false, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version, std::vector(), DefaultLoggingManager().DefaultLogger()); //, true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 11}}, {}); + Model model("ReluClip6Fusion", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + *logger_); //, true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 11}}, {}); auto& graph = model.MainGraph(); std::vector inputs; @@ -835,7 +840,7 @@ TEST(GraphTransformationTests, ReluClip11Fusion) { rule_transformer_L1->Register(onnxruntime::make_unique()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - status = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + status = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(status.IsOK()) << status; op_to_count = CountOpsInGraph(graph); @@ -882,15 +887,15 @@ TEST(GraphTransformationTests, ReluClip11Fusion) { } // Test Reshape Fusion with 2 constant initializers for Concat inputs. -TEST(GraphTransformationTests, ReshapeFusionTest) { +TEST_F(GraphTransformationTests, ReshapeFusionTest) { auto model_uri = MODEL_FOLDER "fusion/reshape.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -919,15 +924,15 @@ TEST(GraphTransformationTests, ReshapeFusionTest) { } // Test Reshape Fusion with one constant initializer for Concat inputs. -TEST(GraphTransformationTests, ReshapeFusionOneConstTest) { +TEST_F(GraphTransformationTests, ReshapeFusionOneConstTest) { auto model_uri = MODEL_FOLDER "fusion/reshape_one_const.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1069,15 +1074,15 @@ static void ValidateAttention(Graph& graph) { } // Test Attention Fusion with int32 mask -TEST(GraphTransformationTests, AttentionFusionInt32Test) { +TEST_F(GraphTransformationTests, AttentionFusionInt32Test) { auto model_uri = MODEL_FOLDER "fusion/attention_int32_mask.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1093,15 +1098,15 @@ TEST(GraphTransformationTests, AttentionFusionInt32Test) { } // Test Attention Fusion with int64 mask and symbolic batch dimension -TEST(GraphTransformationTests, AttentionFusionInt64Test) { +TEST_F(GraphTransformationTests, AttentionFusionInt64Test) { auto model_uri = MODEL_FOLDER "fusion/attention_symbolic_batch.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1116,15 +1121,15 @@ TEST(GraphTransformationTests, AttentionFusionInt64Test) { ValidateAttention(graph); } -TEST(GraphTransformationTests, GeluFusionTest) { +TEST_F(GraphTransformationTests, GeluFusionTest) { auto model_uri = MODEL_FOLDER "fusion/gelu.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1135,15 +1140,15 @@ TEST(GraphTransformationTests, GeluFusionTest) { ASSERT_TRUE(op_to_count["Gelu"] == 1); } -TEST(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) { +TEST_F(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) { auto model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1154,15 +1159,15 @@ TEST(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) { ASSERT_TRUE(op_to_count["Gelu"] == 1); } -TEST(GraphTransformationTests, GeluFusionTestFormat2) { +TEST_F(GraphTransformationTests, GeluFusionTestFormat2) { auto model_uri = MODEL_FOLDER "fusion/gelu_format2_1.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1173,15 +1178,15 @@ TEST(GraphTransformationTests, GeluFusionTestFormat2) { ASSERT_TRUE(op_to_count["Gelu"] == 1); } -TEST(GraphTransformationTests, GeluFusionTestFormat2GraphInput) { +TEST_F(GraphTransformationTests, GeluFusionTestFormat2GraphInput) { auto model_uri = MODEL_FOLDER "fusion/gelu_format2_1_use_graph_input.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1192,16 +1197,16 @@ TEST(GraphTransformationTests, GeluFusionTestFormat2GraphInput) { ASSERT_TRUE(op_to_count["Gelu"] == 1); } -TEST(GraphTransformationTests, BiasGeluTest) { +TEST_F(GraphTransformationTests, BiasGeluTest) { auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Div"] == 0); @@ -1213,15 +1218,15 @@ TEST(GraphTransformationTests, BiasGeluTest) { } // Test Gelu -> FastGelu -TEST(GraphTransformationTests, GeluApproximation_Gelu) { +TEST_F(GraphTransformationTests, GeluApproximation_Gelu) { auto model_uri = MODEL_FOLDER "approximation/gelu.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1230,15 +1235,15 @@ TEST(GraphTransformationTests, GeluApproximation_Gelu) { } // Test AddGeluFusion -> FastGelu -TEST(GraphTransformationTests, GeluApproximation_Gelu_Add_Bias) { +TEST_F(GraphTransformationTests, GeluApproximation_Gelu_Add_Bias) { auto model_uri = MODEL_FOLDER "approximation/gelu_add_bias.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1247,15 +1252,15 @@ TEST(GraphTransformationTests, GeluApproximation_Gelu_Add_Bias) { } // Test MatMul & AddGeluFusion -> MatMul & FastGelu -TEST(GraphTransformationTests, GeluApproximation_Gelu_Add_MatMul) { +TEST_F(GraphTransformationTests, GeluApproximation_Gelu_Add_MatMul) { auto model_uri = MODEL_FOLDER "approximation/gelu_add_matmul.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1264,16 +1269,16 @@ TEST(GraphTransformationTests, GeluApproximation_Gelu_Add_MatMul) { EXPECT_EQ(op_to_count["FastGelu"], 1); } -TEST(GraphTransformationTests, FastGeluFusionTest) { +TEST_F(GraphTransformationTests, FastGeluFusionTest) { auto model_uri = MODEL_FOLDER "fusion/fast_gelu.onnx"; std::shared_ptr p_model; - auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(load_ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1284,16 +1289,16 @@ TEST(GraphTransformationTests, FastGeluFusionTest) { ASSERT_TRUE(op_to_count["FastGelu"] == 1); } -TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest) { +TEST_F(GraphTransformationTests, FastGeluUseGraphInputFusionTest) { auto model_uri = MODEL_FOLDER "fusion/fast_gelu_use_graph_input.onnx"; std::shared_ptr p_model; - auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(load_ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1303,17 +1308,17 @@ TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest) { ASSERT_TRUE(op_to_count["FastGelu"] == 1); } -TEST(GraphTransformationTests, FastGeluWithBiasFusionTest) { +TEST_F(GraphTransformationTests, FastGeluWithBiasFusionTest) { auto model_uri = MODEL_FOLDER "fusion/fast_gelu_with_bias.onnx"; std::shared_ptr p_model; - auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(load_ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1323,17 +1328,17 @@ TEST(GraphTransformationTests, FastGeluWithBiasFusionTest) { ASSERT_TRUE(op_to_count["FastGelu"] == 1); } -TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest) { +TEST_F(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest) { auto model_uri = MODEL_FOLDER "fusion/fast_gelu_with_bias_use_graph_input.onnx"; std::shared_ptr p_model; - auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(load_ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1343,16 +1348,16 @@ TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest) { ASSERT_TRUE(op_to_count["FastGelu"] == 1); } -TEST(GraphTransformationTests, FastGeluFusionTest2) { +TEST_F(GraphTransformationTests, FastGeluFusionTest2) { auto model_uri = MODEL_FOLDER "fusion/fast_gelu2.onnx"; std::shared_ptr p_model; - auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(load_ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1362,16 +1367,16 @@ TEST(GraphTransformationTests, FastGeluFusionTest2) { ASSERT_TRUE(op_to_count["FastGelu"] == 1); } -TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest2) { +TEST_F(GraphTransformationTests, FastGeluUseGraphInputFusionTest2) { auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_use_graph_input.onnx"; std::shared_ptr p_model; - auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(load_ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1381,17 +1386,17 @@ TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest2) { ASSERT_TRUE(op_to_count["FastGelu"] == 1); } -TEST(GraphTransformationTests, FastGeluWithBiasFusionTest2) { +TEST_F(GraphTransformationTests, FastGeluWithBiasFusionTest2) { auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_with_bias.onnx"; std::shared_ptr p_model; - auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(load_ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1401,17 +1406,17 @@ TEST(GraphTransformationTests, FastGeluWithBiasFusionTest2) { ASSERT_TRUE(op_to_count["FastGelu"] == 1); } -TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest2) { +TEST_F(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest2) { auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_with_bias_use_graph_input.onnx"; std::shared_ptr p_model; - auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_); ASSERT_TRUE(load_ret.IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1421,15 +1426,15 @@ TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest2) { ASSERT_TRUE(op_to_count["FastGelu"] == 1); } -TEST(GraphTransformationTests, LayerNormFusionTest) { +TEST_F(GraphTransformationTests, LayerNormFusionTest) { auto model_uri = MODEL_FOLDER "fusion/layer_norm.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1457,15 +1462,15 @@ TEST(GraphTransformationTests, LayerNormFusionTest) { } } -TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) { +TEST_F(GraphTransformationTests, LayerNormWithSubDupFusionTest) { auto model_uri = MODEL_FOLDER "fusion/layer_norm_sub_dup.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1493,15 +1498,16 @@ TEST(GraphTransformationTests, LayerNormWithSubDupFusionTest) { } } -static void TestSkipLayerNormFusion(const std::basic_string& file_path, int add_count, int ln_count, int skip_ln_count) { +static void TestSkipLayerNormFusion(const std::basic_string& file_path, int add_count, int ln_count, + int skip_ln_count, logging::Logger* logger) { std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1515,24 +1521,24 @@ static void TestSkipLayerNormFusion(const std::basic_string& file_pat ASSERT_TRUE(op_to_count["SkipLayerNormalization"] == skip_ln_count); } -TEST(GraphTransformationTests, SkipLayerNormFusionTest) { - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1); - TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0); +TEST_F(GraphTransformationTests, SkipLayerNormFusionTest) { + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3.onnx", 0, 0, 1, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_partial.onnx", 1, 0, 1, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_partial.onnx", 1, 0, 1, logger_.get()); + TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_no_fusion.onnx", 1, 1, 0, logger_.get()); } -TEST(GraphTransformationTests, EmbedLayerNormFusionFormat1) { +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1) { auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format1.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1544,15 +1550,15 @@ TEST(GraphTransformationTests, EmbedLayerNormFusionFormat1) { ASSERT_TRUE(op_to_count["EmbedLayerNormalization"] == 1); } -TEST(GraphTransformationTests, EmbedLayerNormFusionFormat2) { +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat2) { auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format2.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1571,15 +1577,15 @@ TEST(GraphTransformationTests, EmbedLayerNormFusionFormat2) { ASSERT_TRUE(op_to_count["EmbedLayerNormalization"] == 1); } -TEST(GraphTransformationTests, EmbedLayerNormFusionFormat3) { +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3) { auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1597,15 +1603,15 @@ TEST(GraphTransformationTests, EmbedLayerNormFusionFormat3) { EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); } -TEST(GraphTransformationTests, EmbedLayerNormFusionFormat4) { +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat4) { auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format4.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); @@ -1625,15 +1631,15 @@ TEST(GraphTransformationTests, EmbedLayerNormFusionFormat4) { ASSERT_TRUE(op_to_count["EmbedLayerNormalization"] == 1); } -TEST(GraphTransformationTests, EmbedLayerNormFusionFormat5) { +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) { auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format5.onnx"; std::shared_ptr p_model; - ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); ASSERT_TRUE(ret.IsOK()); std::map op_to_count = CountOpsInGraph(graph); diff --git a/onnxruntime/test/platform/threadpool_test.cc b/onnxruntime/test/platform/threadpool_test.cc index a517236520ba4..678acbecac404 100644 --- a/onnxruntime/test/platform/threadpool_test.cc +++ b/onnxruntime/test/platform/threadpool_test.cc @@ -115,8 +115,10 @@ TEST(ThreadPoolTest, TestStackSize) { ULONG_PTR low_limit, high_limit; bool has_thread_limit_info = false; tp->Schedule([&]() { - FnGetCurrentThreadStackLimits GetTS = (FnGetCurrentThreadStackLimits)GetProcAddress( - GetModuleHandle(TEXT("kernel32.dll")), "GetCurrentThreadStackLimits"); + HMODULE kernel32_module = GetModuleHandle(TEXT("kernel32.dll")); + assert(kernel32_module != nullptr); + FnGetCurrentThreadStackLimits GetTS = + (FnGetCurrentThreadStackLimits)GetProcAddress(kernel32_module, "GetCurrentThreadStackLimits"); if (GetTS != nullptr) { GetTS(&low_limit, &high_limit); has_thread_limit_info = true; diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index 607332d17aae0..783c94db37103 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -1467,6 +1467,22 @@ TEST(ReductionOpTest, ReduceProd_int32) { test.Run(); } +TEST(ReductionOpTest, ReduceProd_int64) { + OpTester test("ReduceProd"); + test.AddAttribute("axes", std::vector{0, 2}); + test.AddInput("data", {3, 2, 2}, + {1, 2, + 3, 4, + + 5, 6, + 7, 8, + + 9, 10, + 11, 12}); + test.AddOutput("reduced", {1, 2, 1}, {5400, 88704}); + test.Run(); +} + #if !(defined USE_TENSORRT) && !(defined USE_TVM) TEST(ReductionOpTest, ReduceProd0DTensor) { OpTester test("ReduceProd"); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index fdea852b9393e..db4554aeb6a81 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -29,7 +29,7 @@ TEST(DequantizeLinearOpTest, DequantizeLinear_1) { } // 2d inputs -TEST(DequantizeLinearOpTest, DequantizeLinear_2) { +TEST(DequantizeLinearOpTest, DequantizeLinear_2D) { OpTester test("DequantizeLinear", 10); std::vector dims{3, 4}; test.AddInput("X", dims, @@ -45,6 +45,15 @@ TEST(DequantizeLinearOpTest, DequantizeLinear_2) { test.Run(); } +// dequantize with scalar data +TEST(DequantizeLinearOpTest, DequantizeLinear_Scalar) { + OpTester test("DequantizeLinear", 10); + test.AddInput("x", {}, {100}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("x_zero_point", {}, {-10}); + test.AddOutput("y", {}, {220.0f}); + test.Run(); +} // quantize with scalar zero point and scale TEST(QuantizeLinearOpTest, QuantizeLinear_uint8) { @@ -91,7 +100,7 @@ TEST(QuantizeLinearOpTest, QuantizeLinear_int8_PositiveZeroPoint) { } // quantize with 2D data -TEST(QuantizeLinearOpTest, QuantizeLinear_1) { +TEST(QuantizeLinearOpTest, QuantizeLinear_2D) { OpTester test("QuantizeLinear", 10); std::vector dims{3, 4}; test.AddInput("X", dims, @@ -106,5 +115,16 @@ TEST(QuantizeLinearOpTest, QuantizeLinear_1) { 0, 0, 1, 250}); test.Run(); } + +// quantize with scalar data +TEST(QuantizeLinearOpTest, QuantizeLinear_Scalar) { + OpTester test("QuantizeLinear", 10); + test.AddInput("x", {}, {3}); + test.AddInput("y_scale", {}, {2.0f}); + test.AddInput("y_zero_point", {}, {128}); + test.AddOutput("y", {}, {130}); + test.Run(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/where_op_test.cc b/onnxruntime/test/providers/cpu/tensor/where_op_test.cc index e953761df0a9c..cc941cd12df17 100644 --- a/onnxruntime/test/providers/cpu/tensor/where_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/where_op_test.cc @@ -122,5 +122,20 @@ TEST(WhereOpTest, BroadcastDimWithZero) { // exclude NGraph as this isn't handled by that EP test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider}); } + +TEST(WhereOpTest, BroadcastWithScalar) { + OpTester test{kOpName, kOpVersion}; + + test.AddInput("condition", {3}, {true, false, true}); + test.AddInput("X", {1, 3}, {1, 2, 3}); + test.AddInput("Y", {}, {1}); + + test.AddOutput("output", {1, 3}, {1, 1, 3}); + + // exclude NGraph as this isn't handled by that EP + //test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider}); + test.Run(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/memcpy_test.cc b/onnxruntime/test/providers/memcpy_test.cc index f9c7aa5c74c8a..d72f1241814ff 100644 --- a/onnxruntime/test/providers/memcpy_test.cc +++ b/onnxruntime/test/providers/memcpy_test.cc @@ -13,6 +13,7 @@ #include "core/platform/path_lib.h" #include #include "test/test_environment.h" +#include "asserts.h" namespace onnxruntime { namespace { @@ -34,7 +35,7 @@ TEST(MemcpyTest, copy1) { SessionState s{execution_providers, true, &tp, nullptr}; s.SetLogger(logging::LoggingManager::DefaultLogger()); KernelRegistryManager kernel_registry_manager; - kernel_registry_manager.RegisterKernels(execution_providers); + ASSERT_STATUS_OK(kernel_registry_manager.RegisterKernels(execution_providers)); ONNX_NAMESPACE::ModelProto mp; std::ifstream model_istream("testdata/matmul_1.onnx", std::ifstream::in | std::ifstream::binary); diff --git a/onnxruntime/test/providers/provider_test_utils.cc b/onnxruntime/test/providers/provider_test_utils.cc index 4faa3ea922785..d65869d073e42 100644 --- a/onnxruntime/test/providers/provider_test_utils.cc +++ b/onnxruntime/test/providers/provider_test_utils.cc @@ -632,6 +632,12 @@ void OpTester::Run( run_options, execution_providers, custom_output_verifier); } +#define ASSERT_PROVIDER_STATUS_OK(function) \ + do { \ + Status _tmp_status = function; \ + ASSERT_TRUE(_tmp_status.IsOK()) << "provider: " << provider_type << ", error: " << _tmp_status; \ + } while (false) + void OpTester::Run( SessionOptions so, // Take the SessionOptions by value (i.e. make a copy) // because we may need to modify it @@ -717,8 +723,7 @@ void OpTester::Run( for (auto& entry : *execution_providers) { provider_types += entry->Type() + ":"; - EXPECT_TRUE( - session_object.RegisterExecutionProvider(std::move(entry)).IsOK()); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(entry))); } fetches_ = ExecuteModel( @@ -740,7 +745,7 @@ void OpTester::Run( InferenceSession session_object{so, GetEnvironment()}; for (auto& custom_session_registry : custom_session_registries_) - session_object.RegisterCustomRegistry(custom_session_registry); + ASSERT_PROVIDER_STATUS_OK(session_object.RegisterCustomRegistry(custom_session_registry)); std::unique_ptr execution_provider; if (provider_type == onnxruntime::kCpuExecutionProvider) @@ -801,14 +806,11 @@ void OpTester::Run( continue; for (auto& custom_session_registry : custom_session_registries_) - session_object.RegisterCustomRegistry(custom_session_registry); + ASSERT_PROVIDER_STATUS_OK(session_object.RegisterCustomRegistry(custom_session_registry)); has_run = true; - EXPECT_TRUE( - session_object - .RegisterExecutionProvider(std::move(execution_provider)) - .IsOK()); + ASSERT_PROVIDER_STATUS_OK(session_object.RegisterExecutionProvider(std::move(execution_provider))); fetches_ = ExecuteModel( *p_model, session_object, expect_result, expected_failure_string, diff --git a/onnxruntime/test/providers/test_main.cc b/onnxruntime/test/providers/test_main.cc index a0fd66a5e39c1..dae5e1d16c8dc 100644 --- a/onnxruntime/test/providers/test_main.cc +++ b/onnxruntime/test/providers/test_main.cc @@ -24,6 +24,9 @@ #pragma warning(disable : 4506) /*no definition for inline function 'function'*/ #pragma warning(disable : 4800) /*'type' : forcing value to bool 'true' or 'false' (performance warning)*/ #pragma warning(disable : 4996) /*The compiler encountered a deprecated declaration.*/ +#pragma warning(disable : 6011) /*Dereferencing NULL pointer*/ +#pragma warning(disable : 6387) /*'value' could be '0'*/ +#pragma warning(disable : 26495) /*Variable is uninitialized.*/ #endif #include #ifdef __GNUC__ diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index fb1eafca2c8db..c9469f86cfd48 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -172,7 +172,8 @@ def create_backend_test(testname=None): '^test_resize_downsample_scales_cubic_align_corners_cpu', # results mismatch with onnx tests '^test_resize_downsample_scales_linear_align_corners_cpu' # results mismatch with onnx tests ] - + if platform.architecture()[0] == '32bit': + current_failing_tests += ['^test_vgg19', '^test_zfnet512', '^test_bvlc_alexnet_cpu'] # Example of how to disable tests for a specific provider. # if c2.supports_device('NGRAPH'): # current_failing_tests.append('^test_operator_repeat_dim_overflow_cpu') diff --git a/onnxruntime/test/shared_lib/onnx_protobuf.h b/onnxruntime/test/shared_lib/onnx_protobuf.h index f8312038fd970..62e993d06b1f3 100644 --- a/onnxruntime/test/shared_lib/onnx_protobuf.h +++ b/onnxruntime/test/shared_lib/onnx_protobuf.h @@ -4,15 +4,20 @@ #pragma once // TODO(): delete this file from public interface #ifdef __GNUC__ +#include "onnxruntime_config.h" #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wignored-qualifiers" #pragma GCC diagnostic ignored "-Wunused-parameter" +#ifdef HAS_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif #else #pragma warning(push) #pragma warning(disable : 4018) /*'expression' : signed/unsigned mismatch */ #pragma warning(disable : 4065) /*switch statement contains 'default' but no 'case' labels*/ #pragma warning(disable : 4100) #pragma warning(disable : 4146) /*unary minus operator applied to unsigned type, result still unsigned*/ +#pragma warning(disable : 4127) #pragma warning(disable : 4244) /*'conversion' conversion from 'type1' to 'type2', possible loss of data*/ #pragma warning( \ disable : 4251) /*'identifier' : class 'type' needs to have dll-interface to be used by clients of class 'type2'*/ @@ -26,6 +31,9 @@ #pragma warning(disable : 4506) /*no definition for inline function 'function'*/ #pragma warning(disable : 4800) /*'type' : forcing value to bool 'true' or 'false' (performance warning)*/ #pragma warning(disable : 4996) /*The compiler encountered a deprecated declaration.*/ +#pragma warning(disable : 6011) /*Dereferencing NULL pointer*/ +#pragma warning(disable : 6387) /*'value' could be '0'*/ +#pragma warning(disable : 26495) /*Variable is uninitialized.*/ #endif #include "onnx/onnx-ml.pb.h" #ifdef __GNUC__ diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 51294e5dc091e..7a85b1f081b0b 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -100,7 +100,7 @@ void TestInference(Ort::Env& env, T model_uri, if (custom_op_library_filename) { void* library_handle = nullptr; // leak this, no harm. - Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, &library_handle); + Ort::ThrowOnError(Ort::GetApi().RegisterCustomOpsLibrary((OrtSessionOptions*)session_options, custom_op_library_filename, &library_handle)); } // if session creation passes, model loads fine @@ -318,7 +318,11 @@ TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) { } #endif +#ifndef __ANDROID__ +TEST(CApiTest, test_custom_op_library) { +#else TEST(CApiTest, DISABLED_test_custom_op_library) { +#endif std::cout << "Running inference using custom op shared library" << std::endl; std::vector inputs(2); @@ -346,7 +350,7 @@ TEST(CApiTest, DISABLED_test_custom_op_library) { #elif defined(__APPLE__) lib_name = "libcustom_op_library.dylib"; #else - lib_name = "libcustom_op_library.so"; + lib_name = "./libcustom_op_library.so"; #endif TestInference(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y, expected_values_y, 0, nullptr, lib_name.c_str()); @@ -509,37 +513,71 @@ TEST(CApiTest, end_profiling) { } TEST(CApiTest, model_metadata) { - Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); auto allocator = onnxruntime::make_unique(); + // The following all tap into the c++ APIs which internally wrap over C APIs - // Create session - Ort::SessionOptions session_options; - Ort::Session session(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options); + // The following section tests a model containing all metadata supported via the APIs + { + Ort::SessionOptions session_options; + Ort::Session session(*ort_env, MODEL_WITH_CUSTOM_MODEL_METADATA, session_options); - // Fetch model metadata - // The following all tap into the c++ APIs which internally wrap over C APIs - auto model_metadata = session.GetModelMetadata(); + // Fetch model metadata + auto model_metadata = session.GetModelMetadata(); + + char* producer_name = model_metadata.GetProducerName(allocator.get()); + ASSERT_TRUE(strcmp("Hari", producer_name) == 0); + allocator.get()->Free(producer_name); - char* producer_name = model_metadata.GetProducerName(allocator.get()); - ASSERT_TRUE(strcmp("Hari", producer_name) == 0); + char* graph_name = model_metadata.GetGraphName(allocator.get()); + ASSERT_TRUE(strcmp("matmul test", graph_name) == 0); + allocator.get()->Free(graph_name); - char* graph_name = model_metadata.GetGraphName(allocator.get()); - ASSERT_TRUE(strcmp("matmul test", graph_name) == 0); + char* domain = model_metadata.GetDomain(allocator.get()); + ASSERT_TRUE(strcmp("", domain) == 0); + allocator.get()->Free(domain); - char* domain = model_metadata.GetDomain(allocator.get()); - ASSERT_TRUE(strcmp("", domain) == 0); + char* description = model_metadata.GetDescription(allocator.get()); + ASSERT_TRUE(strcmp("This is a test model with a valid ORT config Json", description) == 0); + allocator.get()->Free(description); - char* description = model_metadata.GetDescription(allocator.get()); - ASSERT_TRUE(strcmp("This is a test model with a valid ORT config Json", description) == 0); + int64_t version = model_metadata.GetVersion(); + ASSERT_TRUE(version == 1); - int64_t version = model_metadata.GetVersion(); - ASSERT_TRUE(version == 1); + int64_t num_keys_in_custom_metadata_map; + char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map); + ASSERT_TRUE(num_keys_in_custom_metadata_map == 1); + ASSERT_TRUE(strcmp(custom_metadata_map_keys[0], "ort_config") == 0); + allocator.get()->Free(custom_metadata_map_keys[0]); + allocator.get()->Free(custom_metadata_map_keys); - char* lookup_value = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get()); - ASSERT_TRUE(strcmp(lookup_value, - "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0); + char* lookup_value = model_metadata.LookupCustomMetadataMap("ort_config", allocator.get()); + ASSERT_TRUE(strcmp(lookup_value, + "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}") == 0); + allocator.get()->Free(lookup_value); - // key doesn't exist in custom metadata map - lookup_value = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get()); - ASSERT_TRUE(lookup_value == nullptr); + // key doesn't exist in custom metadata map + lookup_value = model_metadata.LookupCustomMetadataMap("key_doesnt_exist", allocator.get()); + ASSERT_TRUE(lookup_value == nullptr); + } + + // The following section tests a model with some missing metadata info + // Adding this just to make sure the API implementation is able to handle empty/missing info + { + Ort::SessionOptions session_options; + Ort::Session session(*ort_env, MODEL_URI, session_options); + + // Fetch model metadata + auto model_metadata = session.GetModelMetadata(); + + // Model description is empty + char* description = model_metadata.GetDescription(allocator.get()); + ASSERT_TRUE(strcmp("", description) == 0); + allocator.get()->Free(description); + + // Model does not contain custom metadata map + int64_t num_keys_in_custom_metadata_map; + char** custom_metadata_map_keys = model_metadata.GetCustomMetadataMapKeys(allocator.get(), num_keys_in_custom_metadata_map); + ASSERT_TRUE(num_keys_in_custom_metadata_map == 0); + ASSERT_TRUE(custom_metadata_map_keys == nullptr); + } } diff --git a/onnxruntime/test/shared_lib/test_nontensor_types.cc b/onnxruntime/test/shared_lib/test_nontensor_types.cc index 0c6f676805594..372699a7dd95d 100644 --- a/onnxruntime/test/shared_lib/test_nontensor_types.cc +++ b/onnxruntime/test/shared_lib/test_nontensor_types.cc @@ -190,7 +190,7 @@ TEST(CApiTest, CreateGetSeqStringTensors) { std::vector shape{2}; auto value = Ort::Value::CreateTensor(Ort::AllocatorWithDefaultOptions(), shape.data(), shape.size(), ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING); - Ort::GetApi().FillStringTensor(value, string_input_data, 2); + Ort::ThrowOnError(Ort::GetApi().FillStringTensor(value, string_input_data, 2)); in.push_back(std::move(value)); } diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.lds b/onnxruntime/test/testdata/custom_op_library/custom_op_library.lds new file mode 100644 index 0000000000000..8e58857c2aafb --- /dev/null +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.lds @@ -0,0 +1,6 @@ +VERS_1.0.0 { + global: + RegisterCustomOps; + local: + *; +}; diff --git a/onnxruntime/test/util/test_allocator.cc b/onnxruntime/test/util/test_allocator.cc index d2c640de7e626..b2e4e24fb28e3 100644 --- a/onnxruntime/test/util/test_allocator.cc +++ b/onnxruntime/test/util/test_allocator.cc @@ -19,6 +19,8 @@ void* MockedOrtAllocator::Alloc(size_t size) { constexpr size_t extra_len = sizeof(size_t); memory_inuse.fetch_add(size += extra_len); void* p = ::malloc(size); + if (p == nullptr) + return p; *(size_t*)p = size; return (char*)p + extra_len; } diff --git a/onnxruntime/tool/etw/eparser.cc b/onnxruntime/tool/etw/eparser.cc index 883a1aef3e401..6be8b4d57fe9d 100644 --- a/onnxruntime/tool/etw/eparser.cc +++ b/onnxruntime/tool/etw/eparser.cc @@ -171,7 +171,7 @@ DWORD GetPropertyLength(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, (pInfo->EventPropertyInfoArray[i].Flags & PropertyStruct) == PropertyStruct) { *PropertyLength = pInfo->EventPropertyInfoArray[i].length; } else { - wprintf(L"Unexpected length of 0 for intype %d and outtype %d\n", + wprintf(L"Unexpected length of 0 for intype %ud and outtype %ud\n", pInfo->EventPropertyInfoArray[i].nonStructType.InType, pInfo->EventPropertyInfoArray[i].nonStructType.OutType); diff --git a/requirements-dev.txt b/requirements-dev.txt index 2877c2bba698b..93763a9611018 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,6 +10,5 @@ pytest pytest-cov scikit-learn scipy -six sympy wheel diff --git a/requirements.txt b/requirements.txt index 1cf243d912f28..52a73514b63ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -six -numpy +numpy >= 1.18.0 +onnx >= 1.2.3 protobuf diff --git a/setup.py b/setup.py index 34d77b699f96c..34e1f14d678a9 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ nightly_build = False package_name = 'onnxruntime' +wheel_name_suffix = None if '--use_tensorrt' in sys.argv: package_name = 'onnxruntime-gpu-tensorrt' @@ -34,7 +35,7 @@ elif '--use_ngraph' in sys.argv: package_name = 'onnxruntime-ngraph' sys.argv.remove('--use_ngraph') - + elif '--use_dnnl' in sys.argv: package_name = 'onnxruntime-dnnl' sys.argv.remove('--use_dnnl') @@ -51,6 +52,15 @@ nightly_build = True sys.argv.remove('--nightly_build') +for arg in sys.argv[1:]: + if arg.startswith("--wheel_name_suffix="): + wheel_name_suffix = arg[len("--wheel_name_suffix="):] + nightly_build = True + + sys.argv.remove(arg) + + break + is_manylinux1 = False if environ.get('AUDITWHEEL_PLAT', None) == 'manylinux1_x86_64' or environ.get('AUDITWHEEL_PLAT', None) == 'manylinux2010_x86_64' : is_manylinux1 = True @@ -188,19 +198,32 @@ def run(self): version_number = f.readline().strip() if nightly_build: #https://docs.microsoft.com/en-us/azure/devops/pipelines/build/variables - date_suffix = environ.get('BUILD_BUILDNUMBER') - if date_suffix is None: + build_suffix = environ.get('BUILD_BUILDNUMBER') + if build_suffix is None: #The following line is only for local testing - date_suffix = str(datetime.datetime.now().date().strftime("%Y%m%d")) + build_suffix = str(datetime.datetime.now().date().strftime("%Y%m%d")) else: - date_suffix = date_suffix.replace('.','') - version_number = version_number + ".dev" + date_suffix + build_suffix = build_suffix.replace('.','') + + version_number = version_number + ".dev" + build_suffix + +if wheel_name_suffix: + package_name = "{}_{}".format(package_name, wheel_name_suffix) cmd_classes = {} if bdist_wheel is not None : cmd_classes['bdist_wheel'] = bdist_wheel cmd_classes['build_ext'] = build_ext +requirements_path = path.join(getcwd(), "requirements.txt") +if not path.exists(requirements_path): + this = path.dirname(__file__) + requirements_path = path.join(this, "requirements.txt") +if not path.exists(requirements_path): + raise FileNotFoundError("Unable to find 'requirements.txt'") +with open(requirements_path) as f: + install_requires = f.read().splitlines() + # Setup setup( name=package_name, @@ -222,10 +245,7 @@ def run(self): 'onnxruntime': data + examples + extra, }, py_modules=python_modules_list, - install_requires=[ - 'onnx>=1.2.3', - 'numpy>=1.18.0' - ], + install_requires=install_requires, entry_points= { 'console_scripts': [ 'onnxruntime_test = onnxruntime.tools.onnxruntime_test:main', diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 31d6fb6b79335..15d090d9e3e48 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -91,6 +91,8 @@ def parse_arguments(): # Python bindings parser.add_argument("--enable_pybind", action='store_true', help="Enable Python Bindings.") parser.add_argument("--build_wheel", action='store_true', help="Build Python Wheel. ") + parser.add_argument("--wheel_name_suffix", help="Suffix to append to created wheel names." + "This value is currently only used for nightly builds.") parser.add_argument("--numpy_version", help="Installs a specific version of numpy " "before building the python binding.") parser.add_argument("--skip-keras-test", action='store_true', help="Skip tests with Keras if keras is installed") @@ -168,6 +170,7 @@ def parse_arguments(): parser.add_argument("--enable_multi_device_test", action='store_true', help="Test with multi-device. Mostly used for multi-device GPU") parser.add_argument("--use_dml", action='store_true', help="Build with DirectML.") parser.add_argument("--use_winml", action='store_true', help="Build with WinML.") + parser.add_argument("--winml_root_namespace_override", type=str, help="Specify the namespace that WinML builds into.") parser.add_argument("--use_telemetry", action='store_true', help="Only official builds can set this flag to enable telemetry.") parser.add_argument("--enable_wcos", action='store_true', help="Build for Windows Core OS.") parser.add_argument("--enable_lto", action='store_true', help="Enable Link Time Optimization") @@ -288,13 +291,19 @@ def setup_test_data(build_dir, configs): # create a shortcut for test models if there is a 'models' folder in build_dir if is_windows(): src_model_dir = os.path.join(build_dir, 'models') + if os.path.exists('C:\\local\\models') and not os.path.exists(src_model_dir): + log.debug("creating shortcut %s -> %s" % ('C:\\local\\models', src_model_dir)) + run_subprocess(['mklink', '/D', '/J', src_model_dir, 'C:\\local\\models'], shell=True) for config in configs: config_build_dir = get_config_build_dir(build_dir, config) os.makedirs(config_build_dir, exist_ok=True) dest_model_dir = os.path.join(config_build_dir, 'models') - if os.path.exists(src_model_dir) and not os.path.exists(dest_model_dir): + if os.path.exists('C:\\local\\models') and not os.path.exists(dest_model_dir): + log.debug("creating shortcut %s -> %s" % ('C:\\local\\models', dest_model_dir)) + run_subprocess(['mklink', '/D', '/J', dest_model_dir, 'C:\\local\\models'], shell=True) + elif os.path.exists(src_model_dir) and not os.path.exists(dest_model_dir): log.debug("creating shortcut %s -> %s" % (src_model_dir, dest_model_dir)) - run_subprocess(['mklink', '/D', '/J', dest_model_dir, src_model_dir], shell=True) + run_subprocess(['mklink', '/D', '/J', dest_model_dir, src_model_dir], shell=True) def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home, tensorrt_home, path_to_protoc_exe, configs, cmake_extra_defines, args, cmake_extra_args): log.info("Generating CMake build tree") @@ -351,6 +360,9 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), ] + if args.winml_root_namespace_override: + cmake_args += ["-Donnxruntime_WINML_NAMESPACE_OVERRIDE=" + args.winml_root_namespace_override] + # nGraph and TensorRT providers currently only supports full_protobuf option. if args.use_full_protobuf or args.use_ngraph or args.use_tensorrt or args.gen_doc: cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] @@ -805,7 +817,7 @@ def nuphar_run_python_tests(build_dir, configs): run_subprocess([sys.executable, 'onnxruntime_test_python_nuphar.py'], cwd=cwd, dll_path=dll_path) -def build_python_wheel(source_dir, build_dir, configs, use_cuda, use_ngraph, use_dnnl, use_tensorrt, use_openvino, use_nuphar, nightly_build = False): +def build_python_wheel(source_dir, build_dir, configs, use_cuda, use_ngraph, use_dnnl, use_tensorrt, use_openvino, use_nuphar, wheel_name_suffix, nightly_build = False): for config in configs: cwd = get_config_build_dir(build_dir, config) if is_windows(): @@ -825,6 +837,9 @@ def build_python_wheel(source_dir, build_dir, configs, use_cuda, use_ngraph, use args.append('--use_openvino') elif use_nuphar: args.append('--use_nuphar') + if wheel_name_suffix: + args.append('--wheel_name_suffix={}'.format(wheel_name_suffix)) + run_subprocess(args, cwd=cwd) def build_protoc_for_host(cmake_path, source_dir, build_dir, args): @@ -1079,7 +1094,19 @@ def main(): if args.build: if args.build_wheel: nightly_build = bool(os.getenv('NIGHTLY_BUILD') == '1') - build_python_wheel(source_dir, build_dir, configs, args.use_cuda, args.use_ngraph, args.use_dnnl, args.use_tensorrt, args.use_openvino, args.use_nuphar, nightly_build) + build_python_wheel( + source_dir, + build_dir, + configs, + args.use_cuda, + args.use_ngraph, + args.use_dnnl, + args.use_tensorrt, + args.use_openvino, + args.use_nuphar, + args.wheel_name_suffix, + nightly_build=nightly_build, + ) if args.gen_doc and (args.build or args.test): generate_documentation(source_dir, build_dir, configs) diff --git a/tools/ci_build/github/azure-pipelines/azure-pipelines-py-packaging.yml b/tools/ci_build/github/azure-pipelines/azure-pipelines-py-packaging.yml index f16e3bcc5ca4a..ec3d9b534bfa2 100644 --- a/tools/ci_build/github/azure-pipelines/azure-pipelines-py-packaging.yml +++ b/tools/ci_build/github/azure-pipelines/azure-pipelines-py-packaging.yml @@ -1,3 +1,12 @@ +# Note that any variable set in this file overrides a value set within the Azure pipeline UX. +# Therefore, do not set default values here, but instead do it within a template script. +# +# parameters: +# - name: is_featurizers_build +# displayName: "Is Featurizers Build" +# type: boolean +# default: false + jobs: - job: Manylinux2010_py_Wheels workspace: @@ -22,8 +31,16 @@ jobs: python.dir: '/opt/python/cp38-cp38' python.include.dir: '/opt/python/cp38-cp38/include/python3.8' steps: + - checkout: self + clean: true + submodules: recursive + - template: templates/set-test-data-variables-step.yml + - template: templates/set-featurizer-build-flag-step.yml + parameters: + is_featurizers_build: $(is_featurizers_build) + - task: CmdLine@2 displayName: 'Download azcopy' inputs: @@ -47,7 +64,7 @@ jobs: - task: CmdLine@2 inputs: script: | - docker run --rm --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build -e NIGHTLY_BUILD -e BUILD_BUILDNUMBER onnxruntime-manylinux-$(python.version) $(python.dir)/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_submodule_sync --parallel --build_wheel --use_openmp --enable_onnx_tests --cmake_extra_defines PYTHON_INCLUDE_DIR=$(python.include.dir) PYTHON_LIBRARY=/usr/lib64/librt.so + docker run --rm --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build -e NIGHTLY_BUILD -e BUILD_BUILDNUMBER onnxruntime-manylinux-$(python.version) $(python.dir)/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_submodule_sync --parallel --build_wheel --use_openmp --enable_onnx_tests $(FeaturizerBuildFlag) --cmake_extra_defines PYTHON_INCLUDE_DIR=$(python.include.dir) PYTHON_LIBRARY=/usr/lib64/librt.so workingDirectory: $(Build.SourcesDirectory) - task: CopyFiles@2 @@ -89,8 +106,16 @@ jobs: python.dir: '/opt/python/cp38-cp38' python.include.dir: '/opt/python/cp38-cp38/include/python3.8' steps: + - checkout: self + clean: true + submodules: recursive + - template: templates/set-test-data-variables-step.yml + - template: templates/set-featurizer-build-flag-step.yml + parameters: + is_featurizers_build: $(is_featurizers_build) + - task: CmdLine@2 displayName: 'Download azcopy' inputs: @@ -114,7 +139,7 @@ jobs: - task: CmdLine@2 inputs: script: | - docker run --gpus all -e NVIDIA_VISIBLE_DEVICES=all --rm --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build -e NIGHTLY_BUILD -e BUILD_BUILDNUMBER onnxruntime-manylinux-gpu-$(python.version) $(python.dir)/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_submodule_sync --parallel --build_wheel --use_openmp --enable_onnx_tests --cmake_extra_defines PYTHON_INCLUDE_DIR=$(python.include.dir) PYTHON_LIBRARY=/usr/lib64/librt.so --use_cuda --cuda_version=10.1 --cuda_home=/usr/local/cuda-10.1 --cudnn_home=/usr/local/cuda-10.1 + docker run --gpus all -e NVIDIA_VISIBLE_DEVICES=all --rm --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build -e NIGHTLY_BUILD -e BUILD_BUILDNUMBER onnxruntime-manylinux-gpu-$(python.version) $(python.dir)/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_submodule_sync --parallel --build_wheel --use_openmp --enable_onnx_tests $(FeaturizerBuildFlag) --cmake_extra_defines PYTHON_INCLUDE_DIR=$(python.include.dir) PYTHON_LIBRARY=/usr/lib64/librt.so --use_cuda --cuda_version=10.1 --cuda_home=/usr/local/cuda-10.1 --cudnn_home=/usr/local/cuda-10.1 workingDirectory: $(Build.SourcesDirectory) - task: CopyFiles@2 @@ -156,14 +181,23 @@ jobs: timeoutInMinutes: 60 workspace: clean: all - steps: + + steps: + - checkout: self + clean: true + submodules: recursive + - task: UsePythonVersion@0 - inputs: - versionSpec: $(python.version) - addToPath: true + inputs: + versionSpec: $(python.version) + addToPath: true architecture: 'x64' - template: templates/set-test-data-variables-step.yml + - template: templates/set-featurizer-build-flag-step.yml + parameters: + is_featurizers_build: $(is_featurizers_build) + - task: BatchScript@1 displayName: 'setup env' inputs: @@ -173,24 +207,26 @@ jobs: - script: | python -m pip install -q pyopenssl setuptools wheel numpy==1.18 - + workingDirectory: '$(Build.BinariesDirectory)' - displayName: 'Install python modules' + displayName: 'Install python modules' - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) + - powershell: | + $Env:USE_MSVC_STATIC_RUNTIME=1 + $Env:ONNX_ML=1 + $Env:CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake -DVCPKG_TARGET_TRIPLET=$(buildArch)-windows-static" + python setup.py bdist_wheel + Get-ChildItem -Path dist/*.whl | foreach {pip install --upgrade $_.fullname} + workingDirectory: '$(Build.SourcesDirectory)\cmake\external\onnx' + displayName: 'Install ONNX' - task: PythonScript@0 displayName: 'BUILD' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --enable_lto --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_generator "Visual Studio 16 2019" --build_wheel --use_openmp --enable_onnx_tests --parallel' - workingDirectory: '$(Build.BinariesDirectory)' - + arguments: '--config RelWithDebInfo --enable_lto --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_generator "Visual Studio 16 2019" --build_wheel --use_openmp --enable_onnx_tests $(FeaturizerBuildFlag) --parallel' + workingDirectory: '$(Build.BinariesDirectory)' + - task: CopyFiles@2 displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' inputs: @@ -213,9 +249,10 @@ jobs: workspace: clean: all pool: 'Win-GPU-2019' - timeoutInMinutes: 60 + timeoutInMinutes: 60 variables: CUDA_VERSION: '10.1' + buildArch: x64 EnvSetupScript: setup_env_cuda.bat strategy: matrix: @@ -226,10 +263,14 @@ jobs: Python37: python.version: '3.7' steps: + - checkout: self + clean: true + submodules: recursive + - task: UsePythonVersion@0 - inputs: - versionSpec: $(python.version) - addToPath: true + inputs: + versionSpec: $(python.version) + addToPath: true architecture: 'x64' - task: BatchScript@1 @@ -239,26 +280,32 @@ jobs: modifyEnvironment: true workingFolder: '$(Build.BinariesDirectory)' - - script: | python -m pip install -q pyopenssl setuptools wheel numpy==1.18 workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' + - powershell: | + $Env:USE_MSVC_STATIC_RUNTIME=1 + $Env:ONNX_ML=1 + $Env:CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake -DVCPKG_TARGET_TRIPLET=$(buildArch)-windows-static" + python setup.py bdist_wheel + Get-ChildItem -Path dist/*.whl | foreach {pip install --upgrade $_.fullname} + workingDirectory: '$(Build.SourcesDirectory)\cmake\external\onnx' + displayName: 'Install ONNX' + - template: templates/set-test-data-variables-step.yml - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) + + - template: templates/set-featurizer-build-flag-step.yml + parameters: + is_featurizers_build: $(is_featurizers_build) - task: PythonScript@0 displayName: 'build' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: --config RelWithDebInfo --enable_lto --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_generator "Visual Studio 16 2019" --build_wheel --use_openmp --enable_onnx_tests --parallel --use_cuda --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cudnn_home="C:\local\cudnn-$(CUDA_VERSION)-windows10-x64-v7.6.5.32\cuda" - workingDirectory: '$(Build.BinariesDirectory)' + arguments: --config RelWithDebInfo --enable_lto --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --cmake_generator "Visual Studio 16 2019" --build_wheel --use_openmp --enable_onnx_tests $(FeaturizerBuildFlag) --parallel --use_cuda --cuda_version=$(CUDA_VERSION) --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$(CUDA_VERSION)" --cudnn_home="C:\local\cudnn-$(CUDA_VERSION)-windows10-x64-v7.6.5.32\cuda" + workingDirectory: '$(Build.BinariesDirectory)' - task: CopyFiles@2 displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' @@ -275,7 +322,7 @@ jobs: - template: templates/component-governance-component-detection-steps.yml - template: templates/clean-agent-build-directory-step.yml - + - job: MacOS_py_Wheels workspace: clean: all @@ -292,6 +339,14 @@ jobs: Python38: python.version: '3.8' steps: + - checkout: self + clean: true + submodules: recursive + + - template: templates/set-featurizer-build-flag-step.yml + parameters: + is_featurizers_build: $(is_featurizers_build) + - task: UsePythonVersion@0 displayName: 'Use Python' inputs: @@ -300,9 +355,9 @@ jobs: - script: | sudo python -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' sudo xcode-select --switch /Applications/Xcode_10.app/Contents/Developer - ./build.sh --config Release --skip_submodule_sync --parallel --build_wheel --use_openmp - displayName: 'Command Line Script' - + ./build.sh --config Release --skip_submodule_sync --parallel --build_wheel --use_openmp $(FeaturizerBuildFlag) + displayName: 'Command Line Script' + - task: CopyFiles@2 displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' inputs: diff --git a/tools/ci_build/github/azure-pipelines/c-api-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-packaging-pipelines.yml index 56ceebe0d9be4..b4ffece4c0bc5 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-packaging-pipelines.yml @@ -108,20 +108,20 @@ jobs: workingFolder: '$(Build.BinariesDirectory)' - script: | - python -m pip install -q pyopenssl setuptools wheel numpy + python -m pip install -q pyopenssl setuptools wheel numpy scipy workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' - + - powershell: | + $Env:USE_MSVC_STATIC_RUNTIME=1 + $Env:ONNX_ML=1 + $Env:CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake -DVCPKG_TARGET_TRIPLET=$(buildArch)-windows-static" + python setup.py bdist_wheel + Get-ChildItem -Path dist/*.whl | foreach {pip install --upgrade $_.fullname} + workingDirectory: '$(Build.SourcesDirectory)\cmake\external\onnx' + displayName: 'Install ONNX' + - template: templates/set-test-data-variables-step.yml - template: templates/set-version-number-variables-step.yml - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) - - - task: PythonScript@0 displayName: 'Generate cmake config' inputs: @@ -185,19 +185,20 @@ jobs: workingFolder: '$(Build.BinariesDirectory)' - script: | - python -m pip install -q pyopenssl setuptools wheel numpy + python -m pip install -q pyopenssl setuptools wheel numpy scipy workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' - + - powershell: | + $Env:USE_MSVC_STATIC_RUNTIME=1 + $Env:ONNX_ML=1 + $Env:CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake -DVCPKG_TARGET_TRIPLET=$(buildArch)-windows-static" + python setup.py bdist_wheel + Get-ChildItem -Path dist/*.whl | foreach {pip install --upgrade $_.fullname} + workingDirectory: '$(Build.SourcesDirectory)\cmake\external\onnx' + displayName: 'Install ONNX' + - template: templates/set-test-data-variables-step.yml - template: templates/set-version-number-variables-step.yml - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) - - task: PythonScript@0 displayName: 'Generate cmake config' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/bundle_dlls.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/bundle_dlls.yml index d93c58789b0a5..2a0223e59fbae 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/bundle_dlls.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/bundle_dlls.yml @@ -22,9 +22,9 @@ steps: move win-x86\runtimes\win-x86\native\onnxruntime.dll %%~ni\runtimes\win-x86\native\onnxruntime.dll move win-x86\runtimes\win-x86\native\onnxruntime.lib %%~ni\runtimes\win-x86\native\onnxruntime.lib move win-x86\runtimes\win-x86\native\onnxruntime.pdb %%~ni\runtimes\win-x86\native\onnxruntime.pdb - move win-x86\runtimes\win-x86\native\windows.ai.machinelearning.dll %%~ni\runtimes\win-x86\native\Windows.AI.MachineLearning.dll - move win-x86\runtimes\win-x86\native\windows.ai.machinelearning.lib %%~ni\runtimes\win-x86\native\Windows.AI.MachineLearning.lib - move win-x86\runtimes\win-x86\native\windows.ai.machinelearning.pdb %%~ni\runtimes\win-x86\native\Windows.AI.MachineLearning.pdb + move win-x86\runtimes\win-x86\native\microsoft.ai.machinelearning.dll %%~ni\runtimes\win-x86\native\Microsoft.AI.MachineLearning.dll + move win-x86\runtimes\win-x86\native\microsoft.ai.machinelearning.lib %%~ni\runtimes\win-x86\native\Microsoft.AI.MachineLearning.lib + move win-x86\runtimes\win-x86\native\microsoft.ai.machinelearning.pdb %%~ni\runtimes\win-x86\native\Microsoft.AI.MachineLearning.pdb move linux-x64\linux-x64\libonnxruntime.so %%~ni\runtimes\linux-x64\native\libonnxruntime.so unzip osx-x64.zip -d osx-x64 dir osx-x64 /s diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/gpu.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/gpu.yml index f1dab8d2a1c64..58a08e651f575 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/gpu.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/gpu.yml @@ -188,9 +188,9 @@ jobs: move win-x86\runtimes\win-x86\native\onnxruntime.lib %%~ni\runtimes\win-x86\native\onnxruntime.lib move win-x86\runtimes\win-x86\native\onnxruntime.pdb %%~ni\runtimes\win-x86\native\onnxruntime.pdb - move win-x86\runtimes\win-x86\native\windows.ai.machinelearning.dll %%~ni\runtimes\win-x86\native\windows.ai.machinelearning.dll - move win-x86\runtimes\win-x86\native\windows.ai.machinelearning.lib %%~ni\runtimes\win-x86\native\windows.ai.machinelearning.lib - move win-x86\runtimes\win-x86\native\windows.ai.machinelearning.pdb %%~ni\runtimes\win-x86\native\windows.ai.machinelearning.pdb + move win-x86\runtimes\win-x86\native\microsoft.ai.machinelearning.dll %%~ni\runtimes\win-x86\native\microsoft.ai.machinelearning.dll + move win-x86\runtimes\win-x86\native\microsoft.ai.machinelearning.lib %%~ni\runtimes\win-x86\native\microsoft.ai.machinelearning.lib + move win-x86\runtimes\win-x86\native\microsoft.ai.machinelearning.pdb %%~ni\runtimes\win-x86\native\microsoft.ai.machinelearning.pdb move win-x86\runtimes\win-x86\native\directml.dll %%~ni\runtimes\win-x86\native\directml.dll diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index b666b7586354a..f6d06b2136811 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -28,7 +28,7 @@ jobs: workingFolder: '$(Build.BinariesDirectory)' - script: | - python -m pip install -q pyopenssl setuptools wheel numpy + python -m pip install -q pyopenssl setuptools wheel numpy scipy mkdir $(Build.SourcesDirectory)\$(BuildConfig) workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' @@ -67,12 +67,6 @@ jobs: - template: templates/set-test-data-variables-step.yml - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) - task: BatchScript@1 displayName: 'Setup VS2019 env vars' diff --git a/tools/ci_build/github/azure-pipelines/templates/set-featurizer-build-flag-step.yml b/tools/ci_build/github/azure-pipelines/templates/set-featurizer-build-flag-step.yml new file mode 100644 index 0000000000000..be919eda25eed --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/set-featurizer-build-flag-step.yml @@ -0,0 +1,17 @@ +parameters: + is_featurizers_build: false + +steps: +- task: PythonScript@0 + displayName: "Set FeaturizerBuildFlag variable" + inputs: + scriptSource: inline + script: |- + import os + + if "${{ parameters.is_featurizers_build }}".lower() == "true": + flags = "--use_featurizers --wheel_name_suffix=featurizer" + else: + flags = "" + + print("##vso[task.setvariable variable=FeaturizerBuildFlag]%s" % flags) diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci-2019.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci-2019.yml index 01016ea1b90f6..d85133038e5ac 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci-2019.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci-2019.yml @@ -59,11 +59,19 @@ jobs: workingFolder: '$(Build.BinariesDirectory)' - script: | - python -m pip install -q pyopenssl setuptools wheel numpy + python -m pip install -q pyopenssl setuptools wheel numpy scipy workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' - + - powershell: | + $Env:USE_MSVC_STATIC_RUNTIME=1 + $Env:ONNX_ML=1 + $Env:CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake -DVCPKG_TARGET_TRIPLET=$(buildArch)-windows-static" + python setup.py bdist_wheel + Get-ChildItem -Path dist/*.whl | foreach {pip install --upgrade $_.fullname} + workingDirectory: '$(Build.SourcesDirectory)\cmake\external\onnx' + displayName: 'Install ONNX' + - task: PythonScript@0 displayName: 'Generate cmake config' inputs: @@ -117,14 +125,7 @@ jobs: arguments: '--configuration RelWithDebInfo -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=${{ parameters.OrtPackageId }}' workingDirectory: '$(Build.SourcesDirectory)\csharp' - - ${{ if in(parameters['sln_platform'], 'Win32', 'x64') }}: - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) - + - ${{ if in(parameters['sln_platform'], 'Win32', 'x64') }}: - task: DotNetCoreCLI@2 displayName: 'Test C#' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci-arm.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci-arm.yml index da4c6b91a622f..765f09cdea8aa 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci-arm.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci-arm.yml @@ -64,7 +64,7 @@ jobs: architecture: x64 - script: | - python -m pip install -q pyopenssl setuptools wheel numpy + python -m pip install -q pyopenssl setuptools wheel numpy scipy workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' diff --git a/tools/ci_build/github/azure-pipelines/templates/windows-build-tools-setup-steps.yml b/tools/ci_build/github/azure-pipelines/templates/windows-build-tools-setup-steps.yml index ec13ca83b2cfe..40e513b015e6f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/windows-build-tools-setup-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/windows-build-tools-setup-steps.yml @@ -52,14 +52,6 @@ steps: arguments: 'install -q --insecure -y pyopenssl setuptools wheel numpy' timeoutInMinutes: 10 - - task: PythonScript@0 - displayName: 'Download test data' - condition: ${{parameters.DoDataDownload}} - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - pythonInterpreter: '$(Build.BinariesDirectory)\packages\python\python.exe' - workingDirectory: $(Build.BinariesDirectory) - task: CmdLine@1 continueOnError: true diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 12d301481c0b3..0e8814a25d424 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -16,7 +16,7 @@ jobs: EnvSetupScript: setup_env.bat buildArch: x64 setVcvars: true - timeoutInMinutes: 90 + timeoutInMinutes: 120 workspace: clean: all steps: @@ -34,10 +34,19 @@ jobs: workingFolder: '$(Build.BinariesDirectory)' - script: | - python -m pip install -q pyopenssl setuptools wheel numpy + python -m pip install -q pyopenssl setuptools wheel numpy scipy workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' + - powershell: | + $Env:USE_MSVC_STATIC_RUNTIME=1 + $Env:ONNX_ML=1 + $Env:CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake -DVCPKG_TARGET_TRIPLET=$(buildArch)-windows-static" + python setup.py bdist_wheel + Get-ChildItem -Path dist/*.whl | foreach {pip install --upgrade $_.fullname} + workingDirectory: '$(Build.SourcesDirectory)\cmake\external\onnx' + displayName: 'Install ONNX' + - task: NuGetToolInstaller@0 displayName: Use Nuget 4.9 inputs: @@ -81,12 +90,202 @@ jobs: - template: templates/set-test-data-variables-step.yml + - task: DotNetCoreCLI@2 + displayName: 'Restore nuget packages' + inputs: + command: restore + projects: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + configuration: '$(BuildConfig)' + arguments: '--configuration $(BuildConfig) -p:Platform="Any CPU" -p:OrtPackageId=$(OrtPackageId)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: DotNetCoreCLI@2 + displayName: 'Build C#' + inputs: + command: build + projects: '$(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.sln' + configuration: '$(BuildConfig)' + arguments: '--configuration $(BuildConfig) -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - task: DotNetCoreCLI@2 + displayName: 'Test C#' + condition: and(succeeded(), eq(variables['BuildConfig'], 'RelWithDebInfo')) + inputs: + command: test + projects: '$(Build.SourcesDirectory)\csharp\test\Microsoft.ML.OnnxRuntime.Tests\Microsoft.ML.OnnxRuntime.Tests.csproj' + configuration: '$(BuildConfig)' + arguments: '--configuration $(BuildConfig) -p:Platform="Any CPU" -p:OnnxRuntimeBuildDirectory="$(Build.BinariesDirectory)" -p:OrtPackageId=$(OrtPackageId)' + workingDirectory: '$(Build.SourcesDirectory)\csharp' + + - script: | + mklink /D /J $(Build.BinariesDirectory)\$(BuildConfig)\models $(Build.BinariesDirectory)\models + DIR dist\ /S /B > wheel_filename_file + set /p WHEEL_FILENAME= wheel_filename_file + set /p WHEEL_FILENAME= wheel_filename_file + set /p WHEEL_FILENAME= wheel_filename_file @@ -123,7 +557,7 @@ jobs: del wheel_filename_file python.exe -m pip install -q --upgrade %WHEEL_FILENAME% set PATH=$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig);%PATH% - python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 16 2019" --use_dnnl --build_wheel --enable_onnx_tests + python $(Build.SourcesDirectory)\tools\ci_build\build.py --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 16 2019" --build_wheel --disable_contrib_ops --enable_msvc_static_runtime --enable_onnx_tests workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run tests' @@ -133,4 +567,5 @@ jobs: parameters : condition : 'succeeded' - - template: templates/clean-agent-build-directory-step.yml \ No newline at end of file + - template: templates/clean-agent-build-directory-step.yml + diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 4ba0e1020663e..fdd4e76c7c794 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -34,7 +34,7 @@ jobs: workingFolder: '$(Build.BinariesDirectory)' - script: | - python -m pip install -q pyopenssl setuptools wheel numpy + python -m pip install -q pyopenssl setuptools wheel numpy scipy workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' @@ -72,15 +72,6 @@ jobs: inputs: versionSpec: 4.9.4 - - - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) - - task: DotNetCoreCLI@2 displayName: 'Restore nuget packages' inputs: diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml index c8c764a2c4aca..84a47ee3b9d95 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml @@ -67,12 +67,6 @@ jobs: - template: templates/set-test-data-variables-step.yml - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) - script: | mklink /D /J $(Build.BinariesDirectory)\$(BuildConfig)\models $(Build.BinariesDirectory)\models diff --git a/tools/ci_build/github/azure-pipelines/win-nocontribops-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-nocontribops-ci-pipeline.yml index 8eeb6592d0e08..08ae78fdcc87c 100644 --- a/tools/ci_build/github/azure-pipelines/win-nocontribops-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-nocontribops-ci-pipeline.yml @@ -37,7 +37,16 @@ jobs: python -m pip install -q pyopenssl setuptools wheel numpy workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' - + + - powershell: | + $Env:USE_MSVC_STATIC_RUNTIME=1 + $Env:ONNX_ML=1 + $Env:CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake -DVCPKG_TARGET_TRIPLET=$(buildArch)-windows-static" + python setup.py bdist_wheel + Get-ChildItem -Path dist/*.whl | foreach {pip install --upgrade $_.fullname} + workingDirectory: '$(Build.SourcesDirectory)\cmake\external\onnx' + displayName: 'Install ONNX' + - task: PythonScript@0 displayName: 'Generate cmake config' inputs: @@ -70,16 +79,7 @@ jobs: - task: NuGetToolInstaller@0 displayName: Use Nuget 4.9 inputs: - versionSpec: 4.9.4 - - - - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) + versionSpec: 4.9.4 - task: DotNetCoreCLI@2 displayName: 'Restore nuget packages' diff --git a/tools/ci_build/github/azure-pipelines/win-x86-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-x86-ci-pipeline.yml index 8e2930315f344..22efc60ae8ce6 100644 --- a/tools/ci_build/github/azure-pipelines/win-x86-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-x86-ci-pipeline.yml @@ -34,10 +34,19 @@ jobs: workingFolder: '$(Build.BinariesDirectory)' - script: | - python -m pip install -q pyopenssl setuptools wheel numpy + python -m pip install -q pyopenssl setuptools wheel numpy scipy workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' + - powershell: | + $Env:USE_MSVC_STATIC_RUNTIME=1 + $Env:ONNX_ML=1 + $Env:CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake -DVCPKG_TARGET_TRIPLET=$(buildArch)-windows-static" + python setup.py bdist_wheel + Get-ChildItem -Path dist/*.whl | foreach {pip install --upgrade $_.fullname} + workingDirectory: '$(Build.SourcesDirectory)\cmake\external\onnx' + displayName: 'Install ONNX' + - task: PythonScript@0 displayName: 'Generate cmake config' inputs: @@ -72,14 +81,6 @@ jobs: inputs: versionSpec: 4.9.4 - - - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) - task: DotNetCoreCLI@2 displayName: 'Restore nuget packages' diff --git a/tools/ci_build/github/azure-pipelines/win-x86-nocontribops-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-x86-nocontribops-ci-pipeline.yml index 8a866e9f8b77d..6f95cef3ca4c5 100644 --- a/tools/ci_build/github/azure-pipelines/win-x86-nocontribops-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-x86-nocontribops-ci-pipeline.yml @@ -34,10 +34,19 @@ jobs: workingFolder: '$(Build.BinariesDirectory)' - script: | - python -m pip install -q pyopenssl setuptools wheel numpy + python -m pip install -q pyopenssl setuptools wheel numpy scipy workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' - + + - powershell: | + $Env:USE_MSVC_STATIC_RUNTIME=1 + $Env:ONNX_ML=1 + $Env:CMAKE_ARGS="-DONNX_USE_PROTOBUF_SHARED_LIBS=OFF -DProtobuf_USE_STATIC_LIBS=ON -DONNX_USE_LITE_PROTO=ON -DCMAKE_TOOLCHAIN_FILE=C:/vcpkg/scripts/buildsystems/vcpkg.cmake -DVCPKG_TARGET_TRIPLET=$(buildArch)-windows-static" + python setup.py bdist_wheel + Get-ChildItem -Path dist/*.whl | foreach {pip install --upgrade $_.fullname} + workingDirectory: '$(Build.SourcesDirectory)\cmake\external\onnx' + displayName: 'Install ONNX' + - task: PythonScript@0 displayName: 'Generate cmake config' inputs: @@ -71,15 +80,7 @@ jobs: displayName: Use Nuget 4.9 inputs: versionSpec: 4.9.4 - - - - - task: PythonScript@0 - displayName: 'Download test data' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\github\download_test_data.py' - arguments: --test_data_url $(TestDataUrl) --build_dir $(Build.BinariesDirectory) - workingDirectory: $(Build.BinariesDirectory) + - task: DotNetCoreCLI@2 displayName: 'Restore nuget packages' diff --git a/tools/ci_build/github/download_test_data.py b/tools/ci_build/github/download_test_data.py index bda0e0840cb43..d7d173cb758cb 100755 --- a/tools/ci_build/github/download_test_data.py +++ b/tools/ci_build/github/download_test_data.py @@ -151,10 +151,6 @@ def download_additional_data(build_dir, azure_region): print("Starting test data download %s" % url) download_and_unzip(args.build_dir, url, models_folder) - # On windows download additional data - if is_windows(): - download_additional_data(args.build_dir, azure_region) - all_downloads_done = True except Exception as e: diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index f2a4d641a2585..f9c0ca6b24570 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -134,26 +134,27 @@ def generate_files(list, args): files_list.append('') files_list.append('') - # Process Windows.AI.MachineLearning lib, dll, and pdb - if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime') and os.path.exists(os.path.join(args.native_build_path, 'windows.ai.machinelearning.lib')): - files_list.append('') - - if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime') and os.path.exists(os.path.join(args.native_build_path, 'windows.ai.machinelearning.dll')): - files_list.append('') - - if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime') and os.path.exists(os.path.join(args.native_build_path, 'windows.ai.machinelearning.pdb')): - files_list.append('') - - # Process windows.ai.machinelearning.winmd - if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime') and os.path.exists(os.path.join(args.ort_build_path, args.build_config, 'windows.ai.machinelearning.winmd')): - files_list.append('') - - # Process windows.ai.machinelearning headers - if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime') and os.path.exists(os.path.join(args.ort_build_path, args.build_config, 'windows.ai.machinelearning.h')): - files_list.append('') - - if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime') and os.path.exists(os.path.join(args.ort_build_path, args.build_config, 'windows.ai.machinelearning.native.h')): - files_list.append('') + if (is_windows()): + # Process Microsoft.AI.MachineLearning lib, dll, and pdb + if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime'): + files_list.append('') + if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime'): + files_list.append('') + if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime'): + files_list.append('') + # Process microsoft.ai.machinelearning.winmd + if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime'): + files_list.append('') + # Process microsoft.ai.machinelearning headers + if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime'): + files_list.append('') + if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime'): + files_list.append('') + if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime'): + files_list.append('') + + if (args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' or args.package_name == 'Microsoft.ML.OnnxRuntime') and os.path.exists(os.path.join(args.ort_build_path, args.build_config, 'dualapipartitionattribute.h')): + files_list.append('') # Process dnll.dll if os.path.exists(os.path.join(args.native_build_path, 'dnnl.dll')): diff --git a/tools/python/remove_initializer_from_input.py b/tools/python/remove_initializer_from_input.py new file mode 100644 index 0000000000000..6099a0fc35c56 --- /dev/null +++ b/tools/python/remove_initializer_from_input.py @@ -0,0 +1,37 @@ +import onnx +import sys +import argparse + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True, help="input model") + parser.add_argument("--output", required=True, help="output model") + args = parser.parse_args() + return args + + +def remove_initializer_from_input(): + args = get_args() + + model = onnx.load(args.input) + if model.ir_version < 4: + print( + 'Model with ir_version below 4 requires to include initilizer in graph input' + ) + return + + inputs = model.graph.input + name_to_input = {} + for input in inputs: + name_to_input[input.name] = input + + for initializer in model.graph.initializer: + if initializer.name in name_to_input: + inputs.remove(name_to_input[initializer.name]) + + onnx.save(model, args.output) + + +if __name__ == '__main__': + remove_initializer_from_input() diff --git a/winml/api/Windows.AI.MachineLearning.idl b/winml/api/Windows.AI.MachineLearning.idl index 2debbc75e5c9e..fcc2c88fce38a 100644 --- a/winml/api/Windows.AI.MachineLearning.idl +++ b/winml/api/Windows.AI.MachineLearning.idl @@ -16,9 +16,18 @@ import "windows.graphics.imaging.idl"; import "windows.storage.idl"; #endif +#ifndef ROOT_NS +#define ROOT_NS Windows +#endif + +#define STRINGIFY(x) #x +#define XSTRINGIFY(x) STRINGIFY(x) +#define CREATE_OBJECT_TOKEN(root_ns, obj) root_ns##.AI.MachineLearning.##obj +#define EMBED_IN_NAMESPACE(x) XSTRINGIFY(CREATE_OBJECT_TOKEN(ROOT_NS, x)) + #include -namespace Windows.AI.MachineLearning +namespace ROOT_NS.AI.MachineLearning { [contractversion(3)] apicontract MachineLearningContract{}; @@ -77,8 +86,8 @@ namespace Windows.AI.MachineLearning //! one of the Load constructors. You can then enumerate the InputFeatures and //! OutputFeatures. To bind and evaluate you create a LearningModelSession. [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ILearningModelStatics", e3b977e8-6952-4e47-8ef4-1f7f07897c6d)] - [interface_name("Windows.AI.MachineLearning.ILearningModel", 5b8e4920-489f-4e86-9128-265a327b78fa)] + [static_name(EMBED_IN_NAMESPACE(ILearningModelStatics), e3b977e8-6952-4e47-8ef4-1f7f07897c6d)] + [interface_name(EMBED_IN_NAMESPACE(ILearningModel), 5b8e4920-489f-4e86-9128-265a327b78fa)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -144,11 +153,11 @@ namespace Windows.AI.MachineLearning //! \class LearningModelDevice //! \brief Create an instance specific to which device you want to evaluate the machine learning model on. - //! \namespace Windows.AI.MachineLearning + //! \namespace *.AI.MachineLearning [contract(MachineLearningContract, 1)] - [constructor_name("Windows.AI.MachineLearning.ILearningModelDeviceFactory", 9cffd74d-b1e5-4f20-80ad-0a56690db06b)] - [static_name("Windows.AI.MachineLearning.ILearningModelDeviceStatics", 49f32107-a8bf-42bb-92c7-10b12dc5d21f)] - [interface_name("Windows.AI.MachineLearning.ILearningModelDevice", f5c2c8fe-3f56-4a8c-ac5f-fdb92d8b8252)] + [constructor_name(EMBED_IN_NAMESPACE(ILearningModelDeviceFactory), 9cffd74d-b1e5-4f20-80ad-0a56690db06b)] + [static_name(EMBED_IN_NAMESPACE(ILearningModelDeviceStatics), 49f32107-a8bf-42bb-92c7-10b12dc5d21f)] + [interface_name(EMBED_IN_NAMESPACE(ILearningModelDevice), f5c2c8fe-3f56-4a8c-ac5f-fdb92d8b8252)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -168,7 +177,7 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [interface_name("Windows.AI.MachineLearning.ILearningModelEvaluationResult", b2f9bfcd-960e-49c0-8593-eb190ae3eee2)] + [interface_name(EMBED_IN_NAMESPACE(ILearningModelEvaluationResult), b2f9bfcd-960e-49c0-8593-eb190ae3eee2)] [marshaling_behavior(agile)] [dualapipartition(1)] runtimeclass LearningModelEvaluationResult @@ -212,8 +221,8 @@ namespace Windows.AI.MachineLearning //! \class LearningModelSession //! \brief TODO:Docs [contract(MachineLearningContract, 1)] - [constructor_name("Windows.AI.MachineLearning.ILearningModelSessionFactory", 0f6b881d-1c9b-47b6-bfe0-f1cf62a67579)] - [interface_name("Windows.AI.MachineLearning.ILearningModelSession", 8e58f8f6-b787-4c11-90f0-7129aeca74a9)] + [constructor_name(EMBED_IN_NAMESPACE(ILearningModelSessionFactory), 0f6b881d-1c9b-47b6-bfe0-f1cf62a67579)] + [interface_name(EMBED_IN_NAMESPACE(ILearningModelSession), 8e58f8f6-b787-4c11-90f0-7129aeca74a9)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -265,8 +274,8 @@ namespace Windows.AI.MachineLearning //! \class LearningModelBinding //! \brief Holder for associations between model inputs/outputs and variable instances. [contract(MachineLearningContract, 1)] - [constructor_name("Windows.AI.MachineLearning.ILearningModelBindingFactory", c95f7a7a-e788-475e-8917-23aa381faf0b)] - [interface_name("Windows.AI.MachineLearning.ILearningModelBinding", ea312f20-168f-4f8c-94fe-2e7ac31b4aa8)] + [constructor_name(EMBED_IN_NAMESPACE(ILearningModelBindingFactory), c95f7a7a-e788-475e-8917-23aa381faf0b)] + [interface_name(EMBED_IN_NAMESPACE(ILearningModelBinding), ea312f20-168f-4f8c-94fe-2e7ac31b4aa8)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -325,7 +334,7 @@ namespace Windows.AI.MachineLearning //! \class MapFeatureDescriptor //! \brief TODO:Docs [contract(MachineLearningContract, 1)] - [interface_name("Windows.AI.MachineLearning.IMapFeatureDescriptor", 530424bd-a257-436d-9e60-c2981f7cc5c4)] + [interface_name(EMBED_IN_NAMESPACE(IMapFeatureDescriptor), 530424bd-a257-436d-9e60-c2981f7cc5c4)] [marshaling_behavior(agile)] [dualapipartition(1)] runtimeclass MapFeatureDescriptor : ILearningModelFeatureDescriptor @@ -339,7 +348,7 @@ namespace Windows.AI.MachineLearning //! \class SequenceFeatureDescriptor //! \brief TODO:Docs [contract(MachineLearningContract, 1)] - [interface_name("Windows.AI.MachineLearning.ISequenceFeatureDescriptor", 84f6945a-562b-4d62-a851-739aced96668)] + [interface_name(EMBED_IN_NAMESPACE(ISequenceFeatureDescriptor), 84f6945a-562b-4d62-a851-739aced96668)] [marshaling_behavior(agile)] [dualapipartition(1)] runtimeclass SequenceFeatureDescriptor : ILearningModelFeatureDescriptor @@ -351,7 +360,7 @@ namespace Windows.AI.MachineLearning //! \class TensorFeatureDescriptor //! \brief TODO:Docs [contract(MachineLearningContract, 1)] - [interface_name("Windows.AI.MachineLearning.ITensorFeatureDescriptor", 74455c80-946a-4310-a19c-ee0af028fce4)] + [interface_name(EMBED_IN_NAMESPACE(ITensorFeatureDescriptor), 74455c80-946a-4310-a19c-ee0af028fce4)] [marshaling_behavior(agile)] [dualapipartition(1)] runtimeclass TensorFeatureDescriptor : ILearningModelFeatureDescriptor @@ -365,7 +374,7 @@ namespace Windows.AI.MachineLearning //! \class ImageFeatureDescriptor //! \brief TODO:Docs [contract(MachineLearningContract, 1)] - [interface_name("Windows.AI.MachineLearning.IImageFeatureDescriptor", 365585a5-171a-4a2a-985f-265159d3895a)] + [interface_name(EMBED_IN_NAMESPACE(IImageFeatureDescriptor), 365585a5-171a-4a2a-985f-265159d3895a)] [marshaling_behavior(agile)] [dualapipartition(1)] runtimeclass ImageFeatureDescriptor : ILearningModelFeatureDescriptor @@ -395,8 +404,8 @@ namespace Windows.AI.MachineLearning //! \class TensorFloat //! \brief A 32bit float tensor object. [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorFloatStatics", dbcd395b-3ba3-452f-b10d-3c135e573fa9)] - [interface_name("Windows.AI.MachineLearning.ITensorFloat", f2282d82-aa02-42c8-a0c8-df1efc9676e1)] + [static_name(EMBED_IN_NAMESPACE(ITensorFloatStatics), dbcd395b-3ba3-452f-b10d-3c135e573fa9)] + [interface_name(EMBED_IN_NAMESPACE(ITensorFloat), f2282d82-aa02-42c8-a0c8-df1efc9676e1)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -425,8 +434,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorFloat16BitStatics", a52db6f5-318a-44d4-820b-0cdc7054a84a)] - [interface_name("Windows.AI.MachineLearning.ITensorFloat16Bit", 0ab994fc-5b89-4c3c-b5e4-5282a5316c0a)] + [static_name(EMBED_IN_NAMESPACE(ITensorFloat16BitStatics), a52db6f5-318a-44d4-820b-0cdc7054a84a)] + [interface_name(EMBED_IN_NAMESPACE(ITensorFloat16Bit), 0ab994fc-5b89-4c3c-b5e4-5282a5316c0a)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -450,8 +459,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorUInt8BitStatics", 05f67583-bc24-4220-8a41-2dcd8c5ed33c)] - [interface_name("Windows.AI.MachineLearning.ITensorUInt8Bit", 58e1ae27-622b-48e3-be22-d867aed1daac)] + [static_name(EMBED_IN_NAMESPACE(ITensorUInt8BitStatics), 05f67583-bc24-4220-8a41-2dcd8c5ed33c)] + [interface_name(EMBED_IN_NAMESPACE(ITensorUInt8Bit), 58e1ae27-622b-48e3-be22-d867aed1daac)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -475,8 +484,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorInt8BitStatics", b1a12284-095c-4c76-a661-ac4cee1f3e8b)] - [interface_name("Windows.AI.MachineLearning.ITensorInt8Bit", cddd97c5-ffd8-4fef-aefb-30e1a485b2ee)] + [static_name(EMBED_IN_NAMESPACE(ITensorInt8BitStatics), b1a12284-095c-4c76-a661-ac4cee1f3e8b)] + [interface_name(EMBED_IN_NAMESPACE(ITensorInt8Bit), cddd97c5-ffd8-4fef-aefb-30e1a485b2ee)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -499,8 +508,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorUInt16BitStatics", 5df745dd-028a-481a-a27c-c7e6435e52dd)] - [interface_name("Windows.AI.MachineLearning.ITensorUInt16Bit", 68140f4b-23c0-42f3-81f6-a891c011bc3f)] + [static_name(EMBED_IN_NAMESPACE(ITensorUInt16BitStatics), 5df745dd-028a-481a-a27c-c7e6435e52dd)] + [interface_name(EMBED_IN_NAMESPACE(ITensorUInt16Bit), 68140f4b-23c0-42f3-81f6-a891c011bc3f)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -523,8 +532,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorInt16BitStatics", 98646293-266e-4b1a-821f-e60d70898b91)] - [interface_name("Windows.AI.MachineLearning.ITensorInt16Bit", 98a32d39-e6d6-44af-8afa-baebc44dc020)] + [static_name(EMBED_IN_NAMESPACE(ITensorInt16BitStatics), 98646293-266e-4b1a-821f-e60d70898b91)] + [interface_name(EMBED_IN_NAMESPACE(ITensorInt16Bit), 98a32d39-e6d6-44af-8afa-baebc44dc020)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -547,8 +556,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorUInt32BitStatics", 417c3837-e773-4378-8e7f-0cc33dbea697)] - [interface_name("Windows.AI.MachineLearning.ITensorUInt32Bit", d8c9c2ff-7511-45a3-bfac-c38f370d2237)] + [static_name(EMBED_IN_NAMESPACE(ITensorUInt32BitStatics), 417c3837-e773-4378-8e7f-0cc33dbea697)] + [interface_name(EMBED_IN_NAMESPACE(ITensorUInt32Bit), d8c9c2ff-7511-45a3-bfac-c38f370d2237)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -571,8 +580,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorInt32BitStatics", 6539864b-52fa-4e35-907c-834cac417b50)] - [interface_name("Windows.AI.MachineLearning.ITensorInt32Bit", 2c0c28d3-207c-4486-a7d2-884522c5e589)] + [static_name(EMBED_IN_NAMESPACE(ITensorInt32BitStatics), 6539864b-52fa-4e35-907c-834cac417b50)] + [interface_name(EMBED_IN_NAMESPACE(ITensorInt32Bit), 2c0c28d3-207c-4486-a7d2-884522c5e589)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -595,8 +604,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorUInt64BitStatics", 7a7e20eb-242f-47cb-a9c6-f602ecfbfee4)] - [interface_name("Windows.AI.MachineLearning.ITensorUInt64Bit", 2e70ffad-04bf-4825-839a-82baef8c7886)] + [static_name(EMBED_IN_NAMESPACE(ITensorUInt64BitStatics), 7a7e20eb-242f-47cb-a9c6-f602ecfbfee4)] + [interface_name(EMBED_IN_NAMESPACE(ITensorUInt64Bit), 2e70ffad-04bf-4825-839a-82baef8c7886)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -619,8 +628,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorInt64BitStatics", 9648ad9d-1198-4d74-9517-783ab62b9cc2)] - [interface_name("Windows.AI.MachineLearning.ITensorInt64Bit", 499665ba-1fa2-45ad-af25-a0bd9bda4c87)] + [static_name(EMBED_IN_NAMESPACE(ITensorInt64BitStatics), 9648ad9d-1198-4d74-9517-783ab62b9cc2)] + [interface_name(EMBED_IN_NAMESPACE(ITensorInt64Bit), 499665ba-1fa2-45ad-af25-a0bd9bda4c87)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -643,8 +652,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorBooleanStatics", 2796862c-2357-49a7-b476-d0aa3dfe6866)] - [interface_name("Windows.AI.MachineLearning.ITensorBoolean", 50f311ed-29e9-4a5c-a44d-8fc512584eed)] + [static_name(EMBED_IN_NAMESPACE(ITensorBooleanStatics), 2796862c-2357-49a7-b476-d0aa3dfe6866)] + [interface_name(EMBED_IN_NAMESPACE(ITensorBoolean), 50f311ed-29e9-4a5c-a44d-8fc512584eed)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -668,8 +677,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorDoubleStatics", a86693c5-9538-44e7-a3ca-5df374a5a70c)] - [interface_name("Windows.AI.MachineLearning.ITensorDouble", 91e41252-7a8f-4f0e-a28f-9637ffc8a3d0)] + [static_name(EMBED_IN_NAMESPACE(ITensorDoubleStatics), a86693c5-9538-44e7-a3ca-5df374a5a70c)] + [interface_name(EMBED_IN_NAMESPACE(ITensorDouble), 91e41252-7a8f-4f0e-a28f-9637ffc8a3d0)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -692,8 +701,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.ITensorStringStatics", 83623324-cf26-4f17-a2d4-20ef8d097d53)] - [interface_name("Windows.AI.MachineLearning.ITensorString", 582335c8-bdb1-4610-bc75-35e9cbf009b7)] + [static_name(EMBED_IN_NAMESPACE(ITensorStringStatics), 83623324-cf26-4f17-a2d4-20ef8d097d53)] + [interface_name(EMBED_IN_NAMESPACE(ITensorString), 582335c8-bdb1-4610-bc75-35e9cbf009b7)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] @@ -713,8 +722,8 @@ namespace Windows.AI.MachineLearning } [contract(MachineLearningContract, 1)] - [static_name("Windows.AI.MachineLearning.IImageFeatureValueStatics", 1bc317fd-23cb-4610-b085-c8e1c87ebaa0)] - [interface_name("Windows.AI.MachineLearning.IImageFeatureValue", f0414fd9-c9aa-4405-b7fb-94f87c8a3037)] + [static_name(EMBED_IN_NAMESPACE(IImageFeatureValueStatics), 1bc317fd-23cb-4610-b085-c8e1c87ebaa0)] + [interface_name(EMBED_IN_NAMESPACE(IImageFeatureValue), f0414fd9-c9aa-4405-b7fb-94f87c8a3037)] [threading(both)] [marshaling_behavior(agile)] [dualapipartition(1)] diff --git a/winml/api/dualapipartitionattribute.idl b/winml/api/dualapipartitionattribute.idl index c38ce9a117b7d..41c7c2b2f2ae3 100644 --- a/winml/api/dualapipartitionattribute.idl +++ b/winml/api/dualapipartitionattribute.idl @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -//! \file windows.ai.machinelearning.attributes.idl +//! \file dualapipartitionattribute.idl import "Windows.Foundation.idl"; #include diff --git a/winml/dll/module.cpp b/winml/dll/module.cpp index 717fb3fcede76..804182ca534d7 100644 --- a/winml/dll/module.cpp +++ b/winml/dll/module.cpp @@ -8,7 +8,7 @@ #include "LearningModelDevice.h" #include "OnnxruntimeProvider.h" -using namespace winrt::Windows::AI::MachineLearning::implementation; +using namespace winmlp; void __stdcall OnErrorReported(bool alreadyReported, wil::FailureInfo const& failure) WI_NOEXCEPT { if (!alreadyReported) { @@ -57,7 +57,7 @@ extern "C" BOOL WINAPI DllMain(_In_ HINSTANCE hInstance, DWORD dwReason, _In_ vo } extern "C" HRESULT WINAPI MLCreateOperatorRegistry(_COM_Outptr_ IMLOperatorRegistry** registry) try { - winrt::com_ptr engine_factory; + winrt::com_ptr<_winml::IEngineFactory> engine_factory; WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory.put())); WINML_THROW_IF_FAILED(engine_factory->CreateCustomRegistry(registry)); return S_OK; @@ -65,7 +65,7 @@ extern "C" HRESULT WINAPI MLCreateOperatorRegistry(_COM_Outptr_ IMLOperatorRegis CATCH_RETURN(); STDAPI DllCanUnloadNow() { - // The windows.ai.machinelearning.dll should not be freed by + // This dll should not be freed by // CoFreeUnusedLibraries since there can be outstanding COM object // references to many objects (AbiCustomRegistry, IMLOperatorKernelContext, // IMLOperatorTensor, etc) that are not reference counted in this path. @@ -79,8 +79,7 @@ STDAPI DllCanUnloadNow() { // that are shared out as a consequence of the MLCreateOperatorRegistry API // will be a complex task to complete in RS5. // - // As a temporary workaround we simply prevent the windows.ai.machinelearning.dll - // from unloading. + // As a temporary workaround we simply prevent the dll from unloading. // // There are no known code paths that rely on opportunistic dll unload. return S_FALSE; diff --git a/winml/dll/Windows.AI.MachineLearning.def b/winml/dll/winml.def similarity index 100% rename from winml/dll/Windows.AI.MachineLearning.def rename to winml/dll/winml.def diff --git a/winml/dll/winml.rc b/winml/dll/winml.rc index 9027e5d607dc0..61ae2e0f8ddbb 100644 --- a/winml/dll/winml.rc +++ b/winml/dll/winml.rc @@ -30,9 +30,9 @@ BEGIN VALUE "CompanyName", "Microsoft Corporation" VALUE "FileDescription", "Windows Machine Learning Runtime" VALUE "FileVersion", VER_STRING - VALUE "InternalName", "Windows.AI.MachineLearning.Runtime" + VALUE "InternalName", "Windows Machine Learning Runtime" VALUE "LegalCopyright", "\251 Microsoft Corporation. All rights reserved." - VALUE "OriginalFilename", "windows.ai.machinelearning.dll" + VALUE "OriginalFilename", BINARY_NAME VALUE "ProductName", "Microsoft\256 Windows\256 Operating System" VALUE "ProductVersion", VER_STRING END diff --git a/winml/lib/Api.Image/ConverterResourceStore.cpp b/winml/lib/Api.Image/ConverterResourceStore.cpp index f698a05f29835..4927a911ef47a 100644 --- a/winml/lib/Api.Image/ConverterResourceStore.cpp +++ b/winml/lib/Api.Image/ConverterResourceStore.cpp @@ -10,17 +10,12 @@ #include #include -using namespace Windows::AI::MachineLearning; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Graphics::DirectX::Direct3D11; -using namespace winrt::Windows::Graphics::DirectX; -using namespace Windows::Graphics::DirectX::Direct3D11; +using namespace _winml; ConverterResources::ConverterResources(Pool& pool, ConverterResourceDescription& descriptor) : m_pool(pool), Descriptor(descriptor), - Tensorizer(std::make_unique()), - Detensorizer(std::make_unique()) { + Tensorizer(std::make_unique()), + Detensorizer(std::make_unique()) { } void ConverterResources::ReturnToCache() { diff --git a/winml/lib/Api.Image/CpuDetensorizer.h b/winml/lib/Api.Image/CpuDetensorizer.h index 4059f85eaf211..264e100d36b0a 100644 --- a/winml/lib/Api.Image/CpuDetensorizer.h +++ b/winml/lib/Api.Image/CpuDetensorizer.h @@ -5,7 +5,8 @@ #include "inc/ImageConversionTypes.h" -namespace Windows::AI::MachineLearning::Internal { +namespace _winml { + class CpuDetensorizer { public: template @@ -234,4 +235,4 @@ class CpuDetensorizer { } #endif }; -} // namespace Windows::AI::MachineLearning::Internal +} // namespace _winml diff --git a/winml/lib/Api.Image/CpuTensorizer.h b/winml/lib/Api.Image/CpuTensorizer.h index fa60cef9a145b..a653abf8f29ff 100644 --- a/winml/lib/Api.Image/CpuTensorizer.h +++ b/winml/lib/Api.Image/CpuTensorizer.h @@ -5,10 +5,8 @@ #include "inc/ImageConversionTypes.h" -using namespace Windows::AI::MachineLearning::Internal; -using namespace winrt::Windows::Graphics::Imaging; +namespace _winml { -namespace Windows::AI::MachineLearning::Internal { class CpuTensorizer { public: template @@ -17,7 +15,7 @@ class CpuTensorizer { _In_ ImageTensorChannelType formatTo, _In_ BYTE* pBuffer, _In_ UINT32 bufferWidth, - _In_ const BitmapBounds& inputBounds, + _In_ const wgi::BitmapBounds& inputBounds, _Inout_ T* pCPUTensor) { #pragma warning(push) #pragma warning(disable : 26014) // warning about possible out of bounds accesing pData, but input is checked for BGRA8 format, so uiCapacity should be in multiples of 4 @@ -263,4 +261,4 @@ class CpuTensorizer { } #endif }; -} // namespace Windows::AI::MachineLearning::Internal +} // namespace _winml diff --git a/winml/lib/Api.Image/D3DDeviceCache.cpp b/winml/lib/Api.Image/D3DDeviceCache.cpp index 94119425f1ae1..1aedd23709731 100644 --- a/winml/lib/Api.Image/D3DDeviceCache.cpp +++ b/winml/lib/Api.Image/D3DDeviceCache.cpp @@ -38,32 +38,33 @@ namespace float16 { using namespace Microsoft::WRL; -namespace winrt::Windows::AI::MachineLearning::implementation { -D3DDeviceCache::D3DDeviceCache(Windows::AI::MachineLearning::LearningModelDeviceKind const& deviceKind) { +using namespace _winml; + +D3DDeviceCache::D3DDeviceCache(winml::LearningModelDeviceKind const& deviceKind) { WINML_THROW_IF_FAILED(CoCreateGuid(&fence_guid_)); - if (deviceKind == LearningModelDeviceKind::Cpu || deviceKind == LearningModelDeviceKind::Default) { + if (deviceKind == winml::LearningModelDeviceKind::Cpu || deviceKind == winml::LearningModelDeviceKind::Default) { // CPU device don't make any GPU devices device_luid_.HighPart = device_luid_.LowPart = 0; return; } DXGI_GPU_PREFERENCE preference; - WINML_THROW_IF_FAILED(DeviceHelpers::GetGPUPreference(deviceKind, &preference)); + WINML_THROW_IF_FAILED(GetGPUPreference(deviceKind, &preference)); CommonDeviceHelpers::AdapterEnumerationSupport support; WINML_THROW_IF_FAILED(CommonDeviceHelpers::GetAdapterEnumerationSupport(&support)); const char errStr[] = "No hardware adapters available"; if (support.has_dxgi) { - com_ptr spAdapter; - WINML_THROW_IF_FAILED_MSG(DeviceHelpers::GetDXGIHardwareAdapterWithPreference(preference, spAdapter.put()), errStr); + winrt::com_ptr spAdapter; + WINML_THROW_IF_FAILED_MSG(GetDXGIHardwareAdapterWithPreference(preference, spAdapter.put()), errStr); WINML_THROW_IF_FAILED(D3D12CreateDevice(spAdapter.get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device_.put()))); } #ifdef ENABLE_DXCORE if (support.has_dxgi == false) { com_ptr spAdapter; - WINML_THROW_IF_FAILED_MSG(DeviceHelpers::GetDXCoreHardwareAdapterWithPreference(preference, spAdapter.put()), errStr); + WINML_THROW_IF_FAILED_MSG(GetDXCoreHardwareAdapterWithPreference(preference, spAdapter.put()), errStr); WINML_THROW_IF_FAILED(D3D12CreateDevice(spAdapter.get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device_.put()))); } #endif @@ -72,29 +73,29 @@ D3DDeviceCache::D3DDeviceCache(Windows::AI::MachineLearning::LearningModelDevice device_luid_ = device_->GetAdapterLuid(); } -D3DDeviceCache::D3DDeviceCache(Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice const& device) { +D3DDeviceCache::D3DDeviceCache(wgdx::Direct3D11::IDirect3DDevice const& device) { WINML_THROW_IF_FAILED(CoCreateGuid(&fence_guid_)); // Use the 11 device to initialize 12 winrt_device_ = device; // they told us which device to run on, crack the interop wrapper to get the dxgi device - com_ptr<::Windows::Graphics::DirectX::Direct3D11::IDirect3DDxgiInterfaceAccess> dxgi; + winrt::com_ptr<::Windows::Graphics::DirectX::Direct3D11::IDirect3DDxgiInterfaceAccess> dxgi; dxgi = device.as<::Windows::Graphics::DirectX::Direct3D11::IDirect3DDxgiInterfaceAccess>(); - com_ptr dxgiDevice; + winrt::com_ptr dxgiDevice; WINML_THROW_IF_FAILED(dxgi->GetInterface(IID_PPV_ARGS(dxgiDevice.put()))); device_11_ = dxgiDevice.as(); - com_ptr spContext; + winrt::com_ptr spContext; device_11_->GetImmediateContext(spContext.put()); spContext.as(device_context11_); - com_ptr pDXGIDevice; + winrt::com_ptr pDXGIDevice; WINML_THROW_IF_FAILED(dxgi->GetInterface(IID_PPV_ARGS(pDXGIDevice.put()))); - com_ptr adapter; + winrt::com_ptr adapter; WINML_THROW_IF_FAILED(pDXGIDevice->GetAdapter(adapter.put())); WINML_THROW_IF_FAILED(D3D12CreateDevice(adapter.get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(device_.put()))); @@ -151,7 +152,7 @@ ID3D11DeviceContext4* D3DDeviceCache::GetD3D11DeviceContext() { return device_context11_.get(); } -Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice D3DDeviceCache::GetWinrtDevice() { +wgdx::Direct3D11::IDirect3DDevice D3DDeviceCache::GetWinrtDevice() { EnsureD3D11FromD3D12(); return winrt_device_; } @@ -184,19 +185,19 @@ void D3DDeviceCache::EnsureD3D11FromD3D12() { if (winrt_device_ != nullptr) return; - com_ptr<::IInspectable> spInspectable; - com_ptr spDXGIDevice; + winrt::com_ptr<::IInspectable> spInspectable; + winrt::com_ptr spDXGIDevice; // call our SEH version (for delay loading) - WINML_THROW_IF_FAILED(DeviceHelpers::CreateD3D11On12Device(device_.get(), device_11_.put())); - com_ptr spContext; + WINML_THROW_IF_FAILED(CreateD3D11On12Device(device_.get(), device_11_.put())); + winrt::com_ptr spContext; device_11_->GetImmediateContext(spContext.put()); spContext.as(device_context11_); WINML_THROW_IF_FAILED(device_11_->QueryInterface(IID_PPV_ARGS(spDXGIDevice.put()))); // Convert to Winrt wrapper. This doesn't actually make a new device. WINML_THROW_IF_FAILED(CreateDirect3D11DeviceFromDXGIDevice(spDXGIDevice.get(), spInspectable.put())); - WINML_THROW_IF_FAILED(spInspectable->QueryInterface(winrt::guid_of(), reinterpret_cast(winrt::put_abi(winrt_device_)))); + WINML_THROW_IF_FAILED(spInspectable->QueryInterface(winrt::guid_of(), reinterpret_cast(winrt::put_abi(winrt_device_)))); } void D3DDeviceCache::EnsureD3D12Fence() { @@ -232,12 +233,12 @@ void D3DDeviceCache::EnsureSharedFences() { // ensure the d11 stack is alive, the 11 stack doesn't exist on WCOSHeadless yet, so be resilient EnsureD3D11FromD3D12(); - com_ptr spD3D12DeviceChild; + winrt::com_ptr spD3D12DeviceChild; d3d12_fence_.as(spD3D12DeviceChild); HANDLE hSharedFence; WINML_THROW_IF_FAILED(device_->CreateSharedHandle(spD3D12DeviceChild.get(), NULL, GENERIC_ALL, nullptr, &hSharedFence)); - com_ptr spD3D11Device5; + winrt::com_ptr spD3D11Device5; device_11_.as(spD3D11Device5); wil::unique_handle safe(hSharedFence); WINML_THROW_IF_FAILED(spD3D11Device5->OpenSharedFence(safe.get(), IID_PPV_ARGS(d3d11_fence_.put()))); @@ -295,7 +296,7 @@ void D3DDeviceCache::WaitForFenceValue(UINT64 fenceValue) { ID3D12RootSignature* D3DDeviceCache::GetTensorizeRootSignature() { if (tensorize_root_signature_ == nullptr) { - com_ptr newRootSignature; + winrt::com_ptr newRootSignature; D3D12_FEATURE_DATA_ROOT_SIGNATURE featureData = {}; // This is the highest version the sample supports. If CheckFeatureSupport succeeds, the HighestVersion returned will not be greater than this. @@ -319,8 +320,8 @@ ID3D12RootSignature* D3DDeviceCache::GetTensorizeRootSignature() { CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC computeRootSignatureDesc; computeRootSignatureDesc.Init_1_1(_countof(rootParameters), rootParameters, 0, nullptr); - com_ptr signature; - com_ptr error; + winrt::com_ptr signature; + winrt::com_ptr error; WINML_THROW_IF_FAILED(D3DX12SerializeVersionedRootSignature(&computeRootSignatureDesc, featureData.HighestVersion, signature.put(), error.put())); WINML_THROW_IF_FAILED(device_->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(newRootSignature.put()))); newRootSignature->SetName(L"Tensorize Rootsignature"); @@ -340,7 +341,7 @@ ID3D12RootSignature* D3DDeviceCache::GetTensorizeRootSignature() { ID3D12RootSignature* D3DDeviceCache::GetDetensorizeRootSignature() { if (detensorize_root_signature_ == nullptr) { - com_ptr newRootSignature; + winrt::com_ptr newRootSignature; D3D12_FEATURE_DATA_ROOT_SIGNATURE featureData = {}; // This is the highest version the sample supports. If CheckFeatureSupport succeeds, the HighestVersion returned will not be greater than this. @@ -364,8 +365,8 @@ ID3D12RootSignature* D3DDeviceCache::GetDetensorizeRootSignature() { CD3DX12_VERSIONED_ROOT_SIGNATURE_DESC rootSignatureDesc; rootSignatureDesc.Init_1_1(_countof(rootParameters), rootParameters, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT); - com_ptr signature; - com_ptr error; + winrt::com_ptr signature; + winrt::com_ptr error; WINML_THROW_IF_FAILED(D3DX12SerializeVersionedRootSignature(&rootSignatureDesc, featureData.HighestVersion, signature.put(), error.put())); WINML_THROW_IF_FAILED(device_->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(newRootSignature.put()))); newRootSignature->SetName(L"Detensorize Rootsignature"); @@ -385,7 +386,7 @@ ID3D12RootSignature* D3DDeviceCache::GetDetensorizeRootSignature() { ID3D12PipelineState* D3DDeviceCache::GetCachedPipelineState(PipelineStateCacheType type, PipelineStateCacheFormat formatFrom, PipelineStateCacheFormat formatTo, PipelineStateCacheOperation operation) { if (cached_pipeline_state[static_cast(type)][static_cast(formatFrom)][static_cast(formatTo)][static_cast(operation)] == nullptr) { - com_ptr newPSO; + winrt::com_ptr newPSO; if (operation == PipelineStateCacheOperation::kTensorize) { newPSO.attach(CreateTensorizePipelineState(type, formatFrom, formatTo)); } else { @@ -475,7 +476,7 @@ ID3D12PipelineState* D3DDeviceCache::CreateTensorizePipelineState(PipelineStateC computePsoDesc.pRootSignature = GetTensorizeRootSignature(); computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(shaderBytecode, static_cast(shaderBytecodeSize)); - com_ptr pipelineState; + winrt::com_ptr pipelineState; WINML_THROW_IF_FAILED(device_->CreateComputePipelineState(&computePsoDesc, IID_PPV_ARGS(pipelineState.put()))); return pipelineState.detach(); @@ -568,7 +569,7 @@ ID3D12PipelineState* D3DDeviceCache::CreateDetensorizePipelineState(PipelineStat computePsoDesc.pRootSignature = GetDetensorizeRootSignature(); computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(shaderBytecode, static_cast(shaderBytecodeSize)); - com_ptr pipelineState; + winrt::com_ptr pipelineState; WINML_THROW_IF_FAILED(device_->CreateComputePipelineState(&computePsoDesc, IID_PPV_ARGS(pipelineState.put()))); return pipelineState.detach(); @@ -576,7 +577,7 @@ ID3D12PipelineState* D3DDeviceCache::CreateDetensorizePipelineState(PipelineStat ID3D12Resource* D3DDeviceCache::GetDetensorizeVertexBuffer(_Out_ UINT* vertexBufferSize) { if (detensorize_vertex_buffer_ == nullptr) { - com_ptr newResource; + winrt::com_ptr newResource; // Create the vertex buffer. // 2 triangles for full screen DirectX::XMFLOAT3 triangleVertices[] = @@ -672,4 +673,3 @@ void D3DDeviceCache::SyncD3D11DeviceToConverter(_In_ ID3D11Fence* pD3D11Fence) { bool D3DDeviceCache::SharedHandleInitialized() { return d3d11_fence_ != nullptr; } -} // namespace winrt::Windows::AI::MachineLearning::implementation diff --git a/winml/lib/Api.Image/DeviceHelpers.cpp b/winml/lib/Api.Image/DeviceHelpers.cpp index 9cd06b70641fa..97a968d6e302a 100644 --- a/winml/lib/Api.Image/DeviceHelpers.cpp +++ b/winml/lib/Api.Image/DeviceHelpers.cpp @@ -12,9 +12,6 @@ #include "CommonDeviceHelpers.h" #include "LearningModelDevice.h" -namespace DeviceHelpers { - - HRESULT IsWarpAdapter(IDXGIAdapter1* pAdapter, bool* isWarpAdapter) { DXGI_ADAPTER_DESC1 pDesc; RETURN_IF_FAILED(pAdapter->GetDesc1(&pDesc)); @@ -28,7 +25,7 @@ HRESULT IsWarpAdapter(IDXGIAdapter1* pAdapter, bool* isWarpAdapter) { return S_OK; } -HRESULT GetDXGIHardwareAdapterWithPreference(DXGI_GPU_PREFERENCE preference, IDXGIAdapter1** ppAdapter) { +HRESULT _winml::GetDXGIHardwareAdapterWithPreference(DXGI_GPU_PREFERENCE preference, IDXGIAdapter1** ppAdapter) { winrt::com_ptr spAdapter; UINT i = 0; @@ -70,7 +67,7 @@ HRESULT GetDXGIHardwareAdapterWithPreference(DXGI_GPU_PREFERENCE preference, IDX // Return the first adapter that matches the preference: // DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE => DXCoreAdapterProperty::IsDetachable // DXGI_GPU_PREFERENCE_MINIMUM_POWER => DXCoreAdapterProperty::IsIntegrated -HRESULT GetDXCoreHardwareAdapterWithPreference(DXGI_GPU_PREFERENCE preference, IDXCoreAdapter** ppAdapter) { +HRESULT _winml::GetDXCoreHardwareAdapterWithPreference(DXGI_GPU_PREFERENCE preference, IDXCoreAdapter** ppAdapter) { winrt::com_ptr spFactory; RETURN_IF_FAILED(DXCoreCreateAdapterFactory(IID_PPV_ARGS(spFactory.put()))); @@ -125,7 +122,7 @@ HRESULT GetDXCoreHardwareAdapterWithPreference(DXGI_GPU_PREFERENCE preference, I } #endif -HRESULT CreateD3D11On12Device(ID3D12Device* device12, ID3D11Device** device11) { +HRESULT _winml::CreateD3D11On12Device(ID3D12Device* device12, ID3D11Device** device11) { return CommonDeviceHelpers::RunDelayLoadedApi( D3D11On12CreateDevice, device12, // pointer to d3d12 device @@ -140,17 +137,17 @@ HRESULT CreateD3D11On12Device(ID3D12Device* device12, ID3D11Device** device11) { nullptr); // pointer to the returned feature level (unused) } -HRESULT GetGPUPreference(winrt::Windows::AI::MachineLearning::LearningModelDeviceKind deviceKind, DXGI_GPU_PREFERENCE* preference) noexcept { +HRESULT _winml::GetGPUPreference(winml::LearningModelDeviceKind deviceKind, DXGI_GPU_PREFERENCE* preference) noexcept { switch (deviceKind) { - case winrt::Windows::AI::MachineLearning::LearningModelDeviceKind::DirectX: { + case winml::LearningModelDeviceKind::DirectX: { *preference = DXGI_GPU_PREFERENCE_UNSPECIFIED; return S_OK; } - case winrt::Windows::AI::MachineLearning::LearningModelDeviceKind::DirectXHighPerformance: { + case winml::LearningModelDeviceKind::DirectXHighPerformance: { *preference = DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE; return S_OK; } - case winrt::Windows::AI::MachineLearning::LearningModelDeviceKind::DirectXMinPower: { + case winml::LearningModelDeviceKind::DirectXMinPower: { *preference = DXGI_GPU_PREFERENCE_MINIMUM_POWER; return S_OK; } @@ -159,4 +156,3 @@ HRESULT GetGPUPreference(winrt::Windows::AI::MachineLearning::LearningModelDevic return E_INVALIDARG; } } -} // namespace DeviceHelpers diff --git a/winml/lib/Api.Image/ImageConversionHelpers.cpp b/winml/lib/Api.Image/ImageConversionHelpers.cpp index b57662480d141..c8aa181045879 100644 --- a/winml/lib/Api.Image/ImageConversionHelpers.cpp +++ b/winml/lib/Api.Image/ImageConversionHelpers.cpp @@ -5,351 +5,347 @@ #include "inc/ImageConversionHelpers.h" using namespace Microsoft::WRL; -using namespace Windows::AI::MachineLearning::Internal; using namespace Windows::Graphics::DirectX::Direct3D11; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Graphics::DirectX; -using namespace winrt::Windows::Graphics::DirectX::Direct3D11; - -namespace Windows::AI::MachineLearning::Internal::ImageConversionHelpers { - LUID GetLUIDFromDirect3DSurface(const IDirect3DSurface& surface) { - ComPtr spDx11Device; - ComPtr spDxgiInterfaceAccess; - ComPtr spDx11Texture2D; - ComPtr spDXGIDevice; - ComPtr spDXGIAdapter; - DXGI_ADAPTER_DESC adapterDesc = {0}; - - spDxgiInterfaceAccess = surface.as().get(); - WINML_THROW_IF_FAILED(spDxgiInterfaceAccess->GetInterface(IID_PPV_ARGS(&spDx11Texture2D))); - spDx11Texture2D->GetDevice(&spDx11Device); - WINML_THROW_IF_FAILED(spDx11Device->QueryInterface(IID_PPV_ARGS(&spDXGIDevice))); - WINML_THROW_IF_FAILED(spDXGIDevice->GetAdapter(&spDXGIAdapter)); - WINML_THROW_IF_FAILED(spDXGIAdapter->GetDesc(&adapterDesc)); - - return adapterDesc.AdapterLuid; - } - HRESULT GetVideoFrameInfo( - _In_ const winrt::Windows::Media::IVideoFrame& inputVideoFrame, - _Out_ DWORD& format, - _Out_ int& width, - _Out_ int& height, - _Out_ LUID& luid) { - winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DSurface spInputSurface = inputVideoFrame.Direct3DSurface(); - if (spInputSurface != nullptr) { - Direct3DSurfaceDescription description; - description = spInputSurface.Description(); - format = (DWORD)description.Format; - width = description.Width; - height = description.Height; - luid = GetLUIDFromDirect3DSurface(spInputSurface); +static LUID GetLUIDFromDirect3DSurface(const wgdx::Direct3D11::IDirect3DSurface& surface) { + ComPtr spDx11Device; + ComPtr spDxgiInterfaceAccess; + ComPtr spDx11Texture2D; + ComPtr spDXGIDevice; + ComPtr spDXGIAdapter; + DXGI_ADAPTER_DESC adapterDesc = {0}; + + spDxgiInterfaceAccess = surface.as().get(); + WINML_THROW_IF_FAILED(spDxgiInterfaceAccess->GetInterface(IID_PPV_ARGS(&spDx11Texture2D))); + spDx11Texture2D->GetDevice(&spDx11Device); + WINML_THROW_IF_FAILED(spDx11Device->QueryInterface(IID_PPV_ARGS(&spDXGIDevice))); + WINML_THROW_IF_FAILED(spDXGIDevice->GetAdapter(&spDXGIAdapter)); + WINML_THROW_IF_FAILED(spDXGIAdapter->GetDesc(&adapterDesc)); + + return adapterDesc.AdapterLuid; +} + +static HRESULT GetVideoFrameInfo( + _In_ const wm::IVideoFrame& inputVideoFrame, + _Out_ DWORD& format, + _Out_ int& width, + _Out_ int& height, + _Out_ LUID& luid) { + wgdx::Direct3D11::IDirect3DSurface spInputSurface = inputVideoFrame.Direct3DSurface(); + if (spInputSurface != nullptr) { + wgdx::Direct3D11::Direct3DSurfaceDescription description; + description = spInputSurface.Description(); + format = (DWORD)description.Format; + width = description.Width; + height = description.Height; + luid = GetLUIDFromDirect3DSurface(spInputSurface); + } else { + wgi::SoftwareBitmap spInputSoftwareBitmap = inputVideoFrame.SoftwareBitmap(); + if (spInputSoftwareBitmap != nullptr) { + format = (DWORD)spInputSoftwareBitmap.BitmapPixelFormat(); + height = spInputSoftwareBitmap.PixelHeight(); + width = spInputSoftwareBitmap.PixelWidth(); + luid.HighPart = luid.LowPart = 0; } else { - winrt::Windows::Graphics::Imaging::SoftwareBitmap spInputSoftwareBitmap = inputVideoFrame.SoftwareBitmap(); - if (spInputSoftwareBitmap != nullptr) { - format = (DWORD)spInputSoftwareBitmap.BitmapPixelFormat(); - height = spInputSoftwareBitmap.PixelHeight(); - width = spInputSoftwareBitmap.PixelWidth(); - luid.HighPart = luid.LowPart = 0; - } else { - return E_INVALIDARG; - } + return E_INVALIDARG; } - return S_OK; } - - void ConvertVideoFrameToVideoFrame( - _In_ const IVideoFrame& inputVideoFrame, - _In_ const BitmapBounds& inputBounds, - _In_ UINT32 outputWidth, - _In_ UINT32 outputHeight, - _Inout_ winrt::Windows::Media::VideoFrame& pOutputVideoFrame) { - BitmapBounds outputBounds = { - 0, - 0, - outputWidth, - outputHeight}; - - winrt::Windows::Graphics::Imaging::SoftwareBitmap spInputSoftwareBitmap = inputVideoFrame.SoftwareBitmap(); - winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DSurface spInputDirect3DSurface = inputVideoFrame.Direct3DSurface(); - - // only one of softwarebitmap or direct3Dsurface should be non-null - if ((spInputSoftwareBitmap == nullptr && spInputDirect3DSurface == nullptr) || (spInputSoftwareBitmap != nullptr && spInputDirect3DSurface != nullptr)) { - WINML_THROW_HR(E_INVALIDARG); - } - - auto pInputVideoFrame2 = inputVideoFrame.as(); - pInputVideoFrame2.CopyToAsync(pOutputVideoFrame, inputBounds, outputBounds).get(); + return S_OK; +} + +void _winmli::ConvertVideoFrameToVideoFrame( + _In_ const wm::IVideoFrame& inputVideoFrame, + _In_ const wgi::BitmapBounds& inputBounds, + _In_ UINT32 outputWidth, + _In_ UINT32 outputHeight, + _Inout_ wm::VideoFrame& pOutputVideoFrame) { + wgi::BitmapBounds outputBounds = { + 0, + 0, + outputWidth, + outputHeight}; + + wgi::SoftwareBitmap spInputSoftwareBitmap = inputVideoFrame.SoftwareBitmap(); + wgdx::Direct3D11::IDirect3DSurface spInputDirect3DSurface = inputVideoFrame.Direct3DSurface(); + + // only one of softwarebitmap or direct3Dsurface should be non-null + if ((spInputSoftwareBitmap == nullptr && spInputDirect3DSurface == nullptr) || (spInputSoftwareBitmap != nullptr && spInputDirect3DSurface != nullptr)) { + WINML_THROW_HR(E_INVALIDARG); } - bool SoftwareBitmapFormatSupported(const SoftwareBitmap& softwareBitmap) { - assert(softwareBitmap != nullptr); + auto pInputVideoFrame2 = inputVideoFrame.as(); + pInputVideoFrame2.CopyToAsync(pOutputVideoFrame, inputBounds, outputBounds).get(); +} - switch (softwareBitmap.BitmapPixelFormat()) { - case BitmapPixelFormat::Bgra8: - case BitmapPixelFormat::Rgba8: - case BitmapPixelFormat::Gray8: - return true; - } +bool _winmli::SoftwareBitmapFormatSupported(const wgi::SoftwareBitmap& softwareBitmap) { + assert(softwareBitmap != nullptr); - return false; + switch (softwareBitmap.BitmapPixelFormat()) { + case wgi::BitmapPixelFormat::Bgra8: + case wgi::BitmapPixelFormat::Rgba8: + case wgi::BitmapPixelFormat::Gray8: + return true; } - bool DirectXPixelFormatSupported(DirectXPixelFormat format) { - switch (format) { - case DirectXPixelFormat::B8G8R8X8UIntNormalized: - case DirectXPixelFormat::B8G8R8A8UIntNormalized: - case DirectXPixelFormat::R8G8B8A8UIntNormalized: - case DirectXPixelFormat::R8UIntNormalized: - return true; - } - - return false; - } - - bool FormatSupportedForUAV(_In_ ID3D12Device1* device, _In_ DXGI_FORMAT format) { - assert(device != nullptr); - - D3D12_FEATURE_DATA_FORMAT_SUPPORT formatSupport = {format}; - HRESULT hr = device->CheckFeatureSupport(D3D12_FEATURE_FORMAT_SUPPORT, &formatSupport, sizeof(formatSupport)); + return false; +} - return SUCCEEDED(hr) && (formatSupport.Support1 & D3D12_FORMAT_SUPPORT1_TYPED_UNORDERED_ACCESS_VIEW); +bool _winmli::DirectXPixelFormatSupported(wgdx::DirectXPixelFormat format) { + switch (format) { + case wgdx::DirectXPixelFormat::B8G8R8X8UIntNormalized: + case wgdx::DirectXPixelFormat::B8G8R8A8UIntNormalized: + case wgdx::DirectXPixelFormat::R8G8B8A8UIntNormalized: + case wgdx::DirectXPixelFormat::R8UIntNormalized: + return true; } - // This helper method uses the input parameters do determine if a conversion is necessary - // A conversion is not necessary if - // 1. input bounds cover the entire input bitmap/surface (else we are cropping) - // 2. desired output size is equal to input size (else we are resizing) - // 3. (mapping softwarebitmap to softwarebitmap) OR (mapping from d3dsurface to d3dsurface AND the two surfaces are on the same device) - // 4. the input is already in the desired format (BGRA8/B8G8R8X8UIntNormalized) - bool NeedsVideoFrameConversion( - _In_ const IVideoFrame& inputVideoFrame, - _In_ LUID outputLuid, - _In_ const BitmapBounds& inputBounds, - _In_ UINT32 outputWidth, - _In_ UINT32 outputHeight) { - bool bNeedConversion = false; - HRESULT hr = S_OK; - - DWORD format = 0; - int width = 0, height = 0; - LUID luid; - - if (FAILED((hr = GetVideoFrameInfo(inputVideoFrame, format, width, height, luid)))) { + return false; +} + +bool _winmli::FormatSupportedForUAV(_In_ ID3D12Device1* device, _In_ DXGI_FORMAT format) { + assert(device != nullptr); + + D3D12_FEATURE_DATA_FORMAT_SUPPORT formatSupport = {format}; + HRESULT hr = device->CheckFeatureSupport(D3D12_FEATURE_FORMAT_SUPPORT, &formatSupport, sizeof(formatSupport)); + + return SUCCEEDED(hr) && (formatSupport.Support1 & D3D12_FORMAT_SUPPORT1_TYPED_UNORDERED_ACCESS_VIEW); +} + +// This helper method uses the input parameters do determine if a conversion is necessary +// A conversion is not necessary if +// 1. input bounds cover the entire input bitmap/surface (else we are cropping) +// 2. desired output size is equal to input size (else we are resizing) +// 3. (mapping softwarebitmap to softwarebitmap) OR (mapping from d3dsurface to d3dsurface AND the two surfaces are on the same device) +// 4. the input is already in the desired format (BGRA8/B8G8R8X8UIntNormalized) +bool _winmli::NeedsVideoFrameConversion( + _In_ const wm::IVideoFrame& inputVideoFrame, + _In_ LUID outputLuid, + _In_ const wgi::BitmapBounds& inputBounds, + _In_ UINT32 outputWidth, + _In_ UINT32 outputHeight) { + bool bNeedConversion = false; + HRESULT hr = S_OK; + + DWORD format = 0; + int width = 0, height = 0; + LUID luid; + + if (FAILED((hr = GetVideoFrameInfo(inputVideoFrame, format, width, height, luid)))) { + bNeedConversion = true; + } else if (((int)inputBounds.Width != outputWidth) || + (inputBounds.X != 0) || + ((int)inputBounds.Height != outputHeight) || + (inputBounds.Y != 0) || + (inputVideoFrame == nullptr)) // Check crop + { + bNeedConversion = true; + } else if (luid.HighPart != outputLuid.HighPart || + luid.LowPart != outputLuid.LowPart) { + bNeedConversion = true; + } else if (static_cast(width) != outputWidth || + static_cast(height) != outputHeight) { + bNeedConversion = true; + } else if (outputLuid.HighPart != 0 || + outputLuid.LowPart != 0) { + if (format != (DWORD)wgdx::DirectXPixelFormat::B8G8R8X8UIntNormalized) { bNeedConversion = true; - } else if (((int)inputBounds.Width != outputWidth) || - (inputBounds.X != 0) || - ((int)inputBounds.Height != outputHeight) || - (inputBounds.Y != 0) || - (inputVideoFrame == nullptr)) // Check crop - { - bNeedConversion = true; - } else if (luid.HighPart != outputLuid.HighPart || - luid.LowPart != outputLuid.LowPart) { - bNeedConversion = true; - } else if (static_cast(width) != outputWidth || - static_cast(height) != outputHeight) { + } + } else { + if (format != (DWORD)wgi::BitmapPixelFormat::Bgra8) { bNeedConversion = true; - } else if (outputLuid.HighPart != 0 || - outputLuid.LowPart != 0) { - if (format != (DWORD)DirectXPixelFormat::B8G8R8X8UIntNormalized) { - bNeedConversion = true; - } - } else { - if (format != (DWORD)BitmapPixelFormat::Bgra8) { - bNeedConversion = true; - } } - - TraceLoggingWrite( - winml_trace_logging_provider, - "InputVideoFrame", - TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), - TraceLoggingBool(bNeedConversion, "Convert"), - TraceLoggingHexInt32(hr, "HRESULT"), - TraceLoggingInt32(width, "iWidth"), - TraceLoggingInt32(outputWidth, "oWidth"), - TraceLoggingInt32(height, "iHeight"), - TraceLoggingInt32(outputWidth, "oHeight"), - TraceLoggingHexInt64(*((ULONGLONG*)&luid), "iLuid"), - TraceLoggingHexInt64(*((ULONGLONG*)&outputLuid), "oLuid"), - TraceLoggingHexInt32(format, "iFormat"), - TraceLoggingInt32(inputBounds.X, "rX"), - TraceLoggingInt32(inputBounds.Y, "rY"), - TraceLoggingInt32(inputBounds.Width, "rW"), - TraceLoggingInt32(inputBounds.Height, "rH")); - - return bNeedConversion; } - ImageTensorChannelType GetChannelTypeFromSoftwareBitmap(const SoftwareBitmap& softwareBitmap) { - assert(softwareBitmap != nullptr); - - switch (softwareBitmap.BitmapPixelFormat()) { - case BitmapPixelFormat::Bgra8: - return kImageTensorChannelTypeBGR8; - case BitmapPixelFormat::Rgba8: - return kImageTensorChannelTypeRGB8; - case BitmapPixelFormat::Gray8: - return kImageTensorChannelTypeGRAY8; - } - - WINML_THROW_HR(E_INVALIDARG); + TraceLoggingWrite( + winml_trace_logging_provider, + "InputVideoFrame", + TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), + TraceLoggingBool(bNeedConversion, "Convert"), + TraceLoggingHexInt32(hr, "HRESULT"), + TraceLoggingInt32(width, "iWidth"), + TraceLoggingInt32(outputWidth, "oWidth"), + TraceLoggingInt32(height, "iHeight"), + TraceLoggingInt32(outputWidth, "oHeight"), + TraceLoggingHexInt64(*((ULONGLONG*)&luid), "iLuid"), + TraceLoggingHexInt64(*((ULONGLONG*)&outputLuid), "oLuid"), + TraceLoggingHexInt32(format, "iFormat"), + TraceLoggingInt32(inputBounds.X, "rX"), + TraceLoggingInt32(inputBounds.Y, "rY"), + TraceLoggingInt32(inputBounds.Width, "rW"), + TraceLoggingInt32(inputBounds.Height, "rH")); + + return bNeedConversion; +} + +_winml::ImageTensorChannelType _winmli::GetChannelTypeFromSoftwareBitmap(const wgi::SoftwareBitmap& softwareBitmap) { + assert(softwareBitmap != nullptr); + + switch (softwareBitmap.BitmapPixelFormat()) { + case wgi::BitmapPixelFormat::Bgra8: + return _winml::kImageTensorChannelTypeBGR8; + case wgi::BitmapPixelFormat::Rgba8: + return _winml::kImageTensorChannelTypeRGB8; + case wgi::BitmapPixelFormat::Gray8: + return _winml::kImageTensorChannelTypeGRAY8; } - BitmapPixelFormat GetBitmapPixelFormatFromChannelType(ImageTensorChannelType channelType) { - switch (channelType) { - case kImageTensorChannelTypeBGR8: - return BitmapPixelFormat::Bgra8; - case kImageTensorChannelTypeRGB8: - return BitmapPixelFormat::Rgba8; - case kImageTensorChannelTypeGRAY8: - return BitmapPixelFormat::Gray8; - } - - WINML_THROW_HR(E_INVALIDARG); + WINML_THROW_HR(E_INVALIDARG); +} + +wgi::BitmapPixelFormat _winmli::GetBitmapPixelFormatFromChannelType(_winml::ImageTensorChannelType channelType) { + switch (channelType) { + case _winml::kImageTensorChannelTypeBGR8: + return wgi::BitmapPixelFormat::Bgra8; + case _winml::kImageTensorChannelTypeRGB8: + return wgi::BitmapPixelFormat::Rgba8; + case _winml::kImageTensorChannelTypeGRAY8: + return wgi::BitmapPixelFormat::Gray8; } - ImageTensorChannelType GetChannelTypeFromDirect3DSurface(const IDirect3DSurface& direct3DSurface) { - assert(direct3DSurface != nullptr); + WINML_THROW_HR(E_INVALIDARG); +} - switch (direct3DSurface.Description().Format) { - case DirectXPixelFormat::B8G8R8A8UIntNormalized: - case DirectXPixelFormat::B8G8R8X8UIntNormalized: - return kImageTensorChannelTypeBGR8; +_winml::ImageTensorChannelType _winmli::GetChannelTypeFromDirect3DSurface( + const wgdx::Direct3D11::IDirect3DSurface& direct3DSurface) { + assert(direct3DSurface != nullptr); - case DirectXPixelFormat::R8G8B8A8UIntNormalized: - return kImageTensorChannelTypeRGB8; + switch (direct3DSurface.Description().Format) { + case wgdx::DirectXPixelFormat::B8G8R8A8UIntNormalized: + case wgdx::DirectXPixelFormat::B8G8R8X8UIntNormalized: + return _winml::kImageTensorChannelTypeBGR8; - case DirectXPixelFormat::R8UIntNormalized: - return kImageTensorChannelTypeGRAY8; - } + case wgdx::DirectXPixelFormat::R8G8B8A8UIntNormalized: + return _winml::kImageTensorChannelTypeRGB8; - WINML_THROW_HR(E_INVALIDARG); + case wgdx::DirectXPixelFormat::R8UIntNormalized: + return _winml::kImageTensorChannelTypeGRAY8; } - DirectXPixelFormat GetDirectXPixelFormatFromDXGIFormat(DXGI_FORMAT dxgiFormat) { - switch (dxgiFormat) { - case DXGI_FORMAT_B8G8R8A8_UNORM: - return DirectXPixelFormat::B8G8R8A8UIntNormalized; - case DXGI_FORMAT_B8G8R8X8_UNORM: - return DirectXPixelFormat::B8G8R8X8UIntNormalized; - case DXGI_FORMAT_R8G8B8A8_UNORM: - return DirectXPixelFormat::R8G8B8A8UIntNormalized; - case DXGI_FORMAT_R8_UNORM: - return DirectXPixelFormat::R8UIntNormalized; - } - - WINML_THROW_HR(E_INVALIDARG); + WINML_THROW_HR(E_INVALIDARG); +} + +wgdx::DirectXPixelFormat _winmli::GetDirectXPixelFormatFromDXGIFormat(DXGI_FORMAT dxgiFormat) { + switch (dxgiFormat) { + case DXGI_FORMAT_B8G8R8A8_UNORM: + return wgdx::DirectXPixelFormat::B8G8R8A8UIntNormalized; + case DXGI_FORMAT_B8G8R8X8_UNORM: + return wgdx::DirectXPixelFormat::B8G8R8X8UIntNormalized; + case DXGI_FORMAT_R8G8B8A8_UNORM: + return wgdx::DirectXPixelFormat::R8G8B8A8UIntNormalized; + case DXGI_FORMAT_R8_UNORM: + return wgdx::DirectXPixelFormat::R8UIntNormalized; } - DXGI_FORMAT GetDXGIFormatFromDirectXPixelFormat(DirectXPixelFormat directXPixelFormat) { - switch (directXPixelFormat) { - case DirectXPixelFormat::B8G8R8A8UIntNormalized: - return DXGI_FORMAT_B8G8R8A8_UNORM; - case DirectXPixelFormat::B8G8R8X8UIntNormalized: - return DXGI_FORMAT_B8G8R8X8_UNORM; - case DirectXPixelFormat::R8G8B8A8UIntNormalized: - return DXGI_FORMAT_R8G8B8A8_UNORM; - case DirectXPixelFormat::R8UIntNormalized: - return DXGI_FORMAT_R8_UNORM; - } - - WINML_THROW_HR(E_INVALIDARG); + WINML_THROW_HR(E_INVALIDARG); +} + +DXGI_FORMAT _winmli::GetDXGIFormatFromDirectXPixelFormat(wgdx::DirectXPixelFormat directXPixelFormat) { + switch (directXPixelFormat) { + case wgdx::DirectXPixelFormat::B8G8R8A8UIntNormalized: + return DXGI_FORMAT_B8G8R8A8_UNORM; + case wgdx::DirectXPixelFormat::B8G8R8X8UIntNormalized: + return DXGI_FORMAT_B8G8R8X8_UNORM; + case wgdx::DirectXPixelFormat::R8G8B8A8UIntNormalized: + return DXGI_FORMAT_R8G8B8A8_UNORM; + case wgdx::DirectXPixelFormat::R8UIntNormalized: + return DXGI_FORMAT_R8_UNORM; } - DirectXPixelFormat GetDirectXPixelFormatFromChannelType(ImageTensorChannelType channelType) { - switch (channelType) { - case kImageTensorChannelTypeBGR8: - return DirectXPixelFormat::B8G8R8A8UIntNormalized; - case kImageTensorChannelTypeRGB8: - return DirectXPixelFormat::R8G8B8A8UIntNormalized; - case kImageTensorChannelTypeGRAY8: - return DirectXPixelFormat::R8UIntNormalized; - } - - WINML_THROW_HR(E_INVALIDARG); + WINML_THROW_HR(E_INVALIDARG); +} + +wgdx::DirectXPixelFormat _winmli::GetDirectXPixelFormatFromChannelType(_winml::ImageTensorChannelType channelType) { + switch (channelType) { + case _winml::kImageTensorChannelTypeBGR8: + return wgdx::DirectXPixelFormat::B8G8R8A8UIntNormalized; + case _winml::kImageTensorChannelTypeRGB8: + return wgdx::DirectXPixelFormat::R8G8B8A8UIntNormalized; + case _winml::kImageTensorChannelTypeGRAY8: + return wgdx::DirectXPixelFormat::R8UIntNormalized; } - IDirect3DDevice GetDeviceFromDirect3DSurface(const IDirect3DSurface& d3dSurface) { - assert(d3dSurface != nullptr); + WINML_THROW_HR(E_INVALIDARG); +} - ComPtr spDx11Texture2D; - ComPtr spDxgiInterfaceAccess = d3dSurface.as().get(); - WINML_THROW_IF_FAILED(spDxgiInterfaceAccess->GetInterface(IID_PPV_ARGS(&spDx11Texture2D))); +wgdx::Direct3D11::IDirect3DDevice _winmli::GetDeviceFromDirect3DSurface(const wgdx::Direct3D11::IDirect3DSurface& d3dSurface) { + assert(d3dSurface != nullptr); - ComPtr spDx11Device; - spDx11Texture2D->GetDevice(&spDx11Device); + ComPtr spDx11Texture2D; + ComPtr spDxgiInterfaceAccess = d3dSurface.as().get(); + WINML_THROW_IF_FAILED(spDxgiInterfaceAccess->GetInterface(IID_PPV_ARGS(&spDx11Texture2D))); - ComPtr spDXGIDevice; - WINML_THROW_IF_FAILED(spDx11Device->QueryInterface(IID_PPV_ARGS(&spDXGIDevice))); + ComPtr spDx11Device; + spDx11Texture2D->GetDevice(&spDx11Device); - ComPtr<::IInspectable> spInspectable; - WINML_THROW_IF_FAILED(CreateDirect3D11DeviceFromDXGIDevice(spDXGIDevice.Get(), &spInspectable)); + ComPtr spDXGIDevice; + WINML_THROW_IF_FAILED(spDx11Device->QueryInterface(IID_PPV_ARGS(&spDXGIDevice))); - IDirect3DDevice d3dDevice; - WINML_THROW_IF_FAILED(spInspectable->QueryInterface(winrt::guid_of(), reinterpret_cast(winrt::put_abi(d3dDevice)))); + ComPtr<::IInspectable> spInspectable; + WINML_THROW_IF_FAILED(CreateDirect3D11DeviceFromDXGIDevice(spDXGIDevice.Get(), &spInspectable)); - return d3dDevice; - } + wgdx::Direct3D11::IDirect3DDevice d3dDevice; + WINML_THROW_IF_FAILED(spInspectable->QueryInterface( + winrt::guid_of(), + reinterpret_cast(winrt::put_abi(d3dDevice)))); - bool TexturesHaveSameDevice(_In_ ID3D11Texture2D* pTexture1, _In_ ID3D11Texture2D* pTexture2) { - if (pTexture1 && pTexture2) { - ComPtr spDevice1; - pTexture1->GetDevice(&spDevice1); + return d3dDevice; +} - ComPtr spDevice2; - pTexture2->GetDevice(&spDevice2); +bool _winmli::TexturesHaveSameDevice(_In_ ID3D11Texture2D* pTexture1, _In_ ID3D11Texture2D* pTexture2) { + if (pTexture1 && pTexture2) { + ComPtr spDevice1; + pTexture1->GetDevice(&spDevice1); - return spDevice1.Get() == spDevice2.Get(); - } + ComPtr spDevice2; + pTexture2->GetDevice(&spDevice2); - return false; + return spDevice1.Get() == spDevice2.Get(); } - bool TextureIsOnDevice(_In_ ID3D11Texture2D* pTexture, _In_ ID3D11Device* pDevice) { - if (pTexture && pDevice) { - ComPtr spDevice1; - pTexture->GetDevice(&spDevice1); + return false; +} - return spDevice1.Get() == pDevice; - } +bool _winmli::TextureIsOnDevice(_In_ ID3D11Texture2D* pTexture, _In_ ID3D11Device* pDevice) { + if (pTexture && pDevice) { + ComPtr spDevice1; + pTexture->GetDevice(&spDevice1); - return false; + return spDevice1.Get() == pDevice; } - ComPtr GetTextureFromDirect3DSurface(const IDirect3DSurface& d3dSurface) { - auto spDxgiInterfaceAccess = d3dSurface.as(); - ComPtr d3d11Texture; - WINML_THROW_IF_FAILED(spDxgiInterfaceAccess->GetInterface(IID_PPV_ARGS(&d3d11Texture))); + return false; +} - return d3d11Texture; - } +ComPtr _winmli::GetTextureFromDirect3DSurface(const wgdx::Direct3D11::IDirect3DSurface& d3dSurface) { + auto spDxgiInterfaceAccess = d3dSurface.as(); + ComPtr d3d11Texture; + WINML_THROW_IF_FAILED(spDxgiInterfaceAccess->GetInterface(IID_PPV_ARGS(&d3d11Texture))); - bool VideoFramesHaveSameDimensions(const IVideoFrame& videoFrame1, const IVideoFrame& videoFrame2) { - if (videoFrame1 && videoFrame2) { - Direct3DSurfaceDescription desc1 = videoFrame1.Direct3DSurface().Description(); - Direct3DSurfaceDescription desc2 = videoFrame2.Direct3DSurface().Description(); + return d3d11Texture; +} - return desc1.Width == desc2.Width && desc1.Height == desc2.Height; - } +bool _winmli::VideoFramesHaveSameDimensions(const wm::IVideoFrame& videoFrame1, const wm::IVideoFrame& videoFrame2) { + if (videoFrame1 && videoFrame2) { + auto desc1 = videoFrame1.Direct3DSurface().Description(); + auto desc2 = videoFrame2.Direct3DSurface().Description(); - return false; + return desc1.Width == desc2.Width && desc1.Height == desc2.Height; } - bool VideoFramesHaveSameDevice(const IVideoFrame& videoFrame1, const IVideoFrame& videoFrame2) { - if (videoFrame1 && videoFrame2) { - ComPtr spTexture1 = GetTextureFromDirect3DSurface(videoFrame1.Direct3DSurface()); - ComPtr spTexture2 = GetTextureFromDirect3DSurface(videoFrame2.Direct3DSurface()); + return false; +} - ComPtr spDevice1, spDevice2; - spTexture1->GetDevice(&spDevice1); - spTexture2->GetDevice(&spDevice2); +bool _winmli::VideoFramesHaveSameDevice(const wm::IVideoFrame& videoFrame1, const wm::IVideoFrame& videoFrame2) { + if (videoFrame1 && videoFrame2) { + ComPtr spTexture1 = _winmli::GetTextureFromDirect3DSurface(videoFrame1.Direct3DSurface()); + ComPtr spTexture2 = _winmli::GetTextureFromDirect3DSurface(videoFrame2.Direct3DSurface()); - return spDevice1.Get() == spDevice2.Get(); - } + ComPtr spDevice1, spDevice2; + spTexture1->GetDevice(&spDevice1); + spTexture2->GetDevice(&spDevice2); - return false; + return spDevice1.Get() == spDevice2.Get(); } + + return false; } \ No newline at end of file diff --git a/winml/lib/Api.Image/ImageConverter.cpp b/winml/lib/Api.Image/ImageConverter.cpp index 956d9d69d8f8e..8df0249c2b488 100644 --- a/winml/lib/Api.Image/ImageConverter.cpp +++ b/winml/lib/Api.Image/ImageConverter.cpp @@ -7,13 +7,8 @@ #include "inc/D3DDeviceCache.h" using namespace Microsoft::WRL; -using namespace Windows::Graphics::DirectX::Direct3D11; -using namespace Windows::AI::MachineLearning::Internal; -using namespace winrt::Windows::AI::MachineLearning::implementation; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Graphics::DirectX; -using namespace winrt::Windows::Graphics::DirectX::Direct3D11; + +using namespace _winml; void ImageConverter::SyncD3D11ToD3D12(_In_ D3DDeviceCache& device_cache, _In_ ID3D11Texture2D* pD3D11Texture) { assert(pD3D11Texture != nullptr); @@ -90,18 +85,20 @@ void ImageConverter::ResetAllocator() { } ComPtr ImageConverter::CreateTextureFromUnsupportedColorFormat( - const IVideoFrame& videoFrame, - const BitmapBounds& inputBounds, - const BitmapBounds& outputBounds, - DirectXPixelFormat newFormat) { + const wm::IVideoFrame& videoFrame, + const wgi::BitmapBounds& inputBounds, + const wgi::BitmapBounds& outputBounds, + wgdx::DirectXPixelFormat newFormat) { assert(videoFrame != nullptr); // Make sure we create the new video frame on the same device. We don't want the VideoFrame pipeline to implicitly share the texture between // 2 devices since we will need to do it ourselves anyway. - IDirect3DDevice device = ImageConversionHelpers::GetDeviceFromDirect3DSurface(videoFrame.Direct3DSurface()); + auto device = _winmli::GetDeviceFromDirect3DSurface(videoFrame.Direct3DSurface()); - VideoFrame spNewVideoFrame = VideoFrame::CreateAsDirect3D11SurfaceBacked(newFormat, outputBounds.Width, outputBounds.Height, device); - videoFrame.as().CopyToAsync(spNewVideoFrame, inputBounds, outputBounds).get(); + auto spNewVideoFrame = wm::VideoFrame::CreateAsDirect3D11SurfaceBacked(newFormat, outputBounds.Width, outputBounds.Height, device); + videoFrame.as().CopyToAsync(spNewVideoFrame, inputBounds, outputBounds).get(); + + using namespace Windows::Graphics::DirectX::Direct3D11; auto spDxgiInterfaceAccess = spNewVideoFrame.Direct3DSurface().as(); ComPtr d3d11Texture; @@ -110,7 +107,8 @@ ComPtr ImageConverter::CreateTextureFromUnsupportedColorFormat( return d3d11Texture; } -void ImageConverter::CopyTextureIntoTexture(_In_ ID3D11Texture2D* pTextureFrom, _In_ const BitmapBounds& inputBounds, _Inout_ ID3D11Texture2D* pTextureTo) { +void ImageConverter::CopyTextureIntoTexture(_In_ ID3D11Texture2D* pTextureFrom, + _In_ const wgi::BitmapBounds& inputBounds, _Inout_ ID3D11Texture2D* pTextureTo) { assert(pTextureFrom != nullptr); assert(pTextureTo != nullptr); diff --git a/winml/lib/Api.Image/TensorToVideoFrameConverter.cpp b/winml/lib/Api.Image/TensorToVideoFrameConverter.cpp index bf717f648c7f0..d3726fb809079 100644 --- a/winml/lib/Api.Image/TensorToVideoFrameConverter.cpp +++ b/winml/lib/Api.Image/TensorToVideoFrameConverter.cpp @@ -12,16 +12,13 @@ #include "inc/TensorToVideoFrameConverter.h" #include "CpuDetensorizer.h" +#include "inc/ImageConversionHelpers.h" #include "LearningModelDevice.h" using namespace Microsoft::WRL; -using namespace Windows::AI::MachineLearning::Internal; using namespace Windows::Graphics::DirectX::Direct3D11; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Graphics::DirectX::Direct3D11; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::AI::MachineLearning::implementation; -using namespace winrt::Windows::Graphics::DirectX; + +using namespace _winml; class GPUTensorToDX12TextureTelemetryEvent { public: @@ -69,36 +66,36 @@ class ConvertCPUTensorToVideoFrameWithSoftwareBitmapTelemetryEvent { void TensorToVideoFrameConverter::DX12TensorToVideoFrame( _In_ UINT32 batchIdx, - _In_ winrt::Windows::AI::MachineLearning::LearningModelSession& session, + _In_ winml::LearningModelSession& session, _In_ ID3D12Resource* pInputTensor, - _In_ const ImageTensorDescription& tensorDesc, - _Inout_ VideoFrame& destVideoFrame) { + _In_ const _winml::ImageTensorDescription& tensorDesc, + _Inout_ wm::VideoFrame& destVideoFrame) { CWinMLAutoLock lock(&lock_); auto spDevice = session.Device().as(); - D3DDeviceCache* pDeviceCache = spDevice->GetD3DDeviceCache(); + _winml::D3DDeviceCache* pDeviceCache = spDevice->GetD3DDeviceCache(); - IDirect3DSurface spDestDirect3DSurface = destVideoFrame.Direct3DSurface(); - SoftwareBitmap softwareBitmap = destVideoFrame.SoftwareBitmap(); + wgdx::Direct3D11::IDirect3DSurface spDestDirect3DSurface = destVideoFrame.Direct3DSurface(); + wgi::SoftwareBitmap softwareBitmap = destVideoFrame.SoftwareBitmap(); if (softwareBitmap) { ConvertGPUTensorToSoftwareBitmap(batchIdx, pInputTensor, *pDeviceCache, tensorDesc, softwareBitmap); } else if (spDestDirect3DSurface) { - bool isUAVSupportedFormat = ImageConversionHelpers::FormatSupportedForUAV( + bool isUAVSupportedFormat = _winmli::FormatSupportedForUAV( pDeviceCache->GetD3D12Device(), - ImageConversionHelpers::GetDXGIFormatFromDirectXPixelFormat(spDestDirect3DSurface.Description().Format)); + _winmli::GetDXGIFormatFromDirectXPixelFormat(spDestDirect3DSurface.Description().Format)); // UAV support for formats is device dependent if (!isUAVSupportedFormat) { ConvertDX12TensorToUnsupportedVideoFrameFormat(batchIdx, pInputTensor, *pDeviceCache, tensorDesc, destVideoFrame); } else { - ComPtr spVideoFrameTexture = ImageConversionHelpers::GetTextureFromDirect3DSurface(destVideoFrame.Direct3DSurface()); + ComPtr spVideoFrameTexture = _winmli::GetTextureFromDirect3DSurface(destVideoFrame.Direct3DSurface()); D3D11_TEXTURE2D_DESC videoFrameTextureDesc; spVideoFrameTexture->GetDesc(&videoFrameTextureDesc); - BitmapBounds bounds = {0, 0, videoFrameTextureDesc.Width, videoFrameTextureDesc.Height}; + wgi::BitmapBounds bounds = {0, 0, videoFrameTextureDesc.Width, videoFrameTextureDesc.Height}; - if (ImageConversionHelpers::TextureIsOnDevice(spVideoFrameTexture.Get(), pDeviceCache->GetD3D11Device())) { + if (_winmli::TextureIsOnDevice(spVideoFrameTexture.Get(), pDeviceCache->GetD3D11Device())) { // The texture is on our device, so we can just create own texture, share it and cache it if (!output_resource_) { output_resource_ = CreateShareableD3D12Texture(videoFrameTextureDesc, pDeviceCache->GetD3D12Device()); @@ -189,20 +186,20 @@ ComPtr TensorToVideoFrameConverter::CreateShareableD3D12Texture( void TensorToVideoFrameConverter::ConvertDX12TensorToUnsupportedVideoFrameFormat( _In_ UINT32 batchIdx, _In_ ID3D12Resource* pInputTensor, - _In_ D3DDeviceCache& device_cache, + _In_ _winml::D3DDeviceCache& device_cache, _In_ const ImageTensorDescription& tensorDesc, - _Inout_ VideoFrame& unsupportedVideoFrame) { + _Inout_ wm::VideoFrame& unsupportedVideoFrame) { assert(pInputTensor != nullptr); // Find the first supported format and convert to it auto supportedFormatIter = std::find_if( - ImageConversionHelpers::supportedWinMLFormats.begin(), - ImageConversionHelpers::supportedWinMLFormats.end(), - [&device_cache](DXGI_FORMAT format) { return ImageConversionHelpers::FormatSupportedForUAV(device_cache.GetD3D12Device(), format); }); + _winmli::supportedWinMLFormats.begin(), + _winmli::supportedWinMLFormats.end(), + [&device_cache](DXGI_FORMAT format) { return _winmli::FormatSupportedForUAV(device_cache.GetD3D12Device(), format); }); WINML_THROW_HR_IF_FALSE_MSG( E_INVALIDARG, - supportedFormatIter != ImageConversionHelpers::supportedWinMLFormats.end(), + supportedFormatIter != _winmli::supportedWinMLFormats.end(), "Detensorization for this format is unsupported on the current device."); D3D11_TEXTURE2D_DESC supportedDesc {}; @@ -215,7 +212,7 @@ void TensorToVideoFrameConverter::ConvertDX12TensorToUnsupportedVideoFrameFormat supportedDesc.SampleDesc.Quality = 0; supportedDesc.Usage = D3D11_USAGE_DEFAULT; - ComPtr unsupportedTexture = ImageConversionHelpers::GetTextureFromDirect3DSurface(unsupportedVideoFrame.Direct3DSurface()); + ComPtr unsupportedTexture = _winmli::GetTextureFromDirect3DSurface(unsupportedVideoFrame.Direct3DSurface()); ComPtr d3d11Device; unsupportedTexture->GetDevice(&d3d11Device); @@ -229,9 +226,9 @@ void TensorToVideoFrameConverter::ConvertDX12TensorToUnsupportedVideoFrameFormat ComPtr inspectableSurface; WINML_THROW_IF_FAILED(CreateDirect3D11SurfaceFromDXGISurface(dxgiSurface.Get(), &inspectableSurface)); - IDirect3DSurface surface; + wgdx::Direct3D11::IDirect3DSurface surface; WINML_THROW_IF_FAILED(inspectableSurface->QueryInterface(winrt::guid_of(), reinterpret_cast(winrt::put_abi(surface)))); - converted_video_frame_ = VideoFrame::CreateWithDirect3D11Surface(surface); + converted_video_frame_ = wm::VideoFrame::CreateWithDirect3D11Surface(surface); // Detensorize ConvertGPUTensorToDX12Texture(batchIdx, pInputTensor, device_cache, tensorDesc, output_resource_.Get()); @@ -268,27 +265,27 @@ ComPtr TensorToVideoFrameConverter::ShareD3D12Texture(ID3D12Res } void TensorToVideoFrameConverter::SoftwareTensorToVideoFrame( - _In_ winrt::Windows::AI::MachineLearning::LearningModelSession& session, + _In_ winml::LearningModelSession& session, _In_ BYTE* pCPUTensorToConvert, _In_ ImageTensorDescription tensorDesc, - _Inout_ winrt::Windows::Media::VideoFrame& pDestVideoFrame) { + _Inout_ wm::VideoFrame& pDestVideoFrame) { CWinMLAutoLock lock(&lock_); - winrt::Windows::Media::IVideoFrame spTensorFrame; + wm::IVideoFrame spTensorFrame; UINT32 outputWidth = 0; UINT32 outputHeight = 0; UINT32 tensorHeight = static_cast(tensorDesc.sizes[2]); UINT32 tensorWidth = static_cast(tensorDesc.sizes[3]); // create a bitmap bounds for the whole image/tensor - BitmapBounds inputBounds = + wgi::BitmapBounds inputBounds = { 0, 0, tensorWidth, tensorHeight}; - winrt::Windows::Graphics::Imaging::SoftwareBitmap spOutputSoftwareBitmap = pDestVideoFrame.SoftwareBitmap(); - winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DSurface spOutputSurface = pDestVideoFrame.Direct3DSurface(); + wgi::SoftwareBitmap spOutputSoftwareBitmap = pDestVideoFrame.SoftwareBitmap(); + wgdx::Direct3D11::IDirect3DSurface spOutputSurface = pDestVideoFrame.Direct3DSurface(); // only one of softwarebitmap or direct3Dsurface should be non-null if ((spOutputSoftwareBitmap == nullptr && spOutputSurface == nullptr) || (spOutputSoftwareBitmap != nullptr && spOutputSurface != nullptr)) { @@ -298,16 +295,17 @@ void TensorToVideoFrameConverter::SoftwareTensorToVideoFrame( outputWidth = spOutputSoftwareBitmap.PixelWidth(); outputHeight = spOutputSoftwareBitmap.PixelHeight(); } else { - Direct3DSurfaceDescription description; + wgdx::Direct3D11::Direct3DSurfaceDescription description; description = spOutputSurface.Description(); outputWidth = description.Width; outputHeight = description.Height; } - if (ImageConversionHelpers::NeedsVideoFrameConversion(pDestVideoFrame, {}, {0, 0, (UINT32)tensorWidth, (UINT32)tensorHeight}, tensorWidth, tensorHeight)) { + if (_winmli::NeedsVideoFrameConversion(pDestVideoFrame, {}, {0, 0, (UINT32)tensorWidth, (UINT32)tensorHeight}, tensorWidth, tensorHeight)) { if (converted_video_frame_ == nullptr || - ImageConversionHelpers::NeedsVideoFrameConversion(converted_video_frame_, {}, {0, 0, (UINT32)tensorWidth, (UINT32)tensorHeight}, tensorWidth, tensorHeight)) { - converted_video_frame_ = VideoFrame::CreateWithSoftwareBitmap(SoftwareBitmap(BitmapPixelFormat::Bgra8, tensorWidth, tensorHeight)); + _winmli::NeedsVideoFrameConversion(converted_video_frame_, {}, {0, 0, (UINT32)tensorWidth, (UINT32)tensorHeight}, tensorWidth, tensorHeight)) { + converted_video_frame_ = wm::VideoFrame::CreateWithSoftwareBitmap( + wgi::SoftwareBitmap(wgi::BitmapPixelFormat::Bgra8, tensorWidth, tensorHeight)); } spTensorFrame = converted_video_frame_; @@ -322,7 +320,7 @@ void TensorToVideoFrameConverter::SoftwareTensorToVideoFrame( bitmap); if (converted_video_frame_) { - ImageConversionHelpers::ConvertVideoFrameToVideoFrame( + _winmli::ConvertVideoFrameToVideoFrame( converted_video_frame_, inputBounds, outputWidth, @@ -334,7 +332,7 @@ void TensorToVideoFrameConverter::SoftwareTensorToVideoFrame( void TensorToVideoFrameConverter::ConvertGPUTensorToDX12Texture( _In_ UINT32 batchIdx, _In_ ID3D12Resource* pInputResource, - _In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, + _In_ _winml::D3DDeviceCache& device_cache, _In_ const ImageTensorDescription& tensorDesc, _Inout_ ID3D12Resource* pOutputResource) { assert(pInputResource != nullptr); @@ -479,9 +477,9 @@ void TensorToVideoFrameConverter::ConvertGPUTensorToDX12Texture( void TensorToVideoFrameConverter::ConvertGPUTensorToSoftwareBitmap( _In_ UINT32 batchIdx, _In_ ID3D12Resource* pInputTensor, - _In_ D3DDeviceCache& device_cache, + _In_ _winml::D3DDeviceCache& device_cache, _In_ const ImageTensorDescription& tensorDesc, - _Inout_ SoftwareBitmap& softwareBitmap) { + _Inout_ wgi::SoftwareBitmap& softwareBitmap) { assert(pInputTensor != nullptr); assert(softwareBitmap != nullptr); @@ -523,7 +521,7 @@ void TensorToVideoFrameConverter::ConvertGPUTensorToSoftwareBitmap( D3D12_SHADER_RESOURCE_VIEW_DESC TensorToVideoFrameConverter::CreateSRVDescriptor( const UINT32 batchIdx, const D3D12_RESOURCE_DESC& resourceDesc, - const ImageTensorDescription& desc) { + const _winml::ImageTensorDescription& desc) { UINT uiTensorElementSize = desc.dataType == kImageTensorDataTypeFloat32 ? sizeof(UINT) : sizeof(uint16_t); @@ -559,7 +557,7 @@ D3D12_SHADER_RESOURCE_VIEW_DESC TensorToVideoFrameConverter::CreateSRVDescriptor void TensorToVideoFrameConverter::ConvertCPUTensorToSoftwareBitmap( _In_ void* pCPUTensor, _In_ const ImageTensorDescription& tensorDesc, - _Inout_ SoftwareBitmap& softwareBitmap) { + _Inout_ wgi::SoftwareBitmap& softwareBitmap) { ConvertCPUTensorToVideoFrameWithSoftwareBitmapTelemetryEvent telemetrylogger(tensorDesc); auto height = softwareBitmap.PixelHeight(); @@ -569,7 +567,7 @@ void TensorToVideoFrameConverter::ConvertCPUTensorToSoftwareBitmap( // Validate input description WINML_THROW_HR_IF_FALSE_MSG( E_INVALIDARG, - format == BitmapPixelFormat::Bgra8 || format == BitmapPixelFormat::Rgba8 || format == BitmapPixelFormat::Gray8, + format == wgi::BitmapPixelFormat::Bgra8 || format == wgi::BitmapPixelFormat::Rgba8 || format == wgi::BitmapPixelFormat::Gray8, "Format was input image %d. Input image format must Bgra8, Rgba8 or Gray8.", format); WINML_THROW_HR_IF_FALSE_MSG(E_INVALIDARG, height > 0, "Output input image height provided. Height is set to zero."); @@ -594,14 +592,14 @@ void TensorToVideoFrameConverter::ConvertCPUTensorToSoftwareBitmap( BYTE* pData = nullptr; UINT32 uiCapacity = 0; - winrt::Windows::Graphics::Imaging::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(winrt::Windows::Graphics::Imaging::BitmapBufferAccessMode::Write)); - winrt::Windows::Foundation::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); + wgi::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(wgi::BitmapBufferAccessMode::Write)); + wf::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); auto spByteAccess = reference.as(); WINML_THROW_IF_FAILED(spByteAccess->GetBuffer(&pData, &uiCapacity)); uint32_t bufferWidth = uiCapacity / height; - ImageTensorChannelType targetChannelType = ImageConversionHelpers::GetChannelTypeFromSoftwareBitmap(softwareBitmap); + ImageTensorChannelType targetChannelType = _winmli::GetChannelTypeFromSoftwareBitmap(softwareBitmap); if (tensorDesc.dataType == kImageTensorDataTypeFloat32) { WINML_THROW_IF_FAILED(CpuDetensorizer::Detensorize(tensorDesc.channelType, targetChannelType, static_cast(pCPUTensor), bufferWidth, height, width, pData)); diff --git a/winml/lib/Api.Image/VideoFrameToTensorConverter.cpp b/winml/lib/Api.Image/VideoFrameToTensorConverter.cpp index da1b582b3ae8a..21478f1dd0cef 100644 --- a/winml/lib/Api.Image/VideoFrameToTensorConverter.cpp +++ b/winml/lib/Api.Image/VideoFrameToTensorConverter.cpp @@ -15,13 +15,9 @@ #include "LearningModelDevice.h" using namespace Microsoft::WRL; -using namespace Windows::AI::MachineLearning::Internal; using namespace Windows::Graphics::DirectX::Direct3D11; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Graphics::DirectX::Direct3D11; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::AI::MachineLearning::implementation; -using namespace winrt::Windows::Graphics::DirectX; + +using namespace _winml; class DX12TextureToGPUTensorTelemetryEvent { public: @@ -68,14 +64,14 @@ class ConvertVideoFrameWithSoftwareBitmapToCPUTensorTelemetryEvent { }; void VideoFrameToTensorConverter::VideoFrameToSoftwareTensor( - _In_ const IVideoFrame& inputVideoFrame, - _In_ const BitmapBounds& inputBounds, + _In_ const wm::IVideoFrame& inputVideoFrame, + _In_ const wgi::BitmapBounds& inputBounds, _In_ const ImageTensorDescription& tensorDesc, _Out_ BYTE* pOutputCPUTensor) { CWinMLAutoLock lock(&lock_); - winrt::Windows::Graphics::Imaging::SoftwareBitmap spInputSoftwareBitmap = inputVideoFrame.SoftwareBitmap(); - winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DSurface spInputSurface = inputVideoFrame.Direct3DSurface(); + wgi::SoftwareBitmap spInputSoftwareBitmap = inputVideoFrame.SoftwareBitmap(); + wgdx::Direct3D11::IDirect3DSurface spInputSurface = inputVideoFrame.Direct3DSurface(); // only one of softwarebitmap or direct3Dsurface should be non-null if ((spInputSoftwareBitmap == nullptr && spInputSurface == nullptr) || (spInputSoftwareBitmap != nullptr && spInputSurface != nullptr)) { @@ -84,14 +80,15 @@ void VideoFrameToTensorConverter::VideoFrameToSoftwareTensor( UINT32 tensorHeight = static_cast(tensorDesc.sizes[2]); UINT32 tensorWidth = static_cast(tensorDesc.sizes[3]); - if (spInputSurface || ImageConversionHelpers::NeedsVideoFrameConversion(inputVideoFrame, {}, inputBounds, tensorWidth, tensorHeight)) { + if (spInputSurface || _winmli::NeedsVideoFrameConversion(inputVideoFrame, {}, inputBounds, tensorWidth, tensorHeight)) { if (converted_video_frame_ == nullptr || - ImageConversionHelpers::NeedsVideoFrameConversion(converted_video_frame_, {}, {0, 0, (UINT32)tensorWidth, (UINT32)tensorHeight}, tensorWidth, tensorHeight)) { - converted_video_frame_ = VideoFrame::CreateWithSoftwareBitmap(SoftwareBitmap(BitmapPixelFormat::Bgra8, tensorWidth, tensorHeight)); + _winmli::NeedsVideoFrameConversion(converted_video_frame_, {}, {0, 0, (UINT32)tensorWidth, (UINT32)tensorHeight}, tensorWidth, tensorHeight)) { + converted_video_frame_ = wm::VideoFrame::CreateWithSoftwareBitmap( + wgi::SoftwareBitmap(wgi::BitmapPixelFormat::Bgra8, tensorWidth, tensorHeight)); } // Resize the input VideoFrame to converted_video_frame_ - ImageConversionHelpers::ConvertVideoFrameToVideoFrame( + _winmli::ConvertVideoFrameToVideoFrame( inputVideoFrame, inputBounds, tensorWidth, @@ -135,9 +132,9 @@ ComPtr VideoFrameToTensorConverter::ShareD3D11Texture(ID3D11Text void VideoFrameToTensorConverter::VideoFrameToDX12Tensor( _In_ const UINT32 batchIdx, - _In_ winrt::Windows::AI::MachineLearning::LearningModelSession& session, - _In_ const IVideoFrame& inputVideoFrame, - _In_ const BitmapBounds& inputBounds, + _In_ winml::LearningModelSession& session, + _In_ const wm::IVideoFrame& inputVideoFrame, + _In_ const wgi::BitmapBounds& inputBounds, _In_ const ImageTensorDescription& tensorDesc, _Inout_ ID3D12Resource* pOutputTensor) { // Validate Tensor description @@ -147,22 +144,22 @@ void VideoFrameToTensorConverter::VideoFrameToDX12Tensor( WINML_THROW_HR_IF_FALSE_MSG(E_INVALIDARG, tensorDesc.channelType != kImageTensorChannelTypeGRAY8 || tensorDesc.sizes[1] == 1, "Target tensor description expects kImageTensorChannelTypeGRAY8, but has %lld channels specified instead of 1.", tensorDesc.sizes[1]); CWinMLAutoLock lock(&lock_); - auto device = session.Device().as(); - D3DDeviceCache* pDeviceCache = device->GetD3DDeviceCache(); - IDirect3DSurface spDirect3DSurface = inputVideoFrame.Direct3DSurface(); + auto device = session.Device().as(); + _winml::D3DDeviceCache* pDeviceCache = device->GetD3DDeviceCache(); + wgdx::Direct3D11::IDirect3DSurface spDirect3DSurface = inputVideoFrame.Direct3DSurface(); if (inputVideoFrame.SoftwareBitmap()) { ConvertSoftwareBitmapToGPUTensor(batchIdx, inputVideoFrame, *pDeviceCache, inputBounds, tensorDesc, pOutputTensor); } else if (spDirect3DSurface) { ComPtr spVideoFrameTexture; - BitmapBounds scaledBounds = inputBounds; + wgi::BitmapBounds scaledBounds = inputBounds; // TODO: Scale during the tensorization phase instead of using the video frame pipeline when the input bounds are not the same size as the tensor - if (!ImageConversionHelpers::DirectXPixelFormatSupported(spDirect3DSurface.Description().Format) || static_cast(inputBounds.Width) != tensorDesc.sizes[3] || static_cast(inputBounds.Height) != tensorDesc.sizes[2]) { + if (!_winmli::DirectXPixelFormatSupported(spDirect3DSurface.Description().Format) || static_cast(inputBounds.Width) != tensorDesc.sizes[3] || static_cast(inputBounds.Height) != tensorDesc.sizes[2]) { // Force the VideoFrame to not do a conversion if the format is supported since we do it during the tensorization anyway - DirectXPixelFormat newFormat = ImageConversionHelpers::DirectXPixelFormatSupported(spDirect3DSurface.Description().Format) + wgdx::DirectXPixelFormat newFormat = _winmli::DirectXPixelFormatSupported(spDirect3DSurface.Description().Format) ? spDirect3DSurface.Description().Format - : ImageConversionHelpers::GetDirectXPixelFormatFromChannelType(tensorDesc.channelType); + : _winmli::GetDirectXPixelFormatFromChannelType(tensorDesc.channelType); // Change the input bounds since the video frame pipeline already cropped the texture scaledBounds = {0, 0, static_cast(tensorDesc.sizes[3]), static_cast(tensorDesc.sizes[2])}; @@ -171,13 +168,13 @@ void VideoFrameToTensorConverter::VideoFrameToDX12Tensor( spVideoFrameTexture = CreateTextureFromUnsupportedColorFormat(inputVideoFrame, inputBounds, scaledBounds, newFormat); } else { // If the color format is known or the input widths are not smaller than the tensor desc, just use the video frame as is - spVideoFrameTexture = ImageConversionHelpers::GetTextureFromDirect3DSurface(spDirect3DSurface); + spVideoFrameTexture = _winmli::GetTextureFromDirect3DSurface(spDirect3DSurface); } D3D11_TEXTURE2D_DESC videoFrameTextureDesc; spVideoFrameTexture->GetDesc(&videoFrameTextureDesc); - if (ImageConversionHelpers::TextureIsOnDevice(spVideoFrameTexture.Get(), pDeviceCache->GetD3D11Device())) { + if (_winmli::TextureIsOnDevice(spVideoFrameTexture.Get(), pDeviceCache->GetD3D11Device())) { // The texture is on our device, so we can just create own texture, share it and cache it if (!D3D11_cached_texture_) { WINML_THROW_IF_FAILED(pDeviceCache->GetD3D11Device()->CreateTexture2D(&videoFrameTextureDesc, nullptr, &D3D11_cached_texture_)); @@ -234,7 +231,7 @@ void VideoFrameToTensorConverter::VideoFrameToDX12Tensor( void VideoFrameToTensorConverter::ConvertDX12TextureToGPUTensor( _In_ UINT32 batchIdx, _In_ ID3D12Resource* pInputResource, - _In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, + _In_ _winml::D3DDeviceCache& device_cache, _In_ const ImageTensorDescription& tensorDesc, _Inout_ ID3D12Resource* pOutputResource) { assert(pInputResource != nullptr); @@ -394,9 +391,9 @@ void VideoFrameToTensorConverter::ConvertDX12TextureToGPUTensor( void VideoFrameToTensorConverter::ConvertSoftwareBitmapToGPUTensor( _In_ UINT32 batchIdx, - _In_ const IVideoFrame& videoFrame, - _In_ D3DDeviceCache& device_cache, - _In_ const BitmapBounds& inputBounds, + _In_ const wm::IVideoFrame& videoFrame, + _In_ _winml::D3DDeviceCache& device_cache, + _In_ const wgi::BitmapBounds& inputBounds, _In_ const ImageTensorDescription& tensorDesc, _Inout_ ID3D12Resource* pOutputResource) { assert(pOutputResource != nullptr); @@ -404,25 +401,25 @@ void VideoFrameToTensorConverter::ConvertSoftwareBitmapToGPUTensor( DX12TextureToGPUTensorTelemetryEvent telemetrylogger(tensorDesc); - SoftwareBitmap convertedSoftwareBitmap = nullptr; - BitmapBounds scaledBounds = inputBounds; + wgi::SoftwareBitmap convertedSoftwareBitmap = nullptr; + wgi::BitmapBounds scaledBounds = inputBounds; // TODO: Scale during the tensorization phase instead of using the video frame pipeline when the input bounds are not the same size as the tensor if (static_cast(inputBounds.Width) != tensorDesc.sizes[3] || static_cast(inputBounds.Height) != tensorDesc.sizes[2]) { scaledBounds = {0, 0, static_cast(tensorDesc.sizes[3]), static_cast(tensorDesc.sizes[2])}; // Force the VideoFrame to not do a conversion if the format is supported since we do it during the tensorization anyway - BitmapPixelFormat newPixelFormat = ImageConversionHelpers::SoftwareBitmapFormatSupported(videoFrame.SoftwareBitmap()) + wgi::BitmapPixelFormat newPixelFormat = _winmli::SoftwareBitmapFormatSupported(videoFrame.SoftwareBitmap()) ? videoFrame.SoftwareBitmap().BitmapPixelFormat() - : ImageConversionHelpers::GetBitmapPixelFormatFromChannelType(tensorDesc.channelType); + : _winmli::GetBitmapPixelFormatFromChannelType(tensorDesc.channelType); - convertedSoftwareBitmap = SoftwareBitmap(newPixelFormat, static_cast(tensorDesc.sizes[3]), static_cast(tensorDesc.sizes[2])); - VideoFrame convertedVideoFrame = VideoFrame::CreateWithSoftwareBitmap(convertedSoftwareBitmap); - videoFrame.as().CopyToAsync(convertedVideoFrame, inputBounds, scaledBounds).get(); + convertedSoftwareBitmap = wgi::SoftwareBitmap(newPixelFormat, static_cast(tensorDesc.sizes[3]), static_cast(tensorDesc.sizes[2])); + wm::VideoFrame convertedVideoFrame = wm::VideoFrame::CreateWithSoftwareBitmap(convertedSoftwareBitmap); + videoFrame.as().CopyToAsync(convertedVideoFrame, inputBounds, scaledBounds).get(); convertedSoftwareBitmap = convertedVideoFrame.SoftwareBitmap(); - } else if (!ImageConversionHelpers::SoftwareBitmapFormatSupported(videoFrame.SoftwareBitmap())) { - convertedSoftwareBitmap = SoftwareBitmap::Convert(videoFrame.SoftwareBitmap(), ImageConversionHelpers::GetBitmapPixelFormatFromChannelType(tensorDesc.channelType)); + } else if (!_winmli::SoftwareBitmapFormatSupported(videoFrame.SoftwareBitmap())) { + convertedSoftwareBitmap = wgi::SoftwareBitmap::Convert(videoFrame.SoftwareBitmap(), _winmli::GetBitmapPixelFormatFromChannelType(tensorDesc.channelType)); } else { // We don't need a conversion convertedSoftwareBitmap = videoFrame.SoftwareBitmap(); @@ -467,7 +464,7 @@ void VideoFrameToTensorConverter::ConvertSoftwareBitmapToGPUTensor( D3D12_UNORDERED_ACCESS_VIEW_DESC VideoFrameToTensorConverter::CreateUAVDescription( const UINT32 batchIdx, const D3D12_RESOURCE_DESC& resourceDesc, - const ImageTensorDescription& desc) { + const _winml::ImageTensorDescription& desc) { UINT uiTensorElementSize = desc.dataType == kImageTensorDataTypeFloat32 ? sizeof(UINT) : sizeof(uint16_t); @@ -501,9 +498,9 @@ D3D12_UNORDERED_ACCESS_VIEW_DESC VideoFrameToTensorConverter::CreateUAVDescripti } void VideoFrameToTensorConverter::ConvertSoftwareBitmapToCPUTensor( - _In_ const SoftwareBitmap& softwareBitmap, - _In_ const ImageTensorDescription& tensorDesc, - _In_ const BitmapBounds& inputBounds, + _In_ const wgi::SoftwareBitmap& softwareBitmap, + _In_ const _winml::ImageTensorDescription& tensorDesc, + _In_ const wgi::BitmapBounds& inputBounds, _Inout_ void* pCPUTensor) { assert(softwareBitmap != nullptr); @@ -516,7 +513,7 @@ void VideoFrameToTensorConverter::ConvertSoftwareBitmapToCPUTensor( // Validate input description WINML_THROW_HR_IF_FALSE_MSG( E_INVALIDARG, - format == BitmapPixelFormat::Bgra8 || format == BitmapPixelFormat::Rgba8 || format == BitmapPixelFormat::Gray8, + format == wgi::BitmapPixelFormat::Bgra8 || format == wgi::BitmapPixelFormat::Rgba8 || format == wgi::BitmapPixelFormat::Gray8, "Format was input image %d. Input image format must Bgra8, Rgba8 or Gray8.", format); WINML_THROW_HR_IF_FALSE_MSG(E_INVALIDARG, height > 0, "Invalid input image height provided. Height is set to zero."); @@ -540,18 +537,18 @@ void VideoFrameToTensorConverter::ConvertSoftwareBitmapToCPUTensor( // get the byte buffer out of a softwarebitmap BYTE* pData = nullptr; UINT32 bufferSize = 0; - winrt::Windows::Graphics::Imaging::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(winrt::Windows::Graphics::Imaging::BitmapBufferAccessMode::Read)); - winrt::Windows::Foundation::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); + wgi::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(wgi::BitmapBufferAccessMode::Read)); + wf::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); auto spByteAccess = reference.as(); WINML_THROW_IF_FAILED(spByteAccess->GetBuffer(&pData, &bufferSize)); UINT32 bufferWidth = bufferSize / height; - ImageTensorChannelType channelType = ImageConversionHelpers::GetChannelTypeFromSoftwareBitmap(softwareBitmap); + ImageTensorChannelType channelType = _winmli::GetChannelTypeFromSoftwareBitmap(softwareBitmap); - if (tensorDesc.dataType == kImageTensorDataTypeFloat32) { + if (tensorDesc.dataType == _winml::kImageTensorDataTypeFloat32) { WINML_THROW_IF_FAILED(CpuTensorizer::TensorizeData(channelType, tensorDesc.channelType, pData, bufferWidth, inputBounds, reinterpret_cast(pCPUTensor))); - } else if (tensorDesc.dataType == kImageTensorDataTypeFloat16) { + } else if (tensorDesc.dataType == _winml::kImageTensorDataTypeFloat16) { WINML_THROW_IF_FAILED(CpuTensorizer::TensorizeData(channelType, tensorDesc.channelType, pData, bufferWidth, inputBounds, reinterpret_cast(pCPUTensor))); } } \ No newline at end of file diff --git a/winml/lib/Api.Image/inc/ConverterResourceStore.h b/winml/lib/Api.Image/inc/ConverterResourceStore.h index ed3cd10907ec1..9162d334f2b10 100644 --- a/winml/lib/Api.Image/inc/ConverterResourceStore.h +++ b/winml/lib/Api.Image/inc/ConverterResourceStore.h @@ -7,7 +7,7 @@ #include "VideoFrameToTensorConverter.h" #include "TensorToVideoFrameConverter.h" -namespace Windows::AI::MachineLearning { +namespace _winml { // Forward Declare class ConverterResourceStore; @@ -51,8 +51,8 @@ class ConverterResources : public std::enable_shared_from_this Tensorizer; - std::unique_ptr Detensorizer; + std::unique_ptr<_winml::VideoFrameToTensorConverter> Tensorizer; + std::unique_ptr<_winml::TensorToVideoFrameConverter> Detensorizer; private: Pool m_pool; @@ -115,4 +115,4 @@ class PoolObjectWrapper { std::shared_ptr m_resources; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Image/inc/D3DDeviceCache.h b/winml/lib/Api.Image/inc/D3DDeviceCache.h index 43df3dcadf070..d18b950d88ad0 100644 --- a/winml/lib/Api.Image/inc/D3DDeviceCache.h +++ b/winml/lib/Api.Image/inc/D3DDeviceCache.h @@ -14,7 +14,8 @@ #define VcppException(sev, err) ((sev) | (FACILITY_VISUALCPP << 16) | err) -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace _winml { + enum class PipelineStateCacheType : unsigned char { kFloat32 = 0, kFloat16 = 1, @@ -37,8 +38,8 @@ enum class PipelineStateCacheOperation : unsigned char { class D3DDeviceCache { public: ~D3DDeviceCache(); - D3DDeviceCache(Windows::AI::MachineLearning::LearningModelDeviceKind const& device_kind); - D3DDeviceCache(Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice const& device); + D3DDeviceCache(winml::LearningModelDeviceKind const& device_kind); + D3DDeviceCache(wgdx::Direct3D11::IDirect3DDevice const& device); D3DDeviceCache(ID3D12CommandQueue* queue); ID3D11Device* GetD3D11Device(); @@ -47,7 +48,7 @@ class D3DDeviceCache { ID3D12Device1* GetD3D12Device() { return device_.get(); } ID3D12CommandQueue* GetCommandQueue() { return command_queue_.get(); } - Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice GetWinrtDevice(); + wgdx::Direct3D11::IDirect3DDevice GetWinrtDevice(); ID3D12RootSignature* GetTensorizeRootSignature(); ID3D12RootSignature* GetDetensorizeRootSignature(); @@ -83,28 +84,28 @@ class D3DDeviceCache { ID3D12PipelineState* CreateTensorizePipelineState(PipelineStateCacheType type, PipelineStateCacheFormat format_from, PipelineStateCacheFormat format_to); ID3D12PipelineState* CreateDetensorizePipelineState(PipelineStateCacheType type, PipelineStateCacheFormat format_from, PipelineStateCacheFormat format_to); - com_ptr device_; - com_ptr command_queue_; - com_ptr sharing_contract_; + winrt::com_ptr device_; + winrt::com_ptr command_queue_; + winrt::com_ptr sharing_contract_; - com_ptr device_11_; - Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice winrt_device_; - com_ptr device_context11_; + winrt::com_ptr device_11_; + wgdx::Direct3D11::IDirect3DDevice winrt_device_; + winrt::com_ptr device_context11_; - com_ptr tensorize_root_signature_; - com_ptr detensorize_root_signature_; + winrt::com_ptr tensorize_root_signature_; + winrt::com_ptr detensorize_root_signature_; - com_ptr cached_pipeline_state[PipelineStateCacheType::kCount][PipelineStateCacheFormat::kCount][PipelineStateCacheFormat::kCount][PipelineStateCacheOperation::kCount]; + winrt::com_ptr cached_pipeline_state[PipelineStateCacheType::kCount][PipelineStateCacheFormat::kCount][PipelineStateCacheFormat::kCount][PipelineStateCacheOperation::kCount]; - com_ptr detensorize_vertex_buffer_; + winrt::com_ptr detensorize_vertex_buffer_; - com_ptr d3d11_fence_; - com_ptr d3d12_fence_; + winrt::com_ptr d3d11_fence_; + winrt::com_ptr d3d12_fence_; std::atomic fence_value_ = 1; GUID fence_guid_; - com_ptr converter_fence_; + winrt::com_ptr converter_fence_; wil::unique_handle converter_fence_handle_; std::atomic converter_fence_value_ = 1; @@ -115,4 +116,4 @@ class D3DDeviceCache { // initialization happen later, we need make it thread safe. CWinMLLock lock_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Image/inc/DeviceHelpers.h b/winml/lib/Api.Image/inc/DeviceHelpers.h index f1b82217b020a..9d73ed71ac10a 100644 --- a/winml/lib/Api.Image/inc/DeviceHelpers.h +++ b/winml/lib/Api.Image/inc/DeviceHelpers.h @@ -14,11 +14,16 @@ #include #endif -namespace DeviceHelpers { +namespace _winml { + HRESULT CreateD3D11On12Device(ID3D12Device* device12, ID3D11Device** device11); + #ifdef ENABLE_DXCORE HRESULT GetDXCoreHardwareAdapterWithPreference(DXGI_GPU_PREFERENCE preference, _COM_Outptr_ IDXCoreAdapter** ppAdapter); #endif + HRESULT GetDXGIHardwareAdapterWithPreference(DXGI_GPU_PREFERENCE preference, _COM_Outptr_ IDXGIAdapter1** adapter); -HRESULT GetGPUPreference(winrt::Windows::AI::MachineLearning::LearningModelDeviceKind deviceKind, DXGI_GPU_PREFERENCE* preference) noexcept; -} // namespace DeviceHelpers + +HRESULT GetGPUPreference(winml::LearningModelDeviceKind deviceKind, DXGI_GPU_PREFERENCE* preference) noexcept; + +} // namespace _winml diff --git a/winml/lib/Api.Image/inc/ImageConversionHelpers.h b/winml/lib/Api.Image/inc/ImageConversionHelpers.h index a4805664bbb81..7c0e96d017bf2 100644 --- a/winml/lib/Api.Image/inc/ImageConversionHelpers.h +++ b/winml/lib/Api.Image/inc/ImageConversionHelpers.h @@ -6,15 +6,16 @@ #include #include "ImageConversionTypes.h" -namespace Windows::AI::MachineLearning::Internal::ImageConversionHelpers { +namespace _winml::Imaging { + // This API that takes a video frame and converts it to a video frame of desired format (DXGI_FORMAT_B8G8R8X8_UNORM/BitmapPixelFormat::Bgra8) and size (after any scale/crop operations). // This should also cover any DX adapter hop (if needed in a multi GPU scenario) and CPU->GPU / GPU->CPU conversion void ConvertVideoFrameToVideoFrame( - _In_ const winrt::Windows::Media::IVideoFrame& input_video_frame, - _In_ const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, + _In_ const wm::IVideoFrame& input_video_frame, + _In_ const wgi::BitmapBounds& input_bounds, _In_ UINT32 output_width, _In_ UINT32 output_height, - _Inout_ winrt::Windows::Media::VideoFrame& output_video_frame); + _Inout_ wm::VideoFrame& output_video_frame); // This helper method uses the input parameters do determine if a conversion is necessary // A conversion is not necessary if @@ -23,32 +24,32 @@ namespace Windows::AI::MachineLearning::Internal::ImageConversionHelpers { // 3. (mapping softwarebitmap to softwarebitmap) OR (mapping from d3dsurface to d3dsurface AND the two surfaces are on the same device) // 4. the input is already in the desired format (BGRA8/B8G8R8X8UIntNormalized) bool NeedsVideoFrameConversion( - _In_ const winrt::Windows::Media::IVideoFrame& input_video_frame, + _In_ const wm::IVideoFrame& input_video_frame, _In_ LUID output_luid, - _In_ const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, + _In_ const wgi::BitmapBounds& input_bounds, _In_ UINT32 output_width, _In_ UINT32 output_height); - bool SoftwareBitmapFormatSupported(const winrt::Windows::Graphics::Imaging::SoftwareBitmap& software_bitmap); - bool DirectXPixelFormatSupported(winrt::Windows::Graphics::DirectX::DirectXPixelFormat format); + bool SoftwareBitmapFormatSupported(const wgi::SoftwareBitmap& software_bitmap); + bool DirectXPixelFormatSupported(wgdx::DirectXPixelFormat format); bool FormatSupportedForUAV(_In_ ID3D12Device1* device, _In_ DXGI_FORMAT format); - ImageTensorChannelType GetChannelTypeFromSoftwareBitmap(const winrt::Windows::Graphics::Imaging::SoftwareBitmap& software_bitmap); - ImageTensorChannelType GetChannelTypeFromDirect3DSurface(const winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DSurface& direct3D_surface); - winrt::Windows::Graphics::Imaging::BitmapPixelFormat GetBitmapPixelFormatFromChannelType(ImageTensorChannelType channel_type); - winrt::Windows::Graphics::DirectX::DirectXPixelFormat GetDirectXPixelFormatFromDXGIFormat(DXGI_FORMAT dxgi_format); - DXGI_FORMAT GetDXGIFormatFromDirectXPixelFormat(_In_ winrt::Windows::Graphics::DirectX::DirectXPixelFormat directX_pixel_format); - winrt::Windows::Graphics::DirectX::DirectXPixelFormat GetDirectXPixelFormatFromChannelType(_In_ ImageTensorChannelType channel_type); - Microsoft::WRL::ComPtr GetTextureFromDirect3DSurface(const winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DSurface& d3d_surface); + ImageTensorChannelType GetChannelTypeFromSoftwareBitmap(const wgi::SoftwareBitmap& software_bitmap); + ImageTensorChannelType GetChannelTypeFromDirect3DSurface(const wgdx::Direct3D11::IDirect3DSurface& direct3D_surface); + wgi::BitmapPixelFormat GetBitmapPixelFormatFromChannelType(ImageTensorChannelType channel_type); + wgdx::DirectXPixelFormat GetDirectXPixelFormatFromDXGIFormat(DXGI_FORMAT dxgi_format); + DXGI_FORMAT GetDXGIFormatFromDirectXPixelFormat(_In_ wgdx::DirectXPixelFormat directX_pixel_format); + wgdx::DirectXPixelFormat GetDirectXPixelFormatFromChannelType(_In_ ImageTensorChannelType channel_type); + Microsoft::WRL::ComPtr GetTextureFromDirect3DSurface(const wgdx::Direct3D11::IDirect3DSurface& d3d_surface); bool TexturesHaveSameDevice(_In_ ID3D11Texture2D* pTexture1, _In_ ID3D11Texture2D* texture2d); bool TextureIsOnDevice(_In_ ID3D11Texture2D* pTexture, _In_ ID3D11Device* device); - bool VideoFramesHaveSameDimensions(const winrt::Windows::Media::IVideoFrame& video_frame_1, const winrt::Windows::Media::IVideoFrame& video_frame_2); - bool VideoFramesHaveSameDevice(const winrt::Windows::Media::IVideoFrame& video_frame_1, const winrt::Windows::Media::IVideoFrame& video_frame_2); + bool VideoFramesHaveSameDimensions(const wm::IVideoFrame& video_frame_1, const wm::IVideoFrame& video_frame_2); + bool VideoFramesHaveSameDevice(const wm::IVideoFrame& video_frame_1, const wm::IVideoFrame& video_frame_2); - winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice GetDeviceFromDirect3DSurface( - const winrt::Windows::Graphics::DirectX::Direct3D11::IDirect3DSurface& d3dSurface); + wgdx::Direct3D11::IDirect3DDevice GetDeviceFromDirect3DSurface( + const wgdx::Direct3D11::IDirect3DSurface& d3dSurface); constexpr std::array supportedWinMLFormats = { DXGI_FORMAT_R8G8B8A8_UNORM, DXGI_FORMAT_B8G8R8A8_UNORM, DXGI_FORMAT_B8G8R8X8_UNORM}; -} // namespace Windows::AI::MachineLearning::Internal::ImageConversionHelpers \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Image/inc/ImageConversionTypes.h b/winml/lib/Api.Image/inc/ImageConversionTypes.h index 667167743b501..489c00c01f69f 100644 --- a/winml/lib/Api.Image/inc/ImageConversionTypes.h +++ b/winml/lib/Api.Image/inc/ImageConversionTypes.h @@ -3,7 +3,7 @@ #pragma once -namespace Windows::AI::MachineLearning::Internal { +namespace _winml { const UINT kImageTensorDimensionCountMax = 4; // NCHW format enum ImageTensorDataType { @@ -30,4 +30,4 @@ struct ImageTensorDescription { ImageTensorChannelType channelType; int64_t sizes[kImageTensorDimensionCountMax]; }; -} // namespace Windows::AI::MachineLearning::Internal \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Image/inc/ImageConverter.h b/winml/lib/Api.Image/inc/ImageConverter.h index c564c51ce455f..74ed16e23565d 100644 --- a/winml/lib/Api.Image/inc/ImageConverter.h +++ b/winml/lib/Api.Image/inc/ImageConverter.h @@ -5,7 +5,6 @@ #include #include "WinML_Lock.h" -#include "ImageConversionHelpers.h" // Assign a name to the object to aid with debugging. #if defined(_DEBUG) @@ -25,12 +24,11 @@ inline void SetNameIndexed(ID3D12Object*, LPCWSTR, UINT) { } #endif +namespace _winml { + // Forward declaration -namespace winrt::Windows::AI::MachineLearning::implementation { class D3DDeviceCache; -} -namespace Windows::AI::MachineLearning::Internal { struct ConstantBufferCS { UINT height; UINT width; @@ -55,23 +53,23 @@ class ImageConverter { Microsoft::WRL::ComPtr pipeline_state_; Microsoft::WRL::ComPtr descriptor_heap_; Microsoft::WRL::ComPtr D3D11_cached_texture_; - winrt::Windows::Media::VideoFrame converted_video_frame_; + wm::VideoFrame converted_video_frame_; CWinMLLock lock_; - void SyncD3D11ToD3D12(_In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, _In_ ID3D11Texture2D* D3D11_texture); - void SyncD3D12ToD3D11(_In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, _In_ ID3D11Texture2D* texture); - void ResetCommandList(_In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache); - Microsoft::WRL::ComPtr FetchOrCreateFenceOnDevice(_In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, _In_ ID3D11Device* D3D11_device); + void SyncD3D11ToD3D12(_In_ _winml::D3DDeviceCache& device_cache, _In_ ID3D11Texture2D* D3D11_texture); + void SyncD3D12ToD3D11(_In_ _winml::D3DDeviceCache& device_cache, _In_ ID3D11Texture2D* texture); + void ResetCommandList(_In_ _winml::D3DDeviceCache& device_cache); + Microsoft::WRL::ComPtr FetchOrCreateFenceOnDevice(_In_ _winml::D3DDeviceCache& device_cache, _In_ ID3D11Device* D3D11_device); Microsoft::WRL::ComPtr CreateTextureFromUnsupportedColorFormat( - const winrt::Windows::Media::IVideoFrame& video_frame, - const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, - const winrt::Windows::Graphics::Imaging::BitmapBounds& output_bounds, - winrt::Windows::Graphics::DirectX::DirectXPixelFormat new_format); + const wm::IVideoFrame& video_frame, + const wgi::BitmapBounds& input_bounds, + const wgi::BitmapBounds& output_bounds, + wgdx::DirectXPixelFormat new_format); static void CopyTextureIntoTexture( _In_ ID3D11Texture2D* texture_from, - _In_ const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, + _In_ const wgi::BitmapBounds& input_bounds, _Inout_ ID3D11Texture2D* texture_to); }; -} // namespace Windows::AI::MachineLearning::Internal \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Image/inc/TensorToVideoFrameConverter.h b/winml/lib/Api.Image/inc/TensorToVideoFrameConverter.h index 319f9a07406c9..8321891582904 100644 --- a/winml/lib/Api.Image/inc/TensorToVideoFrameConverter.h +++ b/winml/lib/Api.Image/inc/TensorToVideoFrameConverter.h @@ -6,21 +6,21 @@ #include "ImageConverter.h" #include "ImageConversionTypes.h" -namespace Windows::AI::MachineLearning::Internal { +namespace _winml { class ITensorToVideoFrameConverter { public: virtual void DX12TensorToVideoFrame( _In_ UINT32 batch_index, - _In_ winrt::Windows::AI::MachineLearning::LearningModelSession& session, + _In_ winml::LearningModelSession& session, _In_ ID3D12Resource* input_tensor, _In_ const ImageTensorDescription& tensor_description, - _Inout_ winrt::Windows::Media::VideoFrame& destination_video_frame) = 0; + _Inout_ wm::VideoFrame& destination_video_frame) = 0; virtual void SoftwareTensorToVideoFrame( - _In_ winrt::Windows::AI::MachineLearning::LearningModelSession& session, + _In_ winml::LearningModelSession& session, _In_ BYTE* CPU_tensor_to_convert, _In_ ImageTensorDescription tensor_description, - _Inout_ winrt::Windows::Media::VideoFrame& destination_video_frame) = 0; + _Inout_ wm::VideoFrame& destination_video_frame) = 0; }; class TensorToVideoFrameConverter : ITensorToVideoFrameConverter, public ImageConverter { @@ -31,18 +31,18 @@ class TensorToVideoFrameConverter : ITensorToVideoFrameConverter, public ImageCo // converts it to a VideoFrame backed by either a SoftwareBitmap or D3DSurface void DX12TensorToVideoFrame( _In_ UINT32 batch_index, - _In_ winrt::Windows::AI::MachineLearning::LearningModelSession& session, + _In_ winml::LearningModelSession& session, _In_ ID3D12Resource* input_tensor, _In_ const ImageTensorDescription& tensor_description, - _Inout_ winrt::Windows::Media::VideoFrame& destination_video_frame); + _Inout_ wm::VideoFrame& destination_video_frame); // Function takes in a byte pointer to a CPUTensor // converts it to VideoFrame backed by either a SoftwareBitmap or D3DSurface, void SoftwareTensorToVideoFrame( - _In_ winrt::Windows::AI::MachineLearning::LearningModelSession& session, + _In_ winml::LearningModelSession& session, _In_ BYTE* CPU_tensor_to_convert, _In_ ImageTensorDescription tensor_description, - _Inout_ winrt::Windows::Media::VideoFrame& destination_video_frame); + _Inout_ wm::VideoFrame& destination_video_frame); private: GUID _d3d11TextureGUID = {0x14bf1054, 0x6ce7, 0x4c00, {0xa1, 0x32, 0xb0, 0xf2, 0x11, 0x5D, 0xE0, 0x7f}}; // {14BF1054-6CE7-4C00-A132-B0F2115DE07F} @@ -58,23 +58,23 @@ class TensorToVideoFrameConverter : ITensorToVideoFrameConverter, public ImageCo void ConvertGPUTensorToSoftwareBitmap( _In_ UINT32 batch_index, _In_ ID3D12Resource* input_tensor, - _In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, + _In_ _winml::D3DDeviceCache& device_cache, _In_ const ImageTensorDescription& tensor_description, - _Inout_ winrt::Windows::Graphics::Imaging::SoftwareBitmap& software_bitmap); + _Inout_ wgi::SoftwareBitmap& software_bitmap); void ConvertGPUTensorToDX12Texture( _In_ UINT32 batch_index, _In_ ID3D12Resource* input_resource, - _In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, + _In_ _winml::D3DDeviceCache& device_cache, _In_ const ImageTensorDescription& tensor_description, _Inout_ ID3D12Resource* output_resource); void ConvertDX12TensorToUnsupportedVideoFrameFormat( _In_ UINT32 batch_index, _In_ ID3D12Resource* input_tensor, - _In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, + _In_ _winml::D3DDeviceCache& device_cache, _In_ const ImageTensorDescription& tensor_description, - _Inout_ winrt::Windows::Media::VideoFrame& unsupported_video_frame); + _Inout_ wm::VideoFrame& unsupported_video_frame); static D3D12_SHADER_RESOURCE_VIEW_DESC TensorToVideoFrameConverter::CreateSRVDescriptor( const UINT32 batch_index, @@ -84,10 +84,10 @@ class TensorToVideoFrameConverter : ITensorToVideoFrameConverter, public ImageCo static void ConvertCPUTensorToSoftwareBitmap( _In_ void* CPU_tensor, _In_ const ImageTensorDescription& tensor_description, - _Inout_ winrt::Windows::Graphics::Imaging::SoftwareBitmap& software_bitmap); + _Inout_ wgi::SoftwareBitmap& software_bitmap); static Microsoft::WRL::ComPtr CreateShareableD3D12Texture( const D3D11_TEXTURE2D_DESC& d3d11Desc, ID3D12Device* d3d12Device); }; -} // namespace Windows::AI::MachineLearning::Internal \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Image/inc/VideoFrameToTensorConverter.h b/winml/lib/Api.Image/inc/VideoFrameToTensorConverter.h index 82b2cfc0b6e80..e2e19c434ae5c 100644 --- a/winml/lib/Api.Image/inc/VideoFrameToTensorConverter.h +++ b/winml/lib/Api.Image/inc/VideoFrameToTensorConverter.h @@ -7,20 +7,20 @@ #include "ImageConversionHelpers.h" #include "ImageConversionTypes.h" -namespace Windows::AI::MachineLearning::Internal { +namespace _winml { class IVideoFrameToTensorConverter { public: virtual void VideoFrameToDX12Tensor( _In_ const UINT32 batch_index, - _In_ winrt::Windows::AI::MachineLearning::LearningModelSession& session, - _In_ const winrt::Windows::Media::IVideoFrame& input_video_frame, - _In_ const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, + _In_ winml::LearningModelSession& session, + _In_ const wm::IVideoFrame& input_video_frame, + _In_ const wgi::BitmapBounds& input_bounds, _In_ const ImageTensorDescription& tensor_description, _Inout_ ID3D12Resource* output_tensor) = 0; virtual void VideoFrameToSoftwareTensor( - _In_ const winrt::Windows::Media::IVideoFrame& input_video_frame, - _In_ const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, + _In_ const wm::IVideoFrame& input_video_frame, + _In_ const wgi::BitmapBounds& input_bounds, _In_ const ImageTensorDescription& tensor_description, _Out_ BYTE* output_CPU_tensor) = 0; }; @@ -38,9 +38,9 @@ class VideoFrameToTensorConverter : IVideoFrameToTensorConverter, public ImageCo // If the region of interest is the entire VideoFrame, the input BitmapBounds should describe the entire image. void VideoFrameToDX12Tensor( _In_ const UINT32 batch_index, - _In_ winrt::Windows::AI::MachineLearning::LearningModelSession& session, - _In_ const winrt::Windows::Media::IVideoFrame& input_video_frame, - _In_ const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, + _In_ winml::LearningModelSession& session, + _In_ const wm::IVideoFrame& input_video_frame, + _In_ const wgi::BitmapBounds& input_bounds, _In_ const ImageTensorDescription& tensor_description, _Inout_ ID3D12Resource* output_tensor); @@ -50,8 +50,8 @@ class VideoFrameToTensorConverter : IVideoFrameToTensorConverter, public ImageCo // {upperleft X, upperleft Y, width, height} to be turned into a tensor. // If the region of interest is the entire VideoFrame, the input BitmapBounds should describe the entire image. void VideoFrameToSoftwareTensor( - _In_ const winrt::Windows::Media::IVideoFrame& input_video_frame, - _In_ const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, + _In_ const wm::IVideoFrame& input_video_frame, + _In_ const wgi::BitmapBounds& input_bounds, _In_ const ImageTensorDescription& tensor_description, _Out_ BYTE* output_CPU_tensor); @@ -67,16 +67,16 @@ class VideoFrameToTensorConverter : IVideoFrameToTensorConverter, public ImageCo void ConvertSoftwareBitmapToGPUTensor( _In_ const UINT32 batch_index, - _In_ const winrt::Windows::Media::IVideoFrame& videoFrame, - _In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, - _In_ const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, + _In_ const wm::IVideoFrame& videoFrame, + _In_ _winml::D3DDeviceCache& device_cache, + _In_ const wgi::BitmapBounds& input_bounds, _In_ const ImageTensorDescription& tensor_description, _Inout_ ID3D12Resource* pOutputResource); void ConvertDX12TextureToGPUTensor( _In_ const UINT32 batch_index, _In_ ID3D12Resource* pInputResource, - _In_ winrt::Windows::AI::MachineLearning::implementation::D3DDeviceCache& device_cache, + _In_ _winml::D3DDeviceCache& device_cache, _In_ const ImageTensorDescription& tensor_description, _Inout_ ID3D12Resource* output_resource); @@ -86,9 +86,9 @@ class VideoFrameToTensorConverter : IVideoFrameToTensorConverter, public ImageCo const ImageTensorDescription& description); static void VideoFrameToTensorConverter::ConvertSoftwareBitmapToCPUTensor( - _In_ const winrt::Windows::Graphics::Imaging::SoftwareBitmap& software_bitmap, + _In_ const wgi::SoftwareBitmap& software_bitmap, _In_ const ImageTensorDescription& tensor_description, - _In_ const winrt::Windows::Graphics::Imaging::BitmapBounds& input_bounds, + _In_ const wgi::BitmapBounds& input_bounds, _Inout_ void* CPU_tensor); }; -} // namespace Windows::AI::MachineLearning::Internal +} // namespace _winml diff --git a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp index 32796ed45b4a6..aee8a7e69a7ff 100644 --- a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.cpp @@ -7,7 +7,7 @@ #include "OnnxruntimeEngine.h" #include "OnnxruntimeErrors.h" -using namespace Windows::AI::MachineLearning; +using namespace _winml; HRESULT OnnxruntimeCpuSessionBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory) { engine_factory_ = engine_factory; diff --git a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h index e1ad853429952..fe531ac7eb641 100644 --- a/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h +++ b/winml/lib/Api.Ort/OnnxruntimeCpuSessionBuilder.h @@ -5,7 +5,7 @@ #include "OnnxruntimeSessionBuilder.h" -namespace Windows::AI::MachineLearning { +namespace _winml { class OnnxruntimeEngineFactory; @@ -29,4 +29,4 @@ class OnnxruntimeCpuSessionBuilder : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::ComPtr engine_factory_; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp index 4a4c81352ec3f..3e1ae7de3a868 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.cpp @@ -18,8 +18,6 @@ #include "OnnxruntimeErrors.h" -using namespace winrt::Windows::AI::MachineLearning; - // BitmapPixelFormat constants static const char* c_bitmap_pixel_format_key = "Image.BitmapPixelFormat"; static const char* c_supported_pixel_formats[] = @@ -44,7 +42,7 @@ static const char* c_supported_nominal_ranges[] = { "NominalRange_0_255"}; -namespace Windows::AI::MachineLearning { +namespace _winml { // Forward declare CreateFeatureDescriptor static winml::ILearningModelFeatureDescriptor @@ -53,109 +51,109 @@ CreateFeatureDescriptor( const OnnxruntimeValueInfoWrapper* feature_descriptor, const std::unordered_map& metadata); -static TensorKind +static winml::TensorKind TensorKindFromONNXTensorElementDataType(ONNXTensorElementDataType dataType) { switch (dataType) { case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - return TensorKind::Boolean; + return winml::TensorKind::Boolean; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { - return TensorKind::String; + return winml::TensorKind::String; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - return TensorKind::Float16; + return winml::TensorKind::Float16; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - return TensorKind::Float; + return winml::TensorKind::Float; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - return TensorKind::Double; + return winml::TensorKind::Double; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - return TensorKind::Int8; + return winml::TensorKind::Int8; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { - return TensorKind::Int16; + return winml::TensorKind::Int16; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - return TensorKind::Int32; + return winml::TensorKind::Int32; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - return TensorKind::Int64; + return winml::TensorKind::Int64; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - return TensorKind::UInt8; + return winml::TensorKind::UInt8; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { - return TensorKind::UInt16; + return winml::TensorKind::UInt16; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { - return TensorKind::UInt32; + return winml::TensorKind::UInt32; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { - return TensorKind::UInt64; + return winml::TensorKind::UInt64; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: { - return TensorKind::Complex64; + return winml::TensorKind::Complex64; } case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: { - return TensorKind::Complex128; + return winml::TensorKind::Complex128; } default: { - return TensorKind::Undefined; + return winml::TensorKind::Undefined; } } } static std::string -TensorKindToString(TensorKind tensorKind) { +TensorKindToString(winml::TensorKind tensorKind) { switch (tensorKind) { - case TensorKind::Float: { + case winml::TensorKind::Float: { return "float"; } - case TensorKind::UInt8: { + case winml::TensorKind::UInt8: { return "uint8"; } - case TensorKind::Int8: { + case winml::TensorKind::Int8: { return "int8"; } - case TensorKind::UInt16: { + case winml::TensorKind::UInt16: { return "uint16"; } - case TensorKind::Int16: { + case winml::TensorKind::Int16: { return "int16"; } - case TensorKind::Int32: { + case winml::TensorKind::Int32: { return "int32"; } - case TensorKind::Int64: { + case winml::TensorKind::Int64: { return "int64"; } - case TensorKind::String: { + case winml::TensorKind::String: { return "string"; } - case TensorKind::Boolean: { + case winml::TensorKind::Boolean: { return "boolean"; } - case TensorKind::Float16: { + case winml::TensorKind::Float16: { return "float16"; } - case TensorKind::Double: { + case winml::TensorKind::Double: { return "double"; } - case TensorKind::UInt32: { + case winml::TensorKind::UInt32: { return "uint32"; } - case TensorKind::UInt64: { + case winml::TensorKind::UInt64: { return "uint64"; } - case TensorKind::Complex64: { + case winml::TensorKind::Complex64: { return "complex64"; } - case TensorKind::Complex128: { + case winml::TensorKind::Complex128: { return "complex128"; } - case TensorKind::Undefined: + case winml::TensorKind::Undefined: default: { return "undefined"; } @@ -264,44 +262,40 @@ CreateBitmapPixelFormatAndAlphaModeInfo( static winmlp::ImageColorSpaceGamma CreateImageColorSpaceGamma(const char* color_space_gamma) { - using namespace winmlp; - if (color_space_gamma) { auto comparator = std::bind(std::strcmp, color_space_gamma, std::placeholders::_1); if (0 == comparator("Linear")) { - return ImageColorSpaceGamma::ImageColorSpaceGamma_Linear; + return winmlp::ImageColorSpaceGamma::ImageColorSpaceGamma_Linear; } else if (0 == comparator("SRGB")) { - return ImageColorSpaceGamma::ImageColorSpaceGamma_SRGB; + return winmlp::ImageColorSpaceGamma::ImageColorSpaceGamma_SRGB; } } // default value, non conforming values are overridden to SRGB - return ImageColorSpaceGamma::ImageColorSpaceGamma_SRGB; + return winmlp::ImageColorSpaceGamma::ImageColorSpaceGamma_SRGB; } static winmlp::ImageNominalPixelRange CreateImageNominalPixelRange(const char* nominal_range) { - using namespace winmlp; - if (nominal_range) { auto comparator = std::bind(std::strcmp, nominal_range, std::placeholders::_1); if (0 == comparator("NominalRange_0_255")) { - return ImageNominalPixelRange::ImageNominalPixelRange_NominalRange_0_255; + return winmlp::ImageNominalPixelRange::ImageNominalPixelRange_NominalRange_0_255; } else if (0 == comparator("Normalized_0_1")) { - return ImageNominalPixelRange::ImageNominalPixelRange_Normalized_0_1; + return winmlp::ImageNominalPixelRange::ImageNominalPixelRange_Normalized_0_1; } else if (0 == comparator("Normalized_1_1")) { - return ImageNominalPixelRange::ImageNominalPixelRange_Normalized_1_1; + return winmlp::ImageNominalPixelRange::ImageNominalPixelRange_Normalized_1_1; } else if (0 == comparator("NominalRange_16_235")) { - return ImageNominalPixelRange::ImageNominalPixelRange_NominalRange_16_235; + return winmlp::ImageNominalPixelRange::ImageNominalPixelRange_NominalRange_16_235; } } // default value, non conforming values are overridden to NominalRange_0_255 - return ImageNominalPixelRange::ImageNominalPixelRange_NominalRange_0_255; + return winmlp::ImageNominalPixelRange::ImageNominalPixelRange_NominalRange_0_255; } enum class TensorType { Tensor_Data, @@ -338,8 +332,8 @@ GetTensorType( THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), engine_factory->UseOrtApi()); - auto tensor_kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); - auto is_float_tensor = tensor_kind == TensorKind::Float; + auto tensor_kind = _winml::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); + auto is_float_tensor = tensor_kind == winml::TensorKind::Float; if (!is_float_tensor) { log_stream << "Unsupported image with " << TensorKindToString(tensor_kind) << " found." << std::endl; @@ -418,7 +412,7 @@ CreateTensorFeatureDescriptor( THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), engine_factory->UseOrtApi()); - auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); + auto kind = _winml::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); auto descriptor = winrt::make( feature_descriptor->name_, @@ -453,7 +447,7 @@ CreateImageFeatureDescriptor( ONNXTensorElementDataType tensor_element_data_type; THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetTensorElementType(tensor_info, &tensor_element_data_type), engine_factory->UseOrtApi()); - auto kind = WinML::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); + auto kind = _winml::TensorKindFromONNXTensorElementDataType(tensor_element_data_type); // pixel format and alpha auto pixel_format_value = FetchMetadataValueOrNull(metadata, c_bitmap_pixel_format_key); @@ -506,7 +500,7 @@ CreateMapFeatureDescriptor( THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetMapKeyType(map_info, &map_key_data_type), engine_factory->UseOrtApi()); - auto key_kind = WinML::TensorKindFromONNXTensorElementDataType(map_key_data_type); + auto key_kind = _winml::TensorKindFromONNXTensorElementDataType(map_key_data_type); OrtTypeInfo* map_value_type_info; THROW_IF_NOT_OK_MSG(engine_factory->UseOrtApi()->GetMapValueType(map_info, &map_value_type_info), @@ -624,10 +618,10 @@ OnnxruntimeDescriptorConverter::ConvertToLearningModelDescriptors(const std::vec auto features = winrt::single_threaded_vector(); for (const auto& descriptor : descriptors) { - auto learning_model_descriptor = WinML::CreateFeatureDescriptor(engine_factory_.Get(), &descriptor, metadata_); + auto learning_model_descriptor = _winml::CreateFeatureDescriptor(engine_factory_.Get(), &descriptor, metadata_); features.Append(learning_model_descriptor); } return features; } -} // namespace Windows::AI::MachineLearning +} // namespace _winml diff --git a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h index 1152f7f0ec530..dda03d431f7c7 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h +++ b/winml/lib/Api.Ort/OnnxruntimeDescriptorConverter.h @@ -4,7 +4,7 @@ #include "pch.h" -namespace Windows::AI::MachineLearning { +namespace _winml { struct OnnxruntimeValueInfoWrapper { OnnxruntimeValueInfoWrapper() : type_info_(UniqueOrtTypeInfo(nullptr, nullptr)) {} @@ -30,4 +30,4 @@ struct OnnxruntimeDescriptorConverter { const std::unordered_map& metadata_; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp index 7e55919ca8d21..71c470fd1c248 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp @@ -10,7 +10,7 @@ #include "OnnxruntimeErrors.h" #include "LearningModelDevice.h" -using namespace Windows::AI::MachineLearning; +using namespace _winml; HRESULT OnnxruntimeDmlSessionBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory, ID3D12Device* device, ID3D12CommandQueue* queue, bool metacommands_enabled) { engine_factory_ = engine_factory; diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h index a0742e4de262a..3a5f78ee17810 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h @@ -5,7 +5,7 @@ #include "OnnxruntimeSessionBuilder.h" -namespace Windows::AI::MachineLearning { +namespace _winml { class OnnxruntimeEngineFactory; @@ -32,4 +32,4 @@ class OnnxruntimeDmlSessionBuilder : public Microsoft::WRL::RuntimeClass< bool metacommands_enabled_ = true; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp index 0ed27cdb8feb1..b572442307def 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp @@ -11,7 +11,7 @@ #include "OnnxruntimeSessionBuilder.h" #include "OnnxruntimeErrors.h" -using namespace WinML; +using namespace _winml; static const OrtApi* GetVersionedOrtApi() { static const uint32_t ort_version = 2; @@ -175,7 +175,7 @@ static auto GetStrings(const OrtApi* ort_api, const OrtValue* ort_value, return std::make_shared>(std::move(strings), std::move(buffer)); } -HRESULT OnnxruntimeValue::GetResource(WinML::Resource& out) { +HRESULT OnnxruntimeValue::GetResource(_winml::Resource& out) { auto ort_api = engine_->GetEngineFactory()->UseOrtApi(); auto winml_adapter_api = engine_->GetEngineFactory()->UseWinmlAdapterApi(); @@ -193,13 +193,13 @@ HRESULT OnnxruntimeValue::GetResource(WinML::Resource& out) { RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, mutable_data, reinterpret_cast(&resource)), ort_api); - out = WinML::Resource(resource, [](void*) { /*do nothing, as this pointer is actually a com pointer! */ }); + out = _winml::Resource(resource, [](void*) { /*do nothing, as this pointer is actually a com pointer! */ }); } else { int is_tensor; RETURN_HR_IF_NOT_OK_MSG(ort_api->IsTensor(value_.get(), &is_tensor), ort_api); if (is_tensor == 0) { - out = WinML::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ }); + out = _winml::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ }); return S_OK; } @@ -215,9 +215,9 @@ HRESULT OnnxruntimeValue::GetResource(WinML::Resource& out) { if (data_type == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { auto strings = GetStrings(ort_api, value_.get(), info); auto string_data = strings->first.data(); - out = WinML::Resource(string_data, [capture_strings = strings](void*) { /*This deleter does nothing but capture the strings, which extends the lifetime of the returned strings.*/ }); + out = _winml::Resource(string_data, [capture_strings = strings](void*) { /*This deleter does nothing but capture the strings, which extends the lifetime of the returned strings.*/ }); } else { - out = WinML::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ }); + out = _winml::Resource(mutable_data, [](void*) { /*do nothing, as this pointer is actually owned elsewhere in ORT! */ }); } } return S_OK; @@ -696,7 +696,7 @@ HRESULT OnnxruntimeEngine::CreateStringTensorValueFromDataWithCopy(const char* c RETURN_IF_FAILED(CreateTensorValueFromDefaultAllocator(shape, count, winml::TensorKind::String, out)); - auto ort_value = reinterpret_cast(*out)->UseOrtValue(); + auto ort_value = reinterpret_cast<_winml::OnnxruntimeValue*>(*out)->UseOrtValue(); RETURN_HR_IF_NOT_OK_MSG(ort_api->FillStringTensor(ort_value, reinterpret_cast(data), num_elements), ort_api); return S_OK; @@ -791,7 +791,7 @@ typename auto CppwinrtTypeToOrtType(TCppwinrtType raw) { template <> typename auto CppwinrtTypeToOrtType(winrt::hstring raw) { - return WinML::Strings::UTF8FromHString(raw); + return _winml::Strings::UTF8FromHString(raw); } template @@ -801,7 +801,7 @@ typename auto ResourceTypeToCppwinrtType(typename AbiTypeInfo::Resourc template <> typename auto ResourceTypeToCppwinrtType(typename AbiTypeInfo::ResourceType value) { - return WinML::Strings::HStringFromUTF8(value.data(), value.size()); + return _winml::Strings::HStringFromUTF8(value.data(), value.size()); } template @@ -809,8 +809,8 @@ auto CastToWinrtMap(IInspectable* map_insp) { using cppwinrt_key_type = typename AbiTypeInfo::CppWinRTType; using cppwinrt_value_type = typename AbiTypeInfo::CppWinRTType; - ::winrt::Windows::Foundation::IInspectable map_inspectable; - ::winrt::Windows::Foundation::Collections::IMap map; + wf::IInspectable map_inspectable; + wfc::IMap map; winrt::copy_from_abi(map_inspectable, map_insp); map_inspectable.as(map); return map; @@ -821,10 +821,10 @@ auto CastToWinrtSequenceOfMaps(IInspectable* sequence_insp) { using cppwinrt_key_type = typename AbiTypeInfo::CppWinRTType; using cppwinrt_value_type = typename AbiTypeInfo::CppWinRTType; - using cppwinrt_element_map_type = ::winrt::Windows::Foundation::Collections::IMap; - using cppwinrt_sequence_type = ::winrt::Windows::Foundation::Collections::IVector; + using cppwinrt_element_map_type = wfc::IMap; + using cppwinrt_sequence_type = wfc::IVector; cppwinrt_sequence_type sequence; - ::winrt::Windows::Foundation::IInspectable sequence_inspectable; + wf::IInspectable sequence_inspectable; winrt::copy_from_abi(sequence_inspectable, sequence_insp); sequence_inspectable.as(sequence); return sequence; @@ -950,11 +950,11 @@ HRESULT CreateMapValue(OnnxruntimeEngine* engine, IInspectable* map_insp, winml: auto map = CastToWinrtMap(map_insp); std::vector shape = {static_cast(map.Size())}; - winrt::com_ptr key_value; + winrt::com_ptr<_winml::IValue> key_value; RETURN_IF_FAILED(engine->CreateTensorValueFromDefaultAllocator(shape.data(), shape.size(), key_kind, key_value.put())); auto keys_ort_value = static_cast(key_value.get())->UseOrtValue(); - winrt::com_ptr value_value; + winrt::com_ptr<_winml::IValue> value_value; RETURN_IF_FAILED(engine->CreateTensorValueFromDefaultAllocator(shape.data(), shape.size(), value_kind, value_value.put())); auto values_ort_value = static_cast(value_value.get())->UseOrtValue(); @@ -1004,9 +1004,9 @@ HRESULT CreateSequenceOfMapsValue(OnnxruntimeEngine* engine, IInspectable* seque auto ort_api = engine->UseOrtApi(); auto sequence = CastToWinrtSequenceOfMaps(sequence_insp); - std::vector> element_values; + std::vector> element_values; for (auto element : sequence) { - winrt::com_ptr element_value; + winrt::com_ptr<_winml::IValue> element_value; engine->CreateMapValue(reinterpret_cast(winrt::get_abi(element)), key_kind, value_kind, element_value.put()); element_values.push_back(element_value); } @@ -1045,12 +1045,12 @@ HRESULT OnnxruntimeEngine::CreateSequenceOfMapsValue(IInspectable* sequence, win } template -static HRESULT FillAbiSequence(IInspectable* sequence_insp, std::vector<::winrt::Windows::Foundation::IInspectable>& elements) { +static HRESULT FillAbiSequence(IInspectable* sequence_insp, std::vector& elements) { using cppwinrt_key_type = typename AbiTypeInfo::CppWinRTType; using cppwinrt_value_type = typename AbiTypeInfo::CppWinRTType; auto sequence = CastToWinrtSequenceOfMaps(sequence_insp); for (auto element : elements) { - ::winrt::Windows::Foundation::Collections::IMap map_element; + wfc::IMap map_element; element.as(map_element); sequence.Append(map_element); } @@ -1067,8 +1067,8 @@ static auto GetAbiSequenceFiller(winml::TensorKind key_kind, winml::TensorKind v THROW_HR(E_NOTIMPL); } -static winrt::Windows::Foundation::IInspectable CreateMap(winml::TensorKind key_kind, winml::TensorKind value_kind) { - winrt::Windows::Foundation::IInspectable map_insp; +static wf::IInspectable CreateMap(winml::TensorKind key_kind, winml::TensorKind value_kind) { + wf::IInspectable map_insp; if (key_kind == winml::TensorKind::String && value_kind == winml::TensorKind::Float) { auto map = winrt::single_threaded_map(); map.as(map_insp); @@ -1092,7 +1092,7 @@ HRESULT OnnxruntimeEngine::FillSequenceOfMapsValue(IInspectable* sequence, winml RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValueCount(ort_sequence_value, &num_elements), ort_api); // get the elements - std::vector<::winrt::Windows::Foundation::IInspectable> element_map_inspectables; + std::vector element_map_inspectables; for (size_t index = 0; index < num_elements; index++) { OrtValue* elements_ort_value = nullptr; RETURN_HR_IF_NOT_OK_MSG(ort_api->GetValue(ort_sequence_value, static_cast(index), ort_allocator, &elements_ort_value), ort_api); @@ -1101,7 +1101,7 @@ HRESULT OnnxruntimeEngine::FillSequenceOfMapsValue(IInspectable* sequence, winml winrt::com_ptr element_value; RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(element_value.put(), this, std::move(unique_element_value), UniqueOrtAllocator(nullptr, nullptr))); - ::winrt::Windows::Foundation::IInspectable map_inspectable = CreateMap(key_kind, value_kind); + wf::IInspectable map_inspectable = CreateMap(key_kind, value_kind); RETURN_IF_FAILED(FillFromMapValue(reinterpret_cast(winrt::get_abi(map_inspectable)), key_kind, value_kind, element_value.get())); element_map_inspectables.push_back(map_inspectable); } @@ -1110,7 +1110,7 @@ HRESULT OnnxruntimeEngine::FillSequenceOfMapsValue(IInspectable* sequence, winml return S_OK; } -HRESULT OnnxruntimeEngine::GetSequenceOfTensorValues(_In_ WinML::IValue* sequence_value, _Out_ std::vector>& out_values) { +HRESULT OnnxruntimeEngine::GetSequenceOfTensorValues(_In_ _winml::IValue* sequence_value, _Out_ std::vector>& out_values) { auto ort_api = engine_factory_->UseOrtApi(); auto onnxruntime_squence_value = static_cast(sequence_value); auto ort_sequence_value = onnxruntime_squence_value->UseOrtValue(); @@ -1285,9 +1285,9 @@ HRESULT OnnxruntimeEngine::FillFromMapValue(IInspectable* map, winml::TensorKind std::vector keys_shape; keys_value->GetTensorShape(keys_shape); - WinML::Resource keys_data; + _winml::Resource keys_data; RETURN_IF_FAILED(keys_value->GetResource(keys_data)); - WinML::Resource values_data; + _winml::Resource values_data; RETURN_IF_FAILED(values_value->GetResource(values_data)); auto num_elements = static_cast(ShapeSize(keys_shape.data(), keys_shape.size())); @@ -1336,7 +1336,7 @@ STDMETHODIMP OnnxruntimeEngineFactory::CreateModel(_In_ void* data, _In_ size_t return S_OK; } -STDMETHODIMP OnnxruntimeEngineFactory::CreateEngineBuilder(_Outptr_ Windows::AI::MachineLearning::IEngineBuilder** out) { +STDMETHODIMP OnnxruntimeEngineFactory::CreateEngineBuilder(_Outptr_ _winml::IEngineBuilder** out) { RETURN_IF_FAILED(EnsureEnvironment()); Microsoft::WRL::ComPtr onnxruntime_engine_builder; RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_engine_builder, this)); @@ -1370,7 +1370,7 @@ HRESULT OnnxruntimeEngineFactory::CreateCustomRegistry(IMLOperatorRegistry** reg return S_OK; } -STDAPI CreateOnnxruntimeEngineFactory(_Out_ Windows::AI::MachineLearning::IEngineFactory** engine_factory) { +STDAPI CreateOnnxruntimeEngineFactory(_Out_ _winml::IEngineFactory** engine_factory) { Microsoft::WRL::ComPtr onnxruntime_engine_factory; RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize(&onnxruntime_engine_factory)); RETURN_IF_FAILED(onnxruntime_engine_factory.CopyTo(engine_factory)); diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.h b/winml/lib/Api.Ort/OnnxruntimeEngine.h index 74a5945ce67e8..37e8affb3d660 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.h +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.h @@ -5,7 +5,7 @@ #include -namespace Windows::AI::MachineLearning { +namespace _winml { class OnnxruntimeEngineBuilder; class OnnxruntimeEngineFactory; @@ -29,7 +29,7 @@ class OnnxruntimeValue : public Microsoft::WRL::RuntimeClass< STDMETHOD(IsCpu) (bool* out) override; STDMETHOD(GetResource) - (WinML::Resource& resource) override; + (_winml::Resource& resource) override; STDMETHOD(IsTensor) (bool* out) override; STDMETHOD(IsOfTensorType) @@ -111,7 +111,7 @@ class OnnxruntimeEngine : public Microsoft::WRL::RuntimeClass< (IInspectable* sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue* value) override; STDMETHOD(GetSequenceOfTensorValues) - (WinML::IValue* sequence_value, _Out_ std::vector>& out_values) override; + (_winml::IValue* sequence_value, _Out_ std::vector>& out_values) override; OrtSession* UseOrtSession(); const OrtApi* UseOrtApi(); @@ -152,4 +152,4 @@ class OnnxruntimeEngineFactory : public Microsoft::WRL::RuntimeClass< std::mutex mutex_; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp index 4813b9659d1fa..62841ef65e1b3 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp @@ -12,14 +12,14 @@ #endif #include "OnnxruntimeErrors.h" -using namespace WinML; +using namespace _winml; HRESULT OnnxruntimeEngineBuilder::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_factory) { engine_factory_ = engine_factory; return S_OK; } -STDMETHODIMP OnnxruntimeEngineBuilder::CreateEngine(Windows::AI::MachineLearning::IEngine** out) { +STDMETHODIMP OnnxruntimeEngineBuilder::CreateEngine(_winml::IEngine** out) { auto ort_api = engine_factory_->UseOrtApi(); Microsoft::WRL::ComPtr onnxruntime_session_builder; diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h index 14597ed546b45..058e02c71c82c 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h @@ -3,7 +3,7 @@ #include "iengine.h" -namespace Windows::AI::MachineLearning { +namespace _winml { class OnnxruntimeEngineBuilder : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, @@ -37,4 +37,4 @@ class OnnxruntimeEngineBuilder : public Microsoft::WRL::RuntimeClass< std::optional batch_size_override_; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp index f62d077de711e..af0188e827620 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.cpp @@ -7,7 +7,7 @@ #include "core/platform/windows/TraceLoggingConfig.h" #include -using namespace Windows::AI ::MachineLearning; +using namespace _winml; static bool debug_output_ = false; diff --git a/winml/lib/Api.Ort/OnnxruntimeEnvironment.h b/winml/lib/Api.Ort/OnnxruntimeEnvironment.h index 2d81579ce2ad5..5d47266cf9a6d 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEnvironment.h +++ b/winml/lib/Api.Ort/OnnxruntimeEnvironment.h @@ -6,7 +6,7 @@ #pragma warning(push) #pragma warning(disable : 4505) -namespace Windows::AI ::MachineLearning { +namespace _winml { class OnnxruntimeEnvironment { public: @@ -19,6 +19,6 @@ class OnnxruntimeEnvironment { UniqueOrtEnv ort_env_; }; -} // namespace Windows::AI::MachineLearning +} // namespace _winml #pragma warning(pop) \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeErrors.h b/winml/lib/Api.Ort/OnnxruntimeErrors.h index 7d041870a3bdc..8af3e23506924 100644 --- a/winml/lib/Api.Ort/OnnxruntimeErrors.h +++ b/winml/lib/Api.Ort/OnnxruntimeErrors.h @@ -46,7 +46,7 @@ inline HRESULT OrtErrorCodeToHRESULT(OrtErrorCode status) noexcept { auto error_message = ort_api->GetErrorMessage(_status); \ HRESULT hresult = OrtErrorCodeToHRESULT(error_code); \ telemetry_helper.LogRuntimeError(hresult, std::string(error_message), __FILE__, __FUNCTION__, __LINE__); \ - auto message = WinML::Strings::HStringFromUTF8(error_message); \ + auto message = _winml::Strings::HStringFromUTF8(error_message); \ RoOriginateError(hresult, reinterpret_cast(winrt::get_abi(message))); \ return hresult; \ } \ @@ -60,7 +60,7 @@ inline HRESULT OrtErrorCodeToHRESULT(OrtErrorCode status) noexcept { auto error_message = ort_api->GetErrorMessage(_status); \ HRESULT hresult = OrtErrorCodeToHRESULT(error_code); \ telemetry_helper.LogRuntimeError(hresult, std::string(error_message), __FILE__, __FUNCTION__, __LINE__); \ - auto message = WinML::Strings::HStringFromUTF8(error_message); \ + auto message = _winml::Strings::HStringFromUTF8(error_message); \ throw winrt::hresult_error(hresult, message); \ } \ } while (0) diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.cpp b/winml/lib/Api.Ort/OnnxruntimeModel.cpp index 562bf505d86f3..5e0f76ae912c1 100644 --- a/winml/lib/Api.Ort/OnnxruntimeModel.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeModel.cpp @@ -10,7 +10,7 @@ #include "OnnxruntimeEngine.h" #include "OnnxruntimeErrors.h" -using namespace Windows::AI::MachineLearning; +using namespace _winml; struct winml_adapter_api_model_feature_helper { decltype(WinmlAdapterApi::ModelGetInputCount) GetCount; @@ -70,7 +70,7 @@ HRESULT ModelInfo::RuntimeClassInitialize(OnnxruntimeEngineFactory* engine_facto std::string(metadata_value, metadata_value_len)); } - WinML::OnnxruntimeDescriptorConverter converter(engine_factory, model_metadata_); + _winml::OnnxruntimeDescriptorConverter converter(engine_factory, model_metadata_); static const winml_adapter_api_model_feature_helper input_helpers = { winml_adapter_api->ModelGetInputCount, @@ -151,8 +151,8 @@ STDMETHODIMP ModelInfo::GetVersion(int64_t* out) { STDMETHODIMP ModelInfo::GetModelMetadata(ABI::Windows::Foundation::Collections::IMapView** metadata) { std::unordered_map map_copy; for (auto& pair : model_metadata_) { - auto metadata_key = WinML::Strings::HStringFromUTF8(pair.first); - auto metadata_value = WinML::Strings::HStringFromUTF8(pair.second); + auto metadata_key = _winml::Strings::HStringFromUTF8(pair.first); + auto metadata_value = _winml::Strings::HStringFromUTF8(pair.second); map_copy.emplace(std::move(metadata_key), std::move(metadata_value)); } auto map = winrt::single_threaded_map(std::move(map_copy)); diff --git a/winml/lib/Api.Ort/OnnxruntimeModel.h b/winml/lib/Api.Ort/OnnxruntimeModel.h index 47325780221aa..1c34651bb06e3 100644 --- a/winml/lib/Api.Ort/OnnxruntimeModel.h +++ b/winml/lib/Api.Ort/OnnxruntimeModel.h @@ -5,7 +5,7 @@ #include "iengine.h" -namespace Windows::AI::MachineLearning { +namespace _winml { class OnnxruntimeEngineFactory; @@ -77,4 +77,4 @@ class OnnruntimeModel : public Microsoft::WRL::RuntimeClass< std::optional> metadata_cache_; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h index 9924e96cc345e..dbf5b086c19d9 100644 --- a/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h +++ b/winml/lib/Api.Ort/OnnxruntimeSessionBuilder.h @@ -3,7 +3,7 @@ #pragma once -namespace Windows::AI::MachineLearning { +namespace _winml { // The IOrtSessionBuilder offers an abstraction over the creation of // InferenceSession, that enables the creation of the session based on a device (CPU/DML). @@ -20,4 +20,4 @@ IOrtSessionBuilder : IUnknown { OrtSession * session) = 0; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h b/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h index d7e3cbfa2f6fc..e22864c1b47da 100644 --- a/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h +++ b/winml/lib/Api.Ort/inc/OnnxruntimeProvider.h @@ -5,4 +5,4 @@ #include "iengine.h" -STDAPI CreateOnnxruntimeEngineFactory(_Out_ Windows::AI::MachineLearning::IEngineFactory** engine_factory); \ No newline at end of file +STDAPI CreateOnnxruntimeEngineFactory(_Out_ _winml::IEngineFactory** engine_factory); \ No newline at end of file diff --git a/winml/lib/Api/FeatureValues.h b/winml/lib/Api/FeatureValues.h index df68a8fe65a72..b2420e32a9986 100644 --- a/winml/lib/Api/FeatureValues.h +++ b/winml/lib/Api/FeatureValues.h @@ -30,15 +30,15 @@ // CREATE_TENSOR is used by data tensor types to implement common functionality #define CREATE_TENSOR(type, element_type, element_view_type) \ - namespace winrt::Windows::AI::MachineLearning::implementation { \ - struct type : public WinML::TensorBase< \ + namespace WINMLP { \ + struct type : public _winml::TensorBase< \ element_type, \ element_view_type, \ type, \ I##type, \ type##T> { \ + _winml::ILotusValueProviderPrivate>> { \ using Base = \ TensorBase< \ element_type, \ @@ -48,7 +48,7 @@ type##T< \ type, \ ITensorNative, \ - WinML::ILotusValueProviderPrivate>>; \ + _winml::ILotusValueProviderPrivate>>; \ \ type() = default; \ \ @@ -59,7 +59,7 @@ type(std::vector const& shape, ID3D12Resource* pResource) : Base(shape, pResource){}; \ }; \ } \ - namespace winrt::Windows::AI::MachineLearning::factory_implementation { \ + namespace WINML::factory_implementation { \ struct type : type##T { \ STDMETHOD(CreateFromD3D12Resource) \ (ID3D12Resource * value, __int64* shape, int shapeSize, IUnknown** result) { \ @@ -83,7 +83,7 @@ CREATE_TENSOR(TensorUInt32Bit, uint32_t, uint32_t) CREATE_TENSOR(TensorInt32Bit, int32_t, int32_t) CREATE_TENSOR(TensorUInt64Bit, uint64_t, uint64_t) CREATE_TENSOR(TensorInt64Bit, int64_t, int64_t) -CREATE_TENSOR(TensorFloat16Bit, WinML::Half, float) +CREATE_TENSOR(TensorFloat16Bit, _winml::Half, float) #pragma warning(push) #pragma warning(disable : 4702) // Unreachable code (one of TensorBase's constructor unconditionally throws for @@ -93,8 +93,8 @@ CREATE_TENSOR(TensorString, std::string, winrt::hstring) // CREATE_MAP is used by map types to implement common functionality #define CREATE_MAP(type, key_type, value_type) \ - namespace winrt::Windows::AI::MachineLearning::implementation { \ - struct type : public WinML::MapBase { \ + namespace WINMLP { \ + struct type : public _winml::MapBase { \ type(wfc::IMap const& data) : MapBase(data){}; \ }; \ } @@ -110,8 +110,8 @@ CREATE_MAP(MapStringToString, hstring, hstring) // CREATE_SEQUENCE is used by sequence types to implement common functionality #define CREATE_SEQUENCE(type, element_type, raw_type) \ - namespace winrt::Windows::AI::MachineLearning::implementation { \ - struct type : public WinML::SequenceBase { \ + namespace WINMLP { \ + struct type : public _winml::SequenceBase { \ type(wfc::IIterable const& data) : SequenceBase(data){}; \ }; \ } @@ -132,24 +132,24 @@ CREATE_SEQUENCE(SequenceTensorUInt32Bit, winml::TensorUInt32Bit, uint32_t) CREATE_SEQUENCE(SequenceTensorInt32Bit, winml::TensorInt32Bit, int32_t) CREATE_SEQUENCE(SequenceTensorUInt64Bit, winml::TensorUInt64Bit, uint64_t) CREATE_SEQUENCE(SequenceTensorInt64Bit, winml::TensorInt64Bit, int64_t) -CREATE_SEQUENCE(SequenceTensorFloat16Bit, winml::TensorFloat16Bit, WinML::Half) +CREATE_SEQUENCE(SequenceTensorFloat16Bit, winml::TensorFloat16Bit, _winml::Half) CREATE_SEQUENCE(SequenceTensorString, winml::TensorString, std::string) -namespace Windows::AI::MachineLearning { +namespace _winml { template -inline winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue CreateTensorValueFromInspectable( - WinML::BindingType bindingType, +inline winml::ILearningModelFeatureValue CreateTensorValueFromInspectable( + _winml::BindingType bindingType, const wf::IInspectable& inspectable, const winml::ITensorFeatureDescriptor& descriptor) { - if (descriptor.TensorKind() == WinML::TensorKindFrom::Type) { + if (descriptor.TensorKind() == _winml::TensorKindFrom::Type) { if (auto vector = inspectable.try_as>()) { return TValueType::CreateFromIterable(descriptor.Shape(), vector); } - if (bindingType == WinML::BindingType::kInput) { + if (bindingType == _winml::BindingType::kInput) { // Feature inputs should be more permissive, and allow for views to be bound since they are read only if (auto vectorView = inspectable.try_as>()) { return TValueType::CreateFromIterable(descriptor.Shape(), vectorView); @@ -161,7 +161,7 @@ inline winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue CreateTen template <> inline winml::ILearningModelFeatureValue CreateTensorValueFromInspectable( - WinML::BindingType bindingType, + _winml::BindingType bindingType, const wf::IInspectable& inspectable, const winml::ITensorFeatureDescriptor& descriptor) { @@ -170,7 +170,7 @@ inline winml::ILearningModelFeatureValue CreateTensorValueFromInspectable>()) { return winmlp::TensorInt8Bit::CreateFromIterable(descriptor.Shape(), vectorView); @@ -182,8 +182,8 @@ inline winml::ILearningModelFeatureValue CreateTensorValueFromInspectable inline winml::ILearningModelFeatureValue CreateTensorValueFromInspectable( - WinML::BindingType bindingType, - const winrt::Windows::Foundation::IInspectable& inspectable, + _winml::BindingType bindingType, + const wf::IInspectable& inspectable, const winml::ITensorFeatureDescriptor& descriptor) { if (descriptor.TensorKind() == winml::TensorKind::Float16) { @@ -191,7 +191,7 @@ inline winml::ILearningModelFeatureValue CreateTensorValueFromInspectable>()) { return winmlp::TensorFloat16Bit::CreateFromIterable(descriptor.Shape(), vectorView); @@ -202,7 +202,7 @@ inline winml::ILearningModelFeatureValue CreateTensorValueFromInspectable(videoFrames); } - if (bindingType == WinML::BindingType::kInput) { + if (bindingType == _winml::BindingType::kInput) { // Allows to bind IVectorView as input. if (auto videoFrames = inspectable.try_as>()) { return (0 == videoFrames.Size()) ? nullptr : winrt::make(videoFrames); @@ -253,7 +253,7 @@ inline winml::ILearningModelFeatureValue CreateFeatureValueFromInspectable( return winmlp::MapInt64BitToString::Create(map); } - if (bindingType == WinML::BindingType::kInput) { + if (bindingType == _winml::BindingType::kInput) { // Feature inputs should be more permissive, and allow for views to be bound since they are read only if (auto map = inspectable.try_as>()) { return winmlp::MapStringToFloat::Create(map); @@ -329,7 +329,7 @@ inline winml::ILearningModelFeatureValue CreateFeatureValueFromInspectable( return winmlp::SequenceTensorString::Create(sequence); } - if (bindingType == WinML::BindingType::kInput) { + if (bindingType == _winml::BindingType::kInput) { // Feature inputs should be more permissive, and allow for views to be bound since they are read only if (auto sequence = inspectable.try_as>>()) { return winmlp::SequenceMapStringFloat::Create(sequence); @@ -410,4 +410,4 @@ inline winml::ILearningModelFeatureValue CreateFeatureValueFromInspectable( return nullptr; } -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api/ImageFeatureDescriptor.cpp b/winml/lib/Api/ImageFeatureDescriptor.cpp index 211a1da3d3383..0651e9258e3ba 100644 --- a/winml/lib/Api/ImageFeatureDescriptor.cpp +++ b/winml/lib/Api/ImageFeatureDescriptor.cpp @@ -7,7 +7,7 @@ #include -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { ImageFeatureDescriptor::ImageFeatureDescriptor( const char* name, const char* description, @@ -19,8 +19,8 @@ ImageFeatureDescriptor::ImageFeatureDescriptor( uint32_t width, uint32_t height, ImageNominalPixelRange nominal_pixel_range, - ImageColorSpaceGamma color_space_gamma) : name_(WinML::Strings::HStringFromUTF8(name)), - description_(WinML::Strings::HStringFromUTF8(description)), + ImageColorSpaceGamma color_space_gamma) : name_(_winml::Strings::HStringFromUTF8(name)), + description_(_winml::Strings::HStringFromUTF8(description)), tensor_kind_(tensor_kind), shape_(shape), is_required_(is_required), @@ -120,4 +120,4 @@ ImageColorSpaceGamma ImageFeatureDescriptor::GetColorSpaceGamma() { return color_space_gamma_; } -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP diff --git a/winml/lib/Api/ImageFeatureDescriptor.h b/winml/lib/Api/ImageFeatureDescriptor.h index c038987827add..e3043dd511fb3 100644 --- a/winml/lib/Api/ImageFeatureDescriptor.h +++ b/winml/lib/Api/ImageFeatureDescriptor.h @@ -5,7 +5,8 @@ #include "ImageFeatureDescriptor.g.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { + enum class ImageNominalPixelRange { ImageNominalPixelRange_NominalRange_0_255, ImageNominalPixelRange_Normalized_0_1, @@ -93,4 +94,5 @@ struct ImageFeatureDescriptor : ImageFeatureDescriptorT< ImageNominalPixelRange nominal_pixel_range_; ImageColorSpaceGamma color_space_gamma_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file + +} // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/ImageFeatureValue.cpp b/winml/lib/Api/ImageFeatureValue.cpp index 0fe90a726c6cd..206b890961a6d 100644 --- a/winml/lib/Api/ImageFeatureValue.cpp +++ b/winml/lib/Api/ImageFeatureValue.cpp @@ -20,35 +20,28 @@ #include "D3DDeviceCache.h" #include "TensorFeatureDescriptor.h" -using namespace WinML; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Graphics::DirectX::Direct3D11; -using namespace winrt::Windows::Graphics::DirectX; -using namespace Windows::AI::MachineLearning::Internal; -using namespace winrt::Windows::Foundation::Collections; - -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { struct ImageFeatureValue::ImageResourceMetadata { - std::vector Bounds; - ::Windows::AI::MachineLearning::Internal::ImageTensorDescription TensorDescriptor; + std::vector Bounds; + _winml::ImageTensorDescription TensorDescriptor; }; -Windows::AI::MachineLearning::ImageFeatureValue ImageFeatureValue::Create( +winml::ImageFeatureValue ImageFeatureValue::Create( uint32_t batchSize, - BitmapPixelFormat format, + wgi::BitmapPixelFormat format, uint32_t width, uint32_t height) { - std::vector videoFrames = {}; + std::vector videoFrames = {}; for (uint32_t i = 0; i < batchSize; ++i) { - SoftwareBitmap bitmap(format, width, height); - Windows::Media::VideoFrame frame = Windows::Media::VideoFrame::CreateWithSoftwareBitmap(bitmap); + wgi::SoftwareBitmap bitmap(format, width, height); + wm::VideoFrame frame = wm::VideoFrame::CreateWithSoftwareBitmap(bitmap); videoFrames.emplace_back(frame); } return make(winrt::single_threaded_vector(std::move(videoFrames))); } -Windows::AI::MachineLearning::ImageFeatureValue ImageFeatureValue::CreateFromVideoFrame(Windows::Media::VideoFrame const& image) try { +winml::ImageFeatureValue ImageFeatureValue::CreateFromVideoFrame(wm::VideoFrame const& image) try { return make(image); } WINML_CATCH_ALL @@ -58,29 +51,29 @@ void ImageFeatureValue::Initialize() { for (auto videoFrame : m_videoFrames) { // TODO: Check all videoFrames come from either CPU or GPU. if (auto surface = videoFrame.Direct3DSurface()) { - Direct3DSurfaceDescription description = surface.Description(); + wgdx::Direct3D11::Direct3DSurfaceDescription description = surface.Description(); m_widths.emplace_back(description.Width); m_heights.emplace_back(description.Height); } else { - ISoftwareBitmap softwarebitmap(videoFrame.SoftwareBitmap()); + wgi::ISoftwareBitmap softwarebitmap(videoFrame.SoftwareBitmap()); m_widths.emplace_back(softwarebitmap.PixelWidth()); m_heights.emplace_back(softwarebitmap.PixelHeight()); } } } -ImageFeatureValue::ImageFeatureValue(Windows::Media::VideoFrame const& image) { - std::vector frame = {image}; +ImageFeatureValue::ImageFeatureValue(wm::VideoFrame const& image) { + std::vector frame = {image}; m_videoFrames = winrt::single_threaded_vector(std::move(frame)); Initialize(); } -ImageFeatureValue::ImageFeatureValue(IVector const& images) : m_videoFrames(images) { +ImageFeatureValue::ImageFeatureValue(wfc::IVector const& images) : m_videoFrames(images) { Initialize(); } -ImageFeatureValue::ImageFeatureValue(IVectorView const& images) { - std::vector videoFrames = {}; +ImageFeatureValue::ImageFeatureValue(wfc::IVectorView const& images) { + std::vector videoFrames = {}; for (uint32_t i = 0; i < images.Size(); ++i) { videoFrames.emplace_back(images.GetAt(i)); } @@ -88,16 +81,16 @@ ImageFeatureValue::ImageFeatureValue(IVectorView con Initialize(); } -static std::optional GetBitmapPixelFormatFromMetadata(const IPropertySet& properties) { +static std::optional GetBitmapPixelFormatFromMetadata(const wfc::IPropertySet& properties) { if (properties != nullptr && properties.HasKey(L"BitmapPixelFormat")) { if (auto pixelFormatInspectable = properties.Lookup(L"BitmapPixelFormat")) { - auto pixelFormatValue = pixelFormatInspectable.as(); - auto pixelFormat = static_cast(pixelFormatValue.GetInt32()); + auto pixelFormatValue = pixelFormatInspectable.as(); + auto pixelFormat = static_cast(pixelFormatValue.GetInt32()); WINML_THROW_HR_IF_FALSE_MSG( WINML_ERR_INVALID_BINDING, - pixelFormat == BitmapPixelFormat::Rgba8 || - pixelFormat == BitmapPixelFormat::Bgra8 || - pixelFormat == BitmapPixelFormat::Gray8, + pixelFormat == wgi::BitmapPixelFormat::Rgba8 || + pixelFormat == wgi::BitmapPixelFormat::Bgra8 || + pixelFormat == wgi::BitmapPixelFormat::Gray8, "BitmapPixelFormat must be either Rgba8, Bgra8, or Gray8"); return pixelFormat; @@ -107,13 +100,13 @@ static std::optional GetBitmapPixelFormatFromMetadata(const I return {}; } -static std::optional GetBoundsFromMetadata(const IPropertySet& properties) { +static std::optional GetBoundsFromMetadata(const wfc::IPropertySet& properties) { if (properties != nullptr && properties.HasKey(L"BitmapBounds")) { if (auto boundsInspectable = properties.Lookup(L"BitmapBounds")) { - auto boundsPropertyValue = boundsInspectable.as(); + auto boundsPropertyValue = boundsInspectable.as(); WINML_THROW_HR_IF_FALSE_MSG( WINML_ERR_INVALID_BINDING, - boundsPropertyValue.Type() == Windows::Foundation::PropertyType::UInt32Array, + boundsPropertyValue.Type() == wf::PropertyType::UInt32Array, "BitmapBounds must reference a property value with type UInt32Array with 4 elements."); com_array bounds; @@ -123,18 +116,18 @@ static std::optional GetBoundsFromMetadata(const IPropertySet& pro bounds.size() == 4, "BitmapBounds must reference a property value with type UInt32Array with 4 elements."); - return Windows::Graphics::Imaging::BitmapBounds{bounds[0], bounds[1], bounds[2], bounds[3]}; + return wgi::BitmapBounds{bounds[0], bounds[1], bounds[2], bounds[3]}; } } return {}; } -BitmapBounds ImageFeatureValue::CenterAndCropBounds( +wgi::BitmapBounds ImageFeatureValue::CenterAndCropBounds( uint32_t idx, uint32_t desiredWidth, uint32_t desiredHeight) { - BitmapBounds bounds = {}; + wgi::BitmapBounds bounds = {}; float RequiredAspectRatio = static_cast(desiredWidth) / static_cast(desiredHeight); // crop to center while maintaining size @@ -155,12 +148,12 @@ BitmapBounds ImageFeatureValue::CenterAndCropBounds( return bounds; } -static ImageTensorDataType GetTensorDataTypeFromTensorKind(TensorKind kind) { +static _winml::ImageTensorDataType GetTensorDataTypeFromTensorKind(winml::TensorKind kind) { switch (kind) { - case TensorKind::Float: - return kImageTensorDataTypeFloat32; - case TensorKind::Float16: - return kImageTensorDataTypeFloat16; + case winml::TensorKind::Float: + return _winml::ImageTensorDataType::kImageTensorDataTypeFloat32; + case winml::TensorKind::Float16: + return _winml::ImageTensorDataType::kImageTensorDataTypeFloat16; default: WINML_THROW_HR_IF_FALSE_MSG(WINML_ERR_INVALID_BINDING, false, "Model image inputs must have tensor type of Float or Float16."); } @@ -168,11 +161,11 @@ static ImageTensorDataType GetTensorDataTypeFromTensorKind(TensorKind kind) { FAIL_FAST_HR(E_INVALIDARG); } -static unsigned GetSizeFromTensorDataType(ImageTensorDataType type) { +static unsigned GetSizeFromTensorDataType(_winml::ImageTensorDataType type) { switch (type) { - case kImageTensorDataTypeFloat32: + case _winml::ImageTensorDataType::kImageTensorDataTypeFloat32: return sizeof(float); - case kImageTensorDataTypeFloat16: + case _winml::ImageTensorDataType::kImageTensorDataTypeFloat16: return sizeof(uint16_t); default: WINML_THROW_HR_IF_FALSE_MSG(WINML_ERR_INVALID_BINDING, false, "Model image inputs must have tensor type of Float or Float16."); @@ -181,19 +174,20 @@ static unsigned GetSizeFromTensorDataType(ImageTensorDataType type) { FAIL_FAST_HR(E_INVALIDARG); } -static ImageTensorDescription CreateImageTensorDescriptor(TensorKind tensorKind, BitmapPixelFormat pixelFormat, uint32_t batchSize, uint32_t width, uint32_t height) { - ImageTensorDescription tensorDescription = {}; +static _winml::ImageTensorDescription CreateImageTensorDescriptor(winml::TensorKind tensorKind, wgi::BitmapPixelFormat pixelFormat, + uint32_t batchSize, uint32_t width, uint32_t height) { + _winml::ImageTensorDescription tensorDescription = {}; tensorDescription.dataType = GetTensorDataTypeFromTensorKind(tensorKind); tensorDescription.sizes[0] = batchSize; - if (pixelFormat == Windows::Graphics::Imaging::BitmapPixelFormat::Rgba8) { - tensorDescription.channelType = kImageTensorChannelTypeRGB8; + if (pixelFormat == wgi::BitmapPixelFormat::Rgba8) { + tensorDescription.channelType = _winml::ImageTensorChannelType::kImageTensorChannelTypeRGB8; tensorDescription.sizes[1] = 3; - } else if (pixelFormat == Windows::Graphics::Imaging::BitmapPixelFormat::Bgra8) { - tensorDescription.channelType = kImageTensorChannelTypeBGR8; + } else if (pixelFormat == wgi::BitmapPixelFormat::Bgra8) { + tensorDescription.channelType = _winml::ImageTensorChannelType::kImageTensorChannelTypeBGR8; tensorDescription.sizes[1] = 3; - } else if (pixelFormat == Windows::Graphics::Imaging::BitmapPixelFormat::Gray8) { - tensorDescription.channelType = kImageTensorChannelTypeGRAY8; + } else if (pixelFormat == wgi::BitmapPixelFormat::Gray8) { + tensorDescription.channelType = _winml::ImageTensorChannelType::kImageTensorChannelTypeGRAY8; tensorDescription.sizes[1] = 1; } else { THROW_HR(E_NOTIMPL); @@ -205,20 +199,20 @@ static ImageTensorDescription CreateImageTensorDescriptor(TensorKind tensorKind, } static void CPUTensorize( - Windows::Media::IVideoFrame videoFrame, - BitmapBounds bounds, - ImageTensorDescription tensorDescriptor, + wm::IVideoFrame videoFrame, + wgi::BitmapBounds bounds, + _winml::ImageTensorDescription tensorDescriptor, com_ptr spSession, void* pResource) { auto spDevice = spSession->Device().as(); - ConverterResourceDescription descriptor = {}; - descriptor.pixel_format = static_cast(BitmapPixelFormat::Bgra8); + _winml::ConverterResourceDescription descriptor = {}; + descriptor.pixel_format = static_cast(wgi::BitmapPixelFormat::Bgra8); descriptor.width = static_cast(tensorDescriptor.sizes[3]); descriptor.height = static_cast(tensorDescriptor.sizes[2]); descriptor.luid = {}; // Converted image on CPU - auto pooledConverter = PoolObjectWrapper::Create(spDevice->TensorizerStore()->Fetch(descriptor)); + auto pooledConverter = _winml::PoolObjectWrapper::Create(spDevice->TensorizerStore()->Fetch(descriptor)); //apply tensorization pooledConverter->Get()->Tensorizer->VideoFrameToSoftwareTensor( @@ -233,9 +227,9 @@ static void CPUTensorize( } static void CPUTensorize( - IVector videoFrames, - std::vector bounds, - ImageTensorDescription tensorDescriptor, + wfc::IVector videoFrames, + std::vector bounds, + _winml::ImageTensorDescription tensorDescriptor, com_ptr spSession, BYTE* resource, unsigned int singleFrameBufferSize) { @@ -247,26 +241,26 @@ static void CPUTensorize( } static void GPUTensorize( - IVector videoFrames, - std::vector bounds, - ImageTensorDescription tensorDescriptor, + wfc::IVector videoFrames, + std::vector bounds, + _winml::ImageTensorDescription tensorDescriptor, com_ptr spSession, ID3D12Resource* d3dResource, - WinML::BindingContext& context) { + _winml::BindingContext& context) { auto spDevice = spSession->Device().as(); - ConverterResourceDescription descriptor = {}; - descriptor.pixel_format = static_cast(DirectXPixelFormat::B8G8R8X8UIntNormalized); + _winml::ConverterResourceDescription descriptor = {}; + descriptor.pixel_format = static_cast(wgdx::DirectXPixelFormat::B8G8R8X8UIntNormalized); descriptor.width = static_cast(tensorDescriptor.sizes[3]); descriptor.height = static_cast(tensorDescriptor.sizes[2]); descriptor.luid = spDevice->GetD3DDevice()->GetAdapterLuid(); // Converted image on GPU // Tensorize video frames one by one without extra copy. for (uint32_t batchIdx = 0; batchIdx < videoFrames.Size(); ++batchIdx) { - auto pooledConverter = PoolObjectWrapper::Create(spDevice->TensorizerStore()->Fetch(descriptor)); + auto pooledConverter = _winml::PoolObjectWrapper::Create(spDevice->TensorizerStore()->Fetch(descriptor)); { // Apply tensorization - auto session = spSession.as(); + auto session = spSession.as(); pooledConverter->Get()->Tensorizer->VideoFrameToDX12Tensor( batchIdx, session, @@ -288,11 +282,11 @@ static void GPUTensorize( } } -std::optional ImageFeatureValue::GetInputMetadata(const WinML::BindingContext& context) { +std::optional ImageFeatureValue::GetInputMetadata(const _winml::BindingContext& context) { uint32_t descriptorWidth; uint32_t descriptorHeight; - TensorKind tensorKind = TensorKind::Undefined; + auto tensorKind = winml::TensorKind::Undefined; auto spImageDescriptor = context.descriptor.try_as(); auto spTensorDescriptor = context.descriptor.try_as(); @@ -343,18 +337,18 @@ std::optional ImageFeatureValue::GetIn // Set up BitmapBounds // For batch of images with different sizes, like { {1, 3, 1080, 1080}, {1, 3, 720, 720} }, // a vector of bounds is to record the result after cropped. - std::vector bounds = {}; + std::vector bounds = {}; for (uint32_t i = 0; i < m_batchSize; ++i) { auto tempBounds = GetBoundsFromMetadata(context.properties); if (!tempBounds.has_value()) { // If the user has not specified bounds, we need to infer the bounds // from the combination of descriptor, and input value or output value - if (context.type == BindingType::kInput) { + if (context.type == _winml::BindingType::kInput) { // If unspecified output, get the crop with correct aspect ratio tempBounds = CenterAndCropBounds(i, descriptorWidth, descriptorHeight); } else { // If given an unspecified output region, write into the top left portion of the output image. - tempBounds = BitmapBounds{0, 0, m_widths[i], m_heights[i]}; + tempBounds = wgi::BitmapBounds{0, 0, m_widths[i], m_heights[i]}; } } bounds.emplace_back(tempBounds.value()); @@ -363,7 +357,7 @@ std::optional ImageFeatureValue::GetIn // Set up BitmapPixelFormat - auto pixelFormat = std::optional{}; + auto pixelFormat = std::optional{}; pixelFormat = GetBitmapPixelFormatFromMetadata(context.properties); if (!pixelFormat.has_value() && spImageDescriptor) { pixelFormat = spImageDescriptor->BitmapPixelFormat(); @@ -372,11 +366,11 @@ std::optional ImageFeatureValue::GetIn int channelCount = static_cast(shape.GetAt(1)); if (channelCount == 1) { // Assume Gray if no image descriptor is given and channelcount 1 - pixelFormat = BitmapPixelFormat::Gray8; + pixelFormat = wgi::BitmapPixelFormat::Gray8; } else if (channelCount == 3) { // Assume Bgra8 if no image descriptor is given - pixelFormat = BitmapPixelFormat::Bgra8; + pixelFormat = wgi::BitmapPixelFormat::Bgra8; } else { THROW_HR(WINML_ERR_SIZE_MISMATCH); } @@ -387,7 +381,7 @@ std::optional ImageFeatureValue::GetIn return ImageResourceMetadata{bounds, imageTensorDescriptor}; } -HRESULT ImageFeatureValue::GetValue(WinML::BindingContext& context, IValue** out) try { +HRESULT ImageFeatureValue::GetValue(_winml::BindingContext& context, _winml::IValue** out) try { FAIL_FAST_IF(!(std::all_of(m_widths.begin(), m_widths.end(), [](int i) { return i != 0; }))); FAIL_FAST_IF(!(std::all_of(m_heights.begin(), m_heights.end(), [](int i) { return i != 0; }))); @@ -402,18 +396,19 @@ HRESULT ImageFeatureValue::GetValue(WinML::BindingContext& context, IValue** out auto engine = spSession->GetEngine(); // create the OrtValue - winrt::com_ptr value; + winrt::com_ptr<_winml::IValue> value; RETURN_IF_FAILED(engine->CreateTensorValue( resourceMetadata.TensorDescriptor.sizes, sizeof(resourceMetadata.TensorDescriptor.sizes) / sizeof(resourceMetadata.TensorDescriptor.sizes[0]), - resourceMetadata.TensorDescriptor.dataType == kImageTensorDataTypeFloat32 ? winml::TensorKind::Float : winml::TensorKind::Float16, + resourceMetadata.TensorDescriptor.dataType == _winml::ImageTensorDataType::kImageTensorDataTypeFloat32 ? + winml::TensorKind::Float : winml::TensorKind::Float16, value.put())); // Get the tensor raw data - WinML::Resource void_resource; + _winml::Resource void_resource; RETURN_IF_FAILED(value->GetResource(void_resource)); - if (context.type == BindingType::kInput) { + if (context.type == _winml::BindingType::kInput) { // Only tensorize inputs auto bufferSize = std::accumulate(std::begin(resourceMetadata.TensorDescriptor.sizes), std::end(resourceMetadata.TensorDescriptor.sizes), static_cast(1), std::multiplies()); auto bufferByteSize = GetSizeFromTensorDataType(resourceMetadata.TensorDescriptor.dataType) * bufferSize; @@ -438,29 +433,29 @@ HRESULT ImageFeatureValue::IsPlaceholder(bool* pIsPlaceHolder) { return S_OK; } -HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, IValue* value) try { +HRESULT ImageFeatureValue::UpdateSourceResourceData(_winml::BindingContext& context, _winml::IValue* value) try { // Get the device auto spSession = context.session.as(); auto spDevice = spSession->Device().as(); // Get the output tensor raw data - WinML::Resource void_resource; + _winml::Resource void_resource; RETURN_IF_FAILED(value->GetResource(void_resource)); // Get the run context auto metadata = GetInputMetadata(context); ImageResourceMetadata resourceMetadata = metadata.value(); - ConverterResourceDescription descriptor = {}; + _winml::ConverterResourceDescription descriptor = {}; descriptor.width = static_cast(resourceMetadata.TensorDescriptor.sizes[3]); descriptor.height = static_cast(resourceMetadata.TensorDescriptor.sizes[2]); bool out; if (SUCCEEDED(value->IsCpu(&out)) && out) { - descriptor.pixel_format = static_cast(BitmapPixelFormat::Bgra8); + descriptor.pixel_format = static_cast(wgi::BitmapPixelFormat::Bgra8); descriptor.luid = {}; // Converted image on CPU - auto pooledConverter = PoolObjectWrapper::Create(spDevice->DetensorizerStore()->Fetch(descriptor)); + auto pooledConverter = _winml::PoolObjectWrapper::Create(spDevice->DetensorizerStore()->Fetch(descriptor)); auto bufferSize = std::accumulate(std::begin(resourceMetadata.TensorDescriptor.sizes), std::end(resourceMetadata.TensorDescriptor.sizes), static_cast(1), std::multiplies()); auto bufferByteSize = GetSizeFromTensorDataType(resourceMetadata.TensorDescriptor.dataType) * bufferSize / m_batchSize; @@ -473,10 +468,10 @@ HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, IVa resource += bufferByteSize; } } else { - descriptor.pixel_format = static_cast(DirectXPixelFormat::B8G8R8X8UIntNormalized); + descriptor.pixel_format = static_cast(wgdx::DirectXPixelFormat::B8G8R8X8UIntNormalized); descriptor.luid = spDevice->GetD3DDevice()->GetAdapterLuid(); // Converted image on GPU - auto pooledConverter = PoolObjectWrapper::Create(spDevice->DetensorizerStore()->Fetch(descriptor)); + auto pooledConverter = _winml::PoolObjectWrapper::Create(spDevice->DetensorizerStore()->Fetch(descriptor)); auto d3dResource = reinterpret_cast(void_resource.get()); @@ -501,13 +496,13 @@ HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, IVa } WINML_CATCH_ALL_COM -HRESULT ImageFeatureValue::AbiRepresentation(winrt::Windows::Foundation::IInspectable& abiRepresentation) { +HRESULT ImageFeatureValue::AbiRepresentation(wf::IInspectable& abiRepresentation) { if (IsBatch()) { m_videoFrames.as(abiRepresentation); } else { - winrt::Windows::AI::MachineLearning::ImageFeatureValue to = nullptr; + winml::ImageFeatureValue to = nullptr; RETURN_IF_FAILED(this->QueryInterface( - winrt::guid_of(), + winrt::guid_of(), reinterpret_cast(winrt::put_abi(to)))); to.as(abiRepresentation); @@ -515,18 +510,18 @@ HRESULT ImageFeatureValue::AbiRepresentation(winrt::Windows::Foundation::IInspec return S_OK; } -Windows::AI::MachineLearning::LearningModelFeatureKind ImageFeatureValue::Kind() try { - return LearningModelFeatureKind::Image; +winml::LearningModelFeatureKind ImageFeatureValue::Kind() try { + return winml::LearningModelFeatureKind::Image; } WINML_CATCH_ALL -Windows::Media::VideoFrame ImageFeatureValue::VideoFrame() try { +wm::VideoFrame ImageFeatureValue::VideoFrame() try { return m_videoFrames.GetAt(0); } WINML_CATCH_ALL -IIterable ImageFeatureValue::VideoFrames() try { - return m_videoFrames.try_as>(); +wfc::IIterable ImageFeatureValue::VideoFrames() try { + return m_videoFrames.try_as>(); } WINML_CATCH_ALL -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP diff --git a/winml/lib/Api/ImageFeatureValue.h b/winml/lib/Api/ImageFeatureValue.h index 4c6292fb677b9..e7785283e4e50 100644 --- a/winml/lib/Api/ImageFeatureValue.h +++ b/winml/lib/Api/ImageFeatureValue.h @@ -7,60 +7,61 @@ #include "inc/ILotusValueProviderPrivate.h" -namespace winrt::Windows::AI::MachineLearning::implementation { -struct ImageFeatureValue : ImageFeatureValueT { +namespace WINMLP { + +struct ImageFeatureValue : ImageFeatureValueT { // Metadata about the resource which helps in finding // compatible cached resources struct ImageResourceMetadata; ImageFeatureValue() = delete; ImageFeatureValue(Windows::Media::VideoFrame const& image); - ImageFeatureValue(winrt::Windows::Foundation::Collections::IVector const& images); - ImageFeatureValue(winrt::Windows::Foundation::Collections::IVectorView const& images); + ImageFeatureValue(wfc::IVector const& images); + ImageFeatureValue(wfc::IVectorView const& images); Windows::Media::VideoFrame VideoFrame(); - winrt::Windows::Foundation::Collections::IIterable VideoFrames(); - Windows::AI::MachineLearning::LearningModelFeatureKind Kind(); + wfc::IIterable VideoFrames(); + winml::LearningModelFeatureKind Kind(); - static Windows::AI::MachineLearning::ImageFeatureValue ImageFeatureValue::Create( + static winml::ImageFeatureValue ImageFeatureValue::Create( uint32_t batchSize, Windows::Graphics::Imaging::BitmapPixelFormat format, uint32_t width, uint32_t height); - static Windows::AI::MachineLearning::ImageFeatureValue CreateFromVideoFrame(Windows::Media::VideoFrame const& image); + static winml::ImageFeatureValue CreateFromVideoFrame(Windows::Media::VideoFrame const& image); - std::optional GetInputMetadata(const WinML::BindingContext& context); + std::optional GetInputMetadata(const _winml::BindingContext& context); // ILotusValueProviderPrivate implementation STDMETHOD(GetValue) - (WinML::BindingContext& context, WinML::IValue** out); + (_winml::BindingContext& context, _winml::IValue** out); STDMETHOD(IsPlaceholder) (bool* pIsPlaceHolder); STDMETHOD(UpdateSourceResourceData) - (WinML::BindingContext& context, WinML::IValue* value); + (_winml::BindingContext& context, _winml::IValue* value); STDMETHOD(AbiRepresentation) - (winrt::Windows::Foundation::IInspectable& abiRepresentation); + (wf::IInspectable& abiRepresentation); std::vector Widths() { return m_widths; } std::vector Heights() { return m_heights; } bool IsBatch() { return m_batchSize > 1; } private: - winrt::Windows::Foundation::Collections::IVector m_videoFrames; + wfc::IVector m_videoFrames; std::vector m_widths = {}; std::vector m_heights = {}; uint32_t m_batchSize = 1; // Crop the image with desired aspect ratio. // This function does not crop image to desried width and height, but crops to center for desired ratio - Windows::Graphics::Imaging::BitmapBounds CenterAndCropBounds( + wgi::BitmapBounds CenterAndCropBounds( uint32_t idx, uint32_t desiredWidth, uint32_t desiredHeight); void Initialize(); }; -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP -namespace winrt::Windows::AI::MachineLearning::factory_implementation { +namespace WINML::factory_implementation { struct ImageFeatureValue : ImageFeatureValueT { }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace WINMLP diff --git a/winml/lib/Api/LearningModel.cpp b/winml/lib/Api/LearningModel.cpp index e732d86625832..68432cec6bf51 100644 --- a/winml/lib/Api/LearningModel.cpp +++ b/winml/lib/Api/LearningModel.cpp @@ -14,10 +14,10 @@ #include -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { LearningModel::LearningModel( const hstring& path, - const winml::ILearningModelOperatorProvider op_provider) try : LearningModel(WinML::Strings::UTF8FromHString(path), + const winml::ILearningModelOperatorProvider op_provider) try : LearningModel(_winml::Strings::UTF8FromHString(path), op_provider) { } WINML_CATCH_ALL @@ -34,9 +34,9 @@ LearningModel::LearningModel( WINML_CATCH_ALL static HRESULT CreateModelFromStream( - WinML::IEngineFactory* engine_factory, + _winml::IEngineFactory* engine_factory, const wss::IRandomAccessStreamReference stream, - WinML::IModel** model) { + _winml::IModel** model) { auto content = stream.OpenReadAsync().get(); wss::Buffer buffer(static_cast(content.Size())); @@ -74,7 +74,7 @@ LearningModel::Author() try { const char* out; size_t len; WINML_THROW_IF_FAILED(model_info_->GetAuthor(&out, &len)); - return WinML::Strings::HStringFromUTF8(out); + return _winml::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL @@ -83,7 +83,7 @@ LearningModel::Name() try { const char* out; size_t len; WINML_THROW_IF_FAILED(model_info_->GetName(&out, &len)); - return WinML::Strings::HStringFromUTF8(out); + return _winml::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL @@ -92,7 +92,7 @@ LearningModel::Domain() try { const char* out; size_t len; WINML_THROW_IF_FAILED(model_info_->GetDomain(&out, &len)); - return WinML::Strings::HStringFromUTF8(out); + return _winml::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL @@ -101,7 +101,7 @@ LearningModel::Description() try { const char* out; size_t len; WINML_THROW_IF_FAILED(model_info_->GetDescription(&out, &len)); - return WinML::Strings::HStringFromUTF8(out); + return _winml::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL @@ -227,9 +227,9 @@ LearningModel::LoadFromStream( } WINML_CATCH_ALL -WinML::IModel* +_winml::IModel* LearningModel::DetachModel() { - com_ptr detached_model; + com_ptr<_winml::IModel> detached_model; if (model_ != nullptr) { detached_model.attach(model_.detach()); @@ -239,26 +239,26 @@ LearningModel::DetachModel() { return detached_model.detach(); } -WinML::IModel* +_winml::IModel* LearningModel::CloneModel() { if (model_ == nullptr) { return nullptr; } - com_ptr model_copy; + com_ptr<_winml::IModel> model_copy; WINML_THROW_IF_FAILED(model_->CloneModel(model_copy.put())); return model_copy.detach(); } -WinML::IEngineFactory* +_winml::IEngineFactory* LearningModel::GetEngineFactory() { return engine_factory_.get(); } -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP -namespace winrt::Windows::AI::MachineLearning::factory_implementation { +namespace WINML::factory_implementation { // copied from cppwinrt magic to create abi wrappers. Need to do it this way // since peeps underneath (like the constructor) will throw HRESULT @@ -271,11 +271,11 @@ __stdcall LearningModel::Load( WINML_THROW_HR_IF_FALSE_MSG(E_INVALIDARG, model_path_size > 0, "Failed to create LearningModel. Ivalid argument model_path_size."); WINML_THROW_HR_IF_NULL_MSG(E_INVALIDARG, pp_model_unk, "Failed to create LearningModel. Ivalid argument pp_model_unk."); - auto path = WinML::Strings::UTF8FromUnicode(p_model_path, model_path_size); + auto path = _winml::Strings::UTF8FromUnicode(p_model_path, model_path_size); auto model = make(path, nullptr); *pp_model_unk = model.as().detach(); return S_OK; } WINML_CATCH_ALL_COM } -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace WINML::factory_implementation diff --git a/winml/lib/Api/LearningModel.h b/winml/lib/Api/LearningModel.h index 66f940a0bd7ce..a2dac15122bdd 100644 --- a/winml/lib/Api/LearningModel.h +++ b/winml/lib/Api/LearningModel.h @@ -5,13 +5,13 @@ #include "LearningModel.g.h" -namespace Windows::AI::MachineLearning { +namespace _winml { struct IEngineFactory; struct IModel; struct IModelInfo; -} // namespace Windows::AI::MachineLearning +} // namespace _winml -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { struct LearningModel : LearningModelT { /* LearningModel constructors (MachineLearningContract 1). */ @@ -60,11 +60,11 @@ struct LearningModel : LearningModelT { /* LearningModel static methods (MachineLearningContract 1). */ static wf::IAsyncOperation LoadFromStorageFileAsync( - Windows::Storage::IStorageFile const model_file); + ws::IStorageFile const model_file); static wf::IAsyncOperation LoadFromStorageFileAsync( - Windows::Storage::IStorageFile const model_file, + ws::IStorageFile const model_file, winml::ILearningModelOperatorProvider const operator_provider); static wf::IAsyncOperation @@ -98,25 +98,24 @@ struct LearningModel : LearningModelT { /* Non-ABI methods */ bool IsDisposed(); IMLOperatorRegistry* GetOperatorRegistry(); - WinML::IModel* DetachModel(); - WinML::IModel* CloneModel(); - WinML::IEngineFactory* GetEngineFactory(); + _winml::IModel* DetachModel(); + _winml::IModel* CloneModel(); + _winml::IEngineFactory* GetEngineFactory(); private: - com_ptr engine_factory_; - com_ptr model_; - com_ptr model_info_; + com_ptr<_winml::IEngineFactory> engine_factory_; + com_ptr<_winml::IModel> model_; + com_ptr<_winml::IModelInfo> model_info_; ILearningModelOperatorProvider operator_provider_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation - -namespace winrt::Windows::AI::MachineLearning::factory_implementation { +} // namespace WINMLP +namespace WINML::factory_implementation { struct LearningModel : LearningModelT { STDMETHOD(Load) (const wchar_t* p_model_path, UINT32 model_path_size, IUnknown** pp_model_unk); }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace WINML::factory_implementation diff --git a/winml/lib/Api/LearningModelBinding.cpp b/winml/lib/Api/LearningModelBinding.cpp index 908689c82e947..55007b7297042 100644 --- a/winml/lib/Api/LearningModelBinding.cpp +++ b/winml/lib/Api/LearningModelBinding.cpp @@ -8,12 +8,9 @@ #include "LearningModelBinding.h" #include "LearningModelSession.h" #include "TelemetryEvent.h" -#include #include "LearningModel.h" -using namespace WinML; - -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { LearningModelBinding::~LearningModelBinding() { Clear(); @@ -43,15 +40,15 @@ static winml::ILearningModelFeatureDescriptor FindValidBinding( return nullptr; } -using NullableBindingPort = std::optional>; +using NullableBindingPort = std::optional>; static NullableBindingPort FindValidBinding( winml::LearningModel model, const std::wstring& name) { if (auto descriptor = FindValidBinding(model.InputFeatures(), name)) { - return std::make_pair(descriptor, BindingType::kInput); + return std::make_pair(descriptor, _winml::BindingType::kInput); } else if (auto output_descriptor = FindValidBinding(model.OutputFeatures(), name)) { - return std::make_pair(output_descriptor, BindingType::kOutput); + return std::make_pair(output_descriptor, _winml::BindingType::kOutput); } return {}; @@ -63,13 +60,13 @@ void LearningModelBinding::CacheProvider( m_providers[name] = providerInfo; } -std::tuple, BindingType> LearningModelBinding::CreateBinding( +std::tuple, _winml::BindingType> LearningModelBinding::CreateBinding( const std::string& name, const wf::IInspectable& inspectable, wfc::IPropertySet const& properties) { // Given a known type, validate against the model auto model = m_session.Model(); - auto bindingPort = FindValidBinding(model, WinML::Strings::WStringFromString(name)); + auto bindingPort = FindValidBinding(model, _winml::Strings::WStringFromString(name)); WINML_THROW_HR_IF_FALSE_MSG( WINML_ERR_INVALID_BINDING, @@ -82,7 +79,7 @@ std::tuple, BindingType> LearningMode auto bindingType = bindingPort->second; // Create a feature value from the iinspectable input - auto featureValue = WinML::CreateFeatureValueFromInspectable(bindingType, inspectable, descriptor); + auto featureValue = _winml::CreateFeatureValueFromInspectable(bindingType, inspectable, descriptor); WINML_THROW_HR_IF_NULL_MSG( WINML_ERR_INVALID_BINDING, featureValue, @@ -90,10 +87,10 @@ std::tuple, BindingType> LearningMode name.c_str()); // Validate that the feature value is compatible with the descriptor - WinML::VerifyFeatureValueCompatibleWithDescriptor(featureValue, descriptor); + _winml::VerifyFeatureValueCompatibleWithDescriptor(featureValue, descriptor); // Create the Binding Context to pass to the feature value - BindingContext context{ + _winml::BindingContext context{ bindingType, m_session, descriptor, @@ -102,10 +99,10 @@ std::tuple, BindingType> LearningMode }; // Get the bound tensor - winrt::com_ptr value; + winrt::com_ptr<_winml::IValue> value; // Get the native interface for the given bind value - auto spLotusValueProvider = featureValue.as(); + auto spLotusValueProvider = featureValue.as<_winml::ILotusValueProviderPrivate>(); auto spSession = m_session.as(); @@ -118,7 +115,7 @@ std::tuple, BindingType> LearningMode // This enables the chaining scenario. auto spDevice = m_session.Device().as(); auto isGpuSession = !spDevice->IsCpuDevice(); - auto spTensor = featureValue.try_as(); + auto spTensor = featureValue.try_as(); auto isTensorWithShape = spTensor != nullptr && spTensor.Shape().Size() != 0; auto shouldAlwaysTensorize = isTensorWithShape && isGpuSession; @@ -131,7 +128,7 @@ std::tuple, BindingType> LearningMode } else { WINML_THROW_HR_IF_TRUE_MSG( WINML_ERR_INVALID_BINDING, - isPlaceHolder && bindingType == BindingType::kInput, + isPlaceHolder && bindingType == _winml::BindingType::kInput, "The model variable %s is an input, but has no associated resources to bind.", name.c_str()); @@ -165,16 +162,16 @@ void LearningModelBinding::Bind( _winmlt::TelemetryEvent binding_event(_winmlt::EventCategory::kBinding); - BindingType binding_type; + _winml::BindingType binding_type; std::string binding_name; - winrt::com_ptr binding_value = nullptr; - auto featureName = WinML::Strings::UTF8FromHString(name); + winrt::com_ptr<_winml::IValue> binding_value = nullptr; + auto featureName = _winml::Strings::UTF8FromHString(name); std::tie(binding_name, binding_value, binding_type) = CreateBinding(featureName, value, properties); switch (binding_type) { - case BindingType::kInput: + case _winml::BindingType::kInput: WINML_THROW_IF_FAILED(BindInput(binding_name, binding_value)); break; - case BindingType::kOutput: + case _winml::BindingType::kOutput: WINML_THROW_IF_FAILED(BindOutput(binding_name, binding_value)); break; default: @@ -203,7 +200,7 @@ wfc::IIterator LearningModelBinding::First() std::unordered_map bindingsMap; for (auto mergedBindings : m_providers) { - auto name = WinML::Strings::HStringFromUTF8(mergedBindings.first); + auto name = _winml::Strings::HStringFromUTF8(mergedBindings.first); bindingsMap[name] = mergedBindings.second.CallerSpecifiedFeatureValue; } @@ -211,7 +208,7 @@ wfc::IIterator LearningModelBinding::First() } wf::IInspectable LearningModelBinding::Lookup(hstring const& key) { - auto utf8_name = WinML::Strings::UTF8FromHString(key); + auto utf8_name = _winml::Strings::UTF8FromHString(key); auto foundIt = m_providers.find(utf8_name); WINML_THROW_HR_IF_FALSE_MSG( @@ -229,7 +226,7 @@ uint32_t LearningModelBinding::Size() { } bool LearningModelBinding::HasKey(hstring const& key) { - auto utf8_name = WinML::Strings::UTF8FromHString(key); + auto utf8_name = _winml::Strings::UTF8FromHString(key); return m_providers.find(utf8_name) != m_providers.end(); } @@ -243,15 +240,14 @@ void LearningModelBinding::Split( } ILearningModelFeatureValue LearningModelBinding::CreateUnboundOuputFeatureValue( - const winrt::com_ptr value, + const winrt::com_ptr<_winml::IValue> value, ILearningModelFeatureDescriptor& descriptor) { bool out; if (SUCCEEDED(value->IsTensor(&out)) && out) { if (SUCCEEDED(value->IsOfTensorType(TensorKind::Float, &out)) && out) { if (descriptor.Kind() == LearningModelFeatureKind::Image) { - using namespace Windows::Graphics::Imaging; // TODO: this format for unbound output needs more discussion - BitmapPixelFormat format = descriptor.as()->BitmapPixelFormat(); + wgi::BitmapPixelFormat format = descriptor.as()->BitmapPixelFormat(); std::vector shape; value->GetTensorShape(shape); uint32_t width = static_cast(shape[3]); @@ -372,7 +368,7 @@ ILearningModelFeatureValue LearningModelBinding::CreateUnboundOuputFeatureValue( return winmlp::SequenceTensorFloat16Bit::Create(); } - auto utf8_name = WinML::Strings::UTF8FromHString(descriptor.Name()); + auto utf8_name = _winml::Strings::UTF8FromHString(descriptor.Name()); WINML_THROW_HR_IF_TRUE_MSG( E_UNEXPECTED, true, @@ -384,11 +380,11 @@ ILearningModelFeatureValue LearningModelBinding::CreateUnboundOuputFeatureValue( wf::IInspectable LearningModelBinding::CreateUnboundOutput( const std::string& name, - winrt::com_ptr value) { + winrt::com_ptr<_winml::IValue> value) { // Find valid binding port auto bindingPort = FindValidBinding( m_session.Model(), - WinML::Strings::WStringFromString(name)); + _winml::Strings::WStringFromString(name)); WINML_THROW_HR_IF_FALSE_MSG( E_UNEXPECTED, @@ -401,12 +397,12 @@ wf::IInspectable LearningModelBinding::CreateUnboundOutput( auto bindingType = bindingPort->second; WINML_THROW_HR_IF_FALSE_MSG( E_UNEXPECTED, - bindingType == BindingType::kOutput, + bindingType == _winml::BindingType::kOutput, "The engine produced an unexpected evaluation output %s, that is not a model variable output.", name.c_str()); // Create a binding context - BindingContext context{ + _winml::BindingContext context{ bindingType, m_session, descriptor, @@ -418,7 +414,7 @@ wf::IInspectable LearningModelBinding::CreateUnboundOutput( auto featureValue = CreateUnboundOuputFeatureValue(value, descriptor); // Update feature value - auto spLotusValueProvider = featureValue.as(); + auto spLotusValueProvider = featureValue.as<_winml::ILotusValueProviderPrivate>(); WINML_THROW_IF_FAILED_MSG( spLotusValueProvider->UpdateSourceResourceData(context, value.get()), "Failed to update bound object for model variable output %s", @@ -487,22 +483,22 @@ STDMETHODIMP LearningModelBinding::Bind( CWinMLAutoLock lock(!device->IsCpuDevice() ? session->GetDMLEPLock() : nullptr); _winmlt::TelemetryEvent binding_event(_winmlt::EventCategory::kBinding); - BindingType binding_type; + _winml::BindingType binding_type; std::string binding_name; - winrt::com_ptr binding_value; + winrt::com_ptr<_winml::IValue> binding_value; wf::IInspectable to; RETURN_IF_FAILED(value->QueryInterface( winrt::guid_of(), reinterpret_cast(winrt::put_abi(to)))); - auto featureName = WinML::Strings::UTF8FromUnicode(name, cchName); + auto featureName = _winml::Strings::UTF8FromUnicode(name, cchName); std::tie(binding_name, binding_value, binding_type) = CreateBinding(featureName, to, nullptr); switch (binding_type) { - case BindingType::kInput: + case _winml::BindingType::kInput: WINML_THROW_IF_FAILED(BindInput(binding_name, binding_value)); break; - case BindingType::kOutput: + case _winml::BindingType::kOutput: WINML_THROW_IF_FAILED(BindOutput(binding_name, binding_value)); break; default: @@ -522,13 +518,13 @@ static std::pair Contains(const std::vector& names, c } // This method releases control of memory of ml_value from caller of BindInput -HRESULT LearningModelBinding::BindInput(const std::string& name, winrt::com_ptr value) { +HRESULT LearningModelBinding::BindInput(const std::string& name, winrt::com_ptr<_winml::IValue> value) { bool exists; size_t index; std::tie(exists, index) = Contains(input_names_, name); auto engine = m_session.as()->GetEngine(); - winrt::com_ptr device_value; + winrt::com_ptr<_winml::IValue> device_value; WINML_THROW_IF_FAILED(engine->CreateOneInputAcrossDevices(name.c_str(), value.get(), device_value.put())); // an input will always be copied on device mismatch if (exists) { @@ -541,7 +537,7 @@ HRESULT LearningModelBinding::BindInput(const std::string& name, winrt::com_ptr< return S_OK; } -HRESULT LearningModelBinding::BindOutput(const std::string& name, winrt::com_ptr value) { +HRESULT LearningModelBinding::BindOutput(const std::string& name, winrt::com_ptr<_winml::IValue> value) { bool exists; size_t index; std::tie(exists, index) = Contains(output_names_, name); @@ -564,11 +560,11 @@ const std::vector& LearningModelBinding::GetInputNames() const { return input_names_; } -std::vector>& LearningModelBinding::GetOutputs() { +std::vector>& LearningModelBinding::GetOutputs() { return outputs_; } -const std::vector>& LearningModelBinding::GetInputs() const { +const std::vector>& LearningModelBinding::GetInputs() const { return inputs_; } @@ -596,7 +592,7 @@ void LearningModelBinding::BindUnboundOutputs() { const wchar_t* p_name; uint32_t size; WINML_THROW_IF_FAILED(descriptor_native->GetName(&p_name, &size)); - return WinML::Strings::UTF8FromUnicode(p_name, size); + return _winml::Strings::UTF8FromUnicode(p_name, size); }); // Find the set difference to determine if there are any unbound output features @@ -612,10 +608,10 @@ void LearningModelBinding::BindUnboundOutputs() { for (const auto& unbound_output : unbound_output_names) { auto engine = m_session.as()->GetEngine(); - winrt::com_ptr value; + winrt::com_ptr<_winml::IValue> value; WINML_THROW_IF_FAILED(engine->CreateNullValue(value.put())); WINML_THROW_IF_FAILED(BindOutput(unbound_output, value)); } } -} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file +} // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/LearningModelBinding.h b/winml/lib/Api/LearningModelBinding.h index dc92a43a8e5c8..ed425001dc899 100644 --- a/winml/lib/Api/LearningModelBinding.h +++ b/winml/lib/Api/LearningModelBinding.h @@ -8,76 +8,75 @@ #include "inc/ILotusValueProviderPrivate.h" #include "core/providers/winml/winml_provider_factory.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { struct LearningModelBinding : LearningModelBindingT { struct ProviderInfo { - Windows::Foundation::IInspectable CallerSpecifiedFeatureValue = nullptr; - winrt::com_ptr Provider = nullptr; - WinML::BindingContext Context = {}; + wf::IInspectable CallerSpecifiedFeatureValue = nullptr; + winrt::com_ptr<_winml::ILotusValueProviderPrivate> Provider = nullptr; + _winml::BindingContext Context = {}; }; public: - using KeyValuePair = - Windows::Foundation::Collections::IKeyValuePair; + using KeyValuePair = wfc::IKeyValuePair; ~LearningModelBinding(); LearningModelBinding() = delete; - LearningModelBinding(Windows::AI::MachineLearning::LearningModelSession const& session); + LearningModelBinding(winml::LearningModelSession const& session); - void Bind(hstring const& name, Windows::Foundation::IInspectable const& value); - void Bind(hstring const& name, Windows::Foundation::IInspectable const& value, Windows::Foundation::Collections::IPropertySet const& properties); + void Bind(hstring const& name, wf::IInspectable const& value); + void Bind(hstring const& name, wf::IInspectable const& value, wfc::IPropertySet const& properties); STDMETHOD(Bind)(const wchar_t* name, UINT32 cchName, IUnknown* value); void Clear(); - Windows::Foundation::Collections::IIterator First(); - Windows::Foundation::IInspectable Lookup(hstring const& key); + wfc::IIterator First(); + wf::IInspectable Lookup(hstring const& key); uint32_t Size(); bool HasKey(hstring const& key); void Split( - Windows::Foundation::Collections::IMapView& first, - Windows::Foundation::Collections::IMapView& second); + wfc::IMapView& first, + wfc::IMapView& second); - std::tuple, WinML::BindingType> CreateBinding( + std::tuple, _winml::BindingType> CreateBinding( const std::string& name, - const Windows::Foundation::IInspectable& value, - Windows::Foundation::Collections::IPropertySet const& properties); + const wf::IInspectable& value, + wfc::IPropertySet const& properties); - std::unordered_map UpdateProviders(); + std::unordered_map UpdateProviders(); - const Windows::AI::MachineLearning::LearningModelSession& GetSession() { return m_session; } + const winml::LearningModelSession& GetSession() { return m_session; } const std::vector& GetInputNames() const; const std::vector& GetOutputNames() const; - const std::vector>& GetInputs() const; - std::vector>& GetOutputs(); + const std::vector>& GetInputs() const; + std::vector>& GetOutputs(); - HRESULT BindOutput(const std::string& name, winrt::com_ptr value); + HRESULT BindOutput(const std::string& name, winrt::com_ptr<_winml::IValue> value); void BindUnboundOutputs(); private: void CacheProvider(std::string name, ProviderInfo& spProvider); - Windows::Foundation::IInspectable CreateUnboundOutput(const std::string& name, winrt::com_ptr value); + wf::IInspectable CreateUnboundOutput(const std::string& name, winrt::com_ptr<_winml::IValue> value); ILearningModelFeatureValue CreateUnboundOuputFeatureValue( - const winrt::com_ptr value, + const winrt::com_ptr<_winml::IValue> value, ILearningModelFeatureDescriptor& descriptor); - HRESULT BindInput(const std::string& name, winrt::com_ptr value); + HRESULT BindInput(const std::string& name, winrt::com_ptr<_winml::IValue> value); private: - const Windows::AI::MachineLearning::LearningModelSession m_session; + const winml::LearningModelSession m_session; std::unordered_map m_providers; std::vector input_names_; - std::vector> inputs_; + std::vector> inputs_; std::vector output_names_; - std::vector> outputs_; + std::vector> outputs_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP -namespace winrt::Windows::AI::MachineLearning::factory_implementation { +namespace WINML::factory_implementation { struct LearningModelBinding : LearningModelBindingT { }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace WINML::factory_implementation diff --git a/winml/lib/Api/LearningModelDevice.cpp b/winml/lib/Api/LearningModelDevice.cpp index 011ca06fdfa36..e23f19a76f85e 100644 --- a/winml/lib/Api/LearningModelDevice.cpp +++ b/winml/lib/Api/LearningModelDevice.cpp @@ -10,19 +10,20 @@ #include "ConverterResourceStore.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { + /*static*/ void LearningModelDevice::DllUnload() { } -Windows::Graphics::DisplayAdapterId LearningModelDevice::AdapterId() try { - Windows::Graphics::DisplayAdapterId id; +wg::DisplayAdapterId LearningModelDevice::AdapterId() try { + wg::DisplayAdapterId id; id.LowPart = m_deviceCache->GetDeviceLuid().LowPart; id.HighPart = m_deviceCache->GetDeviceLuid().HighPart; return id; } WINML_CATCH_ALL -LearningModelDevice::LearningModelDevice(Windows::AI::MachineLearning::LearningModelDeviceKind const& deviceKind) try : m_deviceCache(std::make_unique(deviceKind)) { +LearningModelDevice::LearningModelDevice(winml::LearningModelDeviceKind const& deviceKind) try : m_deviceCache(std::make_unique<_winml::D3DDeviceCache>(deviceKind)) { m_deviceKind = deviceKind; m_isCpuDevice = m_deviceKind == LearningModelDeviceKind::Cpu || m_deviceKind == LearningModelDeviceKind::Default; if (m_isCpuDevice) { @@ -31,14 +32,14 @@ LearningModelDevice::LearningModelDevice(Windows::AI::MachineLearning::LearningM } WINML_CATCH_ALL -LearningModelDevice::LearningModelDevice(Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice const& device) try : m_deviceCache(std::make_unique(device)) { +LearningModelDevice::LearningModelDevice(wgdx::Direct3D11::IDirect3DDevice const& device) try : m_deviceCache(std::make_unique<_winml::D3DDeviceCache>(device)) { m_deviceKind = LearningModelDeviceKind::DirectX; m_isCpuDevice = false; } WINML_CATCH_ALL LearningModelDevice::LearningModelDevice(ID3D12CommandQueue* queue) try : m_deviceKind(LearningModelDeviceKind::DirectX), - m_deviceCache(std::make_unique(queue)) { + m_deviceCache(std::make_unique<_winml::D3DDeviceCache>(queue)) { m_isCpuDevice = false; } WINML_CATCH_ALL @@ -47,18 +48,18 @@ LearningModelDevice::~LearningModelDevice() { // needed for shared ptr destruction } -Windows::AI::MachineLearning::LearningModelDevice LearningModelDevice::CreateFromDirect3D11Device(Windows::Graphics::DirectX::Direct3D11::IDirect3DDevice const& device) try { +winml::LearningModelDevice LearningModelDevice::CreateFromDirect3D11Device(wgdx::Direct3D11::IDirect3DDevice const& device) try { return make(device); } WINML_CATCH_ALL -std::shared_ptr<::Windows::AI::MachineLearning::ConverterResourceStore> LearningModelDevice::TensorizerStore() { - std::call_once(m_tensorizerStoreInitialized, [this](){ m_tensorizerStore = ::Windows::AI::MachineLearning::ConverterResourceStore::Create(5); }); +std::shared_ptr<_winml::ConverterResourceStore> LearningModelDevice::TensorizerStore() { + std::call_once(m_tensorizerStoreInitialized, [this]() { m_tensorizerStore = _winml::ConverterResourceStore::Create(5); }); return m_tensorizerStore; } -std::shared_ptr<::Windows::AI::MachineLearning::ConverterResourceStore> LearningModelDevice::DetensorizerStore() { - std::call_once(m_detensorizerStoreInitialized, [this](){ m_detensorizerStore = ::Windows::AI::MachineLearning::ConverterResourceStore::Create(5); }); +std::shared_ptr<_winml::ConverterResourceStore> LearningModelDevice::DetensorizerStore() { + std::call_once(m_detensorizerStoreInitialized, [this]() { m_detensorizerStore = _winml::ConverterResourceStore::Create(5); }); return m_detensorizerStore; } @@ -76,7 +77,7 @@ LearningModelDevice::GetDeviceLuid() { return m_deviceCache->GetDeviceLuid(); } -D3DDeviceCache* +_winml::D3DDeviceCache* LearningModelDevice::GetD3DDeviceCache() { return m_deviceCache.get(); } @@ -111,9 +112,11 @@ STDMETHODIMP_(boolean) LearningModelDevice::SharedHandleInitialized() { return m_deviceCache->SharedHandleInitialized(); } -} // namespace winrt::Windows::AI::MachineLearning::implementation -namespace winrt::Windows::AI::MachineLearning::factory_implementation { +} // namespace WINMLP + +namespace WINML::factory_implementation { + // copied from cppwinrt magic to create abi wrappers. Need to do it this way // since peeps underneath (like the constructor) will throw HRESULT __stdcall LearningModelDevice::CreateFromD3D12CommandQueue( @@ -129,4 +132,5 @@ HRESULT __stdcall LearningModelDevice::CreateFromD3D12CommandQueue( } WINML_CATCH_ALL_COM } -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation + +} // namespace WINML::factory_implementation diff --git a/winml/lib/Api/LearningModelDevice.h b/winml/lib/Api/LearningModelDevice.h index 26639d7c7efb6..b02798e6b4caf 100644 --- a/winml/lib/Api/LearningModelDevice.h +++ b/winml/lib/Api/LearningModelDevice.h @@ -5,12 +5,12 @@ #include "LearningModelDevice.g.h" -namespace Windows::AI::MachineLearning { +namespace _winml { class ConverterResourceStore; +class D3DDeviceCache; } -namespace winrt::Windows::AI::MachineLearning::implementation { -class D3DDeviceCache; +namespace WINMLP { struct LearningModelDevice : LearningModelDeviceT { public: @@ -56,7 +56,7 @@ struct LearningModelDevice : LearningModelDeviceT + std::shared_ptr<_winml::ConverterResourceStore> TensorizerStore(); - std::shared_ptr + std::shared_ptr<_winml::ConverterResourceStore> DetensorizerStore(); private: @@ -83,17 +83,17 @@ struct LearningModelDevice : LearningModelDeviceT m_detensorizerStore; + std::shared_ptr<_winml::ConverterResourceStore> m_detensorizerStore; std::once_flag m_detensorizerStoreInitialized; - std::shared_ptr m_tensorizerStore; + std::shared_ptr<_winml::ConverterResourceStore> m_tensorizerStore; std::once_flag m_tensorizerStoreInitialized; - std::unique_ptr m_deviceCache; + std::unique_ptr<_winml::D3DDeviceCache> m_deviceCache; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP -namespace winrt::Windows::AI::MachineLearning::factory_implementation { +namespace WINML::factory_implementation { struct LearningModelDevice : LearningModelDeviceT { HRESULT __stdcall CreateFromD3D12CommandQueue(ID3D12CommandQueue* queue, IUnknown** device) noexcept final; }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace WINML::factory_implementation diff --git a/winml/lib/Api/LearningModelEvaluationResult.cpp b/winml/lib/Api/LearningModelEvaluationResult.cpp index 776bd79422cb6..1df705b332ef1 100644 --- a/winml/lib/Api/LearningModelEvaluationResult.cpp +++ b/winml/lib/Api/LearningModelEvaluationResult.cpp @@ -4,7 +4,7 @@ #include "pch.h" #include "LearningModelEvaluationResult.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { hstring LearningModelEvaluationResult::CorrelationId() try { return m_correlationId; } @@ -36,7 +36,7 @@ Windows::Foundation::Collections::IMapView outputs; for (auto& output : m_outputs) { - auto key = WinML::Strings::HStringFromUTF8(output.first); + auto key = _winml::Strings::HStringFromUTF8(output.first); auto value = output.second; outputs.emplace(key, value); } @@ -49,7 +49,7 @@ void LearningModelEvaluationResult::Outputs(Windows::Foundation::Collections::IM m_outputs.clear(); for (auto pair : outputs) { - auto key = WinML::Strings::UTF8FromHString(pair.Key()); + auto key = _winml::Strings::UTF8FromHString(pair.Key()); auto value = pair.Value(); m_outputs.emplace(key, value); } @@ -61,7 +61,7 @@ HRESULT LearningModelEvaluationResult::GetOutput( IUnknown** result) { *result = nullptr; - auto outputName = WinML::Strings::UTF8FromUnicode(name, cchName); + auto outputName = _winml::Strings::UTF8FromUnicode(name, cchName); auto foundIt = m_outputs.find(outputName); if (foundIt == std::end(m_outputs)) { @@ -80,4 +80,4 @@ HRESULT LearningModelEvaluationResult::SetOutputs( return S_OK; } -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP diff --git a/winml/lib/Api/LearningModelEvaluationResult.h b/winml/lib/Api/LearningModelEvaluationResult.h index b2d8cedac10e4..b26c3fb0d8c8b 100644 --- a/winml/lib/Api/LearningModelEvaluationResult.h +++ b/winml/lib/Api/LearningModelEvaluationResult.h @@ -5,7 +5,7 @@ #include "LearningModelEvaluationResult.g.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { struct LearningModelEvaluationResult : LearningModelEvaluationResultT< LearningModelEvaluationResult, ILearningModelEvaluationResultNative> { @@ -20,8 +20,8 @@ struct LearningModelEvaluationResult : LearningModelEvaluationResultT< bool Succeeded(); void Succeeded(bool succeeded); - Windows::Foundation::Collections::IMapView Outputs(); - void Outputs(Windows::Foundation::Collections::IMapView outputs); + wfc::IMapView Outputs(); + void Outputs(wfc::IMapView outputs); // ILearningModelEvaluationResultNative STDMETHOD(GetOutput) @@ -30,12 +30,12 @@ struct LearningModelEvaluationResult : LearningModelEvaluationResultT< UINT32 cchName, IUnknown** result); - HRESULT SetOutputs(std::unordered_map&& outputs); + HRESULT SetOutputs(std::unordered_map&& outputs); private: hstring m_correlationId; int32_t m_errorStatus = 0; bool m_succeeded = false; - std::unordered_map m_outputs; + std::unordered_map m_outputs; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP diff --git a/winml/lib/Api/LearningModelSession.cpp b/winml/lib/Api/LearningModelSession.cpp index a89faacd95aad..fd8d0d9c22c6c 100644 --- a/winml/lib/Api/LearningModelSession.cpp +++ b/winml/lib/Api/LearningModelSession.cpp @@ -25,7 +25,7 @@ struct __declspec(uuid("D113B493-BBA2-4993-8608-D706A73B91CE")) __declspec(novta } // namespace guid_details static const GUID WINML_PIX_EVAL_CAPTURABLE_WORK_GUID = __uuidof(guid_details::WINML_PIX_EVAL_CAPTURABLE_WORK_GUID); -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { LearningModelSession::LearningModelSession( winml::LearningModel const& model) try : LearningModelSession(model, @@ -50,7 +50,7 @@ LearningModelSession::LearningModelSession( } WINML_CATCH_ALL -WinML::IModel* +_winml::IModel* LearningModelSession::GetOptimizedModel() { // Get the model proto @@ -61,9 +61,9 @@ LearningModelSession::GetOptimizedModel() { return GetOptimizedModel(should_close_model); } -WinML::IModel* +_winml::IModel* LearningModelSession::GetOptimizedModel(bool should_close_model) { - com_ptr model; + com_ptr<_winml::IModel> model; { // Lock the model detach/copy since multiple threads can access concurrently @@ -93,7 +93,7 @@ void LearningModelSession::Initialize() { _winmlt::TelemetryEvent session_creation_event( _winmlt::EventCategory::kSessionCreation); // Get the optimized model proto from the learning model - com_ptr model; + com_ptr<_winml::IModel> model; model.attach(GetOptimizedModel()); // Create the session builder @@ -102,7 +102,7 @@ void LearningModelSession::Initialize() { engine_factory_.copy_from(model_impl->GetEngineFactory()); - com_ptr engine_builder; + com_ptr<_winml::IEngineBuilder> engine_builder; engine_factory_->CreateEngineBuilder(engine_builder.put()); if (device_impl->IsCpuDevice() == false) { @@ -115,7 +115,7 @@ void LearningModelSession::Initialize() { engine_builder->SetBatchSizeOverride(session_options_.BatchSizeOverride()); } - com_ptr engine; + com_ptr<_winml::IEngine> engine; WINML_THROW_IF_FAILED(engine_builder->CreateEngine(engine.put())); // Register the custom operator registry @@ -206,7 +206,7 @@ uint64_t LearningModelSession::Run(winrt::com_ptr [&](auto& name) { return name.c_str(); }); auto& inputs = binding_impl->GetInputs(); - std::vector inputs_raw; + std::vector<_winml::IValue*> inputs_raw; std::transform( std::begin(inputs), std::end(inputs), @@ -222,7 +222,7 @@ uint64_t LearningModelSession::Run(winrt::com_ptr [&](auto& name) { return name.c_str(); }); auto outputs = binding_impl->GetOutputs(); - std::vector outputs_raw; + std::vector<_winml::IValue*> outputs_raw; std::transform( std::begin(outputs), std::end(outputs), @@ -408,7 +408,7 @@ void LearningModelSession::ToggleProfiler() { } } -WinML::IEngine* +_winml::IEngine* LearningModelSession::GetEngine() { return engine_.get(); } @@ -418,4 +418,4 @@ void LearningModelSession::CheckClosed() { WINML_THROW_HR(RO_E_CLOSED); } } -} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file +} // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/LearningModelSession.h b/winml/lib/Api/LearningModelSession.h index 1a27784b49ba8..f131acb781538 100644 --- a/winml/lib/Api/LearningModelSession.h +++ b/winml/lib/Api/LearningModelSession.h @@ -11,7 +11,7 @@ #include "core/providers/winml/winml_provider_factory.h" #include "iengine.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { struct LearningModelSession : LearningModelSessionT { /* LearningModelSession constructors (MachineLearningContract 1). */ @@ -68,7 +68,7 @@ struct LearningModelSession : LearningModelSessionT { public: /* Non-ABI methods */ - WinML::IEngine* + _winml::IEngine* GetEngine(); void @@ -84,10 +84,10 @@ struct LearningModelSession : LearningModelSessionT { void Initialize(); - WinML::IModel* + _winml::IModel* GetOptimizedModel(); - WinML::IModel* + _winml::IModel* GetOptimizedModel(bool should_close_model); uint64_t @@ -107,8 +107,8 @@ struct LearningModelSession : LearningModelSessionT { ToggleProfiler(); private: - com_ptr engine_factory_; - com_ptr engine_; + com_ptr<_winml::IEngineFactory> engine_factory_; + com_ptr<_winml::IEngine> engine_; using MLOperatorRegistry = std::unique_ptr; MLOperatorRegistry operator_registry_; @@ -129,11 +129,11 @@ struct LearningModelSession : LearningModelSessionT { }; -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP -namespace winrt::Windows::AI::MachineLearning::factory_implementation { +namespace WINML::factory_implementation { struct LearningModelSession : LearningModelSessionT { }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace WINML::factory_implementation diff --git a/winml/lib/Api/LearningModelSessionOptions.cpp b/winml/lib/Api/LearningModelSessionOptions.cpp index c74880fe16279..c2f7a2e55e582 100644 --- a/winml/lib/Api/LearningModelSessionOptions.cpp +++ b/winml/lib/Api/LearningModelSessionOptions.cpp @@ -4,7 +4,7 @@ #include "pch.h" #include "LearningModelSessionOptions.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { LearningModelSessionOptions::LearningModelSessionOptions(const LearningModelSessionOptions& options) : batch_size_override_(options.batch_size_override_), close_model_on_session_creation_(options.close_model_on_session_creation_) {} @@ -23,4 +23,4 @@ bool LearningModelSessionOptions::CloseModelOnSessionCreation() { void LearningModelSessionOptions::CloseModelOnSessionCreation(bool value) { close_model_on_session_creation_ = value; } -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP diff --git a/winml/lib/Api/LearningModelSessionOptions.h b/winml/lib/Api/LearningModelSessionOptions.h index 8e88c7264ea73..59c8a77d38aea 100644 --- a/winml/lib/Api/LearningModelSessionOptions.h +++ b/winml/lib/Api/LearningModelSessionOptions.h @@ -5,7 +5,7 @@ #include "LearningModelSessionOptions.g.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { struct LearningModelSessionOptions : LearningModelSessionOptionsT { LearningModelSessionOptions() = default; @@ -43,9 +43,9 @@ struct LearningModelSessionOptions : LearningModelSessionOptionsT { }; -} // namespace winrt::Windows::AI::MachineLearning::factory_implementation +} // namespace WINML::factory_implementation diff --git a/winml/lib/Api/MapFeatureDescriptor.cpp b/winml/lib/Api/MapFeatureDescriptor.cpp index 9563ec8e3a33a..6614150a45c76 100644 --- a/winml/lib/Api/MapFeatureDescriptor.cpp +++ b/winml/lib/Api/MapFeatureDescriptor.cpp @@ -5,14 +5,14 @@ #include "MapFeatureDescriptor.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { MapFeatureDescriptor::MapFeatureDescriptor( const char* name, const char* description, bool is_required, winml::TensorKind key_kind, - winml::ILearningModelFeatureDescriptor value_kind) : name_(WinML::Strings::HStringFromUTF8(name)), - description_(WinML::Strings::HStringFromUTF8(description)), + winml::ILearningModelFeatureDescriptor value_kind) : name_(_winml::Strings::HStringFromUTF8(name)), + description_(_winml::Strings::HStringFromUTF8(description)), is_required_(is_required), key_kind_(key_kind), value_kind_(value_kind) { @@ -70,4 +70,4 @@ MapFeatureDescriptor::GetDescription( *cchDescription = static_cast(description_.size()); return S_OK; } -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP diff --git a/winml/lib/Api/MapFeatureDescriptor.h b/winml/lib/Api/MapFeatureDescriptor.h index 8faa4292b0ce9..e22d70f9040c8 100644 --- a/winml/lib/Api/MapFeatureDescriptor.h +++ b/winml/lib/Api/MapFeatureDescriptor.h @@ -5,7 +5,7 @@ #include "MapFeatureDescriptor.g.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { struct MapFeatureDescriptor : MapFeatureDescriptorT< MapFeatureDescriptor, ILearningModelFeatureDescriptorNative> { @@ -55,4 +55,4 @@ struct MapFeatureDescriptor : MapFeatureDescriptorT< winml::TensorKind key_kind_; winml::ILearningModelFeatureDescriptor value_kind_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file +} // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/SequenceFeatureDescriptor.cpp b/winml/lib/Api/SequenceFeatureDescriptor.cpp index 2c062c78f2046..30be69502d61c 100644 --- a/winml/lib/Api/SequenceFeatureDescriptor.cpp +++ b/winml/lib/Api/SequenceFeatureDescriptor.cpp @@ -5,15 +5,13 @@ #include "SequenceFeatureDescriptor.h" -using namespace Windows::AI::MachineLearning; - -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { SequenceFeatureDescriptor::SequenceFeatureDescriptor( const char* name, const char* description, bool is_required, - winml::ILearningModelFeatureDescriptor descriptor) : name_(WinML::Strings::HStringFromUTF8(name)), - description_(WinML::Strings::HStringFromUTF8(description)), + winml::ILearningModelFeatureDescriptor descriptor) : name_(_winml::Strings::HStringFromUTF8(name)), + description_(_winml::Strings::HStringFromUTF8(description)), is_required_(is_required), element_descriptor_(descriptor) {} @@ -64,4 +62,4 @@ SequenceFeatureDescriptor::GetDescription( *cchDescription = static_cast(description_.size()); return S_OK; } -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP diff --git a/winml/lib/Api/SequenceFeatureDescriptor.h b/winml/lib/Api/SequenceFeatureDescriptor.h index 7f60691691c9c..4627b8de0e118 100644 --- a/winml/lib/Api/SequenceFeatureDescriptor.h +++ b/winml/lib/Api/SequenceFeatureDescriptor.h @@ -5,7 +5,7 @@ #include "SequenceFeatureDescriptor.g.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT< SequenceFeatureDescriptor, ILearningModelFeatureDescriptorNative> { @@ -48,4 +48,4 @@ struct SequenceFeatureDescriptor : SequenceFeatureDescriptorT< bool is_required_; winml::ILearningModelFeatureDescriptor element_descriptor_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file +} // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/TensorFeatureDescriptor.cpp b/winml/lib/Api/TensorFeatureDescriptor.cpp index 192fae7287d6d..ee91887f25280 100644 --- a/winml/lib/Api/TensorFeatureDescriptor.cpp +++ b/winml/lib/Api/TensorFeatureDescriptor.cpp @@ -7,15 +7,15 @@ #include "TensorFeatureDescriptor.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { TensorFeatureDescriptor::TensorFeatureDescriptor( const char* name, const char* description, winml::TensorKind tensor_kind, const std::vector& shape, bool is_required, - bool has_unsupported_image_metadata) : name_(WinML::Strings::HStringFromUTF8(name)), - description_(WinML::Strings::HStringFromUTF8(description)), + bool has_unsupported_image_metadata) : name_(_winml::Strings::HStringFromUTF8(name)), + description_(_winml::Strings::HStringFromUTF8(description)), tensor_kind_(tensor_kind), shape_(shape), is_required_(is_required), @@ -83,4 +83,4 @@ TensorFeatureDescriptor::GetDescription( *cchDescription = static_cast(description_.size()); return S_OK; } -} // namespace winrt::Windows::AI::MachineLearning::implementation +} // namespace WINMLP diff --git a/winml/lib/Api/TensorFeatureDescriptor.h b/winml/lib/Api/TensorFeatureDescriptor.h index dc233bb3aa4fa..eb83d60bc7466 100644 --- a/winml/lib/Api/TensorFeatureDescriptor.h +++ b/winml/lib/Api/TensorFeatureDescriptor.h @@ -5,7 +5,7 @@ #include "TensorFeatureDescriptor.g.h" -namespace winrt::Windows::AI::MachineLearning::implementation { +namespace WINMLP { struct TensorFeatureDescriptor : TensorFeatureDescriptorT< TensorFeatureDescriptor, ILearningModelFeatureDescriptorNative> { @@ -59,4 +59,4 @@ struct TensorFeatureDescriptor : TensorFeatureDescriptorT< bool is_required_; bool has_unsupported_image_metadata_; }; -} // namespace winrt::Windows::AI::MachineLearning::implementation \ No newline at end of file +} // namespace WINMLP \ No newline at end of file diff --git a/winml/lib/Api/impl/FeatureCompatibility.h b/winml/lib/Api/impl/FeatureCompatibility.h index 53f089d26d2b6..5338623a9e0ed 100644 --- a/winml/lib/Api/impl/FeatureCompatibility.h +++ b/winml/lib/Api/impl/FeatureCompatibility.h @@ -8,11 +8,11 @@ #include "IMapFeatureValue.h" #include "ISequenceFeatureValue.h" #include "TensorFeatureDescriptor.h" +#include "NamespaceAliases.h" -namespace Windows::AI::MachineLearning { +namespace _winml { namespace error_strings { -using namespace winrt::Windows::AI::MachineLearning; // This must be kept in sync with the TensorKind enum in Windows.AI.MachineLearning.idl const char* SzTensorKind[] = @@ -35,7 +35,7 @@ const char* SzTensorKind[] = "Complex128", }; -static std::string ToString(ILearningModelFeatureDescriptor descriptor); +static std::string ToString(winml::ILearningModelFeatureDescriptor descriptor); static std::string ToString(const std::vector& shape) { std::ostringstream stream; @@ -46,36 +46,36 @@ static std::string ToString(const std::vector& shape) { return stream.str(); } -static std::string ToString(winrt::Windows::Foundation::Collections::IVectorView shape) { +static std::string ToString(wfc::IVectorView shape) { auto shapeVec = std::vector(begin(shape), end(shape)); return ToString(shapeVec); } static std::string ToString( - TensorKind kind, - winrt::Windows::Foundation::Collections::IVectorView shape) { - FAIL_FAST_IF_MSG(kind == TensorKind::Complex128, "Unexpected TensorKind Complex128."); - FAIL_FAST_IF_MSG(kind == TensorKind::Complex64, "Unexpected TensorKind Complex64."); - FAIL_FAST_IF_MSG(kind == TensorKind::Undefined, "Unexpected TensorKind Undefined."); + winml::TensorKind kind, + wfc::IVectorView shape) { + FAIL_FAST_IF_MSG(kind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128."); + FAIL_FAST_IF_MSG(kind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64."); + FAIL_FAST_IF_MSG(kind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined."); std::ostringstream stream; stream << SzTensorKind[static_cast(kind)] << ToString(shape); return stream.str(); } -static std::string ToString(ITensorFeatureDescriptor descriptor) { +static std::string ToString(winml::ITensorFeatureDescriptor descriptor) { return ToString(descriptor.TensorKind(), descriptor.Shape()); } -static std::string ToString(ITensor value) { +static std::string ToString(winml::ITensor value) { return ToString(value.TensorKind(), value.Shape()); } -static std::string ToString(IMapFeatureDescriptor descriptor) { +static std::string ToString(winml::IMapFeatureDescriptor descriptor) { auto keyKind = descriptor.KeyKind(); - FAIL_FAST_IF_MSG(keyKind == TensorKind::Complex128, "Unexpected TensorKind Complex128."); - FAIL_FAST_IF_MSG(keyKind == TensorKind::Complex64, "Unexpected TensorKind Complex64."); - FAIL_FAST_IF_MSG(keyKind == TensorKind::Undefined, "Unexpected TensorKind Undefined."); + FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128."); + FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64."); + FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined."); auto valueDescriptor = descriptor.ValueDescriptor(); std::ostringstream stream; @@ -83,29 +83,29 @@ static std::string ToString(IMapFeatureDescriptor descriptor) { return stream.str(); } -static std::string ToString(winrt::com_ptr value) { - TensorKind keyKind; +static std::string ToString(winrt::com_ptr<_winml::IMapFeatureValue> value) { + winml::TensorKind keyKind; FAIL_FAST_IF_FAILED(value->get_KeyKind(&keyKind)); - FAIL_FAST_IF_MSG(keyKind == TensorKind::Complex128, "Unexpected TensorKind Complex128."); - FAIL_FAST_IF_MSG(keyKind == TensorKind::Complex64, "Unexpected TensorKind Complex64."); - FAIL_FAST_IF_MSG(keyKind == TensorKind::Undefined, "Unexpected TensorKind Undefined."); + FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex128, "Unexpected TensorKind Complex128."); + FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Complex64, "Unexpected TensorKind Complex64."); + FAIL_FAST_IF_MSG(keyKind == winml::TensorKind::Undefined, "Unexpected TensorKind Undefined."); - ILearningModelFeatureDescriptor valueDescriptor; + winml::ILearningModelFeatureDescriptor valueDescriptor; FAIL_FAST_IF_FAILED(value->get_ValueDescriptor(&valueDescriptor)); std::ostringstream stream; stream << "Map<" << SzTensorKind[static_cast(keyKind)] << "," << ToString(valueDescriptor) << ">"; return stream.str(); } -static std::string ToString(ISequenceFeatureDescriptor descriptor) { +static std::string ToString(winml::ISequenceFeatureDescriptor descriptor) { auto elementDescriptor = descriptor.ElementDescriptor(); std::ostringstream stream; stream << "Sequence<" << ToString(elementDescriptor) << ">"; return stream.str(); } -static std::string ToString(winrt::com_ptr value) { - ILearningModelFeatureDescriptor elementDescriptor; +static std::string ToString(winrt::com_ptr<_winml::ISequenceFeatureValue> value) { + winml::ILearningModelFeatureDescriptor elementDescriptor; FAIL_FAST_IF_FAILED(value->get_ElementDescriptor(&elementDescriptor)); std::ostringstream stream; @@ -113,49 +113,49 @@ static std::string ToString(winrt::com_ptr value) { return stream.str().c_str(); } -static std::string ToString(IImageFeatureDescriptor descriptor) { +static std::string ToString(winml::IImageFeatureDescriptor descriptor) { std::ostringstream stream; stream << "Image[" << descriptor.Width() << "x" << descriptor.Height() << "]"; return stream.str(); } -static std::string ToString(winrt::com_ptr value) { +static std::string ToString(winrt::com_ptr value) { std::ostringstream stream; stream << "Image[" << value->Widths()[0] << "x" << value->Heights()[0] << "]"; return stream.str(); } -static std::string ToString(ILearningModelFeatureDescriptor descriptor) { +static std::string ToString(winml::ILearningModelFeatureDescriptor descriptor) { switch (descriptor.Kind()) { - case LearningModelFeatureKind::Image: - return ToString(descriptor.as()); + case winml::LearningModelFeatureKind::Image: + return ToString(descriptor.as()); break; - case LearningModelFeatureKind::Map: - return ToString(descriptor.as()); + case winml::LearningModelFeatureKind::Map: + return ToString(descriptor.as()); break; - case LearningModelFeatureKind::Sequence: - return ToString(descriptor.as()); + case winml::LearningModelFeatureKind::Sequence: + return ToString(descriptor.as()); break; - case LearningModelFeatureKind::Tensor: - return ToString(descriptor.as()); + case winml::LearningModelFeatureKind::Tensor: + return ToString(descriptor.as()); default: FAIL_FAST_MSG("Unexpected descriptor LearningModelFeatureKind."); } } -static std::string ToString(ILearningModelFeatureValue value) { +static std::string ToString(winml::ILearningModelFeatureValue value) { switch (value.Kind()) { - case LearningModelFeatureKind::Image: - return ToString(value.as()); + case winml::LearningModelFeatureKind::Image: + return ToString(value.as()); break; - case LearningModelFeatureKind::Map: - return ToString(value.as()); + case winml::LearningModelFeatureKind::Map: + return ToString(value.as<_winml::IMapFeatureValue>()); break; - case LearningModelFeatureKind::Sequence: - return ToString(value.as()); + case winml::LearningModelFeatureKind::Sequence: + return ToString(value.as<_winml::ISequenceFeatureValue>()); break; - case LearningModelFeatureKind::Tensor: - return ToString(value.as()); + case winml::LearningModelFeatureKind::Tensor: + return ToString(value.as()); default: FAIL_FAST_MSG("Unexpected descriptor LearningModelFeatureKind."); } @@ -170,12 +170,12 @@ static std::string ToString(ILearningModelFeatureValue value) { // This matrix is indexed by Kind, and is a group of function pointers which accept // a feature value and descriptr, and return whether they are compatible. namespace compatibility_details { -using namespace winrt::Windows::AI::MachineLearning; -using K = LearningModelFeatureKind; +using K = winml::LearningModelFeatureKind; -static void not_compatible_hr(HRESULT hr, ILearningModelFeatureValue value, ILearningModelFeatureDescriptor descriptor) { - auto name = WinML::Strings::UTF8FromHString(descriptor.Name()); +static void not_compatible_hr(HRESULT hr, winml::ILearningModelFeatureValue value, + winml::ILearningModelFeatureDescriptor descriptor) { + auto name = _winml::Strings::UTF8FromHString(descriptor.Name()); WINML_THROW_IF_FAILED_MSG( hr, @@ -185,29 +185,29 @@ static void not_compatible_hr(HRESULT hr, ILearningModelFeatureValue value, ILea error_strings::ToString(value).c_str()); } -static void not_compatible(ILearningModelFeatureValue value, ILearningModelFeatureDescriptor descriptor) { +static void not_compatible(winml::ILearningModelFeatureValue value, winml::ILearningModelFeatureDescriptor descriptor) { not_compatible_hr(WINML_ERR_INVALID_BINDING, value, descriptor); } // This method is used in validating sequeance and map type's internal element, key and value types. -static HRESULT verify(ILearningModelFeatureDescriptor first, ILearningModelFeatureDescriptor second) { +static HRESULT verify(winml::ILearningModelFeatureDescriptor first, winml::ILearningModelFeatureDescriptor second) { RETURN_HR_IF(WINML_ERR_INVALID_BINDING, first.Kind() != second.Kind()); - if (auto mapFirst = first.try_as()) { - auto mapSecond = second.try_as(); + if (auto mapFirst = first.try_as()) { + auto mapSecond = second.try_as(); RETURN_HR_IF_NULL(WINML_ERR_INVALID_BINDING, mapSecond); RETURN_HR_IF(WINML_ERR_INVALID_BINDING, mapFirst.KeyKind() != mapSecond.KeyKind()); return verify(mapFirst.ValueDescriptor(), mapSecond.ValueDescriptor()); } - if (auto sequenceFirst = first.try_as()) { - auto sequenceSecond = second.try_as(); + if (auto sequenceFirst = first.try_as()) { + auto sequenceSecond = second.try_as(); RETURN_HR_IF_NULL(WINML_ERR_INVALID_BINDING, sequenceSecond); return verify(sequenceFirst.ElementDescriptor(), sequenceSecond.ElementDescriptor()); } - if (auto tensorFirst = first.try_as()) { - auto tensorSecond = second.try_as(); + if (auto tensorFirst = first.try_as()) { + auto tensorSecond = second.try_as(); RETURN_HR_IF_NULL(WINML_ERR_INVALID_BINDING, tensorSecond); RETURN_HR_IF(WINML_ERR_INVALID_BINDING, tensorFirst.TensorKind() != tensorSecond.TensorKind()); return S_OK; @@ -222,22 +222,22 @@ static HRESULT verify(ILearningModelFeatureDescriptor first, ILearningModelFeatu TFeatureDescriptor: feature description from model */ template -void verify(ILearningModelFeatureValue value, ILearningModelFeatureDescriptor descriptor) { +void verify(winml::ILearningModelFeatureValue value, winml::ILearningModelFeatureDescriptor descriptor) { not_compatible(value, descriptor); } template <> void verify( - ILearningModelFeatureValue value, - ILearningModelFeatureDescriptor descriptor) { + winml::ILearningModelFeatureValue value, + winml::ILearningModelFeatureDescriptor descriptor) { thrower fail = std::bind(not_compatible_hr, std::placeholders::_1, value, descriptor); enforce check = std::bind(enforce_not_false, std::placeholders::_1, std::placeholders::_2, fail); - auto tensorValue = value.as(); - auto tensorDescriptor = descriptor.as(); + auto tensorValue = value.as(); + auto tensorDescriptor = descriptor.as(); check(WINML_ERR_INVALID_BINDING, tensorValue.TensorKind() == tensorDescriptor.TensorKind()); - auto spValueProvider = tensorValue.as(); + auto spValueProvider = tensorValue.as<_winml::ILotusValueProviderPrivate>(); bool isPlaceHolder; if (SUCCEEDED(spValueProvider->IsPlaceholder(&isPlaceHolder)) && !isPlaceHolder) { @@ -259,20 +259,20 @@ void verify( template <> void verify( - ILearningModelFeatureValue value, - ILearningModelFeatureDescriptor descriptor) { + winml::ILearningModelFeatureValue value, + winml::ILearningModelFeatureDescriptor descriptor) { thrower fail = std::bind(not_compatible_hr, std::placeholders::_1, value, descriptor); enforce check = std::bind(enforce_not_false, std::placeholders::_1, std::placeholders::_2, fail); enforce_succeeded check_succeeded = std::bind(enforce_not_failed, std::placeholders::_1, fail); - auto spMapFeatureValue = value.as(); - auto mapDescriptor = descriptor.as(); + auto spMapFeatureValue = value.as<_winml::IMapFeatureValue>(); + auto mapDescriptor = descriptor.as(); - TensorKind valueKeyKind; + winml::TensorKind valueKeyKind; check_succeeded(spMapFeatureValue->get_KeyKind(&valueKeyKind)); check(WINML_ERR_INVALID_BINDING, valueKeyKind == mapDescriptor.KeyKind()); - ILearningModelFeatureDescriptor valueValueDescriptor; + winml::ILearningModelFeatureDescriptor valueValueDescriptor; check_succeeded(spMapFeatureValue->get_ValueDescriptor(&valueValueDescriptor)); check_succeeded(verify(valueValueDescriptor, mapDescriptor.ValueDescriptor())); @@ -280,15 +280,15 @@ void verify( template <> void verify( - ILearningModelFeatureValue value, - ILearningModelFeatureDescriptor descriptor) { + winml::ILearningModelFeatureValue value, + winml::ILearningModelFeatureDescriptor descriptor) { thrower fail = std::bind(not_compatible_hr, std::placeholders::_1, value, descriptor); enforce_succeeded check_succeeded = std::bind(enforce_not_failed, std::placeholders::_1, fail); - auto spSequenceFeatureValue = value.as(); - auto sequenceDescriptor = descriptor.as(); + auto spSequenceFeatureValue = value.as<_winml::ISequenceFeatureValue>(); + auto sequenceDescriptor = descriptor.as(); - ILearningModelFeatureDescriptor valueElementDescriptor; + winml::ILearningModelFeatureDescriptor valueElementDescriptor; check_succeeded(spSequenceFeatureValue->get_ElementDescriptor(&valueElementDescriptor)); check_succeeded(verify(valueElementDescriptor, sequenceDescriptor.ElementDescriptor())); @@ -296,8 +296,8 @@ void verify( template <> void verify( - ILearningModelFeatureValue value, - ILearningModelFeatureDescriptor descriptor) { + winml::ILearningModelFeatureValue value, + winml::ILearningModelFeatureDescriptor descriptor) { // No check is needed here. Because: // For batchSize==1, no matter what shape the input has (smaller or larger), we support to bind it. // For batchSize > 1, @@ -310,16 +310,16 @@ void verify( template <> void verify( - ILearningModelFeatureValue value, - ILearningModelFeatureDescriptor descriptor) { + winml::ILearningModelFeatureValue value, + winml::ILearningModelFeatureDescriptor descriptor) { thrower fail = std::bind(not_compatible_hr, std::placeholders::_1, value, descriptor); enforce check = std::bind(enforce_not_false, std::placeholders::_1, std::placeholders::_2, fail); enforce_succeeded check_succeeded = std::bind(enforce_not_failed, std::placeholders::_1, fail); - auto tensorValue = value.as(); - auto imageDescriptor = descriptor.as(); + auto tensorValue = value.as(); + auto imageDescriptor = descriptor.as(); - check(WINML_ERR_INVALID_BINDING, tensorValue.TensorKind() == TensorKind::Float); + check(WINML_ERR_INVALID_BINDING, tensorValue.TensorKind() == winml::TensorKind::Float); auto spValueProvider = tensorValue.as(); @@ -354,13 +354,13 @@ void verify( */ template <> void verify( - ILearningModelFeatureValue value, - ILearningModelFeatureDescriptor descriptor) { + winml::ILearningModelFeatureValue value, + winml::ILearningModelFeatureDescriptor descriptor) { thrower fail = std::bind(not_compatible_hr, std::placeholders::_1, value, descriptor); enforce check = std::bind(enforce_not_false, std::placeholders::_1, std::placeholders::_2, fail); - auto imageValue = value.as(); - auto tensorDescriptor = descriptor.as(); + auto imageValue = value.as(); + auto tensorDescriptor = descriptor.as(); check(WINML_ERR_INVALID_BINDING, !tensorDescriptor->IsUnsupportedMetaData()); // NCHW: images must be 4 dimensions @@ -368,7 +368,7 @@ void verify( check(WINML_ERR_SIZE_MISMATCH, 4 == tensorDescriptorShape.Size()); } -static void (*FeatureKindCompatibilityMatrix[4][4])(ILearningModelFeatureValue, ILearningModelFeatureDescriptor) = +static void (*FeatureKindCompatibilityMatrix[4][4])(winml::ILearningModelFeatureValue, winml::ILearningModelFeatureDescriptor) = { // Tensor, Sequence, Map, Image /* Tensor */ {verify, not_compatible, not_compatible, verify}, @@ -378,8 +378,8 @@ static void (*FeatureKindCompatibilityMatrix[4][4])(ILearningModelFeatureValue, } // namespace compatibility_details inline void VerifyFeatureValueCompatibleWithDescriptor( - winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue value, - winrt::Windows::AI::MachineLearning::ILearningModelFeatureDescriptor descriptor) { + winml::ILearningModelFeatureValue value, + winml::ILearningModelFeatureDescriptor descriptor) { using namespace compatibility_details; auto pfnAreKindsCompatible = @@ -389,4 +389,4 @@ inline void VerifyFeatureValueCompatibleWithDescriptor( pfnAreKindsCompatible(value, descriptor); } -} // namespace Windows::AI::MachineLearning +} // namespace _winml diff --git a/winml/lib/Api/impl/IMapFeatureValue.h b/winml/lib/Api/impl/IMapFeatureValue.h index 630befcc40eec..05d8cbfd45cf4 100644 --- a/winml/lib/Api/impl/IMapFeatureValue.h +++ b/winml/lib/Api/impl/IMapFeatureValue.h @@ -3,17 +3,17 @@ #pragma once -namespace Windows::AI::MachineLearning { +namespace _winml { /* [uuid("3e4d4350-0b61-4517-aa6d-79d49bf164b4"), feature, contract, object, exclusiveto] */ MIDL_INTERFACE("3e4d4350-0b61-4517-aa6d-79d49bf164b4") IMapFeatureValue : public ::IUnknown { public: /* [propget] */ virtual HRESULT STDMETHODCALLTYPE get_KeyKind( - /* [out, retval] */ winrt::Windows::AI::MachineLearning::TensorKind * kind) = 0; + /* [out, retval] */ winml::TensorKind * kind) = 0; /* [propget] */ virtual HRESULT STDMETHODCALLTYPE get_ValueDescriptor( - /* [out, retval] */ winrt::Windows::AI::MachineLearning::ILearningModelFeatureDescriptor * result) = 0; + /* [out, retval] */ winml::ILearningModelFeatureDescriptor * result) = 0; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api/impl/ISequenceFeatureValue.h b/winml/lib/Api/impl/ISequenceFeatureValue.h index 131a3a4814dbd..237bf0c5dae20 100644 --- a/winml/lib/Api/impl/ISequenceFeatureValue.h +++ b/winml/lib/Api/impl/ISequenceFeatureValue.h @@ -3,14 +3,14 @@ #pragma once -namespace Windows::AI::MachineLearning { +namespace _winml { /* [uuid("529d0bca-4c6c-48c1-9bd3-e1ea2e816348"), feature, contract, object, exclusiveto] */ MIDL_INTERFACE("529d0bca-4c6c-48c1-9bd3-e1ea2e816348") ISequenceFeatureValue : public ::IUnknown { public: /* [propget] */ virtual HRESULT STDMETHODCALLTYPE get_ElementDescriptor( - /* [out, retval] */ winrt::Windows::AI::MachineLearning::ILearningModelFeatureDescriptor * result) = 0; + /* [out, retval] */ winml::ILearningModelFeatureDescriptor * result) = 0; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api/impl/MapBase.h b/winml/lib/Api/impl/MapBase.h index 83d8f112a87cd..ac689639a5687 100644 --- a/winml/lib/Api/impl/MapBase.h +++ b/winml/lib/Api/impl/MapBase.h @@ -8,7 +8,7 @@ #include "MapFeatureDescriptor.h" #include "TensorFeatureDescriptor.h" -namespace Windows::AI::MachineLearning { +namespace _winml { // // MapBase @@ -25,9 +25,9 @@ template < typename TValue> struct MapBase : winrt::implements< MapBase, - winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue, - WinML::IMapFeatureValue, - WinML::ILotusValueProviderPrivate> { + winml::ILearningModelFeatureValue, + _winml::IMapFeatureValue, + _winml::ILotusValueProviderPrivate> { static_assert( std::is_same::value || std::is_same::value, @@ -40,21 +40,21 @@ struct MapBase : winrt::implements< std::is_same::value, "Map values must be int64_t, double, float, or winrt::hstring!"); - using ABIMap = ::winrt::Windows::Foundation::Collections::IMap; - using ABIMapView = ::winrt::Windows::Foundation::Collections::IMapView; + using ABIMap = wfc::IMap; + using ABIMapView = wfc::IMapView; MapBase(ABIMap const& data) : data_(data) {} - static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create() { + static winml::ILearningModelFeatureValue Create() { auto abiMap = winrt::single_threaded_map(); return winrt::make(abiMap); } - static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create(const ABIMap& data) { + static winml::ILearningModelFeatureValue Create(const ABIMap& data) { return winrt::make(data); } - static winrt::Windows::AI::MachineLearning::ILearningModelFeatureValue Create(const ABIMapView& data) { + static winml::ILearningModelFeatureValue Create(const ABIMapView& data) { auto abiMap = winrt::single_threaded_map(); for (const auto& pair : data) { auto key = pair.Key(); @@ -65,19 +65,19 @@ struct MapBase : winrt::implements< return winrt::make(abiMap); } // ILearningModelFeatureValue implementation - winrt::Windows::AI::MachineLearning::LearningModelFeatureKind Kind() { - return winrt::Windows::AI::MachineLearning::LearningModelFeatureKind::Map; + winml::LearningModelFeatureKind Kind() { + return winml::LearningModelFeatureKind::Map; } STDMETHOD(get_KeyKind) - (winrt::Windows::AI::MachineLearning::TensorKind* kind) { + (winml::TensorKind* kind) { FAIL_FAST_IF_NULL(kind); *kind = TensorKindFrom::Type; return S_OK; } STDMETHOD(get_ValueDescriptor) - (winrt::Windows::AI::MachineLearning::ILearningModelFeatureDescriptor* result) { + (winml::ILearningModelFeatureDescriptor* result) { FAIL_FAST_IF_NULL(result); *result = TensorFeatureDescriptorFrom::CreateAnonymous(std::vector{}); @@ -86,11 +86,11 @@ struct MapBase : winrt::implements< } STDMETHOD(GetValue) - (WinML::BindingContext& context, IValue** out) { - auto session = context.session.as(); + (_winml::BindingContext& context, IValue** out) { + auto session = context.session.as(); auto engine = session->GetEngine(); - if (context.type == WinML::BindingType::kInput) { + if (context.type == _winml::BindingType::kInput) { RETURN_IF_FAILED(engine->CreateMapValue(reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), TensorKindFrom::Type, TensorKindFrom::Type, out)); } else { RETURN_IF_FAILED(engine->CreateNullValue(out)); @@ -108,7 +108,7 @@ struct MapBase : winrt::implements< STDMETHOD(UpdateSourceResourceData) (BindingContext& context, IValue* value) { data_.Clear(); - auto session = context.session.as(); + auto session = context.session.as(); auto engine = session->GetEngine(); RETURN_IF_FAILED(engine->FillFromMapValue(reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), TensorKindFrom::Type, TensorKindFrom::Type, value)); return S_OK; @@ -116,7 +116,7 @@ struct MapBase : winrt::implements< STDMETHOD(AbiRepresentation) ( - winrt::Windows::Foundation::IInspectable& abiRepresentation) { + wf::IInspectable& abiRepresentation) { data_.as(abiRepresentation); return S_OK; } @@ -125,4 +125,4 @@ struct MapBase : winrt::implements< ABIMap data_; }; -} // namespace Windows::AI::MachineLearning +} // namespace _winml diff --git a/winml/lib/Api/impl/SequenceBase.h b/winml/lib/Api/impl/SequenceBase.h index 20044f55c72e6..fe5aee937d950 100644 --- a/winml/lib/Api/impl/SequenceBase.h +++ b/winml/lib/Api/impl/SequenceBase.h @@ -7,7 +7,7 @@ #include "SequenceFeatureDescriptor.h" #include "TensorFeatureDescriptor.h" -namespace Windows::AI::MachineLearning { +namespace _winml { // SequenceBase // @@ -20,8 +20,8 @@ template struct SequenceBase : public winrt::implements< SequenceBase, winml::ILearningModelFeatureValue, - WinML::ISequenceFeatureValue, - WinML::ILotusValueProviderPrivate> { + _winml::ISequenceFeatureValue, + _winml::ILotusValueProviderPrivate> { using ABISequence = wfc::IIterable; using AbiMapStringToFloat = wfc::IMap; using AbiMapInt64BitToFloat = wfc::IMap; @@ -40,7 +40,7 @@ struct SequenceBase : public winrt::implements< std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value || + std::is_same::value || std::is_same::value, "Only sequences of of map, map and tensor are supported."); @@ -61,7 +61,7 @@ struct SequenceBase : public winrt::implements< void GetElementDescriptor(winml::ILearningModelFeatureDescriptor* result) { - *result = WinML::TensorFeatureDescriptorFrom::CreateAnonymous(std::vector{}); + *result = _winml::TensorFeatureDescriptorFrom::CreateAnonymous(std::vector{}); } template <> @@ -70,7 +70,7 @@ struct SequenceBase : public winrt::implements< winml::ILearningModelFeatureDescriptor* result) { // zero dimensional tensor has empty shape auto value_descriptor = - WinML::TensorFeatureDescriptorFrom::CreateAnonymous( + _winml::TensorFeatureDescriptorFrom::CreateAnonymous( std::vector{}); *result = winrt::make( @@ -87,7 +87,7 @@ struct SequenceBase : public winrt::implements< winml::ILearningModelFeatureDescriptor* result) { // zero dimensional tensor has empty shape auto value_descriptor = - WinML::TensorFeatureDescriptorFrom::CreateAnonymous( + _winml::TensorFeatureDescriptorFrom::CreateAnonymous( std::vector{}); *result = winrt::make( @@ -128,12 +128,12 @@ struct SequenceBase : public winrt::implements< } STDMETHOD(GetValue)( - WinML::BindingContext& context, + _winml::BindingContext& context, IValue** out) { - auto session = context.session.as(); + auto session = context.session.as(); auto engine = session->GetEngine(); - if (context.type == WinML::BindingType::kInput) { + if (context.type == _winml::BindingType::kInput) { winml::ILearningModelFeatureDescriptor descriptor(nullptr); GetElementDescriptor(&descriptor); @@ -153,17 +153,17 @@ struct SequenceBase : public winrt::implements< // GetValue and delegating tensorization to each of those objects. // // The resulting tensors are collected into a vector. - std::vector> sequence; + std::vector> sequence; for (auto tensor : data_) { - auto value_provider = tensor.as(); - winrt::com_ptr out_value; + auto value_provider = tensor.as<_winml::ILotusValueProviderPrivate>(); + winrt::com_ptr<_winml::IValue> out_value; RETURN_IF_FAILED(value_provider->GetValue(context, out_value.put())); sequence.push_back(out_value); } // The collection of IValues needs wrapped into a single IValue // which represents the sequence value. - std::vector sequence_values; + std::vector<_winml::IValue*> sequence_values; std::transform( std::begin(sequence), std::end(sequence), @@ -208,9 +208,9 @@ struct SequenceBase : public winrt::implements< template <> auto CreatePlaceholderTensor() { return winmlp::TensorString::Create(); } void AppendValue( - WinML::BindingContext& context, wfc::IVector data, winrt::com_ptr value) { + _winml::BindingContext& context, wfc::IVector data, winrt::com_ptr<_winml::IValue> value) { auto tensor = CreatePlaceholderTensor(); - auto value_provider = tensor.as(); + auto value_provider = tensor.as<_winml::ILotusValueProviderPrivate>(); WINML_THROW_IF_FAILED(value_provider->UpdateSourceResourceData(context, value.get())); data.Append(tensor); } @@ -220,7 +220,7 @@ struct SequenceBase : public winrt::implements< auto writable_vector = data_.as>(); writable_vector.Clear(); - auto session = context.session.as(); + auto session = context.session.as(); auto engine = session->GetEngine(); winml::ILearningModelFeatureDescriptor descriptor(nullptr); @@ -231,7 +231,7 @@ struct SequenceBase : public winrt::implements< RETURN_IF_FAILED(engine->FillSequenceOfMapsValue(reinterpret_cast<::IInspectable*>(winrt::get_abi(data_)), SequenceAbiTypeInfo::Key, SequenceAbiTypeInfo::Value, out)); } else if (descriptor.Kind() == winml::LearningModelFeatureKind::Tensor) { // In opset 11, operators that require seq> were added. - std::vector> tensor_values; + std::vector> tensor_values; RETURN_IF_FAILED(engine->GetSequenceOfTensorValues(out, tensor_values)); for (auto tensor_value : tensor_values) { @@ -254,4 +254,4 @@ struct SequenceBase : public winrt::implements< ABISequence data_; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api/impl/Tensor.h b/winml/lib/Api/impl/Tensor.h index b5f694590efe9..fd654b1b0d93d 100644 --- a/winml/lib/Api/impl/Tensor.h +++ b/winml/lib/Api/impl/Tensor.h @@ -10,7 +10,8 @@ // TensorBase contains one of these to represent the raw memory // GetCpuResource() returns it // -namespace Windows::AI::MachineLearning { +namespace _winml { + template class Tensor { private: @@ -25,16 +26,16 @@ class Tensor { Tensor( std::vector const& shape, - winrt::Windows::Storage::Streams::IBuffer buffer) : shape_(shape), - m_buffer( - TensorBuffer::Create( - static_cast( - std::accumulate( - std::begin(shape), - std::end(shape), - static_cast(1), - std::multiplies())), - buffer)) { + wss::IBuffer buffer) : shape_(shape), + m_buffer( + TensorBuffer::Create( + static_cast( + std::accumulate( + std::begin(shape), + std::end(shape), + static_cast(1), + std::multiplies())), + buffer)) { } Tensor( @@ -89,4 +90,4 @@ class Tensor { return m_buffer; } }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api/impl/TensorBase.h b/winml/lib/Api/impl/TensorBase.h index 940546add4741..e3ab1cdcd48c3 100644 --- a/winml/lib/Api/impl/TensorBase.h +++ b/winml/lib/Api/impl/TensorBase.h @@ -14,7 +14,7 @@ #include "core/session/onnxruntime_c_api.h" -namespace Windows::AI::MachineLearning { +namespace _winml { // TensorBase // @@ -72,14 +72,14 @@ struct TensorBase : TBase { TensorBase() : m_resources(std::make_shared>()) { } - TensorBase(winrt::Windows::Foundation::Collections::IIterable const& shape) : shape_(begin(shape), end(shape)), - m_resources(std::make_shared>()) { - GetCpuResource() = std::make_shared>(shape_); + TensorBase(wfc::IIterable const& shape) : shape_(begin(shape), end(shape)), + m_resources(std::make_shared>()) { + GetCpuResource() = std::make_shared<_winml::Tensor>(shape_); } TensorBase(std::vector const& shape) : shape_(shape), m_resources(std::make_shared>()) { - GetCpuResource() = std::make_shared>(shape_); + GetCpuResource() = std::make_shared<_winml::Tensor>(shape_); } TensorBase(std::vector const& shape, ID3D12Resource* resource) : shape_(shape), @@ -96,8 +96,8 @@ struct TensorBase : TBase { HRESULT CreateGPUMLValue(ID3D12Resource* resource, BindingContext& context, IValue** out) { THROW_HR_IF_NULL(E_INVALIDARG, resource); - auto session = context.session.as(); - auto device = session->Device().as(); + auto session = context.session.as(); + auto device = session->Device().as(); WINML_THROW_HR_IF_TRUE_MSG(WINML_ERR_INVALID_BINDING, device->IsCpuDevice(), "Cannot create GPU tensor on CPU device"); @@ -107,8 +107,8 @@ struct TensorBase : TBase { return S_OK; } - HRESULT CPUTensorize(WinML::BindingContext& context, IValue** out) { - auto session = context.session.as(); + HRESULT CPUTensorize(_winml::BindingContext& context, IValue** out) { + auto session = context.session.as(); auto engine = session->GetEngine(); if (GetCpuResource() != nullptr) { @@ -123,13 +123,13 @@ struct TensorBase : TBase { WINML_THROW_HR(WINML_ERR_INVALID_BINDING); } - HRESULT GPUTensorize(WinML::BindingContext& context, IValue** out) { + HRESULT GPUTensorize(_winml::BindingContext& context, IValue** out) { if (GetGpuResource() != nullptr) { return CreateGPUMLValue(GetGpuResource().get(), context, out); } // Get engine - auto session = context.session.as(); + auto session = context.session.as(); auto engine = session->GetEngine(); // If there is no matching gpu resource, then fallback to a cpu resource @@ -137,10 +137,10 @@ struct TensorBase : TBase { return CreateTensorValueFromExternalBuffer(engine, out); } - if (TensorKind() == winrt::Windows::AI::MachineLearning::TensorKind::String) { + if (TensorKind() == winml::TensorKind::String) { // Lazily allocate the cpu TensorString resource // TensorStrings are CPU only, and so a gpu resource cannot be allocated for them. - GetCpuResource() = std::make_shared>(shape_); + GetCpuResource() = std::make_shared<_winml::Tensor>(shape_); return CreateTensorValueFromExternalBuffer(engine, out); } else { // Try to allocate the backing memory for the caller @@ -170,7 +170,7 @@ struct TensorBase : TBase { D3D12_TEXTURE_LAYOUT_ROW_MAJOR, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; - auto device = session->Device().as(); + auto device = session->Device().as(); winrt::com_ptr gpu_resource = nullptr; device->GetD3DDevice()->CreateCommittedResource( @@ -200,7 +200,7 @@ struct TensorBase : TBase { // ILotusValueProviderPrivate::GetOrtValue STDMETHOD(GetValue) - (WinML::BindingContext& context, IValue** out) { + (_winml::BindingContext& context, IValue** out) { RETURN_HR_IF_NULL_MSG( WINML_ERR_INVALID_BINDING, m_resources, @@ -208,8 +208,8 @@ struct TensorBase : TBase { EnsureBufferNotInUse(); - auto spSession = context.session.as(); - auto spDevice = spSession->Device().as(); + auto spSession = context.session.as(); + auto spDevice = spSession->Device().as(); if (spDevice->IsCpuDevice()) { RETURN_IF_FAILED(CPUTensorize(context, out)); @@ -250,7 +250,7 @@ struct TensorBase : TBase { } template - HRESULT CreateTensorValueFromExternalBuffer(WinML::IEngine* engine, IValue** value) { + HRESULT CreateTensorValueFromExternalBuffer(_winml::IEngine* engine, IValue** value) { // This adds compile time checks that ensure that the API can only be called when // the conditions of ASSERT_TEMPLATE_PARAMETERS_EXACT() are met. ASSERT_TEMPLATE_PARAMETERS(); @@ -263,7 +263,7 @@ struct TensorBase : TBase { } template <> - HRESULT CreateTensorValueFromExternalBuffer(WinML::IEngine* engine, IValue** value) { + HRESULT CreateTensorValueFromExternalBuffer(_winml::IEngine* engine, IValue** value) { // Ensure that this call is being called with the correct template parameters ASSERT_TEMPLATE_PARAMETERS(); @@ -290,7 +290,7 @@ struct TensorBase : TBase { m_resources, "The tensor has been closed and its resources have been detached during evaluation!"); - WinML::Resource updated_resource; + _winml::Resource updated_resource; RETURN_IF_FAILED(value->GetResource(updated_resource)); // get the shape @@ -298,7 +298,7 @@ struct TensorBase : TBase { // make sure we always have a CPU resource if (GetCpuResource() == nullptr) { - GetCpuResource() = std::make_shared>(shape_); + GetCpuResource() = std::make_shared<_winml::Tensor>(shape_); } bool is_cpu; @@ -320,7 +320,7 @@ struct TensorBase : TBase { // We don't need to copy the engine provided dx resource into a local copy since we always preallocate gpu // resources for tensors. Therefore we are certain that the returned dxresource is the same as the one we passed in // and was updated in place. - auto spSession = context.session.as(); + auto spSession = context.session.as(); auto engine = spSession->GetEngine(); winrt::com_ptr dest; @@ -344,7 +344,7 @@ struct TensorBase : TBase { // ITensor::Create static typename TBase::class_type Create( - winrt::Windows::Foundation::Collections::IIterable const& shape) try { + wfc::IIterable const& shape) try { typename TBase::class_type tensorValue = winrt::make(); auto tensorValueImpl = tensorValue.as(); tensorValueImpl->shape_ = std::vector(begin(shape), end(shape)); @@ -354,8 +354,8 @@ struct TensorBase : TBase { // ITensor::CreateFromIterable static typename TBase::class_type CreateFromIterable( - winrt::Windows::Foundation::Collections::IIterable shape, - winrt::Windows::Foundation::Collections::IIterable const& data) try { + wfc::IIterable shape, + wfc::IIterable const& data) try { std::vector vecShape(begin(shape), end(shape)); if (HasFreeDimensions(vecShape)) { // If the tensor is being created with a free dimension, the data needs to @@ -363,7 +363,7 @@ struct TensorBase : TBase { // In the case of IIterable, there is no Size accessor, and so we require that // in this case the underlying object also implement IVectorView, so that we may // efficiently query the size of the data. - if (auto vectorView = data.try_as>()) { + if (auto vectorView = data.try_as>()) { vecShape = GetAdjustedShape(vecShape, vectorView.Size()); } } @@ -377,7 +377,7 @@ struct TensorBase : TBase { // ITensor::CreateFromArray static typename TBase::class_type CreateFromArray( - winrt::Windows::Foundation::Collections::IIterable shape, + wfc::IIterable shape, winrt::array_view data) try { std::vector vecShape(begin(shape), end(shape)); return CreateFromArrayInternal(vecShape, data); @@ -409,12 +409,12 @@ struct TensorBase : TBase { // ITensor::CreateFromBuffer static typename TBase::class_type CreateFromBuffer( winrt::array_view shape, - winrt::Windows::Storage::Streams::IBuffer const& buffer) try { + wss::IBuffer const& buffer) try { std::vector vecShape(shape.begin(), shape.end()); typename TBase::class_type tensorValue = winrt::make(); auto tensorValueImpl = tensorValue.as(); tensorValueImpl->shape_ = vecShape; - tensorValueImpl->GetCpuResource() = std::make_shared>(vecShape, buffer); + tensorValueImpl->GetCpuResource() = std::make_shared<_winml::Tensor>(vecShape, buffer); return tensorValue; } WINML_CATCH_ALL @@ -501,7 +501,7 @@ struct TensorBase : TBase { /// // IMemoryBuffer::CreateReference - winrt::Windows::Foundation::IMemoryBufferReference CreateReference() try { + wf::IMemoryBufferReference CreateReference() try { // Create a TensorMemoryBufferReference // Per IMemoryBuffer.CreateReference (https://docs.microsoft.com/en-us/uwp/api/windows.foundation.imemorybuffer.createreference) @@ -563,7 +563,7 @@ struct TensorBase : TBase { // ITensor::GetAsVectorView template - winrt::Windows::Foundation::Collections::IVectorView GetAsVectorView() try { + wfc::IVectorView GetAsVectorView() try { // This adds compile time checks that ensure that the API can only be called when: // 1) the conditions of ASSERT_TEMPLATE_PARAMETERS_EXACT() are met. // 2) the signature of the method conforms to the ABI signature and the return value matches the ABI Return Type (ViewT). @@ -588,12 +588,12 @@ struct TensorBase : TBase { // Specialized version to convert float16 to float template <> - winrt::Windows::Foundation::Collections::IVectorView GetAsVectorView() try { + wfc::IVectorView GetAsVectorView<_winml::Half, float>() try { // Ensure that this call is being called with the correct template parameters - ASSERT_TEMPLATE_PARAMETERS(); + ASSERT_TEMPLATE_PARAMETERS<_winml::Half, float>(); uint32_t size; - WinML::Half* pBuffer; + _winml::Half* pBuffer; // Get the data pointer and size std::tie(size, pBuffer) = GetCpuResource()->buffer(); @@ -604,7 +604,7 @@ struct TensorBase : TBase { floatValue.data(), sizeof(float) /* output stride */, reinterpret_cast(pBuffer), - sizeof(WinML::Half) /* input stride */, + sizeof(_winml::Half) /* input stride */, size); // Create IVectorView from copied data. @@ -614,7 +614,7 @@ struct TensorBase : TBase { // Specialized version to convert string to hstring template <> - winrt::Windows::Foundation::Collections::IVectorView GetAsVectorView() try { + wfc::IVectorView GetAsVectorView() try { // Ensure that this call is being called with the correct template parameters ASSERT_TEMPLATE_PARAMETERS(); @@ -627,7 +627,7 @@ struct TensorBase : TBase { copy.begin(), copy.end(), [n = 0, &pData]() mutable { - return WinML::Strings::HStringFromUTF8(pData[n++]); + return _winml::Strings::HStringFromUTF8(pData[n++]); }); return winrt::single_threaded_vector(std::move(copy)).GetView(); @@ -636,7 +636,7 @@ struct TensorBase : TBase { // Specialized version to convert int8_t to uint8_t template <> - winrt::Windows::Foundation::Collections::IVectorView GetAsVectorView() try { + wfc::IVectorView GetAsVectorView() try { ASSERT_TEMPLATE_PARAMETERS(); uint32_t size; @@ -658,19 +658,19 @@ struct TensorBase : TBase { /// // ILearningModelFeatureValue implementation - winrt::Windows::AI::MachineLearning::LearningModelFeatureKind Kind() try { - return winrt::Windows::AI::MachineLearning::LearningModelFeatureKind::Tensor; + winml::LearningModelFeatureKind Kind() try { + return winml::LearningModelFeatureKind::Tensor; } WINML_CATCH_ALL // ITensor::TensorKind - winrt::Windows::AI::MachineLearning::TensorKind TensorKind() try { + winml::TensorKind TensorKind() try { return TensorKindFrom::Type; } WINML_CATCH_ALL // ITensor::Shape - winrt::Windows::Foundation::Collections::IVectorView Shape() try { + wfc::IVectorView Shape() try { std::vector copy(shape_.cbegin(), shape_.cend()); return winrt::single_threaded_vector(std::move(copy)).GetView(); } @@ -678,7 +678,7 @@ struct TensorBase : TBase { // ILotusValueProviderPrivate::AbiRepresentation STDMETHOD(AbiRepresentation) - (winrt::Windows::Foundation::IInspectable& abiRepresentation) { + (wf::IInspectable& abiRepresentation) { using ABIType = typename TBase::class_type; ABIType to = nullptr; RETURN_IF_FAILED(this->QueryInterface( @@ -721,12 +721,12 @@ struct TensorBase : TBase { // Specialized version to convert floats to float16 template <> - void SetBufferFromArray(winrt::array_view data) { + void SetBufferFromArray<_winml::Half, float>(winrt::array_view data) { // Ensure that this call is being called with the correct template parameters - ASSERT_TEMPLATE_PARAMETERS(); + ASSERT_TEMPLATE_PARAMETERS<_winml::Half, float>(); uint32_t size; - WinML::Half* pBuffer; + _winml::Half* pBuffer; // Get the data pointer and size std::tie(size, pBuffer) = GetCpuResource()->buffer(); @@ -734,7 +734,7 @@ struct TensorBase : TBase { THROW_HR_IF(E_UNEXPECTED, data.size() != size); DirectX::PackedVector::XMConvertFloatToHalfStream( reinterpret_cast(pBuffer), - sizeof(WinML::Half) /* output stride */, + sizeof(_winml::Half) /* output stride */, data.data(), sizeof(float) /* input stride */, data.size()); @@ -769,7 +769,7 @@ struct TensorBase : TBase { std::transform( data.begin(), data.end(), pBuffer, [](auto& element) mutable { - return WinML::Strings::UTF8FromHString(element); + return _winml::Strings::UTF8FromHString(element); }); } @@ -778,7 +778,7 @@ struct TensorBase : TBase { /// template void SetBufferFromIterable( - winrt::Windows::Foundation::Collections::IIterable const& data) { + wfc::IIterable const& data) { // This adds compile time checks that ensure that the API can only be called when // the conditions of ASSERT_TEMPLATE_PARAMETERS_EXACT() are met. ASSERT_TEMPLATE_PARAMETERS_EXACT(); @@ -797,13 +797,13 @@ struct TensorBase : TBase { // Specialized version to convert floats to float16 template <> - void SetBufferFromIterable( - winrt::Windows::Foundation::Collections::IIterable const& data) { + void SetBufferFromIterable<_winml::Half, float>( + wfc::IIterable const& data) { // Ensure that this call is being called with the correct template parameters - ASSERT_TEMPLATE_PARAMETERS(); + ASSERT_TEMPLATE_PARAMETERS<_winml::Half, float>(); uint32_t size; - WinML::Half* pBuffer; + _winml::Half* pBuffer; // Get the data pointer and size std::tie(size, pBuffer) = GetCpuResource()->buffer(); @@ -822,7 +822,7 @@ struct TensorBase : TBase { // Specialized version to convert uint8_t to int8_t template <> void SetBufferFromIterable( - winrt::Windows::Foundation::Collections::IIterable const& data) { + wfc::IIterable const& data) { // Ensure that this call is being called with the correct template parameters ASSERT_TEMPLATE_PARAMETERS(); @@ -838,7 +838,7 @@ struct TensorBase : TBase { // Specialized version to convert hstring to string template <> void SetBufferFromIterable( - winrt::Windows::Foundation::Collections::IIterable const& data) { + wfc::IIterable const& data) { // Ensure that this call is being called with the correct template parameters ASSERT_TEMPLATE_PARAMETERS(); @@ -850,11 +850,11 @@ struct TensorBase : TBase { // Convert and copy into the underlying buffer std::transform(begin(data), end(data), pBuffer, [](const auto& element) { - return WinML::Strings::UTF8FromHString(element); + return _winml::Strings::UTF8FromHString(element); }); } - std::shared_ptr>& GetCpuResource() { + std::shared_ptr<_winml::Tensor>& GetCpuResource() { WINML_THROW_HR_IF_NULL_MSG( E_ILLEGAL_METHOD_CALL, m_resources, @@ -879,6 +879,6 @@ struct TensorBase : TBase { bool m_isClosed = false; }; -} // namespace Windows::AI::MachineLearning +} // namespace _winml #pragma warning(pop) diff --git a/winml/lib/Api/impl/TensorBuffer.h b/winml/lib/Api/impl/TensorBuffer.h index bd8c101c3fa88..adb48faae4f62 100644 --- a/winml/lib/Api/impl/TensorBuffer.h +++ b/winml/lib/Api/impl/TensorBuffer.h @@ -6,10 +6,11 @@ #include "robuffer.h" #include "winrt/Windows.Storage.Streams.h" -namespace Windows::AI::MachineLearning { +namespace _winml { + class VectorBuffer : public winrt::implements< VectorBuffer, - winrt::Windows::Storage::Streams::IBuffer, + wss::IBuffer, Windows::Storage::Streams::IBufferByteAccess> { public: VectorBuffer(size_t size) : m_buffer(size) {} @@ -41,7 +42,7 @@ class VectorBuffer : public winrt::implements< template class TensorBuffer { - winrt::Windows::Storage::Streams::IBuffer m_buffer; + wss::IBuffer m_buffer; uint32_t m_size; TensorBuffer(uint32_t size) : m_size(size), @@ -57,7 +58,7 @@ class TensorBuffer { TensorBuffer( uint32_t size, - winrt::Windows::Storage::Streams::IBuffer buffer) : m_size(size), + wss::IBuffer buffer) : m_size(size), m_buffer(buffer) {} public: @@ -69,7 +70,7 @@ class TensorBuffer { static auto Create( uint32_t size, - winrt::Windows::Storage::Streams::IBuffer buffer) { + wss::IBuffer buffer) { return std::shared_ptr(new TensorBuffer(size, buffer)); } @@ -145,4 +146,4 @@ class TensorBuffer { std::copy(data, data + size, m_buffer.begin()); } }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api/impl/TensorKindFrom.h b/winml/lib/Api/impl/TensorKindFrom.h index 487b9cb3b2370..4d67099aa61c7 100644 --- a/winml/lib/Api/impl/TensorKindFrom.h +++ b/winml/lib/Api/impl/TensorKindFrom.h @@ -3,7 +3,7 @@ #pragma once -namespace Windows::AI::MachineLearning { +namespace _winml { // We need to define our own type for Half since DirectX::PackedVector::Half resolves to uint16_t per its typedef declaration. // Templates require an actual type name to resolve correctly. @@ -84,4 +84,4 @@ struct TensorFeatureDescriptorFrom { } }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Api/impl/TensorMemoryBufferReference.h b/winml/lib/Api/impl/TensorMemoryBufferReference.h index 3463e66f14294..3d963f5671079 100644 --- a/winml/lib/Api/impl/TensorMemoryBufferReference.h +++ b/winml/lib/Api/impl/TensorMemoryBufferReference.h @@ -8,7 +8,8 @@ #include -namespace Windows::AI::MachineLearning { +namespace _winml { + template struct TensorResources { // ITensorNative::GetBuffer @@ -29,7 +30,7 @@ struct TensorResources { // Lazily allocate the cpu resource on call to GetBuffer if (CpuResource == nullptr) { - CpuResource = std::make_shared>(shape); + CpuResource = std::make_shared<_winml::Tensor>(shape); } // Get the data pointer and size @@ -46,7 +47,7 @@ struct TensorResources { } // Theses are access directly by TensorMemoryBufferReference and TensorBase - std::shared_ptr> CpuResource; + std::shared_ptr<_winml::Tensor> CpuResource; winrt::com_ptr GpuResource; }; @@ -58,9 +59,9 @@ struct TensorResources { template class TensorMemoryBufferReference : public winrt::implements< TensorMemoryBufferReference, - winrt::Windows::Foundation::IMemoryBufferReference, + wf::IMemoryBufferReference, Windows::Foundation::IMemoryBufferByteAccess> { - using ClosedDelegate = winrt::Windows::Foundation::TypedEventHandler; + using ClosedDelegate = wf::TypedEventHandler; public: // winrt::Windows::Foundation::IMemoryBufferReference @@ -139,9 +140,9 @@ class TensorMemoryBufferReference : public winrt::implements< private: void FireClosed() { - winrt::Windows::Foundation::IMemoryBufferReference memoryBufferReference = nullptr; + wf::IMemoryBufferReference memoryBufferReference = nullptr; WINML_THROW_IF_FAILED(this->QueryInterface( - winrt::guid_of(), + winrt::guid_of(), reinterpret_cast(winrt::put_abi(memoryBufferReference)))); for (auto handler : m_handlers) { @@ -156,4 +157,4 @@ class TensorMemoryBufferReference : public winrt::implements< int64_t m_eventTokenCounter = 0; }; -} // namespace Windows::AI::MachineLearning +} // namespace _winml diff --git a/winml/lib/Api/inc/ILotusValueProviderPrivate.h b/winml/lib/Api/inc/ILotusValueProviderPrivate.h index 70ec0b4f0ba3d..0a487be263719 100644 --- a/winml/lib/Api/inc/ILotusValueProviderPrivate.h +++ b/winml/lib/Api/inc/ILotusValueProviderPrivate.h @@ -8,7 +8,7 @@ // ILotusValueProviderPrivate exposes a private Lotus interface to the engine so that it can retrieve tensor // resources stored in winrt structures. -namespace Windows::AI::MachineLearning { +namespace _winml { class PoolObjectWrapper; @@ -17,17 +17,17 @@ enum class BindingType { kInput, struct BindingContext { BindingType type = BindingType::kInput; - winrt::Windows::AI::MachineLearning::LearningModelSession session = nullptr; - winrt::Windows::AI::MachineLearning::ILearningModelFeatureDescriptor descriptor = nullptr; - winrt::Windows::Foundation::Collections::IPropertySet properties = nullptr; + winml::LearningModelSession session = nullptr; + winml::ILearningModelFeatureDescriptor descriptor = nullptr; + wfc::IPropertySet properties = nullptr; std::shared_ptr converter; }; struct __declspec(uuid("27e2f437-0112-4693-849e-e04323a620fb")) __declspec(novtable) ILotusValueProviderPrivate : IUnknown { - virtual HRESULT __stdcall GetValue(BindingContext& binding_context, WinML::IValue** out) = 0; + virtual HRESULT __stdcall GetValue(BindingContext& binding_context, _winml::IValue** out) = 0; virtual HRESULT __stdcall IsPlaceholder(bool* is_placeholder) = 0; - virtual HRESULT __stdcall UpdateSourceResourceData(BindingContext& binding_context, WinML::IValue* value) = 0; - virtual HRESULT __stdcall AbiRepresentation(winrt::Windows::Foundation::IInspectable& abi_representation) = 0; + virtual HRESULT __stdcall UpdateSourceResourceData(BindingContext& binding_context, _winml::IValue* value) = 0; + virtual HRESULT __stdcall AbiRepresentation(wf::IInspectable& abi_representation) = 0; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Common/CommonDeviceHelpers.cpp b/winml/lib/Common/CommonDeviceHelpers.cpp index e23a7f475b31d..5dbeef7f5fec4 100644 --- a/winml/lib/Common/CommonDeviceHelpers.cpp +++ b/winml/lib/Common/CommonDeviceHelpers.cpp @@ -84,7 +84,7 @@ HRESULT GetDXCoreAdapterMetadata(ID3D12Device& device, bool& isMcdmAdapter, uint } #endif -HRESULT GetD3D12Device(const winrt::Windows::AI::MachineLearning::LearningModelDevice& device, ID3D12Device** outDevice) { +HRESULT GetD3D12Device(const winml::LearningModelDevice& device, ID3D12Device** outDevice) { _LUID id; id.LowPart = device.AdapterId().LowPart; id.HighPart = device.AdapterId().HighPart; @@ -139,7 +139,7 @@ constexpr uint32_t c_intelVendorId = 0x8086; constexpr uint32_t c_nvidiaVendorId = 0x10DE; constexpr uint32_t c_amdVendorId = 0x1002; -bool IsFloat16Supported(const winrt::Windows::AI::MachineLearning::LearningModelDevice& device) { +bool IsFloat16Supported(const winml::LearningModelDevice& device) { auto adapterId = device.AdapterId(); if (!adapterId.HighPart && !adapterId.LowPart) { // CPU device diff --git a/winml/lib/Common/inc/CommonDeviceHelpers.h b/winml/lib/Common/inc/CommonDeviceHelpers.h index 288202ae128d2..a14474a245254 100644 --- a/winml/lib/Common/inc/CommonDeviceHelpers.h +++ b/winml/lib/Common/inc/CommonDeviceHelpers.h @@ -45,5 +45,5 @@ HRESULT RunDelayLoadedApi(TFunc& tfunc, TArgs&&... args) { HRESULT GetAdapterEnumerationSupport(AdapterEnumerationSupport* support); bool IsFloat16Supported(ID3D12Device* device); -bool IsFloat16Supported(const winrt::Windows::AI::MachineLearning::LearningModelDevice& device); +bool IsFloat16Supported(const winml::LearningModelDevice& device); } // namespace CommonDeviceHelpers diff --git a/winml/lib/Common/inc/NamespaceAliases.h b/winml/lib/Common/inc/NamespaceAliases.h index 09424d7dafb05..13926ac49c30e 100644 --- a/winml/lib/Common/inc/NamespaceAliases.h +++ b/winml/lib/Common/inc/NamespaceAliases.h @@ -27,17 +27,19 @@ namespace ws = ::winrt::Windows::Storage; namespace winrt::Windows::Storage::Streams {} namespace wss = ::winrt::Windows::Storage::Streams; -namespace winrt::Windows::AI::MachineLearning {} -namespace winml = ::winrt::Windows::AI::MachineLearning; +#define WINML winrt::WINML_ROOT_NS::AI::MachineLearning +namespace WINML {} +namespace winml = WINML; -namespace winrt::Windows::AI::MachineLearning::implementation {} -namespace winmlp = ::winrt::Windows::AI::MachineLearning::implementation; +#define WINMLP winrt::WINML_ROOT_NS::AI::MachineLearning::implementation +namespace WINMLP {} +namespace winmlp = WINMLP; -namespace Windows::AI::MachineLearning::Adapter {} -namespace winmla = ::Windows::AI::MachineLearning::Adapter; +namespace _winml::Adapter {} +namespace winmla = ::_winml::Adapter; -namespace Windows::AI::MachineLearning {} -namespace WinML = ::Windows::AI::MachineLearning; +namespace _winml::Telemetry {} +namespace _winmlt = ::_winml::Telemetry; -namespace Windows::AI::MachineLearning::Telemetry {} -namespace _winmlt = ::Windows::AI::MachineLearning::Telemetry; +namespace _winml::Imaging {} +namespace _winmli = ::_winml::Imaging; diff --git a/winml/lib/Common/inc/StringHelpers.h b/winml/lib/Common/inc/StringHelpers.h index e8a4c5514aab3..989e92b8ddf6c 100644 --- a/winml/lib/Common/inc/StringHelpers.h +++ b/winml/lib/Common/inc/StringHelpers.h @@ -4,7 +4,7 @@ #pragma once // String Helpers -namespace Windows::AI::MachineLearning::Strings { +namespace _winml::Strings { struct HStringBuilder { HStringBuilder(HStringBuilder const&) = delete; HStringBuilder& operator=(HStringBuilder const&) = delete; @@ -89,4 +89,4 @@ static std::wstring WStringFromString(const std::string& string) { return woss.str(); } -} // namespace Windows::AI::MachineLearning::Strings +} // namespace _winml::Strings diff --git a/winml/lib/Common/inc/WinMLTelemetryHelper.h b/winml/lib/Common/inc/WinMLTelemetryHelper.h index a4bac4c78873e..8809082478fb4 100644 --- a/winml/lib/Common/inc/WinMLTelemetryHelper.h +++ b/winml/lib/Common/inc/WinMLTelemetryHelper.h @@ -44,7 +44,7 @@ class Profiler; // WinMLRuntime Telemetry Support // // {BCAD6AEE-C08D-4F66-828C-4C43461A033D} -#define WINML_PROVIDER_DESC "Microsoft.Windows.AI.MachineLearning" +#define WINML_PROVIDER_DESC "Windows AI Machine Learning" #define WINML_PROVIDER_GUID (0xbcad6aee, 0xc08d, 0x4f66, 0x82, 0x8c, 0x4c, 0x43, 0x46, 0x1a, 0x3, 0x3d) #define WINML_PROVIDER_KEYWORD_DEFAULT 0x1 #define WINML_PROVIDER_KEYWORD_LOTUS_PROFILING 0x2 diff --git a/winml/lib/Common/inc/errors.h b/winml/lib/Common/inc/errors.h index 067e3eaefde86..fc5a52de9d1ab 100644 --- a/winml/lib/Common/inc/errors.h +++ b/winml/lib/Common/inc/errors.h @@ -11,7 +11,7 @@ if (!_status.IsOK()) { \ HRESULT hresult = StatusCodeToHRESULT(static_cast(_status.Code())); \ telemetry_helper.LogRuntimeError(hresult, _status.ErrorMessage(), __FILE__, __FUNCTION__, __LINE__); \ - winrt::hstring errorMessage(WinML::Strings::HStringFromUTF8(_status.ErrorMessage())); \ + winrt::hstring errorMessage(_winml::Strings::HStringFromUTF8(_status.ErrorMessage())); \ throw winrt::hresult_error(hresult, errorMessage); \ } \ } while (0) @@ -28,7 +28,7 @@ char msg[1024]; \ sprintf_s(msg, message, __VA_ARGS__); \ telemetry_helper.LogRuntimeError(_hr, msg, __FILE__, __FUNCTION__, __LINE__); \ - winrt::hstring errorMessage(WinML::Strings::HStringFromUTF8(msg)); \ + winrt::hstring errorMessage(_winml::Strings::HStringFromUTF8(msg)); \ throw winrt::hresult_error(_hr, errorMessage); \ } \ } while (0) diff --git a/winml/lib/Common/inc/iengine.h b/winml/lib/Common/inc/iengine.h index 26e35f8c5c48b..8314e049bf016 100644 --- a/winml/lib/Common/inc/iengine.h +++ b/winml/lib/Common/inc/iengine.h @@ -3,7 +3,7 @@ #pragma once -namespace Windows::AI::MachineLearning { +namespace _winml { MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown { @@ -55,7 +55,7 @@ IValue : IUnknown { (bool* out) PURE; STDMETHOD(GetResource) - (WinML::Resource & resource) PURE; + (_winml::Resource & resource) PURE; STDMETHOD(IsTensor) (bool* out) PURE; @@ -148,7 +148,7 @@ IEngine : IUnknown { (IInspectable * sequence, winml::TensorKind key_kind, winml::TensorKind value_kind, IValue * value) PURE; STDMETHOD(GetSequenceOfTensorValues) - (WinML::IValue* sequence_value, _Out_ std::vector>& out_values) PURE; + (_winml::IValue* sequence_value, _Out_ std::vector>& out_values) PURE; }; MIDL_INTERFACE("8ac0b6b9-4561-492b-b63d-a07bdd8292c6") @@ -190,4 +190,4 @@ IEngineFactory : IUnknown { (_Out_ IMLOperatorRegistry * *registry) PURE; }; -} // namespace Windows::AI::MachineLearning \ No newline at end of file +} // namespace _winml \ No newline at end of file diff --git a/winml/lib/Common/inc/winrt_headers.h b/winml/lib/Common/inc/winrt_headers.h index 0b31d32c44158..ed8de42124201 100644 --- a/winml/lib/Common/inc/winrt_headers.h +++ b/winml/lib/Common/inc/winrt_headers.h @@ -9,10 +9,20 @@ #include "winrt/windows.graphics.imaging.h" #include "winrt/windows.foundation.h" #include "winrt/windows.foundation.collections.h" -#include "comp_generated/winrt/windows.ai.machinelearning.h" + +#define STRINGIFY(x) #x +#define XSTRINGIFY(x) STRINGIFY(x) +#define CPPWINRT_HEADER(root_ns) comp_generated/winrt/##root_ns##.AI.MachineLearning.h +#define NATIVE_HEADER(root_ns) root_ns##.AI.MachineLearning.native.h +#define NATIVE_INTERNAL_HEADER(root_ns) root_ns##.AI.MachineLearning.native.internal.h +#define CREATE_CPPWINRT_COMPONENT_HEADER() XSTRINGIFY(CPPWINRT_HEADER(WINML_ROOT_NS)) +#define CREATE_NATIVE_HEADER() XSTRINGIFY(NATIVE_HEADER(WINML_ROOT_NS)) +#define CREATE_NATIVE_INTERNAL_HEADER() XSTRINGIFY(NATIVE_INTERNAL_HEADER(WINML_ROOT_NS)) + +#include CREATE_CPPWINRT_COMPONENT_HEADER() // WinML Native Headers -#include "Windows.AI.MachineLearning.Native.h" -#include "Windows.AI.MachineLearning.Native.Internal.h" +#include CREATE_NATIVE_HEADER() +#include CREATE_NATIVE_INTERNAL_HEADER() #include "Errors.h" \ No newline at end of file diff --git a/winml/lib/Telemetry/WinMLTelemetryHelper.cpp b/winml/lib/Telemetry/WinMLTelemetryHelper.cpp index a14caf8519e26..28cda97634685 100644 --- a/winml/lib/Telemetry/WinMLTelemetryHelper.cpp +++ b/winml/lib/Telemetry/WinMLTelemetryHelper.cpp @@ -17,12 +17,14 @@ WinMLTelemetryHelper::~WinMLTelemetryHelper() { } void WinMLTelemetryHelper::LogWinMLShutDown() { + std::string message = BINARY_NAME; + message += " is unloaded"; WinMLTraceLoggingWrite( provider_, "WinMLShutDown", TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT), TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), - TraceLoggingString("windows.ai.machinelearning.dll is unloaded", "message"), + TraceLoggingString(message.c_str(), "message"), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES)); } diff --git a/winml/lib/Telemetry/inc/TelemetryEvent.h b/winml/lib/Telemetry/inc/TelemetryEvent.h index 8376211818b2f..ff52eb941dde5 100644 --- a/winml/lib/Telemetry/inc/TelemetryEvent.h +++ b/winml/lib/Telemetry/inc/TelemetryEvent.h @@ -3,7 +3,7 @@ #pragma once -namespace Windows::AI::MachineLearning::Telemetry { +namespace _winml::Telemetry { enum class EventCategory { kModelLoad = 0, @@ -24,4 +24,4 @@ class TelemetryEvent { std::optional event_id_; }; -} // namespace Windows::AI::MachineLearning::Telemetry \ No newline at end of file +} // namespace _winml::Telemetry \ No newline at end of file diff --git a/winml/lib/Telemetry/pch.h b/winml/lib/Telemetry/pch.h index 421903725edb8..5ec2dd9e47d5a 100644 --- a/winml/lib/Telemetry/pch.h +++ b/winml/lib/Telemetry/pch.h @@ -5,3 +5,4 @@ #include "common.h" #include "TraceLoggingConfig.h" +#include "NamespaceAliases.h" \ No newline at end of file diff --git a/winml/test/adapter/adapter_test.cpp b/winml/test/adapter/adapter_test.cpp new file mode 100644 index 0000000000000..d20aee70a79d7 --- /dev/null +++ b/winml/test/adapter/adapter_test.cpp @@ -0,0 +1,331 @@ +#include "testPch.h" +#include "adapter_test.h" +#include "fileHelpers.h" +#include "winrt/Windows.Storage.h" +#include "winrt/Windows.Storage.Streams.h" + +using namespace ws; +using namespace wss; + +static void AdapterTestSetup() { + ort_api = OrtGetApiBase()->GetApi(2); + winml_adapter_api = OrtGetWinMLAdapter(ort_api); + + // for model tests + std::wstring module_path = FileHelpers::GetModulePath(); + std::string squeezenet_path = std::wstring_convert>().to_bytes(module_path + L"squeezenet_modifiedforruntimestests.onnx"); + std::string metadata_path = std::wstring_convert>().to_bytes(module_path + L"modelWith2MetaData.onnx"); + std::string float16_path = std::wstring_convert>().to_bytes(module_path + L"starry-night-fp16.onnx"); + winml_adapter_api->CreateModelFromPath(squeezenet_path.c_str(), squeezenet_path.size(), &squeezenet_model); + winml_adapter_api->CreateModelFromPath(metadata_path.c_str(), metadata_path.size(), &metadata_model); + winml_adapter_api->CreateModelFromPath(float16_path.c_str(), float16_path.size(), &float16_Model); +} + +static void AdapterTestTeardown() { + winml_adapter_api->ReleaseModel(squeezenet_model); + winml_adapter_api->ReleaseModel(metadata_model); + winml_adapter_api->ReleaseModel(float16_Model); +} + +static void CreateModelFromPath() { + WINML_EXPECT_TRUE(squeezenet_model != nullptr); + WINML_EXPECT_TRUE(metadata_model != nullptr); + WINML_EXPECT_TRUE(float16_Model != nullptr); +} + +static void CreateModelFromData() { + StorageFolder folder = StorageFolder::GetFolderFromPathAsync(FileHelpers::GetModulePath()).get(); + StorageFile file = folder.GetFileAsync(L"squeezenet_modifiedforruntimestests.onnx").get(); + IRandomAccessStream stream = file.OpenAsync(FileAccessMode::Read).get(); + DataReader data_reader(stream.GetInputStreamAt(0)); + data_reader.LoadAsync(static_cast(stream.Size())).get(); + IBuffer data_buffer = data_reader.DetachBuffer(); + OrtModel* squeezenet_model_from_data = nullptr; + winml_adapter_api->CreateModelFromData(data_buffer.data(), data_buffer.Length(), &squeezenet_model_from_data); + WINML_EXPECT_TRUE(squeezenet_model_from_data != nullptr); + // Verify a function in the model for thoroughness + const char* author; + size_t len; + winml_adapter_api->ModelGetAuthor(squeezenet_model_from_data, &author, &len); + std::string author_str(author); + WINML_EXPECT_EQUAL(author_str, "onnx-caffe2"); + winml_adapter_api->ReleaseModel(squeezenet_model_from_data); +} + +static void CloneModel() { + OrtModel* squeezenet_clone = nullptr; + winml_adapter_api->CloneModel(squeezenet_model, &squeezenet_clone); + WINML_EXPECT_TRUE(squeezenet_clone != nullptr); + // Verify a function in clone + const char* author; + size_t len; + winml_adapter_api->ModelGetAuthor(squeezenet_clone, &author, &len); + std::string author_str(author); + WINML_EXPECT_EQUAL(author_str, "onnx-caffe2"); +} + +static void ModelGetAuthor() { + const char* author; + size_t len; + winml_adapter_api->ModelGetAuthor(squeezenet_model, &author, &len); + std::string author_str(author); + WINML_EXPECT_EQUAL(author_str, "onnx-caffe2"); +} + +static void ModelGetName() { + const char* name; + size_t len; + winml_adapter_api->ModelGetName(squeezenet_model, &name, &len); + std::string name_str(name); + WINML_EXPECT_EQUAL(name_str, "squeezenet_old"); +} + +static void ModelGetDomain() { + const char* domain; + size_t len; + winml_adapter_api->ModelGetDomain(squeezenet_model, &domain, &len); + std::string domain_str(domain); + WINML_EXPECT_EQUAL(domain_str, "test-domain"); +} + +static void ModelGetDescription() { + const char* description; + size_t len; + winml_adapter_api->ModelGetDescription(squeezenet_model, &description, &len); + std::string description_str(description); + WINML_EXPECT_EQUAL(description_str, "test-doc_string"); +} + +static void ModelGetVersion() { + int64_t version; + winml_adapter_api->ModelGetVersion(squeezenet_model, &version); + WINML_EXPECT_EQUAL(version, 123456); +} + +static void ModelGetInputCount() { + size_t input_count; + winml_adapter_api->ModelGetInputCount(squeezenet_model, &input_count); + WINML_EXPECT_EQUAL(input_count, 1); +} + +static void ModelGetOutputCount() { + size_t output_count; + winml_adapter_api->ModelGetOutputCount(squeezenet_model, &output_count); + WINML_EXPECT_EQUAL(output_count, 1); +} + +static void ModelGetInputName() { + const char* input_name; + size_t count; + winml_adapter_api->ModelGetInputName(squeezenet_model, 0, &input_name, &count); + std::string input_name_str(input_name); + WINML_EXPECT_EQUAL(input_name_str, "data_0"); +} + +static void ModelGetOutputName() { + const char* output_name; + size_t count; + winml_adapter_api->ModelGetOutputName(squeezenet_model, 0, &output_name, &count); + std::string output_name_str(output_name); + WINML_EXPECT_EQUAL(output_name_str, "softmaxout_1"); +} + +static void ModelGetInputDescription() { + const char* input_description; + size_t count; + winml_adapter_api->ModelGetInputDescription(metadata_model, 0, &input_description, &count); + std::string input_description_str(input_description); + WINML_EXPECT_EQUAL(input_description_str, "this is a long input description!"); +} + +static void ModelGetOutputDescription() { + const char* output_description; + size_t count; + winml_adapter_api->ModelGetOutputDescription(metadata_model, 0, &output_description, &count); + std::string output_description_str(output_description); + WINML_EXPECT_EQUAL(output_description_str, "this is a long output description!"); +} + +static void ModelGetInputTypeInfo() { + OrtTypeInfo* input_type_info; + winml_adapter_api->ModelGetInputTypeInfo(squeezenet_model, 0, &input_type_info); + + ONNXType input_type; + ort_api->GetOnnxTypeFromTypeInfo(input_type_info, &input_type); + WINML_EXPECT_EQUAL(input_type, ONNX_TYPE_TENSOR); + + const OrtTensorTypeAndShapeInfo* tensor_info; + ort_api->CastTypeInfoToTensorInfo(input_type_info, &tensor_info); + + ONNXTensorElementDataType tensor_type; + ort_api->GetTensorElementType(tensor_info, &tensor_type); + WINML_EXPECT_EQUAL(tensor_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + size_t dim_count; + ort_api->GetDimensionsCount(tensor_info, &dim_count); + WINML_EXPECT_EQUAL(dim_count, 4); + + int64_t dim_values[4]; + ort_api->GetDimensions(tensor_info, dim_values, 4); + WINML_EXPECT_EQUAL(dim_values[0], 1); + WINML_EXPECT_EQUAL(dim_values[1], 3); + WINML_EXPECT_EQUAL(dim_values[2], 224); + WINML_EXPECT_EQUAL(dim_values[3], 224); + + ort_api->ReleaseTypeInfo(input_type_info); +} + +static void ModelGetOutputTypeInfo() { + OrtTypeInfo* output_type_info; + winml_adapter_api->ModelGetOutputTypeInfo(squeezenet_model, 0, &output_type_info); + + ONNXType output_type; + ort_api->GetOnnxTypeFromTypeInfo(output_type_info, &output_type); + WINML_EXPECT_EQUAL(output_type, ONNX_TYPE_TENSOR); + + const OrtTensorTypeAndShapeInfo* tensor_info; + ort_api->CastTypeInfoToTensorInfo(output_type_info, &tensor_info); + + ONNXTensorElementDataType tensor_type; + ort_api->GetTensorElementType(tensor_info, &tensor_type); + WINML_EXPECT_EQUAL(tensor_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + size_t dim_count; + ort_api->GetDimensionsCount(tensor_info, &dim_count); + WINML_EXPECT_EQUAL(dim_count, 4); + + int64_t dim_values[4]; + ort_api->GetDimensions(tensor_info, dim_values, 4); + WINML_EXPECT_EQUAL(dim_values[0], 1); + WINML_EXPECT_EQUAL(dim_values[1], 1000); + WINML_EXPECT_EQUAL(dim_values[2], 1); + WINML_EXPECT_EQUAL(dim_values[3], 1); + + ort_api->ReleaseTypeInfo(output_type_info); +} + +static void ModelGetMetadataCount() { + size_t metadata_count; + winml_adapter_api->ModelGetMetadataCount(metadata_model, &metadata_count); + WINML_EXPECT_EQUAL(metadata_count, 2); +} + +static void ModelGetMetadata() { + const char* metadata_key; + size_t metadata_key_len; + const char* metadata_value; + size_t metadata_value_len; + + winml_adapter_api->ModelGetMetadata(metadata_model, 0, &metadata_key, &metadata_key_len, &metadata_value, &metadata_value_len); + WINML_EXPECT_EQUAL(std::string(metadata_key), "thisisalongkey"); + WINML_EXPECT_EQUAL(metadata_key_len, 14); + WINML_EXPECT_EQUAL(std::string(metadata_value), "thisisalongvalue"); + WINML_EXPECT_EQUAL(metadata_value_len, 16); + + winml_adapter_api->ModelGetMetadata(metadata_model, 1, &metadata_key, &metadata_key_len, &metadata_value, &metadata_value_len); + WINML_EXPECT_EQUAL(std::string(metadata_key), "key2"); + WINML_EXPECT_EQUAL(metadata_key_len, 4); + WINML_EXPECT_EQUAL(std::string(metadata_value), "val2"); + WINML_EXPECT_EQUAL(metadata_value_len, 4); +} + +static void ModelEnsureNoFloat16() { + OrtStatus* float16_error_status; + + float16_error_status = winml_adapter_api->ModelEnsureNoFloat16(squeezenet_model); + WINML_EXPECT_EQUAL(float16_error_status, nullptr); + + float16_error_status = winml_adapter_api->ModelEnsureNoFloat16(float16_Model); + WINML_EXPECT_NOT_EQUAL(float16_error_status, nullptr); + WINML_EXPECT_EQUAL(ort_api->GetErrorCode(float16_error_status), ORT_INVALID_GRAPH); +} + +static void __stdcall TestLoggingCallback(void* param, OrtLoggingLevel severity, const char* category, + const char* logger_id, const char* code_location, const char* message) noexcept { + UNREFERENCED_PARAMETER(param); + UNREFERENCED_PARAMETER(severity); + UNREFERENCED_PARAMETER(category); + UNREFERENCED_PARAMETER(logger_id); + UNREFERENCED_PARAMETER(code_location); + UNREFERENCED_PARAMETER(message); + logging_function_called = true; +} + +static void __stdcall TestProfileEventCallback(const OrtProfilerEventRecord* profiler_record) noexcept { + UNREFERENCED_PARAMETER(profiler_record); + profiling_function_called = true; +} + +static void EnvConfigureCustomLoggerAndProfiler() { + OrtEnv* ort_env = nullptr; + ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env); + winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(ort_env, + &TestLoggingCallback, &TestProfileEventCallback, nullptr, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env); + logging_function_called = false; + OrtSession* ort_session = nullptr; + std::wstring squeezenet_path = FileHelpers::GetModulePath() + L"relu.onnx"; + ort_api->CreateSession(ort_env, squeezenet_path.c_str(), nullptr, &ort_session); + WINML_EXPECT_TRUE(logging_function_called); + + size_t input_tensor_size = 5; + int64_t input_dimensions[] = {5}; + + std::vector input_tensor_values(input_tensor_size); + std::vector input_node_names = {"X"}; + std::vector output_node_names = {"Y"}; + + // initialize input data with values in [0.0, 1.0] + for (size_t i = 0; i < input_tensor_size; i++) + input_tensor_values[i] = (float)i / (input_tensor_size + 1); + + // create input tensor object from data values + OrtMemoryInfo* memory_info; + ort_api->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info); + OrtValue* input_tensor = nullptr; + ort_api->CreateTensorWithDataAsOrtValue(memory_info, input_tensor_values.data(), input_tensor_size * sizeof(float), input_dimensions, 1, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensor); + int is_tensor; + ort_api->IsTensor(input_tensor, &is_tensor); + assert(is_tensor); + ort_api->ReleaseMemoryInfo(memory_info); + OrtValue* output_tensor = nullptr; + winml_adapter_api->SessionStartProfiling(ort_env, ort_session); + profiling_function_called = false; + ort_api->Run(ort_session, nullptr, input_node_names.data(), (const OrtValue* const*)&input_tensor, 1, output_node_names.data(), 1, &output_tensor); + WINML_EXPECT_TRUE(profiling_function_called); + winml_adapter_api->SessionEndProfiling(ort_session); + + ort_api->ReleaseValue(output_tensor); + ort_api->ReleaseValue(input_tensor); + ort_api->ReleaseSession(ort_session); + ort_api->ReleaseEnv(ort_env); +} + +const AdapterTestApi& getapi() { + static constexpr AdapterTestApi api = + { + AdapterTestSetup, + AdapterTestTeardown, + CreateModelFromPath, + CreateModelFromData, + CloneModel, + ModelGetAuthor, + ModelGetName, + ModelGetDomain, + ModelGetDescription, + ModelGetVersion, + ModelGetInputCount, + ModelGetOutputCount, + ModelGetInputName, + ModelGetOutputName, + ModelGetInputDescription, + ModelGetOutputDescription, + ModelGetInputTypeInfo, + ModelGetOutputTypeInfo, + ModelGetMetadataCount, + ModelGetMetadata, + ModelEnsureNoFloat16, + EnvConfigureCustomLoggerAndProfiler, + }; + return api; +} \ No newline at end of file diff --git a/winml/test/adapter/adapter_test.h b/winml/test/adapter/adapter_test.h new file mode 100644 index 0000000000000..f4814afa0db3e --- /dev/null +++ b/winml/test/adapter/adapter_test.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test.h" +#include "core/providers/winml/winml_provider_factory.h" +#include "winml_adapter_c_api.h" + +struct AdapterTestApi +{ + SetupClass AdapterTestSetup; + TeardownClass AdapterTestTeardown; + VoidTest CreateModelFromPath; + VoidTest CreateModelFromData; + VoidTest CloneModel; + VoidTest ModelGetAuthor; + VoidTest ModelGetName; + VoidTest ModelGetDomain; + VoidTest ModelGetDescription; + VoidTest ModelGetVersion; + VoidTest ModelGetInputCount; + VoidTest ModelGetOutputCount; + VoidTest ModelGetInputName; + VoidTest ModelGetOutputName; + VoidTest ModelGetInputDescription; + VoidTest ModelGetOutputDescription; + VoidTest ModelGetInputTypeInfo; + VoidTest ModelGetOutputTypeInfo; + VoidTest ModelGetMetadataCount; + VoidTest ModelGetMetadata; + VoidTest ModelEnsureNoFloat16; + VoidTest EnvConfigureCustomLoggerAndProfiler; +}; +const AdapterTestApi& getapi(); +const WinmlAdapterApi* winml_adapter_api; +const OrtApi* ort_api; +OrtModel* squeezenet_model = nullptr; +OrtModel* metadata_model = nullptr; +OrtModel* float16_Model = nullptr; +static bool logging_function_called = false; +static bool profiling_function_called = false; + +WINML_TEST_CLASS_BEGIN(AdapterTest) +WINML_TEST_CLASS_SETUP_CLASS(AdapterTestSetup) +WINML_TEST_CLASS_TEARDOWN_CLASS(AdapterTestTeardown) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(AdapterTest, CreateModelFromPath) +WINML_TEST(AdapterTest, CreateModelFromData) +WINML_TEST(AdapterTest, CloneModel) +WINML_TEST(AdapterTest, ModelGetAuthor) +WINML_TEST(AdapterTest, ModelGetName) +WINML_TEST(AdapterTest, ModelGetDomain) +WINML_TEST(AdapterTest, ModelGetDescription) +WINML_TEST(AdapterTest, ModelGetVersion) +WINML_TEST(AdapterTest, ModelGetInputCount) +WINML_TEST(AdapterTest, ModelGetOutputCount) +WINML_TEST(AdapterTest, ModelGetInputName) +WINML_TEST(AdapterTest, ModelGetOutputName) +WINML_TEST(AdapterTest, ModelGetInputDescription) +WINML_TEST(AdapterTest, ModelGetOutputDescription) +WINML_TEST(AdapterTest, ModelGetInputTypeInfo) +WINML_TEST(AdapterTest, ModelGetOutputTypeInfo) +WINML_TEST(AdapterTest, ModelGetMetadataCount) +WINML_TEST(AdapterTest, ModelGetMetadata) +WINML_TEST(AdapterTest, ModelEnsureNoFloat16) +WINML_TEST(AdapterTest, EnvConfigureCustomLoggerAndProfiler) +WINML_TEST_CLASS_END() diff --git a/winml/test/api/APITest.h b/winml/test/api/APITest.h index dfa64f26830d1..0bff68e695dd2 100644 --- a/winml/test/api/APITest.h +++ b/winml/test/api/APITest.h @@ -2,22 +2,25 @@ // Licensed under the MIT License. #pragma once + #include "fileHelpers.h" +#include "winrt_headers.h" + namespace APITest { static void LoadModel(const std::wstring& modelPath, - winrt::Windows::AI::MachineLearning::LearningModel& learningModel) { + winml::LearningModel& learningModel) { std::wstring fullPath = FileHelpers::GetModulePath() + modelPath; - learningModel = winrt::Windows::AI::MachineLearning::LearningModel::LoadFromFilePath(fullPath); + learningModel = winml::LearningModel::LoadFromFilePath(fullPath); }; -static uint64_t GetAdapterIdQuadPart(winrt::Windows::AI::MachineLearning::LearningModelDevice& device) { +static uint64_t GetAdapterIdQuadPart(winml::LearningModelDevice& device) { LARGE_INTEGER id; id.LowPart = device.AdapterId().LowPart; id.HighPart = device.AdapterId().HighPart; return id.QuadPart; }; -static _LUID GetAdapterIdAsLUID(winrt::Windows::AI::MachineLearning::LearningModelDevice& device) { +static _LUID GetAdapterIdAsLUID(winml::LearningModelDevice& device) { _LUID id; id.LowPart = device.AdapterId().LowPart; id.HighPart = device.AdapterId().HighPart; diff --git a/winml/test/api/LearningModelAPITest.cpp b/winml/test/api/LearningModelAPITest.cpp index d07981fe28de0..a6cd9ff2d68f9 100644 --- a/winml/test/api/LearningModelAPITest.cpp +++ b/winml/test/api/LearningModelAPITest.cpp @@ -4,26 +4,21 @@ #include "testPch.h" #include "LearningModelAPITest.h" #include "APITest.h" -#include -#include -#include -#include using namespace winrt; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Foundation::Collections; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Storage; -using namespace winrt::Windows::Storage::Streams; - -static void LearningModelAPITestSetup() { +using namespace winml; +using namespace wfc; +using namespace wgi; +using namespace wm; +using namespace ws; +using namespace wss; + +static void LearningModelAPITestsClassSetup() { init_apartment(); } -static void LearningModelAPITestGpuSetup() { +static void LearningModelAPITestsGpuMethodSetup() { GPUTEST; - init_apartment(); } static void CreateModelFromFilePath() { @@ -33,7 +28,7 @@ static void CreateModelFromFilePath() { static void CreateModelFromIStorage() { std::wstring path = FileHelpers::GetModulePath() + L"squeezenet_modifiedforruntimestests.onnx"; - auto storageFile = winrt::Windows::Storage::StorageFile::GetFileFromPathAsync(path).get(); + auto storageFile = ws::StorageFile::GetFileFromPathAsync(path).get(); LearningModel learningModel = nullptr; WINML_EXPECT_NO_THROW(learningModel = LearningModel::LoadFromStorageFileAsync(storageFile).get()); WINML_EXPECT_TRUE(learningModel != nullptr); @@ -45,7 +40,7 @@ static void CreateModelFromIStorage() { static void CreateModelFromIStorageOutsideCwd() { std::wstring path = FileHelpers::GetModulePath() + L"ModelSubdirectory\\ModelInSubdirectory.onnx"; - auto storageFile = winrt::Windows::Storage::StorageFile::GetFileFromPathAsync(path).get(); + auto storageFile = ws::StorageFile::GetFileFromPathAsync(path).get(); LearningModel learningModel = nullptr; WINML_EXPECT_NO_THROW(learningModel = LearningModel::LoadFromStorageFileAsync(storageFile).get()); WINML_EXPECT_TRUE(learningModel != nullptr); @@ -57,8 +52,8 @@ static void CreateModelFromIStorageOutsideCwd() { static void CreateModelFromIStream() { std::wstring path = FileHelpers::GetModulePath() + L"squeezenet_modifiedforruntimestests.onnx"; - auto storageFile = winrt::Windows::Storage::StorageFile::GetFileFromPathAsync(path).get(); - winrt::Windows::Storage::Streams::IRandomAccessStreamReference streamref; + auto storageFile = ws::StorageFile::GetFileFromPathAsync(path).get(); + ws::Streams::IRandomAccessStreamReference streamref; storageFile.as(streamref); LearningModel learningModel = nullptr; WINML_EXPECT_NO_THROW(learningModel = LearningModel::LoadFromStreamAsync(streamref).get()); @@ -254,11 +249,11 @@ static void CloseModelNoNewSessions() { }); } -const LearningModelApiTestApi& getapi() { - static constexpr LearningModelApiTestApi api = +const LearningModelApiTestsApi& getapi() { + static constexpr LearningModelApiTestsApi api = { - LearningModelAPITestSetup, - LearningModelAPITestGpuSetup, + LearningModelAPITestsClassSetup, + LearningModelAPITestsGpuMethodSetup, CreateModelFromFilePath, CreateModelFromIStorage, CreateModelFromIStorageOutsideCwd, diff --git a/winml/test/api/LearningModelAPITest.h b/winml/test/api/LearningModelAPITest.h index 46d815fc27579..376fb661f7175 100644 --- a/winml/test/api/LearningModelAPITest.h +++ b/winml/test/api/LearningModelAPITest.h @@ -2,10 +2,10 @@ // Licensed under the MIT License. #include "test.h" -struct LearningModelApiTestApi +struct LearningModelApiTestsApi { - SetupTest LearningModelAPITestSetup; - SetupTest LearningModelAPITestGpuSetup; + SetupClass LearningModelAPITestsClassSetup; + SetupTest LearningModelAPITestsGpuMethodSetup; VoidTest CreateModelFromFilePath; VoidTest CreateModelFromIStorage; VoidTest CreateModelFromIStorageOutsideCwd; @@ -21,24 +21,29 @@ struct LearningModelApiTestApi VoidTest CloseModelCheckEval; VoidTest CloseModelNoNewSessions; }; -const LearningModelApiTestApi& getapi(); +const LearningModelApiTestsApi& getapi(); -WINML_TEST_CLASS_BEGIN_WITH_SETUP(LearningModelAPITest, LearningModelAPITestSetup) -WINML_TEST(LearningModelAPITest, CreateModelFromFilePath) -WINML_TEST(LearningModelAPITest, CreateModelFromIStorage) -WINML_TEST(LearningModelAPITest, CreateModelFromIStorageOutsideCwd) -WINML_TEST(LearningModelAPITest, CreateModelFromIStream) -WINML_TEST(LearningModelAPITest, ModelGetAuthor) -WINML_TEST(LearningModelAPITest, ModelGetName) -WINML_TEST(LearningModelAPITest, ModelGetDomain) -WINML_TEST(LearningModelAPITest, ModelGetDescription) -WINML_TEST(LearningModelAPITest, ModelGetVersion) -WINML_TEST(LearningModelAPITest, EnumerateInputs) -WINML_TEST(LearningModelAPITest, EnumerateOutputs) -WINML_TEST(LearningModelAPITest, CloseModelCheckMetadata) -WINML_TEST(LearningModelAPITest, CloseModelNoNewSessions) +WINML_TEST_CLASS_BEGIN(LearningModelAPITests) +WINML_TEST_CLASS_SETUP_CLASS(LearningModelAPITestsClassSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(LearningModelAPITests, CreateModelFromFilePath) +WINML_TEST(LearningModelAPITests, CreateModelFromIStorage) +WINML_TEST(LearningModelAPITests, CreateModelFromIStorageOutsideCwd) +WINML_TEST(LearningModelAPITests, CreateModelFromIStream) +WINML_TEST(LearningModelAPITests, ModelGetAuthor) +WINML_TEST(LearningModelAPITests, ModelGetName) +WINML_TEST(LearningModelAPITests, ModelGetDomain) +WINML_TEST(LearningModelAPITests, ModelGetDescription) +WINML_TEST(LearningModelAPITests, ModelGetVersion) +WINML_TEST(LearningModelAPITests, EnumerateInputs) +WINML_TEST(LearningModelAPITests, EnumerateOutputs) +WINML_TEST(LearningModelAPITests, CloseModelCheckMetadata) +WINML_TEST(LearningModelAPITests, CloseModelNoNewSessions) WINML_TEST_CLASS_END() -WINML_TEST_CLASS_BEGIN_WITH_SETUP(LearningModelAPITestGpu, LearningModelAPITestGpuSetup) -WINML_TEST(LearningModelAPITestGpu, CloseModelCheckEval) +WINML_TEST_CLASS_BEGIN(LearningModelAPITestsGpu) +WINML_TEST_CLASS_SETUP_CLASS(LearningModelAPITestsClassSetup) +WINML_TEST_CLASS_SETUP_METHOD(LearningModelAPITestsGpuMethodSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(LearningModelAPITestsGpu, CloseModelCheckEval) WINML_TEST_CLASS_END() \ No newline at end of file diff --git a/winml/test/api/LearningModelBindingAPITest.cpp b/winml/test/api/LearningModelBindingAPITest.cpp index b09ed33dbab25..b1a345a8c6ca2 100644 --- a/winml/test/api/LearningModelBindingAPITest.cpp +++ b/winml/test/api/LearningModelBindingAPITest.cpp @@ -7,24 +7,21 @@ #include "LearningModelBindingAPITest.h" #include "SqueezeNetValidator.h" -#include -#include -#include "winrt/Windows.Storage.h" #include + using namespace winrt; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Foundation::Collections; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Storage; +using namespace winml; +using namespace wfc; +using namespace wgi; +using namespace wm; +using namespace ws; -static void LearningModelBindingAPITestSetup() { +static void LearningModelBindingAPITestsClassSetup() { init_apartment(); } -static void LearningModelBindingAPITestGpuSetup() { +static void LearningModelBindingAPITestsGpuMethodSetup() { GPUTEST; - init_apartment(); } static void CpuSqueezeNet() @@ -107,7 +104,7 @@ static void DictionaryVectorizerMapInt64() // Bind as IMap auto abiMap = winrt::single_threaded_map(std::move(map)); binding.Bind(mapInputName, abiMap); - auto mapInputInspectable = abiMap.as(); + auto mapInputInspectable = abiMap.as(); auto first = binding.First(); WINML_EXPECT_TRUE(first.Current().Key() == mapInputName); WINML_EXPECT_TRUE(first.Current().Value() == mapInputInspectable); @@ -116,7 +113,7 @@ static void DictionaryVectorizerMapInt64() // Bind as IMapView auto mapView = abiMap.GetView(); binding.Bind(mapInputName, mapView); - mapInputInspectable = mapView.as(); + mapInputInspectable = mapView.as(); first = binding.First(); WINML_EXPECT_TRUE(first.Current().Key() == mapInputName); WINML_EXPECT_TRUE(first.Current().Value() == mapView); @@ -152,7 +149,7 @@ static void DictionaryVectorizerMapString() auto abiMap = winrt::single_threaded_map(std::move(map)); binding.Bind(mapInputName, abiMap); - auto mapInputInspectable = abiMap.as(); + auto mapInputInspectable = abiMap.as(); auto first = binding.First(); WINML_EXPECT_TRUE(first.Current().Key() == mapInputName); WINML_EXPECT_TRUE(first.Current().Value() == mapInputInspectable); @@ -162,7 +159,7 @@ static void DictionaryVectorizerMapString() } static void RunZipMapInt64( - winrt::Windows::AI::MachineLearning::LearningModel model, + winml::LearningModel model, OutputBindingStrategy bindingStrategy) { auto outputFeatures = model.OutputFeatures(); @@ -697,7 +694,7 @@ static void SequenceConstructTensorString() WINML_EXPECT_NO_THROW(learningModelBinding.Bind(L"tensor2", input2)); auto results = learningModelSession.Evaluate(learningModelBinding, L""); - auto output_sequence = results.Outputs().Lookup(L"output_sequence").as>(); + auto output_sequence = results.Outputs().Lookup(L"output_sequence").as>(); WINML_EXPECT_EQUAL(static_cast(2), output_sequence.Size()); WINML_EXPECT_EQUAL(2, output_sequence.GetAt(0).Shape().GetAt(0)); WINML_EXPECT_EQUAL(3, output_sequence.GetAt(0).Shape().GetAt(1)); @@ -714,11 +711,11 @@ static void SequenceConstructTensorString() WINML_EXPECT_EQUAL(3, bound_output_sequence.GetAt(1).Shape().GetAt(1)); } -const LearningModelBindingAPITestApi& getapi() { - static constexpr LearningModelBindingAPITestApi api = +const LearningModelBindingAPITestsApi& getapi() { + static constexpr LearningModelBindingAPITestsApi api = { - LearningModelBindingAPITestSetup, - LearningModelBindingAPITestGpuSetup, + LearningModelBindingAPITestsClassSetup, + LearningModelBindingAPITestsGpuMethodSetup, CpuSqueezeNet, CpuSqueezeNetEmptyOutputs, CpuSqueezeNetUnboundOutputs, diff --git a/winml/test/api/LearningModelBindingAPITest.h b/winml/test/api/LearningModelBindingAPITest.h index eccacda971363..b752c37065e6b 100644 --- a/winml/test/api/LearningModelBindingAPITest.h +++ b/winml/test/api/LearningModelBindingAPITest.h @@ -3,9 +3,9 @@ #include "test.h" -struct LearningModelBindingAPITestApi { - SetupTest LearningModelBindingAPITestSetup; - SetupTest LearningModelBindingAPITestGpuSetup; +struct LearningModelBindingAPITestsApi { + SetupClass LearningModelBindingAPITestsClassSetup; + SetupTest LearningModelBindingAPITestsGpuMethodSetup; VoidTest CpuSqueezeNet; VoidTest CpuSqueezeNetEmptyOutputs; VoidTest CpuSqueezeNetUnboundOutputs; @@ -27,30 +27,35 @@ struct LearningModelBindingAPITestApi { VoidTest SequenceLengthTensorFloat; VoidTest SequenceConstructTensorString; }; -const LearningModelBindingAPITestApi& getapi(); +const LearningModelBindingAPITestsApi& getapi(); -WINML_TEST_CLASS_BEGIN_WITH_SETUP(LearningModelBindingAPITest, LearningModelBindingAPITestSetup) -WINML_TEST(LearningModelBindingAPITest, CpuSqueezeNet) -WINML_TEST(LearningModelBindingAPITest, CpuSqueezeNetEmptyOutputs) -WINML_TEST(LearningModelBindingAPITest, CpuSqueezeNetUnboundOutputs) -WINML_TEST(LearningModelBindingAPITest, CpuSqueezeNetBindInputTensorAsInspectable) -WINML_TEST(LearningModelBindingAPITest, CastMapInt64) -WINML_TEST(LearningModelBindingAPITest, DictionaryVectorizerMapInt64) -WINML_TEST(LearningModelBindingAPITest, DictionaryVectorizerMapString) -WINML_TEST(LearningModelBindingAPITest, ZipMapInt64) -WINML_TEST(LearningModelBindingAPITest, ZipMapInt64Unbound) -WINML_TEST(LearningModelBindingAPITest, ZipMapString) -WINML_TEST(LearningModelBindingAPITest, VerifyOutputAfterEvaluateAsyncCalledTwice) -WINML_TEST(LearningModelBindingAPITest, VerifyOutputAfterImageBindCalledTwice) -WINML_TEST(LearningModelBindingAPITest, SequenceLengthTensorFloat) -WINML_TEST(LearningModelBindingAPITest, SequenceConstructTensorString) +WINML_TEST_CLASS_BEGIN(LearningModelBindingAPITests) +WINML_TEST_CLASS_SETUP_CLASS(LearningModelBindingAPITestsClassSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(LearningModelBindingAPITests, CpuSqueezeNet) +WINML_TEST(LearningModelBindingAPITests, CpuSqueezeNetEmptyOutputs) +WINML_TEST(LearningModelBindingAPITests, CpuSqueezeNetUnboundOutputs) +WINML_TEST(LearningModelBindingAPITests, CpuSqueezeNetBindInputTensorAsInspectable) +WINML_TEST(LearningModelBindingAPITests, CastMapInt64) +WINML_TEST(LearningModelBindingAPITests, DictionaryVectorizerMapInt64) +WINML_TEST(LearningModelBindingAPITests, DictionaryVectorizerMapString) +WINML_TEST(LearningModelBindingAPITests, ZipMapInt64) +WINML_TEST(LearningModelBindingAPITests, ZipMapInt64Unbound) +WINML_TEST(LearningModelBindingAPITests, ZipMapString) +WINML_TEST(LearningModelBindingAPITests, VerifyOutputAfterEvaluateAsyncCalledTwice) +WINML_TEST(LearningModelBindingAPITests, VerifyOutputAfterImageBindCalledTwice) +WINML_TEST(LearningModelBindingAPITests, SequenceLengthTensorFloat) +WINML_TEST(LearningModelBindingAPITests, SequenceConstructTensorString) WINML_TEST_CLASS_END() -WINML_TEST_CLASS_BEGIN_WITH_SETUP(LearningModelBindingAPITestGpu, LearningModelBindingAPITestGpuSetup) -WINML_TEST(LearningModelBindingAPITestGpu, GpuSqueezeNet) -WINML_TEST(LearningModelBindingAPITestGpu, GpuSqueezeNetEmptyOutputs) -WINML_TEST(LearningModelBindingAPITestGpu, GpuSqueezeNetUnboundOutputs) -WINML_TEST(LearningModelBindingAPITestGpu, ImageBindingDimensions) -WINML_TEST(LearningModelBindingAPITestGpu, VerifyInvalidBindExceptions) -WINML_TEST(LearningModelBindingAPITestGpu, BindInvalidInputName) +WINML_TEST_CLASS_BEGIN(LearningModelBindingAPITestsGPU) +WINML_TEST_CLASS_SETUP_CLASS(LearningModelBindingAPITestsClassSetup) +WINML_TEST_CLASS_SETUP_METHOD(LearningModelBindingAPITestsGpuMethodSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(LearningModelBindingAPITestsGPU, GpuSqueezeNet) +WINML_TEST(LearningModelBindingAPITestsGPU, GpuSqueezeNetEmptyOutputs) +WINML_TEST(LearningModelBindingAPITestsGPU, GpuSqueezeNetUnboundOutputs) +WINML_TEST(LearningModelBindingAPITestsGPU, ImageBindingDimensions) +WINML_TEST(LearningModelBindingAPITestsGPU, VerifyInvalidBindExceptions) +WINML_TEST(LearningModelBindingAPITestsGPU, BindInvalidInputName) WINML_TEST_CLASS_END() \ No newline at end of file diff --git a/winml/test/api/LearningModelSessionAPITest.cpp b/winml/test/api/LearningModelSessionAPITest.cpp index a95944d964285..2aa98d0b78497 100644 --- a/winml/test/api/LearningModelSessionAPITest.cpp +++ b/winml/test/api/LearningModelSessionAPITest.cpp @@ -14,23 +14,22 @@ #include "Psapi.h" using namespace winrt; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Foundation::Collections; +using namespace winml; +using namespace wfc; -using winrt::Windows::Foundation::IPropertyValue; +using wf::IPropertyValue; -static void LearningModelSessionAPITestSetup() { +static void LearningModelSessionAPITestsClassSetup() { init_apartment(); } -static void LearningModelSessionAPITestGpuSetup() { +static void LearningModelSessionAPITestsGpuMethodSetup() { GPUTEST; - init_apartment(); } -static void LearningModelSessionAPITestsSkipEdgeCoreSetup() { - SKIP_EDGECORE; - LearningModelSessionAPITestGpuSetup(); +static void LearningModelSessionAPITestsGpuSkipEdgeCoreMethodSetup() { + LearningModelSessionAPITestsGpuMethodSetup(); + SKIP_EDGECORE } static void CreateSessionDeviceDefault() @@ -64,7 +63,7 @@ static void CreateSessionWithModelLoadedFromStream() LearningModel learningModel = nullptr; LearningModelDevice learningModelDevice = nullptr; std::wstring path = FileHelpers::GetModulePath() + L"model.onnx"; - auto storageFile = winrt::Windows::Storage::StorageFile::GetFileFromPathAsync(path).get(); + auto storageFile = ws::StorageFile::GetFileFromPathAsync(path).get(); WINML_EXPECT_NO_THROW(learningModel = LearningModel::LoadFromStream(storageFile)); @@ -167,7 +166,7 @@ static void EvaluateFeatures() auto outputTensor = TensorString::Create(); - std::map featuresstandardmap; + std::map featuresstandardmap; featuresstandardmap[L"X"] = tensor; featuresstandardmap[L"Y"] = outputTensor; auto featureswinrtmap = winrt::single_threaded_map(std::move(featuresstandardmap)); @@ -201,7 +200,7 @@ static void EvaluateFeaturesAsync() auto outputTensor = TensorString::Create(shape); - std::map featuresstandardmap; + std::map featuresstandardmap; featuresstandardmap[L"X"] = tensor; featuresstandardmap[L"Y"] = outputTensor; auto featureswinrtmap = winrt::single_threaded_map(std::move(featuresstandardmap)); @@ -221,7 +220,7 @@ static void EvaluationProperties() LearningModelSession learningModelSession = nullptr; learningModelSession = LearningModelSession(learningModel); // set a property - auto value = winrt::Windows::Foundation::PropertyValue::CreateBoolean(true); + auto value = wf::PropertyValue::CreateBoolean(true); learningModelSession.EvaluationProperties().Insert(L"propName1", value); // get the property and make sure it's there with the right value auto value2 = learningModelSession.EvaluationProperties().Lookup(L"propName1"); @@ -398,12 +397,12 @@ static void CloseSession() }); } -const LearningModelSesssionAPITestApi& getapi() { - static constexpr LearningModelSesssionAPITestApi api = +const LearningModelSesssionAPITestsApi& getapi() { + static constexpr LearningModelSesssionAPITestsApi api = { - LearningModelSessionAPITestSetup, - LearningModelSessionAPITestGpuSetup, - LearningModelSessionAPITestsSkipEdgeCoreSetup, + LearningModelSessionAPITestsClassSetup, + LearningModelSessionAPITestsGpuMethodSetup, + LearningModelSessionAPITestsGpuSkipEdgeCoreMethodSetup, CreateSessionDeviceDefault, CreateSessionDeviceCpu, CreateSessionWithModelLoadedFromStream, diff --git a/winml/test/api/LearningModelSessionAPITest.h b/winml/test/api/LearningModelSessionAPITest.h index 618fab0ea7628..02b7ad2d054b9 100644 --- a/winml/test/api/LearningModelSessionAPITest.h +++ b/winml/test/api/LearningModelSessionAPITest.h @@ -3,10 +3,10 @@ #include "test.h" -struct LearningModelSesssionAPITestApi { - SetupTest LearningModelSessionAPITestSetup; - SetupTest LearningModelSessionAPITestGpuSetup; - SetupTest LearningModelSessionAPITestsSkipEdgeCoreSetup; +struct LearningModelSesssionAPITestsApi { + SetupClass LearningModelSessionAPITestsClassSetup; + SetupTest LearningModelSessionAPITestsGpuMethodSetup; + SetupTest LearningModelSessionAPITestsGpuSkipEdgeCoreMethodSetup; VoidTest CreateSessionDeviceDefault; VoidTest CreateSessionDeviceCpu; VoidTest CreateSessionWithModelLoadedFromStream; @@ -22,26 +22,34 @@ struct LearningModelSesssionAPITestApi { VoidTest EvaluateSessionAndCloseModel; VoidTest CloseSession; }; -const LearningModelSesssionAPITestApi& getapi(); +const LearningModelSesssionAPITestsApi& getapi(); -WINML_TEST_CLASS_BEGIN_WITH_SETUP(LearningModelSessionAPITest, LearningModelSessionAPITestSetup) -WINML_TEST(LearningModelSessionAPITest, CreateSessionDeviceDefault) -WINML_TEST(LearningModelSessionAPITest,CreateSessionDeviceCpu) -WINML_TEST(LearningModelSessionAPITest,CreateSessionWithModelLoadedFromStream) -WINML_TEST(LearningModelSessionAPITest,EvaluateFeatures) -WINML_TEST(LearningModelSessionAPITest,EvaluateFeaturesAsync) -WINML_TEST(LearningModelSessionAPITest,EvaluationProperties) -WINML_TEST(LearningModelSessionAPITest,EvaluateSessionAndCloseModel) +WINML_TEST_CLASS_BEGIN(LearningModelSessionAPITests) +WINML_TEST_CLASS_SETUP_CLASS(LearningModelSessionAPITestsClassSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(LearningModelSessionAPITests, CreateSessionDeviceDefault) +WINML_TEST(LearningModelSessionAPITests,CreateSessionDeviceCpu) +WINML_TEST(LearningModelSessionAPITests,CreateSessionWithModelLoadedFromStream) +WINML_TEST(LearningModelSessionAPITests,EvaluateFeatures) +WINML_TEST(LearningModelSessionAPITests,EvaluateFeaturesAsync) +WINML_TEST(LearningModelSessionAPITests,EvaluationProperties) +WINML_TEST(LearningModelSessionAPITests,EvaluateSessionAndCloseModel) WINML_TEST_CLASS_END() -WINML_TEST_CLASS_BEGIN_WITH_SETUP(LearningModelSessionAPITestGpu, LearningModelSessionAPITestGpuSetup) -WINML_TEST(LearningModelSessionAPITestGpu, CreateSessionDeviceDirectX) -WINML_TEST(LearningModelSessionAPITestGpu, CreateSessionDeviceDirectXHighPerformance) -WINML_TEST(LearningModelSessionAPITestGpu, CreateSessionDeviceDirectXMinimumPower) -WINML_TEST(LearningModelSessionAPITestGpu, CreateSessionWithCastToFloat16InModel) -WINML_TEST(LearningModelSessionAPITestGpu, DISABLED_CreateSessionWithFloat16InitializersInModel) +WINML_TEST_CLASS_BEGIN(LearningModelSessionAPITestsGpu) +WINML_TEST_CLASS_SETUP_CLASS(LearningModelSessionAPITestsClassSetup) +WINML_TEST_CLASS_SETUP_METHOD(LearningModelSessionAPITestsGpuMethodSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(LearningModelSessionAPITestsGpu, CreateSessionDeviceDirectX) +WINML_TEST(LearningModelSessionAPITestsGpu, CreateSessionDeviceDirectXHighPerformance) +WINML_TEST(LearningModelSessionAPITestsGpu, CreateSessionDeviceDirectXMinimumPower) +WINML_TEST(LearningModelSessionAPITestsGpu, CreateSessionWithCastToFloat16InModel) +WINML_TEST(LearningModelSessionAPITestsGpu, DISABLED_CreateSessionWithFloat16InitializersInModel) WINML_TEST_CLASS_END() -WINML_TEST_CLASS_BEGIN_WITH_SETUP(LearningModelSessionAPITestsSkipEdgeCore, LearningModelSessionAPITestsSkipEdgeCoreSetup) -WINML_TEST(LearningModelSessionAPITestsSkipEdgeCore, AdapterIdAndDevice) +WINML_TEST_CLASS_BEGIN(LearningModelSessionAPITestsGpuSkipEdgeCore) +WINML_TEST_CLASS_SETUP_CLASS(LearningModelSessionAPITestsClassSetup) +WINML_TEST_CLASS_SETUP_METHOD(LearningModelSessionAPITestsGpuSkipEdgeCoreMethodSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(LearningModelSessionAPITestsGpuSkipEdgeCore, AdapterIdAndDevice) WINML_TEST_CLASS_END() \ No newline at end of file diff --git a/winml/test/common/SqueezeNetValidator.cpp b/winml/test/common/SqueezeNetValidator.cpp index 95ee72c0386f0..3c85ef2c351dc 100644 --- a/winml/test/common/SqueezeNetValidator.cpp +++ b/winml/test/common/SqueezeNetValidator.cpp @@ -10,13 +10,13 @@ #include #include #include -// using namespace winrt::Windows::Foundation; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Foundation::Collections; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Storage; -using namespace winrt::Windows::Storage::Streams; + +using namespace wfc; +using namespace wgi; +using namespace wm; +using namespace ws; +using namespace wss; +using namespace winml; namespace WinML::Engine::Test{ diff --git a/winml/test/common/SqueezeNetValidator.h b/winml/test/common/SqueezeNetValidator.h index 5c6e6fa03d95b..df09641b3792d 100644 --- a/winml/test/common/SqueezeNetValidator.h +++ b/winml/test/common/SqueezeNetValidator.h @@ -4,6 +4,7 @@ #pragma once #include "std.h" +#include "winrt_headers.h" enum OutputBindingStrategy { Bound, Unbound, Empty }; @@ -11,14 +12,14 @@ namespace WinML::Engine::Test::ModelValidator { void FnsCandy16( const std::string& instance, - winrt::Windows::AI::MachineLearning::LearningModelDeviceKind deviceKind, + winml::LearningModelDeviceKind deviceKind, OutputBindingStrategy outputBindingStrategy, bool bindInputsAsIInspectable, float dataTolerance = false); void SqueezeNet( const std::string& instance, - winrt::Windows::AI::MachineLearning::LearningModelDeviceKind deviceKind, + winml::LearningModelDeviceKind deviceKind, float dataTolerance, bool bindAsImage = false, OutputBindingStrategy outputBindingStrategy = OutputBindingStrategy::Bound, diff --git a/winml/test/common/dllload.cpp b/winml/test/common/dllload.cpp index db27ad1741a07..00921fd72257d 100644 --- a/winml/test/common/dllload.cpp +++ b/winml/test/common/dllload.cpp @@ -27,12 +27,16 @@ HRESULT __stdcall WINRT_RoGetActivationFactory(HSTRING classId_hstring, GUID con std::wstring_view name{ WindowsGetStringRawBuffer(classId_hstring, nullptr), WindowsGetStringLen(classId_hstring) }; HMODULE library{ nullptr }; - std::wstring winmlDllPath = FileHelpers::GetWinMLPath() + L"Windows.AI.MachineLearning.dll"; + std::wostringstream dll; + dll << BINARY_NAME; - if (starts_with(name, L"Windows.AI.MachineLearning.")) + std::wstring winml_dll_name = dll.str(); + std::wstring winml_dll_path = FileHelpers::GetWinMLPath() + winml_dll_name; + std::wstring winml_dll_prefix = winml_dll_name.substr(0, winml_dll_name.size() - 3); + if (starts_with(name, winml_dll_prefix)) { - const wchar_t* libPath = winmlDllPath.c_str(); - library = LoadLibraryExW(libPath, nullptr, 0); + const wchar_t* lib_path = winml_dll_path.c_str(); + library = LoadLibraryExW(lib_path, nullptr, 0); } else { @@ -54,7 +58,7 @@ HRESULT __stdcall WINRT_RoGetActivationFactory(HSTRING classId_hstring, GUID con return hr; } - winrt::com_ptr activation_factory; + winrt::com_ptr activation_factory; HRESULT const hr = call(classId_hstring, activation_factory.put_void()); if (FAILED(hr)) @@ -63,7 +67,7 @@ HRESULT __stdcall WINRT_RoGetActivationFactory(HSTRING classId_hstring, GUID con return hr; } - if (winrt::guid(iid) != winrt::guid_of()) + if (winrt::guid(iid) != winrt::guid_of()) { return activation_factory->QueryInterface(iid, factory); } diff --git a/winml/test/common/fileHelpers.cpp b/winml/test/common/fileHelpers.cpp index 3fa43b44715cc..d1129de9d0b69 100644 --- a/winml/test/common/fileHelpers.cpp +++ b/winml/test/common/fileHelpers.cpp @@ -9,12 +9,12 @@ EXTERN_C IMAGE_DOS_HEADER __ImageBase; using namespace winrt; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Storage; +using namespace winml; +using namespace wgi; +using namespace ws; namespace FileHelpers -{ +{ std::wstring GetModulePath() { std::wstring val; @@ -44,7 +44,7 @@ namespace FileHelpers } - winrt::Windows::Graphics::Imaging::SoftwareBitmap GetSoftwareBitmapFromFile(const std::wstring& filePath) + wgi::SoftwareBitmap GetSoftwareBitmapFromFile(const std::wstring& filePath) { auto storageFile = StorageFile::GetFileFromPathAsync(filePath).get(); auto stream = storageFile.OpenAsync(FileAccessMode::Read).get(); @@ -61,10 +61,10 @@ namespace FileHelpers return softwareBitmap; } - ImageFeatureValue LoadImageFeatureValue(const std::wstring& imagePath) + winml::ImageFeatureValue LoadImageFeatureValue(const std::wstring& imagePath) { auto softwareBitmap = FileHelpers::GetSoftwareBitmapFromFile(FileHelpers::GetModulePath() + imagePath); - auto videoFrame = winrt::Windows::Media::VideoFrame::CreateWithSoftwareBitmap(softwareBitmap); + auto videoFrame = wm::VideoFrame::CreateWithSoftwareBitmap(softwareBitmap); return ImageFeatureValue::CreateFromVideoFrame(videoFrame); } } diff --git a/winml/test/common/fileHelpers.h b/winml/test/common/fileHelpers.h index 9426744914cb1..e8a246b50359e 100644 --- a/winml/test/common/fileHelpers.h +++ b/winml/test/common/fileHelpers.h @@ -2,14 +2,17 @@ // Licensed under the MIT License. #pragma once + +#include "std.h" +#include "winrt_headers.h" + #include "winrt/Windows.Graphics.Imaging.h" -#include "winrt/Windows.AI.MachineLearning.h" namespace FileHelpers { std::wstring GetModulePath(); std::wstring GetWinMLPath(); - winrt::Windows::Graphics::Imaging::SoftwareBitmap GetSoftwareBitmapFromFile(const std::wstring& filePath); - winrt::Windows::AI::MachineLearning::ImageFeatureValue LoadImageFeatureValue(const std::wstring& imagePath); + wgi::SoftwareBitmap GetSoftwareBitmapFromFile(const std::wstring& filePath); + winml::ImageFeatureValue LoadImageFeatureValue(const std::wstring& imagePath); } diff --git a/winml/test/common/googleTestMacros.h b/winml/test/common/googleTestMacros.h index 8e55e6b762f51..19e054b0959be 100644 --- a/winml/test/common/googleTestMacros.h +++ b/winml/test/common/googleTestMacros.h @@ -12,22 +12,39 @@ getapi().test_name(); \ } -#define WINML_TEST_CLASS_BEGIN_NO_SETUP(test_class_name) \ - namespace { \ - class test_class_name : public ::testing::Test { \ - }; - -#define WINML_TEST_CLASS_BEGIN_WITH_SETUP(test_class_name, setup_method) \ - namespace { \ - class test_class_name : public ::testing::Test { \ - protected: \ - void SetUp() override { \ - getapi().setup_method(); \ - } \ - }; +#define WINML_TEST_CLASS_BEGIN(test_class_name) \ + namespace { \ + class test_class_name : public ::testing::Test { +#define WINML_TEST_CLASS_SETUP_CLASS(setup_class) \ + protected: \ + static void SetUpTestSuite() { \ + getapi().setup_class(); \ + } + +#define WINML_TEST_CLASS_TEARDOWN_CLASS(teardown_class) \ + protected: \ + static void TearDownTestSuite() { \ + getapi().teardown_class(); \ + } + +#define WINML_TEST_CLASS_SETUP_METHOD(setup_method) \ + protected: \ + void SetUp() override { \ + getapi().setup_method(); \ + } + +#define WINML_TEST_CLASS_TEARDOWN_METHOD(teardown_method) \ + protected: \ + void TearDown() override { \ + getapi().teardown_method(); \ + } + +#define WINML_TEST_CLASS_BEGIN_TESTS }; + #define WINML_TEST_CLASS_END() } + // For old versions of gtest without GTEST_SKIP, stream the message and return success instead #ifndef GTEST_SKIP #define GTEST_SKIP_(message) \ @@ -72,11 +89,13 @@ WINML_SKIP_TEST("GPU tests disabled because this is a WinML only build (no DML)") #define GPUTEST_ENABLED alwaysFalse() #else -#define GPUTEST \ - if (auto no_gpu_tests = RuntimeParameters::Parameters.find("noGPUtests"); \ - no_gpu_tests != RuntimeParameters::Parameters.end() && no_gpu_tests->second != "0") { \ - WINML_SKIP_TEST("GPU tests disabled"); \ - } +#define GPUTEST \ + do { \ + if (auto no_gpu_tests = RuntimeParameters::Parameters.find("noGPUtests"); \ + no_gpu_tests != RuntimeParameters::Parameters.end() && no_gpu_tests->second != "0") { \ + WINML_SKIP_TEST("GPU tests disabled"); \ + } \ + } while (0) #define GPUTEST_ENABLED auto _no_gpu_tests = RuntimeParameters::Parameters.find("noGPUtests"); \ _no_gpu_tests == RuntimeParameters::Parameters.end() || _no_gpu_tests->second == "0" #endif diff --git a/winml/test/common/protobufHelpers.cpp b/winml/test/common/protobufHelpers.cpp index f109dc22b8713..9fb5fd19c699b 100644 --- a/winml/test/common/protobufHelpers.cpp +++ b/winml/test/common/protobufHelpers.cpp @@ -18,11 +18,9 @@ #include -#include "winrt/Windows.Storage.Streams.h" - -using namespace winrt::Windows::Storage::Streams; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Foundation::Collections; +using namespace wss; +using namespace wfc; +using namespace winml; // Copy and pasted from LOTUS as is. temporary code to load tensors from protobufs int FdOpen(const std::string& name) { @@ -192,8 +190,8 @@ TensorFloat16Bit ProtobufHelpers::LoadTensorFloat16FromProtobufFile( return nullptr; } -winrt::Windows::AI::MachineLearning::LearningModel ProtobufHelpers::CreateModel( - winrt::Windows::AI::MachineLearning::TensorKind kind, +winml::LearningModel ProtobufHelpers::CreateModel( + winml::TensorKind kind, const std::vector& shape, uint32_t num_elements) { onnx::ModelProto model; diff --git a/winml/test/common/protobufHelpers.h b/winml/test/common/protobufHelpers.h index 84c1b20883bc9..2f51b8f54733d 100644 --- a/winml/test/common/protobufHelpers.h +++ b/winml/test/common/protobufHelpers.h @@ -4,17 +4,18 @@ #pragma once #include "std.h" +#include "winrt_headers.h" namespace ProtobufHelpers { // LoadTensorFromProtobufFile take a path to a FP32 data file and loads it into a 32bit array or // 16bit array based on isFp16 - winrt::Windows::AI::MachineLearning::ITensor LoadTensorFromProtobufFile(const std::wstring& filePath, bool isFp16); + winml::ITensor LoadTensorFromProtobufFile(const std::wstring& filePath, bool isFp16); // LoadTensorFloat16FromProtobufFile takes a path to a FP16 data file and loads it into a 16bit array - winrt::Windows::AI::MachineLearning::TensorFloat16Bit LoadTensorFloat16FromProtobufFile(const std::wstring& filePath); + winml::TensorFloat16Bit LoadTensorFloat16FromProtobufFile(const std::wstring& filePath); - winrt::Windows::AI::MachineLearning::LearningModel CreateModel( - winrt::Windows::AI::MachineLearning::TensorKind kind, + winml::LearningModel CreateModel( + winml::TensorKind kind, const std::vector& shape, uint32_t num_elements = 1); } diff --git a/winml/test/common/std.h b/winml/test/common/std.h index a51abb0e91912..b8435dd3278a0 100644 --- a/winml/test/common/std.h +++ b/winml/test/common/std.h @@ -15,16 +15,6 @@ #include #include #include +#include -#include "test.h" - -// IUnknown must be declared before winrt/base.h is included to light up support for native COM -// interfaces with C++/WinRT types (e.g. winrt::com_ptr). -#include -#include -#include "winrt/base.h" -#include "winrt/Windows.Foundation.Collections.h" -#include "comp_generated/winrt/windows.ai.machinelearning.h" - -// WinML -#include "Windows.AI.MachineLearning.Native.h" +#include "test.h" \ No newline at end of file diff --git a/winml/test/common/taefTestMacros.h b/winml/test/common/taefTestMacros.h index 4b63201c594e3..54bd4de0caa82 100644 --- a/winml/test/common/taefTestMacros.h +++ b/winml/test/common/taefTestMacros.h @@ -8,14 +8,36 @@ using namespace WEX::TestExecution; #define WINML_EXPECT_NO_THROW(statement) VERIFY_NO_THROW(statement) -#define WINML_TEST_CLASS_BEGIN_WITH_SETUP(test_class_name, setup_method) \ - class test_class_name { \ - TEST_CLASS(test_class_name); \ - TEST_CLASS_SETUP(TestClassSetup) { \ - getapi().setup_method(); \ - return true; \ +#define WINML_TEST_CLASS_BEGIN(test_class_name) \ + class test_class_name { \ + TEST_CLASS(test_class_name); + +#define WINML_TEST_CLASS_SETUP_CLASS(setup_class) \ + TEST_CLASS_SETUP(TestMethodSetup) { \ + getapi().setup_class(); \ + return true; \ + } + +#define WINML_TEST_CLASS_TEARDOWN_CLASS(teardown_class) \ + TEST_CLASS_CLEANUP(TestClassCleanup) { \ + getapi().teardown_class(); \ + return true; \ + } + +#define WINML_TEST_CLASS_SETUP_METHOD(setup_method) \ + TEST_METHOD_SETUP(TestMethodSetup) { \ + getapi().setup_method(); \ + return true; \ + } + +#define WINML_TEST_CLASS_TEARDOWN_METHOD(teardown_method) \ + TEST_METHOD_CLEANUP(TestClassCleanup) { \ + getapi().teardown_method(); \ + return true; \ } +#define WINML_TEST_CLASS_BEGIN_TESTS + #define WINML_TEST_CLASS_END() \ } \ ; diff --git a/winml/test/common/test.h b/winml/test/common/test.h index 7c122e992cbec..e30736b0f7cc1 100644 --- a/winml/test/common/test.h +++ b/winml/test/common/test.h @@ -4,7 +4,10 @@ #pragma once using VoidTest = void (*)(); +using SetupClass = VoidTest; +using TeardownClass = VoidTest; using SetupTest = VoidTest; +using TeardownTest = VoidTest; constexpr bool alwaysTrue() { return true; diff --git a/winml/test/common/winrt_headers.h b/winml/test/common/winrt_headers.h new file mode 100644 index 0000000000000..ff371d36356ad --- /dev/null +++ b/winml/test/common/winrt_headers.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// IUnknown must be declared before winrt/base.h is included to light up support for native COM +// interfaces with C++/WinRT types (e.g. winrt::com_ptr). +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define STRINGIFY(x) #x +#define XSTRINGIFY(x) STRINGIFY(x) +#define CPPWINRT_HEADER(root_ns) comp_generated/winrt/##root_ns##.AI.MachineLearning.h +#define NATIVE_HEADER(root_ns) root_ns##.AI.MachineLearning.native.h +#define NATIVE_INTERNAL_HEADER(root_ns) root_ns##.AI.MachineLearning.native.internal.h +#define CREATE_CPPWINRT_COMPONENT_HEADER() XSTRINGIFY(CPPWINRT_HEADER(WINML_ROOT_NS)) +#define CREATE_NATIVE_HEADER() XSTRINGIFY(NATIVE_HEADER(WINML_ROOT_NS)) +#define CREATE_NATIVE_INTERNAL_HEADER() XSTRINGIFY(NATIVE_INTERNAL_HEADER(WINML_ROOT_NS)) + +#include CREATE_CPPWINRT_COMPONENT_HEADER() + +// WinML Native Headers +#include CREATE_NATIVE_HEADER() +#include CREATE_NATIVE_INTERNAL_HEADER() + +namespace winml = winrt::WINML_ROOT_NS::AI::MachineLearning; +namespace wf = winrt::Windows::Foundation; +namespace wfc = winrt::Windows::Foundation::Collections; +namespace wm = winrt::Windows::Media; +namespace wgi = winrt::Windows::Graphics::Imaging; +namespace wgdx = winrt::Windows::Graphics::DirectX; +namespace ws = winrt::Windows::Storage; +namespace wss = winrt::Windows::Storage::Streams; \ No newline at end of file diff --git a/winml/test/concurrency/ConcurrencyTests.cpp b/winml/test/concurrency/ConcurrencyTests.cpp index 9200f8776ccca..aac227c92cb47 100644 --- a/winml/test/concurrency/ConcurrencyTests.cpp +++ b/winml/test/concurrency/ConcurrencyTests.cpp @@ -4,14 +4,13 @@ #include "model.h" #include "SqueezeNetValidator.h" #include "threadPool.h" -#include "windows.ai.machinelearning.native.internal.h" #include #include #include #include -using namespace winrt::Windows::AI::MachineLearning; +using namespace winml; using namespace winrt; namespace { @@ -37,16 +36,20 @@ void LoadBindEvalSqueezenetRealDataWithValidationConcurrently() { } } -void ConcurrencyTestsApiSetup() { +void ConcurrencyTestsClassSetup() { init_apartment(); std::srand(static_cast(std::time(nullptr))); } +void ConcurrencyTestsGpuMethodSetup() { + GPUTEST; +} + struct EvaluationUnit { LearningModel model; LearningModelSession session; LearningModelBinding binding; - winrt::Windows::Foundation::IAsyncOperation operation; + wf::IAsyncOperation operation; LearningModelEvaluationResult result; EvaluationUnit() : model(nullptr), session(nullptr), binding(nullptr), result(nullptr) {} @@ -156,7 +159,7 @@ void EvalAsyncDifferentBindings() { VerifyEvaluation(evaluation_units, { TABBY_CAT_INDEX, TENCH_INDEX }); } -winrt::Windows::AI::MachineLearning::ILearningModelFeatureDescriptor UnusedCreateFeatureDescriptor( +winml::ILearningModelFeatureDescriptor UnusedCreateFeatureDescriptor( std::shared_ptr model, const std::wstring& name, const std::wstring& description, @@ -253,7 +256,6 @@ void MultiThreadMultiSession() { } void MultiThreadMultiSessionGpu() { - GPUTEST MultiThreadMultiSessionOnDevice(LearningModelDeviceKind::DirectX); } @@ -323,14 +325,14 @@ void MultiThreadSingleSession() { } void MultiThreadSingleSessionGpu() { - GPUTEST MultiThreadSingleSessionOnDevice(LearningModelDeviceKind::DirectX); } } const ConcurrencyTestsApi& getapi() { static constexpr ConcurrencyTestsApi api = { - ConcurrencyTestsApiSetup, + ConcurrencyTestsClassSetup, + ConcurrencyTestsGpuMethodSetup, LoadBindEvalSqueezenetRealDataWithValidationConcurrently, MultiThreadLoadModel, MultiThreadMultiSession, diff --git a/winml/test/concurrency/ConcurrencyTests.h b/winml/test/concurrency/ConcurrencyTests.h index 2bfb8a4896c32..0fa8e4a45fd2c 100644 --- a/winml/test/concurrency/ConcurrencyTests.h +++ b/winml/test/concurrency/ConcurrencyTests.h @@ -6,7 +6,8 @@ struct ConcurrencyTestsApi { - SetupTest ConcurrencyTestsApiSetup; + SetupClass ConcurrencyTestsClassSetup; + SetupTest ConcurrencyTestsGpuMethodSetup; VoidTest LoadBindEvalSqueezenetRealDataWithValidationConcurrently; VoidTest MultiThreadLoadModel; VoidTest MultiThreadMultiSession; @@ -19,18 +20,26 @@ struct ConcurrencyTestsApi }; const ConcurrencyTestsApi& getapi(); -WINML_TEST_CLASS_BEGIN_WITH_SETUP(ConcurrencyTests, ConcurrencyTestsApiSetup) +WINML_TEST_CLASS_BEGIN(ConcurrencyTests) +WINML_TEST_CLASS_SETUP_CLASS(ConcurrencyTestsClassSetup) +WINML_TEST_CLASS_BEGIN_TESTS WINML_TEST(ConcurrencyTests, LoadBindEvalSqueezenetRealDataWithValidationConcurrently) WINML_TEST(ConcurrencyTests, MultiThreadLoadModel) WINML_TEST(ConcurrencyTests, MultiThreadMultiSession) -WINML_TEST(ConcurrencyTests, MultiThreadMultiSessionGpu) WINML_TEST(ConcurrencyTests, MultiThreadSingleSession) -WINML_TEST(ConcurrencyTests, MultiThreadSingleSessionGpu) WINML_TEST(ConcurrencyTests, EvalAsyncDifferentModels) WINML_TEST(ConcurrencyTests, EvalAsyncDifferentSessions) WINML_TEST(ConcurrencyTests, EvalAsyncDifferentBindings) WINML_TEST_CLASS_END() +WINML_TEST_CLASS_BEGIN(ConcurrencyTestsGpu) +WINML_TEST_CLASS_SETUP_CLASS(ConcurrencyTestsClassSetup) +WINML_TEST_CLASS_SETUP_METHOD(ConcurrencyTestsGpuMethodSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(ConcurrencyTestsGpu, MultiThreadMultiSessionGpu) +WINML_TEST(ConcurrencyTestsGpu, MultiThreadSingleSessionGpu) +WINML_TEST_CLASS_END() + // indices for imagenet label static constexpr uint32_t TABBY_CAT_INDEX = 281; static constexpr uint32_t TENCH_INDEX = 0; diff --git a/winml/test/image/imageTestHelper.cpp b/winml/test/image/imageTestHelper.cpp index 6be4cab7c65f0..ab055ab624f46 100644 --- a/winml/test/image/imageTestHelper.cpp +++ b/winml/test/image/imageTestHelper.cpp @@ -12,10 +12,10 @@ #define FENCE_SIGNAL_VALUE 1 using namespace winrt; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Foundation::Collections; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Graphics::Imaging; +using namespace winml; +using namespace wfc; +using namespace wm; +using namespace wgi; namespace ImageTestHelper { BitmapPixelFormat GetPixelFormat(const std::wstring& inputPixelFormat) { @@ -37,8 +37,8 @@ namespace ImageTestHelper { softwareBitmap = SoftwareBitmap::Convert(softwareBitmap, BitmapPixelFormat::Bgra8); BYTE* pData = nullptr; UINT32 size = 0; - winrt::Windows::Graphics::Imaging::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(winrt::Windows::Graphics::Imaging::BitmapBufferAccessMode::Read)); - winrt::Windows::Foundation::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); + wgi::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(wgi::BitmapBufferAccessMode::Read)); + wf::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); auto spByteAccess = reference.as<::Windows::Foundation::IMemoryBufferByteAccess>(); spByteAccess->GetBuffer(&pData, &size); @@ -82,8 +82,8 @@ namespace ImageTestHelper { softwareBitmap = SoftwareBitmap::Convert(softwareBitmap, BitmapPixelFormat::Bgra8); BYTE* pData = nullptr; UINT32 size = 0; - BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(winrt::Windows::Graphics::Imaging::BitmapBufferAccessMode::Read)); - winrt::Windows::Foundation::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); + BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(wgi::BitmapBufferAccessMode::Read)); + wf::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); com_ptr<::Windows::Foundation::IMemoryBufferByteAccess> spByteAccess = reference.as<::Windows::Foundation::IMemoryBufferByteAccess>(); spByteAccess->GetBuffer(&pData, &size); @@ -214,7 +214,7 @@ namespace ImageTestHelper { wil::unique_event hDirectEvent(directEvent); //Create Fence - Microsoft::WRL::ComPtr spDirectFence = nullptr; + ::Microsoft::WRL::ComPtr spDirectFence = nullptr; WINML_EXPECT_HRESULT_SUCCEEDED(pD3D12Device->CreateFence( 0, D3D12_FENCE_FLAG_NONE, @@ -251,8 +251,8 @@ namespace ImageTestHelper { uint32_t size = 4 * softwareBitmapActual.PixelHeight() * softwareBitmapActual.PixelWidth(); - winrt::Windows::Storage::Streams::Buffer actualOutputBuffer(size); - winrt::Windows::Storage::Streams::Buffer expectedOutputBuffer(size); + ws::Streams::Buffer actualOutputBuffer(size); + ws::Streams::Buffer expectedOutputBuffer(size); softwareBitmapActual.CopyToBuffer(actualOutputBuffer); softwareBitmapExpected.CopyToBuffer(expectedOutputBuffer); diff --git a/winml/test/image/imageTestHelper.h b/winml/test/image/imageTestHelper.h index 2754dffd7b6ad..f5c397aeba72d 100644 --- a/winml/test/image/imageTestHelper.h +++ b/winml/test/image/imageTestHelper.h @@ -17,18 +17,18 @@ enum VideoFrameSource { FromSoftwareBitmap, FromDirect3DSurface, FromUnsupported namespace ImageTestHelper { - winrt::Windows::Graphics::Imaging::BitmapPixelFormat GetPixelFormat(const std::wstring& inputPixelFormat); + wgi::BitmapPixelFormat GetPixelFormat(const std::wstring& inputPixelFormat); - winrt::Windows::AI::MachineLearning::TensorFloat LoadInputImageFromCPU( - winrt::Windows::Graphics::Imaging::SoftwareBitmap softwareBitmap, + winml::TensorFloat LoadInputImageFromCPU( + wgi::SoftwareBitmap softwareBitmap, const std::wstring& modelPixelFormat); - winrt::Windows::AI::MachineLearning::TensorFloat LoadInputImageFromGPU( - winrt::Windows::Graphics::Imaging::SoftwareBitmap softwareBitmap, + winml::TensorFloat LoadInputImageFromGPU( + wgi::SoftwareBitmap softwareBitmap, const std::wstring& modelPixelFormat); bool VerifyHelper( - winrt::Windows::Media::VideoFrame actual, - winrt::Windows::Media::VideoFrame expected); + wm::VideoFrame actual, + wm::VideoFrame expected); } diff --git a/winml/test/image/imagetests.cpp b/winml/test/image/imagetests.cpp index 32efc43a29248..0f6eb76214996 100644 --- a/winml/test/image/imagetests.cpp +++ b/winml/test/image/imagetests.cpp @@ -3,9 +3,6 @@ #include "filehelpers.h" #include "imageTestHelper.h" #include "robuffer.h" -#include "windows.ai.machinelearning.native.internal.h" -#include "winrt/Windows.Storage.h" -#include "winrt/Windows.Storage.Streams.h" #include #include @@ -17,13 +14,13 @@ #endif using namespace winrt; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Foundation::Collections; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Graphics::DirectX; -using namespace winrt::Windows::Storage; -using namespace winrt::Windows::Storage::Streams; +using namespace winml; +using namespace wfc; +using namespace wm; +using namespace wgi; +using namespace wgdx; +using namespace ws; +using namespace wss; enum BindingLocation { CPU, @@ -32,14 +29,14 @@ enum BindingLocation { class ImageTests : public ::testing::Test { protected: - winrt::Windows::AI::MachineLearning::LearningModel m_model = nullptr; - winrt::Windows::AI::MachineLearning::LearningModelDevice m_device = nullptr; - winrt::Windows::AI::MachineLearning::LearningModelSession m_session = nullptr; - winrt::Windows::AI::MachineLearning::LearningModelBinding m_model_binding = nullptr; - winrt::Windows::AI::MachineLearning::LearningModelEvaluationResult m_result = nullptr; - - void SetUp() override { - init_apartment(); + winml::LearningModel m_model = nullptr; + winml::LearningModelDevice m_device = nullptr; + winml::LearningModelSession m_session = nullptr; + winml::LearningModelBinding m_model_binding = nullptr; + winml::LearningModelEvaluationResult m_result = nullptr; + + static void SetUpTestSuite() { + init_apartment(); } void LoadModel(const std::wstring& model_path) { @@ -690,8 +687,8 @@ static void RunImageBindingInputAndOutput(bool bindInputAsIInspectable) { BYTE* data = nullptr; UINT32 ui_capacity = 0; - winrt::Windows::Graphics::Imaging::BitmapBuffer bitmap_buffer(output_image.SoftwareBitmap().LockBuffer(winrt::Windows::Graphics::Imaging::BitmapBufferAccessMode::Read)); - winrt::Windows::Foundation::IMemoryBufferReference reference = bitmap_buffer.CreateReference(); + wgi::BitmapBuffer bitmap_buffer(output_image.SoftwareBitmap().LockBuffer(wgi::BitmapBufferAccessMode::Read)); + wf::IMemoryBufferReference reference = bitmap_buffer.CreateReference(); auto spByteAccess = reference.as<::Windows::Foundation::IMemoryBufferByteAccess>(); WINML_EXPECT_HRESULT_SUCCEEDED(spByteAccess->GetBuffer(&data, &ui_capacity)); WINML_EXPECT_NOT_EQUAL(data[0], 0); diff --git a/winml/test/scenario/cppwinrt/CustomOperatorProvider.h b/winml/test/scenario/cppwinrt/CustomOperatorProvider.h index a280295be32eb..87c51c8412a18 100644 --- a/winml/test/scenario/cppwinrt/CustomOperatorProvider.h +++ b/winml/test/scenario/cppwinrt/CustomOperatorProvider.h @@ -9,7 +9,7 @@ struct CustomOperatorProvider : winrt::implements< CustomOperatorProvider, - winrt::Windows::AI::MachineLearning::ILearningModelOperatorProvider, + winml::ILearningModelOperatorProvider, ILearningModelOperatorProviderNative> { HMODULE m_library; @@ -17,10 +17,14 @@ struct CustomOperatorProvider : CustomOperatorProvider() { + std::wostringstream dll; + dll << BINARY_NAME; + auto winml_dll_name = dll.str(); + #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) - m_library = LoadLibraryW(L"windows.ai.machinelearning.dll"); + m_library = LoadLibraryW(winml_dll_name.c_str()); #elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP) - m_library = LoadPackagedLibrary(L"windows.ai.machinelearning.dll", 0 /*Reserved*/); + m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/); #endif using create_registry_delegate = HRESULT WINAPI (_COM_Outptr_ IMLOperatorRegistry** registry); auto create_registry = reinterpret_cast(GetProcAddress(m_library, "MLCreateOperatorRegistry")); diff --git a/winml/test/scenario/cppwinrt/CustomOps.cpp b/winml/test/scenario/cppwinrt/CustomOps.cpp index 2afd05d99b9ab..d7e057b93a36e 100644 --- a/winml/test/scenario/cppwinrt/CustomOps.cpp +++ b/winml/test/scenario/cppwinrt/CustomOps.cpp @@ -7,10 +7,6 @@ #include #include "filehelpers.h" #include -#include -#include -#include "winrt/Windows.Storage.h" -#include #include #include #include "CustomOperatorProvider.h" @@ -25,477 +21,450 @@ #include "CustomNullOp.h" #include -using namespace winrt; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Foundation::Collections; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Storage; -using namespace winrt::Windows::Storage::Streams; +using namespace winml; +using namespace wfc; +using namespace wm; +using namespace wgi; +using namespace ws; +using namespace wss; -static void CustomOpsScenarioTestSetup() +static void CustomOpsScenarioTestsClassSetup() { - init_apartment(); + winrt::init_apartment(); } -static void CustomOpsScenarioGpuTestSetup() +static void CustomOpsScenarioTestsGpuMethodSetup() { - init_apartment(); GPUTEST; } // Tests that the execution provider correctly fuses operators together when custom ops are involved. static void CustomOperatorFusion() { - constexpr const wchar_t* c_modelFilename = L"squeezenet_tensor_input.onnx"; - - // This particular model has 25 Conv ops and 25 Relu ops, all of which are eligible for fusion so we expect them - // all to be fused (removing them from the graph) and replaced with the appropriate fused op instead. The same - // goes for the single Gemm+Sigmoid in the model too. - constexpr const uint32_t c_expectedConvOps = 0; - constexpr const uint32_t c_expectedReluOps = 0; - constexpr const uint32_t c_expectedFusedConvOps = 25; - constexpr const uint32_t c_expectedGemmOps = 0; - constexpr const uint32_t c_expectedSigmoidOps = 0; - constexpr const uint32_t c_expectedFusedGemmOps = 1; - - // These ops are also part of the model but shouldn't be fused - constexpr const uint32_t c_expectedBatchNormOps = 1; - constexpr const uint32_t c_expectedMaxPoolOps = 3; - constexpr const uint32_t c_expectedConcatOps = 8; - - struct CallbackOperatorProvider : - winrt::implements< - CallbackOperatorProvider, - winrt::Windows::AI::MachineLearning::ILearningModelOperatorProvider, - ILearningModelOperatorProviderNative> - { - struct CallCounts - { - std::atomic conv = 0; - std::atomic relu = 0; - std::atomic fusedConv = 0; - std::atomic gemm = 0; - std::atomic sigmoid = 0; - std::atomic fusedGemm = 0; - std::atomic batchNorm = 0; - std::atomic maxPool = 0; - std::atomic concat = 0; - }; + constexpr const wchar_t* c_modelFilename = L"squeezenet_tensor_input.onnx"; + + // This particular model has 25 Conv ops and 25 Relu ops, all of which are eligible for fusion so we expect them + // all to be fused (removing them from the graph) and replaced with the appropriate fused op instead. The same + // goes for the single Gemm+Sigmoid in the model too. + constexpr const uint32_t c_expectedConvOps = 0; + constexpr const uint32_t c_expectedReluOps = 0; + constexpr const uint32_t c_expectedFusedConvOps = 25; + constexpr const uint32_t c_expectedGemmOps = 0; + constexpr const uint32_t c_expectedSigmoidOps = 0; + constexpr const uint32_t c_expectedFusedGemmOps = 1; + + // These ops are also part of the model but shouldn't be fused + constexpr const uint32_t c_expectedBatchNormOps = 1; + constexpr const uint32_t c_expectedMaxPoolOps = 3; + constexpr const uint32_t c_expectedConcatOps = 8; + + struct CallbackOperatorProvider : winrt::implements< + CallbackOperatorProvider, + winml::ILearningModelOperatorProvider, + ILearningModelOperatorProviderNative> { + struct CallCounts { + std::atomic conv = 0; + std::atomic relu = 0; + std::atomic fusedConv = 0; + std::atomic gemm = 0; + std::atomic sigmoid = 0; + std::atomic fusedGemm = 0; + std::atomic batchNorm = 0; + std::atomic maxPool = 0; + std::atomic concat = 0; + }; - const CallCounts& GetCallCounts() - { - return m_callCounts; - } + const CallCounts& GetCallCounts() { + return m_callCounts; + } - CallbackOperatorProvider() - { - using namespace OperatorHelper; + CallbackOperatorProvider() { + using namespace OperatorHelper; + + std::wostringstream dll; + dll << BINARY_NAME; + auto winml_dll_name = dll.str(); - WINML_EXPECT_HRESULT_SUCCEEDED(MLCreateOperatorRegistry(m_registry.put())); +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + auto m_library = LoadLibraryW(winml_dll_name.c_str()); +#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP) + auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/); +#endif + using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry); + auto create_registry = reinterpret_cast(GetProcAddress(m_library, "MLCreateOperatorRegistry")); + WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put())); #pragma push_macro("REGISTER_KERNEL") #define REGISTER_KERNEL(_name, _domain, _opSet, _shapeInferrer, _callCount) \ - NullOperatorFactory::RegisterKernel( \ - #_name, \ - (_domain), \ - _opSet::sc_sinceVer_ ## _name, \ - m_registry, \ - winrt::make>(), \ - (_callCount)); - - REGISTER_KERNEL(Conv, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConvHelper, &m_callCounts.conv); - REGISTER_KERNEL(Relu, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.relu); - REGISTER_KERNEL(FusedConv, onnxruntime::kMSDmlDomain, MsftOperatorSet1, ConvHelper, &m_callCounts.fusedConv); - - REGISTER_KERNEL(Gemm, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GemmHelper, &m_callCounts.gemm); - REGISTER_KERNEL(Sigmoid, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.sigmoid); - REGISTER_KERNEL(FusedGemm, onnxruntime::kMSDmlDomain, MsftOperatorSet1, GemmHelper, &m_callCounts.fusedGemm); - - REGISTER_KERNEL(BatchNormalization, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.batchNorm); - REGISTER_KERNEL(MaxPool, onnxruntime::kOnnxDomain, OnnxOperatorSet7, PoolingHelper, &m_callCounts.maxPool); - REGISTER_KERNEL(Concat, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConcatHelper, &m_callCounts.concat); + NullOperatorFactory::RegisterKernel( \ + #_name, \ + (_domain), \ + _opSet::sc_sinceVer_##_name, \ + m_registry, \ + winrt::make>(), \ + (_callCount)); + + REGISTER_KERNEL(Conv, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConvHelper, &m_callCounts.conv); + REGISTER_KERNEL(Relu, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.relu); + REGISTER_KERNEL(FusedConv, onnxruntime::kMSDmlDomain, MsftOperatorSet1, ConvHelper, &m_callCounts.fusedConv); + + REGISTER_KERNEL(Gemm, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GemmHelper, &m_callCounts.gemm); + REGISTER_KERNEL(Sigmoid, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.sigmoid); + REGISTER_KERNEL(FusedGemm, onnxruntime::kMSDmlDomain, MsftOperatorSet1, GemmHelper, &m_callCounts.fusedGemm); + + REGISTER_KERNEL(BatchNormalization, onnxruntime::kOnnxDomain, OnnxOperatorSet7, GetOutputShapeAsInputShapeHelper, &m_callCounts.batchNorm); + REGISTER_KERNEL(MaxPool, onnxruntime::kOnnxDomain, OnnxOperatorSet7, PoolingHelper, &m_callCounts.maxPool); + REGISTER_KERNEL(Concat, onnxruntime::kOnnxDomain, OnnxOperatorSet7, ConcatHelper, &m_callCounts.concat); #pragma pop_macro("REGISTER_KERNEL") - } - - STDMETHOD(GetRegistry)(IMLOperatorRegistry** ppOperatorRegistry) - { - if (ppOperatorRegistry == nullptr) - { - return E_POINTER; - } - - m_registry.copy_to(ppOperatorRegistry); - return S_OK; - } - - private: - winrt::com_ptr m_registry; - CallCounts m_callCounts; - }; - - auto customOperatorProvider = winrt::make(); - auto provider = customOperatorProvider.as(); - - LearningModelDevice device = nullptr; - WINML_EXPECT_NO_THROW(device = LearningModelDevice(LearningModelDeviceKind::DirectX)); - std::wstring fullPath = FileHelpers::GetModulePath() + c_modelFilename; - auto model = LearningModel::LoadFromFilePath(fullPath, provider); - - auto featureValue = FileHelpers::LoadImageFeatureValue(L"227x227.png"); - - LearningModelSession session = nullptr; - WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device)); - LearningModelBinding modelBinding(session); + } - modelBinding.Bind(L"data", featureValue); - auto result = session.Evaluate(modelBinding, L""); + STDMETHOD(GetRegistry) + (IMLOperatorRegistry** ppOperatorRegistry) { + if (ppOperatorRegistry == nullptr) { + return E_POINTER; + } - const auto& callCounts = customOperatorProvider.as()->GetCallCounts(); + m_registry.copy_to(ppOperatorRegistry); + return S_OK; + } - // Verify that the correct number of each operator was seen (i.e. that none were dropped / incorrectly fused) - WINML_EXPECT_EQUAL(c_expectedConvOps, callCounts.conv); - WINML_EXPECT_EQUAL(c_expectedReluOps, callCounts.relu); - WINML_EXPECT_EQUAL(c_expectedFusedConvOps, callCounts.fusedConv); - WINML_EXPECT_EQUAL(c_expectedGemmOps, callCounts.gemm); - WINML_EXPECT_EQUAL(c_expectedSigmoidOps, callCounts.sigmoid); - WINML_EXPECT_EQUAL(c_expectedFusedGemmOps, callCounts.fusedGemm); - WINML_EXPECT_EQUAL(c_expectedBatchNormOps, callCounts.batchNorm); - WINML_EXPECT_EQUAL(c_expectedMaxPoolOps, callCounts.maxPool); - WINML_EXPECT_EQUAL(c_expectedConcatOps, callCounts.concat); + private: + winrt::com_ptr m_registry; + CallCounts m_callCounts; + }; + + auto customOperatorProvider = winrt::make(); + auto provider = customOperatorProvider.as(); + + LearningModelDevice device = nullptr; + WINML_EXPECT_NO_THROW(device = LearningModelDevice(LearningModelDeviceKind::DirectX)); + std::wstring fullPath = FileHelpers::GetModulePath() + c_modelFilename; + auto model = LearningModel::LoadFromFilePath(fullPath, provider); + + auto featureValue = FileHelpers::LoadImageFeatureValue(L"227x227.png"); + + LearningModelSession session = nullptr; + WINML_EXPECT_NO_THROW(session = LearningModelSession(model, device)); + LearningModelBinding modelBinding(session); + + modelBinding.Bind(L"data", featureValue); + auto result = session.Evaluate(modelBinding, L""); + + const auto& callCounts = customOperatorProvider.as()->GetCallCounts(); + + // Verify that the correct number of each operator was seen (i.e. that none were dropped / incorrectly fused) + WINML_EXPECT_EQUAL(c_expectedConvOps, callCounts.conv); + WINML_EXPECT_EQUAL(c_expectedReluOps, callCounts.relu); + WINML_EXPECT_EQUAL(c_expectedFusedConvOps, callCounts.fusedConv); + WINML_EXPECT_EQUAL(c_expectedGemmOps, callCounts.gemm); + WINML_EXPECT_EQUAL(c_expectedSigmoidOps, callCounts.sigmoid); + WINML_EXPECT_EQUAL(c_expectedFusedGemmOps, callCounts.fusedGemm); + WINML_EXPECT_EQUAL(c_expectedBatchNormOps, callCounts.batchNorm); + WINML_EXPECT_EQUAL(c_expectedMaxPoolOps, callCounts.maxPool); + WINML_EXPECT_EQUAL(c_expectedConcatOps, callCounts.concat); } -struct LocalCustomOperatorProvider : - winrt::implements< - LocalCustomOperatorProvider, - winrt::Windows::AI::MachineLearning::ILearningModelOperatorProvider, - ILearningModelOperatorProviderNative> -{ - LocalCustomOperatorProvider() - { - WINML_EXPECT_HRESULT_SUCCEEDED(MLCreateOperatorRegistry(m_registry.put())); +struct LocalCustomOperatorProvider : winrt::implements< + LocalCustomOperatorProvider, + winml::ILearningModelOperatorProvider, + ILearningModelOperatorProviderNative> { + LocalCustomOperatorProvider() { + + std::wostringstream dll; + dll << BINARY_NAME; + auto winml_dll_name = dll.str(); + +#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) + auto m_library = LoadLibraryW(winml_dll_name.c_str()); +#elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PC_APP) + auto m_library = LoadPackagedLibrary(winml_dll_name.c_str(), 0 /*Reserved*/); +#endif + using create_registry_delegate = HRESULT WINAPI(_COM_Outptr_ IMLOperatorRegistry * *registry); + auto create_registry = reinterpret_cast(GetProcAddress(m_library, "MLCreateOperatorRegistry")); + WINML_EXPECT_HRESULT_SUCCEEDED(create_registry(m_registry.put())); + } + + STDMETHOD(GetRegistry) + (IMLOperatorRegistry** ppOperatorRegistry) { + if (ppOperatorRegistry == nullptr) { + return E_POINTER; } - STDMETHOD(GetRegistry)(IMLOperatorRegistry** ppOperatorRegistry) - { - if (ppOperatorRegistry == nullptr) - { - return E_POINTER; - } - - m_registry.copy_to(ppOperatorRegistry); - return S_OK; - } + m_registry.copy_to(ppOperatorRegistry); + return S_OK; + } - IMLOperatorRegistry* GetRegistry() - { - return m_registry.get(); - } + IMLOperatorRegistry* GetRegistry() { + return m_registry.get(); + } -protected: - winrt::com_ptr m_registry; + protected: + winrt::com_ptr m_registry; }; // Checks test attributes set on ABI kernels can be queried with correct values -void VerifyTestAttributes(const MLOperatorAttributes& attrs) -{ - std::string strAttr = attrs.GetAttribute("DefaultedNonRequiredString"); - WINML_EXPECT_EQUAL(strAttr, "1"); +void VerifyTestAttributes(const MLOperatorAttributes& attrs) { + std::string strAttr = attrs.GetAttribute("DefaultedNonRequiredString"); + WINML_EXPECT_EQUAL(strAttr, "1"); - std::vector strArrayAttr = attrs.GetAttributeVector("DefaultedNonRequiredStringArray"); - std::vector expected = std::vector({ "1", "2" }); - for (size_t i = 0; i < expected.size(); ++i) - { - WINML_EXPECT_EQUAL(strArrayAttr[i], expected[i]); - } + std::vector strArrayAttr = attrs.GetAttributeVector("DefaultedNonRequiredStringArray"); + std::vector expected = std::vector({"1", "2"}); + for (size_t i = 0; i < expected.size(); ++i) { + WINML_EXPECT_EQUAL(strArrayAttr[i], expected[i]); + } - WINML_EXPECT_EQUAL(1, attrs.GetAttribute("DefaultedNonRequiredInt")); - WINML_EXPECT_EQUAL(1.0f, attrs.GetAttribute("DefaultedNonRequiredFloat")); + WINML_EXPECT_EQUAL(1, attrs.GetAttribute("DefaultedNonRequiredInt")); + WINML_EXPECT_EQUAL(1.0f, attrs.GetAttribute("DefaultedNonRequiredFloat")); - WINML_EXPECT_EQUAL(std::vector({ 1, 2 }), attrs.GetAttributeVector("DefaultedNonRequiredIntArray")); - WINML_EXPECT_EQUAL(std::vector({ 1.0f, 2.0f }), attrs.GetAttributeVector("DefaultedNonRequiredFloatArray")); + WINML_EXPECT_EQUAL(std::vector({1, 2}), attrs.GetAttributeVector("DefaultedNonRequiredIntArray")); + WINML_EXPECT_EQUAL(std::vector({1.0f, 2.0f}), attrs.GetAttributeVector("DefaultedNonRequiredFloatArray")); } // Foo kernel which is doing Add and optionally truncates its output template -class FooKernel -{ -public: - FooKernel(const MLOperatorKernelCreationContext& info) - { - if (VerifyAttributes) - { - VerifyTestAttributes(info); - } - - VerifyShapeInfo(info); +class FooKernel { + public: + FooKernel(const MLOperatorKernelCreationContext& info) { + if (VerifyAttributes) { + VerifyTestAttributes(info); } - void VerifyShapeInfo(const MLOperatorKernelCreationContext& info) - { - if (!Truncate) - { - com_ptr shapeInfo; - WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), false); - WINML_EXPECT_HRESULT_FAILED(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put())); - } - else - { - com_ptr shapeInfo; - WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), true); - WINML_EXPECT_EQUAL(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()), S_OK); - } + VerifyShapeInfo(info); + } + + void VerifyShapeInfo(const MLOperatorKernelCreationContext& info) { + if (!Truncate) { + winrt::com_ptr shapeInfo; + WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), false); + WINML_EXPECT_HRESULT_FAILED(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put())); + } else { + winrt::com_ptr shapeInfo; + WINML_EXPECT_EQUAL(info.GetInterface()->HasTensorShapeDescription(), true); + WINML_EXPECT_EQUAL(info.GetInterface()->GetTensorShapeDescription(shapeInfo.put()), S_OK); } + } - void Compute(const MLOperatorKernelContext& context) const - { - const auto X = context.GetInputTensor(0); - const auto W = context.GetInputTensor(1); + void Compute(const MLOperatorKernelContext& context) const { + const auto X = context.GetInputTensor(0); + const auto W = context.GetInputTensor(1); - auto xData = X.GetData(); - auto wData = W.GetData(); + auto xData = X.GetData(); + auto wData = W.GetData(); - auto shape = X.GetShape(); + auto shape = X.GetShape(); - // This is used to test shape inference - if (Truncate) - { - shape[0] -= 1; - } + // This is used to test shape inference + if (Truncate) { + shape[0] -= 1; + } - if (!Truncate) - { - com_ptr tensor; - WINML_EXPECT_HRESULT_FAILED(context.GetInterface()->GetOutputTensor(0, tensor.put())); - } - else - { - MLOperatorTensor tensor = context.GetOutputTensor(0); - } + if (!Truncate) { + winrt::com_ptr tensor; + WINML_EXPECT_HRESULT_FAILED(context.GetInterface()->GetOutputTensor(0, tensor.put())); + } else { + MLOperatorTensor tensor = context.GetOutputTensor(0); + } - auto Y = context.GetOutputTensor(0, shape); - auto yData = Y.GetData(); + auto Y = context.GetOutputTensor(0, shape); + auto yData = Y.GetData(); - size_t size = 1; - for (size_t i = 0; i < shape.size(); i++) - { - size *= shape[i]; - } + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } - for (size_t i = 0; i < size; i++) - { - yData[i] = xData[i] + wData[i]; - } + for (size_t i = 0; i < size; i++) { + yData[i] = xData[i] + wData[i]; } + } }; template -void CALLBACK CreateABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) -{ - HRESULT hr = MLOperatorKernel>::CreateInstance(*kernelInfo, opKernel); - THROW_IF_FAILED(hr); +void CALLBACK CreateABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) { + HRESULT hr = MLOperatorKernel>::CreateInstance(*kernelInfo, opKernel); + THROW_IF_FAILED(hr); } -void CALLBACK CreateTruncatedABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) -{ - HRESULT hr = MLOperatorKernel>::CreateInstance(*kernelInfo, opKernel); - THROW_IF_FAILED(hr); +void CALLBACK CreateTruncatedABIFooKernel(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) { + HRESULT hr = MLOperatorKernel>::CreateInstance(*kernelInfo, opKernel); + THROW_IF_FAILED(hr); } // Test using a foo kernel which is doing Add, but register it as "Mul". -static void CustomKernelWithBuiltInSchema() -{ - // Create the registry - auto operatorProvider = winrt::make(); - IMLOperatorRegistry* registry = operatorProvider.as()->GetRegistry(); +static void CustomKernelWithBuiltInSchema() { + // Create the registry + auto operatorProvider = winrt::make(); + IMLOperatorRegistry* registry = operatorProvider.as()->GetRegistry(); - // Register the kernel - MLOperatorEdgeDescription floatTensorType = - { - MLOperatorEdgeType::Tensor, - static_cast(MLOperatorTensorDataType::Float) - }; - - MLOperatorEdgeTypeConstrant constraint = { "T", &floatTensorType, 1 }; - - MLOperatorKernelDescription kernelDesc = - { - "", - "Mul", - 7, - MLOperatorExecutionType::Cpu, - &constraint, - 1, - nullptr, - 0, - MLOperatorKernelOptions::AllowDynamicInputShapes - }; - - Microsoft::WRL::ComPtr factory = wil::MakeOrThrow(CreateABIFooKernel); - WINML_EXPECT_HRESULT_SUCCEEDED(registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr)); - - // Prepare inputs - std::vector dimsX = { 3, 2 }; - std::vector valuesX = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }; - - // Prepare expected inputs and outputs - std::vector expectedDimsY = { 3, 2 }; - - // The expected value should be Add's result. - std::vector expectedValuesY = { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f }; - - // Create the model and sessions - std::wstring fullPath = FileHelpers::GetModulePath() + L"mul.onnx"; - LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider); - - LearningModelSession session(model); - LearningModelBinding bindings(session); - - // Bind inputs and outputs - TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view(std::move(valuesX))); - bindings.Bind(winrt::hstring(L"X"), inputTensor); - - auto outputValue = TensorFloat::Create(); - WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue)); - - // Evaluate the model - hstring correlationId; - WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId)); - - // Check the result shape - WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size()); - for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) - { - WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j)); - } + // Register the kernel + MLOperatorEdgeDescription floatTensorType = + { + MLOperatorEdgeType::Tensor, + static_cast(MLOperatorTensorDataType::Float)}; - // Check the results - auto buffer = outputValue.GetAsVectorView(); - WINML_EXPECT_TRUE(buffer != nullptr); - WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer))); + MLOperatorEdgeTypeConstrant constraint = {"T", &floatTensorType, 1}; - // Release the model before operatorProvider goes out of scope - model = nullptr; + MLOperatorKernelDescription kernelDesc = + { + "", + "Mul", + 7, + MLOperatorExecutionType::Cpu, + &constraint, + 1, + nullptr, + 0, + MLOperatorKernelOptions::AllowDynamicInputShapes}; + + Microsoft::WRL::ComPtr factory = wil::MakeOrThrow(CreateABIFooKernel); + WINML_EXPECT_HRESULT_SUCCEEDED(registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr)); + + // Prepare inputs + std::vector dimsX = {3, 2}; + std::vector valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Prepare expected inputs and outputs + std::vector expectedDimsY = {3, 2}; + + // The expected value should be Add's result. + std::vector expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; + + // Create the model and sessions + std::wstring fullPath = FileHelpers::GetModulePath() + L"mul.onnx"; + LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider); + + LearningModelSession session(model); + LearningModelBinding bindings(session); + + // Bind inputs and outputs + TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view(std::move(valuesX))); + bindings.Bind(winrt::hstring(L"X"), inputTensor); + + auto outputValue = TensorFloat::Create(); + WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue)); + + // Evaluate the model + winrt::hstring correlationId; + WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId)); + + // Check the result shape + WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size()); + for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) { + WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j)); + } + + // Check the results + auto buffer = outputValue.GetAsVectorView(); + WINML_EXPECT_TRUE(buffer != nullptr); + WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer))); + + // Release the model before operatorProvider goes out of scope + model = nullptr; } // Similar to MLOperatorShapeInferrer, but using an std::function class MLOperatorShapeInferrerFromFunc : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, IMLOperatorShapeInferrer> -{ -public: - MLOperatorShapeInferrerFromFunc(std::function shapeInferenceFn) : - m_func(shapeInferenceFn) - {} - - HRESULT STDMETHODCALLTYPE InferOutputShapes(IMLOperatorShapeInferenceContext* context) noexcept override try - { - m_func(context); - return S_OK; - } - CATCH_RETURN(); - -private: - std::function m_func; + Microsoft::WRL::RuntimeClassFlags, IMLOperatorShapeInferrer> { + public: + MLOperatorShapeInferrerFromFunc(std::function shapeInferenceFn) : m_func(shapeInferenceFn) {} + + HRESULT STDMETHODCALLTYPE InferOutputShapes(IMLOperatorShapeInferenceContext* context) noexcept override try { + m_func(context); + return S_OK; + } + CATCH_RETURN(); + + private: + std::function m_func; }; // Test using a custom kernel and schema, while verifying attribute defaults, type mapping, and inference methods -static void CustomKernelWithCustomSchema() -{ - // Test cases - struct - { - // Whether the Foo kernel should truncate its output - bool truncateOutput; +static void CustomKernelWithCustomSchema() { + // Test cases + struct + { + // Whether the Foo kernel should truncate its output + bool truncateOutput; - // Whether a type label is used in the schema, versus a type description - bool useTypeLabel; + // Whether a type label is used in the schema, versus a type description + bool useTypeLabel; - // Whether the schema provides a type inference function, and uses an output type - // of Int32 instead of Float32 - bool useTypeInference; + // Whether the schema provides a type inference function, and uses an output type + // of Int32 instead of Float32 + bool useTypeInference; - // Whether a shape inference method is provided in the schema - bool useShapeInferenceInSchema; + // Whether a shape inference method is provided in the schema + bool useShapeInferenceInSchema; - // Whether a shape inference method is provided in the kernel - bool useShapeInferenceInKernel; + // Whether a shape inference method is provided in the kernel + bool useShapeInferenceInKernel; - // Whether attribute defaults are provided in the schema, instead of the kernel - bool attributeDefaultsInSchema; - } testCases[] = - { - {false, true, false, false, false, false}, - {false, false, false, false, false, false}, - {false, true, true, false, false, true}, - {true, false, false, false, true, false}, - {true, true, true, true, true, true}, - }; + // Whether attribute defaults are provided in the schema, instead of the kernel + bool attributeDefaultsInSchema; + } testCases[] = + { + {false, true, false, false, false, false}, + {false, false, false, false, false, false}, + {false, true, true, false, false, true}, + {true, false, false, false, true, false}, + {true, true, true, true, true, true}, + }; - for (size_t caseIndex = 0; caseIndex < std::size(testCases); ++caseIndex) - { - // Create the registry - auto operatorProvider = winrt::make(); - IMLOperatorRegistry* registry = operatorProvider.as()->GetRegistry(); + for (size_t caseIndex = 0; caseIndex < std::size(testCases); ++caseIndex) { + // Create the registry + auto operatorProvider = winrt::make(); + IMLOperatorRegistry* registry = operatorProvider.as()->GetRegistry(); - // Create input and output parameters - MLOperatorSchemaEdgeDescription inputParam = {}; - inputParam.options = MLOperatorParameterOptions::Single; + // Create input and output parameters + MLOperatorSchemaEdgeDescription inputParam = {}; + inputParam.options = MLOperatorParameterOptions::Single; - if (!testCases[caseIndex].useTypeLabel) - { - assert(!testCases[caseIndex].useTypeInference); + if (!testCases[caseIndex].useTypeLabel) { + assert(!testCases[caseIndex].useTypeInference); - MLOperatorEdgeDescription edgeDesc = {}; - edgeDesc.edgeType = MLOperatorEdgeType::Tensor; - edgeDesc.tensorDataType = MLOperatorTensorDataType::Float; + MLOperatorEdgeDescription edgeDesc = {}; + edgeDesc.edgeType = MLOperatorEdgeType::Tensor; + edgeDesc.tensorDataType = MLOperatorTensorDataType::Float; - inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::EdgeDescription; - inputParam.edgeDescription = edgeDesc; - } - else - { - inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::Label; - inputParam.typeLabel = "T1"; - } + inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::EdgeDescription; + inputParam.edgeDescription = edgeDesc; + } else { + inputParam.typeFormat = MLOperatorSchemaEdgeTypeFormat::Label; + inputParam.typeLabel = "T1"; + } - MLOperatorSchemaEdgeDescription outputParam = inputParam; + MLOperatorSchemaEdgeDescription outputParam = inputParam; - // Type inference should set this to tensor(float) even though T2 is not matched - // on an input label - if (testCases[caseIndex].useTypeInference) - { - if (inputParam.typeFormat == MLOperatorSchemaEdgeTypeFormat::Label) - { - outputParam.typeLabel = "T2"; - } - else - { - outputParam.edgeDescription.tensorDataType = MLOperatorTensorDataType::Int32; - } - } - - MLOperatorSchemaEdgeDescription inputs[] = { inputParam, inputParam }; - - MLOperatorEdgeDescription edgeTypes[6] = + // Type inference should set this to tensor(float) even though T2 is not matched + // on an input label + if (testCases[caseIndex].useTypeInference) { + if (inputParam.typeFormat == MLOperatorSchemaEdgeTypeFormat::Label) { + outputParam.typeLabel = "T2"; + } else { + outputParam.edgeDescription.tensorDataType = MLOperatorTensorDataType::Int32; + } + } + + MLOperatorSchemaEdgeDescription inputs[] = {inputParam, inputParam}; + + MLOperatorEdgeDescription edgeTypes[6] = { {MLOperatorEdgeType::Tensor, static_cast(MLOperatorTensorDataType::UInt32)}, {MLOperatorEdgeType::Tensor, static_cast(MLOperatorTensorDataType::UInt64)}, {MLOperatorEdgeType::Tensor, static_cast(MLOperatorTensorDataType::Int32)}, {MLOperatorEdgeType::Tensor, static_cast(MLOperatorTensorDataType::Int64)}, {MLOperatorEdgeType::Tensor, static_cast(MLOperatorTensorDataType::Float)}, - {MLOperatorEdgeType::Tensor, static_cast(MLOperatorTensorDataType::Double)} - }; + {MLOperatorEdgeType::Tensor, static_cast(MLOperatorTensorDataType::Double)}}; - // Type constraints. Only the first is used unless type inference is provided and - // the kernel emits a different output type as "T2" - MLOperatorEdgeTypeConstrant constraints[] = + // Type constraints. Only the first is used unless type inference is provided and + // the kernel emits a different output type as "T2" + MLOperatorEdgeTypeConstrant constraints[] = { {"T1", edgeTypes, static_cast(std::size(edgeTypes))}, - {"T2", edgeTypes, static_cast(std::size(edgeTypes))} - }; + {"T2", edgeTypes, static_cast(std::size(edgeTypes))}}; - // Test attributes - MLOperatorAttribute attributes[] = + // Test attributes + MLOperatorAttribute attributes[] = { {"DefaultedNonRequiredInt", MLOperatorAttributeType::Int, false}, {"DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, false}, @@ -507,9 +476,9 @@ static void CustomKernelWithCustomSchema() {"NonDefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, false}, }; - // Defaults. These are queried back during kernel creation, type and shape inference - // and tested against the same values - MLOperatorAttributeNameValue defaultAttributes[] = + // Defaults. These are queried back during kernel creation, type and shape inference + // and tested against the same values + MLOperatorAttributeNameValue defaultAttributes[] = { {"DefaultedNonRequiredInt", MLOperatorAttributeType::Int, 1}, {"DefaultedNonRequiredFloat", MLOperatorAttributeType::Float, 1}, @@ -519,219 +488,201 @@ static void CustomKernelWithCustomSchema() {"DefaultedNonRequiredStringArray", MLOperatorAttributeType::StringArray, 2}, }; - int64_t defaultInts[] = { 1, 2 }; - float defaultFloats[] = { 1.0f, 2.0f }; - const char* defaultStrings[] = { "1", "2" }; - defaultAttributes[0].ints = defaultInts; - defaultAttributes[1].floats = defaultFloats; - defaultAttributes[2].strings = defaultStrings; - defaultAttributes[3].ints = defaultInts; - defaultAttributes[4].floats = defaultFloats; - defaultAttributes[5].strings = defaultStrings; - - // Schema definition - MLOperatorSchemaDescription schemaDesc = {}; - schemaDesc.name = "Foo"; - schemaDesc.operatorSetVersionAtLastChange = 7; - schemaDesc.inputs = inputs; - schemaDesc.inputCount = 2; - schemaDesc.outputs = &outputParam; - schemaDesc.outputCount = 1; - schemaDesc.typeConstraints = constraints; - schemaDesc.typeConstraintCount = testCases[caseIndex].useTypeLabel ? 2 : 0; - schemaDesc.attributes = attributes; - schemaDesc.attributeCount = static_cast(std::size(attributes)); - - if (testCases[caseIndex].attributeDefaultsInSchema) - { - schemaDesc.defaultAttributes = defaultAttributes; - schemaDesc.defaultAttributeCount = static_cast(std::size(defaultAttributes)); - } + int64_t defaultInts[] = {1, 2}; + float defaultFloats[] = {1.0f, 2.0f}; + const char* defaultStrings[] = {"1", "2"}; + defaultAttributes[0].ints = defaultInts; + defaultAttributes[1].floats = defaultFloats; + defaultAttributes[2].strings = defaultStrings; + defaultAttributes[3].ints = defaultInts; + defaultAttributes[4].floats = defaultFloats; + defaultAttributes[5].strings = defaultStrings; + + // Schema definition + MLOperatorSchemaDescription schemaDesc = {}; + schemaDesc.name = "Foo"; + schemaDesc.operatorSetVersionAtLastChange = 7; + schemaDesc.inputs = inputs; + schemaDesc.inputCount = 2; + schemaDesc.outputs = &outputParam; + schemaDesc.outputCount = 1; + schemaDesc.typeConstraints = constraints; + schemaDesc.typeConstraintCount = testCases[caseIndex].useTypeLabel ? 2 : 0; + schemaDesc.attributes = attributes; + schemaDesc.attributeCount = static_cast(std::size(attributes)); + + if (testCases[caseIndex].attributeDefaultsInSchema) { + schemaDesc.defaultAttributes = defaultAttributes; + schemaDesc.defaultAttributeCount = static_cast(std::size(defaultAttributes)); + } - Microsoft::WRL::ComPtr typeInferrer; - Microsoft::WRL::ComPtr shapeInferrer; + Microsoft::WRL::ComPtr typeInferrer; + Microsoft::WRL::ComPtr shapeInferrer; - // Type inference function - if (testCases[caseIndex].useTypeInference) - { - typeInferrer = wil::MakeOrThrow([](IMLOperatorTypeInferenceContext* ctx) -> void - { - VerifyTestAttributes(MLOperatorTypeInferenceContext(ctx)); - - MLOperatorEdgeDescription edgeDesc = {}; - edgeDesc.edgeType = MLOperatorEdgeType::Tensor; - edgeDesc.tensorDataType = MLOperatorTensorDataType::Float; - - MLOperatorTypeInferenceContext(ctx).SetOutputEdgeDescription(0, &edgeDesc); - }); - } - - // Store the shape inference context with a reference following the call to InferOutputShapes. - // This will be called after loading the model as an isolated test for how ABI context objects - // are "closed." - Microsoft::WRL::ComPtr shapeInferenceContext; - - // Shape inference is tested by truncating the output size - bool truncateOutput = testCases[caseIndex].truncateOutput; - if (truncateOutput) - { - shapeInferrer = wil::MakeOrThrow([&shapeInferenceContext](IMLOperatorShapeInferenceContext* ctx) -> void - { - VerifyTestAttributes(MLShapeInferenceContext(ctx)); - MLShapeInferenceContext(ctx).SetOutputTensorShape(0, { 2, 2 }); - shapeInferenceContext = ctx; - }); - } - - // Register the schema - MLOperatorSetId opsetId = { "", 7 }; - MLOperatorSchemaDescription* opSchemaDescs = &schemaDesc; - WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema( - &opsetId, - 1, - &opSchemaDescs, - 1, - typeInferrer.Get(), - testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr - )); + // Type inference function + if (testCases[caseIndex].useTypeInference) { + typeInferrer = wil::MakeOrThrow([](IMLOperatorTypeInferenceContext* ctx) -> void { + VerifyTestAttributes(MLOperatorTypeInferenceContext(ctx)); - { - // Register a future version of the schema in the same domain, while setting its - // input count to zero to ensure it is not being used. - auto futureSchemaDesc = schemaDesc; - futureSchemaDesc.inputCount = 0; - - MLOperatorSetId id = { "", 9 }; - MLOperatorSchemaDescription* schemaDescs = &futureSchemaDesc; - WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema( - &id, - 7, - &schemaDescs, - 1, - typeInferrer.Get(), - testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr - )); - } - { - // Register in another (unused) domain to the custom registry - auto otherSchemaDesc = schemaDesc; - otherSchemaDesc.inputCount = 0; - - MLOperatorSetId id = { "otherDomain", 7 }; - MLOperatorSchemaDescription* schemaDescs = &otherSchemaDesc; - WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema( - &id, - 1, - &schemaDescs, - 1, - typeInferrer.Get(), - testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr - )); - } - // Register the Foo kernel - MLOperatorEdgeDescription floatTensorEdgeDesc = {}; - floatTensorEdgeDesc.edgeType = MLOperatorEdgeType::Tensor; - floatTensorEdgeDesc.tensorDataType = MLOperatorTensorDataType::Float; - - MLOperatorEdgeTypeConstrant kernelConstraint = { "T", &floatTensorEdgeDesc, 1 }; - - MLOperatorKernelDescription kernelDesc = + MLOperatorEdgeDescription edgeDesc = {}; + edgeDesc.edgeType = MLOperatorEdgeType::Tensor; + edgeDesc.tensorDataType = MLOperatorTensorDataType::Float; + + MLOperatorTypeInferenceContext(ctx).SetOutputEdgeDescription(0, &edgeDesc); + }); + } + + // Store the shape inference context with a reference following the call to InferOutputShapes. + // This will be called after loading the model as an isolated test for how ABI context objects + // are "closed." + Microsoft::WRL::ComPtr shapeInferenceContext; + + // Shape inference is tested by truncating the output size + bool truncateOutput = testCases[caseIndex].truncateOutput; + if (truncateOutput) { + shapeInferrer = wil::MakeOrThrow([&shapeInferenceContext](IMLOperatorShapeInferenceContext* ctx) -> void { + VerifyTestAttributes(MLShapeInferenceContext(ctx)); + MLShapeInferenceContext(ctx).SetOutputTensorShape(0, {2, 2}); + shapeInferenceContext = ctx; + }); + } + + // Register the schema + MLOperatorSetId opsetId = {"", 7}; + MLOperatorSchemaDescription* opSchemaDescs = &schemaDesc; + WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema( + &opsetId, + 1, + &opSchemaDescs, + 1, + typeInferrer.Get(), + testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr)); + + { + // Register a future version of the schema in the same domain, while setting its + // input count to zero to ensure it is not being used. + auto futureSchemaDesc = schemaDesc; + futureSchemaDesc.inputCount = 0; + + MLOperatorSetId id = {"", 9}; + MLOperatorSchemaDescription* schemaDescs = &futureSchemaDesc; + WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema( + &id, + 7, + &schemaDescs, + 1, + typeInferrer.Get(), + testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr)); + } + { + // Register in another (unused) domain to the custom registry + auto otherSchemaDesc = schemaDesc; + otherSchemaDesc.inputCount = 0; + + MLOperatorSetId id = {"otherDomain", 7}; + MLOperatorSchemaDescription* schemaDescs = &otherSchemaDesc; + WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorSetSchema( + &id, + 1, + &schemaDescs, + 1, + typeInferrer.Get(), + testCases[caseIndex].useShapeInferenceInSchema ? shapeInferrer.Get() : nullptr)); + } + // Register the Foo kernel + MLOperatorEdgeDescription floatTensorEdgeDesc = {}; + floatTensorEdgeDesc.edgeType = MLOperatorEdgeType::Tensor; + floatTensorEdgeDesc.tensorDataType = MLOperatorTensorDataType::Float; + + MLOperatorEdgeTypeConstrant kernelConstraint = {"T", &floatTensorEdgeDesc, 1}; + + MLOperatorKernelDescription kernelDesc = { "", "Foo", 7, MLOperatorExecutionType::Cpu, &kernelConstraint, - 1 - }; + 1}; - if (!testCases[caseIndex].attributeDefaultsInSchema) - { - kernelDesc.defaultAttributes = defaultAttributes; - kernelDesc.defaultAttributeCount = static_cast(std::size(defaultAttributes)); - } + if (!testCases[caseIndex].attributeDefaultsInSchema) { + kernelDesc.defaultAttributes = defaultAttributes; + kernelDesc.defaultAttributeCount = static_cast(std::size(defaultAttributes)); + } - if (!truncateOutput) - { - kernelDesc.options = MLOperatorKernelOptions::AllowDynamicInputShapes; - Microsoft::WRL::ComPtr factory = wil::MakeOrThrow(CreateABIFooKernel); + if (!truncateOutput) { + kernelDesc.options = MLOperatorKernelOptions::AllowDynamicInputShapes; + Microsoft::WRL::ComPtr factory = wil::MakeOrThrow(CreateABIFooKernel); + + WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr)); + } else { + Microsoft::WRL::ComPtr factory = wil::MakeOrThrow(CreateTruncatedABIFooKernel); + WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel( + &kernelDesc, + factory.Get(), + testCases[caseIndex].useShapeInferenceInKernel ? shapeInferrer.Get() : nullptr)); + } - WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel(&kernelDesc, factory.Get(), nullptr)); - } - else - { - Microsoft::WRL::ComPtr factory = wil::MakeOrThrow(CreateTruncatedABIFooKernel); - WINML_EXPECT_EQUAL(S_OK, registry->RegisterOperatorKernel( - &kernelDesc, - factory.Get(), - testCases[caseIndex].useShapeInferenceInKernel ? shapeInferrer.Get() : nullptr - )); - } - - // Prepare inputs - std::vector dimsX = { 3, 2 }; - std::vector valuesX = { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }; - - // Prepare expected inputs and outputs - std::vector expectedDimsY = { truncateOutput ? 2 : 3, 2 }; - // now the expected value should be Add's result. - std::vector expectedValuesY = { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f }; - if (truncateOutput) - { - // The leading dimension is truncated, and the second dimension has two elements over that dim - expectedValuesY.resize(expectedValuesY.size() - 2); - } + // Prepare inputs + std::vector dimsX = {3, 2}; + std::vector valuesX = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Prepare expected inputs and outputs + std::vector expectedDimsY = {truncateOutput ? 2 : 3, 2}; + // now the expected value should be Add's result. + std::vector expectedValuesY = {2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}; + if (truncateOutput) { + // The leading dimension is truncated, and the second dimension has two elements over that dim + expectedValuesY.resize(expectedValuesY.size() - 2); + } - // Load the model and sessions - std::wstring fullPath = FileHelpers::GetModulePath() + (truncateOutput ? L"foo_truncated.onnx" : L"foo.onnx"); - LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider); - LearningModelSession session(model); + // Load the model and sessions + std::wstring fullPath = FileHelpers::GetModulePath() + (truncateOutput ? L"foo_truncated.onnx" : L"foo.onnx"); + LearningModel model = LearningModel::LoadFromFilePath(fullPath, operatorProvider); + LearningModelSession session(model); - // Bind input and outputs - LearningModelBinding bindings(session); + // Bind input and outputs + LearningModelBinding bindings(session); - TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view(std::move(valuesX))); - bindings.Bind(winrt::hstring(L"X"), inputTensor); + TensorFloat inputTensor = TensorFloat::CreateFromArray(dimsX, winrt::array_view(std::move(valuesX))); + bindings.Bind(winrt::hstring(L"X"), inputTensor); - auto outputValue = TensorFloat::Create(); - WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue)); + auto outputValue = TensorFloat::Create(); + WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue)); - // Evaluate the model - hstring correlationId; - WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId)); + // Evaluate the model + winrt::hstring correlationId; + WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId)); - // Verify the result shape - WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size()); - for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) - { - WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j)); - } + // Verify the result shape + WINML_EXPECT_EQUAL(expectedDimsY.size(), outputValue.Shape().Size()); + for (uint32_t j = 0; j < outputValue.Shape().Size(); j++) { + WINML_EXPECT_EQUAL(expectedDimsY.at(j), outputValue.Shape().GetAt(j)); + } - // Verify the result values - auto buffer = outputValue.GetAsVectorView(); - WINML_EXPECT_TRUE(buffer != nullptr); - WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer))); + // Verify the result values + auto buffer = outputValue.GetAsVectorView(); + WINML_EXPECT_TRUE(buffer != nullptr); + WINML_EXPECT_TRUE(std::equal(expectedValuesY.cbegin(), expectedValuesY.cend(), begin(buffer))); - // Release the model before operatorProvider goes out of scope - model = nullptr; + // Release the model before operatorProvider goes out of scope + model = nullptr; - if (shapeInferenceContext) - { - // Check that the shape inference context is closed and safely fails - MLOperatorEdgeDescription edgeDesc; - WINML_EXPECT_EQUAL(E_INVALIDARG, shapeInferenceContext->GetInputEdgeDescription(0, &edgeDesc)); - } + if (shapeInferenceContext) { + // Check that the shape inference context is closed and safely fails + MLOperatorEdgeDescription edgeDesc; + WINML_EXPECT_EQUAL(E_INVALIDARG, shapeInferenceContext->GetInputEdgeDescription(0, &edgeDesc)); } + } } -const CustomOpsTestApi& getapi() { - static constexpr CustomOpsTestApi api = +const CustomOpsTestsApi& getapi() { + static constexpr CustomOpsTestsApi api = { - CustomOpsScenarioTestSetup, - CustomOpsScenarioGpuTestSetup, + CustomOpsScenarioTestsClassSetup, + CustomOpsScenarioTestsGpuMethodSetup, CustomOperatorFusion, CustomKernelWithBuiltInSchema, - CustomKernelWithCustomSchema - }; + CustomKernelWithCustomSchema}; return api; } \ No newline at end of file diff --git a/winml/test/scenario/cppwinrt/CustomOps.h b/winml/test/scenario/cppwinrt/CustomOps.h index 02447bac9d84f..4659234f0d6d2 100644 --- a/winml/test/scenario/cppwinrt/CustomOps.h +++ b/winml/test/scenario/cppwinrt/CustomOps.h @@ -2,21 +2,26 @@ // Licensed under the MIT License. #include "test.h" -struct CustomOpsTestApi +struct CustomOpsTestsApi { - SetupTest CustomOpsScenarioTestSetup; - SetupTest CustomOpsScenarioGpuTestSetup; + SetupTest CustomOpsScenarioTestsClassSetup; + SetupTest CustomOpsScenarioTestsGpuMethodSetup; VoidTest CustomOperatorFusion; VoidTest CustomKernelWithBuiltInSchema; VoidTest CustomKernelWithCustomSchema; }; -const CustomOpsTestApi& getapi(); +const CustomOpsTestsApi& getapi(); -WINML_TEST_CLASS_BEGIN_WITH_SETUP(CustomOpsScenarioTest, CustomOpsScenarioTestSetup) -WINML_TEST(CustomOpsScenarioTest, CustomKernelWithBuiltInSchema) -WINML_TEST(CustomOpsScenarioTest, CustomKernelWithCustomSchema) +WINML_TEST_CLASS_BEGIN(CustomOpsScenarioTests) +WINML_TEST_CLASS_SETUP_CLASS(CustomOpsScenarioTestsClassSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(CustomOpsScenarioTests, CustomKernelWithBuiltInSchema) +WINML_TEST(CustomOpsScenarioTests, CustomKernelWithCustomSchema) WINML_TEST_CLASS_END() -WINML_TEST_CLASS_BEGIN_WITH_SETUP(CustomOpsScenarioGpuTest, CustomOpsScenarioGpuTestSetup) -WINML_TEST(CustomOpsScenarioGpuTest, CustomOperatorFusion) +WINML_TEST_CLASS_BEGIN(CustomOpsScenarioGpuTests) +WINML_TEST_CLASS_SETUP_CLASS(CustomOpsScenarioTestsClassSetup) +WINML_TEST_CLASS_SETUP_METHOD(CustomOpsScenarioTestsGpuMethodSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(CustomOpsScenarioGpuTests, CustomOperatorFusion) WINML_TEST_CLASS_END() \ No newline at end of file diff --git a/winml/test/scenario/cppwinrt/scenariotestscppwinrt.cpp b/winml/test/scenario/cppwinrt/scenariotestscppwinrt.cpp index 2b573d5ba275e..31f01268821ce 100644 --- a/winml/test/scenario/cppwinrt/scenariotestscppwinrt.cpp +++ b/winml/test/scenario/cppwinrt/scenariotestscppwinrt.cpp @@ -5,29 +5,13 @@ #include -#include "winrt/Windows.Devices.Enumeration.Pnp.h" -#include "winrt/Windows.Graphics.DirectX.Direct3D11.h" -#include "winrt/Windows.Media.Capture.h" -#include "winrt/Windows.Media.h" -#include "winrt/Windows.Security.Cryptography.Core.h" -#include "winrt/Windows.Security.Cryptography.h" -#include "winrt/Windows.Storage.h" -#include "winrt/Windows.Storage.Streams.h" - -// lame, but WinBase.h redefines this, which breaks winrt headers later -#ifdef GetCurrentTime -#undef GetCurrentTime -#endif #include "CommonDeviceHelpers.h" #include "CustomOperatorProvider.h" #include "filehelpers.h" #include "robuffer.h" #include "scenariotestscppwinrt.h" -#include "Windows.AI.MachineLearning.Native.h" #include "Windows.Graphics.DirectX.Direct3D11.interop.h" #include "windows.ui.xaml.media.dxinterop.h" -#include "winrt/Windows.UI.Xaml.Controls.h" -#include "winrt/Windows.UI.Xaml.Media.Imaging.h" #include #include @@ -41,29 +25,43 @@ #include #endif -using namespace winrt; -using namespace winrt::Windows::AI::MachineLearning; -using namespace winrt::Windows::Foundation::Collections; -using namespace winrt::Windows::Media; -using namespace winrt::Windows::Graphics::Imaging; -using namespace winrt::Windows::Graphics::DirectX; -using namespace ::Windows::Graphics::DirectX::Direct3D11; -using namespace winrt::Windows::Storage; -using namespace winrt::Windows::Storage::Streams; +// lame, but WinBase.h redefines this, which breaks winrt headers later +#ifdef GetCurrentTime +#undef GetCurrentTime +#endif + +#include +#include +#include +#include +#include + +using namespace winml; +using namespace wfc; +using namespace wm; +using namespace wgi; +using namespace wgdx; +using namespace ws; +using namespace wss; using namespace winrt::Windows::UI::Xaml::Media::Imaging; +using namespace Windows::Graphics::DirectX::Direct3D11; -static void ScenarioCppWinrtTestSetup() { - init_apartment(); +static void ScenarioCppWinrtTestsClassSetup() { + winrt::init_apartment(); } -static void ScenarioCppWinrtGpuTestSetup() { + +static void ScenarioCppWinrtTestsGpuMethodSetup() { GPUTEST; - ScenarioCppWinrtTestSetup(); }; -static void ScenarioCppWinrtGpuSkipEdgeCoreTestSetup() { - SKIP_EDGECORE; - ScenarioCppWinrtGpuTestSetup(); +static void ScenarioCppWinrtTestsSkipEdgeCoreMethodSetup() { + SKIP_EDGECORE +}; + +static void ScenarioCppWinrtTestsGpuSkipEdgeCoreMethodSetup() { + ScenarioCppWinrtTestsGpuMethodSetup(); + SKIP_EDGECORE }; static void Sample1() { @@ -90,12 +88,12 @@ ILearningModelFeatureValue MakeTensor(const ITensorFeatureDescriptor& descriptor return ftv; } default: - throw_hresult(E_NOTIMPL); + winrt::throw_hresult(E_NOTIMPL); break; } } -ILearningModelFeatureValue MakeImage(const IImageFeatureDescriptor& /*descriptor*/, winrt::Windows::Foundation::IInspectable data) { +ILearningModelFeatureValue MakeImage(const IImageFeatureDescriptor& /*descriptor*/, wf::IInspectable data) { VideoFrame videoFrame = nullptr; if (data != nullptr) { SoftwareBitmap sb = nullptr; @@ -109,7 +107,7 @@ ILearningModelFeatureValue MakeImage(const IImageFeatureDescriptor& /*descriptor return imageValue; } -ILearningModelFeatureValue FeatureValueFromFeatureValueDescriptor(ILearningModelFeatureDescriptor descriptor, winrt::Windows::Foundation::IInspectable data = nullptr) { +ILearningModelFeatureValue FeatureValueFromFeatureValueDescriptor(ILearningModelFeatureDescriptor descriptor, wf::IInspectable data = nullptr) { auto kind = descriptor.Kind(); switch (kind) { case LearningModelFeatureKind::Image: { @@ -118,10 +116,10 @@ ILearningModelFeatureValue FeatureValueFromFeatureValueDescriptor(ILearningModel return MakeImage(imageDescriptor, data); } case LearningModelFeatureKind::Map: - throw_hresult(E_NOTIMPL); + winrt::throw_hresult(E_NOTIMPL); break; case LearningModelFeatureKind::Sequence: - throw_hresult(E_NOTIMPL); + winrt::throw_hresult(E_NOTIMPL); break; case LearningModelFeatureKind::Tensor: { TensorFeatureDescriptor tensorDescriptor = nullptr; @@ -129,7 +127,7 @@ ILearningModelFeatureValue FeatureValueFromFeatureValueDescriptor(ILearningModel return MakeTensor(tensorDescriptor); } default: - throw_hresult(E_INVALIDARG); + winrt::throw_hresult(E_INVALIDARG); break; } } @@ -204,7 +202,7 @@ static void Scenario3SoftwareBitmapInputBinding() { } //! Scenario5: run an async eval -winrt::Windows::Foundation::IAsyncOperation DoEvalAsync() { +wf::IAsyncOperation DoEvalAsync() { // load a model std::wstring filePath = FileHelpers::GetModulePath() + L"model.onnx"; LearningModel model = LearningModel::LoadFromFilePath(filePath); @@ -226,7 +224,7 @@ winrt::Windows::Foundation::IAsyncOperation DoEva static void Scenario5AsyncEval() { auto task = DoEvalAsync(); - while (task.Status() == winrt::Windows::Foundation::AsyncStatus::Started) { + while (task.Status() == wf::AsyncStatus::Started) { std::cout << "Waiting...\n"; Sleep(30); } @@ -261,7 +259,7 @@ static void Scenario6BindWithProperties() { bounds.Height = 100; bounds.Width = 100; - auto bitmapsBoundsProperty = winrt::Windows::Foundation::PropertyValue::CreateUInt32Array({bounds.X, bounds.Y, bounds.Width, bounds.Height}); + auto bitmapsBoundsProperty = wf::PropertyValue::CreateUInt32Array({bounds.X, bounds.Y, bounds.Width, bounds.Height}); // insert it in the property set propertySet.Insert(L"BitmapBounds", bitmapsBoundsProperty); @@ -269,7 +267,7 @@ static void Scenario6BindWithProperties() { BitmapPixelFormat bitmapPixelFormat = BitmapPixelFormat::Bgra8; // translate it to an int so it can be used as a PropertyValue; int intFromBitmapPixelFormat = static_cast(bitmapPixelFormat); - auto bitmapPixelFormatProperty = winrt::Windows::Foundation::PropertyValue::CreateInt32(intFromBitmapPixelFormat); + auto bitmapPixelFormatProperty = wf::PropertyValue::CreateInt32(intFromBitmapPixelFormat); // insert it in the property set propertySet.Insert(L"BitmapPixelFormat", bitmapPixelFormatProperty); @@ -282,7 +280,7 @@ static void Scenario6BindWithProperties() { //! Scenario7: run eval without creating a binding object static void Scenario7EvalWithNoBind() { - auto map = winrt::single_threaded_map(); + auto map = winrt::single_threaded_map(); // load a model std::wstring filePath = FileHelpers::GetModulePath() + L"model.onnx"; @@ -357,15 +355,15 @@ static void Scenario8SetDeviceSampleMyCameraDevice() { LearningModel model = LearningModel::LoadFromFilePath(filePath); auto devices = winrt::Windows::Devices::Enumeration::DeviceInformation::FindAllAsync(winrt::Windows::Devices::Enumeration::DeviceClass::VideoCapture).get(); - hstring deviceId; + winrt::hstring deviceId; if (devices.Size() > 0) { auto device = devices.GetAt(0); deviceId = device.Id(); auto deviceName = device.Name(); auto enabled = device.IsEnabled(); std::cout << "Found device " << deviceName.c_str() << ", enabled = " << enabled << "\n"; - winrt::Windows::Media::Capture::MediaCapture captureManager; - winrt::Windows::Media::Capture::MediaCaptureInitializationSettings settings; + wm::Capture::MediaCapture captureManager; + wm::Capture::MediaCaptureInitializationSettings settings; settings.VideoDeviceId(deviceId); captureManager.InitializeAsync(settings).get(); auto mediaCaptureSettings = captureManager.MediaCaptureSettings(); @@ -383,8 +381,8 @@ static void Scenario8SetDeviceSampleD3D11Device() { std::wstring filePath = FileHelpers::GetModulePath() + L"model.onnx"; LearningModel model = LearningModel::LoadFromFilePath(filePath); - com_ptr pD3D11Device = nullptr; - com_ptr pContext = nullptr; + winrt::com_ptr pD3D11Device = nullptr; + winrt::com_ptr pContext = nullptr; D3D_FEATURE_LEVEL fl; HRESULT result = D3D11CreateDevice( nullptr, D3D_DRIVER_TYPE::D3D_DRIVER_TYPE_HARDWARE, nullptr, 0, nullptr, 0, @@ -394,14 +392,14 @@ static void Scenario8SetDeviceSampleD3D11Device() { } // get dxgiDevice from d3ddevice - com_ptr pDxgiDevice; + winrt::com_ptr pDxgiDevice; pD3D11Device.get()->QueryInterface(pDxgiDevice.put()); - com_ptr<::IInspectable> pInspectable; + winrt::com_ptr<::IInspectable> pInspectable; CreateDirect3D11DeviceFromDXGIDevice(pDxgiDevice.get(), pInspectable.put()); LearningModelDevice device = LearningModelDevice::CreateFromDirect3D11Device( - pInspectable.as()); + pInspectable.as()); LearningModelSession session(model, device); } @@ -411,7 +409,7 @@ static void Scenario8SetDeviceSampleCustomCommandQueue() { std::wstring filePath = FileHelpers::GetModulePath() + L"model.onnx"; LearningModel model = LearningModel::LoadFromFilePath(filePath); - com_ptr pD3D12Device = nullptr; + winrt::com_ptr pD3D12Device = nullptr; CommonDeviceHelpers::AdapterEnumerationSupport support; if (FAILED(CommonDeviceHelpers::GetAdapterEnumerationSupport(&support))) { WINML_LOG_ERROR("Unable to load DXGI or DXCore"); @@ -423,12 +421,12 @@ static void Scenario8SetDeviceSampleCustomCommandQueue() { } #ifdef ENABLE_DXCORE if (support.has_dxgi == false) { - com_ptr spFactory; + winrt::com_ptr spFactory; DXCoreCreateAdapterFactory(IID_PPV_ARGS(spFactory.put())); const GUID gpuFilter[] = {DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS}; - com_ptr spAdapterList; + winrt::com_ptr spAdapterList; spFactory->CreateAdapterList(1, gpuFilter, IID_PPV_ARGS(spAdapterList.put())); - com_ptr spAdapter; + winrt::com_ptr spAdapter; WINML_EXPECT_NO_THROW(spAdapterList->GetAdapter(0, IID_PPV_ARGS(spAdapter.put()))); ::IUnknown* pAdapter = spAdapter.get(); WINML_EXPECT_NO_THROW(result = D3D12CreateDevice(pAdapter, D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_12_0, __uuidof(ID3D12Device), reinterpret_cast(pD3D12Device.put()))); @@ -439,13 +437,13 @@ static void Scenario8SetDeviceSampleCustomCommandQueue() { WINML_SKIP_TEST("Test skipped because d3d12 device is missing"); return; } - com_ptr dxQueue = nullptr; + winrt::com_ptr dxQueue = nullptr; D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {}; commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; pD3D12Device->CreateCommandQueue(&commandQueueDesc, __uuidof(ID3D12CommandQueue), reinterpret_cast(&dxQueue)); - auto factory = get_activation_factory(); + auto factory = winrt::get_activation_factory(); - com_ptr<::IUnknown> spUnk; + winrt::com_ptr<::IUnknown> spUnk; factory->CreateFromD3D12CommandQueue(dxQueue.get(), spUnk.put()); auto dmlDeviceCustom = spUnk.as(); @@ -458,16 +456,16 @@ static void Scenario9LoadBindEvalInputTensorGPU() { std::wstring filePath = FileHelpers::GetModulePath() + L"fns-candy.onnx"; LearningModel model = LearningModel::LoadFromFilePath(filePath); - com_ptr pD3D12Device; + winrt::com_ptr pD3D12Device; WINML_EXPECT_NO_THROW(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_11_0, __uuidof(ID3D12Device), pD3D12Device.put_void())); - com_ptr dxQueue; + winrt::com_ptr dxQueue; D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {}; commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; pD3D12Device->CreateCommandQueue(&commandQueueDesc, __uuidof(ID3D12CommandQueue), dxQueue.put_void()); - auto devicefactory = get_activation_factory(); - auto tensorfactory = get_activation_factory(); + auto devicefactory = winrt::get_activation_factory(); + auto tensorfactory = winrt::get_activation_factory(); - com_ptr<::IUnknown> spUnk; + winrt::com_ptr<::IUnknown> spUnk; WINML_EXPECT_NO_THROW(devicefactory->CreateFromD3D12CommandQueue(dxQueue.get(), spUnk.put())); LearningModelDevice dmlDeviceCustom = nullptr; @@ -496,7 +494,7 @@ static void Scenario9LoadBindEvalInputTensorGPU() { D3D12_TEXTURE_LAYOUT_ROW_MAJOR, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; - com_ptr pGPUResource = nullptr; + winrt::com_ptr pGPUResource = nullptr; pD3D12Device->CreateCommittedResource( &heapProperties, D3D12_HEAP_FLAG_NONE, @@ -505,7 +503,7 @@ static void Scenario9LoadBindEvalInputTensorGPU() { nullptr, __uuidof(ID3D12Resource), pGPUResource.put_void()); - com_ptr<::IUnknown> spUnkTensor; + winrt::com_ptr<::IUnknown> spUnkTensor; TensorFloat input1imagetensor(nullptr); __int64 shape[4] = {1, 3, 720, 720}; tensorfactory->CreateFromD3D12Resource(pGPUResource.get(), shape, 4, spUnkTensor.put()); @@ -525,7 +523,7 @@ static void Scenario9LoadBindEvalInputTensorGPU() { WINML_EXPECT_NO_THROW(modelBinding.Bind(model.OutputFeatures().First().Current().Name(), outputTensor)); // Testing GetAsD3D12Resource - com_ptr pReturnedResource; + winrt::com_ptr pReturnedResource; input1imagetensor.as()->GetD3D12Resource(pReturnedResource.put()); WINML_EXPECT_EQUAL(pReturnedResource.get(), pGPUResource.get()); @@ -602,7 +600,7 @@ static void Scenario11FreeDimensionsImage() { struct SwapChainEntry { LearningModelSession session; LearningModelBinding binding; - winrt::Windows::Foundation::IAsyncOperation activetask; + wf::IAsyncOperation activetask; SwapChainEntry() : session(nullptr), binding(nullptr), activetask(nullptr) {} }; void SubmitEval(LearningModel model, SwapChainEntry* sessionBindings, int swapchaindex) { @@ -666,13 +664,13 @@ static void LoadBindEval_CustomOperator_CPU(const wchar_t* fileName) { auto inputValue = TensorFloat::CreateFromIterable( inputShape, - single_threaded_vector(std::move(inputData)).GetView()); + winrt::single_threaded_vector(std::move(inputData)).GetView()); WINML_EXPECT_NO_THROW(bindings.Bind(L"X", inputValue)); auto outputValue = TensorFloat::Create(); WINML_EXPECT_NO_THROW(bindings.Bind(L"Y", outputValue)); - hstring correlationId; + winrt::hstring correlationId; WINML_EXPECT_NO_THROW(session.Evaluate(bindings, correlationId)); auto buffer = outputValue.GetAsVectorView(); @@ -767,8 +765,8 @@ bool VerifyHelper(ImageFeatureValue actual, ImageFeatureValue expected) { // 4 means 4 channels uint32_t size = 4 * softwareBitmapActual.PixelHeight() * softwareBitmapActual.PixelWidth(); - winrt::Windows::Storage::Streams::Buffer actualOutputBuffer(size); - winrt::Windows::Storage::Streams::Buffer expectedOutputBuffer(size); + ws::Streams::Buffer actualOutputBuffer(size); + ws::Streams::Buffer expectedOutputBuffer(size); softwareBitmapActual.CopyToBuffer(actualOutputBuffer); softwareBitmapExpected.CopyToBuffer(expectedOutputBuffer); @@ -814,8 +812,8 @@ static void Scenario22ImageBindingAsCPUTensor() { // Put softwareBitmap into buffer BYTE* pData = nullptr; UINT32 size = 0; - winrt::Windows::Graphics::Imaging::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(winrt::Windows::Graphics::Imaging::BitmapBufferAccessMode::Read)); - winrt::Windows::Foundation::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); + wgi::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(wgi::BitmapBufferAccessMode::Read)); + wf::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); auto spByteAccess = reference.as<::Windows::Foundation::IMemoryBufferByteAccess>(); spByteAccess->GetBuffer(&pData, &size); @@ -823,7 +821,7 @@ static void Scenario22ImageBindingAsCPUTensor() { float* pCPUTensor; uint32_t uCapacity; TensorFloat tf = TensorFloat::Create(shape); - com_ptr itn = tf.as(); + winrt::com_ptr itn = tf.as(); itn->GetBuffer(reinterpret_cast(&pCPUTensor), &uCapacity); uint32_t height = softwareBitmap.PixelHeight(); @@ -883,8 +881,8 @@ static void Scenario22ImageBindingAsGPUTensor() { // Put softwareBitmap into cpu buffer BYTE* pData = nullptr; UINT32 size = 0; - winrt::Windows::Graphics::Imaging::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(winrt::Windows::Graphics::Imaging::BitmapBufferAccessMode::Read)); - winrt::Windows::Foundation::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); + wgi::BitmapBuffer spBitmapBuffer(softwareBitmap.LockBuffer(wgi::BitmapBufferAccessMode::Read)); + wf::IMemoryBufferReference reference = spBitmapBuffer.CreateReference(); auto spByteAccess = reference.as<::Windows::Foundation::IMemoryBufferByteAccess>(); spByteAccess->GetBuffer(&pData, &size); @@ -894,7 +892,7 @@ static void Scenario22ImageBindingAsGPUTensor() { // CPU tensorization TensorFloat tf = TensorFloat::Create(shape); - com_ptr itn = tf.as(); + winrt::com_ptr itn = tf.as(); itn->GetBuffer(reinterpret_cast(&pCPUTensor), &uCapacity); uint32_t height = softwareBitmap.PixelHeight(); @@ -907,17 +905,17 @@ static void Scenario22ImageBindingAsGPUTensor() { } // create the d3d device. - com_ptr pD3D12Device = nullptr; + winrt::com_ptr pD3D12Device = nullptr; WINML_EXPECT_NO_THROW(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_11_0, __uuidof(ID3D12Device), reinterpret_cast(&pD3D12Device))); // create the command queue. - com_ptr dxQueue = nullptr; + winrt::com_ptr dxQueue = nullptr; D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {}; commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; pD3D12Device->CreateCommandQueue(&commandQueueDesc, __uuidof(ID3D12CommandQueue), reinterpret_cast(&dxQueue)); - auto devicefactory = get_activation_factory(); - auto tensorfactory = get_activation_factory(); - com_ptr<::IUnknown> spUnk; + auto devicefactory = winrt::get_activation_factory(); + auto tensorfactory = winrt::get_activation_factory(); + winrt::com_ptr<::IUnknown> spUnk; devicefactory->CreateFromD3D12CommandQueue(dxQueue.get(), spUnk.put()); LearningModel model(nullptr); @@ -931,8 +929,8 @@ static void Scenario22ImageBindingAsGPUTensor() { // Create ID3D12GraphicsCommandList and Allocator D3D12_COMMAND_LIST_TYPE queuetype = dxQueue->GetDesc().Type; - com_ptr alloctor; - com_ptr cmdList; + winrt::com_ptr alloctor; + winrt::com_ptr cmdList; pD3D12Device->CreateCommandAllocator( queuetype, @@ -968,8 +966,8 @@ static void Scenario22ImageBindingAsGPUTensor() { D3D12_TEXTURE_LAYOUT_ROW_MAJOR, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; - com_ptr pGPUResource = nullptr; - com_ptr imageUploadHeap; + winrt::com_ptr pGPUResource = nullptr; + winrt::com_ptr imageUploadHeap; pD3D12Device->CreateCommittedResource( &heapProperties, D3D12_HEAP_FLAG_NONE, @@ -1004,7 +1002,7 @@ static void Scenario22ImageBindingAsGPUTensor() { dxQueue->ExecuteCommandLists(_countof(ppCommandLists), ppCommandLists); // GPU tensorize - com_ptr<::IUnknown> spUnkTensor; + winrt::com_ptr<::IUnknown> spUnkTensor; TensorFloat input1imagetensor(nullptr); __int64 shapes[4] = {1, 3, softwareBitmap.PixelWidth(), softwareBitmap.PixelHeight()}; tensorfactory->CreateFromD3D12Resource(pGPUResource.get(), shapes, 4, spUnkTensor.put()); @@ -1129,7 +1127,7 @@ static void SyncVsAsync() { std::cout << "Synchronous time for " << N << " evaluations: " << syncTime.count() << " milliseconds\n"; // evaluate N times Asynchronously and time it - std::vector> tasks; + std::vector> tasks; std::vector bindings(N, nullptr); for (size_t i = 0; i < bindings.size(); i++) { @@ -1158,21 +1156,21 @@ static void CustomCommandQueueWithFence() { static const wchar_t* const modelFileName = L"fns-candy.onnx"; static const wchar_t* const inputDataImageFileName = L"fish_720.png"; - com_ptr d3d12Device; + winrt::com_ptr d3d12Device; WINML_EXPECT_HRESULT_SUCCEEDED(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, __uuidof(ID3D12Device), d3d12Device.put_void())); D3D12_COMMAND_QUEUE_DESC queueDesc = {}; queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; - com_ptr queue; + winrt::com_ptr queue; WINML_EXPECT_HRESULT_SUCCEEDED(d3d12Device->CreateCommandQueue(&queueDesc, __uuidof(ID3D12CommandQueue), queue.put_void())); - com_ptr fence; + winrt::com_ptr fence; WINML_EXPECT_HRESULT_SUCCEEDED(d3d12Device->CreateFence(0, D3D12_FENCE_FLAG_NONE, __uuidof(ID3D12Fence), fence.put_void())); - auto devicefactory = get_activation_factory(); + auto devicefactory = winrt::get_activation_factory(); - com_ptr<::IUnknown> learningModelDeviceUnknown; + winrt::com_ptr<::IUnknown> learningModelDeviceUnknown; WINML_EXPECT_HRESULT_SUCCEEDED(devicefactory->CreateFromD3D12CommandQueue(queue.get(), learningModelDeviceUnknown.put())); LearningModelDevice device = nullptr; @@ -1219,13 +1217,13 @@ static void CustomCommandQueueWithFence() { WINML_EXPECT_HRESULT_SUCCEEDED(queue->Signal(fence.get(), 2)); winrt::hstring correlationId; - winrt::Windows::Foundation::IAsyncOperation asyncOp; + wf::IAsyncOperation asyncOp; WINML_EXPECT_NO_THROW(asyncOp = modelSession.EvaluateAsync(modelBinding, correlationId)); Sleep(1000); // Give the model a chance to run (which it shouldn't if everything is working correctly) // Because we haven't unblocked the wait yet, model evaluation must not have completed (nor the fence signal) - WINML_EXPECT_NOT_EQUAL(asyncOp.Status(), winrt::Windows::Foundation::AsyncStatus::Completed); + WINML_EXPECT_NOT_EQUAL(asyncOp.Status(), wf::AsyncStatus::Completed); WINML_EXPECT_EQUAL(fence->GetCompletedValue(), 0); // Unblock the queue @@ -1298,7 +1296,7 @@ static void EncryptedStream() { // get a stream std::wstring path = FileHelpers::GetModulePath() + L"model.onnx"; auto storageFile = StorageFile::GetFileFromPathAsync(path).get(); - auto fileBuffer = winrt::Windows::Storage::FileIO::ReadBufferAsync(storageFile).get(); + auto fileBuffer = ws::FileIO::ReadBufferAsync(storageFile).get(); // encrypt auto algorithmName = winrt::Windows::Security::Cryptography::Core::SymmetricAlgorithmNames::AesCbcPkcs7(); @@ -1371,16 +1369,16 @@ static void D2DInterop() { std::wstring filePath = FileHelpers::GetModulePath() + L"model.onnx"; LearningModel model = LearningModel::LoadFromFilePath(filePath); // create a dx12 device - com_ptr device = nullptr; + winrt::com_ptr device = nullptr; WINML_EXPECT_HRESULT_SUCCEEDED(D3D12CreateDevice(NULL, D3D_FEATURE_LEVEL_11_0, __uuidof(ID3D12Device1), device.put_void())); // now create a command queue from it - com_ptr commandQueue = nullptr; + winrt::com_ptr commandQueue = nullptr; D3D12_COMMAND_QUEUE_DESC queueDesc = {}; queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; WINML_EXPECT_HRESULT_SUCCEEDED(device->CreateCommandQueue(&queueDesc, winrt::guid_of(), commandQueue.put_void())); // create a winml learning device based on that dx12 queue - auto factory = get_activation_factory(); - com_ptr<::IUnknown> spUnk; + auto factory = winrt::get_activation_factory(); + winrt::com_ptr<::IUnknown> spUnk; WINML_EXPECT_HRESULT_SUCCEEDED(factory->CreateFromD3D12CommandQueue(commandQueue.get(), spUnk.put())); auto learningDevice = spUnk.as(); // create a winml session from that dx device @@ -1393,14 +1391,14 @@ static void D2DInterop() { session.Device().Direct3D11Device()); // create a D2D factory D2D1_FACTORY_OPTIONS options = {}; - com_ptr d2dFactory; + winrt::com_ptr d2dFactory; WINML_EXPECT_HRESULT_SUCCEEDED(D2D1CreateFactory(D2D1_FACTORY_TYPE_SINGLE_THREADED, __uuidof(ID2D1Factory), &options, d2dFactory.put_void())); // grab the dxgi surface back from our video frame - com_ptr dxgiSurface; - com_ptr dxgiInterfaceAccess = frame.Direct3DSurface().as(); + winrt::com_ptr dxgiSurface; + winrt::com_ptr dxgiInterfaceAccess = frame.Direct3DSurface().as(); WINML_EXPECT_HRESULT_SUCCEEDED(dxgiInterfaceAccess->GetInterface(__uuidof(IDXGISurface), dxgiSurface.put_void())); // and try and use our surface to create a render targer - com_ptr renderTarget; + winrt::com_ptr renderTarget; D2D1_RENDER_TARGET_PROPERTIES props = D2D1::RenderTargetProperties(); props.pixelFormat = D2D1::PixelFormat( DXGI_FORMAT_B8G8R8A8_UNORM, @@ -1411,12 +1409,13 @@ static void D2DInterop() { renderTarget.put())); } -const ScenarioTestApi& getapi() { - static constexpr ScenarioTestApi api = +const ScenarioTestsApi& getapi() { + static constexpr ScenarioTestsApi api = { - ScenarioCppWinrtTestSetup, - ScenarioCppWinrtGpuTestSetup, - ScenarioCppWinrtGpuSkipEdgeCoreTestSetup, + ScenarioCppWinrtTestsClassSetup, + ScenarioCppWinrtTestsGpuMethodSetup, + ScenarioCppWinrtTestsSkipEdgeCoreMethodSetup, + ScenarioCppWinrtTestsGpuSkipEdgeCoreMethodSetup, Sample1, Scenario1LoadBindEvalDefault, Scenario2LoadModelFromStream, diff --git a/winml/test/scenario/cppwinrt/scenariotestscppwinrt.h b/winml/test/scenario/cppwinrt/scenariotestscppwinrt.h index 2409de5fd60c2..9999b4465ca03 100644 --- a/winml/test/scenario/cppwinrt/scenariotestscppwinrt.h +++ b/winml/test/scenario/cppwinrt/scenariotestscppwinrt.h @@ -2,11 +2,12 @@ // Licensed under the MIT License. #include "test.h" -struct ScenarioTestApi +struct ScenarioTestsApi { - SetupTest ScenarioCppWinrtTestSetup; - SetupTest ScenarioCppWinrtGpuTestSetup; - SetupTest ScenarioCppWinrtGpuSkipEdgeCoreTestSetup; + SetupClass ScenarioCppWinrtTestsClassSetup; + SetupTest ScenarioCppWinrtTestsGpuMethodSetup; + SetupTest ScenarioCppWinrtTestsSkipEdgeCoreMethodSetup; + SetupTest ScenarioCppWinrtTestsGpuSkipEdgeCoreMethodSetup; VoidTest Sample1; VoidTest Scenario1LoadBindEvalDefault; VoidTest Scenario2LoadModelFromStream; @@ -42,47 +43,61 @@ struct ScenarioTestApi VoidTest Scenario8SetDeviceSampleD3D11Device; VoidTest D2DInterop; }; -const ScenarioTestApi& getapi(); +const ScenarioTestsApi& getapi(); -WINML_TEST_CLASS_BEGIN_WITH_SETUP(ScenarioCppWinrtTest, ScenarioCppWinrtTestSetup) -WINML_TEST(ScenarioCppWinrtTest, Sample1) -WINML_TEST(ScenarioCppWinrtTest, Scenario1LoadBindEvalDefault) -WINML_TEST(ScenarioCppWinrtTest, Scenario2LoadModelFromStream) -WINML_TEST(ScenarioCppWinrtTest, Scenario5AsyncEval) -WINML_TEST(ScenarioCppWinrtTest, Scenario7EvalWithNoBind) -WINML_TEST(ScenarioCppWinrtTest, Scenario8SetDeviceSampleDefault) -WINML_TEST(ScenarioCppWinrtTest, Scenario8SetDeviceSampleCPU) -WINML_TEST(ScenarioCppWinrtTest, Scenario17DevDiagnostics) -WINML_TEST(ScenarioCppWinrtTest, DISABLED_Scenario22ImageBindingAsCPUTensor) -WINML_TEST(ScenarioCppWinrtTest, QuantizedModels) -WINML_TEST(ScenarioCppWinrtTest, EncryptedStream) +WINML_TEST_CLASS_BEGIN(ScenarioCppWinrtTests) +WINML_TEST_CLASS_SETUP_CLASS(ScenarioCppWinrtTestsClassSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(ScenarioCppWinrtTests, Sample1) +WINML_TEST(ScenarioCppWinrtTests, Scenario1LoadBindEvalDefault) +WINML_TEST(ScenarioCppWinrtTests, Scenario2LoadModelFromStream) +WINML_TEST(ScenarioCppWinrtTests, Scenario5AsyncEval) +WINML_TEST(ScenarioCppWinrtTests, Scenario7EvalWithNoBind) +WINML_TEST(ScenarioCppWinrtTests, Scenario8SetDeviceSampleDefault) +WINML_TEST(ScenarioCppWinrtTests, Scenario8SetDeviceSampleCPU) +WINML_TEST(ScenarioCppWinrtTests, Scenario17DevDiagnostics) +WINML_TEST(ScenarioCppWinrtTests, DISABLED_Scenario22ImageBindingAsCPUTensor) +WINML_TEST(ScenarioCppWinrtTests, QuantizedModels) +WINML_TEST(ScenarioCppWinrtTests, EncryptedStream) WINML_TEST_CLASS_END() -WINML_TEST_CLASS_BEGIN_WITH_SETUP(ScenarioCppWinrtGpuTest, ScenarioCppWinrtGpuTestSetup) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario3SoftwareBitmapInputBinding) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario6BindWithProperties) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario8SetDeviceSampleDefaultDirectX) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario8SetDeviceSampleMinPower) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario8SetDeviceSampleMaxPerf) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario8SetDeviceSampleCustomCommandQueue) -WINML_TEST(ScenarioCppWinrtGpuTest, DISABLED_Scenario9LoadBindEvalInputTensorGPU) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario13SingleModelOnCPUandGPU) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario11FreeDimensionsTensor) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario11FreeDimensionsImage) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario14RunModelSwapchain) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario20aLoadBindEvalCustomOperatorCPU) -WINML_TEST(ScenarioCppWinrtGpuTest, Scenario20bLoadBindEvalReplacementCustomOperatorCPU) -WINML_TEST(ScenarioCppWinrtGpuTest, DISABLED_Scenario21RunModel2ChainZ) -WINML_TEST(ScenarioCppWinrtGpuTest, DISABLED_Scenario22ImageBindingAsGPUTensor) -WINML_TEST(ScenarioCppWinrtGpuTest, MsftQuantizedModels) -WINML_TEST(ScenarioCppWinrtGpuTest, DISABLED_SyncVsAsync) -WINML_TEST(ScenarioCppWinrtGpuTest, DISABLED_CustomCommandQueueWithFence) -WINML_TEST(ScenarioCppWinrtGpuTest, DISABLED_ReuseVideoFrame) -WINML_TEST(ScenarioCppWinrtGpuTest, DeviceLostRecovery) +WINML_TEST_CLASS_BEGIN(ScenarioCppWinrtTestsGpu) +WINML_TEST_CLASS_SETUP_CLASS(ScenarioCppWinrtTestsClassSetup) +WINML_TEST_CLASS_SETUP_METHOD(ScenarioCppWinrtTestsGpuMethodSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario3SoftwareBitmapInputBinding) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario6BindWithProperties) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario8SetDeviceSampleDefaultDirectX) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario8SetDeviceSampleMinPower) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario8SetDeviceSampleMaxPerf) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario8SetDeviceSampleCustomCommandQueue) +WINML_TEST(ScenarioCppWinrtTestsGpu, DISABLED_Scenario9LoadBindEvalInputTensorGPU) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario13SingleModelOnCPUandGPU) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario11FreeDimensionsTensor) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario11FreeDimensionsImage) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario14RunModelSwapchain) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario20aLoadBindEvalCustomOperatorCPU) +WINML_TEST(ScenarioCppWinrtTestsGpu, Scenario20bLoadBindEvalReplacementCustomOperatorCPU) +WINML_TEST(ScenarioCppWinrtTestsGpu, DISABLED_Scenario21RunModel2ChainZ) +WINML_TEST(ScenarioCppWinrtTestsGpu, DISABLED_Scenario22ImageBindingAsGPUTensor) +WINML_TEST(ScenarioCppWinrtTestsGpu, MsftQuantizedModels) +WINML_TEST(ScenarioCppWinrtTestsGpu, DISABLED_SyncVsAsync) +WINML_TEST(ScenarioCppWinrtTestsGpu, DISABLED_CustomCommandQueueWithFence) +WINML_TEST(ScenarioCppWinrtTestsGpu, DISABLED_ReuseVideoFrame) +WINML_TEST(ScenarioCppWinrtTestsGpu, DeviceLostRecovery) WINML_TEST_CLASS_END() -WINML_TEST_CLASS_BEGIN_WITH_SETUP(ScenarioCppWinrtGpuSkipEdgeCoreTest, ScenarioCppWinrtGpuSkipEdgeCoreTestSetup) -WINML_TEST(ScenarioCppWinrtGpuSkipEdgeCoreTest, Scenario8SetDeviceSampleMyCameraDevice) -WINML_TEST(ScenarioCppWinrtGpuSkipEdgeCoreTest, Scenario8SetDeviceSampleD3D11Device ) -WINML_TEST(ScenarioCppWinrtGpuSkipEdgeCoreTest, D2DInterop) +WINML_TEST_CLASS_BEGIN(ScenarioCppWinrtTestsSkipEdgeCore) +WINML_TEST_CLASS_SETUP_CLASS(ScenarioCppWinrtTestsClassSetup) +WINML_TEST_CLASS_SETUP_METHOD(ScenarioCppWinrtTestsSkipEdgeCoreMethodSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(ScenarioCppWinrtTestsSkipEdgeCore, Scenario8SetDeviceSampleMyCameraDevice) +WINML_TEST_CLASS_END() + +WINML_TEST_CLASS_BEGIN(ScenarioCppWinrtTestsGpuSkipEdgeCore) +WINML_TEST_CLASS_SETUP_CLASS(ScenarioCppWinrtTestsClassSetup) +WINML_TEST_CLASS_SETUP_METHOD(ScenarioCppWinrtTestsGpuSkipEdgeCoreMethodSetup) +WINML_TEST_CLASS_BEGIN_TESTS +WINML_TEST(ScenarioCppWinrtTestsGpuSkipEdgeCore, Scenario8SetDeviceSampleD3D11Device) +WINML_TEST(ScenarioCppWinrtTestsGpuSkipEdgeCore, D2DInterop) WINML_TEST_CLASS_END() \ No newline at end of file