Skip to content

Commit

Permalink
Memory planner and pattern generation enhancements. (microsoft#4443)
Browse files Browse the repository at this point in the history
* static allocation.

* chanegs.

* contigious dynamic allocation.

* contigious dynamic allocation.

* fix bugs.

* fix bug.

* build errors.

* PR feedback.

* PR feedback.

* Update Graph builder for nccl_allreduce, mps.

* misc.

* fix windows build break.

* changes.

* fine-grained memory-time scheduling.

* merge.

* fix misc stuff.

* fix windows build.

* fix windows build.

* fix merge bug.

* merge conflicts.

* revert onnx-tensorrt submodule commit.

* fix submodule commit.

* misc.

* merge conflicts.

* Revert "merge conflicts."

This reverts commit 319a071.

* merge conflict.

* merge conflict.

* merge conflicts.

* fixes.

* PR feedback.

* build break.

* build break.

* Add asserts.

* Add asserts.

* asserts.

* asserts.

* asserts.

* asserts.

* asserts.

* fixes.

* fixes.

Co-authored-by: Ubuntu <OrtTrainingDev3@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
Co-authored-by: root <root@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
  • Loading branch information
3 people authored Nov 2, 2020
1 parent d98062d commit 9af0d48
Show file tree
Hide file tree
Showing 20 changed files with 519 additions and 89 deletions.
15 changes: 15 additions & 0 deletions include/onnxruntime/core/framework/kernel_def_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class KernelDef {

bool IsOutputOnCpu(size_t output_index) const { return MemTypeOnCpuExplicitly(OutputMemoryType(output_index)); }

bool AllocateInputsContiguously() const { return allocate_inputs_contiguously_; }

OrtMemType OutputMemoryType(size_t output_index) const {
auto it = output_memory_type_args_.find(output_index);
if (it == output_memory_type_args_.end())
Expand Down Expand Up @@ -128,6 +130,9 @@ class KernelDef {

// An element <i, j> means that output j is an alias of input i.
std::vector<std::pair<int, int>> alias_map_;

// Require input tensors to be allocated contiguously.
bool allocate_inputs_contiguously_ = false;

// The memory types of inputs/outputs of this kernel
MemTypeMap input_memory_type_args_;
Expand Down Expand Up @@ -215,6 +220,16 @@ class KernelDefBuilder {
KernelDefBuilder& Alias(const std::vector<std::pair<int, int>>& aliases);
KernelDefBuilder& Alias(int input_index, int output_index);

/**
Specify that this kernel requires input tensors to be allocated
contiguously. This allows kernels to execute as a single large
computation, rather than numerous smaller computations.
*/
KernelDefBuilder& AllocateInputsContiguously() {
kernel_def_->allocate_inputs_contiguously_ = true;
return *this;
}

/**
Specify that this kernel requires an input arg
in certain memory type (instead of the default, device memory).
Expand Down
142 changes: 137 additions & 5 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ class PlannerImpl {

// This is initialized to -1 to ensure that if ProcessDef is somehow not called, planning
// will fail more cleanly. This is also used as a temporary workaround to detect the
// case that the DML provider has removed initilizers from the graph during partitioning.
// Removing initializers is a temporary measure needed to limit the number of copies of
// case that the DML provider has removed initilizers from the graph during partitioning.
// Removing initializers is a temporary measure needed to limit the number of copies of
// tensors in GPU memory.
OrtValueIndex reused_buffer_index = -1; // index of original buffer to reuse
};
Expand Down Expand Up @@ -637,16 +637,29 @@ class PlannerImpl {
} else if (IsNonTensor(*node_output)) {
// we do not try sharing-optimization for non-tensors
AllocPlan(current).alloc_kind = AllocKind::kAllocate;
AllocPlan(current).program_counter_start.emplace_back(program_counter);
AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX);
} else if (FindReusableInput(*pnode, static_cast<int>(output_arg_def_index), &reused)) {
// Reuse one of this node's input buffers as the output buffer (for in-place update)
Reuse(reused, current, AllocKind::kReuse);
} else if (!context_.IsParallelExecutionEnabled() &&
FindReusableTensor(*node_output, &reused)) {
// Reuse an available (dead) buffer for this output, this is only for sequential execution.
Reuse(reused, current, AllocKind::kReuse);
OrtValueIndex original = Buffer(reused);
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() != SIZE_MAX);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() < program_counter);

AllocPlan(original).program_counter_start.emplace_back(program_counter);
AllocPlan(original).program_counter_end.emplace_back(SIZE_MAX);
}
} else {
// otherwise: allocate a new buffer for this output
AllocPlan(current).alloc_kind = AllocKind::kAllocate;
AllocPlan(current).program_counter_start.emplace_back(program_counter);
AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX);
}
}

Expand All @@ -657,8 +670,14 @@ class PlannerImpl {
auto original = Buffer(Index(sym));
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if ((original != -1) && (0 == DecrementUseCount(original)))
if ((original != -1) && (0 == DecrementUseCount(original))) {
freelist_.push_front(FreeBufferInfo(original, program_counter));
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX);
AllocPlan(original).program_counter_end.back() = program_counter;
}
}
}
}

Expand All @@ -668,8 +687,14 @@ class PlannerImpl {
auto original = Buffer(Index(sym));
// The index will be -1 if it's an initializer that was removed as part of a temporary workaround.
// See comments in the OrtValueInfo definition.
if ((original != -1) && (0 == DecrementUseCount(original)))
if ((original != -1) && (0 == DecrementUseCount(original))) {
freelist_.push_front(FreeBufferInfo(original, program_counter));
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX);
AllocPlan(original).program_counter_end.back() = program_counter;
}
}
}
}

Expand All @@ -678,14 +703,101 @@ class PlannerImpl {
if (node_output->Exists()) {
auto& sym = node_output->Name();
auto original = Buffer(Index(sym));
if (0 == DecrementUseCount(original))
if (0 == DecrementUseCount(original)) {
freelist_.push_front(FreeBufferInfo(original, program_counter));
if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0);
ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX);
AllocPlan(original).program_counter_end.back() = program_counter;
}
}
}
}
}
return Status::OK();
}

bool AllocateInputsContiguously(const Node& node) const {
const KernelCreateInfo& ci = GetKernelCreateInfo(kernel_create_info_map_, node.Index());
if (ci.kernel_def == nullptr) {
return false;
}

return ci.kernel_def->AllocateInputsContiguously();
}

// Compute allocation order for tensors that are required to be allocated contiguously.
Status ComputeAllocationOrder() {
std::vector<SequentialExecutionPlan::NodeExecutionPlan>& execution_plan(plan_.execution_plan);
std::vector<OrtValueIndex>& initializer_allocation_order(plan_.initializer_allocation_order);
std::vector<OrtValueIndex>& activation_allocation_order(plan_.activation_allocation_order);
for (size_t program_counter = 0; program_counter < execution_plan.size(); ++program_counter) {
SequentialExecutionPlan::NodeExecutionPlan step = execution_plan[program_counter];
const auto* pnode = graph_viewer_.GetNode(step.node_index);
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cannot find the node ", step.node_index);
if (!AllocateInputsContiguously(*pnode)) continue;
// This node has requested inputs be allocated contiguously.
const auto& input_defs = pnode->InputDefs();
onnxruntime::AllocKind input_kind = AllocKind::kAllocateStatically;
bool set_input_kind = true;
for (int input_arg_def_index = 0; static_cast<size_t>(input_arg_def_index) < input_defs.size(); ++input_arg_def_index) {
const auto& node_input = input_defs[input_arg_def_index];
if (!node_input->Exists()) continue;
const auto current_idx = Index(node_input->Name());
const auto& current_plan = AllocPlan(current_idx);
const auto actual_idx = current_plan.alloc_kind == AllocKind::kReuse ? current_plan.reused_buffer : current_idx;
const auto& actual_plan = AllocPlan(actual_idx);
if (set_input_kind) {
input_kind = actual_plan.alloc_kind;
set_input_kind = false;
}

if ((actual_plan.alloc_kind == AllocKind::kAllocateStatically) && (input_kind != AllocKind::kAllocateStatically))
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "AllocateInputsContiguously() requires all inputs to be initializers, or all inputs to be non-initializers.");

if (actual_plan.alloc_kind == AllocKind::kAllocateStatically) {
if (std::find(initializer_allocation_order.begin(), initializer_allocation_order.end(), actual_idx) == initializer_allocation_order.end())
initializer_allocation_order.push_back(actual_idx);
} else {
if (std::find(activation_allocation_order.begin(), activation_allocation_order.end(), actual_idx) == activation_allocation_order.end())
activation_allocation_order.push_back(actual_idx);
}
}
}
return Status::OK();
}

// Ensure memory time schedule is sorted.
Status VerifyMemoryTimeSchedule() {
std::vector<SequentialExecutionPlan::NodeExecutionPlan>& execution_plan(plan_.execution_plan);
for (size_t program_counter = 0; program_counter < execution_plan.size(); ++program_counter) {
SequentialExecutionPlan::NodeExecutionPlan step = execution_plan[program_counter];
const auto* pnode = graph_viewer_.GetNode(step.node_index);
if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cannot find the node ", step.node_index);
const auto& input_defs = pnode->InputDefs();
for (int input_arg_def_index = 0; static_cast<size_t>(input_arg_def_index) < input_defs.size(); ++input_arg_def_index) {
const auto& node_input = input_defs[input_arg_def_index];
if (!node_input->Exists()) continue;
const auto& current_plan = AllocPlan(Index(node_input->Name()));
if (current_plan.alloc_kind != AllocKind::kAllocate) continue;

ORT_ENFORCE(current_plan.program_counter_start.size() == current_plan.program_counter_end.size());

size_t start = 0;
for (size_t index = 0; index < current_plan.program_counter_start.size(); index += 1) {
ORT_ENFORCE((current_plan.program_counter_start[index] > start) || (start == 0));
ORT_ENFORCE(current_plan.program_counter_start[index] <= current_plan.program_counter_end[index]);
ORT_ENFORCE((current_plan.program_counter_start[index] < SIZE_MAX) || (index == 0));
ORT_ENFORCE((current_plan.program_counter_end[index] > 0) || (index == 0));

start = current_plan.program_counter_start[index];
}
}
}

return Status::OK();
}

// Whether a given NodeArg has fence or not.
// If the buffer is reused, need to check whether original OrtValue has fence or not.
bool HasFence(const onnxruntime::NodeArg* arg) {
Expand Down Expand Up @@ -757,6 +869,19 @@ class PlannerImpl {

if (has_prev_dealloc_point)
plan_.execution_plan[prev_dealloc_point].free_to_index = current - 1;

size_t program_counter = 0;
for (auto& node_plan : plan_.execution_plan) {
for (int index = node_plan.free_from_index; index <= node_plan.free_to_index; ++index) {
auto ml_value_idx = plan_.to_be_freed[index];
if (AllocPlan(ml_value_idx).alloc_kind == AllocKind::kAllocate) {
ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_start.back() <= program_counter);
ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_end.back() == program_counter);
}
}

program_counter += 1;
}
}

static bool IsNonTensor(const onnxruntime::NodeArg& nodearg) {
Expand Down Expand Up @@ -789,9 +914,16 @@ Status PlannerImpl::CreatePlan() {
// Determine nodes that need fence check. This needs to be done after ComputeUseCounts and ComputeReusePlan.
ORT_RETURN_IF_ERROR(ComputeFenceCheck());

// Determine allocation order for weights and activations. This needs to be done after ComputeReusePlan.
ORT_RETURN_IF_ERROR(ComputeAllocationOrder());

// convert information in the freelist_ into a deallocation plan in required format
GenerateDeallocationPlan();

// Ensure Memory-Time schedule is sorted. This should be called at the end because memory start/end timestamps
// are updated until GenerateDeallocationPlan is finished.
ORT_RETURN_IF_ERROR(VerifyMemoryTimeSchedule());

return Status::OK();
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ ExecutionFrame::ExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, const
}

// log size of activation. Keep it commented out for now to avoid log flooding.
// VLOGS(session_state_.Logger(), 1) << "Allocated memory for activations, size: "
// << mem_patterns_->patterns[i].PeakSize();
printf("\n **** Allocated memory for activations, size: %zu ***\n", mem_patterns_->patterns[i].PeakSize());
}
}
}
Expand Down Expand Up @@ -350,6 +349,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va
// if we have pre-calculated memory pattern, and the ort_value is not output mlvalue
// try to allocated on pre-allocated big chunk.
const auto& per_alloc_plan = GetAllocationPlan(ort_value_index);

if (mem_patterns_ && per_alloc_plan.alloc_kind != AllocKind::kAllocateOutput) {
auto pattern = mem_patterns_->GetPatterns(location);
if (pattern) {
Expand Down
Loading

0 comments on commit 9af0d48

Please sign in to comment.