From 9af0d485241d3f037b94ebf876e457f396b6e583 Mon Sep 17 00:00:00 2001 From: "M. Zeeshan Siddiqui" Date: Sun, 1 Nov 2020 23:05:46 -0800 Subject: [PATCH] Memory planner and pattern generation enhancements. (#4443) * static allocation. * chanegs. * contigious dynamic allocation. * contigious dynamic allocation. * fix bugs. * fix bug. * build errors. * PR feedback. * PR feedback. * Update Graph builder for nccl_allreduce, mps. * misc. * fix windows build break. * changes. * fine-grained memory-time scheduling. * merge. * fix misc stuff. * fix windows build. * fix windows build. * fix merge bug. * merge conflicts. * revert onnx-tensorrt submodule commit. * fix submodule commit. * misc. * merge conflicts. * Revert "merge conflicts." This reverts commit 319a071a6e63e79d001afe40faba5adccd137902. * merge conflict. * merge conflict. * merge conflicts. * fixes. * PR feedback. * build break. * build break. * Add asserts. * Add asserts. * asserts. * asserts. * asserts. * asserts. * asserts. * fixes. * fixes. Co-authored-by: Ubuntu Co-authored-by: root --- .../core/framework/kernel_def_builder.h | 15 ++ .../core/framework/allocation_planner.cc | 142 ++++++++++++++++- onnxruntime/core/framework/execution_frame.cc | 4 +- .../core/framework/mem_pattern_planner.h | 146 +++++++++++++++++- .../framework/ort_value_pattern_planner.cc | 12 ++ .../framework/ort_value_pattern_planner.h | 1 + .../core/framework/parallel_executor.cc | 3 + .../framework/sequential_execution_plan.h | 9 ++ .../core/framework/sequential_executor.cc | 4 + onnxruntime/core/framework/session_state.cc | 96 ++++++++++-- .../core/framework/session_state_utils.cc | 11 +- .../core/framework/session_state_utils.h | 3 +- onnxruntime/core/framework/utils.cc | 21 +++ onnxruntime/core/framework/utils.h | 2 + .../allreduce_optimizer_graph_builder.cc | 61 ++------ .../core/graph/optimizer_graph_builder.cc | 41 ++++- .../core/graph/optimizer_graph_builder.h | 8 + orttraining/orttraining/models/bert/main.cc | 7 +- .../cuda/collective/nccl_kernels.cc | 20 +-- .../tools/ci_test/run_convergence_test.py | 2 + 20 files changed, 519 insertions(+), 89 deletions(-) diff --git a/include/onnxruntime/core/framework/kernel_def_builder.h b/include/onnxruntime/core/framework/kernel_def_builder.h index 4ce1d01af28e9..fc1656d47b91e 100644 --- a/include/onnxruntime/core/framework/kernel_def_builder.h +++ b/include/onnxruntime/core/framework/kernel_def_builder.h @@ -75,6 +75,8 @@ class KernelDef { bool IsOutputOnCpu(size_t output_index) const { return MemTypeOnCpuExplicitly(OutputMemoryType(output_index)); } + bool AllocateInputsContiguously() const { return allocate_inputs_contiguously_; } + OrtMemType OutputMemoryType(size_t output_index) const { auto it = output_memory_type_args_.find(output_index); if (it == output_memory_type_args_.end()) @@ -128,6 +130,9 @@ class KernelDef { // An element means that output j is an alias of input i. std::vector> alias_map_; + + // Require input tensors to be allocated contiguously. + bool allocate_inputs_contiguously_ = false; // The memory types of inputs/outputs of this kernel MemTypeMap input_memory_type_args_; @@ -215,6 +220,16 @@ class KernelDefBuilder { KernelDefBuilder& Alias(const std::vector>& aliases); KernelDefBuilder& Alias(int input_index, int output_index); + /** + Specify that this kernel requires input tensors to be allocated + contiguously. This allows kernels to execute as a single large + computation, rather than numerous smaller computations. + */ + KernelDefBuilder& AllocateInputsContiguously() { + kernel_def_->allocate_inputs_contiguously_ = true; + return *this; + } + /** Specify that this kernel requires an input arg in certain memory type (instead of the default, device memory). diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 711c6c28c0ce7..e7568df886d9a 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -147,8 +147,8 @@ class PlannerImpl { // This is initialized to -1 to ensure that if ProcessDef is somehow not called, planning // will fail more cleanly. This is also used as a temporary workaround to detect the - // case that the DML provider has removed initilizers from the graph during partitioning. - // Removing initializers is a temporary measure needed to limit the number of copies of + // case that the DML provider has removed initilizers from the graph during partitioning. + // Removing initializers is a temporary measure needed to limit the number of copies of // tensors in GPU memory. OrtValueIndex reused_buffer_index = -1; // index of original buffer to reuse }; @@ -637,6 +637,8 @@ class PlannerImpl { } else if (IsNonTensor(*node_output)) { // we do not try sharing-optimization for non-tensors AllocPlan(current).alloc_kind = AllocKind::kAllocate; + AllocPlan(current).program_counter_start.emplace_back(program_counter); + AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX); } else if (FindReusableInput(*pnode, static_cast(output_arg_def_index), &reused)) { // Reuse one of this node's input buffers as the output buffer (for in-place update) Reuse(reused, current, AllocKind::kReuse); @@ -644,9 +646,20 @@ class PlannerImpl { FindReusableTensor(*node_output, &reused)) { // Reuse an available (dead) buffer for this output, this is only for sequential execution. Reuse(reused, current, AllocKind::kReuse); + OrtValueIndex original = Buffer(reused); + if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) { + ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0); + ORT_ENFORCE(AllocPlan(original).program_counter_end.back() != SIZE_MAX); + ORT_ENFORCE(AllocPlan(original).program_counter_end.back() < program_counter); + + AllocPlan(original).program_counter_start.emplace_back(program_counter); + AllocPlan(original).program_counter_end.emplace_back(SIZE_MAX); + } } else { // otherwise: allocate a new buffer for this output AllocPlan(current).alloc_kind = AllocKind::kAllocate; + AllocPlan(current).program_counter_start.emplace_back(program_counter); + AllocPlan(current).program_counter_end.emplace_back(SIZE_MAX); } } @@ -657,8 +670,14 @@ class PlannerImpl { auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. - if ((original != -1) && (0 == DecrementUseCount(original))) + if ((original != -1) && (0 == DecrementUseCount(original))) { freelist_.push_front(FreeBufferInfo(original, program_counter)); + if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) { + ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0); + ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX); + AllocPlan(original).program_counter_end.back() = program_counter; + } + } } } @@ -668,8 +687,14 @@ class PlannerImpl { auto original = Buffer(Index(sym)); // The index will be -1 if it's an initializer that was removed as part of a temporary workaround. // See comments in the OrtValueInfo definition. - if ((original != -1) && (0 == DecrementUseCount(original))) + if ((original != -1) && (0 == DecrementUseCount(original))) { freelist_.push_front(FreeBufferInfo(original, program_counter)); + if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) { + ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0); + ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX); + AllocPlan(original).program_counter_end.back() = program_counter; + } + } } } @@ -678,14 +703,101 @@ class PlannerImpl { if (node_output->Exists()) { auto& sym = node_output->Name(); auto original = Buffer(Index(sym)); - if (0 == DecrementUseCount(original)) + if (0 == DecrementUseCount(original)) { freelist_.push_front(FreeBufferInfo(original, program_counter)); + if (AllocPlan(original).alloc_kind == AllocKind::kAllocate) { + ORT_ENFORCE(AllocPlan(original).program_counter_end.size() > 0); + ORT_ENFORCE(AllocPlan(original).program_counter_end.back() == SIZE_MAX); + AllocPlan(original).program_counter_end.back() = program_counter; + } + } + } + } + } + return Status::OK(); + } + + bool AllocateInputsContiguously(const Node& node) const { + const KernelCreateInfo& ci = GetKernelCreateInfo(kernel_create_info_map_, node.Index()); + if (ci.kernel_def == nullptr) { + return false; + } + + return ci.kernel_def->AllocateInputsContiguously(); + } + + // Compute allocation order for tensors that are required to be allocated contiguously. + Status ComputeAllocationOrder() { + std::vector& execution_plan(plan_.execution_plan); + std::vector& initializer_allocation_order(plan_.initializer_allocation_order); + std::vector& activation_allocation_order(plan_.activation_allocation_order); + for (size_t program_counter = 0; program_counter < execution_plan.size(); ++program_counter) { + SequentialExecutionPlan::NodeExecutionPlan step = execution_plan[program_counter]; + const auto* pnode = graph_viewer_.GetNode(step.node_index); + if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cannot find the node ", step.node_index); + if (!AllocateInputsContiguously(*pnode)) continue; + // This node has requested inputs be allocated contiguously. + const auto& input_defs = pnode->InputDefs(); + onnxruntime::AllocKind input_kind = AllocKind::kAllocateStatically; + bool set_input_kind = true; + for (int input_arg_def_index = 0; static_cast(input_arg_def_index) < input_defs.size(); ++input_arg_def_index) { + const auto& node_input = input_defs[input_arg_def_index]; + if (!node_input->Exists()) continue; + const auto current_idx = Index(node_input->Name()); + const auto& current_plan = AllocPlan(current_idx); + const auto actual_idx = current_plan.alloc_kind == AllocKind::kReuse ? current_plan.reused_buffer : current_idx; + const auto& actual_plan = AllocPlan(actual_idx); + if (set_input_kind) { + input_kind = actual_plan.alloc_kind; + set_input_kind = false; + } + + if ((actual_plan.alloc_kind == AllocKind::kAllocateStatically) && (input_kind != AllocKind::kAllocateStatically)) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "AllocateInputsContiguously() requires all inputs to be initializers, or all inputs to be non-initializers."); + + if (actual_plan.alloc_kind == AllocKind::kAllocateStatically) { + if (std::find(initializer_allocation_order.begin(), initializer_allocation_order.end(), actual_idx) == initializer_allocation_order.end()) + initializer_allocation_order.push_back(actual_idx); + } else { + if (std::find(activation_allocation_order.begin(), activation_allocation_order.end(), actual_idx) == activation_allocation_order.end()) + activation_allocation_order.push_back(actual_idx); } } } return Status::OK(); } + // Ensure memory time schedule is sorted. + Status VerifyMemoryTimeSchedule() { + std::vector& execution_plan(plan_.execution_plan); + for (size_t program_counter = 0; program_counter < execution_plan.size(); ++program_counter) { + SequentialExecutionPlan::NodeExecutionPlan step = execution_plan[program_counter]; + const auto* pnode = graph_viewer_.GetNode(step.node_index); + if (pnode == nullptr) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cannot find the node ", step.node_index); + const auto& input_defs = pnode->InputDefs(); + for (int input_arg_def_index = 0; static_cast(input_arg_def_index) < input_defs.size(); ++input_arg_def_index) { + const auto& node_input = input_defs[input_arg_def_index]; + if (!node_input->Exists()) continue; + const auto& current_plan = AllocPlan(Index(node_input->Name())); + if (current_plan.alloc_kind != AllocKind::kAllocate) continue; + + ORT_ENFORCE(current_plan.program_counter_start.size() == current_plan.program_counter_end.size()); + + size_t start = 0; + for (size_t index = 0; index < current_plan.program_counter_start.size(); index += 1) { + ORT_ENFORCE((current_plan.program_counter_start[index] > start) || (start == 0)); + ORT_ENFORCE(current_plan.program_counter_start[index] <= current_plan.program_counter_end[index]); + ORT_ENFORCE((current_plan.program_counter_start[index] < SIZE_MAX) || (index == 0)); + ORT_ENFORCE((current_plan.program_counter_end[index] > 0) || (index == 0)); + + start = current_plan.program_counter_start[index]; + } + } + } + + return Status::OK(); + } + // Whether a given NodeArg has fence or not. // If the buffer is reused, need to check whether original OrtValue has fence or not. bool HasFence(const onnxruntime::NodeArg* arg) { @@ -757,6 +869,19 @@ class PlannerImpl { if (has_prev_dealloc_point) plan_.execution_plan[prev_dealloc_point].free_to_index = current - 1; + + size_t program_counter = 0; + for (auto& node_plan : plan_.execution_plan) { + for (int index = node_plan.free_from_index; index <= node_plan.free_to_index; ++index) { + auto ml_value_idx = plan_.to_be_freed[index]; + if (AllocPlan(ml_value_idx).alloc_kind == AllocKind::kAllocate) { + ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_start.back() <= program_counter); + ORT_ENFORCE(AllocPlan(ml_value_idx).program_counter_end.back() == program_counter); + } + } + + program_counter += 1; + } } static bool IsNonTensor(const onnxruntime::NodeArg& nodearg) { @@ -789,9 +914,16 @@ Status PlannerImpl::CreatePlan() { // Determine nodes that need fence check. This needs to be done after ComputeUseCounts and ComputeReusePlan. ORT_RETURN_IF_ERROR(ComputeFenceCheck()); + // Determine allocation order for weights and activations. This needs to be done after ComputeReusePlan. + ORT_RETURN_IF_ERROR(ComputeAllocationOrder()); + // convert information in the freelist_ into a deallocation plan in required format GenerateDeallocationPlan(); + // Ensure Memory-Time schedule is sorted. This should be called at the end because memory start/end timestamps + // are updated until GenerateDeallocationPlan is finished. + ORT_RETURN_IF_ERROR(VerifyMemoryTimeSchedule()); + return Status::OK(); } diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index f9184c9eaa6c1..67a538393259c 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -292,8 +292,7 @@ ExecutionFrame::ExecutionFrame(const std::vector& feed_mlvalue_idxs, const } // log size of activation. Keep it commented out for now to avoid log flooding. - // VLOGS(session_state_.Logger(), 1) << "Allocated memory for activations, size: " - // << mem_patterns_->patterns[i].PeakSize(); + printf("\n **** Allocated memory for activations, size: %zu ***\n", mem_patterns_->patterns[i].PeakSize()); } } } @@ -350,6 +349,7 @@ Status ExecutionFrame::AllocateMLValueTensorSelfOwnBufferHelper(OrtValue& ort_va // if we have pre-calculated memory pattern, and the ort_value is not output mlvalue // try to allocated on pre-allocated big chunk. const auto& per_alloc_plan = GetAllocationPlan(ort_value_index); + if (mem_patterns_ && per_alloc_plan.alloc_kind != AllocKind::kAllocateOutput) { auto pattern = mem_patterns_->GetPatterns(location); if (pattern) { diff --git a/onnxruntime/core/framework/mem_pattern_planner.h b/onnxruntime/core/framework/mem_pattern_planner.h index 1fb64cce86052..cb57ae31d3f3f 100644 --- a/onnxruntime/core/framework/mem_pattern_planner.h +++ b/onnxruntime/core/framework/mem_pattern_planner.h @@ -31,7 +31,38 @@ class MemPatternPlanner { public: MemPatternPlanner() = default; - void TraceAllocation(int ml_value_idx, size_t size) { + // Returns true if there is an intersection between two time schedules. + // ASSUMES EACH TIME SCHEDULE IS SORTED. THIS IS VALIDATED AT THE END OF MEMORY PLANNING. + bool OverlappingTimeSchedules(const std::vector& program_counter_start_1, const std::vector& program_counter_end_1, + const std::vector& program_counter_start_2, const std::vector& program_counter_end_2) { + ORT_ENFORCE(program_counter_start_1.size() > 0); + ORT_ENFORCE(program_counter_start_2.size() > 0); + ORT_ENFORCE(program_counter_start_1.size() == program_counter_end_1.size()); + ORT_ENFORCE(program_counter_start_2.size() == program_counter_end_2.size()); + + size_t index_1 = 0; + size_t index_2 = 0; + bool overlap = false; + while ((index_1 < program_counter_start_1.size()) && (index_2 < program_counter_start_2.size())) { + if (program_counter_start_1[index_1] <= program_counter_start_2[index_2]) { + if (program_counter_end_1[index_1] >= program_counter_start_2[index_2]) { + overlap = true; + break; + } + index_1 += 1; + } else { + if (program_counter_end_2[index_2] >= program_counter_start_1[index_1]) { + overlap = true; + break; + } + index_2 += 1; + } + } + + return overlap; + } + + void TraceAllocation(int ml_value_idx, const std::vector& program_counter_start, const std::vector& program_counter_end, size_t size) { std::lock_guard lock(lock_); if (size == 0) { @@ -47,23 +78,96 @@ class MemPatternPlanner { best_offset = last_block.block_.offset_ + last_block.block_.size_; } + for (auto it = blocks_.begin(); it != blocks_.end(); it++) { + // Memory block can be re-used as long as there is no overlap between their time schedules. + if (allocs_[*it].reuse_ && !OverlappingTimeSchedules(program_counter_start, program_counter_end, + allocs_[*it].program_counter_start_, allocs_[*it].program_counter_end_)) { + continue; + } + + if (allocs_[*it].block_.offset_ >= current) { + auto gap = allocs_[*it].block_.offset_ - current; + if (gap >= size && (gap - size) < waste_bytes) { + waste_bytes = gap - size; + best_offset = current; + } + } + + current = std::max(current, allocs_[*it].block_.offset_ + allocs_[*it].block_.size_); + } + + if (current < buffer_size_) { + auto gap = buffer_size_ - current; + if ((gap >= size) && ((gap - size) < waste_bytes)) + best_offset = current; + } + + // we only need to bounds check the addition of size to best_offset as that is the only time we extend + // the maximum size of the buffer. + buffer_size_ = std::max(buffer_size_, SafeInt(best_offset) + size); + allocs_.emplace_back(ml_value_idx, program_counter_start, program_counter_end, MemoryBlock(best_offset, size)); std::list::iterator best_fit_it = blocks_.end(); + for (auto it = blocks_.begin(); it != blocks_.end(); it++) { + if (allocs_[*it].block_.offset_ < best_offset) + continue; + + if ((allocs_[*it].block_.offset_ > best_offset) || (allocs_[*it].block_.size_ >= size)) { + best_fit_it = it; + break; + } + } + + blocks_.insert(best_fit_it, (static_cast(allocs_.size()) - 1)); + } + + void TraceAllocation(int ml_value_idx, size_t size) { + std::lock_guard lock(lock_); + + if (size == 0) { + allocs_.emplace_back(ml_value_idx, MemoryBlock(0, 0)); + return; + } + + size_t current = 0; + size_t waste_bytes = std::numeric_limits::max(); + size_t best_offset = 0; + if (!blocks_.empty()) { + auto last_block = allocs_[*blocks_.rbegin()]; + best_offset = last_block.block_.offset_ + last_block.block_.size_; + } + for (auto it = blocks_.begin(); it != blocks_.end(); it++) { if (allocs_[*it].block_.offset_ >= current) { auto gap = allocs_[*it].block_.offset_ - current; if (gap >= size && (gap - size) < waste_bytes) { - best_fit_it = it; waste_bytes = gap - size; best_offset = current; } } - current = allocs_[*it].block_.offset_ + allocs_[*it].block_.size_; + current = std::max(current, allocs_[*it].block_.offset_ + allocs_[*it].block_.size_); + } + + if (current < buffer_size_) { + auto gap = buffer_size_ - current; + if ((gap >= size) && ((gap - size) < waste_bytes)) + best_offset = current; } // we only need to bounds check the addition of size to best_offset as that is the only time we extend // the maximum size of the buffer. buffer_size_ = std::max(buffer_size_, SafeInt(best_offset) + size); allocs_.emplace_back(ml_value_idx, MemoryBlock(best_offset, size)); + std::list::iterator best_fit_it = blocks_.end(); + for (auto it = blocks_.begin(); it != blocks_.end(); it++) { + if (allocs_[*it].block_.offset_ < best_offset) + continue; + + if ((allocs_[*it].block_.offset_ > best_offset) || (allocs_[*it].block_.size_ >= size)) { + best_fit_it = it; + break; + } + } + blocks_.insert(best_fit_it, (static_cast(allocs_.size()) - 1)); } @@ -78,9 +182,36 @@ class MemPatternPlanner { } } - MemoryPattern GenerateMemPattern() const { + MemoryPattern GenerateMemPattern() { std::lock_guard lock(lock_); + // Time schedules of overlapping memory blocks SHOULD NOT intersect. + for (size_t index_1 = 0; index_1 < allocs_.size(); index_1 += 1) { + if (!allocs_[index_1].reuse_) + continue; + + for (size_t index_2 = index_1 + 1; index_2 < allocs_.size(); index_2 += 1) { + if (!allocs_[index_2].reuse_) + continue; + + size_t alloc_1_start = allocs_[index_1].block_.offset_; + size_t alloc_1_end = alloc_1_start + allocs_[index_1].block_.size_ - 1; + + ORT_ENFORCE(alloc_1_start <= alloc_1_end); + + size_t alloc_2_start = allocs_[index_2].block_.offset_; + size_t alloc_2_end = alloc_2_start + allocs_[index_2].block_.size_ - 1; + + ORT_ENFORCE(alloc_2_start <= alloc_2_end); + + if (((alloc_1_start >= alloc_2_start) && (alloc_1_start <= alloc_2_end)) || + ((alloc_2_start >= alloc_1_start) && (alloc_2_start <= alloc_1_end))) { + ORT_ENFORCE(!OverlappingTimeSchedules(allocs_[index_1].program_counter_start_, allocs_[index_1].program_counter_end_, + allocs_[index_2].program_counter_start_, allocs_[index_2].program_counter_end_)); + } + } + } + MemoryPattern pattern; pattern.peak_size_ = buffer_size_; for (auto& alloc : allocs_) { @@ -94,9 +225,12 @@ class MemPatternPlanner { struct OrtValueAllocationBlock { int index_{-1}; MemoryBlock block_; - + const std::vector program_counter_start_; + const std::vector program_counter_end_; + bool reuse_{false}; OrtValueAllocationBlock() = default; - OrtValueAllocationBlock(int index, const MemoryBlock& block) : index_(index), block_(block) {} + OrtValueAllocationBlock(int index, const MemoryBlock& block) : index_(index), block_(block), reuse_{false} {} + OrtValueAllocationBlock(int index, std::vector program_counter_start, std::vector program_counter_end, const MemoryBlock& block) : index_(index), block_(block), program_counter_start_(program_counter_start), program_counter_end_(program_counter_end), reuse_{true} {} }; std::vector allocs_; diff --git a/onnxruntime/core/framework/ort_value_pattern_planner.cc b/onnxruntime/core/framework/ort_value_pattern_planner.cc index e4df7c65608a0..6fd8b140f50b3 100644 --- a/onnxruntime/core/framework/ort_value_pattern_planner.cc +++ b/onnxruntime/core/framework/ort_value_pattern_planner.cc @@ -13,6 +13,18 @@ OrtValuePatternPlanner::OrtValuePatternPlanner(const ExecutionPlanBase& executio } } +common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, const std::vector& program_counter_start, const std::vector& program_counter_end, size_t size) { + // TODO(codemzs): refactor code. + auto location = execution_planner_.GetLocation(ort_value_idx); + auto it = planner_map_.find(location); + if (it == planner_map_.end()) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); + } + + it->second->TraceAllocation(ort_value_idx, program_counter_start, program_counter_end, size); + return common::Status::OK(); +} + common::Status OrtValuePatternPlanner::TraceAllocation(int ort_value_idx, size_t size) { auto location = execution_planner_.GetLocation(ort_value_idx); auto it = planner_map_.find(location); diff --git a/onnxruntime/core/framework/ort_value_pattern_planner.h b/onnxruntime/core/framework/ort_value_pattern_planner.h index e6748f92ff9af..fd25e49e92dba 100644 --- a/onnxruntime/core/framework/ort_value_pattern_planner.h +++ b/onnxruntime/core/framework/ort_value_pattern_planner.h @@ -19,6 +19,7 @@ class ExecutionPlanBase; class OrtValuePatternPlanner { public: explicit OrtValuePatternPlanner(const ExecutionPlanBase& execution_plan); + common::Status TraceAllocation(int ort_value_idx, const std::vector& program_counter_start, const std::vector& program_counter_end, size_t size); common::Status TraceAllocation(int ort_value_idx, size_t size); common::Status TraceFree(int ort_value_index); common::Status GeneratePatterns(MemoryPatternGroup* out); diff --git a/onnxruntime/core/framework/parallel_executor.cc b/onnxruntime/core/framework/parallel_executor.cc index d7865b27673b9..10ddb74856bdc 100644 --- a/onnxruntime/core/framework/parallel_executor.cc +++ b/onnxruntime/core/framework/parallel_executor.cc @@ -191,6 +191,9 @@ Status ParallelExecutor::RunNodeAsync(size_t p_node_index, // Execute the kernel. ORT_TRY { + if (p_op_kernel->KernelDef().AllocateInputsContiguously()) + utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context); + status = p_op_kernel->Compute(&op_kernel_context); } ORT_CATCH(const std::exception& ex) { diff --git a/onnxruntime/core/framework/sequential_execution_plan.h b/onnxruntime/core/framework/sequential_execution_plan.h index 9d76856aeb514..e806e65cd2597 100644 --- a/onnxruntime/core/framework/sequential_execution_plan.h +++ b/onnxruntime/core/framework/sequential_execution_plan.h @@ -7,6 +7,7 @@ #include "core/framework/alloc_kind.h" #include "core/framework/data_types.h" #include "core/framework/execution_plan_base.h" +#include "core/graph/graph.h" namespace onnxruntime { // Every ml-value has a unique name and is assigned a unique integral number. @@ -30,6 +31,8 @@ struct AllocPlanPerValue { // if the value is used in async kernel, a fence object would be created // note the fence object would be shared between MLValues reusing the same buffer bool create_fence_if_async{false}; + std::vector program_counter_start; + std::vector program_counter_end; public: AllocPlanPerValue() : location(CPU, Invalid) {} @@ -45,6 +48,12 @@ struct SequentialExecutionPlan : public ExecutionPlanBase { // The following vector is indexed by OrtValueIndex std::vector allocation_plan; + // The following vector contains any initializer tensors that must be allocated sequentially. + std::vector initializer_allocation_order; + + // The following vector contains any activation tensors that must be allocated sequentially. + std::vector activation_allocation_order; + // The following indicates the order in which nodes should be executed and the // ml-values to be free after each node's execution: diff --git a/onnxruntime/core/framework/sequential_executor.cc b/onnxruntime/core/framework/sequential_executor.cc index cbcef6a52121a..6877dac0de269 100644 --- a/onnxruntime/core/framework/sequential_executor.cc +++ b/onnxruntime/core/framework/sequential_executor.cc @@ -13,6 +13,7 @@ #include "core/framework/execution_frame.h" #include "core/framework/session_state.h" #include "core/framework/op_kernel_context_internal.h" +#include "core/framework/utils.h" #if defined DEBUG_NODE_INPUTS_OUTPUTS #include "core/framework/debug_node_inputs_outputs_utils.h" @@ -306,6 +307,9 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std: node_compute_range.Begin(); #endif ORT_TRY { + if (p_op_kernel->KernelDef().AllocateInputsContiguously()) + utils::VerifyInputTensorsAllocatedContiguously(&op_kernel_context); + compute_status = p_op_kernel->Compute(&op_kernel_context); } ORT_CATCH(const std::exception& ex) { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index c5fe24cd56574..324e6432b13aa 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -314,14 +314,11 @@ Status ResolveDimParams(const GraphViewer& graph, return Status::OK(); } -Status ResolveSizeAndShape( +Status TryResolveShape( const NodeArg* arg, const std::unordered_map& symbolic_dimensions, - size_t& size, // total number of elements. It's 0 if shape is unknown. std::vector& resolved_shape) { if (!arg->Shape()) { - // 0 means no shape information. - size = 0; return Status::OK(); } @@ -345,15 +342,24 @@ Status ResolveSizeAndShape( } } - size = safe_size; - // Only assign shape if all symbolic dimensions are resolved. - if (size != 0) { + if (safe_size != 0) { resolved_shape = std::move(shape); } return Status::OK(); } + +void TryCalculateSizeFromResolvedShape(int ml_value_idx, std::unordered_map& resolved_shapes, size_t& size) { + size = 0; + auto shape = resolved_shapes.find(ml_value_idx); + if (shape != resolved_shapes.end()) { + size = 1; + for (auto dim : shape->second.GetDims()) + size *= dim; + } +} + } // namespace Status SessionState::GeneratePatternGroupCache(const std::vector>& input_shape, @@ -371,31 +377,83 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorexecution_plan) { int node_index = node_index_info.GetNodeOffset(node_plan.node_index); auto* node = graph_viewer_->GetNode(node_plan.node_index); int output_start = node_index + static_cast(node->InputDefs().size()) + static_cast(node->ImplicitInputDefs().size()); - //allocate output + for (int i = 0, end = static_cast(node->OutputDefs().size()); i < end; ++i) { const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i); if (ml_value_idx == NodeIndexInfo::kInvalidEntry) continue; + const auto* ml_type = exe_plan->allocation_plan[ml_value_idx].value_type; if (!ml_type->IsTensorType()) continue; - const auto* ml_data_type = static_cast(ml_type)->GetElementType(); auto* arg = node->OutputDefs()[i]; - size_t size = 0; std::vector resolved_shape; - ORT_RETURN_IF_ERROR(ResolveSizeAndShape(arg, map, size, resolved_shape)); + ORT_RETURN_IF_ERROR(TryResolveShape(arg, map, resolved_shape)); // Store all valid resolved shapes. They will be queried in, for example, // Recv operator to bypass the dependency of output shapes on inputs. - if (size != 0) { + if (resolved_shape.size() > 0) { resolved_shapes[ml_value_idx] = resolved_shape; } + } + } + + // Allocate activations that want to be laid out contigously in memory. + for (auto ml_value_idx : exe_plan->activation_allocation_order) { + ORT_ENFORCE(ml_value_idx >= 0); + + const auto* ml_type = exe_plan->allocation_plan[ml_value_idx].value_type; + if (!ml_type->IsTensorType()) + continue; + const auto* ml_data_type = static_cast(ml_type)->GetElementType(); + if (exe_plan->allocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate && + ml_data_type != DataTypeImpl::GetType()) { + size_t size = 0; + TryCalculateSizeFromResolvedShape(ml_value_idx, resolved_shapes, size); + if (size == 0) { + return Status(ONNXRUNTIME, FAIL, "Unknown shape found in memory pattern compute"); + } + + if (!IAllocator::CalcMemSizeForArrayWithAlignment<64>(size, ml_data_type->Size(), &size)) { + return Status(ONNXRUNTIME, FAIL, "Size overflow"); + } + + ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate); + ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start.size() == exe_plan->allocation_plan[ml_value_idx].program_counter_end.size()); + + for (size_t index = 0; index < exe_plan->allocation_plan[ml_value_idx].program_counter_start.size(); index += 1) + ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start[index] <= exe_plan->allocation_plan[ml_value_idx].program_counter_end[index]); + + mem_planner.TraceAllocation(ml_value_idx, exe_plan->allocation_plan[ml_value_idx].program_counter_start, + exe_plan->allocation_plan[ml_value_idx].program_counter_end, size); + } + } + + // Allocate all other activations. + for (auto& node_plan : exe_plan->execution_plan) { + int node_index = node_index_info.GetNodeOffset(node_plan.node_index); + auto* node = graph_viewer_->GetNode(node_plan.node_index); + int output_start = node_index + static_cast(node->InputDefs().size()) + static_cast(node->ImplicitInputDefs().size()); + //allocate output + for (int i = 0, end = static_cast(node->OutputDefs().size()); i < end; ++i) { + const auto ml_value_idx = node_index_info.GetMLValueIndex(output_start + i); + if (ml_value_idx == NodeIndexInfo::kInvalidEntry || + (std::find(exe_plan->activation_allocation_order.begin(), exe_plan->activation_allocation_order.end(), ml_value_idx) != exe_plan->activation_allocation_order.end())) + continue; + const auto* ml_type = exe_plan->allocation_plan[ml_value_idx].value_type; + if (!ml_type->IsTensorType()) + continue; + const auto* ml_data_type = static_cast(ml_type)->GetElementType(); + size_t size = 0; + TryCalculateSizeFromResolvedShape(ml_value_idx, resolved_shapes, size); // Plan memory if conditions are met. if (exe_plan->allocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate && @@ -405,9 +463,17 @@ Status SessionState::GeneratePatternGroupCache(const std::vectorallocation_plan[ml_value_idx].alloc_kind == AllocKind::kAllocate); + ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start.size() == exe_plan->allocation_plan[ml_value_idx].program_counter_end.size()); + + for (size_t index = 0; index < exe_plan->allocation_plan[ml_value_idx].program_counter_start.size(); index += 1) + ORT_ENFORCE(exe_plan->allocation_plan[ml_value_idx].program_counter_start[index] <= exe_plan->allocation_plan[ml_value_idx].program_counter_end[index]); + + mem_planner.TraceAllocation(ml_value_idx, exe_plan->allocation_plan[ml_value_idx].program_counter_start, + exe_plan->allocation_plan[ml_value_idx].program_counter_end, aligned_size); } } + //release nodes for (int index = node_plan.free_from_index; index <= node_plan.free_to_index; ++index) { auto ml_value_idx = exe_plan->to_be_freed[index]; @@ -874,12 +940,14 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string tensor_allocator_( ITensorAllocator::Create(enable_mem_pattern_, *p_seq_exec_plan_, *this, weights_buffers_)); + const auto& initializer_allocation_order = p_seq_exec_plan_->initializer_allocation_order; + // move initializers from TensorProto instances in Graph to OrtValue instances in SessionState ORT_RETURN_IF_ERROR( session_state_utils::SaveInitializedTensors( Env::Default(), graph_location, *graph_viewer_, execution_providers_.GetDefaultCpuMemoryInfo(), - ort_value_name_idx_map_, *tensor_allocator_, + ort_value_name_idx_map_, initializer_allocation_order, *tensor_allocator_, [this](int idx, const OrtValue& value, const OrtCallback& d, bool constant) -> Status { return AddInitializedTensor(idx, value, &d, constant); }, diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 1bcdcba04657f..a749312c6b41b 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -90,7 +90,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st common::Status SaveInitializedTensors( const Env& env, const std::basic_string& graph_loc, const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, - const OrtValueNameIdxMap& ort_value_name_idx_map, ITensorAllocator& planner, + const OrtValueNameIdxMap& ort_value_name_idx_map, const std::vector& initializer_allocation_order, ITensorAllocator& planner, const std::function& save_tensor_func, const logging::Logger& logger, const DataTransferManager& data_transfer_mgr, const ExecutionPlanBase& exec_plan, @@ -139,6 +139,15 @@ common::Status SaveInitializedTensors( } id_to_initialized_tensor[ort_value_index] = entry.second; } + + // tensors requiring a specific allocation order are traced first, to ensure they are allocated in order + auto initialized_tensors_to_allocate = id_to_initialized_tensor; + for (int ort_value_index : initializer_allocation_order) { + const auto entry = initialized_tensors_to_allocate.find(ort_value_index); + ORT_ENFORCE(entry != initialized_tensors_to_allocate.end()); + ORT_RETURN_IF_ERROR(planner.Trace(entry->first, entry->second)); + initialized_tensors_to_allocate.erase(entry); + } for (const auto& entry : id_to_initialized_tensor) { // We don't want to trace shared initializers since their memory is provided by the user diff --git a/onnxruntime/core/framework/session_state_utils.h b/onnxruntime/core/framework/session_state_utils.h index b499c8426734a..a05c8910d6d02 100644 --- a/onnxruntime/core/framework/session_state_utils.h +++ b/onnxruntime/core/framework/session_state_utils.h @@ -9,6 +9,7 @@ #include "core/framework/tensor.h" #include "core/framework/tensor_allocator.h" #include "core/framework/session_options.h" +#include "core/framework/sequential_execution_plan.h" #include "core/platform/path_lib.h" namespace onnxruntime { @@ -29,7 +30,7 @@ namespace session_state_utils { common::Status SaveInitializedTensors( const Env& env, const std::basic_string& graph_loc, const GraphViewer& graph, const OrtMemoryInfo& default_cpu_memory_info, - const OrtValueNameIdxMap& ort_value_name_idx_map, + const OrtValueNameIdxMap& ort_value_name_idx_map, const std::vector& initializer_allocation_order, ITensorAllocator& planner, const std::function& save_tensor_func, const logging::Logger& logger, diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index ab573c0b40ec4..2b95fe31468dc 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -568,5 +568,26 @@ int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType onn } } +common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context) { + const Tensor* prev_input = context->Input(0); + for (int i = 1; i < context->InputCount(); i++) { + const Tensor* curr_input = context->Input(i); + + ORT_ENFORCE(prev_input->Shape().Size() >= 0); + + size_t input_element_count = static_cast(prev_input->Shape().Size()); + size_t input_element_size = prev_input->DataType()->Size(); + size_t input_aligned_bytes = 0; + + ORT_RETURN_IF_NOT(IAllocator::CalcMemSizeForArrayWithAlignment<256>(input_element_count, input_element_size, &input_aligned_bytes)); + + ORT_RETURN_IF_NOT(curr_input->DataRaw() == static_cast(prev_input->DataRaw()) + input_aligned_bytes || + curr_input->DataRaw() == static_cast(prev_input->DataRaw()) + prev_input->SizeInBytes()); + + prev_input = curr_input; + } + return Status::OK(); +} + } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 9a81215d0d1d9..cfe1087a098ae 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -149,5 +149,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); +common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); + } // namespace utils } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/allreduce_optimizer_graph_builder.cc b/orttraining/orttraining/core/graph/allreduce_optimizer_graph_builder.cc index ed4d483a5f9df..5c837b4ae404b 100644 --- a/orttraining/orttraining/core/graph/allreduce_optimizer_graph_builder.cc +++ b/orttraining/orttraining/core/graph/allreduce_optimizer_graph_builder.cc @@ -92,61 +92,29 @@ Status AllreduceOptimizerGraphBuilder::AddHorovodAllReduceForGradients(std::vect static Status AddNcclAllReduceForGradients( std::vector& gradient_argdefs, - ArgDef& fused_gradient_argdef, - GraphAugmenter::GraphDefs& graph_defs, - ArgDef& fused_allreduce_output) { - fused_allreduce_output = ArgDef(fused_gradient_argdef.name + "AllReduce_Out", fused_gradient_argdef.type_proto); - - // Add NCCL Allreduce node. - graph_defs.AddNodeDefs({NodeDef(OpDef{"NcclAllReduce", kMSDomain, 1}, - {fused_gradient_argdef}, - {fused_allreduce_output}, - NodeAttributes(), - "NcclAllReduce")}); - - std::vector view_inputs(gradient_argdefs.size() + 1); - view_inputs[0] = fused_allreduce_output; - - for (size_t i = 0; i < gradient_argdefs.size(); i++) { - ArgDef& gradient_shape = view_inputs[i + 1]; - gradient_shape = ArgDef(gradient_argdefs[i].name + "_Shape"); - - graph_defs.AddNodeDefs({NodeDef("Shape", - {gradient_argdefs[i]}, - {gradient_shape}, - NodeAttributes(), - gradient_shape.name)}); - } + std::vector& input_gradient_argdef, + GraphAugmenter::GraphDefs& graph_defs) { std::vector allreduce_outputs(gradient_argdefs.size()); for (size_t i = 0; i < gradient_argdefs.size(); i++) { TypeProto* allreduced_gradient_type_proto = graph_defs.CopyTypeProto(gradient_argdefs[i]); allreduced_gradient_type_proto->mutable_tensor_type()->set_elem_type( - fused_gradient_argdef.type_proto->tensor_type().elem_type()); + input_gradient_argdef[0].type_proto->tensor_type().elem_type()); allreduce_outputs[i] = ArgDef(gradient_argdefs[i].name + "_AllReduce_Out", allreduced_gradient_type_proto); } - graph_defs.AddNodeDefs({NodeDef(OpDef{"View", kMSDomain, 1}, - view_inputs, + // Add NCCL Allreduce node. + graph_defs.AddNodeDefs({NodeDef(OpDef{"NcclAllReduce", kMSDomain, 1}, + input_gradient_argdef, allreduce_outputs, NodeAttributes(), - "AllReduceOutputView")}); + "NcclAllReduce")}); gradient_argdefs = allreduce_outputs; return Status::OK(); } -static std::vector GetGradientNormInputs( - const std::vector& gradient_argdefs, - ArgDef fused_gradient_argdef) { - if (!fused_gradient_argdef.name.empty()) { - return {fused_gradient_argdef}; - } else { - return gradient_argdefs; - } -} - AllreduceOptimizerGraphBuilder::AllreduceOptimizerGraphBuilder( const OptimizerBuilderRegistry& opt_builder_registry, const OptimizerGraphConfig& opt_graph_config, @@ -174,24 +142,20 @@ Status AllreduceOptimizerGraphBuilder::BuildInternal( return graph.GenerateNodeArgName(base_name); }; - const bool overlap_compute_allreduce = !opt_graph_config_.use_nccl; const int64_t horovod_reduce_op = opt_graph_config_.horovod_reduce_op; // add gradient scaling - ArgDef fused_gradient_argdef; + std::vector output_gradient_argdef; const auto total_num_accumulations = opt_graph_config_.gradient_accumulation_steps * opt_graph_config_.data_parallel_group_size; ORT_RETURN_IF_NOT(total_num_accumulations > 0); const float scale = 1.0f / total_num_accumulations; - const bool fuse_scaling_outputs = !overlap_compute_allreduce; - ORT_RETURN_IF_ERROR(AddGradientScalingNodes(nodearg_name_generator, scale, gradient_argdefs, fused_gradient_argdef, graph_defs, - opt_graph_config_.AllReduceDataType(), fuse_scaling_outputs)); + ORT_RETURN_IF_ERROR(AddGradientScalingNodes(nodearg_name_generator, scale, gradient_argdefs, output_gradient_argdef, graph_defs, + opt_graph_config_.AllReduceDataType())); // add Allreduce for gradients - ArgDef reduced_fused_gradient_argdef; - if (opt_graph_config_.use_nccl) { - ORT_RETURN_IF_ERROR(AddNcclAllReduceForGradients(gradient_argdefs, fused_gradient_argdef, graph_defs, reduced_fused_gradient_argdef)); + ORT_RETURN_IF_ERROR(AddNcclAllReduceForGradients(gradient_argdefs, output_gradient_argdef, graph_defs)); } else { ORT_RETURN_IF_ERROR(AddHorovodAllReduceForGradients(gradient_argdefs, graph_defs, horovod_reduce_op)); } @@ -200,9 +164,8 @@ Status AllreduceOptimizerGraphBuilder::BuildInternal( ArgDef global_grad_norm_argdef; ArgDef global_grad_norm_finite_argdef; if (opt_graph_config_.use_mixed_precision) { - auto gradient_norm_inputs = GetGradientNormInputs(gradient_argdefs, reduced_fused_gradient_argdef); ORT_RETURN_IF_ERROR(AddGradientNorm( - nodearg_name_generator, gradient_norm_inputs, graph_defs, global_grad_norm_argdef)); + nodearg_name_generator, gradient_argdefs, graph_defs, global_grad_norm_argdef)); optimizer_graph_outputs[OptimizerOutputKey::GlobalGradientNorm] = global_grad_norm_argdef.name; ORT_RETURN_IF_ERROR(AddFiniteGradientCheck( diff --git a/orttraining/orttraining/core/graph/optimizer_graph_builder.cc b/orttraining/orttraining/core/graph/optimizer_graph_builder.cc index 0e8f4e7219740..280c424a316f5 100644 --- a/orttraining/orttraining/core/graph/optimizer_graph_builder.cc +++ b/orttraining/orttraining/core/graph/optimizer_graph_builder.cc @@ -129,6 +129,45 @@ Status OptimizerGraphBuilder::AddGradientScalingNodes( return Status::OK(); } +Status OptimizerGraphBuilder::AddGradientScalingNodes( + const NodeArgNameGeneratorFn& nodearg_name_generator, + const float scale, + std::vector& input_gradient_argdefs, // update argdefs in place + std::vector& output_gradient_argdef, // update argdef in place + GraphAugmenter::GraphDefs& graph_defs, + ONNX_NAMESPACE::TensorProto_DataType target_type) { + ArgDef pre_allreduce_scale(nodearg_name_generator("pre_allreduce_scale"), + graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + + graph_defs.AddInitializers({CreateTensorProto(pre_allreduce_scale.name, scale, {})}); + + TypeProto* fused_gradient_type_proto = graph_defs.CreateTypeProto(); + fused_gradient_type_proto->mutable_tensor_type()->set_elem_type(target_type); + + std::vector inputs; + inputs.emplace_back(pre_allreduce_scale); + for (size_t i = 0; i < input_gradient_argdefs.size(); ++i) { + inputs.emplace_back(input_gradient_argdefs[i]); + } + + for (size_t i = 0; i < input_gradient_argdefs.size(); ++i) { + ArgDef& gradient_argdef = input_gradient_argdefs[i]; + + TypeProto* scaled_gradient_type_proto = graph_defs.CopyTypeProto(gradient_argdef); + scaled_gradient_type_proto->mutable_tensor_type()->set_elem_type(target_type); + + output_gradient_argdef.emplace_back(ArgDef(nodearg_name_generator(gradient_argdef.name + "_scaled"), scaled_gradient_type_proto)); + } + + graph_defs.AddNodeDefs({NodeDef(OpDef{"MixedPrecisionScale", kMSDomain, 1}, + inputs, + output_gradient_argdef, + std::vector({ONNX_NAMESPACE::MakeAttribute("to", static_cast(target_type))}), + pre_allreduce_scale.name)}); + + return Status::OK(); +} + ArgDef AddGradientAccumulationNodes(const NodeArgNameGeneratorFn& nodearg_name_generator, std::vector& gradient_argdefs, // update argdefs in place std::vector& gradient_accumulation_buffers, // output @@ -163,7 +202,7 @@ Status AddZeroGradientNodes(const NodeArgNameGeneratorFn& nodearg_name_generator const std::vector& control_signals, std::vector& gradient_argdefs, // update argdefs in place GraphAugmenter::GraphDefs& graph_defs) { - assert(gradient_argdefs.size() == control_signals.size()); + //assert(gradient_argdefs.size() == control_signals.size()); for (size_t i = 0; i < gradient_argdefs.size(); ++i) { gradient_argdefs[i] = BuildZeroGradientNode(nodearg_name_generator, control_signals[i], gradient_argdefs[i], graph_defs); } diff --git a/orttraining/orttraining/core/graph/optimizer_graph_builder.h b/orttraining/orttraining/core/graph/optimizer_graph_builder.h index 226aa0dcb8716..f3fea149d6f5e 100644 --- a/orttraining/orttraining/core/graph/optimizer_graph_builder.h +++ b/orttraining/orttraining/core/graph/optimizer_graph_builder.h @@ -82,6 +82,14 @@ class OptimizerGraphBuilder { ONNX_NAMESPACE::TensorProto_DataType allreduce_element_type, const bool fuse_scaling_outputs); + Status AddGradientScalingNodes( + const NodeArgNameGeneratorFn& nodearg_name_generator, + const float scale, + std::vector& gradient_argdefs, // update argdefs in place + std::vector& output_gradient_argdef, // update argdef in place + GraphAugmenter::GraphDefs& graph_defs, + ONNX_NAMESPACE::TensorProto_DataType target_type); + Status AddGradientNorm( const NodeArgNameGeneratorFn& nodearg_name_generator, const std::vector& grad_argdefs, diff --git a/orttraining/orttraining/models/bert/main.cc b/orttraining/orttraining/models/bert/main.cc index ecda50577e73c..d020539627f8c 100644 --- a/orttraining/orttraining/models/bert/main.cc +++ b/orttraining/orttraining/models/bert/main.cc @@ -54,7 +54,7 @@ struct BertParameters : public TrainingRunner::Parameters { size_t num_train_steps_phase2; float warmup_ratio_phase2; float cuda_mem_limit_in_gb = -1; - + bool debug_break = false; PathString train_data_dir_phase2; PathString test_data_dir_phase2; @@ -191,6 +191,8 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet ("number_recompute_layers", "Number of layers to apply recompute.", cxxopts::value()->default_value("0")) ("use_invertible_layernorm_grad", "Specify whether to use invertible laynorm(dropping the input activation)", + cxxopts::value()->default_value("false")) + ("debug_break", "Specify whether to break at app start, useful for multi-gpu debugging.", cxxopts::value()->default_value("false")); options .add_options("ORT configuration") @@ -204,6 +206,7 @@ Status ParseArguments(int argc, char* argv[], BertParameters& params, OrtParamet auto flags = options.parse(argc, argv); params.model_name = flags["model_name"].as(); + params.debug_break = flags["debug_break"].as(); float lr = flags["learning_rate"].as(); if (lr > 1.f || lr < 0.f) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "learning_rate is not in valid range [0.0, 1.0]"); @@ -801,6 +804,8 @@ int main(int argc, char* argv[]) { BertParameters params; OrtParameters ort_params{}; RETURN_IF_FAIL(ParseArguments(argc, argv, params, ort_params)); + bool keep_looping = params.debug_break; + while(keep_looping); // setup logger, be noted: LOGS_DEFAULT must be after logging manager initialization. string default_logger_id{"Default"}; diff --git a/orttraining/orttraining/training_ops/cuda/collective/nccl_kernels.cc b/orttraining/orttraining/training_ops/cuda/collective/nccl_kernels.cc index 72ac650638c7d..0dc3209d12d30 100644 --- a/orttraining/orttraining/training_ops/cuda/collective/nccl_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/collective/nccl_kernels.cc @@ -13,19 +13,18 @@ Status NcclAllReduce::ComputeInternal(OpKernelContext* context) const { cudaStream_t stream = nullptr; // Default stream ncclComm_t comm = nccl_->Comm(group_type_); + size_t input_count = 0; + const void* input_data = context->Input(0)->DataRaw(); + void* output_data = context->Output(0, context->Input(0)->Shape())->MutableDataRaw(); + MLDataType onnx_type = context->Input(0)->DataType(); for (int i = 0; i < context->InputCount(); i++) { const Tensor* input_tensor = context->Input(i); - auto onnx_type = input_tensor->DataType(); - const void* input_data = input_tensor->DataRaw(); - size_t input_count = input_tensor->Shape().Size(); - - Tensor* output_tensor = context->Output(i, input_tensor->Shape()); - void* output_data = output_tensor->MutableDataRaw(); - - ncclDataType_t dtype = GetNcclDataType(onnx_type); - NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, input_count, dtype, ncclSum, comm, stream)); + input_count += input_tensor->Shape().Size(); + context->Output(i, input_tensor->Shape()); } + ncclDataType_t dtype = GetNcclDataType(onnx_type); + NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, input_count, dtype, ncclSum, comm, stream)); return Status::OK(); } @@ -219,6 +218,7 @@ ONNX_OPERATOR_KERNEL_EX( kCudaExecutionProvider, KernelDefBuilder() .Alias(AliasRange(0, 1024)) + .AllocateInputsContiguously() .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), NcclAllReduce); @@ -229,6 +229,7 @@ ONNX_OPERATOR_KERNEL_EX( kCudaExecutionProvider, KernelDefBuilder() .Alias(AliasRange(0, 1024)) + .AllocateInputsContiguously() .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), NcclAllGather); @@ -239,6 +240,7 @@ ONNX_OPERATOR_KERNEL_EX( kCudaExecutionProvider, KernelDefBuilder() .Alias(AliasRange(0, 1024)) + .AllocateInputsContiguously() .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()), NcclReduceScatter); diff --git a/orttraining/tools/ci_test/run_convergence_test.py b/orttraining/tools/ci_test/run_convergence_test.py index 68528e2897ca6..9b100d7fc2d3f 100755 --- a/orttraining/tools/ci_test/run_convergence_test.py +++ b/orttraining/tools/ci_test/run_convergence_test.py @@ -31,6 +31,8 @@ def main(): # run BERT training subprocess.run([ + "/bert_ort/openmpi/bin/mpirun", + "-n", "2", os.path.join(args.binary_dir, "onnxruntime_training_bert"), "--model_name", os.path.join( args.model_root, "nv/bert-base/bert-base-uncased_L_12_H_768_A_12_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12"),