Skip to content

Commit

Permalink
Add example of registering custom cuda op as shared lib (microsoft#10025
Browse files Browse the repository at this point in the history
)
  • Loading branch information
wangyems authored Jan 5, 2022
1 parent 2078210 commit 2803a94
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 3 deletions.
14 changes: 11 additions & 3 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,13 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
endif()
endif()

onnxruntime_add_shared_library_module(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
if (onnxruntime_USE_CUDA)
onnxruntime_add_shared_library_module(custom_op_library ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu
${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
target_include_directories(custom_op_library PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
else()
onnxruntime_add_shared_library_module(custom_op_library ${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.cc)
endif()
target_include_directories(custom_op_library PRIVATE ${REPO_ROOT}/include)
if(UNIX)
if (APPLE)
Expand All @@ -1175,8 +1181,10 @@ if(UNIX)
endif()
else()
set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.def")
target_compile_options(custom_op_library PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler /wd26409>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26409>")
if (NOT onnxruntime_USE_CUDA)
target_compile_options(custom_op_library PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler /wd26409>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd26409>")
endif()
endif()
set_property(TARGET custom_op_library APPEND_STRING PROPERTY LINK_FLAGS ${ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,12 @@ private void TestRegisterCustomOpLibrary()
string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName);
Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist.");

var ortEnvInstance = OrtEnv.Instance();
string[] providers = ortEnvInstance.GetAvailableProviders();
if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider")) {
option.AppendExecutionProvider_CUDA(0);
}

IntPtr libraryHandle = IntPtr.Zero;
try
{
Expand Down
3 changes: 3 additions & 0 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,9 @@ public void testLoadCustomLibrary() throws OrtException {
try (OrtEnvironment env = OrtEnvironment.getEnvironment("testLoadCustomLibrary");
SessionOptions options = new SessionOptions()) {
options.registerCustomOpLibrary(customLibraryName);
if (OnnxRuntime.extractCUDA()) {
options.addCUDA();
}
try (OrtSession session = env.createSession(customOpLibraryTestModel, options)) {
Map<String, OnnxTensor> container = new HashMap<>();

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,13 @@ lib_name = "./libcustom_op_library.so";
#endif

void* library_handle = nullptr;
#ifdef USE_CUDA
TestInference<int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
expected_values_y, 1, nullptr, lib_name.c_str(), &library_handle);
#else
TestInference<int32_t>(*ort_env, CUSTOM_OP_LIBRARY_TEST_MODEL_URI, inputs, "output", expected_dims_y,
expected_values_y, 0, nullptr, lib_name.c_str(), &library_handle);
#endif

#ifdef _WIN32
bool success = ::FreeLibrary(reinterpret_cast<HMODULE>(library_handle));
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
#include <cmath>
#include <mutex>

#ifdef USE_CUDA
#include <cuda_runtime.h>
template <typename T1, typename T2, typename T3>
void cuda_add(int64_t, T3*, const T1*, const T2*, cudaStream_t compute_stream);
#endif

static const char* c_OpDomain = "test.customop";

struct OrtCustomOpDomainDeleter {
Expand Down Expand Up @@ -63,9 +69,14 @@ struct KernelOne {
ort_.ReleaseTensorTypeAndShapeInfo(output_info);

// Do computation
#ifdef USE_CUDA
cudaStream_t stream = reinterpret_cast<cudaStream_t>(ort_.KernelContext_GetGPUComputeStream(context));
cuda_add(size, out, X, Y, stream);
#else
for (int64_t i = 0; i < size; i++) {
out[i] = X[i] + Y[i];
}
#endif
}

private:
Expand Down Expand Up @@ -112,6 +123,10 @@ struct CustomOpOne : Ort::CustomOpBase<CustomOpOne, KernelOne> {

const char* GetName() const { return "CustomOpOne"; };

#ifdef USE_CUDA
const char* GetExecutionProviderType() const { return "CUDAExecutionProvider"; };
#endif

size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };

Expand Down

0 comments on commit 2803a94

Please sign in to comment.