Skip to content

Commit

Permalink
Minor changes to AMD element-wise kernels to converge with CUDA eleme…
Browse files Browse the repository at this point in the history
…nt-wise kernels.
  • Loading branch information
jessebenson committed Dec 15, 2020
1 parent a954828 commit a8d549e
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

#include <cuda_runtime.h>
#include "binary_elementwise_ops_impl.h"
#include "contrib_ops/cuda/math/binary_elementwise_ops_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cu_inc/binary_elementwise_impl.cuh"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

#include <hip/hip_runtime.h>
#include "binary_elementwise_ops_impl.h"
#include "contrib_ops/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

#include <cuda_runtime.h>
#include "binary_elementwise_ops_impl.h"
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cu_inc/binary_elementwise_impl.cuh"
#include "core/providers/cuda/math/binary_elementwise_ops_impl_functors.cuh"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ BINARY_OP_REGISTER_VERSIONED_CLASS_HFD(Pow, Pow_7, 7, 11)
BINARY_LOGICALOP_TYPED(And, 7, bool)
BINARY_LOGICALOP_TYPED(Or, 7, bool)
BINARY_LOGICALOP_TYPED(Xor, 7, bool)
BINARY_OP_HFD(PRelu, 7)
BINARY_OP_VERSIONED_HFD(PRelu, 7, 8)
BINARY_OP_HFD(PRelu, 9)

// Pow since version 12
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

#include <hip/hip_runtime.h>
#include "binary_elementwise_ops_impl.h"
#include "core/providers/rocm/math/binary_elementwise_ops_impl.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/cu_inc/binary_elementwise_impl.cuh"
#include "core/providers/rocm/math/binary_elementwise_ops_impl_functors.cuh"
Expand Down

This file was deleted.

18 changes: 12 additions & 6 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, PRelu);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, PRelu);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, PRelu);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, PRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, And);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Or);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Xor);
Expand Down Expand Up @@ -1063,9 +1066,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, PRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, And)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Or)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, bool, Xor)>,
Expand Down
1 change: 0 additions & 1 deletion tools/ci_build/amd_hipify.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
'math/binary_elementwise_ops.h',
'math/binary_elementwise_ops_impl.cu',
'math/binary_elementwise_ops_impl.h',
'math/binary_elementwise_ops_impl_functors.cuh',
'math/cumsum.cc',
'math/cumsum.h',
'math/cumsum_impl.cu',
Expand Down

0 comments on commit a8d549e

Please sign in to comment.