Skip to content

Commit

Permalink
Support disabling support for the optional type in ORT builds (micros…
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharans29 authored Nov 18, 2021
1 parent 9fb3fac commit e23892d
Show file tree
Hide file tree
Showing 34 changed files with 1,558 additions and 1,223 deletions.
5 changes: 5 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ option(onnxruntime_USE_ROCM "Build with AMD GPU support" OFF)
option(onnxruntime_DISABLE_CONTRIB_OPS "Disable contrib ops" OFF)
option(onnxruntime_DISABLE_ML_OPS "Disable traditional ML ops" OFF)
option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OFF)
option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF)
option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF)
cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON" OFF)
# For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone
Expand Down Expand Up @@ -817,6 +818,10 @@ if (onnxruntime_DISABLE_SPARSE_TENSORS)
add_compile_definitions(DISABLE_SPARSE_TENSORS)
endif()

if (onnxruntime_DISABLE_OPTIONAL_TYPE)
add_compile_definitions(DISABLE_OPTIONAL_TYPE)
endif()

if (onnxruntime_USE_CUDA AND "${onnxruntime_CUDNN_HOME}" STREQUAL "")
message(FATAL_ERROR "onnxruntime_CUDNN_HOME required for onnxruntime_USE_CUDA")
endif()
Expand Down
68 changes: 65 additions & 3 deletions include/onnxruntime/core/framework/data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ class SparseTensorTypeBase;
#endif
class SequenceTensorTypeBase;
class NonTensorTypeBase;
#if !defined(DISABLE_OPTIONAL_TYPE)
class OptionalTypeBase;
#endif
class PrimitiveDataTypeBase;
class Tensor;
class TensorSeq;
Expand Down Expand Up @@ -132,9 +134,11 @@ class DataTypeImpl {
}
#endif

#if !defined(DISABLE_OPTIONAL_TYPE)
virtual const OptionalTypeBase* AsOptionalType() const {
return nullptr;
}
#endif

virtual const NonTensorTypeBase* AsNonTensorType() const {
return nullptr;
Expand Down Expand Up @@ -319,11 +323,13 @@ struct IsSparseTensorContainedType : public IsAnyOf<T, float, uint8_t, int8_t, u
};
#endif

#if !defined(DISABLE_OPTIONAL_TYPE)
/// Tells if the specified type is one of ORT types
/// that can be contained within an optional struct.
template <typename T>
struct IsOptionalOrtType : public IsAnyOf<T, Tensor, TensorSeq> {
};
#endif

/// This template's Get() returns a corresponding MLDataType
/// It dispatches the call to either GetTensorType<>() or
Expand Down Expand Up @@ -505,6 +511,49 @@ class TensorType : public TensorTypeBase {
}
};

#if defined(DISABLE_OPTIONAL_TYPE)

/// Common base-class for all disabled types. We need DataTypeImpl::ToString to work in a minimal build
/// with disabled types to keep the ORT format model kernel hashes stable.
class DisabledTypeBase : public DataTypeImpl {
public:
static MLDataType Type();

bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override {
// We always want to return false for the IsCompatible() for a disabled type
// because this will ensure that no kernel supporting the disabled type will
// be matched to a model node requiring that type and the model load will
// result in failure.
return false;
}

size_t Size() const override {
ORT_THROW("Type is disabled in this build.");
}

DeleteFunc GetDeleteFunc() const override {
ORT_THROW("Type is disabled in this build.");
}

// This must work
const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override;

ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DisabledTypeBase);

protected:
// This must work
ONNX_NAMESPACE::TypeProto& MutableTypeProto();

DisabledTypeBase();
~DisabledTypeBase() override;

private:
struct Impl;
Impl* impl_;
};

#endif

#if !defined(DISABLE_SPARSE_TENSORS)
/// Common base-class for all sparse-tensors (with different element types).
class SparseTensorTypeBase : public DataTypeImpl {
Expand Down Expand Up @@ -569,6 +618,8 @@ class SparseTensorType : public SparseTensorTypeBase {
#endif // !defined(DISABLE_SPARSE_TENSORS)

/// Common base-class for all optional types.

#if !defined(DISABLE_OPTIONAL_TYPE)
class OptionalTypeBase : public DataTypeImpl {
public:
static MLDataType Type();
Expand Down Expand Up @@ -613,18 +664,28 @@ class OptionalTypeBase : public DataTypeImpl {
struct Impl;
Impl* impl_;
};
#endif

// Derive from OptionalTypeBase if the Optional type support is enabled,
// else derive from DisabledTypeBase
template <typename T, typename elemT>
class OptionalType : public OptionalTypeBase {
class OptionalType :
#if !defined(DISABLE_OPTIONAL_TYPE)
public OptionalTypeBase
#else
public DisabledTypeBase
#endif
{
public:
static MLDataType Type();

#if !defined(DISABLE_OPTIONAL_TYPE)
static_assert(data_types_internal::IsOptionalOrtType<T>::value,
"Requires one of the supported types: Tensor or TensorSeq");

static_assert(data_types_internal::IsTensorContainedType<elemT>::value,
"Requires one of the tensor fundamental types");

static MLDataType Type();

MLDataType GetElementType() const override {
if (std::is_same<T, Tensor>::value) {
return DataTypeImpl::GetTensorType<elemT>();
Expand All @@ -635,6 +696,7 @@ class OptionalType : public OptionalTypeBase {
ORT_ENFORCE(false, "Unsupported optional type");
}
}
#endif

private:
OptionalType() {
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/framework/op_kernel_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class OpKernelContext {
SparseTensor* OutputSparse(int index, const TensorShape& shape);
#endif

#if !defined(DISABLE_OPTIONAL_TYPE)
// Use this API to output a "None" of a specific type (e.g. Tensor) at specified index
template <typename T>
void OutputOptionalWithoutData(int index) {
Expand All @@ -92,6 +93,7 @@ class OpKernelContext {
type,
type->GetDeleteFunc());
}
#endif

// Retrieve indexed shape obtained from memory planning before actual
// computation. If the indexed shape cannot be inferred, this function returns
Expand Down
1 change: 0 additions & 1 deletion include/onnxruntime/core/framework/ort_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ struct OrtValue {
template <typename T>
const T& Get() const {
ORT_ENFORCE(onnxruntime::DataTypeImpl::GetType<T>() == type_, onnxruntime::DataTypeImpl::GetType<T>(), " != ", type_);
ORT_ENFORCE(IsAllocated(), "OrtValue contains no data");
return *static_cast<T*>(data_.get());
}

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,16 @@ class PlannerImpl {
// TODO this should be an error case, needs more investigation
continue;
}

#if !defined(DISABLE_OPTIONAL_TYPE)
// Make sure optional types are not up for re-use as we aren't quite
// sure if the re-used tensor will be a None or otherwise. This cannot
// be determined statically.
if (IsOptionalType(*p_node_arg)) {
continue;
}
#endif

auto& available_memory_info = AllocPlan(p_node_arg->Name()).location;
if (!(available_memory_info == required_memory_info)) continue;
auto p_available_buffer_shape = context_.GetShape(*p_node_arg);
Expand Down Expand Up @@ -1142,10 +1146,12 @@ class PlannerImpl {
return !utils::HasTensorType(type_proto);
}

#if !defined(DISABLE_OPTIONAL_TYPE)
static bool IsOptionalType(const onnxruntime::NodeArg& nodearg) {
const auto* type_proto = nodearg.TypeAsProto();
return type_proto->value_case() == ONNX_NAMESPACE::TypeProto::kOptionalType;
}
#endif

//For in-place reuse tensors, the lifetime is the union of all the tensors that tensors that use that buffer
#if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE)
Expand Down
67 changes: 59 additions & 8 deletions onnxruntime/core/framework/data_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_SparseTensor& tensor_proto,
const ONNX_NAMESPACE::TypeProto_SparseTensor& type_proto);
#endif

#if !defined(DISABLE_OPTIONAL_TYPE)
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Optional& optional_proto,
const ONNX_NAMESPACE::TypeProto_Optional& type_proto);
#endif

#if !defined(DISABLE_ML_OPS)
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto,
const ONNX_NAMESPACE::TypeProto_Map& type_proto);
Expand Down Expand Up @@ -196,6 +201,11 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Map& map_proto,
case TypeProto::ValueCase::kSparseTensorType:
result = IsCompatible(lhs.value_type().sparse_tensor_type(), rhs.value_type().sparse_tensor_type());
break;
#endif
#if !defined(DISABLE_OPTIONAL_TYPE)
case TypeProto::ValueCase::kOptionalType:
result = IsCompatible(lhs.value_type().optional_type(), rhs.value_type().optional_type());
break;
#endif
default:
ORT_ENFORCE(false);
Expand Down Expand Up @@ -231,6 +241,11 @@ static bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto_1,
case TypeProto::ValueCase::kSparseTensorType:
result = IsCompatible(type_proto_1.sparse_tensor_type(), type_proto_2.sparse_tensor_type());
break;
#endif
#if !defined(DISABLE_OPTIONAL_TYPE)
case TypeProto::ValueCase::kOptionalType:
result = IsCompatible(type_proto_1.optional_type(), type_proto_2.optional_type());
break;
#endif
default:
ORT_ENFORCE(false);
Expand All @@ -247,10 +262,12 @@ bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Sequence& sequence_proto,
return IsCompatible(sequence_proto.elem_type(), type_proto.elem_type());
}

#if !defined(DISABLE_OPTIONAL_TYPE)
bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Optional& optional_proto,
const ONNX_NAMESPACE::TypeProto_Optional& type_proto) {
return IsCompatible(optional_proto.elem_type(), type_proto.elem_type());
}
#endif

bool IsCompatible(const ONNX_NAMESPACE::TypeProto_Opaque& opaque_proto,
const ONNX_NAMESPACE::TypeProto_Opaque& type_proto) {
Expand Down Expand Up @@ -493,6 +510,7 @@ MLDataType SequenceTensorTypeBase::Type() {
return &sequence_tensor_base;
}

#if !defined(DISABLE_OPTIONAL_TYPE)
///// OptionalTypeBase

struct OptionalTypeBase::Impl : public data_types_internal::TypeProtoImpl {
Expand Down Expand Up @@ -531,6 +549,33 @@ MLDataType OptionalTypeBase::Type() {
static OptionalTypeBase optional_type_base;
return &optional_type_base;
}
#endif

/// DisabledTypeBase

#if defined(DISABLE_OPTIONAL_TYPE)
struct DisabledTypeBase::Impl : public data_types_internal::TypeProtoImpl {
};

DisabledTypeBase::DisabledTypeBase() : impl_(new Impl()) {}

DisabledTypeBase::~DisabledTypeBase() {
delete impl_;
}

const ONNX_NAMESPACE::TypeProto* DisabledTypeBase::GetTypeProto() const {
return impl_->GetProto();
}

ONNX_NAMESPACE::TypeProto& DisabledTypeBase::MutableTypeProto() {
return impl_->MutableTypeProto();
}

MLDataType DisabledTypeBase::Type() {
static DisabledTypeBase disabled_base;
return &disabled_base;
}
#endif

/// NoTensorTypeBase
struct NonTensorTypeBase::Impl : public data_types_internal::TypeProtoImpl {};
Expand Down Expand Up @@ -695,11 +740,13 @@ ORT_REGISTER_OPTIONAL_ORT_TYPE(TensorSeq)
reg_fn(mltype); \
}

#if !defined(DISABLE_OPTIONAL_TYPE)
#define REGISTER_OPTIONAL_PROTO(ORT_TYPE, TYPE, reg_fn) \
{ \
MLDataType mltype = DataTypeImpl::GetOptionalType<ORT_TYPE, TYPE>(); \
reg_fn(mltype); \
}
#endif

#if !defined(DISABLE_SPARSE_TENSORS)
#define REGISTER_SPARSE_TENSOR_PROTO(TYPE, reg_fn) \
Expand Down Expand Up @@ -781,6 +828,7 @@ void RegisterAllProtos(const std::function<void(MLDataType)>& reg_fn) {
REGISTER_ONNX_PROTO(VectorMapInt64ToFloat, reg_fn);
#endif

#if !defined(DISABLE_OPTIONAL_TYPE)
#define REGISTER_OPTIONAL_PROTO_ORT_TYPE(ORT_TYPE, reg_fn) \
REGISTER_OPTIONAL_PROTO(ORT_TYPE, int32_t, reg_fn); \
REGISTER_OPTIONAL_PROTO(ORT_TYPE, float, reg_fn); \
Expand All @@ -799,6 +847,7 @@ void RegisterAllProtos(const std::function<void(MLDataType)>& reg_fn) {

REGISTER_OPTIONAL_PROTO_ORT_TYPE(Tensor, reg_fn);
REGISTER_OPTIONAL_PROTO_ORT_TYPE(TensorSeq, reg_fn);
#endif
}
} // namespace data_types_internal

Expand Down Expand Up @@ -1031,27 +1080,27 @@ std::vector<MLDataType> GetOptionalTensorTypesFromTypeList() {
}

template <typename... ElementTypes>
struct GetSequenceTensorTypesImpl {
struct GetOptionalSequenceTensorTypesImpl {
std::vector<MLDataType> operator()() const {
return {DataTypeImpl::GetSequenceTensorType<ElementTypes>()...};
return {DataTypeImpl::GetOptionalType<TensorSeq, ElementTypes>()...};
}
};

template <typename L>
std::vector<MLDataType> GetSequenceTensorTypesFromTypeList() {
return boost::mp11::mp_apply<GetSequenceTensorTypesImpl, L>{}();
std::vector<MLDataType> GetOptionalSequenceTensorTypesFromTypeList() {
return boost::mp11::mp_apply<GetOptionalSequenceTensorTypesImpl, L>{}();
}

template <typename... ElementTypes>
struct GetOptionalSequenceTensorTypesImpl {
struct GetSequenceTensorTypesImpl {
std::vector<MLDataType> operator()() const {
return {DataTypeImpl::GetOptionalType<TensorSeq, ElementTypes>()...};
return {DataTypeImpl::GetSequenceTensorType<ElementTypes>()...};
}
};

template <typename L>
std::vector<MLDataType> GetOptionalSequenceTensorTypesFromTypeList() {
return boost::mp11::mp_apply<GetOptionalSequenceTensorTypesImpl, L>{}();
std::vector<MLDataType> GetSequenceTensorTypesFromTypeList() {
return boost::mp11::mp_apply<GetSequenceTensorTypesImpl, L>{}();
}

} // namespace
Expand Down Expand Up @@ -1200,10 +1249,12 @@ ContainerChecker::ContainerChecker(MLDataType ml_type) {
types_.emplace_back(ContainerType::kSequence, TensorProto_DataType_UNDEFINED);
type_proto = &type_proto->sequence_type().elem_type();
break;
#if !defined(DISABLE_OPTIONAL_TYPE)
case TypeProto::ValueCase::kOptionalType:
types_.emplace_back(ContainerType::kOptional, TensorProto_DataType_UNDEFINED);
type_proto = &type_proto->optional_type().elem_type();
break;
#endif
case TypeProto::ValueCase::kOpaqueType:
// We do not handle this and terminate here
types_.emplace_back(ContainerType::kOpaque,
Expand Down
Loading

0 comments on commit e23892d

Please sign in to comment.