Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Addressing PR comments (#3334)" #3412

Merged
merged 1 commit into from
Apr 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ static int64_t CalculateMemoryPatternsKey(const std::vector<std::reference_wrapp
#ifdef ENABLE_TRAINING
namespace {
Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, TensorShape>& feeds, std::unordered_map<std::string, int64_t>& out) {
for (const auto* input : graph.GetInputs()) {
for (size_t i = 0; i < graph.GetInputs().size(); ++i) {
auto* input = graph.GetInputs()[i];
auto* shape = input->Shape();
auto it = feeds.find(input->Name());
if (it == feeds.end())
Expand All @@ -200,7 +201,7 @@ Status ResolveDimParams(const GraphViewer& graph, const std::map<std::string, Te
return Status(ONNXRUNTIME, FAIL, "Graph input " + input->Name() +
"'s shape is not present or its shape doesn't match feed's shape."
"Unable to resolve the value for dynamic shape");
for (int k = 0, end = shape->dim_size(); k < end; ++k) {
for (int k = 0; k < shape->dim_size(); ++k) {
if (shape->dim()[k].has_dim_param()) {
out.insert({shape->dim()[k].dim_param(), it->second.GetDims()[k]});
}
Expand All @@ -214,7 +215,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
const std::vector<int>& feed_mlvalue_idxs,
MemoryPatternGroup* output) const {
std::map<std::string, TensorShape> feeds;
for (size_t i = 0, end = feed_mlvalue_idxs.size(); i < end; ++i) {
for (size_t i = 0; i < feed_mlvalue_idxs.size(); ++i) {
std::string name;
ORT_RETURN_IF_ERROR(this->ort_value_name_idx_map_.GetName(feed_mlvalue_idxs[i], name));
feeds.insert({name, input_shape[i]});
Expand All @@ -230,7 +231,7 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
auto* node = graph_viewer_->GetNode(node_plan.node_index);
int output_start = node_index + static_cast<int>(node->InputDefs().size()) + static_cast<int>(node->ImplicitInputDefs().size());
//allocate output
for (int i = 0, end = static_cast<int>(node->OutputDefs().size()); i < end; ++i) {
for (int i = 0; i < static_cast<int>(node->OutputDefs().size()); ++i) {
const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i);
if (ml_value_idx == NodeIndexInfo::kInvalidEntry)
continue;
Expand All @@ -253,10 +254,8 @@ Status SessionState::GeneratePatternGroupCache(const std::vector<std::reference_
return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute");
}
len *= it->second;
} else if (dim.has_dim_value()) {
len *= dim.dim_value();
} else {
return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute");
len *= dim.dim_value();
}
}
if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(len, ml_data_type->Size(), &size)) {
Expand Down
32 changes: 16 additions & 16 deletions onnxruntime/core/providers/cpu/tensor/slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,21 +260,21 @@ Status SliceBase::PrepareForCompute(const std::vector<int64_t>& raw_starts,
}

// Slice V10 & DynamicSlice
void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
void SliceBase::FillVectorsFromInput(const Tensor* start_tensor,
const Tensor* ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends,
std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) const {
ORT_ENFORCE(start_tensor.Shape().NumDimensions() == 1, "Starts must be a 1-D array");
ORT_ENFORCE(ends_tensor.Shape().NumDimensions() == 1, "Ends must be a 1-D array");
ORT_ENFORCE(start_tensor.Shape() == ends_tensor.Shape(), "Starts and ends shape mismatch");
ORT_ENFORCE(nullptr == axes_tensor || start_tensor.Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
ORT_ENFORCE(nullptr == steps_tensor || start_tensor.Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch");
ORT_ENFORCE(nullptr != start_tensor && start_tensor->Shape().NumDimensions() == 1, "Starts must be a 1-D array");
ORT_ENFORCE(nullptr != ends_tensor && ends_tensor->Shape().NumDimensions() == 1, "Ends must be a 1-D array");
ORT_ENFORCE(start_tensor->Shape() == ends_tensor->Shape(), "Starts and ends shape mismatch");
ORT_ENFORCE(nullptr == axes_tensor || start_tensor->Shape() == axes_tensor->Shape(), "Starts and axes shape mismatch");
ORT_ENFORCE(nullptr == steps_tensor || start_tensor->Shape() == steps_tensor->Shape(), "Starts and steps shape mismatch");

const auto& size = start_tensor.Shape().Size();
const auto& size = start_tensor->Shape().Size();
input_starts.resize(size);
input_ends.resize(size);
if (nullptr != axes_tensor)
Expand All @@ -283,19 +283,19 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,
if (nullptr != steps_tensor)
input_steps.resize(size);

if (start_tensor.IsDataType<int32_t>()) {
std::copy(start_tensor.Data<int32_t>(), start_tensor.Data<int32_t>() + size, input_starts.begin());
std::copy(ends_tensor.Data<int32_t>(), ends_tensor.Data<int32_t>() + size, input_ends.begin());
if (start_tensor->IsDataType<int32_t>()) {
std::copy(start_tensor->Data<int32_t>(), start_tensor->Data<int32_t>() + size, input_starts.begin());
std::copy(ends_tensor->Data<int32_t>(), ends_tensor->Data<int32_t>() + size, input_ends.begin());
if (nullptr != axes_tensor)
std::copy(axes_tensor->Data<int32_t>(), axes_tensor->Data<int32_t>() + size, input_axes.begin());
// Slice V10
if (nullptr != steps_tensor)
std::copy(steps_tensor->Data<int32_t>(), steps_tensor->Data<int32_t>() + size, input_steps.begin());
}

else if (start_tensor.IsDataType<int64_t>()) {
std::copy(start_tensor.Data<int64_t>(), start_tensor.Data<int64_t>() + size, input_starts.begin());
std::copy(ends_tensor.Data<int64_t>(), ends_tensor.Data<int64_t>() + size, input_ends.begin());
else if (start_tensor->IsDataType<int64_t>()) {
std::copy(start_tensor->Data<int64_t>(), start_tensor->Data<int64_t>() + size, input_starts.begin());
std::copy(ends_tensor->Data<int64_t>(), ends_tensor->Data<int64_t>() + size, input_ends.begin());
if (nullptr != axes_tensor)
std::copy(axes_tensor->Data<int64_t>(), axes_tensor->Data<int64_t>() + size, input_axes.begin());
// Slice V10
Expand All @@ -305,7 +305,7 @@ void SliceBase::FillVectorsFromInput(const Tensor& start_tensor,

// should not reach this as no kernel is registered for this condition to be triggered - just an additional safety check
else {
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor.DataType());
ORT_THROW("Data type for starts and ends inputs' need to be int32_t or int64_t, but instead got ", start_tensor->DataType());
}
}

Expand Down Expand Up @@ -379,7 +379,7 @@ Status Slice<T, dynamic>::Compute(OpKernelContext* ctx) const {
std::vector<int64_t> input_ends;
std::vector<int64_t> input_axes;
std::vector<int64_t> input_steps;
FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
FillVectorsFromInput(ctx->Input<Tensor>(1), ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
ctx->Input<Tensor>(4), input_starts, input_ends, input_axes, input_steps);

ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cpu/tensor/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class SliceBase {
std::vector<int64_t>*& flattened_output_dims) const;

// Slice V10 & DynamicSlice
void FillVectorsFromInput(const Tensor& start_tensor,
const Tensor& ends_tensor,
void FillVectorsFromInput(const Tensor* start_tensor,
const Tensor* ends_tensor,
const Tensor* axes_tensor,
const Tensor* steps_tensor,
std::vector<int64_t>& input_starts,
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ namespace onnxruntime {
struct CUDAProviderFactory : IExecutionProviderFactory {
CUDAProviderFactory(OrtDevice::DeviceId device_id,
size_t cuda_mem_limit = std::numeric_limits<size_t>::max(),
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo)
: device_id_(device_id),
cuda_mem_limit_(cuda_mem_limit),
arena_extend_strategy_(arena_extend_strategy) {}
ArenaExtendStrategy arena_extend_strategy = ArenaExtendStrategy::kNextPowerOfTwo) : device_id_(device_id), cuda_mem_limit_(cuda_mem_limit), arena_extend_strategy_(arena_extend_strategy) {}
~CUDAProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/tensor/slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ template <bool dynamic>
void Slice<dynamic>::FillInputVectors(OpKernelContext* ctx, std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends, std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) const {
FillVectorsFromInput(*ctx->Input<Tensor>(1), *ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
FillVectorsFromInput(ctx->Input<Tensor>(1), ctx->Input<Tensor>(2), ctx->Input<Tensor>(3),
ctx->Input<Tensor>(4), input_starts, input_ends, input_axes, input_steps);
}

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,9 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio
// Pass fused function manager to subgraph
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());

// Pass fused function manager to subgraph
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());

// recurse
ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,8 @@ class InferenceSession {
// The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx
std::basic_string<ORTCHAR_T> model_location_;

SessionOptions session_options_;

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession);

Expand Down Expand Up @@ -428,8 +430,6 @@ class InferenceSession {
template <typename T>
void StartProfiling(const std::basic_string<T>& file_prefix);

SessionOptions session_options_;

onnxruntime::GraphTransformerManager graph_transformation_mgr_;

// List of transformers to run. When this list is not empty only the transformers in this list
Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/core/session/training_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ void TrainingSession::AddPredefinedTransformers(GraphTransformerManager& transfo
const std::vector<std::string>& custom_list) {
auto add_transformers = [&](TransformerLevel level) {
// Generate and register transformers for level
auto transformers_to_register = transformer_utils::GenerateTransformers(level, GetSessionOptions().free_dimension_overrides, custom_list);
auto transformers_to_register = transformer_utils::GenerateTransformers(level, session_options_.free_dimension_overrides, custom_list);
for (auto& entry : transformers_to_register) {
transformer_manager.Register(std::move(entry), level);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Status SliceGrad::Compute(OpKernelContext* context) const {
std::vector<int64_t> input_ends;
std::vector<int64_t> input_axes;
std::vector<int64_t> input_steps;
FillVectorsFromInput(*context->Input<Tensor>(2), *context->Input<Tensor>(3), context->Input<Tensor>(4),
FillVectorsFromInput(context->Input<Tensor>(2), context->Input<Tensor>(3), context->Input<Tensor>(4),
context->Input<Tensor>(5), input_starts, input_ends, input_axes, input_steps);

ORT_RETURN_IF_ERROR(PrepareForCompute(input_starts, input_ends, input_axes, input_steps,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const Tensor* SliceGrad::GetSlicedOrUnslicedTensor(OpKernelContext* ctx) const {
void SliceGrad::FillInputVectors(OpKernelContext* ctx, std::vector<int64_t>& input_starts,
std::vector<int64_t>& input_ends, std::vector<int64_t>& input_axes,
std::vector<int64_t>& input_steps) const {
FillVectorsFromInput(*ctx->Input<Tensor>(2), *ctx->Input<Tensor>(3), ctx->Input<Tensor>(4),
FillVectorsFromInput(ctx->Input<Tensor>(2), ctx->Input<Tensor>(3), ctx->Input<Tensor>(4),
ctx->Input<Tensor>(5), input_starts, input_ends, input_axes, input_steps);
}

Expand Down